Commit ·
a8bf2f3
verified ·
0
Parent(s):
Super-squash branch 'main' using huggingface_hub
Browse files- .gitattributes +35 -0
- README.md +100 -0
- common_spear.py +702 -0
- config.json +167 -0
- configuration_spear.py +347 -0
- generation_config.json +3 -0
- model-00001-of-00003.safetensors +3 -0
- model-00002-of-00003.safetensors +3 -0
- model-00003-of-00003.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_spear.py +0 -0
- processing_spear.py +1897 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: gemma
|
| 3 |
+
library_name: transformers
|
| 4 |
+
pipeline_tag: visual-question-answering
|
| 5 |
+
---
|
| 6 |
+
# SPEAR-1 model card
|
| 7 |
+
|
| 8 |
+
SPEAR-1 is a cutting-edge Vision-Language-Action (VLA) model capable of achieving performance __superior or on par with state-of-the-art models such as pi0-FAST and pi0.5__
|
| 9 |
+
on multiple embodiments while being trained __on 20x less robot data__.
|
| 10 |
+
|
| 11 |
+
This model was developed by [INSAIT](https://insait.ai/), a special unit of Sofia University St. Kliment Ohridski, in Sofia, Bulgaria.
|
| 12 |
+
|
| 13 |
+
Code and model weights for SPEAR-1 models are free to used under the Gemma license.
|
| 14 |
+
|
| 15 |
+
This repo provides model weights fine-tuned for a Franka setup with one wrist and one external camera.
|
| 16 |
+
|
| 17 |
+
## Model description
|
| 18 |
+
|
| 19 |
+
The key to SPEAR-1's data efficiency is SPEAR-VLM, a 3D-aware VLM. SPEAR-VLM extends PaliGemma with the MoGe depth encoder and is trained on 3D VQA tasks using
|
| 20 |
+
primarily non-robot data sources, such as EgoExo-4D.
|
| 21 |
+
|
| 22 |
+
SPEAR-1's architecture combines SPEAR-VLM with a DiT action expert. It is first pre-trained on a mixture of robot demonstration datasets from Open X Embodiment and
|
| 23 |
+
then fine-tuned for specific embodiments.
|
| 24 |
+
|
| 25 |
+
## Use with 🤗 Transformers
|
| 26 |
+
|
| 27 |
+
We provide a fully `AutoModel` compatible implementation of SPEAR-1 that can be used via transformers.
|
| 28 |
+
|
| 29 |
+
### Environment setup
|
| 30 |
+
|
| 31 |
+
The current implementation requires the following additional dependencies: `roma`, `timm`, `flash-attn`.
|
| 32 |
+
|
| 33 |
+
Here is a snippet to set up a working environment for inference via `uv`:
|
| 34 |
+
|
| 35 |
+
```
|
| 36 |
+
uv venv python 3.10.12
|
| 37 |
+
source .venv/bin/activate
|
| 38 |
+
uv pip install --torch-backend=cu126 roma==1.5.0 numpy==2.2.4 torch==2.6.0 torchvision==0.21.0 transformers==4.47.0 timm==1.0.15
|
| 39 |
+
uv pip install --no-build-isolation setuptools psutil flash-attn==2.7.3
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
### Example usage
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
```python
|
| 46 |
+
from typing import Dict
|
| 47 |
+
|
| 48 |
+
import numpy as np
|
| 49 |
+
import torch
|
| 50 |
+
from PIL import Image
|
| 51 |
+
from transformers import AutoModel
|
| 52 |
+
|
| 53 |
+
model = AutoModel.from_pretrained("INSAIT-Institute/spear1-franka")
|
| 54 |
+
model = model.to(dtype=torch.bfloat16, device="cuda").eval()
|
| 55 |
+
|
| 56 |
+
main_image = np.asarray(Image.open("path/to/main_image.png"))
|
| 57 |
+
wrist_image = np.asarray(Image.open("path/to/wrist_image.png"))
|
| 58 |
+
|
| 59 |
+
ee_translation = np.array([0.36, 0.0, 0.56])
|
| 60 |
+
ee_rotation = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
|
| 61 |
+
gripper = np.array(1.0)
|
| 62 |
+
|
| 63 |
+
model_input: Dict[str, np.ndarray | str | Dict[str, np.ndarray]] = {
|
| 64 |
+
"images": {
|
| 65 |
+
"main": main_image, # (H, W, C)
|
| 66 |
+
"wrist": wrist_image, # (H, W, C)
|
| 67 |
+
},
|
| 68 |
+
"ee_translation": ee_translation, # (3,)
|
| 69 |
+
"ee_rotation": ee_rotation, # (3, 3)
|
| 70 |
+
"gripper": gripper, # (1,)
|
| 71 |
+
"language_instruction": "put the carrot on the blue plate",
|
| 72 |
+
"dataset_name": "droid"
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
model_output: Dict[str, np.ndarray] = model.predict_action(model_input)
|
| 76 |
+
|
| 77 |
+
ctrl_translation: np.ndarray = model_output["translation"] # (S, 3)
|
| 78 |
+
ctrl_rotation: np.ndarray = model_output["rotation"] # (S, 3, 3)
|
| 79 |
+
ctrl_gripper: np.ndarray = model_output["gripper"] # (S, 1)
|
| 80 |
+
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
## Action space
|
| 84 |
+
|
| 85 |
+
SPEAR-1 predicts action chunks of delta end-effector positions. Each step in the predicted action chunk is relative to the input state.
|
| 86 |
+
|
| 87 |
+
Given the current end-effector position `[R, t]` and a model prediction `A_rel = [[R_1, t_1], ..., [R_n, t_n]]`, absolute end effector pose commands can be computed as:
|
| 88 |
+
```
|
| 89 |
+
A_abs = [[R * R_1, t + t_1], ..., [R * R_n, t * t_n]]
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
## Community Feedback
|
| 93 |
+
|
| 94 |
+
We welcome feedback from the community to help improve SPEAR-1. If you have suggestions, encounter any issues, or have ideas for improvements, please contact us.
|
| 95 |
+
|
| 96 |
+
## Summary
|
| 97 |
+
|
| 98 |
+
- __Model type__: Vision-Language-Action with flow-matching action decoding
|
| 99 |
+
- __Contact__: contact@insait.ai
|
| 100 |
+
- __License__: Gemma Terms of Use
|
common_spear.py
ADDED
|
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections.abc
|
| 2 |
+
import dataclasses
|
| 3 |
+
import enum
|
| 4 |
+
import inspect
|
| 5 |
+
import types
|
| 6 |
+
from collections.abc import Mapping as MappingABC
|
| 7 |
+
from functools import cached_property
|
| 8 |
+
from typing import (
|
| 9 |
+
Any,
|
| 10 |
+
Callable,
|
| 11 |
+
Dict,
|
| 12 |
+
Iterable,
|
| 13 |
+
List,
|
| 14 |
+
Mapping,
|
| 15 |
+
Optional,
|
| 16 |
+
Sequence,
|
| 17 |
+
Tuple,
|
| 18 |
+
Type,
|
| 19 |
+
Union,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import transformers
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class StrEnum(str, enum.Enum):
|
| 27 |
+
"""
|
| 28 |
+
A minimal drop-in replacement for backports.strenum.StrEnum
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __str__(self):
|
| 32 |
+
return str(self.value)
|
| 33 |
+
|
| 34 |
+
def __new__(cls, value):
|
| 35 |
+
# Create new instance that properly handles string initialization
|
| 36 |
+
if isinstance(value, str):
|
| 37 |
+
obj = str.__new__(cls, value)
|
| 38 |
+
obj._value_ = value
|
| 39 |
+
return obj
|
| 40 |
+
return super().__new__(cls, value)
|
| 41 |
+
|
| 42 |
+
@classmethod
|
| 43 |
+
def _missing_(cls, value):
|
| 44 |
+
# Enhanced lookup by string value with better error handling
|
| 45 |
+
if isinstance(value, str):
|
| 46 |
+
for member in cls:
|
| 47 |
+
if member.value == value:
|
| 48 |
+
return member
|
| 49 |
+
# Return None to let enum handle the KeyError
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
def __eq__(self, other):
|
| 53 |
+
# Allow comparison with string values
|
| 54 |
+
if isinstance(other, str):
|
| 55 |
+
return self.value == other
|
| 56 |
+
return super().__eq__(other)
|
| 57 |
+
|
| 58 |
+
def __hash__(self):
|
| 59 |
+
# Ensure consistent hashing
|
| 60 |
+
return hash(self.value)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class _cached_classproperty:
|
| 64 |
+
def __init__(self, func):
|
| 65 |
+
self.func = func
|
| 66 |
+
self._values = {}
|
| 67 |
+
|
| 68 |
+
def __get__(self, obj, klass):
|
| 69 |
+
if klass not in self._values.keys():
|
| 70 |
+
self._values[klass] = self.func.__get__(obj, klass)()
|
| 71 |
+
return self._values[klass]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def cached_classproperty(func):
|
| 75 |
+
if not isinstance(func, (classmethod, staticmethod)):
|
| 76 |
+
func = classmethod(func)
|
| 77 |
+
return _cached_classproperty(func)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@dataclasses.dataclass
|
| 81 |
+
class Dataclass:
|
| 82 |
+
def __post_init__(self):
|
| 83 |
+
pass
|
| 84 |
+
|
| 85 |
+
@classmethod
|
| 86 |
+
def make_empty(cls) -> "Dataclass":
|
| 87 |
+
return cls(
|
| 88 |
+
**{
|
| 89 |
+
k: (v.make_empty() if inspect.isclass(v) and issubclass(v, Dataclass) else None)
|
| 90 |
+
for (k, v) in cls.types.items()
|
| 91 |
+
}
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
@cached_classproperty
|
| 95 |
+
def fields(cls) -> Tuple[dataclasses.Field, ...]:
|
| 96 |
+
"""Returns a sorted list of the Field objects"""
|
| 97 |
+
return tuple(sorted(dataclasses.fields(cls), key=lambda x: x.name))
|
| 98 |
+
|
| 99 |
+
@cached_classproperty
|
| 100 |
+
def types(cls) -> Dict[str, type]:
|
| 101 |
+
return {f.name: f.type for f in cls.fields}
|
| 102 |
+
|
| 103 |
+
def as_json(self, recursive: bool = True) -> dict:
|
| 104 |
+
return {k: v.as_json() if isinstance(v, Dataclass) and recursive else v for (k, v) in self.items()}
|
| 105 |
+
|
| 106 |
+
@classmethod
|
| 107 |
+
def keys(cls) -> List[str]:
|
| 108 |
+
return [field.name for field in cls.fields]
|
| 109 |
+
|
| 110 |
+
def values(self):
|
| 111 |
+
return [getattr(self, field.name) for field in self.fields]
|
| 112 |
+
|
| 113 |
+
def items(self, recursive: bool = False):
|
| 114 |
+
for key, value in zip(self.keys(), self.values(), strict=True):
|
| 115 |
+
if recursive and isinstance(value, Dataclass):
|
| 116 |
+
for subkey, subvalue in value.items(recursive=True):
|
| 117 |
+
yield (f"{key}.{subkey}", subvalue)
|
| 118 |
+
else:
|
| 119 |
+
yield (key, value)
|
| 120 |
+
|
| 121 |
+
def replace(self, **kwargs):
|
| 122 |
+
"""
|
| 123 |
+
Return a new instance of Dataclass with the kwargs overwritten.
|
| 124 |
+
"""
|
| 125 |
+
kwargs = maybe_chained_keys_to_nested_dict(kwargs)
|
| 126 |
+
data = self.as_json(recursive=False)
|
| 127 |
+
for key, value in kwargs.items():
|
| 128 |
+
value_type = self.types.get(key, None)
|
| 129 |
+
if value_type is None:
|
| 130 |
+
raise KeyError(f"Dataclass {self.__class__} does not have a field {key}")
|
| 131 |
+
value_type = get_maybe_optional_type(value_type)
|
| 132 |
+
if inspect.isclass(value_type) and issubclass(value_type, Dataclass):
|
| 133 |
+
if isinstance(value, dict):
|
| 134 |
+
data[key] = data[key].replace(**value)
|
| 135 |
+
else:
|
| 136 |
+
data[key] = value
|
| 137 |
+
else:
|
| 138 |
+
data[key] = value
|
| 139 |
+
return self.__class__(**data)
|
| 140 |
+
|
| 141 |
+
def apply(self, fcn: Callable, recursive: bool = True, skip_nones: bool = False) -> "Dataclass":
|
| 142 |
+
def fcn_wrapper(value: Any) -> Any:
|
| 143 |
+
if value is None and skip_nones:
|
| 144 |
+
return None
|
| 145 |
+
if isinstance(value, dict) and recursive:
|
| 146 |
+
return type(value)(**{k: fcn(v) for (k, v) in value.items()})
|
| 147 |
+
if isinstance(value, (list, tuple)) and recursive:
|
| 148 |
+
return type(value)([fcn(v) for v in value])
|
| 149 |
+
if isinstance(value, Dataclass) and recursive:
|
| 150 |
+
return value.apply(fcn, recursive=True, skip_nones=skip_nones)
|
| 151 |
+
return fcn(value)
|
| 152 |
+
|
| 153 |
+
return self.__class__(**{key: fcn_wrapper(value) for (key, value) in self.items()})
|
| 154 |
+
|
| 155 |
+
def __getitem__(self, index) -> "Dataclass":
|
| 156 |
+
def extract(obj):
|
| 157 |
+
if obj is None:
|
| 158 |
+
return None
|
| 159 |
+
if isinstance(obj, torch.Tensor):
|
| 160 |
+
return obj[index]
|
| 161 |
+
raise ValueError(f"Cannot slice {obj.__class__.__name__} object")
|
| 162 |
+
|
| 163 |
+
return self.apply(extract)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class Config:
|
| 167 |
+
def __init__(self, **kwargs):
|
| 168 |
+
self._apply_defaults()
|
| 169 |
+
self._set_attributes(**kwargs)
|
| 170 |
+
super().__init__()
|
| 171 |
+
self.__post_init__()
|
| 172 |
+
|
| 173 |
+
def _apply_defaults(self):
|
| 174 |
+
"""
|
| 175 |
+
Initializes all annotated fields with defaults or sensible instances.
|
| 176 |
+
"""
|
| 177 |
+
annotations = getattr(self, "__annotations__", {})
|
| 178 |
+
for key, type_hint in annotations.items():
|
| 179 |
+
# Skip if already set via class-level value or __init__ kwarg
|
| 180 |
+
if hasattr(self, key):
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
# Case 1: class variable has a default (declared at class level)
|
| 184 |
+
if key in self.__class__.__dict__:
|
| 185 |
+
setattr(self, key, getattr(self.__class__, key))
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
# Case 2: if the type is another Config subclass, default-construct it
|
| 189 |
+
if inspect.isclass(type_hint) and issubclass(type_hint, Config):
|
| 190 |
+
setattr(self, key, type_hint())
|
| 191 |
+
continue
|
| 192 |
+
|
| 193 |
+
# Case 3: fallback None (or empty dict for mappings)
|
| 194 |
+
if hasattr(type_hint, "__origin__") and type_hint.__origin__ in (
|
| 195 |
+
dict,
|
| 196 |
+
Dict,
|
| 197 |
+
MappingABC,
|
| 198 |
+
):
|
| 199 |
+
setattr(self, key, {})
|
| 200 |
+
else:
|
| 201 |
+
setattr(self, key, None)
|
| 202 |
+
|
| 203 |
+
def _set_attributes(self, **kwargs):
|
| 204 |
+
subconfig_types = self._subconfig_types
|
| 205 |
+
for key, value in kwargs.items():
|
| 206 |
+
if key in subconfig_types:
|
| 207 |
+
if not isinstance(value, Mapping):
|
| 208 |
+
raise ValueError(
|
| 209 |
+
f"{self.__class__.__name__}.{key} expects dict-like object for nested config, but got: {value}"
|
| 210 |
+
)
|
| 211 |
+
setattr(self, key, subconfig_types[key](**value))
|
| 212 |
+
else:
|
| 213 |
+
setattr(self, key, value)
|
| 214 |
+
|
| 215 |
+
def keys(self) -> List[str]:
|
| 216 |
+
"""Get all annotated keys including those from parent classes."""
|
| 217 |
+
all_keys = {}
|
| 218 |
+
# Walk through MRO in reverse to respect inheritance order
|
| 219 |
+
for cls in reversed(self.__class__.__mro__):
|
| 220 |
+
if cls is object:
|
| 221 |
+
continue
|
| 222 |
+
all_keys.update(getattr(cls, "__annotations__", {}))
|
| 223 |
+
return list(all_keys.keys())
|
| 224 |
+
|
| 225 |
+
def items(self) -> Iterable[Tuple[str, Any]]:
|
| 226 |
+
for key in self.keys():
|
| 227 |
+
yield (key, getattr(self, key))
|
| 228 |
+
|
| 229 |
+
@cached_classproperty
|
| 230 |
+
def _subconfig_types(cls) -> dict[str, Type]:
|
| 231 |
+
keys = {
|
| 232 |
+
key: value
|
| 233 |
+
for (key, value) in cls.__annotations__.items()
|
| 234 |
+
if inspect.isclass(value) and issubclass(value, Config)
|
| 235 |
+
}
|
| 236 |
+
for base in cls.__bases__:
|
| 237 |
+
if not issubclass(base, Config):
|
| 238 |
+
continue
|
| 239 |
+
keys = {**keys, **base._subconfig_types}
|
| 240 |
+
return keys
|
| 241 |
+
|
| 242 |
+
def __post_init__(self):
|
| 243 |
+
pass
|
| 244 |
+
|
| 245 |
+
def as_json(self) -> dict:
|
| 246 |
+
data = {}
|
| 247 |
+
for key, value in self.items():
|
| 248 |
+
if isinstance(value, Config):
|
| 249 |
+
data[key] = value.as_json()
|
| 250 |
+
elif (
|
| 251 |
+
isinstance(value, collections.abc.Sequence)
|
| 252 |
+
and len(value) > 0
|
| 253 |
+
and isinstance(value[0], Config)
|
| 254 |
+
):
|
| 255 |
+
data[key] = [v.as_json() for v in value]
|
| 256 |
+
elif (
|
| 257 |
+
isinstance(value, collections.abc.Mapping)
|
| 258 |
+
and len(value) > 0
|
| 259 |
+
and isinstance(next(iter(value.values())), Config)
|
| 260 |
+
):
|
| 261 |
+
data[key] = {k: v.as_json() for k, v in value.items()}
|
| 262 |
+
else:
|
| 263 |
+
data[key] = value
|
| 264 |
+
|
| 265 |
+
return data
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class HFConfigMixin(transformers.PretrainedConfig):
|
| 269 |
+
"""
|
| 270 |
+
Bridge between your Config system and HF PretrainedConfig.
|
| 271 |
+
|
| 272 |
+
Usage:
|
| 273 |
+
class SPEAR1Config(HFConfigMixin, Config):
|
| 274 |
+
model_type = "spear1"
|
| 275 |
+
processor_config: PaliGemmaProcessorConfig
|
| 276 |
+
...
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
def __init__(self, **kwargs):
|
| 280 |
+
# Let HF's machinery initialize its own attributes / defaults first.
|
| 281 |
+
# PretrainedConfig.__init__ will set things like `model_type`,
|
| 282 |
+
# `_name_or_path`, `architectures`, and keep a `kwargs`->dict of extra items.
|
| 283 |
+
super().__init__(**kwargs)
|
| 284 |
+
|
| 285 |
+
# Now initialize your Config behavior: set defaults and construct nested configs.
|
| 286 |
+
# We call Config.__init__ explicitly because HFConfigMixin inherits from PretrainedConfig,
|
| 287 |
+
# and the user's concrete class will use multiple-inheritance with Config.
|
| 288 |
+
# (This approach mirrors the earlier MRO design: class Concrete(HFConfigMixin, Config).)
|
| 289 |
+
# We pass kwargs again so nested configs get overridden by user kwargs.
|
| 290 |
+
# Note: Config.__init__ itself calls super().__init__() — but because we are calling
|
| 291 |
+
# Config.__init__ directly (not via super()) the MRO won't re-call PretrainedConfig.__init__ here.
|
| 292 |
+
# (I.e., we are deliberately calling the concrete base initializer.)
|
| 293 |
+
Config.__init__(self, **kwargs) # type: ignore[name-defined]
|
| 294 |
+
|
| 295 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 296 |
+
"""
|
| 297 |
+
Merge HF PretrainedConfig serialization and Config.as_json().
|
| 298 |
+
|
| 299 |
+
Strategy:
|
| 300 |
+
1. Take HF dict (super().to_dict()) so HF metadata/defaults are present.
|
| 301 |
+
2. Take our nested config dict (Config.as_json(self)).
|
| 302 |
+
3. Update the HF dict with our nested config dict so annotated fields
|
| 303 |
+
(nested configs, lists/dicts that should be recursively serialized)
|
| 304 |
+
take precedence.
|
| 305 |
+
"""
|
| 306 |
+
# HF's representation (contains model_type, etc.). This is trusted HF serialization.
|
| 307 |
+
hf = super().to_dict()
|
| 308 |
+
|
| 309 |
+
# Our nested config representation (recursively serializes Config objects).
|
| 310 |
+
# Do not call self.to_dict() because that would recurse back here.
|
| 311 |
+
cfg_json = Config.as_json(self) # type: ignore[name-defined]
|
| 312 |
+
|
| 313 |
+
# Merge: prefer cfg_json values for keys present in our config (so nested configs
|
| 314 |
+
# are represented as dicts rather than raw objects or omitted).
|
| 315 |
+
merged: Dict[str, Any] = dict(hf)
|
| 316 |
+
merged.update(cfg_json)
|
| 317 |
+
return merged
|
| 318 |
+
|
| 319 |
+
@classmethod
|
| 320 |
+
def from_dict(
|
| 321 |
+
cls: Type["HFConfigMixin"],
|
| 322 |
+
config_dict: Dict[str, Any],
|
| 323 |
+
**kwargs,
|
| 324 |
+
) -> "HFConfigMixin":
|
| 325 |
+
"""
|
| 326 |
+
Construct by delegating to the class constructor — that will instantiate nested configs.
|
| 327 |
+
This is simple and consistent with PretrainedConfig.from_dict/from_pretrained behavior.
|
| 328 |
+
"""
|
| 329 |
+
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
| 330 |
+
|
| 331 |
+
instance = cls(**config_dict)
|
| 332 |
+
|
| 333 |
+
if return_unused_kwargs:
|
| 334 |
+
# Return tuple of (instance, unused_kwargs) if requested
|
| 335 |
+
# Since we consume everything in __init__, unused is typically empty
|
| 336 |
+
return instance, {}
|
| 337 |
+
return instance
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class Configurable:
|
| 341 |
+
def __init__(self, config: Config):
|
| 342 |
+
self._config = config
|
| 343 |
+
|
| 344 |
+
@property
|
| 345 |
+
def config(self) -> Config:
|
| 346 |
+
return self._config
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
class RotationFormat(StrEnum):
|
| 350 |
+
"""Determines how rotations will be encoded in the loaded batch"""
|
| 351 |
+
|
| 352 |
+
EULER = "euler"
|
| 353 |
+
QUATERNION = "quaternion"
|
| 354 |
+
ROTMAT = "rotmat"
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class ResizeMode(StrEnum):
|
| 358 |
+
"""
|
| 359 |
+
Different modes for resizing images.
|
| 360 |
+
"""
|
| 361 |
+
|
| 362 |
+
MATCH_WIDTH = "match_width"
|
| 363 |
+
MATCH_HEIGHT = "match_height"
|
| 364 |
+
MATCH_MAX = "match_max"
|
| 365 |
+
NAIVE = "naive"
|
| 366 |
+
SMART = "smart"
|
| 367 |
+
PAD = "pad"
|
| 368 |
+
CROP = "crop"
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class Normalization(StrEnum):
|
| 372 |
+
"""Action normalization types"""
|
| 373 |
+
|
| 374 |
+
NONE = "none"
|
| 375 |
+
BOUNDS = "bounds"
|
| 376 |
+
BOUNDS_Q99 = "bounds_q99"
|
| 377 |
+
MEAN = "mean"
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def expand_dims(tensor: torch.Tensor, ndim: int, order: Sequence[int]) -> torch.Tensor:
|
| 381 |
+
"""
|
| 382 |
+
Expand the dimensions of `tensor` to `ndim` such that all new dimensions have size of 1
|
| 383 |
+
Args:
|
| 384 |
+
tensor: torch.Tensor of any shape
|
| 385 |
+
ndim: Number of output dimensions. Must be >= `tensor.ndim`
|
| 386 |
+
order: Sequence of size `tensor.ndim + 1`. Contains only values of 1 and a single value of -1,
|
| 387 |
+
indicating where the new `ndim - tensor.ndim` dimensions will be inserted
|
| 388 |
+
Returns:
|
| 389 |
+
torch.Tensor with dimensions `ndim`, a view of `tensor`
|
| 390 |
+
|
| 391 |
+
Ex:
|
| 392 |
+
expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[1, -1, 1, 1]).shape -> [2, 1, 1, 3, 4]
|
| 393 |
+
expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[-1, 1, 1, 1]).shape -> [1, 1, 2, 3, 4]
|
| 394 |
+
expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[1, 1, 1, -1]).shape -> [2, 3, 4, 1, 1]
|
| 395 |
+
"""
|
| 396 |
+
assert tensor.ndim <= ndim, f"{tensor.ndim} > {ndim}; shape={tensor.shape}"
|
| 397 |
+
assert len(order) == tensor.ndim + 1, f"{len(order)} != {tensor.ndim + 1}; shape={tensor.shape}"
|
| 398 |
+
order = list(order)
|
| 399 |
+
assert order.count(-1) == 1, "Order must have exactly one value of -1"
|
| 400 |
+
assert order.count(1) == len(order) - 1, "Order must have exactly len(order) - 1 values of 1"
|
| 401 |
+
if tensor.ndim == ndim:
|
| 402 |
+
return tensor
|
| 403 |
+
insert_index = order.index(-1)
|
| 404 |
+
view = list(tensor.shape[:insert_index]) + [1] * (ndim - tensor.ndim) + list(tensor.shape[insert_index:])
|
| 405 |
+
tensor = tensor.view(view)
|
| 406 |
+
return tensor
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def merge_dicts_recursive(dict_1: Dict[str, Any], dict_2: Dict[str, Any]) -> Dict[str, Any]:
|
| 410 |
+
"""
|
| 411 |
+
Merges dict_1 with dict_2 recursively.
|
| 412 |
+
Handles clashing keys:
|
| 413 |
+
1. If both values are dicts, merges them recursively
|
| 414 |
+
2. If any value is not a dict, raises ValueError
|
| 415 |
+
"""
|
| 416 |
+
merged = dict(dict_1)
|
| 417 |
+
for key, value in dict_2.items():
|
| 418 |
+
if key in merged:
|
| 419 |
+
if not type(merged[key]) is type(value) is dict:
|
| 420 |
+
raise ValueError(f"Multiple values provided for key {key}: {merged[key]} and {value}")
|
| 421 |
+
merged[key] = merge_dicts_recursive(merged[key], value)
|
| 422 |
+
else:
|
| 423 |
+
merged[key] = value
|
| 424 |
+
return merged
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def maybe_chained_keys_to_nested_dict(data: Dict[str, Any]) -> Dict[str, Any]:
|
| 428 |
+
"""Converts a dict with keys of the form "key1.key2.key3" to a nested dict"""
|
| 429 |
+
unpacked_data: Dict[str, Any] = {}
|
| 430 |
+
for key, value in data.items():
|
| 431 |
+
if "." not in key:
|
| 432 |
+
unpacked_data = merge_dicts_recursive(unpacked_data, {key: value})
|
| 433 |
+
else:
|
| 434 |
+
(mainkey, subkey) = key.split(".", maxsplit=1)
|
| 435 |
+
nested_value = maybe_chained_keys_to_nested_dict({subkey: value})
|
| 436 |
+
unpacked_data = merge_dicts_recursive(unpacked_data, {mainkey: nested_value})
|
| 437 |
+
return unpacked_data
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def annotation_is_union(type_value: Type) -> bool:
|
| 441 |
+
return getattr(type_value, "__origin__", None) is Union or type(type_value) is types.UnionType
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def annotation_is_optional(type_value: Type) -> bool:
|
| 445 |
+
if annotation_is_union(type_value):
|
| 446 |
+
union_args = set(type_value.__args__)
|
| 447 |
+
if len(union_args) == 2 and type(None) in union_args:
|
| 448 |
+
return True
|
| 449 |
+
return False
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def get_maybe_optional_type(type_value: Type[Optional[Any]]) -> Type[Any]:
|
| 453 |
+
if annotation_is_optional(type_value):
|
| 454 |
+
type_args = type_value.__args__
|
| 455 |
+
if type_args[1] is type(None):
|
| 456 |
+
return type_args[0]
|
| 457 |
+
return type_args[1]
|
| 458 |
+
return type_value
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
@dataclasses.dataclass
|
| 462 |
+
class RoboticsTarget(Dataclass):
|
| 463 |
+
control_tokens_ids: Optional[torch.Tensor]
|
| 464 |
+
text_tokens_ids: Optional[torch.Tensor]
|
| 465 |
+
translation: torch.Tensor
|
| 466 |
+
rotation: torch.Tensor
|
| 467 |
+
gripper: torch.Tensor
|
| 468 |
+
valid_mask: torch.Tensor
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
@dataclasses.dataclass
|
| 472 |
+
class RoboticsControlPlan(Dataclass):
|
| 473 |
+
translation_m: torch.Tensor
|
| 474 |
+
rotmat: torch.Tensor
|
| 475 |
+
gripper_prob: torch.Tensor
|
| 476 |
+
valid_mask: torch.Tensor
|
| 477 |
+
|
| 478 |
+
def __post_init__(self):
|
| 479 |
+
super().__post_init__()
|
| 480 |
+
assert self.translation_m.ndim == 3, self.translation_m.shape
|
| 481 |
+
assert self.rotmat.ndim == 3, self.rotmat.shape
|
| 482 |
+
assert self.gripper_prob.ndim == 3, self.gripper_prob.shape
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
@dataclasses.dataclass
|
| 486 |
+
class RoboticsInput(Dataclass):
|
| 487 |
+
images: Dict[str, torch.Tensor]
|
| 488 |
+
input_ids: torch.Tensor
|
| 489 |
+
attn_mask: torch.Tensor
|
| 490 |
+
ee_pose_translation: torch.Tensor
|
| 491 |
+
ee_pose_rotation: torch.Tensor
|
| 492 |
+
gripper: torch.Tensor
|
| 493 |
+
joints: torch.Tensor
|
| 494 |
+
control_tokens_ids: Optional[torch.Tensor]
|
| 495 |
+
|
| 496 |
+
@property
|
| 497 |
+
def inputs_embeds(self) -> Optional[torch.Tensor]:
|
| 498 |
+
return None
|
| 499 |
+
|
| 500 |
+
@property
|
| 501 |
+
def past_key_values(self) -> Optional[List[torch.Tensor]]:
|
| 502 |
+
return None
|
| 503 |
+
|
| 504 |
+
@cached_property
|
| 505 |
+
def multimodal_indices(self) -> torch.Tensor:
|
| 506 |
+
"""
|
| 507 |
+
Returns a torch.Tensor containing only the indices of the batch examples which are multimodal.
|
| 508 |
+
Return shape is [B]
|
| 509 |
+
"""
|
| 510 |
+
return torch.arange(self.input_ids.shape[0], dtype=torch.int64, device=self.input_ids.device)
|
| 511 |
+
|
| 512 |
+
@cached_property
|
| 513 |
+
def unimodal_indices(self) -> torch.Tensor:
|
| 514 |
+
"""
|
| 515 |
+
Returns a torch.Tensor containing only the indices of the batch examples which are unimodal.
|
| 516 |
+
Return shape is [B]
|
| 517 |
+
"""
|
| 518 |
+
return torch.tensor([], dtype=torch.int64, device=self.input_ids.device)
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
@dataclasses.dataclass
|
| 522 |
+
class FlowInput(Dataclass):
|
| 523 |
+
timestep: torch.Tensor
|
| 524 |
+
translation_t: torch.Tensor
|
| 525 |
+
rotation_t: torch.Tensor
|
| 526 |
+
gripper_t: torch.Tensor
|
| 527 |
+
translation_t0: torch.Tensor
|
| 528 |
+
rotation_t0: torch.Tensor
|
| 529 |
+
gripper_t0: torch.Tensor
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
@dataclasses.dataclass
|
| 533 |
+
class RoboticsFlowInput(RoboticsInput):
|
| 534 |
+
"""Input to the entire Robotics VLM"""
|
| 535 |
+
|
| 536 |
+
flow_input: FlowInput
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
@dataclasses.dataclass
|
| 540 |
+
class DiffusionInput(Dataclass):
|
| 541 |
+
timestep: torch.Tensor
|
| 542 |
+
noised_translation: torch.Tensor
|
| 543 |
+
noised_rotation: torch.Tensor
|
| 544 |
+
noised_gripper: torch.Tensor
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
@dataclasses.dataclass
|
| 548 |
+
class LLMOutput(Dataclass):
|
| 549 |
+
"""Fork of transformers.modeling_outputs.CausalLMOutputWithPast"""
|
| 550 |
+
|
| 551 |
+
input_ids: torch.Tensor
|
| 552 |
+
logits: Optional[torch.Tensor]
|
| 553 |
+
output_ids: Optional[torch.Tensor]
|
| 554 |
+
loss: Optional[torch.Tensor]
|
| 555 |
+
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]
|
| 556 |
+
hidden_states: List[torch.Tensor]
|
| 557 |
+
text_indices: torch.Tensor
|
| 558 |
+
image_indices: torch.Tensor
|
| 559 |
+
|
| 560 |
+
@classmethod
|
| 561 |
+
def from_transformers(
|
| 562 |
+
cls,
|
| 563 |
+
input_ids: torch.Tensor,
|
| 564 |
+
llm_output: transformers.modeling_outputs.CausalLMOutputWithPast,
|
| 565 |
+
text_indices: Optional[torch.Tensor],
|
| 566 |
+
image_indices: Optional[torch.Tensor],
|
| 567 |
+
) -> "LLMOutput":
|
| 568 |
+
return LLMOutput(
|
| 569 |
+
input_ids=input_ids,
|
| 570 |
+
logits=llm_output.logits,
|
| 571 |
+
output_ids=None,
|
| 572 |
+
loss=llm_output.loss,
|
| 573 |
+
past_key_values=(
|
| 574 |
+
list(llm_output.past_key_values) if llm_output.past_key_values is not None else []
|
| 575 |
+
),
|
| 576 |
+
hidden_states=(list(llm_output.hidden_states) if llm_output.hidden_states is not None else []),
|
| 577 |
+
text_indices=text_indices,
|
| 578 |
+
image_indices=image_indices,
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
def compress(self) -> "LLMOutput":
|
| 582 |
+
"""
|
| 583 |
+
Compress the data contained in the class so it can be moved between CPU and GPU or concatenated
|
| 584 |
+
much faster:
|
| 585 |
+
- hidden_states - huge tensors; take a lot of CPU time to move across devices or concat
|
| 586 |
+
- past_key_values - huge tensors; take a lot of CPU time to move across devices or concat
|
| 587 |
+
- logits - huge last dimension; takes a lot of CPU time to move across devices or concat
|
| 588 |
+
"""
|
| 589 |
+
replace: Dict[str, Any] = {
|
| 590 |
+
"hidden_states": [],
|
| 591 |
+
"past_key_values": [],
|
| 592 |
+
"loss": None,
|
| 593 |
+
"input_ids": None,
|
| 594 |
+
}
|
| 595 |
+
if self.logits is not None:
|
| 596 |
+
replace["logits"] = None
|
| 597 |
+
if self.output_ids is None or self.output_ids.shape[1] != self.text_indices.shape[0]:
|
| 598 |
+
replace["output_ids"] = (
|
| 599 |
+
torch.index_select(self.logits, dim=1, index=self.text_indices)
|
| 600 |
+
.argmax(dim=-1)
|
| 601 |
+
.to(dtype=torch.int64)
|
| 602 |
+
)
|
| 603 |
+
return self.replace(**replace)
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
@dataclasses.dataclass
|
| 607 |
+
class RoboticsOutput(Dataclass):
|
| 608 |
+
translation: Optional[torch.Tensor]
|
| 609 |
+
rotation: Optional[torch.Tensor]
|
| 610 |
+
gripper: Optional[torch.Tensor]
|
| 611 |
+
token_logits: Optional[torch.Tensor]
|
| 612 |
+
token_ids: Optional[torch.Tensor]
|
| 613 |
+
llm_output: LLMOutput
|
| 614 |
+
|
| 615 |
+
def compress(self) -> "RoboticsOutput":
|
| 616 |
+
"""
|
| 617 |
+
Compress output and drop unnecessary components to speed up transfer GPU <-> CPU.
|
| 618 |
+
Note that LLM logits can be extremely expensive since their size is [B, S, vocab_size], which
|
| 619 |
+
can reach millions or billions of values for large vocab_size
|
| 620 |
+
"""
|
| 621 |
+
replace: Dict[str, Any] = {
|
| 622 |
+
"llm_output": self.llm_output.compress(),
|
| 623 |
+
"token_logits": None,
|
| 624 |
+
}
|
| 625 |
+
if self.token_logits is not None and self.token_ids is None:
|
| 626 |
+
replace["token_ids"] = torch.argmax(self.token_logits, dim=-1)
|
| 627 |
+
return self.replace(**replace)
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
@dataclasses.dataclass
|
| 631 |
+
class VLMOutput(Dataclass):
|
| 632 |
+
llm_output: LLMOutput
|
| 633 |
+
vit_tokens: Optional[torch.Tensor]
|
| 634 |
+
attn_mask: torch.Tensor
|
| 635 |
+
|
| 636 |
+
def compress(self) -> "VLMOutput":
|
| 637 |
+
"""
|
| 638 |
+
Compress output and drop unnecessary components to speed up transfer GPU <-> CPU.
|
| 639 |
+
Note that LLM logits can be extremely expensive since their size is [B, S, vocab_size], which
|
| 640 |
+
can reach millions or billions of values for large vocab_size
|
| 641 |
+
"""
|
| 642 |
+
return self.replace(llm_output=self.llm_output.compress())
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
def is_quaternion(quaternion: torch.Tensor) -> bool:
|
| 646 |
+
return quaternion.shape[-1] == 4
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
def quaternion_half_cover(quaternion: torch.Tensor) -> torch.Tensor:
|
| 650 |
+
"""
|
| 651 |
+
Flip quaternions so they cover only a half the space. If the q_w is negative, flip the quaternion.
|
| 652 |
+
If q_w is 0, then choose such that the first non-zero component is positive. Note that geometrically,
|
| 653 |
+
this doesn't correspond to a single hemisphere of the unit sphere. Follows
|
| 654 |
+
https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.as_quat.html#scipy.spatial.transform.Rotation.as_quat
|
| 655 |
+
"""
|
| 656 |
+
assert is_quaternion(quaternion), quaternion.shape
|
| 657 |
+
with torch.no_grad():
|
| 658 |
+
is_zero = quaternion == 0
|
| 659 |
+
flip_condition = (
|
| 660 |
+
(quaternion[..., -1:] < 0)
|
| 661 |
+
| is_zero[..., -1:] & (quaternion[..., 0:1] < 0)
|
| 662 |
+
| is_zero[..., -1:] & is_zero[..., 0:1] & (quaternion[..., 1:2] < 0)
|
| 663 |
+
| is_zero[..., -1:] & is_zero[..., 0:1] & is_zero[..., 1:2] & (quaternion[..., 2:3] < 0)
|
| 664 |
+
)
|
| 665 |
+
quaternion = torch.where(flip_condition, -quaternion, quaternion)
|
| 666 |
+
return quaternion
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
def is_rotmat_3x3(rotmat: torch.Tensor) -> bool:
|
| 670 |
+
return rotmat.shape[-2:] == torch.Size([3, 3])
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
def is_rotmat_9(rotmat: torch.Tensor) -> bool:
|
| 674 |
+
return rotmat.shape[-1] == 9
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
def rotmat_as_9(rotmat: torch.Tensor) -> torch.Tensor:
|
| 678 |
+
"""Convert any rotmat input to [..., 9] shape"""
|
| 679 |
+
if is_rotmat_9(rotmat):
|
| 680 |
+
return rotmat
|
| 681 |
+
if is_rotmat_3x3(rotmat):
|
| 682 |
+
return rotmat.reshape(*rotmat.shape[:-2], 9)
|
| 683 |
+
raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix")
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
def is_rotmat(rotmat: torch.Tensor) -> bool:
|
| 687 |
+
"""
|
| 688 |
+
Checks if the tensor shape matches that of a rotmat. However, it's not guaranteed the data is a
|
| 689 |
+
valid rotmat. `is_orthonormal_rotmat` performs this additional check.
|
| 690 |
+
NOTE: This might incorrectly return True if the underlying data is euler angles and accidentally
|
| 691 |
+
`rotmat.shape[-2:] == [3, 3]`. This would happen very rarely, but use with caution
|
| 692 |
+
"""
|
| 693 |
+
return is_rotmat_3x3(rotmat) or is_rotmat_9(rotmat)
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def rotmat_as_3x3(rotmat: torch.Tensor) -> torch.Tensor:
|
| 697 |
+
"""Convert any rotmat input to [..., 3, 3] shape"""
|
| 698 |
+
if rotmat.shape[-1] == 9:
|
| 699 |
+
return rotmat.reshape(*rotmat.shape[:-1], 3, 3)
|
| 700 |
+
if rotmat.shape[-2:] == torch.Size([3, 3]):
|
| 701 |
+
return rotmat
|
| 702 |
+
raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix")
|
config.json
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_auto_class": null,
|
| 3 |
+
"_name_or_path": "/scratch/giuliano_albanese/spear-hf",
|
| 4 |
+
"architectures": [
|
| 5 |
+
"SPEAR1"
|
| 6 |
+
],
|
| 7 |
+
"attribute_map": {},
|
| 8 |
+
"auto_map": {
|
| 9 |
+
"AutoConfig": "configuration_spear.SPEAR1Config",
|
| 10 |
+
"AutoModel": "modeling_spear.SPEAR1"
|
| 11 |
+
},
|
| 12 |
+
"autoclass": "barrel.pipes.vlams.models.vlams.vlam.VLAM",
|
| 13 |
+
"base_config_key": "",
|
| 14 |
+
"control_module_config": {
|
| 15 |
+
"control_decoder_config": {
|
| 16 |
+
"block_config": {
|
| 17 |
+
"activation": "GELU",
|
| 18 |
+
"attn_implementation": "sdpa",
|
| 19 |
+
"dropout": 0.0,
|
| 20 |
+
"feature_size": 1024,
|
| 21 |
+
"head_dim": 256,
|
| 22 |
+
"hidden_size": 4096,
|
| 23 |
+
"norm": "RMSNorm",
|
| 24 |
+
"num_heads": 8,
|
| 25 |
+
"num_kv_heads": 1,
|
| 26 |
+
"position_embed_config": {
|
| 27 |
+
"base": 10000,
|
| 28 |
+
"cached": true,
|
| 29 |
+
"embedding_dim": 256,
|
| 30 |
+
"num_embeddings": 512
|
| 31 |
+
}
|
| 32 |
+
},
|
| 33 |
+
"num_blocks": 18
|
| 34 |
+
},
|
| 35 |
+
"noised_control_proj_config": {
|
| 36 |
+
"activation": "SiLU",
|
| 37 |
+
"layers": [
|
| 38 |
+
8,
|
| 39 |
+
2048,
|
| 40 |
+
1024,
|
| 41 |
+
1024
|
| 42 |
+
],
|
| 43 |
+
"norm": null,
|
| 44 |
+
"time_embed": {
|
| 45 |
+
"activation": "SiLU",
|
| 46 |
+
"layers": [],
|
| 47 |
+
"learnable_features": false,
|
| 48 |
+
"max_period": 10000.0,
|
| 49 |
+
"norm": null,
|
| 50 |
+
"num_features": 1024
|
| 51 |
+
}
|
| 52 |
+
},
|
| 53 |
+
"robot_state_proj_config": {
|
| 54 |
+
"activation": "SiLU",
|
| 55 |
+
"fourier": false,
|
| 56 |
+
"layers": [
|
| 57 |
+
8,
|
| 58 |
+
1024
|
| 59 |
+
],
|
| 60 |
+
"mode": "ee_pose_gripper"
|
| 61 |
+
},
|
| 62 |
+
"rotation_components": 4,
|
| 63 |
+
"token_size": 1024
|
| 64 |
+
},
|
| 65 |
+
"is_composition": false,
|
| 66 |
+
"model_type": "spear1",
|
| 67 |
+
"processor_config": {
|
| 68 |
+
"control_io_config": {
|
| 69 |
+
"future_control_offset_sec": 0.0,
|
| 70 |
+
"future_controls_sequence_length": 5,
|
| 71 |
+
"future_controls_sequence_stride_sec": 0.2,
|
| 72 |
+
"future_frames_sequence_length": 1,
|
| 73 |
+
"future_frames_sequence_stride_sec": null,
|
| 74 |
+
"past_frames_sequence_length": 1,
|
| 75 |
+
"past_frames_stride_sec": null,
|
| 76 |
+
"past_scalars_sequence_length": 1,
|
| 77 |
+
"past_scalars_stride_sec": null,
|
| 78 |
+
"sequence_frames": 1,
|
| 79 |
+
"sequence_frames_stride_sec": null
|
| 80 |
+
},
|
| 81 |
+
"control_stats_path": "barrel/pipes/vlams/types/control_stats.yaml",
|
| 82 |
+
"control_tokenizer_config": {},
|
| 83 |
+
"delta_controls": true,
|
| 84 |
+
"distribution_hyperparams": {
|
| 85 |
+
"alpha": 1.5,
|
| 86 |
+
"beta": 1.0
|
| 87 |
+
},
|
| 88 |
+
"eef_control_frame": false,
|
| 89 |
+
"image_resize": "smart",
|
| 90 |
+
"joints_norm": {
|
| 91 |
+
"high": [
|
| 92 |
+
3.141592653589793,
|
| 93 |
+
3.141592653589793,
|
| 94 |
+
3.141592653589793,
|
| 95 |
+
3.141592653589793,
|
| 96 |
+
3.141592653589793,
|
| 97 |
+
3.141592653589793,
|
| 98 |
+
3.141592653589793
|
| 99 |
+
],
|
| 100 |
+
"low": [
|
| 101 |
+
-3.141592653589793,
|
| 102 |
+
-3.141592653589793,
|
| 103 |
+
-3.141592653589793,
|
| 104 |
+
-3.141592653589793,
|
| 105 |
+
-3.141592653589793,
|
| 106 |
+
-3.141592653589793,
|
| 107 |
+
-3.141592653589793
|
| 108 |
+
]
|
| 109 |
+
},
|
| 110 |
+
"num_inference_steps": 10,
|
| 111 |
+
"obs_rotation_norm": "none",
|
| 112 |
+
"obs_translation_norm": "bounds_q99",
|
| 113 |
+
"observation_stats_path": "barrel/pipes/vlams/types/observation_stats.yaml",
|
| 114 |
+
"r0_distribution": "uniform",
|
| 115 |
+
"rotation_format": "quaternion",
|
| 116 |
+
"rotation_norm": "none",
|
| 117 |
+
"sig_min": 0.001,
|
| 118 |
+
"timestep_distribution": "beta",
|
| 119 |
+
"translation_norm": {
|
| 120 |
+
"high": [
|
| 121 |
+
0.04,
|
| 122 |
+
0.04,
|
| 123 |
+
0.04
|
| 124 |
+
],
|
| 125 |
+
"low": [
|
| 126 |
+
-0.04,
|
| 127 |
+
-0.04,
|
| 128 |
+
-0.04
|
| 129 |
+
]
|
| 130 |
+
}
|
| 131 |
+
},
|
| 132 |
+
"sub_configs": {},
|
| 133 |
+
"torch_dtype": "float32",
|
| 134 |
+
"transformers_version": "4.47.0",
|
| 135 |
+
"vlm_config": {
|
| 136 |
+
"attn_implementation": "flash_attention_2",
|
| 137 |
+
"depth_tokens": 1024,
|
| 138 |
+
"lm_head": false,
|
| 139 |
+
"mean_resizing": false,
|
| 140 |
+
"model_id": "google/paligemma-3b-mix-224",
|
| 141 |
+
"paligemma_3d_config": {
|
| 142 |
+
"depth_config": {
|
| 143 |
+
"hf_filename": "moge/moge-vit-large-patch-14-backbone.pt",
|
| 144 |
+
"hf_hub_repo": "nikonikolov/vlams"
|
| 145 |
+
},
|
| 146 |
+
"depth_layers": 4,
|
| 147 |
+
"depth_only": false,
|
| 148 |
+
"mask_prob": 0.0,
|
| 149 |
+
"projection": "features_add"
|
| 150 |
+
},
|
| 151 |
+
"processor_config": {
|
| 152 |
+
"image_sizes": {
|
| 153 |
+
"main": {
|
| 154 |
+
"height": 210,
|
| 155 |
+
"width": 280
|
| 156 |
+
},
|
| 157 |
+
"wrist": {
|
| 158 |
+
"height": 112,
|
| 159 |
+
"width": 112
|
| 160 |
+
}
|
| 161 |
+
},
|
| 162 |
+
"image_token": "<image>",
|
| 163 |
+
"max_language_tokens": 75
|
| 164 |
+
},
|
| 165 |
+
"train_only_depth_tokens": false
|
| 166 |
+
}
|
| 167 |
+
}
|
configuration_spear.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import collections.abc
|
| 3 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from .common_spear import (
|
| 8 |
+
Config,
|
| 9 |
+
HFConfigMixin,
|
| 10 |
+
Normalization,
|
| 11 |
+
ResizeMode,
|
| 12 |
+
RotationFormat,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class InputSequencingConfig(Config):
|
| 17 |
+
"""
|
| 18 |
+
past_frames_sequence_length: number of past images needed in a single robot state
|
| 19 |
+
past_scalars_sequence_length: number of past scalar state data, e.g. actions, poses, etc,
|
| 20 |
+
needed in a single robot state
|
| 21 |
+
past_frames_stride_sec: sampling rate, determines how far apart in time each point in the sequence
|
| 22 |
+
is. If None, ignored and takes the default data collection frequency from the dataset
|
| 23 |
+
past_scalars_stride_sec: similar to past_frames_stride_sec
|
| 24 |
+
|
| 25 |
+
sequence_frames: number of temporally-sequential points in a single example in the batch
|
| 26 |
+
sequence_frames_stride_sec: sampling rate
|
| 27 |
+
|
| 28 |
+
Understanding sequence_frames:
|
| 29 |
+
TODO: sequences are possibly useful in some rare cases, maybe sequence modeling problems,
|
| 30 |
+
but yet to be confirmed. Keeping for now, but could be removed if proved unnecessary
|
| 31 |
+
|
| 32 |
+
- past_scalars_sequence_length, past_frames_sequence_length, future_controls_sequence_length,
|
| 33 |
+
future_frames_sequence_length are hyperparameters refering to a SINGLE dataset example / 'state'.
|
| 34 |
+
It is assumed that `past_scalars_sequence_length` and `past_frames_sequence_length` are the min
|
| 35 |
+
number of observations that comprise a single 'state'
|
| 36 |
+
- sequence_frames is a hyperparameter refering to the entire learning process. It controls the size
|
| 37 |
+
of the sequence dimension in the batch. It's treated similarly to the batch dimension, with the
|
| 38 |
+
difference that points in the sequence dimensions are temporally aligned. Unlike `past_*`
|
| 39 |
+
attributes, in supervised learning a label is loaded for every point in the sequence dimension
|
| 40 |
+
and the loss usually computed over the entire sequence dimension.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
past_scalars_sequence_length: int = 1
|
| 44 |
+
past_frames_sequence_length: int = 1
|
| 45 |
+
past_scalars_stride_sec: Optional[float] = None
|
| 46 |
+
past_frames_stride_sec: Optional[float] = None
|
| 47 |
+
sequence_frames: int = 1
|
| 48 |
+
sequence_frames_stride_sec: Optional[float] = None
|
| 49 |
+
|
| 50 |
+
def __post_init__(self):
|
| 51 |
+
super().__post_init__()
|
| 52 |
+
assert self.past_scalars_sequence_length >= 1, self.past_scalars_sequence_length
|
| 53 |
+
assert self.past_frames_sequence_length >= 1, self.past_frames_sequence_length
|
| 54 |
+
assert self.sequence_frames >= 1, self.sequence_frames
|
| 55 |
+
if self.past_frames_stride_sec is not None:
|
| 56 |
+
assert self.past_frames_stride_sec >= 0.0, self.past_frames_stride_sec
|
| 57 |
+
if self.past_scalars_stride_sec is not None:
|
| 58 |
+
assert self.past_scalars_stride_sec >= 0.0, self.past_scalars_stride_sec
|
| 59 |
+
if self.sequence_frames_stride_sec is not None:
|
| 60 |
+
assert self.sequence_frames_stride_sec >= 0.0, self.sequence_frames_stride_sec
|
| 61 |
+
|
| 62 |
+
def assert_same_past(self) -> None:
|
| 63 |
+
assert (
|
| 64 |
+
self.past_frames_stride_sec == self.past_scalars_stride_sec
|
| 65 |
+
), f"{self.past_frames_stride_sec} != {self.past_scalars_stride_sec}"
|
| 66 |
+
assert (
|
| 67 |
+
self.past_frames_sequence_length == self.past_scalars_sequence_length
|
| 68 |
+
), f"{self.past_frames_sequence_length} != {self.past_scalars_sequence_length}"
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class OutputSequencingConfig(Config):
|
| 72 |
+
"""
|
| 73 |
+
future_controls_sequence_length: number of control steps in the future the model predicts
|
| 74 |
+
future_frames_sequence_length: number of future frames the model predicts
|
| 75 |
+
(only relevant for neural networks that learn some sort of a world model)
|
| 76 |
+
|
| 77 |
+
future_controls_sequence_stride_sec / future_frames_sequence_stride_sec: sampling rate
|
| 78 |
+
that determines how far apart in time each point in the sequence is. If None,
|
| 79 |
+
ignored and takes the default data collection frequency from the dataset
|
| 80 |
+
|
| 81 |
+
future_control_offset_sec: time interval between the last observation and the first
|
| 82 |
+
point at which control is predicted. Serves as a 'causality hyperparameter', allowing
|
| 83 |
+
for predicting controls slightly further into the future in environments with dynamics
|
| 84 |
+
where the observed effects of an action appear slightly later
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
future_controls_sequence_length: int = 1
|
| 88 |
+
future_controls_sequence_stride_sec: Optional[float] = None
|
| 89 |
+
future_frames_sequence_length: int = 1
|
| 90 |
+
future_frames_sequence_stride_sec: Optional[float] = None
|
| 91 |
+
future_control_offset_sec: float = 0.0
|
| 92 |
+
|
| 93 |
+
def __post_init__(self):
|
| 94 |
+
super().__post_init__()
|
| 95 |
+
assert self.future_controls_sequence_length >= 1, self.future_controls_sequence_length
|
| 96 |
+
assert self.future_frames_sequence_length >= 1, self.future_frames_sequence_length
|
| 97 |
+
assert self.future_control_offset_sec >= 0.0, self.future_control_offset_sec
|
| 98 |
+
if self.future_controls_sequence_stride_sec is not None:
|
| 99 |
+
assert self.future_controls_sequence_stride_sec >= 0.0, self.future_controls_sequence_stride_sec
|
| 100 |
+
if self.future_frames_sequence_stride_sec is not None:
|
| 101 |
+
assert self.future_frames_sequence_stride_sec >= 0.0, self.future_frames_sequence_stride_sec
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class ControlDataIOConfig(InputSequencingConfig, OutputSequencingConfig):
|
| 105 |
+
pass
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class ControlTokenizerConfig(Config):
|
| 109 |
+
pass
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class EmptyTokenizerConfig(ControlTokenizerConfig):
|
| 113 |
+
pass
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class VLAMProcessorConfig(Config):
|
| 117 |
+
control_io_config: ControlDataIOConfig = ControlDataIOConfig()
|
| 118 |
+
obs_translation_norm: Normalization | Dict[str, Tuple[float, float, float]] = Normalization.NONE
|
| 119 |
+
obs_rotation_norm: Normalization = Normalization.NONE
|
| 120 |
+
translation_norm: Normalization | Dict[str, Tuple[float, float, float]] = Normalization.NONE
|
| 121 |
+
rotation_norm: Normalization = Normalization.NONE
|
| 122 |
+
joints_norm: Dict[str, Tuple[float, ...]] = {
|
| 123 |
+
"low": (-np.pi,) * 7,
|
| 124 |
+
"high": (np.pi,) * 7,
|
| 125 |
+
}
|
| 126 |
+
rotation_format: RotationFormat = RotationFormat.QUATERNION
|
| 127 |
+
eef_control_frame: bool = False
|
| 128 |
+
delta_controls: bool = False
|
| 129 |
+
image_resize: ResizeMode = ResizeMode.SMART
|
| 130 |
+
control_tokenizer_config: EmptyTokenizerConfig = EmptyTokenizerConfig()
|
| 131 |
+
control_stats_path: str = "barrel/pipes/vlams/types/control_stats.yaml"
|
| 132 |
+
observation_stats_path: str = "barrel/pipes/vlams/types/observation_stats.yaml"
|
| 133 |
+
|
| 134 |
+
def __post_init__(self):
|
| 135 |
+
super().__post_init__()
|
| 136 |
+
if isinstance(self.translation_norm, collections.abc.Mapping):
|
| 137 |
+
assert all((len(value) == 3 for value in self.translation_norm.values())), self.translation_norm
|
| 138 |
+
assert set(self.translation_norm.keys()) in (
|
| 139 |
+
{"low", "high"},
|
| 140 |
+
{"mean", "std"},
|
| 141 |
+
), self.translation_norm
|
| 142 |
+
assert isinstance(self.joints_norm, collections.abc.Mapping), type(self.joints_norm)
|
| 143 |
+
assert all((len(value) == 7 for value in self.joints_norm.values())), self.joints_norm
|
| 144 |
+
assert set(self.joints_norm.keys()) in (
|
| 145 |
+
{"low", "high"},
|
| 146 |
+
{"mean", "std"},
|
| 147 |
+
), self.joints_norm
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class RegressionProcessorConfig(VLAMProcessorConfig):
|
| 151 |
+
pass
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class PiZeroFlowProcessorConfig(RegressionProcessorConfig):
|
| 155 |
+
num_inference_steps: int
|
| 156 |
+
r0_distribution: str = "uniform"
|
| 157 |
+
timestep_distribution: str
|
| 158 |
+
distribution_hyperparams: Dict[str, Any] = {}
|
| 159 |
+
sig_min: float = 0.001
|
| 160 |
+
|
| 161 |
+
def __post_init__(self):
|
| 162 |
+
super().__post_init__()
|
| 163 |
+
assert self.r0_distribution in ["normal", "uniform"]
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class VLMConfig(Config):
|
| 167 |
+
pass
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class VLMProcessorConfig(Config):
|
| 171 |
+
pass
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class ImageSizeConfig(Config):
|
| 175 |
+
width: int
|
| 176 |
+
height: int
|
| 177 |
+
|
| 178 |
+
def to_dict(self):
|
| 179 |
+
return {"width": self.width, "height": self.height}
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class PaliGemmaProcessorConfig(Config):
|
| 183 |
+
image_token: str = "<image>"
|
| 184 |
+
image_sizes: Dict[str, ImageSizeConfig] = {"main": ImageSizeConfig(width=224, height=224)}
|
| 185 |
+
max_language_tokens: int = 75
|
| 186 |
+
|
| 187 |
+
def __post_init__(self):
|
| 188 |
+
super().__post_init__()
|
| 189 |
+
self.image_sizes = {
|
| 190 |
+
camera_name: (
|
| 191 |
+
ImageSizeConfig(**camera_image_size)
|
| 192 |
+
if not isinstance(camera_image_size, ImageSizeConfig)
|
| 193 |
+
else camera_image_size
|
| 194 |
+
)
|
| 195 |
+
for camera_name, camera_image_size in self.image_sizes.items()
|
| 196 |
+
}
|
| 197 |
+
for camera_name, camera_image_size in self.image_sizes.items():
|
| 198 |
+
assert camera_image_size.height % 14 == 0, f"{camera_name}: {camera_image_size}"
|
| 199 |
+
assert camera_image_size.width % 14 == 0, f"{camera_name}: {camera_image_size}"
|
| 200 |
+
|
| 201 |
+
@property
|
| 202 |
+
def num_image_tokens(self) -> Dict[str, int]:
|
| 203 |
+
return {
|
| 204 |
+
camera_name: camera_image_size.height // 14 * (camera_image_size.width // 14)
|
| 205 |
+
for (camera_name, camera_image_size) in self.image_sizes.items()
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
@property
|
| 209 |
+
def is_single_image_size(self) -> bool:
|
| 210 |
+
return (
|
| 211 |
+
len(self.image_sizes) == 1
|
| 212 |
+
or len(set(((image_size.height, image_size.width) for image_size in self.image_sizes.values())))
|
| 213 |
+
== 1
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
@property
|
| 217 |
+
def camera_names(self) -> List[str]:
|
| 218 |
+
return list(self.image_sizes.keys())
|
| 219 |
+
|
| 220 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 221 |
+
base_dict = {
|
| 222 |
+
"image_token": self.image_token,
|
| 223 |
+
"max_language_tokens": self.max_language_tokens,
|
| 224 |
+
}
|
| 225 |
+
base_dict["image_sizes"] = {
|
| 226 |
+
camera_name: camera_image_size.to_dict()
|
| 227 |
+
for camera_name, camera_image_size in self.image_sizes.items()
|
| 228 |
+
}
|
| 229 |
+
return base_dict
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class PaliGemmaVLMConfig(Config):
|
| 233 |
+
model_id: str = "google/paligemma-3b-mix-224"
|
| 234 |
+
attn_implementation: str = "flash_attention_2"
|
| 235 |
+
processor_config: PaliGemmaProcessorConfig
|
| 236 |
+
lm_head: bool = False
|
| 237 |
+
paligemma_3d_config: Dict[str, Any] = {}
|
| 238 |
+
depth_tokens: int = 0
|
| 239 |
+
train_only_depth_tokens: bool = False
|
| 240 |
+
mean_resizing: bool = False
|
| 241 |
+
|
| 242 |
+
def __post_init__(self):
|
| 243 |
+
super().__post_init__()
|
| 244 |
+
if self.train_only_depth_tokens:
|
| 245 |
+
assert self.depth_tokens > 0, self.depth_tokens
|
| 246 |
+
if self.paligemma_3d_config.get("mask_prob", 0.0) != 0.0:
|
| 247 |
+
raise NotImplementedError(
|
| 248 |
+
f"Masking is deprecated, but got mask_prob={self.paligemma_3d_config['mask_prob']}"
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
@property
|
| 252 |
+
def paligemma_3d_config_dict(self) -> Dict[str, Any]:
|
| 253 |
+
if len(self.paligemma_3d_config) == 0:
|
| 254 |
+
return {}
|
| 255 |
+
config = dict(self.paligemma_3d_config)
|
| 256 |
+
config["depth_config"] = dict(config["depth_config"])
|
| 257 |
+
config["depth_config"]["image_sizes"] = {
|
| 258 |
+
camera_name: camera_image_size.to_dict()
|
| 259 |
+
for camera_name, camera_image_size in self.processor_config.image_sizes.items()
|
| 260 |
+
}
|
| 261 |
+
return config
|
| 262 |
+
|
| 263 |
+
@property
|
| 264 |
+
def with_depth(self) -> bool:
|
| 265 |
+
return len(self.paligemma_3d_config) > 0
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class FourierFeaturesConfig(Config):
|
| 269 |
+
num_features: int = 256
|
| 270 |
+
learnable_features: bool = False
|
| 271 |
+
max_period: float = 10000.0
|
| 272 |
+
layers: List[int] = [256, 512, 256]
|
| 273 |
+
activation: str = "SiLU"
|
| 274 |
+
norm: Optional[str] = None
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class NoisedControlProjectorConfig(Config):
|
| 278 |
+
time_embed: FourierFeaturesConfig
|
| 279 |
+
layers: List[int] = []
|
| 280 |
+
activation: str = "SiLU"
|
| 281 |
+
norm: Optional[str] = None
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class RobotStateProjectorConfig(Config):
|
| 285 |
+
layers: List[int] = []
|
| 286 |
+
mode: str = "none"
|
| 287 |
+
activation: str = "GELU"
|
| 288 |
+
fourier: bool = False
|
| 289 |
+
|
| 290 |
+
def __post_init__(self):
|
| 291 |
+
super().__post_init__()
|
| 292 |
+
assert self.mode in [
|
| 293 |
+
"ee_pose",
|
| 294 |
+
"ee_pose_gripper",
|
| 295 |
+
"ee_pose_joints",
|
| 296 |
+
"joints",
|
| 297 |
+
"all",
|
| 298 |
+
"none",
|
| 299 |
+
], self.mode
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class RotaryPositionalEncodingConfig(Config):
|
| 303 |
+
num_embeddings: int
|
| 304 |
+
embedding_dim: int
|
| 305 |
+
base: int = 10000
|
| 306 |
+
cached: bool = True
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class PiZeroFlowMatchingDecoderBlockConfig(Config):
|
| 310 |
+
feature_size: int
|
| 311 |
+
head_dim: int = 128
|
| 312 |
+
num_heads: int = 32
|
| 313 |
+
num_kv_heads: int = 1
|
| 314 |
+
hidden_size: int
|
| 315 |
+
activation: str = "GELU"
|
| 316 |
+
norm: str = "RMSNorm"
|
| 317 |
+
dropout: float = 0.0
|
| 318 |
+
attn_implementation: str = "sdpa"
|
| 319 |
+
position_embed_config: RotaryPositionalEncodingConfig
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class PiZeroFlowMatchingDecoderConfig(Config):
|
| 323 |
+
num_blocks: int
|
| 324 |
+
block_config: PiZeroFlowMatchingDecoderBlockConfig
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class PiZeroFlowMatchingModuleConfig(Config):
|
| 328 |
+
token_size: int = 1024
|
| 329 |
+
noised_control_proj_config: NoisedControlProjectorConfig
|
| 330 |
+
robot_state_proj_config: RobotStateProjectorConfig
|
| 331 |
+
control_decoder_config: PiZeroFlowMatchingDecoderConfig
|
| 332 |
+
rotation_components: int = 3
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class SPEAR1Config(HFConfigMixin, Config):
|
| 336 |
+
model_type: str = "spear1"
|
| 337 |
+
processor_config: PiZeroFlowProcessorConfig
|
| 338 |
+
vlm_config: PaliGemmaVLMConfig
|
| 339 |
+
control_module_config: PiZeroFlowMatchingModuleConfig
|
| 340 |
+
|
| 341 |
+
def __init__(self, **kwargs):
|
| 342 |
+
if "auto_map" not in kwargs:
|
| 343 |
+
kwargs["auto_map"] = {
|
| 344 |
+
"AutoConfig": "configuration_spear.SPEAR1Config",
|
| 345 |
+
"AutoModel": "modeling_spear.SPEAR1",
|
| 346 |
+
}
|
| 347 |
+
super().__init__(**kwargs)
|
generation_config.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"transformers_version": "4.47.0"
|
| 3 |
+
}
|
model-00001-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a0992d3b5ffdc8b896812ed19801bc9ebda65708237681ced90e642c90e0a0d2
|
| 3 |
+
size 4962008480
|
model-00002-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:db48d29ee9567705a81718181eac6c644d2d996f1e91c497e8c891702050c36e
|
| 3 |
+
size 4999821656
|
model-00003-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1c7e1d6dae46553546f53a3c9fa76a8a2d2e07664a575ce38962ae2930eb7562
|
| 3 |
+
size 4245980072
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_spear.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
processing_spear.py
ADDED
|
@@ -0,0 +1,1897 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import collections.abc
|
| 3 |
+
import re
|
| 4 |
+
import warnings
|
| 5 |
+
from abc import abstractmethod
|
| 6 |
+
from functools import cached_property
|
| 7 |
+
from typing import Dict, List, Optional, Sequence, Tuple, TypeVar
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import PIL.Image
|
| 11 |
+
import roma
|
| 12 |
+
import torch
|
| 13 |
+
import torchvision.transforms.v2
|
| 14 |
+
import transformers
|
| 15 |
+
import yaml
|
| 16 |
+
|
| 17 |
+
from .common_spear import (
|
| 18 |
+
Configurable,
|
| 19 |
+
FlowInput,
|
| 20 |
+
Normalization,
|
| 21 |
+
ResizeMode,
|
| 22 |
+
RoboticsControlPlan,
|
| 23 |
+
RoboticsFlowInput,
|
| 24 |
+
RoboticsInput,
|
| 25 |
+
RoboticsOutput,
|
| 26 |
+
RoboticsTarget,
|
| 27 |
+
RotationFormat,
|
| 28 |
+
expand_dims,
|
| 29 |
+
is_quaternion,
|
| 30 |
+
is_rotmat,
|
| 31 |
+
is_rotmat_3x3,
|
| 32 |
+
is_rotmat_9,
|
| 33 |
+
quaternion_half_cover,
|
| 34 |
+
rotmat_as_3x3,
|
| 35 |
+
rotmat_as_9,
|
| 36 |
+
)
|
| 37 |
+
from .configuration_spear import (
|
| 38 |
+
ControlDataIOConfig,
|
| 39 |
+
ImageSizeConfig,
|
| 40 |
+
PaliGemmaProcessorConfig,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class VLMProcessor(Configurable):
|
| 45 |
+
@abstractmethod
|
| 46 |
+
def preprocess_inputs(
|
| 47 |
+
self, chat: List[str], images: Dict[str, List[PIL.Image.Image]]
|
| 48 |
+
) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: ...
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
@abstractmethod
|
| 52 |
+
def tokenizer(self) -> transformers.PreTrainedTokenizerBase:
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
@abstractmethod
|
| 57 |
+
def image_sizes(self) -> Dict[str, ImageSizeConfig]:
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class EmptyTokenizer(Configurable):
|
| 62 |
+
"""
|
| 63 |
+
Takes the LLM hidden states from `llm_layer_indices` and concatenates them to produce the
|
| 64 |
+
desired result. Includes the hidden states for the image tokens.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self, config, tokenizer: transformers.PreTrainedTokenizerBase) -> None:
|
| 68 |
+
super().__init__(config)
|
| 69 |
+
self.tokenizer = tokenizer
|
| 70 |
+
|
| 71 |
+
def __call__(self, *_) -> str:
|
| 72 |
+
return ""
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def np_unique(
|
| 76 |
+
data: np.ndarray,
|
| 77 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 78 |
+
"""
|
| 79 |
+
Compute unique elements in data and corresponding indices.
|
| 80 |
+
|
| 81 |
+
np.unique returns the values in a sorted order, even if the source is not sorted. Thus, if you simply
|
| 82 |
+
run np.unique on unsorted data, the indices you will get will be invalid.
|
| 83 |
+
|
| 84 |
+
"""
|
| 85 |
+
(_, indices, inverse) = np.unique(data, return_index=True, return_inverse=True)
|
| 86 |
+
(_, indices_of_first_occurence, inverse_indices, counts) = np.unique(
|
| 87 |
+
indices[inverse], return_index=True, return_inverse=True, return_counts=True
|
| 88 |
+
)
|
| 89 |
+
unique_ids = data[indices_of_first_occurence]
|
| 90 |
+
return unique_ids, indices_of_first_occurence, inverse_indices, counts
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def euler_to_rotmat(angles: torch.Tensor) -> torch.Tensor:
|
| 94 |
+
"""
|
| 95 |
+
Args:
|
| 96 |
+
angles: Euler angles in radians in the format 'xyz', shape [..., 3]
|
| 97 |
+
Returns:
|
| 98 |
+
torch.Tensor of shape [..., 3, 3] containing rotation matrices
|
| 99 |
+
"""
|
| 100 |
+
return roma.euler_to_rotmat(convention="xyz", angles=angles, degrees=False)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def euler_to_unit_quaternion(angles: torch.Tensor) -> torch.Tensor:
|
| 104 |
+
"""
|
| 105 |
+
Args:
|
| 106 |
+
angles: Euler angles in radians in the format 'xyz', shape [..., 3]
|
| 107 |
+
Returns:
|
| 108 |
+
torch.Tensor of shape [..., 4] containing unit quaternions
|
| 109 |
+
"""
|
| 110 |
+
return roma.euler_to_unitquat(convention="xyz", angles=angles, degrees=False, normalize=True)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def normalize_quaternion(quaternion: torch.Tensor, eps: float = 1e-08) -> torch.Tensor:
|
| 114 |
+
"""
|
| 115 |
+
Args:
|
| 116 |
+
quaternion: Unnormalized quaternion, torch.Tensor of shape [..., 4]
|
| 117 |
+
eps: Small constant to prevent division by zero
|
| 118 |
+
Returns:
|
| 119 |
+
torch.Tensor of shape [..., 4] of unit quaternions
|
| 120 |
+
"""
|
| 121 |
+
return quaternion / (quaternion.norm(dim=-1, keepdim=True).detach() + eps)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def quaternion_to_euler(quaternion: torch.Tensor) -> torch.Tensor:
|
| 125 |
+
"""
|
| 126 |
+
Args:
|
| 127 |
+
quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized
|
| 128 |
+
Returns:
|
| 129 |
+
torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3)
|
| 130 |
+
"""
|
| 131 |
+
unit_quat = normalize_quaternion(quaternion)
|
| 132 |
+
rotmat = roma.unitquat_to_euler(convention="xyz", quat=unit_quat, as_tuple=False, degrees=False)
|
| 133 |
+
return rotmat
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def quaternion_to_rotmat(quaternion: torch.Tensor) -> torch.Tensor:
|
| 137 |
+
"""
|
| 138 |
+
Args:
|
| 139 |
+
quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized
|
| 140 |
+
Returns:
|
| 141 |
+
torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3)
|
| 142 |
+
"""
|
| 143 |
+
unit_quat = normalize_quaternion(quaternion)
|
| 144 |
+
rotmat = roma.unitquat_to_rotmat(unit_quat)
|
| 145 |
+
return rotmat
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def rotmat_to_unit_quaternion(rotmat: torch.Tensor) -> torch.Tensor:
|
| 149 |
+
"""
|
| 150 |
+
Args:
|
| 151 |
+
rotmat: Batch of rotation matrices, shape [..., 3, 3]
|
| 152 |
+
Returns:
|
| 153 |
+
Batch of unit quaternions, shape [..., 4]
|
| 154 |
+
"""
|
| 155 |
+
rotmat = rotmat_as_3x3(rotmat)
|
| 156 |
+
return roma.rotmat_to_unitquat(rotmat)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def rotmat_to_euler(rotmat: torch.Tensor) -> torch.Tensor:
|
| 160 |
+
"""
|
| 161 |
+
Args:
|
| 162 |
+
rotmat: Batch of rotation matrices, shape [..., 3, 3]
|
| 163 |
+
Returns:
|
| 164 |
+
Batch of Euler angles in radiant, shape [..., 3]
|
| 165 |
+
"""
|
| 166 |
+
rotmat = rotmat_as_3x3(rotmat)
|
| 167 |
+
return roma.rotmat_to_euler(convention="xyz", rotmat=rotmat, as_tuple=False, degrees=False)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def symmetric_orthogonalization(x: torch.Tensor) -> torch.Tensor:
|
| 171 |
+
"""
|
| 172 |
+
Maps 9D input vectors onto SO(3) via symmetric orthogonalization.
|
| 173 |
+
- Let SVD(M) = U \Sigma V^T
|
| 174 |
+
- Returned value is SVD+(M) = U diag(1, 1, det(UV^T)) V^T
|
| 175 |
+
- det(UV^T) ensures that det(SVD+(M)) = 1
|
| 176 |
+
- The return value is a rotation matrix (ortonormal) with the least-squares distance to M
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
x: Input matrices, not necessarily orthonormal, shape [..., 9] or [..., 3, 3]
|
| 180 |
+
Returns:
|
| 181 |
+
torch.Tensor with the same shape as x, where each inner 3x3 matrix is in SO(3)
|
| 182 |
+
"""
|
| 183 |
+
with warnings.catch_warnings():
|
| 184 |
+
warnings.filterwarnings(
|
| 185 |
+
"ignore",
|
| 186 |
+
message="In CPU autocast, but the target dtype is not supported. Disabling autocast.",
|
| 187 |
+
)
|
| 188 |
+
with torch.autocast(device_type=x.device.type, dtype=torch.float32):
|
| 189 |
+
matrices = x.view(-1, 3, 3)
|
| 190 |
+
matrices = matrices.to(dtype=torch.float32)
|
| 191 |
+
(u, s, v) = torch.svd(matrices)
|
| 192 |
+
vt = torch.transpose(v, 1, 2)
|
| 193 |
+
det = torch.det(torch.matmul(u, vt)).view(-1, 1, 1)
|
| 194 |
+
diag_vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), dim=1)
|
| 195 |
+
result = torch.matmul(u, diag_vt)
|
| 196 |
+
result = result.view(*x.shape)
|
| 197 |
+
result = result.to(dtype=x.dtype)
|
| 198 |
+
return result
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def is_rotmat_orthonormal(
|
| 202 |
+
rotmat: torch.Tensor, epsilon: float = 1e-06, reduction: str = "none"
|
| 203 |
+
) -> torch.Tensor | bool:
|
| 204 |
+
"""
|
| 205 |
+
Check if a rotation matrix is orthonormal or not.
|
| 206 |
+
Args:
|
| 207 |
+
rotmat: torch.Tensor of shape [..., 3, 3] or [..., 9]
|
| 208 |
+
epsilon: Tolerance for numerical comparisons. Bigger values allow for more freedom. Generally,
|
| 209 |
+
anything smaller than 1e-6 might incorrectly detect some otrhonormal matrices as not
|
| 210 |
+
reduction:
|
| 211 |
+
'none' - returns torch.Tensor of bools with the same batch shape
|
| 212 |
+
'all' - returns a bool, True is ALL matrices in the batch are orthonormal
|
| 213 |
+
Returns:
|
| 214 |
+
torch.Tensor with the same batch shape or bool
|
| 215 |
+
"""
|
| 216 |
+
assert is_rotmat(rotmat)
|
| 217 |
+
rotmat = rotmat_as_3x3(rotmat.to(dtype=torch.float32))
|
| 218 |
+
is_orthonormal = roma.is_orthonormal_matrix(rotmat, epsilon=epsilon)
|
| 219 |
+
if reduction == "none":
|
| 220 |
+
return is_orthonormal
|
| 221 |
+
if reduction == "all":
|
| 222 |
+
return bool(torch.all(is_orthonormal).item())
|
| 223 |
+
raise ValueError(f"Unknown reduction mode {reduction}")
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def is_orthonormal_rotmat(rotmat: torch.Tensor) -> bool:
|
| 227 |
+
"""
|
| 228 |
+
Checks if the tensor shape matches that of a rotmat. If the last dimensions of shape are 3x3,
|
| 229 |
+
also checks if the data is a valid rotmat. This is to avoid a possible clash with euler angles
|
| 230 |
+
when accidentally `rotmat.shape[-2:] == [3, 3]`
|
| 231 |
+
"""
|
| 232 |
+
return (
|
| 233 |
+
is_rotmat_9(rotmat)
|
| 234 |
+
or is_rotmat_3x3(rotmat)
|
| 235 |
+
and is_rotmat_orthonormal(rotmat, epsilon=0.01, reduction="all")
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def is_euler(euler: torch.Tensor) -> bool:
|
| 240 |
+
return euler.shape[-1] == 3 and not is_orthonormal_rotmat(euler)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def normalize_rotation(rotation: torch.Tensor) -> torch.Tensor:
|
| 244 |
+
if is_quaternion(rotation):
|
| 245 |
+
return normalize_quaternion(rotation)
|
| 246 |
+
if is_euler(rotation):
|
| 247 |
+
return rotation
|
| 248 |
+
if is_rotmat(rotation):
|
| 249 |
+
is_flat = is_rotmat_9(rotation)
|
| 250 |
+
rotation = rotmat_as_3x3(rotation) if is_flat else rotation
|
| 251 |
+
rotmat = roma.special_gramschmidt(rotation)
|
| 252 |
+
rotmat = rotmat_as_9(rotmat) if is_flat else rotmat
|
| 253 |
+
return rotmat
|
| 254 |
+
raise ValueError(f"Unknown rotation format: {rotation.shape}")
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def rotation_format_from_tensor(rotation) -> RotationFormat:
|
| 258 |
+
if is_quaternion(rotation):
|
| 259 |
+
return RotationFormat.QUATERNION
|
| 260 |
+
if is_orthonormal_rotmat(rotation):
|
| 261 |
+
return RotationFormat.ROTMAT
|
| 262 |
+
if is_euler(rotation):
|
| 263 |
+
return RotationFormat.EULER
|
| 264 |
+
raise ValueError(f"Tensor shape {rotation.shape} is not a valid rotation format")
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def is_unit_quaternion(
|
| 268 |
+
quaternion: torch.Tensor, epsilon: float = 1e-08, reduction: str = "none"
|
| 269 |
+
) -> torch.Tensor | bool:
|
| 270 |
+
"""
|
| 271 |
+
Check if a quternion is normalized or not.
|
| 272 |
+
Args:
|
| 273 |
+
quaternion: torch.Tensor of shape [..., 4]
|
| 274 |
+
tolerance: Tolerance for numerical comparisons
|
| 275 |
+
reduction:
|
| 276 |
+
'none' - returns torch.Tensor of bools with the same batch shape
|
| 277 |
+
'all' - returns a bool, True if ALL quaternions in the batch are normalized
|
| 278 |
+
Returns:
|
| 279 |
+
torch.Tensor with the same batch shape or bool
|
| 280 |
+
"""
|
| 281 |
+
assert is_quaternion(quaternion)
|
| 282 |
+
is_norm = torch.isclose(
|
| 283 |
+
quaternion.norm(dim=-1, keepdim=True),
|
| 284 |
+
torch.tensor(1.0, dtype=quaternion.dtype, device=quaternion.device),
|
| 285 |
+
atol=epsilon,
|
| 286 |
+
)
|
| 287 |
+
if reduction == "none":
|
| 288 |
+
return is_norm
|
| 289 |
+
if reduction == "all":
|
| 290 |
+
return bool(torch.all(is_norm).item())
|
| 291 |
+
raise ValueError(f"Unknown reduction mode {reduction}")
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def convert_rotation(
|
| 295 |
+
rotation: torch.Tensor | np.ndarray,
|
| 296 |
+
output_format: RotationFormat,
|
| 297 |
+
autonorm: bool = True,
|
| 298 |
+
half_cover: bool = True,
|
| 299 |
+
) -> torch.Tensor | np.ndarray:
|
| 300 |
+
is_np = isinstance(rotation, np.ndarray)
|
| 301 |
+
if is_np:
|
| 302 |
+
rotation = torch.from_numpy(rotation)
|
| 303 |
+
if is_quaternion(rotation):
|
| 304 |
+
if autonorm and not is_unit_quaternion(rotation, reduction="all"):
|
| 305 |
+
rotation = normalize_quaternion(rotation)
|
| 306 |
+
if output_format == RotationFormat.QUATERNION:
|
| 307 |
+
output = rotation
|
| 308 |
+
elif output_format == RotationFormat.ROTMAT:
|
| 309 |
+
output = rotmat_as_9(quaternion_to_rotmat(rotation))
|
| 310 |
+
elif output_format == RotationFormat.EULER:
|
| 311 |
+
output = quaternion_to_euler(rotation)
|
| 312 |
+
else:
|
| 313 |
+
raise NotImplementedError(f"Unsupported rotation format: {output_format}")
|
| 314 |
+
elif is_orthonormal_rotmat(rotation):
|
| 315 |
+
if autonorm and not is_rotmat_orthonormal(rotation, epsilon=0.01, reduction="all"):
|
| 316 |
+
rotation = symmetric_orthogonalization(rotation)
|
| 317 |
+
if output_format == RotationFormat.QUATERNION:
|
| 318 |
+
output = rotmat_to_unit_quaternion(rotation)
|
| 319 |
+
elif output_format == RotationFormat.ROTMAT:
|
| 320 |
+
output = rotmat_as_9(rotation)
|
| 321 |
+
elif output_format == RotationFormat.EULER:
|
| 322 |
+
output = rotmat_to_euler(rotation)
|
| 323 |
+
else:
|
| 324 |
+
raise NotImplementedError(f"Unsupported rotation format: {output_format}")
|
| 325 |
+
elif is_euler(rotation):
|
| 326 |
+
if output_format == RotationFormat.QUATERNION:
|
| 327 |
+
output = euler_to_unit_quaternion(rotation)
|
| 328 |
+
elif output_format == RotationFormat.ROTMAT:
|
| 329 |
+
output = rotmat_as_9(euler_to_rotmat(rotation))
|
| 330 |
+
elif output_format == RotationFormat.EULER:
|
| 331 |
+
output = rotation
|
| 332 |
+
else:
|
| 333 |
+
raise NotImplementedError(f"Unsupported rotation format: {output_format}")
|
| 334 |
+
else:
|
| 335 |
+
raise ValueError(f"Unknown rotation encoding with shape {rotation.shape}")
|
| 336 |
+
if output_format == RotationFormat.QUATERNION and half_cover:
|
| 337 |
+
output = quaternion_half_cover(output)
|
| 338 |
+
if is_np:
|
| 339 |
+
output = output.numpy()
|
| 340 |
+
return output
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def delta_to_relative_rotations(rotation_sequence: torch.Tensor) -> torch.Tensor:
|
| 344 |
+
"""
|
| 345 |
+
Transform a sequence of rotation representations encoded w.r.t. the PREVIOUS rotation frame in the
|
| 346 |
+
sequence to the 0-th element preceding the sequence
|
| 347 |
+
|
| 348 |
+
Ex:
|
| 349 |
+
`rotation_sequence` contains the rotations: R_01, R_12, R_23, R_34, where R0 is the base frame,
|
| 350 |
+
implicitly encoded in R_01 and R_10 converts from R0 frame to R1 frame
|
| 351 |
+
Output: R_01, R_02, R_03, R_04
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
rotation_sequence: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4], containing
|
| 355 |
+
either rotation matrices (R_01, R_12, R_23, R_34, ...) or quaternions
|
| 356 |
+
Returns:
|
| 357 |
+
torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] containing transformed rotations
|
| 358 |
+
(R_01, R_02, R_03, R_04, ...)
|
| 359 |
+
|
| 360 |
+
TODO: Can you make it work without for loop
|
| 361 |
+
"""
|
| 362 |
+
assert rotation_sequence.ndim >= 3, rotation_sequence.shape
|
| 363 |
+
rotation_format: RotationFormat = rotation_format_from_tensor(rotation_sequence)
|
| 364 |
+
rotation_sequence = convert_rotation(rotation_sequence, RotationFormat.QUATERNION)
|
| 365 |
+
batch_dims = np.arange(rotation_sequence.ndim - 2)
|
| 366 |
+
delta_rotations = torch.cat(
|
| 367 |
+
[rotation_sequence[..., :1, :]]
|
| 368 |
+
+ [
|
| 369 |
+
roma.quat_composition(rotation_sequence[..., :i, :].permute(-2, *batch_dims, -1).unsqueeze(-2))
|
| 370 |
+
for i in range(2, rotation_sequence.shape[-2] + 1)
|
| 371 |
+
],
|
| 372 |
+
dim=-2,
|
| 373 |
+
)
|
| 374 |
+
delta_rotations = convert_rotation(delta_rotations, rotation_format)
|
| 375 |
+
return delta_rotations
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def assert_np_hwc_or_hw_image(image: np.ndarray | PIL.Image.Image) -> np.ndarray:
|
| 379 |
+
"""Make sure image is of type np.ndarray and HWC format"""
|
| 380 |
+
if isinstance(image, PIL.Image.Image):
|
| 381 |
+
image = np.asarray(image)
|
| 382 |
+
assert isinstance(image, np.ndarray), type(image)
|
| 383 |
+
assert image.ndim in [2, 3], image.shape
|
| 384 |
+
if image.ndim == 3:
|
| 385 |
+
assert image.shape[-1] <= 4, image.shape
|
| 386 |
+
return image
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def hw_from_image(image: PIL.Image.Image | np.ndarray) -> tuple[int, int]:
|
| 390 |
+
if isinstance(image, np.ndarray):
|
| 391 |
+
(height, width) = image.shape[:2]
|
| 392 |
+
else:
|
| 393 |
+
(width, height) = image.size
|
| 394 |
+
return height, width
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def pad_image(
|
| 398 |
+
image: PIL.Image.Image | np.ndarray,
|
| 399 |
+
target_size: dict[str, int],
|
| 400 |
+
pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
|
| 401 |
+
) -> PIL.Image.Image | np.ndarray:
|
| 402 |
+
"""Pad image adding a symmetric border around the height/width."""
|
| 403 |
+
assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image)
|
| 404 |
+
(height, width) = hw_from_image(image)
|
| 405 |
+
(target_width, target_height) = (target_size["width"], target_size["height"])
|
| 406 |
+
if width == target_width and height == target_height:
|
| 407 |
+
return image
|
| 408 |
+
assert target_width >= width, f"Can't pad image of width {width} to {target_width}"
|
| 409 |
+
assert target_height >= height, f"Can't pad image of height {height} to {target_height}"
|
| 410 |
+
(horizontal_pad, vertical_pad) = (
|
| 411 |
+
int((target_width - width) / 2),
|
| 412 |
+
int((target_height - height) / 2),
|
| 413 |
+
)
|
| 414 |
+
if isinstance(image, np.ndarray):
|
| 415 |
+
padding = ((vertical_pad, vertical_pad), (horizontal_pad, horizontal_pad)) + ((0, 0),) * (
|
| 416 |
+
image.ndim - 2
|
| 417 |
+
)
|
| 418 |
+
image = np.pad(image, padding, mode="constant", constant_values=pad_value)
|
| 419 |
+
else:
|
| 420 |
+
padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
|
| 421 |
+
image = torchvision.transforms.v2.functional.pad(
|
| 422 |
+
image, padding=padding, fill=pad_value, padding_mode="constant"
|
| 423 |
+
)
|
| 424 |
+
return image
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def pad_image_to_ratio(
|
| 428 |
+
image: PIL.Image.Image | np.ndarray,
|
| 429 |
+
target_wh_ratio: float,
|
| 430 |
+
pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
|
| 431 |
+
) -> PIL.Image.Image | np.ndarray:
|
| 432 |
+
"""Pad image to a target aspect ratio."""
|
| 433 |
+
(height, width) = hw_from_image(image)
|
| 434 |
+
wh_ratio = width / height
|
| 435 |
+
if target_wh_ratio >= wh_ratio:
|
| 436 |
+
pad_size = {"width": round(height * target_wh_ratio), "height": height}
|
| 437 |
+
else:
|
| 438 |
+
pad_size = {"width": width, "height": round(width / target_wh_ratio)}
|
| 439 |
+
image = pad_image(image, target_size=pad_size, pad_value=pad_value)
|
| 440 |
+
return image
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def crop_image(
|
| 444 |
+
image: np.ndarray | PIL.Image.Image,
|
| 445 |
+
start_height: int,
|
| 446 |
+
start_width: int,
|
| 447 |
+
target_height: int,
|
| 448 |
+
target_width: int,
|
| 449 |
+
) -> np.ndarray | PIL.Image.Image:
|
| 450 |
+
np_image = assert_np_hwc_or_hw_image(image)
|
| 451 |
+
(height, width) = hw_from_image(image)
|
| 452 |
+
assert target_width <= width, f"Can't crop image of width {width} to {target_width}"
|
| 453 |
+
assert target_height <= height, f"Can't crop image of width {height} to {target_height}"
|
| 454 |
+
(start_height, start_width) = (round(start_height), round(start_width))
|
| 455 |
+
(target_height, target_width) = (round(target_height), round(target_width))
|
| 456 |
+
np_image = np_image[
|
| 457 |
+
start_height : start_height + target_height,
|
| 458 |
+
start_width : start_width + target_width,
|
| 459 |
+
...,
|
| 460 |
+
]
|
| 461 |
+
image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image
|
| 462 |
+
return image
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def crop_image_center(
|
| 466 |
+
image: np.ndarray | PIL.Image.Image, target_size: dict[str, int]
|
| 467 |
+
) -> np.ndarray | PIL.Image.Image:
|
| 468 |
+
np_image = assert_np_hwc_or_hw_image(image)
|
| 469 |
+
(height, width) = np_image.shape[:2]
|
| 470 |
+
(target_height, target_width) = (target_size["height"], target_size["width"])
|
| 471 |
+
assert target_width <= width, f"Can't crop image of width {width} to {target_width}"
|
| 472 |
+
assert target_height <= height, f"Can't crop image of width {height} to {target_height}"
|
| 473 |
+
top = (height - target_height) // 2
|
| 474 |
+
left = (width - target_width) // 2
|
| 475 |
+
np_image = crop_image(np_image, top, left, target_height, target_width)
|
| 476 |
+
image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image
|
| 477 |
+
return image
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def crop_image_to_ratio(
|
| 481 |
+
image: PIL.Image.Image | np.ndarray, target_wh_ratio: float
|
| 482 |
+
) -> PIL.Image.Image | np.ndarray:
|
| 483 |
+
"""Pad image to a target aspect ratio."""
|
| 484 |
+
(height, width) = hw_from_image(image)
|
| 485 |
+
wh_ratio = width / height
|
| 486 |
+
if target_wh_ratio >= wh_ratio:
|
| 487 |
+
crop_size = {"width": width, "height": round(width / target_wh_ratio)}
|
| 488 |
+
else:
|
| 489 |
+
crop_size = {"width": round(height * target_wh_ratio), "height": height}
|
| 490 |
+
image = crop_image_center(image, target_size=crop_size)
|
| 491 |
+
return image
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def crop_and_pad_image_to_ratio(
|
| 495 |
+
image: PIL.Image.Image | np.ndarray,
|
| 496 |
+
target_wh_ratio: float,
|
| 497 |
+
mode: ResizeMode | str,
|
| 498 |
+
pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
|
| 499 |
+
) -> PIL.Image.Image | np.ndarray:
|
| 500 |
+
"""
|
| 501 |
+
Crop and pad an image to a target size depending on the mode.
|
| 502 |
+
It's expected that the source image and target size have different aspect ratios.
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
image: The image to crop and pad.
|
| 506 |
+
target_size: The target size to crop and pad the image to.
|
| 507 |
+
mode: The mode to use for cropping and padding.
|
| 508 |
+
"""
|
| 509 |
+
(height, width) = hw_from_image(image)
|
| 510 |
+
wh_ratio = width / height
|
| 511 |
+
if np.isclose(wh_ratio, target_wh_ratio, rtol=0.01, atol=0.0001):
|
| 512 |
+
return image
|
| 513 |
+
if mode == ResizeMode.SMART:
|
| 514 |
+
aspect_ratio = max(width, height) / min(width, height)
|
| 515 |
+
target_ratio = max(target_wh_ratio, 1 / target_wh_ratio)
|
| 516 |
+
if aspect_ratio == 1:
|
| 517 |
+
if target_ratio >= 4 / 3 - 0.01:
|
| 518 |
+
crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4
|
| 519 |
+
image = crop_image_to_ratio(image, crop_wh_ratio)
|
| 520 |
+
else:
|
| 521 |
+
pass
|
| 522 |
+
elif aspect_ratio <= 4 / 3 + 0.01:
|
| 523 |
+
if wh_ratio >= 1.0 != (target_wh_ratio >= 1.0):
|
| 524 |
+
image = crop_image_to_ratio(image, 1.0)
|
| 525 |
+
elif wh_ratio >= 1.0 != (target_wh_ratio >= 1.0):
|
| 526 |
+
image = crop_image_to_ratio(image, 1.0)
|
| 527 |
+
elif target_ratio >= 4 / 3 + 0.01:
|
| 528 |
+
pass
|
| 529 |
+
else:
|
| 530 |
+
crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4
|
| 531 |
+
image = crop_image_to_ratio(image, crop_wh_ratio)
|
| 532 |
+
image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value)
|
| 533 |
+
elif mode == ResizeMode.PAD:
|
| 534 |
+
image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value)
|
| 535 |
+
elif mode == ResizeMode.CROP:
|
| 536 |
+
image = crop_image_to_ratio(image, target_wh_ratio)
|
| 537 |
+
else:
|
| 538 |
+
raise ValueError(f"Mode {mode} not supported")
|
| 539 |
+
return image
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
def is_single_channel_image(image: np.ndarray | PIL.Image.Image) -> bool:
|
| 543 |
+
if isinstance(image, PIL.Image.Image):
|
| 544 |
+
return image.mode in [
|
| 545 |
+
"1",
|
| 546 |
+
"L",
|
| 547 |
+
"LA",
|
| 548 |
+
"La",
|
| 549 |
+
"P",
|
| 550 |
+
"PA",
|
| 551 |
+
"F",
|
| 552 |
+
"I",
|
| 553 |
+
"I;16",
|
| 554 |
+
"I;16L",
|
| 555 |
+
"I;16B",
|
| 556 |
+
"I;16N",
|
| 557 |
+
]
|
| 558 |
+
if isinstance(image, np.ndarray):
|
| 559 |
+
return image.ndim == 2 or image.ndim == 3 and image.shape[2] == 1
|
| 560 |
+
raise ValueError(f"Unsupported image type: {type(image)}")
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
def is_binary_mask(image: np.ndarray | PIL.Image.Image) -> bool:
|
| 564 |
+
image = np.asarray(image)
|
| 565 |
+
return image.dtype in [np.uint8, np.bool_] and np.max(image) == 1
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def resize_image(
|
| 569 |
+
image: PIL.Image.Image | np.ndarray,
|
| 570 |
+
target_size: dict[str, int],
|
| 571 |
+
mode: ResizeMode | str,
|
| 572 |
+
resample: PIL.Image.Resampling | str = "auto",
|
| 573 |
+
pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
|
| 574 |
+
) -> PIL.Image.Image | np.ndarray:
|
| 575 |
+
(target_width, target_height) = (target_size["width"], target_size["height"])
|
| 576 |
+
(height, width) = hw_from_image(image)
|
| 577 |
+
if height == target_height and width == target_width:
|
| 578 |
+
return image
|
| 579 |
+
if resample == "auto":
|
| 580 |
+
if is_single_channel_image(image):
|
| 581 |
+
resample = PIL.Image.Resampling.BILINEAR
|
| 582 |
+
else:
|
| 583 |
+
resample = PIL.Image.Resampling.LANCZOS
|
| 584 |
+
else:
|
| 585 |
+
assert isinstance(resample, PIL.Image.Resampling), resample
|
| 586 |
+
if is_single_channel_image(image) and resample not in [
|
| 587 |
+
PIL.Image.Resampling.BILINEAR,
|
| 588 |
+
PIL.Image.Resampling.BICUBIC,
|
| 589 |
+
]:
|
| 590 |
+
raise ValueError(
|
| 591 |
+
f"Single channel images must be resized with bilinear or bicubic, but got {resample}"
|
| 592 |
+
)
|
| 593 |
+
if is_bin_mask := is_binary_mask(image):
|
| 594 |
+
image = np.asarray(image).astype(np.uint8) * 255
|
| 595 |
+
if mode == ResizeMode.SMART:
|
| 596 |
+
image = crop_and_pad_image_to_ratio(
|
| 597 |
+
image,
|
| 598 |
+
target_wh_ratio=target_width / target_height,
|
| 599 |
+
mode=mode,
|
| 600 |
+
pad_value=pad_value,
|
| 601 |
+
)
|
| 602 |
+
pil_image = PIL.Image.fromarray(image) if isinstance(image, np.ndarray) else image
|
| 603 |
+
if mode in [ResizeMode.NAIVE, ResizeMode.SMART]:
|
| 604 |
+
pil_image = pil_image.resize((target_width, target_height), resample=resample)
|
| 605 |
+
else:
|
| 606 |
+
raise NotImplementedError(f"Mode {mode} not supported")
|
| 607 |
+
image = np.asarray(pil_image) if isinstance(image, np.ndarray) else pil_image
|
| 608 |
+
if is_bin_mask:
|
| 609 |
+
image = image.astype(np.uint8) > 127
|
| 610 |
+
return image
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def is_global_norm(
|
| 614 |
+
norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list],
|
| 615 |
+
) -> bool:
|
| 616 |
+
"""Return true if norm is NONE or global for all datasets"""
|
| 617 |
+
return norm == Normalization.NONE or isinstance(norm, collections.abc.Mapping)
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
def is_mean_norm(
|
| 621 |
+
norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list],
|
| 622 |
+
) -> bool:
|
| 623 |
+
"""Return true if norm is based on mean and std"""
|
| 624 |
+
return (
|
| 625 |
+
norm == Normalization.MEAN
|
| 626 |
+
or isinstance(norm, collections.abc.Mapping)
|
| 627 |
+
and set(norm.keys()) == {"mean", "std"}
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def _broadcast_shapes(
|
| 632 |
+
value: torch.Tensor, low: torch.Tensor, high: torch.Tensor
|
| 633 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 634 |
+
"""
|
| 635 |
+
Broadcast shapes for normalization:
|
| 636 |
+
Args:
|
| 637 |
+
value: torch.Tensor of shape [..., num_components]. The entire shape might be:
|
| 638 |
+
- [num_components]: `value` has no batch dimension
|
| 639 |
+
- [num_datasets, num_components]: `value` contains entries *aligned* with the dataset bounds
|
| 640 |
+
contained in `low` and `high`
|
| 641 |
+
- [num_datasets, ..., num_components]: `value` contains entries *aligned* with the dataset bounds
|
| 642 |
+
contained in `low` and `high`
|
| 643 |
+
- [..., num_components]: `value` contains multiple dimensions. In this case, `low` and `high`
|
| 644 |
+
must be for a single dataset, i.e. `num_datasets = 1`
|
| 645 |
+
|
| 646 |
+
low: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `low`
|
| 647 |
+
contains normalization bounds for a single dataset
|
| 648 |
+
high: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `high`
|
| 649 |
+
contains normalization bounds for a single dataset
|
| 650 |
+
Returns:
|
| 651 |
+
Tuple of torch.Tensors (low, high), where `low` and `high` have the same number of dimensions as `value`
|
| 652 |
+
"""
|
| 653 |
+
assert low.ndim == high.ndim == 2, f"{low.shape} != {high.shape} or ndim != 2"
|
| 654 |
+
assert value.shape[-1] == low.shape[-1] == high.shape[-1], f"{value.shape} != {low.shape} / {high.shape}"
|
| 655 |
+
if value.ndim == low.ndim == high.ndim:
|
| 656 |
+
return low, high
|
| 657 |
+
if value.ndim < low.ndim:
|
| 658 |
+
assert low.ndim == high.ndim == 2, f"{low.shape}, {high.shape}"
|
| 659 |
+
assert low.shape[0] == high.shape[0] == 1, f"{low.shape}, {high.shape}"
|
| 660 |
+
(low, high) = (low.view(-1), high.view(-1))
|
| 661 |
+
return low, high
|
| 662 |
+
if low.shape[0] == high.shape[0] == 1:
|
| 663 |
+
low = expand_dims(low.view(-1), ndim=value.ndim, order=[-1, 1])
|
| 664 |
+
high = expand_dims(high.view(-1), ndim=value.ndim, order=[-1, 1])
|
| 665 |
+
else:
|
| 666 |
+
assert value.shape[0] == low.shape[0] == high.shape[0], f"{value.shape} != {low.shape} / {high.shape}"
|
| 667 |
+
low = expand_dims(low, ndim=value.ndim, order=[1, -1, 1])
|
| 668 |
+
high = expand_dims(high, ndim=value.ndim, order=[1, -1, 1])
|
| 669 |
+
return low, high
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
def unnormalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
|
| 673 |
+
(mean, std) = _broadcast_shapes(value, mean, std)
|
| 674 |
+
(mean, std) = (mean.to(device=value.device), std.to(device=value.device))
|
| 675 |
+
return value * (std + 1e-08) + mean
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
def unnormalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor:
|
| 679 |
+
(low, high) = _broadcast_shapes(value, low, high)
|
| 680 |
+
(low, high) = (low.to(device=value.device), high.to(device=value.device))
|
| 681 |
+
return 0.5 * (value + 1) * (high - low) + low
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
def normalize_gripper_by_bounds(
|
| 685 |
+
value: torch.Tensor, low: torch.Tensor, high: torch.Tensor, binary: bool = True
|
| 686 |
+
) -> torch.Tensor:
|
| 687 |
+
"""
|
| 688 |
+
If binary, normalize to [0, 1], otherwise normalize to [-1, 1]
|
| 689 |
+
"""
|
| 690 |
+
(low, high) = _broadcast_shapes(value, low, high)
|
| 691 |
+
(low, high) = (low.to(device=value.device), high.to(device=value.device))
|
| 692 |
+
if binary:
|
| 693 |
+
return torch.clamp((value - low) / torch.clamp(high - low, min=1e-08), min=0.0, max=1.0)
|
| 694 |
+
return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0)
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
def normalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
|
| 698 |
+
(mean, std) = _broadcast_shapes(value, mean, std)
|
| 699 |
+
(mean, std) = (mean.to(device=value.device), std.to(device=value.device))
|
| 700 |
+
return (value - mean) / (std + 1e-08)
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
def normalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor:
|
| 704 |
+
(low, high) = _broadcast_shapes(value, low, high)
|
| 705 |
+
(low, high) = (low.to(device=value.device), high.to(device=value.device))
|
| 706 |
+
return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0)
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
def invert_gripper(gripper: np.ndarray, low: float, high: float) -> np.ndarray:
|
| 710 |
+
if low < 0.0:
|
| 711 |
+
return np.clip(-gripper, low, high)
|
| 712 |
+
return high - np.clip(gripper, low, high)
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
GRIPPER_BOUNDS = {
|
| 716 |
+
"bridge": (0.0, 1.0),
|
| 717 |
+
"bridge_orig": (0.0, 1.0),
|
| 718 |
+
"droid": (0.0, 1.0),
|
| 719 |
+
"roboset": (0.0, 1.0),
|
| 720 |
+
}
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
def preprocess_gripper_observation(
|
| 724 |
+
gripper: np.ndarray, dataset_name: str | np.ndarray, binary: bool = True
|
| 725 |
+
) -> np.ndarray:
|
| 726 |
+
"""
|
| 727 |
+
Preprocess gripper observation depending on dataset. Input is the raw gripper observation from the dataset
|
| 728 |
+
or from the robot and output is normalized continuous value.
|
| 729 |
+
- if `binary`, output is in [0, 1], with 0 = closed and 1 = open.
|
| 730 |
+
- otherwise, output is in [-1, 1], with -1 = closed and 1 = open.
|
| 731 |
+
|
| 732 |
+
Dataset-specific gripper observations:
|
| 733 |
+
bridge: continuous; ~[0=closed; 1=open]
|
| 734 |
+
bridge_orig: continuous; ~[0=closed; 1=open]
|
| 735 |
+
droid: continuous; [0=open, 1=closed]
|
| 736 |
+
roboset: continuous; [0=open, 1=closed]
|
| 737 |
+
"""
|
| 738 |
+
if isinstance(dataset_name, np.ndarray):
|
| 739 |
+
assert np.unique(dataset_name).size == 1, dataset_name
|
| 740 |
+
dataset_name = str(dataset_name[0])
|
| 741 |
+
if dataset_name in [
|
| 742 |
+
"droid",
|
| 743 |
+
"roboset",
|
| 744 |
+
]:
|
| 745 |
+
(low, high) = GRIPPER_BOUNDS[dataset_name]
|
| 746 |
+
gripper = normalize_gripper_by_bounds(
|
| 747 |
+
torch.from_numpy(invert_gripper(gripper, low=low, high=high)),
|
| 748 |
+
low=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][0], dtype=torch.float32),
|
| 749 |
+
high=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][1], dtype=torch.float32),
|
| 750 |
+
binary=binary,
|
| 751 |
+
).numpy()
|
| 752 |
+
elif dataset_name in [
|
| 753 |
+
"bridge",
|
| 754 |
+
"bridge_orig",
|
| 755 |
+
]:
|
| 756 |
+
(low, high) = GRIPPER_BOUNDS[dataset_name]
|
| 757 |
+
gripper = normalize_gripper_by_bounds(
|
| 758 |
+
torch.from_numpy(gripper),
|
| 759 |
+
low=torch.full(gripper.shape, low, dtype=torch.float32),
|
| 760 |
+
high=torch.full(gripper.shape, high, dtype=torch.float32),
|
| 761 |
+
binary=binary,
|
| 762 |
+
).numpy()
|
| 763 |
+
else:
|
| 764 |
+
raise NotImplementedError(f"Unknown dataset: {dataset_name}")
|
| 765 |
+
return gripper
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
def rotation_norm_bounds(
|
| 769 |
+
rotation_norm: Normalization,
|
| 770 |
+
rotation_format: RotationFormat,
|
| 771 |
+
stats: Dict[str, Dict[str, Dict[str, List[float]]]],
|
| 772 |
+
dataset_names: List[str],
|
| 773 |
+
) -> Dict[str, Dict[str, torch.Tensor]]:
|
| 774 |
+
if rotation_format == RotationFormat.EULER and rotation_norm != Normalization.NONE:
|
| 775 |
+
if rotation_norm == Normalization.BOUNDS:
|
| 776 |
+
results = {
|
| 777 |
+
dataset_name: {
|
| 778 |
+
"low": torch.tensor(dataset_stats["euler"]["min"]),
|
| 779 |
+
"high": torch.tensor(dataset_stats["euler"]["max"]),
|
| 780 |
+
}
|
| 781 |
+
for (dataset_name, dataset_stats) in stats.items()
|
| 782 |
+
}
|
| 783 |
+
elif rotation_norm == Normalization.BOUNDS_Q99:
|
| 784 |
+
results = {
|
| 785 |
+
dataset_name: {
|
| 786 |
+
"low": torch.tensor(dataset_stats["euler"]["q01"]),
|
| 787 |
+
"high": torch.tensor(dataset_stats["euler"]["q99"]),
|
| 788 |
+
}
|
| 789 |
+
for (dataset_name, dataset_stats) in stats.items()
|
| 790 |
+
}
|
| 791 |
+
else:
|
| 792 |
+
raise NotImplementedError(f"Normalization type {rotation_norm} not yet implemented")
|
| 793 |
+
else:
|
| 794 |
+
assert rotation_norm == Normalization.NONE, rotation_norm
|
| 795 |
+
if rotation_format == RotationFormat.EULER:
|
| 796 |
+
rotation_size = 3
|
| 797 |
+
elif rotation_format == RotationFormat.QUATERNION:
|
| 798 |
+
rotation_size = 4
|
| 799 |
+
else:
|
| 800 |
+
rotation_size = 9
|
| 801 |
+
results = {
|
| 802 |
+
dataset_name: {
|
| 803 |
+
"low": -1 * torch.ones(rotation_size, dtype=torch.float32),
|
| 804 |
+
"high": 1 * torch.ones(rotation_size, dtype=torch.float32),
|
| 805 |
+
}
|
| 806 |
+
for dataset_name in dataset_names
|
| 807 |
+
}
|
| 808 |
+
return results
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
def translation_norm_bounds(
|
| 812 |
+
translation_norm: Normalization | tuple,
|
| 813 |
+
stats: Dict[str, Dict[str, Dict[str, List[float]]]],
|
| 814 |
+
dataset_names: List[str],
|
| 815 |
+
) -> Dict[str, Dict[str, torch.Tensor]]:
|
| 816 |
+
if isinstance(translation_norm, (Normalization, str)) and translation_norm != Normalization.NONE:
|
| 817 |
+
if translation_norm == Normalization.BOUNDS:
|
| 818 |
+
results = {
|
| 819 |
+
dataset_name: {
|
| 820 |
+
"low": torch.tensor(dataset_stats["translation"]["min"]),
|
| 821 |
+
"high": torch.tensor(dataset_stats["translation"]["max"]),
|
| 822 |
+
}
|
| 823 |
+
for (dataset_name, dataset_stats) in stats.items()
|
| 824 |
+
}
|
| 825 |
+
elif translation_norm == Normalization.BOUNDS_Q99:
|
| 826 |
+
results = {
|
| 827 |
+
dataset_name: {
|
| 828 |
+
"low": torch.tensor(dataset_stats["translation"]["q01"]),
|
| 829 |
+
"high": torch.tensor(dataset_stats["translation"]["q99"]),
|
| 830 |
+
}
|
| 831 |
+
for (dataset_name, dataset_stats) in stats.items()
|
| 832 |
+
}
|
| 833 |
+
elif translation_norm == Normalization.MEAN:
|
| 834 |
+
results = {
|
| 835 |
+
dataset_name: {
|
| 836 |
+
"mean": torch.tensor(dataset_stats["translation"]["mean"]),
|
| 837 |
+
"std": torch.tensor(dataset_stats["translation"]["std"]),
|
| 838 |
+
}
|
| 839 |
+
for (dataset_name, dataset_stats) in stats.items()
|
| 840 |
+
}
|
| 841 |
+
else:
|
| 842 |
+
raise NotImplementedError(f"Normalization type {translation_norm} not yet implemented")
|
| 843 |
+
elif isinstance(translation_norm, Normalization) and translation_norm == Normalization.NONE:
|
| 844 |
+
results = {
|
| 845 |
+
dataset_name: {
|
| 846 |
+
"low": -1 * torch.ones(3, dtype=torch.float32),
|
| 847 |
+
"high": 1 * torch.ones(3, dtype=torch.float32),
|
| 848 |
+
}
|
| 849 |
+
for dataset_name in dataset_names
|
| 850 |
+
}
|
| 851 |
+
else:
|
| 852 |
+
assert isinstance(translation_norm, collections.abc.Mapping), type(translation_norm)
|
| 853 |
+
assert all((len(value) == 3 for value in translation_norm.values())), translation_norm
|
| 854 |
+
assert set(translation_norm.keys()) in (
|
| 855 |
+
{"low", "high"},
|
| 856 |
+
{"mean", "std"},
|
| 857 |
+
), translation_norm
|
| 858 |
+
results = {
|
| 859 |
+
dataset_name: {
|
| 860 |
+
key: torch.tensor(value, dtype=torch.float32) for (key, value) in translation_norm.items()
|
| 861 |
+
}
|
| 862 |
+
for dataset_name in dataset_names
|
| 863 |
+
}
|
| 864 |
+
return results
|
| 865 |
+
|
| 866 |
+
|
| 867 |
+
VLAMProcessorConfigT = TypeVar("VLAMProcessorConfigT")
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
class VLAMProcessor(Configurable):
|
| 871 |
+
def __init__(self, config: VLAMProcessorConfigT, vlm_processor: VLMProcessor):
|
| 872 |
+
super().__init__(config)
|
| 873 |
+
self.vlm_processor = vlm_processor
|
| 874 |
+
self.control_tokenizer = EmptyTokenizer(
|
| 875 |
+
config=self.config.control_tokenizer_config, tokenizer=self.tokenizer
|
| 876 |
+
)
|
| 877 |
+
self.norm_bounds: Dict[str, Dict[str, Dict[str, torch.Tensor]]] = {
|
| 878 |
+
"obs_translation": self.obs_translation_norm_bounds,
|
| 879 |
+
"obs_rotation": self.obs_rotation_norm_bounds,
|
| 880 |
+
"translation": self.translation_norm_bounds,
|
| 881 |
+
"rotation": self.rotation_norm_bounds,
|
| 882 |
+
"joints": self.joints_norm_bounds,
|
| 883 |
+
}
|
| 884 |
+
|
| 885 |
+
@property
|
| 886 |
+
def tokenizer(self) -> transformers.PreTrainedTokenizerBase:
|
| 887 |
+
return self.vlm_processor.tokenizer
|
| 888 |
+
|
| 889 |
+
@property
|
| 890 |
+
def image_sizes(self) -> Dict[str, ImageSizeConfig]:
|
| 891 |
+
return self.vlm_processor.image_sizes
|
| 892 |
+
|
| 893 |
+
@property
|
| 894 |
+
def camera_names(self) -> List[str]:
|
| 895 |
+
return list(self.vlm_processor.image_sizes.keys())
|
| 896 |
+
|
| 897 |
+
@property
|
| 898 |
+
def control_io_config(self) -> ControlDataIOConfig:
|
| 899 |
+
return self.config.control_io_config
|
| 900 |
+
|
| 901 |
+
@cached_property
|
| 902 |
+
def rotation_components(self) -> int:
|
| 903 |
+
if self.config.rotation_format == RotationFormat.EULER:
|
| 904 |
+
return 3
|
| 905 |
+
if self.config.rotation_format == RotationFormat.QUATERNION:
|
| 906 |
+
return 4
|
| 907 |
+
if self.config.rotation_format == RotationFormat.ROTMAT:
|
| 908 |
+
return 9
|
| 909 |
+
raise NotImplementedError(self.config.rotation_format)
|
| 910 |
+
|
| 911 |
+
@abstractmethod
|
| 912 |
+
def policy_control_plan_from_model_target(
|
| 913 |
+
self, target: RoboticsTarget, dataset_name: np.ndarray
|
| 914 |
+
) -> RoboticsControlPlan:
|
| 915 |
+
pass
|
| 916 |
+
|
| 917 |
+
@abstractmethod
|
| 918 |
+
def policy_control_plan_from_model_output(
|
| 919 |
+
self,
|
| 920 |
+
model_output: RoboticsOutput,
|
| 921 |
+
dataset_name: np.ndarray,
|
| 922 |
+
valid_mask: torch.Tensor,
|
| 923 |
+
) -> RoboticsControlPlan:
|
| 924 |
+
pass
|
| 925 |
+
|
| 926 |
+
def resize_image(
|
| 927 |
+
self, camera_name: str, image: PIL.Image.Image | np.ndarray
|
| 928 |
+
) -> PIL.Image.Image | np.ndarray:
|
| 929 |
+
return resize_image(
|
| 930 |
+
image,
|
| 931 |
+
target_size={
|
| 932 |
+
"width": self.image_sizes[camera_name].width,
|
| 933 |
+
"height": self.image_sizes[camera_name].height,
|
| 934 |
+
},
|
| 935 |
+
mode=self.config.image_resize,
|
| 936 |
+
resample=PIL.Image.Resampling.LANCZOS,
|
| 937 |
+
)
|
| 938 |
+
|
| 939 |
+
def preprocess_inputs(
|
| 940 |
+
self,
|
| 941 |
+
chat: List[str],
|
| 942 |
+
images: Dict[str, PIL.Image.Image | List[PIL.Image.Image]],
|
| 943 |
+
ee_pose_translation: np.ndarray,
|
| 944 |
+
ee_pose_rotation: np.ndarray,
|
| 945 |
+
gripper: np.ndarray,
|
| 946 |
+
joints: np.ndarray,
|
| 947 |
+
dataset_name: np.ndarray,
|
| 948 |
+
inference_mode: bool,
|
| 949 |
+
control_target: Optional[RoboticsTarget] = None,
|
| 950 |
+
) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]:
|
| 951 |
+
"""
|
| 952 |
+
Preprocess the inputs for a single example
|
| 953 |
+
Args:
|
| 954 |
+
instruction: Language instruction
|
| 955 |
+
images: History of input images with increasing timestamps
|
| 956 |
+
ee_pose_translation: np.ndarray, shape [..., num_past_scalars, 3]
|
| 957 |
+
ee_pose_rotation: np.ndarray, shape [..., num_past_scalars, 3 | 4 | 9]
|
| 958 |
+
joints: np.ndarray, shape [..., num_past_scalars, <= 7]
|
| 959 |
+
dataset_name: 1D np.ndarray
|
| 960 |
+
inference_mode: If True, prepare the input for inference (e.g. don't include target
|
| 961 |
+
any tokens in the input if relevant). If control_target is available, it should
|
| 962 |
+
still be preprocessed for test dataset comparison
|
| 963 |
+
control_target: RoboticsTarget, each component of shape
|
| 964 |
+
[..., num_control_steps, num_control_components]. Provided only when available, usually
|
| 965 |
+
during training and dataset test
|
| 966 |
+
Returns:
|
| 967 |
+
Dict containing torch.Tensor with inputs
|
| 968 |
+
"""
|
| 969 |
+
del control_target
|
| 970 |
+
del inference_mode
|
| 971 |
+
inputs = self.vlm_processor.preprocess_inputs(chat=chat, images=images)
|
| 972 |
+
images: Dict[str, torch.Tensor] = inputs["images"]
|
| 973 |
+
input_ids: torch.Tensor = inputs["input_ids"][..., : self.tokenizer.model_max_length]
|
| 974 |
+
target_text_tokens_ids: torch.Tensor = inputs["target_ids"][..., : self.tokenizer.model_max_length]
|
| 975 |
+
attn_mask = torch.ones(input_ids.shape, dtype=torch.bool)
|
| 976 |
+
ee_pose_translation = torch.tensor(ee_pose_translation, dtype=torch.float32)
|
| 977 |
+
ee_pose_rotation = torch.tensor(ee_pose_rotation, dtype=torch.float32)
|
| 978 |
+
ee_pose_rotation = convert_rotation(ee_pose_rotation, self.config.rotation_format, autonorm=True)
|
| 979 |
+
gripper = preprocess_gripper_observation(gripper, dataset_name)
|
| 980 |
+
gripper = torch.tensor(gripper, dtype=torch.float32)
|
| 981 |
+
ee_pose_translation = self.normalize(
|
| 982 |
+
ee_pose_translation, dataset_name=dataset_name, key="obs_translation"
|
| 983 |
+
)
|
| 984 |
+
ee_pose_rotation = self.normalize(ee_pose_rotation, dataset_name=dataset_name, key="obs_rotation")
|
| 985 |
+
joints = torch.tensor(joints, dtype=torch.float32)
|
| 986 |
+
if joints.shape[-1] < 7:
|
| 987 |
+
missing_size = 7 - joints.shape[-1]
|
| 988 |
+
joints = torch.cat([joints, torch.zeros([*joints.shape[:-1], missing_size])], dim=-1)
|
| 989 |
+
joints = self.normalize(joints, dataset_name=dataset_name, key="joints")
|
| 990 |
+
outputs = {
|
| 991 |
+
"images": images,
|
| 992 |
+
"input_ids": input_ids,
|
| 993 |
+
"target_text_tokens_ids": target_text_tokens_ids,
|
| 994 |
+
"attn_mask": attn_mask,
|
| 995 |
+
"ee_pose_translation": ee_pose_translation,
|
| 996 |
+
"ee_pose_rotation": ee_pose_rotation,
|
| 997 |
+
"gripper": gripper,
|
| 998 |
+
"joints": joints,
|
| 999 |
+
"control_tokens_ids": None,
|
| 1000 |
+
"target_control_tokens_ids": None,
|
| 1001 |
+
}
|
| 1002 |
+
return outputs
|
| 1003 |
+
|
| 1004 |
+
def create_input(
|
| 1005 |
+
self,
|
| 1006 |
+
chat: List[str],
|
| 1007 |
+
images: Dict[str, List[PIL.Image.Image]],
|
| 1008 |
+
ee_pose_translation: np.ndarray,
|
| 1009 |
+
ee_pose_rotation: np.ndarray,
|
| 1010 |
+
gripper: np.ndarray,
|
| 1011 |
+
joints: np.ndarray,
|
| 1012 |
+
dataset_name: np.ndarray,
|
| 1013 |
+
inference_mode: bool,
|
| 1014 |
+
control_target: Optional[RoboticsTarget] = None,
|
| 1015 |
+
) -> RoboticsInput:
|
| 1016 |
+
inputs = self.preprocess_inputs(
|
| 1017 |
+
chat=chat,
|
| 1018 |
+
images=images,
|
| 1019 |
+
ee_pose_translation=ee_pose_translation,
|
| 1020 |
+
ee_pose_rotation=ee_pose_rotation,
|
| 1021 |
+
gripper=gripper,
|
| 1022 |
+
joints=joints,
|
| 1023 |
+
dataset_name=dataset_name,
|
| 1024 |
+
inference_mode=inference_mode,
|
| 1025 |
+
control_target=control_target,
|
| 1026 |
+
)
|
| 1027 |
+
inputs.pop("target_text_tokens_ids")
|
| 1028 |
+
inputs.pop("target_control_tokens_ids")
|
| 1029 |
+
return RoboticsInput(**inputs)
|
| 1030 |
+
|
| 1031 |
+
def normalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor:
|
| 1032 |
+
if is_mean_norm(getattr(self.config, f"{key}_norm")):
|
| 1033 |
+
(mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
|
| 1034 |
+
output = normalize_by_moments(value, mean=mean, std=std)
|
| 1035 |
+
else:
|
| 1036 |
+
(low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
|
| 1037 |
+
output = normalize_by_bounds(value, low=low, high=high)
|
| 1038 |
+
return output
|
| 1039 |
+
|
| 1040 |
+
def unnormalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor:
|
| 1041 |
+
if is_mean_norm(getattr(self.config, f"{key}_norm")):
|
| 1042 |
+
(mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
|
| 1043 |
+
output = unnormalize_by_moments(value, mean=mean, std=std)
|
| 1044 |
+
else:
|
| 1045 |
+
(low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
|
| 1046 |
+
output = unnormalize_by_bounds(value, low=low, high=high)
|
| 1047 |
+
return output
|
| 1048 |
+
|
| 1049 |
+
def _norm_bounds_from_dataset_name(
|
| 1050 |
+
self, dataset_name: np.ndarray, component_key: str
|
| 1051 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 1052 |
+
"""
|
| 1053 |
+
Create an array of normalization bounds corresponding to dataset names
|
| 1054 |
+
Args:
|
| 1055 |
+
dataset_name: Array of shape [B] of dataset names for which to fetch the low and high
|
| 1056 |
+
normalization bounds. Note the values can be repeating
|
| 1057 |
+
component_key: str. One of 'action', 'translation', 'rotation'. Indicates for which control to
|
| 1058 |
+
compute the normalization bounds
|
| 1059 |
+
Returns:
|
| 1060 |
+
Tuple of low and high bounds or norm and std, each of shape [B, -1]
|
| 1061 |
+
"""
|
| 1062 |
+
norm = getattr(self.config, f"{component_key}_norm")
|
| 1063 |
+
if is_mean_norm(norm):
|
| 1064 |
+
(stats_key_1, stats_key_2) = ("mean", "std")
|
| 1065 |
+
else:
|
| 1066 |
+
(stats_key_1, stats_key_2) = ("low", "high")
|
| 1067 |
+
if component_key == "joints":
|
| 1068 |
+
if not isinstance(norm, collections.abc.Mapping):
|
| 1069 |
+
raise NotImplementedError()
|
| 1070 |
+
stats = {
|
| 1071 |
+
key: torch.from_numpy(np.tile(np.reshape(value, [1, -1]), [len(dataset_name), 1]))
|
| 1072 |
+
for (key, value) in self.joints_norm_bounds["ANY"].items()
|
| 1073 |
+
}
|
| 1074 |
+
return tuple(stats.values())
|
| 1075 |
+
component_size = list(list(self.norm_bounds[component_key].values())[0].values())[0].shape[-1]
|
| 1076 |
+
if self.dataset_names == ["ANY"]:
|
| 1077 |
+
stats_1 = self.norm_bounds[component_key]["ANY"][stats_key_1]
|
| 1078 |
+
stats_2 = self.norm_bounds[component_key]["ANY"][stats_key_2]
|
| 1079 |
+
stats_1 = np.repeat(np.expand_dims(stats_1, axis=0), len(dataset_name), axis=0)
|
| 1080 |
+
stats_2 = np.repeat(np.expand_dims(stats_2, axis=0), len(dataset_name), axis=0)
|
| 1081 |
+
else:
|
| 1082 |
+
(unique_names, _, inverse_indices, _) = np_unique(dataset_name)
|
| 1083 |
+
stats_1 = np.zeros([len(unique_names), component_size], dtype=np.float32)
|
| 1084 |
+
stats_2 = np.zeros([len(unique_names), component_size], dtype=np.float32)
|
| 1085 |
+
for i, ds_name in enumerate(unique_names):
|
| 1086 |
+
stats_1[i] = self.norm_bounds[component_key][ds_name][stats_key_1].numpy()
|
| 1087 |
+
stats_2[i] = self.norm_bounds[component_key][ds_name][stats_key_2].numpy()
|
| 1088 |
+
stats_1 = stats_1[inverse_indices]
|
| 1089 |
+
stats_2 = stats_2[inverse_indices]
|
| 1090 |
+
return torch.from_numpy(stats_1), torch.from_numpy(stats_2)
|
| 1091 |
+
|
| 1092 |
+
@cached_property
|
| 1093 |
+
def obs_rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
|
| 1094 |
+
return rotation_norm_bounds(
|
| 1095 |
+
rotation_norm=self.config.obs_rotation_norm,
|
| 1096 |
+
rotation_format=self.config.rotation_format,
|
| 1097 |
+
stats=self._observation_stats,
|
| 1098 |
+
dataset_names=self.dataset_names,
|
| 1099 |
+
)
|
| 1100 |
+
|
| 1101 |
+
@cached_property
|
| 1102 |
+
def obs_translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
|
| 1103 |
+
return translation_norm_bounds(
|
| 1104 |
+
translation_norm=self.config.obs_translation_norm,
|
| 1105 |
+
stats=self._observation_stats,
|
| 1106 |
+
dataset_names=self.dataset_names,
|
| 1107 |
+
)
|
| 1108 |
+
|
| 1109 |
+
@cached_property
|
| 1110 |
+
def rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
|
| 1111 |
+
return rotation_norm_bounds(
|
| 1112 |
+
rotation_norm=self.config.rotation_norm,
|
| 1113 |
+
rotation_format=self.config.rotation_format,
|
| 1114 |
+
stats=self._control_stats,
|
| 1115 |
+
dataset_names=self.dataset_names,
|
| 1116 |
+
)
|
| 1117 |
+
|
| 1118 |
+
@cached_property
|
| 1119 |
+
def translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
|
| 1120 |
+
return translation_norm_bounds(
|
| 1121 |
+
translation_norm=self.config.translation_norm,
|
| 1122 |
+
stats=self._control_stats,
|
| 1123 |
+
dataset_names=self.dataset_names,
|
| 1124 |
+
)
|
| 1125 |
+
|
| 1126 |
+
@cached_property
|
| 1127 |
+
def joints_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
|
| 1128 |
+
"""
|
| 1129 |
+
NOTE:
|
| 1130 |
+
- Joint values across all joints and all datasets vary in the range [-2pi; 2pi]
|
| 1131 |
+
- The effective range of a single joint is in practice one of [-2pi; 0], [-pi; pi], [0; 2pi]
|
| 1132 |
+
- It's possible to shift all ranges to [-pi; pi], but it requires careful handling for each joint
|
| 1133 |
+
"""
|
| 1134 |
+
low = torch.tensor(self.config.joints_norm["low"], dtype=torch.float32)
|
| 1135 |
+
high = torch.tensor(self.config.joints_norm["high"], dtype=torch.float32)
|
| 1136 |
+
results = {"ANY": {"low": low, "high": high}}
|
| 1137 |
+
return results
|
| 1138 |
+
|
| 1139 |
+
@cached_property
|
| 1140 |
+
def _observation_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]:
|
| 1141 |
+
return {
|
| 1142 |
+
"bridge": {
|
| 1143 |
+
"euler": {
|
| 1144 |
+
"max": [3.141592653589793, 1.570796251296997, 3.141204357147217],
|
| 1145 |
+
"mean": [
|
| 1146 |
+
-0.25754162314671525,
|
| 1147 |
+
-0.12370228389510128,
|
| 1148 |
+
0.1620053749182691,
|
| 1149 |
+
],
|
| 1150 |
+
"min": [-3.141592653492551, -1.4832241535186768, -3.14153790473938],
|
| 1151 |
+
"q01": [-3.138795563420751, -0.56544608771801, -1.4952478170394896],
|
| 1152 |
+
"q99": [3.138720980629329, 0.2677614077925682, 2.0032371997833236],
|
| 1153 |
+
"std": [3.0257414011616577, 0.1622662085147332, 0.6404942954645315],
|
| 1154 |
+
},
|
| 1155 |
+
"gripper": {
|
| 1156 |
+
"max": [1.0370277166366577],
|
| 1157 |
+
"min": [0.04637829214334488],
|
| 1158 |
+
"q01": [0.05192930996417999],
|
| 1159 |
+
"q99": [1.0118417739868164],
|
| 1160 |
+
},
|
| 1161 |
+
"joints": {
|
| 1162 |
+
"max": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
| 1163 |
+
"mean": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
| 1164 |
+
"min": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
| 1165 |
+
"q01": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
| 1166 |
+
"q99": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
| 1167 |
+
"std": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
| 1168 |
+
},
|
| 1169 |
+
"translation": {
|
| 1170 |
+
"max": [0.5862360596656799, 0.4034728705883026, 0.3568263053894043],
|
| 1171 |
+
"mean": [
|
| 1172 |
+
0.309032678604126,
|
| 1173 |
+
0.03403777256608009,
|
| 1174 |
+
0.061277542263269424,
|
| 1175 |
+
],
|
| 1176 |
+
"min": [
|
| 1177 |
+
-0.04167502000927925,
|
| 1178 |
+
-0.2889411449432373,
|
| 1179 |
+
-0.13934996724128723,
|
| 1180 |
+
],
|
| 1181 |
+
"q01": [
|
| 1182 |
+
0.1711955964565277,
|
| 1183 |
+
-0.15639324486255646,
|
| 1184 |
+
-0.048255354166030884,
|
| 1185 |
+
],
|
| 1186 |
+
"q99": [
|
| 1187 |
+
0.4604376256465912,
|
| 1188 |
+
0.24112474918365479,
|
| 1189 |
+
0.18886254727840424,
|
| 1190 |
+
],
|
| 1191 |
+
"std": [
|
| 1192 |
+
0.0635896623134613,
|
| 1193 |
+
0.09153717756271362,
|
| 1194 |
+
0.049334850162267685,
|
| 1195 |
+
],
|
| 1196 |
+
},
|
| 1197 |
+
},
|
| 1198 |
+
"bridge_orig": {
|
| 1199 |
+
"euler": {
|
| 1200 |
+
"max": [3.141592653589793, 1.570796251296997, 3.141204357147217],
|
| 1201 |
+
"mean": [
|
| 1202 |
+
-0.25754162314671525,
|
| 1203 |
+
-0.12370228389510128,
|
| 1204 |
+
0.1620053749182691,
|
| 1205 |
+
],
|
| 1206 |
+
"min": [-3.141592653492551, -1.4832241535186768, -3.14153790473938],
|
| 1207 |
+
"q01": [-3.138795563420751, -0.56544608771801, -1.4952478170394896],
|
| 1208 |
+
"q99": [3.138720980629329, 0.2677614077925682, 2.0032371997833236],
|
| 1209 |
+
"std": [3.0257414011616577, 0.1622662085147332, 0.6404942954645315],
|
| 1210 |
+
},
|
| 1211 |
+
"gripper": {
|
| 1212 |
+
"max": [1.0370277166366577],
|
| 1213 |
+
"min": [0.04637829214334488],
|
| 1214 |
+
"q01": [0.05192930996417999],
|
| 1215 |
+
"q99": [1.0118417739868164],
|
| 1216 |
+
},
|
| 1217 |
+
"joints": {
|
| 1218 |
+
"max": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
| 1219 |
+
"mean": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
| 1220 |
+
"min": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
| 1221 |
+
"q01": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
| 1222 |
+
"q99": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
| 1223 |
+
"std": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
| 1224 |
+
},
|
| 1225 |
+
"translation": {
|
| 1226 |
+
"max": [0.5862360596656799, 0.4034728705883026, 0.3568263053894043],
|
| 1227 |
+
"mean": [
|
| 1228 |
+
0.309032678604126,
|
| 1229 |
+
0.03403777256608009,
|
| 1230 |
+
0.061277542263269424,
|
| 1231 |
+
],
|
| 1232 |
+
"min": [
|
| 1233 |
+
-0.04167502000927925,
|
| 1234 |
+
-0.2889411449432373,
|
| 1235 |
+
-0.13934996724128723,
|
| 1236 |
+
],
|
| 1237 |
+
"q01": [
|
| 1238 |
+
0.1711955964565277,
|
| 1239 |
+
-0.15639324486255646,
|
| 1240 |
+
-0.048255354166030884,
|
| 1241 |
+
],
|
| 1242 |
+
"q99": [
|
| 1243 |
+
0.4604376256465912,
|
| 1244 |
+
0.24112474918365479,
|
| 1245 |
+
0.18886254727840424,
|
| 1246 |
+
],
|
| 1247 |
+
"std": [
|
| 1248 |
+
0.0635896623134613,
|
| 1249 |
+
0.09153717756271362,
|
| 1250 |
+
0.049334850162267685,
|
| 1251 |
+
],
|
| 1252 |
+
},
|
| 1253 |
+
},
|
| 1254 |
+
"droid": {
|
| 1255 |
+
"euler": {
|
| 1256 |
+
"max": [3.141592502593994, 1.5705928802490234, 3.1415867805480957],
|
| 1257 |
+
"mean": [
|
| 1258 |
+
0.3140628098409554,
|
| 1259 |
+
-0.09296274023036387,
|
| 1260 |
+
-0.07227215454779846,
|
| 1261 |
+
],
|
| 1262 |
+
"min": [
|
| 1263 |
+
-3.141592502593994,
|
| 1264 |
+
-1.5691150426864624,
|
| 1265 |
+
-3.1415374279022217,
|
| 1266 |
+
],
|
| 1267 |
+
"q01": [
|
| 1268 |
+
-3.1378602981567383,
|
| 1269 |
+
-1.2125312042236327,
|
| 1270 |
+
-2.1614069032669065,
|
| 1271 |
+
],
|
| 1272 |
+
"q99": [3.137854380607605, 0.9200375998020163, 1.9367506909370364],
|
| 1273 |
+
"std": [2.926265757944871, 0.363273475703332, 0.7576065217938824],
|
| 1274 |
+
},
|
| 1275 |
+
"gripper": {
|
| 1276 |
+
"max": [1.0],
|
| 1277 |
+
"min": [0.0],
|
| 1278 |
+
"q01": [0.0],
|
| 1279 |
+
"q99": [0.9911894202232361],
|
| 1280 |
+
},
|
| 1281 |
+
"joints": {
|
| 1282 |
+
"max": [
|
| 1283 |
+
2.668445110321045,
|
| 1284 |
+
1.5691218376159668,
|
| 1285 |
+
2.666306734085083,
|
| 1286 |
+
-0.3114914000034332,
|
| 1287 |
+
2.6624162197113037,
|
| 1288 |
+
4.28157901763916,
|
| 1289 |
+
2.752457857131958,
|
| 1290 |
+
],
|
| 1291 |
+
"mean": [
|
| 1292 |
+
0.023137084334640106,
|
| 1293 |
+
0.2704989977282293,
|
| 1294 |
+
-0.01451389357228282,
|
| 1295 |
+
-2.018709403792315,
|
| 1296 |
+
-0.042720520800030394,
|
| 1297 |
+
2.350281188152209,
|
| 1298 |
+
0.12424663946659845,
|
| 1299 |
+
],
|
| 1300 |
+
"min": [
|
| 1301 |
+
-2.6536705493927,
|
| 1302 |
+
-1.547789216041565,
|
| 1303 |
+
-2.6781487464904785,
|
| 1304 |
+
-2.9409868717193604,
|
| 1305 |
+
-2.6705946922302246,
|
| 1306 |
+
0.24893812835216522,
|
| 1307 |
+
-2.7615714073181152,
|
| 1308 |
+
],
|
| 1309 |
+
"q01": [
|
| 1310 |
+
-0.9026106441020965,
|
| 1311 |
+
-0.8547340619564057,
|
| 1312 |
+
-0.9028875434398651,
|
| 1313 |
+
-2.7698556280136106,
|
| 1314 |
+
-1.6851656341552732,
|
| 1315 |
+
1.2335169839859008,
|
| 1316 |
+
-1.9587260699272155,
|
| 1317 |
+
],
|
| 1318 |
+
"q99": [
|
| 1319 |
+
0.9569852340221403,
|
| 1320 |
+
1.4148830294609054,
|
| 1321 |
+
0.7693877756595566,
|
| 1322 |
+
-0.4545914208889008,
|
| 1323 |
+
1.5623322343826267,
|
| 1324 |
+
3.475611729621887,
|
| 1325 |
+
2.263479118347167,
|
| 1326 |
+
],
|
| 1327 |
+
"std": [
|
| 1328 |
+
0.31695080251469465,
|
| 1329 |
+
0.49522214687158767,
|
| 1330 |
+
0.27993538230553827,
|
| 1331 |
+
0.478161574676113,
|
| 1332 |
+
0.4969961591445458,
|
| 1333 |
+
0.45101008525403846,
|
| 1334 |
+
0.7287264344068457,
|
| 1335 |
+
],
|
| 1336 |
+
},
|
| 1337 |
+
"translation": {
|
| 1338 |
+
"max": [0.8575563430786133, 0.799155592918396, 1.0043904781341553],
|
| 1339 |
+
"mean": [
|
| 1340 |
+
0.5283099395864883,
|
| 1341 |
+
0.005363794653877434,
|
| 1342 |
+
0.3120132207021294,
|
| 1343 |
+
],
|
| 1344 |
+
"min": [
|
| 1345 |
+
-0.15604186058044434,
|
| 1346 |
+
-0.827903687953949,
|
| 1347 |
+
-0.2347021996974945,
|
| 1348 |
+
],
|
| 1349 |
+
"q01": [
|
| 1350 |
+
0.26669957995414734,
|
| 1351 |
+
-0.43774398624897004,
|
| 1352 |
+
-0.048167889714241026,
|
| 1353 |
+
],
|
| 1354 |
+
"q99": [0.7774086785316463, 0.428325751423835, 0.776091011762619],
|
| 1355 |
+
"std": [
|
| 1356 |
+
0.1148424841779685,
|
| 1357 |
+
0.17489566608140428,
|
| 1358 |
+
0.16541062032731538,
|
| 1359 |
+
],
|
| 1360 |
+
},
|
| 1361 |
+
},
|
| 1362 |
+
"roboset": {
|
| 1363 |
+
"euler": {
|
| 1364 |
+
"max": [3.1415449294818236, 1.5705575529715636, 3.141527342124582],
|
| 1365 |
+
"mean": [
|
| 1366 |
+
-0.0398455755412464,
|
| 1367 |
+
1.0518070390619125,
|
| 1368 |
+
-0.015345692503002759,
|
| 1369 |
+
],
|
| 1370 |
+
"min": [
|
| 1371 |
+
-3.1415813300509536,
|
| 1372 |
+
-1.5222832468962035,
|
| 1373 |
+
-3.141575300866071,
|
| 1374 |
+
],
|
| 1375 |
+
"q01": [
|
| 1376 |
+
-2.9414386317311187,
|
| 1377 |
+
-0.24976770655101155,
|
| 1378 |
+
-2.985256521212579,
|
| 1379 |
+
],
|
| 1380 |
+
"q99": [2.9380437893235993, 1.5403010739503078, 2.9746912523985025],
|
| 1381 |
+
"std": [1.7866587696177456, 0.40620530263065, 1.7288511340250616],
|
| 1382 |
+
},
|
| 1383 |
+
"gripper": {
|
| 1384 |
+
"max": [0.83056640625],
|
| 1385 |
+
"min": [0.0001499652862548828],
|
| 1386 |
+
"q01": [0.0001499652862548828],
|
| 1387 |
+
"q99": [0.82666015625],
|
| 1388 |
+
},
|
| 1389 |
+
"joints": {
|
| 1390 |
+
"max": [
|
| 1391 |
+
0.96240234375,
|
| 1392 |
+
1.1162109375,
|
| 1393 |
+
1.1064453125,
|
| 1394 |
+
-0.98095703125,
|
| 1395 |
+
2.30859375,
|
| 1396 |
+
1.576171875,
|
| 1397 |
+
1.7412109375,
|
| 1398 |
+
],
|
| 1399 |
+
"mean": [
|
| 1400 |
+
0.005913593806326389,
|
| 1401 |
+
0.1877261847257614,
|
| 1402 |
+
0.04653879255056381,
|
| 1403 |
+
-2.0529513359069824,
|
| 1404 |
+
-0.011298442259430885,
|
| 1405 |
+
0.6185526251792908,
|
| 1406 |
+
-0.01701134257018566,
|
| 1407 |
+
],
|
| 1408 |
+
"min": [
|
| 1409 |
+
-0.8330078125,
|
| 1410 |
+
-0.74658203125,
|
| 1411 |
+
-0.8642578125,
|
| 1412 |
+
-2.892578125,
|
| 1413 |
+
-1.390625,
|
| 1414 |
+
-0.24658203125,
|
| 1415 |
+
-2.953125,
|
| 1416 |
+
],
|
| 1417 |
+
"q01": [
|
| 1418 |
+
-0.41015625,
|
| 1419 |
+
-0.5302734375,
|
| 1420 |
+
-0.6455078125,
|
| 1421 |
+
-2.57421875,
|
| 1422 |
+
-0.76416015625,
|
| 1423 |
+
-0.0386962890625,
|
| 1424 |
+
-1.435546875,
|
| 1425 |
+
],
|
| 1426 |
+
"q99": [
|
| 1427 |
+
0.66455078125,
|
| 1428 |
+
0.9501953125,
|
| 1429 |
+
0.7529296875,
|
| 1430 |
+
-1.251953125,
|
| 1431 |
+
0.75244140625,
|
| 1432 |
+
1.2314453125,
|
| 1433 |
+
1.384765625,
|
| 1434 |
+
],
|
| 1435 |
+
"std": [
|
| 1436 |
+
0.17915399372577667,
|
| 1437 |
+
0.32234326004981995,
|
| 1438 |
+
0.26069700717926025,
|
| 1439 |
+
0.31767210364341736,
|
| 1440 |
+
0.205329030752182,
|
| 1441 |
+
0.33385637402534485,
|
| 1442 |
+
0.6263682842254639,
|
| 1443 |
+
],
|
| 1444 |
+
},
|
| 1445 |
+
"translation": {
|
| 1446 |
+
"max": [0.5747738480567932, 0.3972920775413513, 0.7443570494651794],
|
| 1447 |
+
"mean": [
|
| 1448 |
+
0.3331542909145355,
|
| 1449 |
+
0.019357483834028244,
|
| 1450 |
+
0.37330344319343567,
|
| 1451 |
+
],
|
| 1452 |
+
"min": [
|
| 1453 |
+
0.09978063404560089,
|
| 1454 |
+
-0.29593944549560547,
|
| 1455 |
+
0.10065606236457825,
|
| 1456 |
+
],
|
| 1457 |
+
"q01": [
|
| 1458 |
+
0.18437016010284424,
|
| 1459 |
+
-0.25699371099472046,
|
| 1460 |
+
0.15134164690971375,
|
| 1461 |
+
],
|
| 1462 |
+
"q99": [0.543661892414093, 0.29646238684654236, 0.6682320833206177],
|
| 1463 |
+
"std": [
|
| 1464 |
+
0.07849054038524628,
|
| 1465 |
+
0.12241040915250778,
|
| 1466 |
+
0.1460595279932022,
|
| 1467 |
+
],
|
| 1468 |
+
},
|
| 1469 |
+
},
|
| 1470 |
+
}
|
| 1471 |
+
|
| 1472 |
+
@cached_property
|
| 1473 |
+
def _control_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]:
|
| 1474 |
+
if is_global_norm(self.config.rotation_norm) and is_global_norm(self.config.translation_norm):
|
| 1475 |
+
return {}
|
| 1476 |
+
with open(self.config.control_stats_path, "r") as file:
|
| 1477 |
+
stats = yaml.safe_load(file)
|
| 1478 |
+
if self.config.delta_controls:
|
| 1479 |
+
if self.control_io_config.future_controls_sequence_stride_sec is None:
|
| 1480 |
+
horizon = 0.0
|
| 1481 |
+
else:
|
| 1482 |
+
horizon = self.control_io_config.future_controls_sequence_stride_sec
|
| 1483 |
+
elif self.control_io_config.future_controls_sequence_stride_sec is None:
|
| 1484 |
+
if self.control_io_config.future_controls_sequence_length == 1:
|
| 1485 |
+
horizon = 0.0
|
| 1486 |
+
else:
|
| 1487 |
+
raise NotImplementedError()
|
| 1488 |
+
else:
|
| 1489 |
+
horizon = (
|
| 1490 |
+
self.control_io_config.future_controls_sequence_length
|
| 1491 |
+
* self.control_io_config.future_controls_sequence_stride_sec
|
| 1492 |
+
)
|
| 1493 |
+
key = f"horizon_{round(horizon, 2)}s"
|
| 1494 |
+
if key in stats:
|
| 1495 |
+
stats = stats[key]
|
| 1496 |
+
else:
|
| 1497 |
+
raise ValueError(
|
| 1498 |
+
f"Missing control statistics key {key} for future_controls_sequence_length={self.config.control_io_config.future_controls_sequence_length} future_controls_sequence_stride_sec={self.config.control_io_config.future_controls_sequence_stride_sec}. Available keys: [{stats.keys()}]"
|
| 1499 |
+
)
|
| 1500 |
+
return stats
|
| 1501 |
+
|
| 1502 |
+
@cached_property
|
| 1503 |
+
def dataset_names(self) -> List[str]:
|
| 1504 |
+
if (
|
| 1505 |
+
is_global_norm(self.config.rotation_norm)
|
| 1506 |
+
and is_global_norm(self.config.obs_rotation_norm)
|
| 1507 |
+
and is_global_norm(self.config.translation_norm)
|
| 1508 |
+
and is_global_norm(self.config.obs_translation_norm)
|
| 1509 |
+
):
|
| 1510 |
+
return ["ANY"]
|
| 1511 |
+
return list(set(self._control_stats.keys()) | set(self._observation_stats.keys()))
|
| 1512 |
+
|
| 1513 |
+
|
| 1514 |
+
def delta_to_relative_translations(translation_sequence: torch.Tensor) -> torch.Tensor:
|
| 1515 |
+
"""
|
| 1516 |
+
Transform a sequence of translation vectors encoded w.r.t. PREVIOUS frame in the sequence to encoding
|
| 1517 |
+
w.r.t. the 0-th element preceding the sequence
|
| 1518 |
+
Ex:
|
| 1519 |
+
Sequence of points: T1, T2, T3, T4
|
| 1520 |
+
`translation_sequence` contains the vectors: T0T1, T1T2, T2T3, T3T4, where T0 is the base frame,
|
| 1521 |
+
implicitly encoded in T0T1
|
| 1522 |
+
Output: T0T1, T0T2, T0T3, T0T4
|
| 1523 |
+
|
| 1524 |
+
Args:
|
| 1525 |
+
translation_sequence: torch.Tensor of shape [..., S, 3], containing the translation vectors, where S
|
| 1526 |
+
corresponds to the sequence dimension
|
| 1527 |
+
Returns:
|
| 1528 |
+
torch.Tensor of the same shape as translation_sequence, containing delta translations
|
| 1529 |
+
"""
|
| 1530 |
+
assert translation_sequence.ndim >= 3, translation_sequence.shape
|
| 1531 |
+
delta_translations = torch.cumsum(translation_sequence, dim=-2)
|
| 1532 |
+
return delta_translations
|
| 1533 |
+
|
| 1534 |
+
|
| 1535 |
+
class RegressionProcessor(VLAMProcessor):
|
| 1536 |
+
def policy_control_plan_from_model_target(
|
| 1537 |
+
self, target: RoboticsTarget, dataset_name: np.ndarray
|
| 1538 |
+
) -> RoboticsControlPlan:
|
| 1539 |
+
translation_m = self.unnormalize(target.translation, dataset_name=dataset_name, key="translation")
|
| 1540 |
+
rotation = self.unnormalize(target.rotation, dataset_name=dataset_name, key="rotation")
|
| 1541 |
+
rotmat = convert_rotation(rotation, RotationFormat.ROTMAT)
|
| 1542 |
+
gripper_prob = target.gripper
|
| 1543 |
+
if self.config.delta_controls:
|
| 1544 |
+
translation_m = delta_to_relative_translations(translation_m)
|
| 1545 |
+
rotmat = delta_to_relative_rotations(rotmat)
|
| 1546 |
+
return RoboticsControlPlan(
|
| 1547 |
+
translation_m=translation_m,
|
| 1548 |
+
rotmat=rotmat,
|
| 1549 |
+
gripper_prob=gripper_prob,
|
| 1550 |
+
valid_mask=target.valid_mask,
|
| 1551 |
+
)
|
| 1552 |
+
|
| 1553 |
+
def policy_control_plan_from_model_output(
|
| 1554 |
+
self,
|
| 1555 |
+
model_output: RoboticsOutput,
|
| 1556 |
+
dataset_name: np.ndarray,
|
| 1557 |
+
valid_mask: torch.Tensor,
|
| 1558 |
+
) -> RoboticsControlPlan:
|
| 1559 |
+
"""Called during inference to create control plan from model output"""
|
| 1560 |
+
translation_m = self.unnormalize(
|
| 1561 |
+
model_output.translation, dataset_name=dataset_name, key="translation"
|
| 1562 |
+
)
|
| 1563 |
+
rotation = self.unnormalize(model_output.rotation, dataset_name=dataset_name, key="rotation")
|
| 1564 |
+
rotmat = convert_rotation(rotation, RotationFormat.ROTMAT, autonorm=True)
|
| 1565 |
+
gripper_prob = torch.sigmoid(model_output.gripper)
|
| 1566 |
+
if self.config.delta_controls:
|
| 1567 |
+
translation_m = delta_to_relative_translations(translation_m)
|
| 1568 |
+
rotmat = delta_to_relative_rotations(rotmat)
|
| 1569 |
+
return RoboticsControlPlan(
|
| 1570 |
+
translation_m=translation_m,
|
| 1571 |
+
rotmat=rotmat,
|
| 1572 |
+
gripper_prob=gripper_prob,
|
| 1573 |
+
valid_mask=valid_mask,
|
| 1574 |
+
)
|
| 1575 |
+
|
| 1576 |
+
|
| 1577 |
+
class PiZeroFlowMatchingProcessor(RegressionProcessor):
|
| 1578 |
+
def __init__(self, **kwargs):
|
| 1579 |
+
super().__init__(**kwargs)
|
| 1580 |
+
self.generator: torch.Generator = torch.Generator()
|
| 1581 |
+
|
| 1582 |
+
@cached_property
|
| 1583 |
+
def beta_distribution(self) -> torch.distributions.Beta:
|
| 1584 |
+
return torch.distributions.Beta(
|
| 1585 |
+
self.config.distribution_hyperparams.get("alpha", 1.5),
|
| 1586 |
+
self.config.distribution_hyperparams.get("beta", 1.0),
|
| 1587 |
+
)
|
| 1588 |
+
|
| 1589 |
+
def create_input(self, *args, **kwargs) -> RoboticsFlowInput:
|
| 1590 |
+
"""In practice used only during inference"""
|
| 1591 |
+
inputs = super().create_input(*args, **kwargs)
|
| 1592 |
+
flow_input: FlowInput = self.sample_t0_input(batch_size=1, device=torch.device("cpu"))
|
| 1593 |
+
inputs = RoboticsFlowInput(**inputs.as_json(), flow_input=flow_input[0, ...])
|
| 1594 |
+
return inputs
|
| 1595 |
+
|
| 1596 |
+
def sample_timestep(self, batch_size: int) -> torch.Tensor:
|
| 1597 |
+
if self.config.timestep_distribution.lower() == "uniform":
|
| 1598 |
+
eps = 1e-05
|
| 1599 |
+
sample = (torch.rand(1, generator=self.generator) + torch.arange(batch_size) / batch_size) % (
|
| 1600 |
+
1 - eps
|
| 1601 |
+
)
|
| 1602 |
+
elif self.config.timestep_distribution.lower() == "beta":
|
| 1603 |
+
sample = self.beta_distribution.sample([batch_size, 1, 1])
|
| 1604 |
+
sample = (1 - self.config.sig_min) * (1 - sample)
|
| 1605 |
+
else:
|
| 1606 |
+
raise NotImplementedError(self.config.timestep_distribution)
|
| 1607 |
+
sample = sample.view(batch_size, 1, 1)
|
| 1608 |
+
return sample
|
| 1609 |
+
|
| 1610 |
+
def _psi_t(self, timestep: torch.Tensor, x_0: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor:
|
| 1611 |
+
return (1 - (1 - self.config.sig_min) * timestep) * x_0 + timestep * x_1
|
| 1612 |
+
|
| 1613 |
+
def _dpsi_dt(self, x_0: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor:
|
| 1614 |
+
return x_1 - (1 - self.config.sig_min) * x_0
|
| 1615 |
+
|
| 1616 |
+
def sample_t0_input(self, batch_size: int, device: torch.device) -> FlowInput:
|
| 1617 |
+
if self.config.r0_distribution == "normal":
|
| 1618 |
+
controls_t0 = torch.randn(
|
| 1619 |
+
[
|
| 1620 |
+
batch_size,
|
| 1621 |
+
self.config.control_io_config.future_controls_sequence_length,
|
| 1622 |
+
3 + self.rotation_components + 1,
|
| 1623 |
+
],
|
| 1624 |
+
generator=self.generator,
|
| 1625 |
+
).to(device=device)
|
| 1626 |
+
(translation_t0, rotation_t0, gripper_t0) = torch.split(
|
| 1627 |
+
controls_t0, [3, self.rotation_components, 1], dim=-1
|
| 1628 |
+
)
|
| 1629 |
+
rotation_t0 = normalize_rotation(rotation_t0)
|
| 1630 |
+
elif self.config.r0_distribution == "uniform":
|
| 1631 |
+
controls_t0 = torch.randn(
|
| 1632 |
+
[
|
| 1633 |
+
batch_size,
|
| 1634 |
+
self.config.control_io_config.future_controls_sequence_length,
|
| 1635 |
+
4,
|
| 1636 |
+
],
|
| 1637 |
+
generator=self.generator,
|
| 1638 |
+
).to(device=device)
|
| 1639 |
+
(translation_t0, gripper_t0) = torch.split(controls_t0, [3, 1], dim=-1)
|
| 1640 |
+
rotation_t0 = convert_rotation(
|
| 1641 |
+
roma.random_unitquat(
|
| 1642 |
+
(
|
| 1643 |
+
batch_size,
|
| 1644 |
+
self.config.control_io_config.future_controls_sequence_length,
|
| 1645 |
+
),
|
| 1646 |
+
device=device,
|
| 1647 |
+
),
|
| 1648 |
+
self.config.rotation_format,
|
| 1649 |
+
)
|
| 1650 |
+
else:
|
| 1651 |
+
raise NotImplementedError(self.config.r0_distribution)
|
| 1652 |
+
if self.config.rotation_format == RotationFormat.QUATERNION:
|
| 1653 |
+
rotation_t0 = quaternion_half_cover(rotation_t0)
|
| 1654 |
+
timestep = torch.zeros([batch_size, 1, 1], device=device)
|
| 1655 |
+
return FlowInput(
|
| 1656 |
+
timestep=timestep,
|
| 1657 |
+
translation_t0=translation_t0,
|
| 1658 |
+
rotation_t0=rotation_t0,
|
| 1659 |
+
gripper_t0=gripper_t0,
|
| 1660 |
+
translation_t=None,
|
| 1661 |
+
rotation_t=None,
|
| 1662 |
+
gripper_t=None,
|
| 1663 |
+
)
|
| 1664 |
+
|
| 1665 |
+
def policy_control_plan_from_model_output(
|
| 1666 |
+
self,
|
| 1667 |
+
model_output: RoboticsOutput,
|
| 1668 |
+
dataset_name: np.ndarray,
|
| 1669 |
+
valid_mask: torch.Tensor,
|
| 1670 |
+
) -> RoboticsControlPlan:
|
| 1671 |
+
if self.config.translation_norm == Normalization.NONE or is_mean_norm(self.config.translation_norm):
|
| 1672 |
+
model_output = model_output.replace(translation=torch.clamp(model_output.translation, -1, 1))
|
| 1673 |
+
if self.config.rotation_norm == Normalization.NONE or is_mean_norm(self.config.rotation_norm):
|
| 1674 |
+
model_output = model_output.replace(rotation=torch.clamp(model_output.rotation, -1, 1))
|
| 1675 |
+
control_plan = super().policy_control_plan_from_model_output(model_output, dataset_name, valid_mask)
|
| 1676 |
+
control_plan = control_plan.replace(gripper_prob=torch.clamp(model_output.gripper, 0, 1))
|
| 1677 |
+
return control_plan
|
| 1678 |
+
|
| 1679 |
+
|
| 1680 |
+
def make_causal_mask(shape: Sequence[int]) -> torch.Tensor:
|
| 1681 |
+
"""
|
| 1682 |
+
Create a causal attention mask of shape `shape`
|
| 1683 |
+
Args:
|
| 1684 |
+
shape: Shape of the output mask, the last two dimensions correspond to [query_seq_len, kv_seq_len]
|
| 1685 |
+
Returns:
|
| 1686 |
+
torch.Tensor of dtype torch.bool. False values indicate that the row (i.e. query) can't attend
|
| 1687 |
+
to the corresponding column (i.e. key)
|
| 1688 |
+
|
| 1689 |
+
Example:
|
| 1690 |
+
shape = (3, 5) -> Mask the upper triangular part
|
| 1691 |
+
[
|
| 1692 |
+
[ 1, 0, 0, 0, 0],
|
| 1693 |
+
[ 1, 1, 0, 0, 0],
|
| 1694 |
+
[ 1, 1, 1, 0, 0]
|
| 1695 |
+
]
|
| 1696 |
+
"""
|
| 1697 |
+
return torch.tril(torch.ones(shape, dtype=torch.bool), diagonal=0)
|
| 1698 |
+
|
| 1699 |
+
|
| 1700 |
+
def enable_full_attn_blocks(attn_mask: torch.Tensor, full_attn: torch.Tensor) -> torch.Tensor:
|
| 1701 |
+
"""
|
| 1702 |
+
Enable full bi-directional attention in `attn_mask` inside specific blocks
|
| 1703 |
+
Args:
|
| 1704 |
+
attn_mask: Existing attention mask of shape [..., query_seq_len, kv_seq_len] and dtype torch.bool
|
| 1705 |
+
where False values indicate disabled attention
|
| 1706 |
+
full_attn: torch.Tensor of shape [query_seq_len], dtype torch.bool. Blocks of True values indicate
|
| 1707 |
+
positions where full bi-directional attention should be enabled
|
| 1708 |
+
|
| 1709 |
+
Example:
|
| 1710 |
+
1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
|
| 1711 |
+
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
|
| 1712 |
+
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
|
| 1713 |
+
1, 1, 1, 1, 0, 0, 0, 0, -> 1, 1, 1, 1, 0, 0, 0, 0,
|
| 1714 |
+
1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
|
| 1715 |
+
1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
|
| 1716 |
+
1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,
|
| 1717 |
+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
| 1718 |
+
|
| 1719 |
+
"""
|
| 1720 |
+
assert full_attn.dtype == torch.bool, full_attn.dtype
|
| 1721 |
+
assert full_attn.ndim == 1, full_attn.shape
|
| 1722 |
+
assert full_attn.shape[0] == attn_mask.shape[-2], f"{full_attn.shape[0]}, {attn_mask.shape}"
|
| 1723 |
+
if attn_mask.shape[-1] != attn_mask.shape[-2]:
|
| 1724 |
+
raise NotImplementedError("Only self-attention supported right now.")
|
| 1725 |
+
x = full_attn.view(-1, 1) & full_attn.view(1, -1)
|
| 1726 |
+
x = x | make_causal_mask([full_attn.shape[0], full_attn.shape[0]])
|
| 1727 |
+
x = torch.cumprod(x, dim=1).to(dtype=torch.bool)
|
| 1728 |
+
x = x & x.permute(1, 0)
|
| 1729 |
+
mask_positions = torch.sum(x, dim=0) == 1 & ~full_attn
|
| 1730 |
+
mask_indices = torch.where(mask_positions)[0]
|
| 1731 |
+
x[mask_indices, mask_indices] = 0
|
| 1732 |
+
attn_mask = attn_mask | expand_dims(x, ndim=attn_mask.ndim, order=[-1, 1, 1])
|
| 1733 |
+
return attn_mask
|
| 1734 |
+
|
| 1735 |
+
|
| 1736 |
+
IGNORE_INDEX = -100
|
| 1737 |
+
|
| 1738 |
+
|
| 1739 |
+
class PaliGemmaProcessor(VLMProcessor):
|
| 1740 |
+
def __init__(
|
| 1741 |
+
self,
|
| 1742 |
+
config: PaliGemmaProcessorConfig,
|
| 1743 |
+
hf_processor: transformers.models.paligemma.processing_paligemma.PaliGemmaProcessor,
|
| 1744 |
+
**kwargs,
|
| 1745 |
+
):
|
| 1746 |
+
del kwargs
|
| 1747 |
+
super().__init__(config)
|
| 1748 |
+
self.hf_processor = hf_processor
|
| 1749 |
+
self.hf_processor.image_processor.size = dict(self.config.image_sizes["main"].as_json())
|
| 1750 |
+
self.hf_processor.image_seq_length = self.config.num_image_tokens["main"]
|
| 1751 |
+
self.hf_processor.image_processor.image_seq_length = self.config.num_image_tokens["main"]
|
| 1752 |
+
self.bos_id: int = self.tokenizer.bos_token_id
|
| 1753 |
+
self.eos_id: int = self.tokenizer.eos_token_id
|
| 1754 |
+
self.sep_token = "\n"
|
| 1755 |
+
self.sep_id: int = self.tokenizer(
|
| 1756 |
+
self.sep_token,
|
| 1757 |
+
padding=False,
|
| 1758 |
+
add_special_tokens=False,
|
| 1759 |
+
return_attention_mask=False,
|
| 1760 |
+
)["input_ids"][0]
|
| 1761 |
+
self.image_token_id: int = self.tokenizer(
|
| 1762 |
+
self.config.image_token,
|
| 1763 |
+
padding=False,
|
| 1764 |
+
add_special_tokens=False,
|
| 1765 |
+
return_attention_mask=False,
|
| 1766 |
+
)["input_ids"][0]
|
| 1767 |
+
self.image_tokens: list[int] = [self.image_token_id] * sum(self.config.num_image_tokens.values())
|
| 1768 |
+
self.bbox_pattern = re.compile(
|
| 1769 |
+
"\\[(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+)\\]"
|
| 1770 |
+
)
|
| 1771 |
+
|
| 1772 |
+
def preprocess_inputs(
|
| 1773 |
+
self, chat: List[str], images: Dict[str, List[PIL.Image.Image]]
|
| 1774 |
+
) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]:
|
| 1775 |
+
"""
|
| 1776 |
+
Based on PaliGemma paper https://arxiv.org/pdf/2407.07726 and example code at
|
| 1777 |
+
https://ai.google.dev/gemma/docs/paligemma/fine-tuning-paligemma#create_model_inputs
|
| 1778 |
+
Chat must be always made of separate messages from user and model, always starting with user
|
| 1779 |
+
|
| 1780 |
+
<image><image> ... <bos><instruction><sep><assistant><sep><instruction><sep><assistant>...<eos>
|
| 1781 |
+
|
| 1782 |
+
Args:
|
| 1783 |
+
chat: List[str] of even size where each entry corresponds to a different turn in the conversation
|
| 1784 |
+
images: Dict[str, List[PIL.Image.Image]] where different cameras correspond to different keys
|
| 1785 |
+
in the Dict and the List corresponds to history of images
|
| 1786 |
+
"""
|
| 1787 |
+
for key, value in images.items():
|
| 1788 |
+
if not isinstance(value, list):
|
| 1789 |
+
raise TypeError(f"Camera {key} contains values of type {type(value)} instead of list")
|
| 1790 |
+
(input_ids, target_ids) = ([], [])
|
| 1791 |
+
for i, text in enumerate(chat):
|
| 1792 |
+
text = text.replace(self.sep_token, " ").replace("<image>", "")
|
| 1793 |
+
text = self.bbox_pattern.sub(self._bbox_to_loc_tokens, text)
|
| 1794 |
+
turn_input_ids: List[int] = self.tokenizer(
|
| 1795 |
+
text,
|
| 1796 |
+
padding=False,
|
| 1797 |
+
add_special_tokens=False,
|
| 1798 |
+
return_attention_mask=False,
|
| 1799 |
+
)["input_ids"]
|
| 1800 |
+
if i % 2 == 0:
|
| 1801 |
+
turn_target_ids = [IGNORE_INDEX] * len(turn_input_ids)
|
| 1802 |
+
else:
|
| 1803 |
+
turn_target_ids = turn_input_ids
|
| 1804 |
+
if i != len(chat) - 1:
|
| 1805 |
+
turn_input_ids = turn_input_ids + [self.sep_id]
|
| 1806 |
+
turn_target_ids = turn_target_ids + [IGNORE_INDEX]
|
| 1807 |
+
input_ids = input_ids + turn_input_ids
|
| 1808 |
+
target_ids = target_ids + turn_target_ids
|
| 1809 |
+
input_ids = [self.bos_id] + input_ids + [self.eos_id]
|
| 1810 |
+
target_ids = [IGNORE_INDEX] + target_ids + [self.eos_id]
|
| 1811 |
+
image_tokens = self.image_tokens
|
| 1812 |
+
if self.config.max_language_tokens > 0:
|
| 1813 |
+
input_ids = input_ids[: self.config.max_language_tokens]
|
| 1814 |
+
target_ids = target_ids[: self.config.max_language_tokens]
|
| 1815 |
+
input_ids = image_tokens + input_ids
|
| 1816 |
+
target_ids = [IGNORE_INDEX] * len(image_tokens) + target_ids
|
| 1817 |
+
input_ids = torch.tensor(input_ids, dtype=torch.int64)
|
| 1818 |
+
target_ids = torch.tensor(target_ids, dtype=torch.int64)
|
| 1819 |
+
image_tensors: Dict[str, torch.Tensor] = {
|
| 1820 |
+
f"{camera_name}.siglip": self.hf_processor.image_processor(
|
| 1821 |
+
camera_images,
|
| 1822 |
+
size=self.config.image_sizes[camera_name].as_json(),
|
| 1823 |
+
return_tensors="pt",
|
| 1824 |
+
)["pixel_values"]
|
| 1825 |
+
for (camera_name, camera_images) in images.items()
|
| 1826 |
+
}
|
| 1827 |
+
attn_mask = make_causal_mask([len(input_ids), len(input_ids)])
|
| 1828 |
+
attn_mask = enable_full_attn_blocks(attn_mask, full_attn=target_ids == IGNORE_INDEX)
|
| 1829 |
+
return {
|
| 1830 |
+
"input_ids": input_ids,
|
| 1831 |
+
"target_ids": target_ids,
|
| 1832 |
+
"images": image_tensors,
|
| 1833 |
+
"attn_mask": attn_mask,
|
| 1834 |
+
}
|
| 1835 |
+
|
| 1836 |
+
@property
|
| 1837 |
+
def tokenizer(self) -> transformers.PreTrainedTokenizerBase:
|
| 1838 |
+
return self.hf_processor.tokenizer
|
| 1839 |
+
|
| 1840 |
+
@staticmethod
|
| 1841 |
+
def _bbox_to_loc_tokens(match: str) -> str:
|
| 1842 |
+
"""
|
| 1843 |
+
https://developers.googleblog.com/en/gemma-explained-paligemma-architecture/
|
| 1844 |
+
"""
|
| 1845 |
+
floats = list(map(float, match.groups()))
|
| 1846 |
+
transformed = [f"<loc{np.clip(round(num * 1024), 0, 1023):04d}>" for num in floats]
|
| 1847 |
+
return f"[{', '.join(transformed)}]"
|
| 1848 |
+
|
| 1849 |
+
@property
|
| 1850 |
+
def image_sizes(self) -> Dict[str, ImageSizeConfig]:
|
| 1851 |
+
return self.config.image_sizes
|
| 1852 |
+
|
| 1853 |
+
|
| 1854 |
+
class PaliGemmaDepthProcessor(PaliGemmaProcessor):
|
| 1855 |
+
def __init__(
|
| 1856 |
+
self,
|
| 1857 |
+
config: PaliGemmaProcessorConfig,
|
| 1858 |
+
hf_processor: transformers.models.paligemma.processing_paligemma.PaliGemmaProcessor,
|
| 1859 |
+
depth_tokens: int,
|
| 1860 |
+
):
|
| 1861 |
+
super().__init__(config, hf_processor)
|
| 1862 |
+
vocab_size = len(self.tokenizer)
|
| 1863 |
+
self.depth_token_ids = np.arange(vocab_size - depth_tokens, vocab_size)
|
| 1864 |
+
self.depth_input_transforms = {
|
| 1865 |
+
camera_name: torchvision.transforms.v2.Compose(
|
| 1866 |
+
[
|
| 1867 |
+
torchvision.transforms.v2.Resize(
|
| 1868 |
+
size=(camera_image_size.height, camera_image_size.width),
|
| 1869 |
+
interpolation=torchvision.transforms.v2.InterpolationMode.BICUBIC,
|
| 1870 |
+
max_size=None,
|
| 1871 |
+
antialias=True,
|
| 1872 |
+
),
|
| 1873 |
+
torchvision.transforms.v2.ToTensor(),
|
| 1874 |
+
torchvision.transforms.v2.Normalize(
|
| 1875 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
| 1876 |
+
),
|
| 1877 |
+
]
|
| 1878 |
+
)
|
| 1879 |
+
for (camera_name, camera_image_size) in self.config.image_sizes.items()
|
| 1880 |
+
}
|
| 1881 |
+
|
| 1882 |
+
def preprocess_inputs(
|
| 1883 |
+
self, chat: List[str], images: Dict[str, List[PIL.Image.Image]]
|
| 1884 |
+
) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]:
|
| 1885 |
+
inputs = super().preprocess_inputs(chat=chat, images=images)
|
| 1886 |
+
depth_images: Dict[str, torch.Tensor] = {
|
| 1887 |
+
f"{camera_name}.depth": torch.stack(
|
| 1888 |
+
self.depth_input_transforms[camera_name](camera_images), dim=0
|
| 1889 |
+
)
|
| 1890 |
+
for (camera_name, camera_images) in images.items()
|
| 1891 |
+
}
|
| 1892 |
+
inputs["images"] = {**inputs["images"], **depth_images}
|
| 1893 |
+
return inputs
|
| 1894 |
+
|
| 1895 |
+
@property
|
| 1896 |
+
def num_depth_tokens(self) -> int:
|
| 1897 |
+
return len(self.depth_token_ids)
|