Upload folder using huggingface_hub
Browse files- README.md +237 -0
- config.json +55 -0
- configuration_actioncodec.py +230 -0
- model.safetensors +3 -0
- modeling_actioncodec.py +743 -0
- modular_actioncodec.py +787 -0
- rvq.py +522 -0
README.md
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
tags: []
|
| 4 |
+
---
|
| 5 |
+
# Model Card for Model ID
|
| 6 |
+
|
| 7 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
## Model Details
|
| 12 |
+
|
| 13 |
+
### Model Description
|
| 14 |
+
|
| 15 |
+
ActionCodec model trained only on bridgedata:
|
| 16 |
+
- franka_libero_20hz_0s (dummy)
|
| 17 |
+
- widowx_bridge_5hz_2s
|
| 18 |
+
- franka_droid_15hz_0s (dummy)
|
| 19 |
+
|
| 20 |
+
### Model Sources [optional]
|
| 21 |
+
|
| 22 |
+
<!-- Provide the basic links for the model. -->
|
| 23 |
+
|
| 24 |
+
TODO
|
| 25 |
+
|
| 26 |
+
## Uses
|
| 27 |
+
|
| 28 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 29 |
+
|
| 30 |
+
### Direct Use
|
| 31 |
+
|
| 32 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 33 |
+
|
| 34 |
+
```python
|
| 35 |
+
import numpy as np
|
| 36 |
+
from transformers import AutoModel
|
| 37 |
+
np.set_printoptions(suppress=True)
|
| 38 |
+
if __name__ == "__main__":
|
| 39 |
+
tokenizer = AutoModel.from_pretrained("ZibinDong/ActionCodec-bridge-RVQft", trust_remote_code=True)
|
| 40 |
+
q99 = np.array([0.9375, 0.91071427, 0.9375, 0.20357142, 0.26357144, 0.375, 1.0])
|
| 41 |
+
q01 = np.array([-0.87857145, -0.87589288, -0.9375, -0.15107143, -0.20678571, -0.27964285, 0.0])
|
| 42 |
+
# an example action from physical-intelligence/libero
|
| 43 |
+
action = np.array(
|
| 44 |
+
[
|
| 45 |
+
[0.3268, 0.2089, -0.3295, 0.0000, -0.0868, -0.0611, 1.0000],
|
| 46 |
+
[0.3696, 0.1955, -0.2866, 0.0000, -0.0793, -0.0643, 1.0000],
|
| 47 |
+
[0.3857, 0.1929, -0.2759, 0.0000, -0.0782, -0.0654, 1.0000],
|
| 48 |
+
[0.3964, 0.2089, -0.2786, 0.0000, -0.0761, -0.0654, 1.0000],
|
| 49 |
+
[0.3321, 0.1741, -0.3268, 0.0000, -0.0793, -0.0686, 1.0000],
|
| 50 |
+
[0.2250, 0.0964, -0.4232, 0.0000, -0.0932, -0.0761, 1.0000],
|
| 51 |
+
[0.0723, 0.0000, -0.5625, 0.0000, -0.1339, -0.0879, 1.0000],
|
| 52 |
+
[0.0536, 0.0000, -0.5652, 0.0000, -0.1521, -0.0921, 1.0000],
|
| 53 |
+
[0.0750, 0.0000, -0.5464, 0.0000, -0.1511, -0.0964, 1.0000],
|
| 54 |
+
[0.0723, 0.0000, -0.5411, 0.0000, -0.1414, -0.0986, 1.0000],
|
| 55 |
+
[0.0402, 0.0000, -0.5196, 0.0000, -0.1350, -0.1007, 1.0000],
|
| 56 |
+
[0.0080, 0.0000, -0.4795, 0.0000, -0.1189, -0.1018, 1.0000],
|
| 57 |
+
[0.0000, 0.0000, -0.4527, 0.0000, -0.0986, -0.1018, 1.0000],
|
| 58 |
+
[0.0000, 0.0000, -0.4313, 0.0000, -0.0846, -0.1018, 1.0000],
|
| 59 |
+
[-0.0455, -0.0268, -0.3509, 0.0000, -0.0568, -0.1018, 1.0000],
|
| 60 |
+
[-0.0964, -0.0482, -0.3321, 0.0000, -0.0439, -0.1039, 1.0000],
|
| 61 |
+
[-0.1768, -0.0562, -0.3402, 0.0000, -0.0300, -0.1050, 1.0000],
|
| 62 |
+
[-0.2438, -0.0429, -0.3187, 0.0000, -0.0193, -0.0996, 1.0000],
|
| 63 |
+
[-0.3054, -0.0054, -0.2893, 0.0000, -0.0139, -0.0932, 1.0000],
|
| 64 |
+
[-0.3509, 0.0000, -0.2598, 0.0000, -0.0054, -0.0879, 1.0000],
|
| 65 |
+
],
|
| 66 |
+
)[None]
|
| 67 |
+
# normalization
|
| 68 |
+
normalized_action = np.copy(action)
|
| 69 |
+
normalized_action[..., :-1] = normalized_action[..., :-1] / np.maximum(np.abs(q99), np.abs(q01))[..., :-1]
|
| 70 |
+
normalized_action[..., -1] = normalized_action[..., -1] * 2.0 - 1.0 # scale to [-1, 1]
|
| 71 |
+
normalized_action = normalized_action.clip(-1.0, 1.0)
|
| 72 |
+
# tokenization
|
| 73 |
+
tokens = tokenizer.encode(normalized_action) # numpy (b, n, d) -> list of ints
|
| 74 |
+
print(tokens)
|
| 75 |
+
# decoding
|
| 76 |
+
decoded_action, padding_mask = tokenizer.decode(tokens) # list of ints -> numpy (b, n, d)
|
| 77 |
+
# calculate reconstruction error
|
| 78 |
+
mse_error = np.mean((normalized_action - decoded_action) ** 2)
|
| 79 |
+
l1_error = np.mean(np.abs(normalized_action - decoded_action))
|
| 80 |
+
print(f"Reconstruction MSE error: {mse_error:.6f}")
|
| 81 |
+
print(f"Reconstruction L1 error: {l1_error:.6f}")
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
### Downstream Use [optional]
|
| 85 |
+
|
| 86 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 87 |
+
|
| 88 |
+
TODO
|
| 89 |
+
|
| 90 |
+
### Out-of-Scope Use
|
| 91 |
+
|
| 92 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 93 |
+
|
| 94 |
+
TODO
|
| 95 |
+
|
| 96 |
+
## Bias, Risks, and Limitations
|
| 97 |
+
|
| 98 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 99 |
+
|
| 100 |
+
TODO
|
| 101 |
+
|
| 102 |
+
### Recommendations
|
| 103 |
+
|
| 104 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 105 |
+
|
| 106 |
+
TODO
|
| 107 |
+
|
| 108 |
+
## How to Get Started with the Model
|
| 109 |
+
|
| 110 |
+
Use the code below to get started with the model.
|
| 111 |
+
|
| 112 |
+
TODO
|
| 113 |
+
|
| 114 |
+
## Training Details
|
| 115 |
+
|
| 116 |
+
### Training Data
|
| 117 |
+
|
| 118 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 119 |
+
|
| 120 |
+
[More Information Needed]
|
| 121 |
+
|
| 122 |
+
### Training Procedure
|
| 123 |
+
|
| 124 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 125 |
+
|
| 126 |
+
#### Preprocessing [optional]
|
| 127 |
+
|
| 128 |
+
[More Information Needed]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
#### Training Hyperparameters
|
| 132 |
+
|
| 133 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 134 |
+
|
| 135 |
+
#### Speeds, Sizes, Times [optional]
|
| 136 |
+
|
| 137 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Evaluation
|
| 142 |
+
|
| 143 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 144 |
+
|
| 145 |
+
### Testing Data, Factors & Metrics
|
| 146 |
+
|
| 147 |
+
#### Testing Data
|
| 148 |
+
|
| 149 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 150 |
+
|
| 151 |
+
[More Information Needed]
|
| 152 |
+
|
| 153 |
+
#### Factors
|
| 154 |
+
|
| 155 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
#### Metrics
|
| 160 |
+
|
| 161 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 162 |
+
|
| 163 |
+
[More Information Needed]
|
| 164 |
+
|
| 165 |
+
### Results
|
| 166 |
+
|
| 167 |
+
[More Information Needed]
|
| 168 |
+
|
| 169 |
+
#### Summary
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
## Model Examination [optional]
|
| 174 |
+
|
| 175 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
## Environmental Impact
|
| 180 |
+
|
| 181 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 182 |
+
|
| 183 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 184 |
+
|
| 185 |
+
- **Hardware Type:** [More Information Needed]
|
| 186 |
+
- **Hours used:** [More Information Needed]
|
| 187 |
+
- **Cloud Provider:** [More Information Needed]
|
| 188 |
+
- **Compute Region:** [More Information Needed]
|
| 189 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 190 |
+
|
| 191 |
+
## Technical Specifications [optional]
|
| 192 |
+
|
| 193 |
+
### Model Architecture and Objective
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
### Compute Infrastructure
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
| 200 |
+
|
| 201 |
+
#### Hardware
|
| 202 |
+
|
| 203 |
+
[More Information Needed]
|
| 204 |
+
|
| 205 |
+
#### Software
|
| 206 |
+
|
| 207 |
+
[More Information Needed]
|
| 208 |
+
|
| 209 |
+
## Citation [optional]
|
| 210 |
+
|
| 211 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 212 |
+
|
| 213 |
+
**BibTeX:**
|
| 214 |
+
|
| 215 |
+
[More Information Needed]
|
| 216 |
+
|
| 217 |
+
**APA:**
|
| 218 |
+
|
| 219 |
+
[More Information Needed]
|
| 220 |
+
|
| 221 |
+
## Glossary [optional]
|
| 222 |
+
|
| 223 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 224 |
+
|
| 225 |
+
[More Information Needed]
|
| 226 |
+
|
| 227 |
+
## More Information [optional]
|
| 228 |
+
|
| 229 |
+
[More Information Needed]
|
| 230 |
+
|
| 231 |
+
## Model Card Authors [optional]
|
| 232 |
+
|
| 233 |
+
[More Information Needed]
|
| 234 |
+
|
| 235 |
+
## Model Card Contact
|
| 236 |
+
|
| 237 |
+
[More Information Needed]
|
config.json
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"ActionCodec"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_actioncodec.ActionCodecConfig",
|
| 7 |
+
"AutoModel": "modeling_actioncodec.ActionCodec"
|
| 8 |
+
},
|
| 9 |
+
"decoder_add_causal_mask": false,
|
| 10 |
+
"decoder_add_self_attn": false,
|
| 11 |
+
"decoder_cls_size": 1,
|
| 12 |
+
"decoder_dim": 384,
|
| 13 |
+
"decoder_n_heads": 6,
|
| 14 |
+
"decoder_n_layers": 12,
|
| 15 |
+
"decoder_pos_encoding_type": "fourier",
|
| 16 |
+
"dtype": "float32",
|
| 17 |
+
"embodiment_config": {
|
| 18 |
+
"a_franka_libero_20hz": {
|
| 19 |
+
"action_dim": 7,
|
| 20 |
+
"description": "20Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).",
|
| 21 |
+
"duration": 0,
|
| 22 |
+
"freq": 20
|
| 23 |
+
},
|
| 24 |
+
"b_widowx_bridge_5hz": {
|
| 25 |
+
"action_dim": 7,
|
| 26 |
+
"description": "5Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).",
|
| 27 |
+
"duration": 2,
|
| 28 |
+
"freq": 5
|
| 29 |
+
},
|
| 30 |
+
"c_franka_droid_15hz": {
|
| 31 |
+
"action_dim": 7,
|
| 32 |
+
"description": "15Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).",
|
| 33 |
+
"duration": 0,
|
| 34 |
+
"freq": 15
|
| 35 |
+
}
|
| 36 |
+
},
|
| 37 |
+
"encoder_add_causal_mask": false,
|
| 38 |
+
"encoder_add_self_attn": false,
|
| 39 |
+
"encoder_dim": 384,
|
| 40 |
+
"encoder_n_heads": 6,
|
| 41 |
+
"encoder_n_layers": 12,
|
| 42 |
+
"encoder_pos_encoding_type": "fourier",
|
| 43 |
+
"model_type": "action_codec",
|
| 44 |
+
"n_quantizers": 3,
|
| 45 |
+
"n_tokens": 48,
|
| 46 |
+
"transformers_version": "4.57.3",
|
| 47 |
+
"vq_codebook_size": 2048,
|
| 48 |
+
"vq_commitment_weight": 0.25,
|
| 49 |
+
"vq_decay": 0.99,
|
| 50 |
+
"vq_kmeans_init": true,
|
| 51 |
+
"vq_quantizer_dropout": 0.25,
|
| 52 |
+
"vq_threshold_ema_dead_code": 2,
|
| 53 |
+
"vq_type": "rvq",
|
| 54 |
+
"z_dim": 512
|
| 55 |
+
}
|
configuration_actioncodec.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import Any, Dict
|
| 3 |
+
|
| 4 |
+
from transformers import AutoConfig, PretrainedConfig
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ActionCodecConfig(PretrainedConfig):
|
| 8 |
+
model_type = "action_codec"
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
embodiment_config: Dict[str, Any] = None,
|
| 13 |
+
n_tokens: int = 16,
|
| 14 |
+
n_quantizers: int = 1,
|
| 15 |
+
z_dim: int = 512,
|
| 16 |
+
vq_type: str = "vq",
|
| 17 |
+
vq_codebook_size: int = 2048,
|
| 18 |
+
vq_commitment_weight: float = 0.25,
|
| 19 |
+
vq_decay: float = 0.99,
|
| 20 |
+
vq_kmeans_init: bool = True,
|
| 21 |
+
vq_threshold_ema_dead_code: int = 2,
|
| 22 |
+
vq_quantizer_dropout: float = 0.25,
|
| 23 |
+
encoder_dim: int = 256,
|
| 24 |
+
encoder_n_layers: int = 6,
|
| 25 |
+
encoder_n_heads: int = 8,
|
| 26 |
+
encoder_add_self_attn: bool = False,
|
| 27 |
+
encoder_add_causal_mask: bool = False,
|
| 28 |
+
encoder_pos_encoding_type: str = "fourier",
|
| 29 |
+
decoder_dim: int = 256,
|
| 30 |
+
decoder_n_layers: int = 6,
|
| 31 |
+
decoder_n_heads: int = 8,
|
| 32 |
+
decoder_add_self_attn: bool = False,
|
| 33 |
+
decoder_add_causal_mask: bool = False,
|
| 34 |
+
decoder_pos_encoding_type: str = "fourier",
|
| 35 |
+
decoder_cls_size: int = 1,
|
| 36 |
+
**kwargs,
|
| 37 |
+
):
|
| 38 |
+
super().__init__(**kwargs)
|
| 39 |
+
|
| 40 |
+
if embodiment_config is None:
|
| 41 |
+
default_config = {
|
| 42 |
+
"franka_libero_20hz": {
|
| 43 |
+
"action_dim": 7,
|
| 44 |
+
"freq": 20,
|
| 45 |
+
"duration": 1,
|
| 46 |
+
"description": "20Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).",
|
| 47 |
+
},
|
| 48 |
+
"widowx_bridge_5hz": {
|
| 49 |
+
"action_dim": 7,
|
| 50 |
+
"freq": 5,
|
| 51 |
+
"duration": 1,
|
| 52 |
+
"description": "5Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).",
|
| 53 |
+
},
|
| 54 |
+
"franka_droid_15hz": {
|
| 55 |
+
"action_dim": 7,
|
| 56 |
+
"freq": 15,
|
| 57 |
+
"duration": 1,
|
| 58 |
+
"description": "15Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).",
|
| 59 |
+
},
|
| 60 |
+
}
|
| 61 |
+
self.embodiment_config = copy.deepcopy(default_config)
|
| 62 |
+
else:
|
| 63 |
+
self.embodiment_config = copy.deepcopy(embodiment_config)
|
| 64 |
+
|
| 65 |
+
self.n_tokens = n_tokens
|
| 66 |
+
self.n_quantizers = n_quantizers
|
| 67 |
+
self.z_dim = z_dim
|
| 68 |
+
|
| 69 |
+
self.encoder_dim = encoder_dim
|
| 70 |
+
self.encoder_n_layers = encoder_n_layers
|
| 71 |
+
self.encoder_n_heads = encoder_n_heads
|
| 72 |
+
self.encoder_add_self_attn = encoder_add_self_attn
|
| 73 |
+
self.encoder_add_causal_mask = encoder_add_causal_mask
|
| 74 |
+
self.encoder_pos_encoding_type = encoder_pos_encoding_type
|
| 75 |
+
|
| 76 |
+
self.decoder_dim = decoder_dim
|
| 77 |
+
self.decoder_n_layers = decoder_n_layers
|
| 78 |
+
self.decoder_n_heads = decoder_n_heads
|
| 79 |
+
self.decoder_add_self_attn = decoder_add_self_attn
|
| 80 |
+
self.decoder_add_causal_mask = decoder_add_causal_mask
|
| 81 |
+
self.decoder_pos_encoding_type = decoder_pos_encoding_type
|
| 82 |
+
self.decoder_cls_size = decoder_cls_size
|
| 83 |
+
|
| 84 |
+
self.vq_type = vq_type
|
| 85 |
+
self.vq_codebook_size = vq_codebook_size
|
| 86 |
+
self.vq_commitment_weight = vq_commitment_weight
|
| 87 |
+
self.vq_decay = vq_decay
|
| 88 |
+
self.vq_kmeans_init = vq_kmeans_init
|
| 89 |
+
self.vq_threshold_ema_dead_code = vq_threshold_ema_dead_code
|
| 90 |
+
self.vq_quantizer_dropout = vq_quantizer_dropout
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class ActionCodecConfigOld(PretrainedConfig):
|
| 94 |
+
model_type = "action_codec"
|
| 95 |
+
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
horizon: int = 20,
|
| 99 |
+
action_dim: int = 7,
|
| 100 |
+
action_encoding: str = "independent_v2",
|
| 101 |
+
horizon_patch_size: int = 1,
|
| 102 |
+
encoder_class: str = "action_codec.modules.perceiver.PerceiverEncoder",
|
| 103 |
+
decoder_class: str = "action_codec.modules.perceiver.PerceiverDecoder",
|
| 104 |
+
vq_class: str = "vector_quantize_pytorch.VectorQuantize",
|
| 105 |
+
encoder_kwargs: Dict[str, Any] = None,
|
| 106 |
+
decoder_kwargs: Dict[str, Any] = None,
|
| 107 |
+
vq_kwargs: Dict[str, Any] = None,
|
| 108 |
+
**kwargs,
|
| 109 |
+
):
|
| 110 |
+
super().__init__(**kwargs)
|
| 111 |
+
self.horizon = horizon
|
| 112 |
+
self.action_dim = action_dim
|
| 113 |
+
self.action_encoding = action_encoding
|
| 114 |
+
self.horizon_patch_size = horizon_patch_size
|
| 115 |
+
self.encoder_class = encoder_class
|
| 116 |
+
self.decoder_class = decoder_class
|
| 117 |
+
self.vq_class = vq_class
|
| 118 |
+
self.encoder_kwargs = (
|
| 119 |
+
dict(encoder_kwargs)
|
| 120 |
+
if encoder_kwargs is not None
|
| 121 |
+
else {
|
| 122 |
+
"dim": 384,
|
| 123 |
+
"in_len": horizon,
|
| 124 |
+
"out_len": 16,
|
| 125 |
+
"num_layers": 12,
|
| 126 |
+
"num_heads": 4,
|
| 127 |
+
"output_round": -1.0,
|
| 128 |
+
}
|
| 129 |
+
)
|
| 130 |
+
self.decoder_kwargs = (
|
| 131 |
+
dict(decoder_kwargs)
|
| 132 |
+
if decoder_kwargs is not None
|
| 133 |
+
else {
|
| 134 |
+
"dim": 384,
|
| 135 |
+
"in_len": 16,
|
| 136 |
+
"out_len": horizon,
|
| 137 |
+
"num_layers": 12,
|
| 138 |
+
"num_heads": 4,
|
| 139 |
+
}
|
| 140 |
+
)
|
| 141 |
+
self.vq_kwargs = (
|
| 142 |
+
dict(vq_kwargs)
|
| 143 |
+
if vq_kwargs is not None
|
| 144 |
+
else {
|
| 145 |
+
"dim": 512,
|
| 146 |
+
"codebook_size": 2048,
|
| 147 |
+
"kmeans_init": True,
|
| 148 |
+
"kmeans_iters": 10,
|
| 149 |
+
"decay": 0.99,
|
| 150 |
+
"commitment_weight": 0.25,
|
| 151 |
+
"rotation_trick": False,
|
| 152 |
+
"threshold_ema_dead_code": 2,
|
| 153 |
+
"use_cosine_sim": False,
|
| 154 |
+
"codebook_diversity_loss_weight": 0.0,
|
| 155 |
+
}
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class BPEActionCodecConfig(PretrainedConfig):
|
| 160 |
+
model_type = "bpe_action_codec"
|
| 161 |
+
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
horizon: int = 20,
|
| 165 |
+
action_dim: int = 7,
|
| 166 |
+
action_encoding: str = "independent_v2",
|
| 167 |
+
horizon_patch_size: int = 1,
|
| 168 |
+
encoder_class: str = "action_codec.modules.perceiver.PerceiverEncoder",
|
| 169 |
+
decoder_class: str = "action_codec.modules.perceiver.PerceiverDecoder",
|
| 170 |
+
vq_class: str = "vector_quantize_pytorch.VectorQuantize",
|
| 171 |
+
encoder_kwargs: Dict[str, Any] = None,
|
| 172 |
+
decoder_kwargs: Dict[str, Any] = None,
|
| 173 |
+
vq_kwargs: Dict[str, Any] = None,
|
| 174 |
+
**kwargs,
|
| 175 |
+
):
|
| 176 |
+
super().__init__(**kwargs)
|
| 177 |
+
self.horizon = horizon
|
| 178 |
+
self.action_dim = action_dim
|
| 179 |
+
self.action_encoding = action_encoding
|
| 180 |
+
self.horizon_patch_size = horizon_patch_size
|
| 181 |
+
self.encoder_class = encoder_class
|
| 182 |
+
self.decoder_class = decoder_class
|
| 183 |
+
self.vq_class = vq_class
|
| 184 |
+
self.encoder_kwargs = (
|
| 185 |
+
dict(encoder_kwargs)
|
| 186 |
+
if encoder_kwargs is not None
|
| 187 |
+
else {
|
| 188 |
+
"dim": 384,
|
| 189 |
+
"in_len": horizon,
|
| 190 |
+
"out_len": 16,
|
| 191 |
+
"num_layers": 12,
|
| 192 |
+
"num_heads": 4,
|
| 193 |
+
"output_round": -1.0,
|
| 194 |
+
}
|
| 195 |
+
)
|
| 196 |
+
self.decoder_kwargs = (
|
| 197 |
+
dict(decoder_kwargs)
|
| 198 |
+
if decoder_kwargs is not None
|
| 199 |
+
else {
|
| 200 |
+
"dim": 384,
|
| 201 |
+
"in_len": 16,
|
| 202 |
+
"out_len": horizon,
|
| 203 |
+
"num_layers": 12,
|
| 204 |
+
"num_heads": 4,
|
| 205 |
+
}
|
| 206 |
+
)
|
| 207 |
+
self.vq_kwargs = (
|
| 208 |
+
dict(vq_kwargs)
|
| 209 |
+
if vq_kwargs is not None
|
| 210 |
+
else {
|
| 211 |
+
"dim": 512,
|
| 212 |
+
"codebook_size": 2048,
|
| 213 |
+
"kmeans_init": True,
|
| 214 |
+
"kmeans_iters": 10,
|
| 215 |
+
"decay": 0.99,
|
| 216 |
+
"commitment_weight": 0.25,
|
| 217 |
+
"rotation_trick": False,
|
| 218 |
+
"threshold_ema_dead_code": 2,
|
| 219 |
+
"use_cosine_sim": False,
|
| 220 |
+
"codebook_diversity_loss_weight": 0.0,
|
| 221 |
+
}
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
AutoConfig.register("action_codec", ActionCodecConfig)
|
| 226 |
+
AutoConfig.register("bpe_action_codec", BPEActionCodecConfig)
|
| 227 |
+
|
| 228 |
+
ActionCodecConfig.register_for_auto_class()
|
| 229 |
+
|
| 230 |
+
__all__ = ["ActionCodecConfig", "BPEActionCodecConfig"]
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:67e53d45f8bfefe368b5c6ad35cf271368e85df03f1ecb7edc7674037050c73b
|
| 3 |
+
size 197182335
|
modeling_actioncodec.py
ADDED
|
@@ -0,0 +1,743 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import einops
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import AutoModel, PreTrainedModel
|
| 7 |
+
from vector_quantize_pytorch import VectorQuantize
|
| 8 |
+
|
| 9 |
+
from .configuration_actioncodec import ActionCodecConfig
|
| 10 |
+
from .modular_actioncodec import PerceiverDecoder, PerceiverEncoder
|
| 11 |
+
from .rvq import ResidualVectorQuantize
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def trim_trailing_zeros(arr: np.ndarray) -> list[np.ndarray]:
|
| 15 |
+
if arr.shape[0] == 0:
|
| 16 |
+
return []
|
| 17 |
+
|
| 18 |
+
b, n = arr.shape
|
| 19 |
+
|
| 20 |
+
is_nonzero = arr != 0
|
| 21 |
+
flipped_mask = np.flip(is_nonzero, axis=1)
|
| 22 |
+
last_nonzero_indices = n - 1 - np.argmax(flipped_mask, axis=1)
|
| 23 |
+
any_nonzero_in_row = is_nonzero.any(axis=1)
|
| 24 |
+
new_lengths = (last_nonzero_indices + 1) * any_nonzero_in_row
|
| 25 |
+
result = [arr[i, :length].tolist() for i, length in enumerate(new_lengths)]
|
| 26 |
+
|
| 27 |
+
return result
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ActionCodec(PreTrainedModel):
|
| 31 |
+
"""ActionCodec: A neural codec for encoding and decoding robot action sequences.
|
| 32 |
+
|
| 33 |
+
This model uses a Perceiver-based encoder-decoder architecture with vector quantization
|
| 34 |
+
to convert continuous action sequences into discrete token sequences. It supports
|
| 35 |
+
multiple robot embodiments with different action dimensions and control frequencies.
|
| 36 |
+
|
| 37 |
+
The model supports two vector quantization types:
|
| 38 |
+
- VQ (Vector Quantization): Single quantizer
|
| 39 |
+
- RVQ (Residual Vector Quantization): Multiple quantizers for hierarchical encoding
|
| 40 |
+
|
| 41 |
+
Key features:
|
| 42 |
+
- Multi-embodiment support: Handle different robots with varying action dimensions
|
| 43 |
+
- Dynamic expansion: Add new robot configurations without retraining
|
| 44 |
+
- Flexible input/output: Support numpy arrays and torch tensors
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
config_class = ActionCodecConfig
|
| 48 |
+
|
| 49 |
+
def __init__(self, config: ActionCodecConfig):
|
| 50 |
+
"""Initialize the ActionCodec model.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
config (ActionCodecConfig): Model configuration containing hyperparameters
|
| 54 |
+
and embodiment configurations.
|
| 55 |
+
|
| 56 |
+
Raises:
|
| 57 |
+
ValueError: If configuration parameters are invalid.
|
| 58 |
+
NotImplementedError: If the specified VQ type is not supported.
|
| 59 |
+
"""
|
| 60 |
+
super().__init__(config)
|
| 61 |
+
|
| 62 |
+
# Validate configuration
|
| 63 |
+
if config.n_tokens % config.n_quantizers != 0:
|
| 64 |
+
raise ValueError(f"n_tokens ({config.n_tokens}) must be divisible by n_quantizers ({config.n_quantizers})")
|
| 65 |
+
|
| 66 |
+
if config.n_quantizers < 1:
|
| 67 |
+
raise ValueError(f"n_quantizers must be at least 1, got {config.n_quantizers}")
|
| 68 |
+
|
| 69 |
+
if config.vq_codebook_size < 1:
|
| 70 |
+
raise ValueError(f"vq_codebook_size must be at least 1, got {config.vq_codebook_size}")
|
| 71 |
+
|
| 72 |
+
if config.z_dim < 1:
|
| 73 |
+
raise ValueError(f"z_dim must be at least 1, got {config.z_dim}")
|
| 74 |
+
|
| 75 |
+
if not isinstance(config.embodiment_config, dict) or len(config.embodiment_config) == 0:
|
| 76 |
+
raise ValueError(
|
| 77 |
+
"embodiment_config must be a non-empty dictionary mapping embodiment names to configurations"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
self.default_embodiment_id = 0
|
| 81 |
+
|
| 82 |
+
# Initialize encoder and decoder
|
| 83 |
+
self.encoder = PerceiverEncoder(config)
|
| 84 |
+
self.decoder = PerceiverDecoder(config)
|
| 85 |
+
|
| 86 |
+
# Initialize vector quantizer based on type
|
| 87 |
+
if config.vq_type == "vq":
|
| 88 |
+
if config.n_quantizers != 1:
|
| 89 |
+
raise ValueError(
|
| 90 |
+
f"VQ type requires n_quantizers=1, got {config.n_quantizers}. Use RVQ type for multiple quantizers."
|
| 91 |
+
)
|
| 92 |
+
self.vq = VectorQuantize(
|
| 93 |
+
dim=config.z_dim,
|
| 94 |
+
codebook_size=config.vq_codebook_size,
|
| 95 |
+
commitment_weight=config.vq_commitment_weight,
|
| 96 |
+
decay=config.vq_decay,
|
| 97 |
+
kmeans_init=config.vq_kmeans_init,
|
| 98 |
+
threshold_ema_dead_code=config.vq_threshold_ema_dead_code,
|
| 99 |
+
rotation_trick=False,
|
| 100 |
+
straight_through=True,
|
| 101 |
+
)
|
| 102 |
+
elif config.vq_type == "rvq":
|
| 103 |
+
if config.n_quantizers < 2:
|
| 104 |
+
raise ValueError(
|
| 105 |
+
f"RVQ type requires n_quantizers >= 2, got {config.n_quantizers}. Use VQ type for single quantizer."
|
| 106 |
+
)
|
| 107 |
+
self.vq = ResidualVectorQuantize(
|
| 108 |
+
dim=config.z_dim,
|
| 109 |
+
n_codebooks=config.n_quantizers,
|
| 110 |
+
codebook_size=config.vq_codebook_size,
|
| 111 |
+
codebook_dim=config.z_dim,
|
| 112 |
+
quantizer_dropout=config.vq_quantizer_dropout,
|
| 113 |
+
commitment=config.vq_commitment_weight,
|
| 114 |
+
)
|
| 115 |
+
else:
|
| 116 |
+
raise NotImplementedError(f"VQ type '{config.vq_type}' not implemented. Supported types: 'vq', 'rvq'")
|
| 117 |
+
|
| 118 |
+
# Store quantization-related attributes
|
| 119 |
+
self.vocab_size = config.vq_codebook_size
|
| 120 |
+
self.num_quantizers = config.n_quantizers
|
| 121 |
+
self.n_tokens_per_quantizer = config.n_tokens // config.n_quantizers
|
| 122 |
+
|
| 123 |
+
def expand_embodiment(self, embodiment_config: dict):
|
| 124 |
+
"""Dynamically expand the model to support new robot embodiments.
|
| 125 |
+
|
| 126 |
+
This method allows adding new robot configurations to the codec without retraining
|
| 127 |
+
the entire model. It updates the encoder and decoder to handle the new action dimensions
|
| 128 |
+
and frequencies while preserving existing functionality for previously configured robots.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
embodiment_config (dict): Dictionary mapping embodiment names to their configurations.
|
| 132 |
+
Each configuration should be a dict with keys:
|
| 133 |
+
- "action_dim" (int): Action dimensionality for this embodiment.
|
| 134 |
+
- "freq" (float): Control frequency in Hz.
|
| 135 |
+
- "duration" (float): Default action sequence duration in seconds.
|
| 136 |
+
- "description" (str, optional): Human-readable description.
|
| 137 |
+
|
| 138 |
+
Example:
|
| 139 |
+
{
|
| 140 |
+
"robot_B": {
|
| 141 |
+
"action_dim": 10,
|
| 142 |
+
"freq": 20,
|
| 143 |
+
"duration": 1.0,
|
| 144 |
+
"description": "10-dim robot at 20Hz"
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
ActionCodec: Returns self for method chaining.
|
| 150 |
+
|
| 151 |
+
Note:
|
| 152 |
+
- New embodiment keys must not already exist in the current configuration.
|
| 153 |
+
- The model will automatically update max_action_dim if the new embodiment
|
| 154 |
+
has a larger action dimension.
|
| 155 |
+
- Existing embodiments will continue to work with their original configurations.
|
| 156 |
+
"""
|
| 157 |
+
if not isinstance(embodiment_config, dict):
|
| 158 |
+
raise TypeError(f"embodiment_config must be a dict, got {type(embodiment_config)}")
|
| 159 |
+
if len(embodiment_config) == 0:
|
| 160 |
+
raise ValueError("embodiment_config cannot be empty")
|
| 161 |
+
|
| 162 |
+
# Check for duplicate keys
|
| 163 |
+
overlapping_keys = set(embodiment_config.keys()) & set(self.config.embodiment_config.keys())
|
| 164 |
+
if overlapping_keys:
|
| 165 |
+
raise ValueError(f"The following embodiment keys already exist and cannot be redefined: {overlapping_keys}")
|
| 166 |
+
|
| 167 |
+
self.encoder.expand_embodiment(embodiment_config)
|
| 168 |
+
self.decoder.expand_embodiment(embodiment_config)
|
| 169 |
+
self.config.embodiment_config.update(embodiment_config)
|
| 170 |
+
return self
|
| 171 |
+
|
| 172 |
+
def _encode(
|
| 173 |
+
self,
|
| 174 |
+
x: torch.Tensor,
|
| 175 |
+
embodiment_ids: torch.Tensor | int | None = None,
|
| 176 |
+
padding_mask: torch.Tensor | None = None,
|
| 177 |
+
) -> torch.Tensor:
|
| 178 |
+
"""Encode action sequences into latent representations.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
x (torch.Tensor): Action sequences to encode. Shape: (b, seq_len, max_action_dim).
|
| 182 |
+
Assumes that the action dimension is zero-padded to the max action dimension.
|
| 183 |
+
`seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length.
|
| 184 |
+
embodiment_ids (torch.Tensor | int): Embodiment IDs. Shape: (b,).
|
| 185 |
+
If int, the same embodiment ID is repeated for all sequences in the batch.
|
| 186 |
+
It specifies the embodiment to encode.
|
| 187 |
+
padding_mask (Optional[torch.Tensor], optional): Padding mask, where `False` values indicate padding. Shape: (b, seq_len). Defaults to None.
|
| 188 |
+
It is used to mask the padding tokens on `seq_len` dimension.
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
torch.Tensor: Encoded latent representations. Shape: (b, n_tokens_per_quantizer, z_dim).
|
| 192 |
+
"""
|
| 193 |
+
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
|
| 194 |
+
z_e = self.encoder(x, embodiment_ids, padding_mask)
|
| 195 |
+
return z_e
|
| 196 |
+
|
| 197 |
+
def _quantize(
|
| 198 |
+
self, z_e: torch.Tensor, return_perplexity: bool = True
|
| 199 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Union[float, List[float]], torch.Tensor]:
|
| 200 |
+
"""Quantize encoded representations using vector quantization.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
z_e (torch.Tensor): Encoded latent representations to quantize.
|
| 204 |
+
Shape: (b, n_tokens_per_quantizer, z_dim).
|
| 205 |
+
return_perplexity (bool, optional): Whether to compute and return perplexity.
|
| 206 |
+
Defaults to True.
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
Tuple[torch.Tensor, torch.Tensor, Union[float, List[float]], torch.Tensor]:
|
| 210 |
+
A tuple containing:
|
| 211 |
+
- z_q (torch.Tensor): Quantized representations.
|
| 212 |
+
Shape: (b, n_tokens_per_quantizer, z_dim).
|
| 213 |
+
- indices (torch.Tensor): Quantization indices.
|
| 214 |
+
Shape: (b, n_tokens_per_quantizer) for VQ or (b, n_tokens_per_quantizer, n_quantizers) for RVQ.
|
| 215 |
+
- perplexity (Union[float, List[float]]): Codebook perplexity.
|
| 216 |
+
Float for single quantizer, List[float] for multiple quantizers.
|
| 217 |
+
- commit_loss (torch.Tensor): Commitment loss scalar tensor.
|
| 218 |
+
"""
|
| 219 |
+
if isinstance(self.vq, ResidualVectorQuantize):
|
| 220 |
+
z_q, indices, _, commitment_loss, codebook_loss = self.vq(z_e)
|
| 221 |
+
commit_loss = commitment_loss.mean() + codebook_loss.mean()
|
| 222 |
+
elif isinstance(self.vq, VectorQuantize):
|
| 223 |
+
z_q, indices, commit_loss = self.vq(z_e)
|
| 224 |
+
else:
|
| 225 |
+
raise NotImplementedError(f"VQ type {type(self.vq)} not implemented")
|
| 226 |
+
|
| 227 |
+
if return_perplexity:
|
| 228 |
+
if len(indices.size()) < 3:
|
| 229 |
+
indices = indices.unsqueeze(-1)
|
| 230 |
+
perplexity = []
|
| 231 |
+
for k in range(indices.size(-1)):
|
| 232 |
+
this_indices = indices[:, :, k]
|
| 233 |
+
indices_count = torch.bincount(this_indices.view(-1), minlength=self.vq.codebook_size)
|
| 234 |
+
if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1:
|
| 235 |
+
torch.distributed.all_reduce(indices_count)
|
| 236 |
+
this_avg_probs = indices_count.float() / indices_count.sum()
|
| 237 |
+
perplexity.append(((-(this_avg_probs * torch.log(this_avg_probs + 1e-10)).sum()).exp().item()))
|
| 238 |
+
else:
|
| 239 |
+
perplexity = 0
|
| 240 |
+
|
| 241 |
+
return z_q, indices, perplexity, commit_loss
|
| 242 |
+
|
| 243 |
+
def _dequantize(self, indices: torch.Tensor) -> torch.Tensor:
|
| 244 |
+
"""Dequantize token indices back to continuous latent representations.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
indices (torch.Tensor): Quantization indices. Shape depends on quantizer type:
|
| 248 |
+
- For VQ: (b, n_tokens) or (b, n_tokens, 1)
|
| 249 |
+
- For RVQ: (b, n_tokens_per_quantizer, n_quantizers)
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
torch.Tensor: Dequantized latent representations.
|
| 253 |
+
Shape: (b, n_tokens_per_quantizer, z_dim)
|
| 254 |
+
"""
|
| 255 |
+
if self.num_quantizers == 1:
|
| 256 |
+
if len(indices.size()) == 3:
|
| 257 |
+
indices = indices.squeeze(-1)
|
| 258 |
+
if isinstance(self.vq, ResidualVectorQuantize):
|
| 259 |
+
z_q = self.vq.from_codes(indices)[0]
|
| 260 |
+
elif isinstance(self.vq, VectorQuantize):
|
| 261 |
+
z_q = self.vq.get_output_from_indices(indices)
|
| 262 |
+
else:
|
| 263 |
+
raise NotImplementedError(f"VQ type {type(self.vq)} not implemented in _dequantize")
|
| 264 |
+
return z_q
|
| 265 |
+
|
| 266 |
+
def _decode(
|
| 267 |
+
self, z_q: torch.Tensor, embodiment_ids: torch.Tensor | int | None = None, durations: torch.Tensor | None = None
|
| 268 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 269 |
+
"""Decode quantized latent representations into action sequences.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
z_q (torch.Tensor): Quantized latent representations.
|
| 273 |
+
Shape: (b, n_tokens_per_quantizer, z_dim).
|
| 274 |
+
embodiment_ids (Union[torch.Tensor, int, None], optional): Embodiment IDs.
|
| 275 |
+
Shape: (b,) if tensor. If int, the same embodiment ID is used for all
|
| 276 |
+
sequences. Defaults to None, which uses `self.default_embodiment_id`.
|
| 277 |
+
durations (torch.Tensor | None, optional): Duration of each action sequence in seconds.
|
| 278 |
+
Shape: (b,). If None, uses default duration from embodiment_config.
|
| 279 |
+
Defaults to None.
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
| 283 |
+
- x_recon (torch.Tensor): Reconstructed action sequences.
|
| 284 |
+
Shape: (b, seq_len, max_action_dim).
|
| 285 |
+
- padding_mask (torch.Tensor): Padding mask indicating valid timesteps.
|
| 286 |
+
Shape: (b, seq_len), where True indicates valid timesteps.
|
| 287 |
+
"""
|
| 288 |
+
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
|
| 289 |
+
x_recon, padding_mask = self.decoder(z_q, embodiment_ids, durations)
|
| 290 |
+
return x_recon, padding_mask
|
| 291 |
+
|
| 292 |
+
@torch.no_grad()
|
| 293 |
+
def encode(
|
| 294 |
+
self,
|
| 295 |
+
x: Union[np.ndarray, torch.Tensor],
|
| 296 |
+
embodiment_ids: Union[List[int], int, None] = None,
|
| 297 |
+
padding_mask: Union[List[bool], np.ndarray, torch.Tensor, None] = None,
|
| 298 |
+
**kwargs,
|
| 299 |
+
) -> List[List[int]]:
|
| 300 |
+
"""Encode action sequences into latent representations (token indices).
|
| 301 |
+
|
| 302 |
+
This method converts action sequences into discrete token indices using the encoder
|
| 303 |
+
and vector quantizer. The input can be either a numpy array or torch tensor.
|
| 304 |
+
|
| 305 |
+
Args:
|
| 306 |
+
x (Union[np.ndarray, torch.Tensor]): Action sequences to encode.
|
| 307 |
+
Shape: (b, seq_len, max_action_dim).
|
| 308 |
+
Assumes that the action dimension is zero-padded to the max action dimension.
|
| 309 |
+
`seq_len` is supposed to be `int(duration * freq)` for each embodiment and
|
| 310 |
+
padded to the max sequence length.
|
| 311 |
+
embodiment_ids (Union[List[int], int, None], optional): Embodiment IDs.
|
| 312 |
+
Shape: (b,) if list. If int, the same embodiment ID is repeated for all
|
| 313 |
+
sequences in the batch. It specifies the embodiment to encode.
|
| 314 |
+
Defaults to None, which uses `self.default_embodiment_id`.
|
| 315 |
+
padding_mask (Union[List[bool], np.ndarray, torch.Tensor, None], optional):
|
| 316 |
+
Padding mask, where `False` values indicate padding. Shape: (b, seq_len).
|
| 317 |
+
Defaults to None. It is used to mask the padding tokens on `seq_len` dimension.
|
| 318 |
+
**kwargs: Additional keyword arguments (currently unused, reserved for future use).
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
List[List[int]]: List of token sequences. Shape: (b, n_tokens), where n_tokens
|
| 322 |
+
is determined by the model configuration (typically `config.n_tokens`).
|
| 323 |
+
|
| 324 |
+
Raises:
|
| 325 |
+
ValueError: If input shapes are invalid or incompatible with the model configuration.
|
| 326 |
+
TypeError: If input types are not supported.
|
| 327 |
+
|
| 328 |
+
Examples:
|
| 329 |
+
>>> import numpy as np
|
| 330 |
+
>>> # Using numpy array
|
| 331 |
+
>>> x = np.random.randn(2, 10, 7).astype(np.float32)
|
| 332 |
+
>>> tokens = model.encode(x, embodiment_ids=[0, 0])
|
| 333 |
+
>>> # Using torch tensor
|
| 334 |
+
>>> x_tensor = torch.randn(2, 10, 7)
|
| 335 |
+
>>> tokens = model.encode(x_tensor, embodiment_ids=[0, 0])
|
| 336 |
+
"""
|
| 337 |
+
self.eval()
|
| 338 |
+
|
| 339 |
+
# Validate and convert input x
|
| 340 |
+
if isinstance(x, np.ndarray):
|
| 341 |
+
if x.ndim != 3:
|
| 342 |
+
raise ValueError(
|
| 343 |
+
f"Expected 3D input array (batch, seq_len, action_dim), got {x.ndim}D array with shape {x.shape}"
|
| 344 |
+
)
|
| 345 |
+
x_tensor = torch.tensor(x, dtype=self.dtype, device=self.device)
|
| 346 |
+
elif isinstance(x, torch.Tensor):
|
| 347 |
+
if x.ndim != 3:
|
| 348 |
+
raise ValueError(
|
| 349 |
+
f"Expected 3D tensor (batch, seq_len, action_dim), got {x.ndim}D tensor with shape {x.shape}"
|
| 350 |
+
)
|
| 351 |
+
x_tensor = x.to(dtype=self.dtype, device=self.device)
|
| 352 |
+
else:
|
| 353 |
+
raise TypeError(f"Input x must be numpy.ndarray or torch.Tensor, got {type(x)}")
|
| 354 |
+
|
| 355 |
+
# Validate batch size
|
| 356 |
+
batch_size = x_tensor.shape[0]
|
| 357 |
+
if batch_size == 0:
|
| 358 |
+
raise ValueError("Batch size must be at least 1")
|
| 359 |
+
|
| 360 |
+
# Handle embodiment_ids
|
| 361 |
+
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
|
| 362 |
+
if isinstance(embodiment_ids, int):
|
| 363 |
+
if not 0 <= embodiment_ids < len(self.config.embodiment_config):
|
| 364 |
+
raise ValueError(
|
| 365 |
+
f"embodiment_id {embodiment_ids} is out of range [0, {len(self.config.embodiment_config)}). "
|
| 366 |
+
f"Available embodiment IDs: {list(range(len(self.config.embodiment_config)))}"
|
| 367 |
+
)
|
| 368 |
+
embodiment_ids_tensor = torch.tensor([embodiment_ids] * batch_size, dtype=torch.long, device=self.device)
|
| 369 |
+
elif isinstance(embodiment_ids, list):
|
| 370 |
+
if len(embodiment_ids) != batch_size:
|
| 371 |
+
raise ValueError(
|
| 372 |
+
f"Length of embodiment_ids ({len(embodiment_ids)}) must match batch size ({batch_size})"
|
| 373 |
+
)
|
| 374 |
+
for eid in embodiment_ids:
|
| 375 |
+
if not isinstance(eid, int) or not 0 <= eid < len(self.config.embodiment_config):
|
| 376 |
+
raise ValueError(
|
| 377 |
+
f"Invalid embodiment_id {eid}. Must be an integer in range [0, {len(self.config.embodiment_config)})"
|
| 378 |
+
)
|
| 379 |
+
embodiment_ids_tensor = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device)
|
| 380 |
+
else:
|
| 381 |
+
raise TypeError(f"embodiment_ids must be int, List[int], or None, got {type(embodiment_ids)}")
|
| 382 |
+
|
| 383 |
+
# Handle padding_mask
|
| 384 |
+
padding_mask_tensor = None
|
| 385 |
+
if padding_mask is not None:
|
| 386 |
+
if isinstance(padding_mask, (list, np.ndarray)):
|
| 387 |
+
padding_mask_tensor = torch.tensor(padding_mask, dtype=torch.bool, device=self.device)
|
| 388 |
+
elif isinstance(padding_mask, torch.Tensor):
|
| 389 |
+
padding_mask_tensor = padding_mask.to(dtype=torch.bool, device=self.device)
|
| 390 |
+
else:
|
| 391 |
+
raise TypeError(
|
| 392 |
+
f"padding_mask must be List[bool], np.ndarray, torch.Tensor, or None, got {type(padding_mask)}"
|
| 393 |
+
)
|
| 394 |
+
if padding_mask_tensor.shape != (batch_size, x_tensor.shape[1]):
|
| 395 |
+
raise ValueError(
|
| 396 |
+
f"padding_mask shape {padding_mask_tensor.shape} does not match expected shape "
|
| 397 |
+
f"({batch_size}, {x_tensor.shape[1]})"
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
with torch.no_grad():
|
| 401 |
+
z_e = self._encode(x_tensor, embodiment_ids_tensor, padding_mask_tensor)
|
| 402 |
+
_, indices, _, _ = self._quantize(z_e, return_perplexity=False)
|
| 403 |
+
|
| 404 |
+
# Reshape indices: for RVQ, indices shape is (b, n, s), for VQ it's (b, n)
|
| 405 |
+
if len(indices.size()) > 2:
|
| 406 |
+
codes_list = einops.rearrange(indices, "b n s -> b (s n)").cpu()
|
| 407 |
+
else:
|
| 408 |
+
codes_list = indices.cpu()
|
| 409 |
+
|
| 410 |
+
codes_list = codes_list.tolist()
|
| 411 |
+
return codes_list
|
| 412 |
+
|
| 413 |
+
@torch.no_grad()
|
| 414 |
+
def decode(
|
| 415 |
+
self,
|
| 416 |
+
tokens: Union[List[List[int]], np.ndarray, torch.Tensor],
|
| 417 |
+
embodiment_ids: Union[List[int], int, None] = None,
|
| 418 |
+
durations: Union[List[float], np.ndarray, torch.Tensor, None] = None,
|
| 419 |
+
**kwargs,
|
| 420 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 421 |
+
"""Decode token sequences into action sequences.
|
| 422 |
+
|
| 423 |
+
This method reconstructs action sequences from discrete token indices using the
|
| 424 |
+
vector quantizer and decoder. The input tokens can be a list of lists, numpy array,
|
| 425 |
+
or torch tensor.
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
tokens (Union[List[List[int]], np.ndarray, torch.Tensor]): Token sequences to decode.
|
| 429 |
+
Shape: (b, n_tokens), where n_tokens must be divisible by `n_tokens_per_quantizer`.
|
| 430 |
+
For RVQ, tokens are interleaved: [q0_t0, q1_t0, ..., qN_t0, q0_t1, ...].
|
| 431 |
+
embodiment_ids (Union[List[int], int, None], optional): Embodiment IDs.
|
| 432 |
+
Shape: (b,) if list. If int, the same embodiment ID is repeated for all
|
| 433 |
+
sequences in the batch. It specifies the embodiment to decode.
|
| 434 |
+
Defaults to None, which uses `self.default_embodiment_id`.
|
| 435 |
+
durations (Union[List[float], np.ndarray, torch.Tensor, None], optional):
|
| 436 |
+
Duration of each action sequence in seconds. Shape: (b,).
|
| 437 |
+
If None, the duration is inferred from the default values in `embodiment_config`.
|
| 438 |
+
Defaults to None.
|
| 439 |
+
**kwargs: Additional keyword arguments (currently unused, reserved for future use).
|
| 440 |
+
|
| 441 |
+
Returns:
|
| 442 |
+
Tuple[np.ndarray, np.ndarray]: A tuple containing:
|
| 443 |
+
- reconstructed_actions: Reconstructed action sequences.
|
| 444 |
+
Shape: (b, seq_len, max_action_dim).
|
| 445 |
+
- padding_mask: Padding mask indicating valid timesteps.
|
| 446 |
+
Shape: (b, seq_len), where True indicates valid timesteps.
|
| 447 |
+
|
| 448 |
+
Raises:
|
| 449 |
+
ValueError: If token sequence length is invalid or incompatible with the model configuration.
|
| 450 |
+
TypeError: If input types are not supported.
|
| 451 |
+
|
| 452 |
+
Examples:
|
| 453 |
+
>>> # Using list of lists
|
| 454 |
+
>>> tokens = [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16]]
|
| 455 |
+
>>> actions, mask = model.decode(tokens, embodiment_ids=[0, 0])
|
| 456 |
+
>>> # Using numpy array
|
| 457 |
+
>>> tokens_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
|
| 458 |
+
>>> actions, mask = model.decode(tokens_np, embodiment_ids=[0, 0])
|
| 459 |
+
>>> # Using torch tensor
|
| 460 |
+
>>> tokens_tensor = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
|
| 461 |
+
>>> actions, mask = model.decode(tokens_tensor, embodiment_ids=[0, 0])
|
| 462 |
+
"""
|
| 463 |
+
self.eval()
|
| 464 |
+
|
| 465 |
+
# Validate and convert input tokens
|
| 466 |
+
if isinstance(tokens, list):
|
| 467 |
+
if not all(isinstance(seq, list) for seq in tokens):
|
| 468 |
+
raise TypeError("If tokens is a list, all elements must be lists")
|
| 469 |
+
if len(tokens) == 0:
|
| 470 |
+
raise ValueError("Tokens list cannot be empty")
|
| 471 |
+
if not all(isinstance(val, (int, np.integer)) for seq in tokens for val in seq):
|
| 472 |
+
raise TypeError("All token values must be integers")
|
| 473 |
+
tokens_tensor = torch.tensor(tokens, dtype=torch.long, device=self.device)
|
| 474 |
+
elif isinstance(tokens, np.ndarray):
|
| 475 |
+
if tokens.ndim != 2:
|
| 476 |
+
raise ValueError(
|
| 477 |
+
f"Expected 2D array (batch, n_tokens), got {tokens.ndim}D array with shape {tokens.shape}"
|
| 478 |
+
)
|
| 479 |
+
if not np.issubdtype(tokens.dtype, np.integer):
|
| 480 |
+
raise TypeError(f"Tokens array must have integer dtype, got {tokens.dtype}")
|
| 481 |
+
tokens_tensor = torch.tensor(tokens, dtype=torch.long, device=self.device)
|
| 482 |
+
elif isinstance(tokens, torch.Tensor):
|
| 483 |
+
if tokens.ndim != 2:
|
| 484 |
+
raise ValueError(
|
| 485 |
+
f"Expected 2D tensor (batch, n_tokens), got {tokens.ndim}D tensor with shape {tokens.shape}"
|
| 486 |
+
)
|
| 487 |
+
if not tokens.dtype.is_integer:
|
| 488 |
+
raise TypeError(f"Tokens tensor must have integer dtype, got {tokens.dtype}")
|
| 489 |
+
tokens_tensor = tokens.to(dtype=torch.long, device=self.device)
|
| 490 |
+
else:
|
| 491 |
+
raise TypeError(f"tokens must be List[List[int]], np.ndarray, or torch.Tensor, got {type(tokens)}")
|
| 492 |
+
|
| 493 |
+
batch_size, n_tokens = tokens_tensor.shape
|
| 494 |
+
if batch_size == 0:
|
| 495 |
+
raise ValueError("Batch size must be at least 1")
|
| 496 |
+
if n_tokens == 0:
|
| 497 |
+
raise ValueError("Token sequence length must be at least 1")
|
| 498 |
+
|
| 499 |
+
# Validate token sequence length
|
| 500 |
+
if n_tokens % self.n_tokens_per_quantizer != 0:
|
| 501 |
+
raise ValueError(
|
| 502 |
+
f"Token sequence length ({n_tokens}) must be divisible by tokens per quantizer "
|
| 503 |
+
f"({self.n_tokens_per_quantizer}). Total tokens: {n_tokens}, "
|
| 504 |
+
f"Expected multiple of: {self.n_tokens_per_quantizer}. "
|
| 505 |
+
f"Number of quantizers: {self.num_quantizers}, Total tokens per sequence: {self.config.n_tokens}"
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
# Validate token values are within codebook range
|
| 509 |
+
if tokens_tensor.min() < 0 or tokens_tensor.max() >= self.vocab_size:
|
| 510 |
+
raise ValueError(
|
| 511 |
+
f"Token values must be in range [0, {self.vocab_size}), "
|
| 512 |
+
f"got range [{tokens_tensor.min().item()}, {tokens_tensor.max().item()}]"
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
# Handle embodiment_ids
|
| 516 |
+
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
|
| 517 |
+
if isinstance(embodiment_ids, int):
|
| 518 |
+
if not 0 <= embodiment_ids < len(self.config.embodiment_config):
|
| 519 |
+
raise ValueError(
|
| 520 |
+
f"embodiment_id {embodiment_ids} is out of range [0, {len(self.config.embodiment_config)}). "
|
| 521 |
+
f"Available embodiment IDs: {list(range(len(self.config.embodiment_config)))}"
|
| 522 |
+
)
|
| 523 |
+
embodiment_ids_tensor = torch.tensor([embodiment_ids] * batch_size, dtype=torch.long, device=self.device)
|
| 524 |
+
elif isinstance(embodiment_ids, list):
|
| 525 |
+
if len(embodiment_ids) != batch_size:
|
| 526 |
+
raise ValueError(
|
| 527 |
+
f"Length of embodiment_ids ({len(embodiment_ids)}) must match batch size ({batch_size})"
|
| 528 |
+
)
|
| 529 |
+
for eid in embodiment_ids:
|
| 530 |
+
if not isinstance(eid, int) or not 0 <= eid < len(self.config.embodiment_config):
|
| 531 |
+
raise ValueError(
|
| 532 |
+
f"Invalid embodiment_id {eid}. Must be an integer in range [0, {len(self.config.embodiment_config)})"
|
| 533 |
+
)
|
| 534 |
+
embodiment_ids_tensor = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device)
|
| 535 |
+
else:
|
| 536 |
+
raise TypeError(f"embodiment_ids must be int, List[int], or None, got {type(embodiment_ids)}")
|
| 537 |
+
|
| 538 |
+
# Handle durations
|
| 539 |
+
durations_tensor = None
|
| 540 |
+
if durations is not None:
|
| 541 |
+
if isinstance(durations, (list, np.ndarray)):
|
| 542 |
+
durations_tensor = torch.tensor(durations, dtype=torch.float32, device=self.device)
|
| 543 |
+
elif isinstance(durations, torch.Tensor):
|
| 544 |
+
durations_tensor = durations.to(dtype=torch.float32, device=self.device)
|
| 545 |
+
else:
|
| 546 |
+
raise TypeError(
|
| 547 |
+
f"durations must be List[float], np.ndarray, torch.Tensor, or None, got {type(durations)}"
|
| 548 |
+
)
|
| 549 |
+
if durations_tensor.ndim != 1:
|
| 550 |
+
raise ValueError(
|
| 551 |
+
f"durations must be 1D, got {durations_tensor.ndim}D with shape {durations_tensor.shape}"
|
| 552 |
+
)
|
| 553 |
+
if len(durations_tensor) != batch_size:
|
| 554 |
+
raise ValueError(f"Length of durations ({len(durations_tensor)}) must match batch size ({batch_size})")
|
| 555 |
+
if (durations_tensor <= 0).any():
|
| 556 |
+
raise ValueError("All durations must be positive")
|
| 557 |
+
|
| 558 |
+
# Reshape tokens for dequantization: (b, n_tokens) -> (b, n_tokens_per_quantizer, n_quantizers)
|
| 559 |
+
indices = einops.rearrange(tokens_tensor, "b (n m) -> b m n", m=self.n_tokens_per_quantizer)
|
| 560 |
+
|
| 561 |
+
with torch.no_grad():
|
| 562 |
+
z_q = self._dequantize(indices)
|
| 563 |
+
x_recon, padding_mask = self._decode(z_q, embodiment_ids_tensor, durations_tensor)
|
| 564 |
+
|
| 565 |
+
return x_recon.float().cpu().numpy(), padding_mask.float().cpu().numpy()
|
| 566 |
+
|
| 567 |
+
def forward(
|
| 568 |
+
self,
|
| 569 |
+
x: Union[torch.Tensor, np.ndarray],
|
| 570 |
+
embodiment_ids: Union[torch.Tensor, int, List[int], None] = None,
|
| 571 |
+
padding_mask: Union[torch.Tensor, List[bool], np.ndarray, None] = None,
|
| 572 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 573 |
+
"""Forward pass through the full ActionCodec pipeline.
|
| 574 |
+
|
| 575 |
+
This method performs encoding, quantization, and decoding in a single forward pass.
|
| 576 |
+
It is primarily used during training to compute reconstruction loss and commitment loss.
|
| 577 |
+
Both numpy arrays and torch tensors are supported as input.
|
| 578 |
+
|
| 579 |
+
Args:
|
| 580 |
+
x (Union[torch.Tensor, np.ndarray]): Action sequences to process.
|
| 581 |
+
Shape: (b, seq_len, max_action_dim).
|
| 582 |
+
embodiment_ids (Union[torch.Tensor, int, List[int], None], optional):
|
| 583 |
+
Embodiment IDs. Shape: (b,) if tensor or list. If int, same ID for all sequences.
|
| 584 |
+
Defaults to None, which uses `self.default_embodiment_id`.
|
| 585 |
+
padding_mask (Union[torch.Tensor, List[bool], np.ndarray, None], optional):
|
| 586 |
+
Padding mask. Shape: (b, seq_len). Defaults to None.
|
| 587 |
+
|
| 588 |
+
Returns:
|
| 589 |
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
| 590 |
+
- x_recon (torch.Tensor): Reconstructed action sequences.
|
| 591 |
+
Shape: (b, seq_len, max_action_dim).
|
| 592 |
+
- recon_mask (torch.Tensor): Reconstruction mask indicating valid timesteps.
|
| 593 |
+
Shape: (b, seq_len), where True indicates valid timesteps.
|
| 594 |
+
|
| 595 |
+
Note:
|
| 596 |
+
- For inference use cases, prefer using `encode()` and `decode()` methods separately.
|
| 597 |
+
- If you need token indices, use the `encode()` method instead.
|
| 598 |
+
"""
|
| 599 |
+
# Convert numpy array to torch tensor if needed
|
| 600 |
+
if isinstance(x, np.ndarray):
|
| 601 |
+
x = torch.tensor(x, dtype=self.dtype, device=self.device)
|
| 602 |
+
|
| 603 |
+
# Handle embodiment_ids conversion
|
| 604 |
+
if isinstance(embodiment_ids, list):
|
| 605 |
+
embodiment_ids = torch.tensor(embodiment_ids, device=x.device, dtype=torch.long)
|
| 606 |
+
elif isinstance(embodiment_ids, int):
|
| 607 |
+
# Keep as int, will be handled by _encode
|
| 608 |
+
pass
|
| 609 |
+
|
| 610 |
+
# Handle padding_mask conversion
|
| 611 |
+
if isinstance(padding_mask, (list, np.ndarray)):
|
| 612 |
+
padding_mask = torch.tensor(padding_mask, device=x.device, dtype=torch.bool)
|
| 613 |
+
|
| 614 |
+
# Full forward pass: encode -> quantize -> decode
|
| 615 |
+
z_e = self._encode(x, embodiment_ids, padding_mask)
|
| 616 |
+
z_q, indices, perplexity, commit_loss = self._quantize(z_e, return_perplexity=True)
|
| 617 |
+
x_recon, recon_mask = self._decode(z_q, embodiment_ids)
|
| 618 |
+
|
| 619 |
+
return x_recon, recon_mask
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
AutoModel.register(ActionCodecConfig, ActionCodec)
|
| 623 |
+
|
| 624 |
+
__all__ = ["ActionCodec"]
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
if __name__ == "__main__":
|
| 628 |
+
print("=== ActionCodec Comprehensive Test ===\n")
|
| 629 |
+
|
| 630 |
+
# 1. Configuration Setup (RVQ enabled with n_quantizers=4)
|
| 631 |
+
initial_config = {
|
| 632 |
+
"robot_A": {"action_dim": 7, "freq": 10, "duration": 1, "description": "Robot A"},
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
# We set n_quantizers=4 to test Residual VQ logic
|
| 636 |
+
config = ActionCodecConfig(
|
| 637 |
+
embodiment_config=initial_config,
|
| 638 |
+
n_tokens=16, # Total tokens per sequence (latent_len * n_quantizers)
|
| 639 |
+
n_quantizers=4, # RVQ depth
|
| 640 |
+
vq_type="rvq",
|
| 641 |
+
vq_codebook_size=256,
|
| 642 |
+
encoder_dim=128,
|
| 643 |
+
decoder_dim=128,
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
# Expected latent sequence length = n_tokens / n_quantizers = 16 / 4 = 4
|
| 647 |
+
latent_seq_len = int(config.n_tokens // config.n_quantizers)
|
| 648 |
+
print(f"Config: {config.n_quantizers} quantizers, {latent_seq_len} latent vectors per sequence.")
|
| 649 |
+
|
| 650 |
+
codec = ActionCodec(config)
|
| 651 |
+
codec.eval()
|
| 652 |
+
|
| 653 |
+
# 2. Basic Encode/Decode Test
|
| 654 |
+
print("\n--- Test 1: Basic Encode/Decode ---")
|
| 655 |
+
batch_size = 2
|
| 656 |
+
seq_len_A = 10 # 10Hz * 1s
|
| 657 |
+
|
| 658 |
+
# Create random action data for Robot A (ID 0)
|
| 659 |
+
x = np.random.randn(batch_size, seq_len_A, 7).astype(np.float32)
|
| 660 |
+
# Masking: Second item in batch is half padding
|
| 661 |
+
padding_mask = np.ones((batch_size, seq_len_A), dtype=bool)
|
| 662 |
+
padding_mask[1, 5:] = False
|
| 663 |
+
|
| 664 |
+
embodiment_ids = [0, 0]
|
| 665 |
+
|
| 666 |
+
# Encode
|
| 667 |
+
codes = codec.encode(x, embodiment_ids, padding_mask)
|
| 668 |
+
print(f"Encoded codes shape (list length): {len(codes)} x {len(codes[0])}")
|
| 669 |
+
|
| 670 |
+
# Validate code length
|
| 671 |
+
assert len(codes[0]) == config.n_tokens, f"Expected {config.n_tokens} tokens, got {len(codes[0])}"
|
| 672 |
+
|
| 673 |
+
# Decode
|
| 674 |
+
x_recon, recon_mask = codec.decode(codes, embodiment_ids)
|
| 675 |
+
print(f"Reconstructed shape: {x_recon.shape}")
|
| 676 |
+
print(f"Recon mask shape: {recon_mask.shape}")
|
| 677 |
+
|
| 678 |
+
assert x_recon.shape == (batch_size, seq_len_A, 7) # Should imply zero-padding to max dim 7
|
| 679 |
+
|
| 680 |
+
# 3. Expansion Test
|
| 681 |
+
print("\n--- Test 2: Dynamic Expansion ---")
|
| 682 |
+
new_robot_config = {"robot_B": {"action_dim": 10, "freq": 20, "duration": 1, "description": "Robot B (Larger)"}}
|
| 683 |
+
|
| 684 |
+
print("Expanding codec to include Robot B (10 dims, 20Hz)...")
|
| 685 |
+
codec.expand_embodiment(new_robot_config)
|
| 686 |
+
|
| 687 |
+
assert codec.encoder.max_action_dim == 10
|
| 688 |
+
assert codec.decoder.max_action_dim == 10
|
| 689 |
+
print("✅ Expansion successful.")
|
| 690 |
+
|
| 691 |
+
# 4. Mixed Batch Test (Old + New Robot)
|
| 692 |
+
print("\n--- Test 3: Mixed Batch Inference ---")
|
| 693 |
+
|
| 694 |
+
# Batch: [Robot A, Robot B]
|
| 695 |
+
# Robot A: 10Hz, 1s -> 10 steps. Dims 7.
|
| 696 |
+
# Robot B: 20Hz, 1s -> 20 steps. Dims 10.
|
| 697 |
+
# Batch Max Steps: 20. Batch Max Dims: 10.
|
| 698 |
+
|
| 699 |
+
batch_x_mixed = np.zeros((2, 20, 10), dtype=np.float32)
|
| 700 |
+
|
| 701 |
+
# Fill Robot A data (index 0)
|
| 702 |
+
data_A = np.random.randn(10, 7)
|
| 703 |
+
batch_x_mixed[0, :10, :7] = data_A
|
| 704 |
+
|
| 705 |
+
# Fill Robot B data (index 1)
|
| 706 |
+
data_B = np.random.randn(20, 10)
|
| 707 |
+
batch_x_mixed[1, :20, :10] = data_B
|
| 708 |
+
|
| 709 |
+
# Embodiment IDs: 0 for A, 1 for B
|
| 710 |
+
# Note: expand_embodiment appends. Original was 0, new is 1.
|
| 711 |
+
mixed_ids = [0, 1]
|
| 712 |
+
|
| 713 |
+
# Encode Mask
|
| 714 |
+
mixed_mask = np.zeros((2, 20), dtype=bool)
|
| 715 |
+
mixed_mask[0, :10] = True
|
| 716 |
+
mixed_mask[1, :20] = True
|
| 717 |
+
|
| 718 |
+
print("Encoding mixed batch...")
|
| 719 |
+
mixed_codes = codec.encode(batch_x_mixed, mixed_ids, mixed_mask)
|
| 720 |
+
|
| 721 |
+
print("Decoding mixed batch...")
|
| 722 |
+
# Explicit durations (optional, but good for verification if we wanted to override defaults)
|
| 723 |
+
durations = [1, 1]
|
| 724 |
+
x_recon_mixed, dec_mask_mixed = codec.decode(mixed_codes, mixed_ids, durations)
|
| 725 |
+
|
| 726 |
+
print(f"Mixed Recon Shape: {x_recon_mixed.shape}")
|
| 727 |
+
|
| 728 |
+
# Validation
|
| 729 |
+
# Robot A output check (mask should be True for first 10, False for rest)
|
| 730 |
+
valid_A = dec_mask_mixed[0].sum()
|
| 731 |
+
valid_B = dec_mask_mixed[1].sum()
|
| 732 |
+
|
| 733 |
+
print(f"Valid steps detected by Decoder: Robot A={valid_A}, Robot B={valid_B}")
|
| 734 |
+
|
| 735 |
+
assert valid_A == 10
|
| 736 |
+
assert valid_B == 20
|
| 737 |
+
|
| 738 |
+
# Check dimensionality preservation
|
| 739 |
+
# Robot A's reconstruction in dims 7-9 should be noise or zero (depending on implementation),
|
| 740 |
+
# but dims 0-6 should contain signal.
|
| 741 |
+
print("✅ Mixed batch processed successfully.")
|
| 742 |
+
|
| 743 |
+
print("\n✨ All systems go.")
|
modular_actioncodec.py
ADDED
|
@@ -0,0 +1,787 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
from typing import List, Literal, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import einops
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from .configuration_actioncodec import ActionCodecConfig
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def apply_rotary_pos_emb(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
original_dtype = x.dtype
|
| 16 |
+
|
| 17 |
+
x = x.to(torch.float32)
|
| 18 |
+
sin = sin.to(torch.float32)
|
| 19 |
+
cos = cos.to(torch.float32)
|
| 20 |
+
|
| 21 |
+
x1 = x[..., 0::2]
|
| 22 |
+
x2 = x[..., 1::2]
|
| 23 |
+
|
| 24 |
+
rotated_x1 = x1 * cos - x2 * sin
|
| 25 |
+
rotated_x2 = x1 * sin + x2 * cos
|
| 26 |
+
|
| 27 |
+
x_out = torch.empty_like(x)
|
| 28 |
+
x_out[..., 0::2] = rotated_x1
|
| 29 |
+
x_out[..., 1::2] = rotated_x2
|
| 30 |
+
|
| 31 |
+
return x_out.to(original_dtype)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def attention_op(
|
| 35 |
+
q: torch.Tensor,
|
| 36 |
+
k: torch.Tensor,
|
| 37 |
+
v: torch.Tensor,
|
| 38 |
+
mask: torch.Tensor | None = None,
|
| 39 |
+
is_causal: bool = False,
|
| 40 |
+
) -> torch.Tensor:
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
q (torch.Tensor): (*b, h, l, d)
|
| 45 |
+
k (torch.Tensor): (*b, k, s, d)
|
| 46 |
+
v (torch.Tensor): (*b, k, s, d)
|
| 47 |
+
mask (torch.Tensor | None, optional): (*b, l, s), where `True` indicates the element should take part in attention. Defaults to None.
|
| 48 |
+
is_causal (bool, optional): Whether to apply causal mask. Defaults to False.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
torch.Tensor: (*b, h, l, d)
|
| 52 |
+
"""
|
| 53 |
+
heads, kv_heads = q.shape[-3], k.shape[-3]
|
| 54 |
+
if heads != kv_heads:
|
| 55 |
+
assert heads % kv_heads == 0, f"q_heads must be divisible by kv_heads, but got {heads} and {kv_heads}"
|
| 56 |
+
heads_per_kv_head = heads // kv_heads
|
| 57 |
+
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim=1), (k, v))
|
| 58 |
+
|
| 59 |
+
if mask is not None:
|
| 60 |
+
if mask.dim() == 3:
|
| 61 |
+
mask = mask.unsqueeze(1)
|
| 62 |
+
mask = mask.expand(mask.shape[0], heads, -1, -1)
|
| 63 |
+
|
| 64 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=is_causal)
|
| 65 |
+
return out
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class L2Norm(nn.Module):
|
| 69 |
+
def forward(self, x: torch.Tensor):
|
| 70 |
+
return F.normalize(x, p=2, dim=-1)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class Attention(nn.Module):
|
| 74 |
+
"""
|
| 75 |
+
Args:
|
| 76 |
+
hidden_size (int): Hidden size of the input tensor.
|
| 77 |
+
num_heads (int): Number of attention heads.
|
| 78 |
+
num_kv_heads (int, optional): Number of key/value heads. Defaults to None.
|
| 79 |
+
qk_norm (Literal["l2", "ln", "none"], optional): Type of normalization to apply to query/key. Defaults to "none".
|
| 80 |
+
bias (bool, optional): Whether to use bias in linear layers. Defaults to False.
|
| 81 |
+
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
hidden_size: int,
|
| 87 |
+
num_heads: int,
|
| 88 |
+
num_kv_heads: int | None = None,
|
| 89 |
+
qk_norm: Literal["l2", "ln", "none"] = "none",
|
| 90 |
+
bias: bool = False,
|
| 91 |
+
zero_init_output: bool = False,
|
| 92 |
+
):
|
| 93 |
+
super().__init__()
|
| 94 |
+
num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
| 95 |
+
self.dim = hidden_size // num_heads
|
| 96 |
+
self.num_heads, self.num_kv_heads = num_heads, num_kv_heads
|
| 97 |
+
|
| 98 |
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
| 99 |
+
self.k_proj = nn.Linear(hidden_size, self.dim * num_kv_heads, bias=bias)
|
| 100 |
+
self.v_proj = nn.Linear(hidden_size, self.dim * num_kv_heads, bias=bias)
|
| 101 |
+
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
| 102 |
+
|
| 103 |
+
if qk_norm == "l2":
|
| 104 |
+
self.q_norm = L2Norm()
|
| 105 |
+
self.k_norm = L2Norm()
|
| 106 |
+
elif qk_norm == "ln":
|
| 107 |
+
self.q_norm = nn.LayerNorm(self.dim, elementwise_affine=False)
|
| 108 |
+
self.k_norm = nn.LayerNorm(self.dim, elementwise_affine=False)
|
| 109 |
+
else:
|
| 110 |
+
self.q_norm = nn.Identity()
|
| 111 |
+
self.k_norm = nn.Identity()
|
| 112 |
+
|
| 113 |
+
if zero_init_output:
|
| 114 |
+
nn.init.zeros_(self.out_proj.weight)
|
| 115 |
+
if self.out_proj.bias is not None:
|
| 116 |
+
nn.init.zeros_(self.out_proj.bias)
|
| 117 |
+
|
| 118 |
+
def forward(
|
| 119 |
+
self,
|
| 120 |
+
x: torch.Tensor,
|
| 121 |
+
context: torch.Tensor | None = None,
|
| 122 |
+
mask: torch.Tensor | None = None,
|
| 123 |
+
rotary_pos_emb: Tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 124 |
+
is_causal: bool = False,
|
| 125 |
+
) -> torch.Tensor:
|
| 126 |
+
context = x if context is None else context
|
| 127 |
+
|
| 128 |
+
q = self.q_proj(x)
|
| 129 |
+
k, v = self.k_proj(context), self.v_proj(context)
|
| 130 |
+
|
| 131 |
+
q = einops.rearrange(q, "b l (h d) -> b h l d", h=self.num_heads)
|
| 132 |
+
k = einops.rearrange(k, "b s (h d) -> b h s d", h=self.num_kv_heads)
|
| 133 |
+
v = einops.rearrange(v, "b s (h d) -> b h s d", h=self.num_kv_heads)
|
| 134 |
+
|
| 135 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 136 |
+
|
| 137 |
+
if rotary_pos_emb is not None:
|
| 138 |
+
q, k = map(lambda t: apply_rotary_pos_emb(t, *rotary_pos_emb), (q, k))
|
| 139 |
+
|
| 140 |
+
out = attention_op(q, k, v, mask=mask, is_causal=is_causal)
|
| 141 |
+
out = einops.rearrange(out, "b h l d -> b l (h d)")
|
| 142 |
+
out = self.out_proj(out)
|
| 143 |
+
|
| 144 |
+
return out
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class PositionalEmbedding(nn.Module):
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
dim: int,
|
| 151 |
+
encoding_type: Literal["sincos", "fourier"] = "sincos",
|
| 152 |
+
scale: float = 2.0,
|
| 153 |
+
):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.dim = dim
|
| 156 |
+
self.encoding_type = encoding_type
|
| 157 |
+
|
| 158 |
+
if encoding_type == "fourier":
|
| 159 |
+
self.register_buffer("freqs", torch.randn(dim // 2) * scale, persistent=True)
|
| 160 |
+
elif encoding_type == "sincos":
|
| 161 |
+
pass
|
| 162 |
+
else:
|
| 163 |
+
raise ValueError(f"encoding_type must be 'sincos' or 'fourier', but got {encoding_type}")
|
| 164 |
+
|
| 165 |
+
def _create_sincos_emb(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
| 166 |
+
position = torch.arange(seq_len, device=device, dtype=torch.float32).unsqueeze(1)
|
| 167 |
+
div_term = torch.exp(
|
| 168 |
+
torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) * -(math.log(10000.0) / self.dim)
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
pos_emb = torch.zeros(seq_len, self.dim, device=device, dtype=dtype)
|
| 172 |
+
pos_emb[:, 0::2] = torch.sin(position * div_term).to(dtype)
|
| 173 |
+
pos_emb[:, 1::2] = torch.cos(position * div_term).to(dtype)
|
| 174 |
+
|
| 175 |
+
return pos_emb
|
| 176 |
+
|
| 177 |
+
def _create_fourier_emb(self, timestamps: torch.Tensor, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
| 178 |
+
# Ensure freqs is on the correct device
|
| 179 |
+
freqs = self.freqs.to(device)
|
| 180 |
+
pos_emb = torch.einsum("b t, d -> b t d", timestamps, 2 * np.pi * freqs).to(device, torch.float32)
|
| 181 |
+
pos_emb = torch.cat([pos_emb.cos(), pos_emb.sin()], dim=-1).to(dtype)
|
| 182 |
+
return pos_emb
|
| 183 |
+
|
| 184 |
+
def forward(
|
| 185 |
+
self, x: torch.Tensor, freq: Optional[Union[float, torch.Tensor]] = None, dtype: torch.dtype = torch.float32
|
| 186 |
+
) -> torch.Tensor:
|
| 187 |
+
b, t = x.shape[0], x.shape[1]
|
| 188 |
+
device = x.device
|
| 189 |
+
|
| 190 |
+
if self.encoding_type == "sincos":
|
| 191 |
+
pos_emb = self._create_sincos_emb(t, device, dtype)
|
| 192 |
+
pos_emb = pos_emb.unsqueeze(0).expand(b, -1, -1)
|
| 193 |
+
return pos_emb * 0.1
|
| 194 |
+
|
| 195 |
+
elif self.encoding_type == "fourier":
|
| 196 |
+
if freq is None:
|
| 197 |
+
raise ValueError(
|
| 198 |
+
"freq must be provided when encoding_type is 'fourier'. Please provide the sequence frequency."
|
| 199 |
+
)
|
| 200 |
+
if isinstance(freq, float):
|
| 201 |
+
freq = torch.tensor(freq, dtype=dtype, device=device)[None].expand(b)
|
| 202 |
+
timestamps = torch.einsum("t, b -> b t", torch.arange(t, dtype=dtype, device=device), 1 / freq)
|
| 203 |
+
pos_emb = self._create_fourier_emb(timestamps, device, dtype)
|
| 204 |
+
return pos_emb * 0.1
|
| 205 |
+
else:
|
| 206 |
+
raise ValueError(f"Unknown encoding_type: {self.encoding_type}")
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class SinusoidalPositionalEmbedding(PositionalEmbedding):
|
| 210 |
+
def __init__(self, dim: int):
|
| 211 |
+
super().__init__(dim=dim, encoding_type="sincos")
|
| 212 |
+
|
| 213 |
+
def forward(self, x: torch.Tensor, pos: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 214 |
+
return super().forward(x, freq=None)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class FeedForward(nn.Module):
|
| 218 |
+
def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False):
|
| 219 |
+
super().__init__()
|
| 220 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
| 221 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
|
| 222 |
+
self.act_fn = nn.GELU()
|
| 223 |
+
|
| 224 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 225 |
+
down_proj = self.down_proj(self.act_fn(self.up_proj(x)))
|
| 226 |
+
return down_proj
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class LayerScale(nn.Module):
|
| 230 |
+
def __init__(self, dim, init_val=1e-2):
|
| 231 |
+
super().__init__()
|
| 232 |
+
self.scale = nn.Parameter(torch.full([dim], init_val))
|
| 233 |
+
|
| 234 |
+
def forward(self, x):
|
| 235 |
+
return x * self.scale
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class PerceiverTransformerBlock(nn.Module):
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
dim: int,
|
| 242 |
+
num_heads: int,
|
| 243 |
+
mlp_ratio: int = 4,
|
| 244 |
+
dropout: float = 0.0,
|
| 245 |
+
qk_norm: str = "ln",
|
| 246 |
+
layer_scale: bool = True,
|
| 247 |
+
zero_init_output: bool = False,
|
| 248 |
+
add_self_attn: bool = False,
|
| 249 |
+
add_causal_mask: bool = False,
|
| 250 |
+
):
|
| 251 |
+
super().__init__()
|
| 252 |
+
self.add_self_attn = add_self_attn
|
| 253 |
+
self.add_causal_mask = add_causal_mask
|
| 254 |
+
|
| 255 |
+
self.norm1 = nn.LayerNorm(dim, eps=1e-2)
|
| 256 |
+
self.cross_attn = Attention(
|
| 257 |
+
hidden_size=dim, num_heads=num_heads, qk_norm=qk_norm, bias=False, zero_init_output=zero_init_output
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
if add_self_attn:
|
| 261 |
+
self.norm_self_attn = nn.LayerNorm(dim, eps=1e-2)
|
| 262 |
+
self.self_attn = Attention(
|
| 263 |
+
hidden_size=dim, num_heads=num_heads, qk_norm=qk_norm, bias=False, zero_init_output=zero_init_output
|
| 264 |
+
)
|
| 265 |
+
else:
|
| 266 |
+
self.self_attn = None
|
| 267 |
+
|
| 268 |
+
self.norm2 = nn.LayerNorm(dim, eps=1e-2)
|
| 269 |
+
self.mlp = FeedForward(hidden_size=dim, intermediate_size=int(mlp_ratio * dim), bias=True)
|
| 270 |
+
self.dropout = nn.Dropout(dropout)
|
| 271 |
+
|
| 272 |
+
self.attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
| 273 |
+
self.mlp_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
| 274 |
+
|
| 275 |
+
if zero_init_output:
|
| 276 |
+
nn.init.zeros_(self.mlp.down_proj.weight)
|
| 277 |
+
if self.mlp.down_proj.bias is not None:
|
| 278 |
+
nn.init.zeros_(self.mlp.down_proj.bias)
|
| 279 |
+
|
| 280 |
+
def forward(
|
| 281 |
+
self,
|
| 282 |
+
x: torch.Tensor,
|
| 283 |
+
context: torch.Tensor,
|
| 284 |
+
context_mask: Optional[torch.Tensor] = None,
|
| 285 |
+
rotary_pos_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 286 |
+
) -> torch.Tensor:
|
| 287 |
+
residual = x
|
| 288 |
+
x = self.norm1(x)
|
| 289 |
+
x = self.cross_attn(x=x, context=context, mask=context_mask, rotary_pos_emb=rotary_pos_emb, is_causal=False)
|
| 290 |
+
x = self.dropout(x)
|
| 291 |
+
x = self.attn_scale(x)
|
| 292 |
+
x = x + residual
|
| 293 |
+
|
| 294 |
+
if self.add_self_attn:
|
| 295 |
+
residual = x
|
| 296 |
+
x = self.norm_self_attn(x)
|
| 297 |
+
x = self.self_attn(
|
| 298 |
+
x=x,
|
| 299 |
+
context=None,
|
| 300 |
+
mask=None,
|
| 301 |
+
rotary_pos_emb=rotary_pos_emb,
|
| 302 |
+
is_causal=self.add_causal_mask,
|
| 303 |
+
)
|
| 304 |
+
x = self.dropout(x)
|
| 305 |
+
x = self.attn_scale(x)
|
| 306 |
+
x = x + residual
|
| 307 |
+
|
| 308 |
+
residual = x
|
| 309 |
+
x = self.norm2(x)
|
| 310 |
+
x = self.mlp(x)
|
| 311 |
+
x = self.dropout(x)
|
| 312 |
+
x = self.mlp_scale(x)
|
| 313 |
+
x = x + residual
|
| 314 |
+
|
| 315 |
+
return x
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class EmbodimentEmbedding(nn.Module):
|
| 319 |
+
def __init__(self, embodiment_config: dict, out_len: int, out_dim: int) -> None:
|
| 320 |
+
super().__init__()
|
| 321 |
+
self.out_len, self.out_dim = out_len, out_dim
|
| 322 |
+
|
| 323 |
+
self.embodiment_config = embodiment_config
|
| 324 |
+
self.num_embodiments = len(self.embodiment_config)
|
| 325 |
+
|
| 326 |
+
self.embedding = nn.Embedding(self.num_embodiments, out_dim * out_len)
|
| 327 |
+
|
| 328 |
+
@torch.no_grad()
|
| 329 |
+
def expand_embodiment(self, embodiment_config: dict):
|
| 330 |
+
for k in embodiment_config.keys():
|
| 331 |
+
assert k not in self.embodiment_config.keys()
|
| 332 |
+
self.embodiment_config.update(embodiment_config)
|
| 333 |
+
self.num_embodiments = len(self.embodiment_config)
|
| 334 |
+
|
| 335 |
+
extra_embodiments = len(embodiment_config)
|
| 336 |
+
|
| 337 |
+
old_weights = torch.clone(self.embedding.weight)
|
| 338 |
+
self.embedding = nn.Embedding(self.num_embodiments, self.out_dim * self.out_len)
|
| 339 |
+
self.embedding.weight.data[:-extra_embodiments] = old_weights
|
| 340 |
+
return self
|
| 341 |
+
|
| 342 |
+
def keys(self) -> list[str]:
|
| 343 |
+
return list(self.embodiment_config.keys())
|
| 344 |
+
|
| 345 |
+
def ids_to_keys(self, ids: torch.Tensor) -> List[str]:
|
| 346 |
+
return [self.keys()[i] for i in ids]
|
| 347 |
+
|
| 348 |
+
def keys_to_ids(self, keys: List[str]) -> torch.Tensor:
|
| 349 |
+
return torch.tensor([self.keys().index(k) for k in keys])
|
| 350 |
+
|
| 351 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 352 |
+
return einops.rearrange(self.embedding(x), "b (l d) -> b l d", d=self.out_dim)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class PerceiverEncoder(nn.Module):
|
| 356 |
+
def __init__(self, config: ActionCodecConfig):
|
| 357 |
+
super().__init__()
|
| 358 |
+
self.config = config
|
| 359 |
+
self.embodiment_config = deepcopy(config.embodiment_config)
|
| 360 |
+
|
| 361 |
+
out_len = int(config.n_tokens // config.n_quantizers)
|
| 362 |
+
dim = config.encoder_dim
|
| 363 |
+
|
| 364 |
+
_action_dim, _freq, _duration = list(), list(), list()
|
| 365 |
+
for k, v in self.embodiment_config.items():
|
| 366 |
+
_action_dim.append(v["action_dim"])
|
| 367 |
+
_freq.append(v["freq"])
|
| 368 |
+
_duration.append(v["duration"])
|
| 369 |
+
self.register_buffer("_action_dim", torch.tensor(_action_dim), persistent=False)
|
| 370 |
+
self.register_buffer("_freq", torch.tensor(_freq), persistent=False)
|
| 371 |
+
self.register_buffer("_duration", torch.tensor(_duration), persistent=False)
|
| 372 |
+
|
| 373 |
+
self.max_action_dim = max(v["action_dim"] for v in self.embodiment_config.values())
|
| 374 |
+
self.input_proj = nn.Linear(self.max_action_dim, dim)
|
| 375 |
+
|
| 376 |
+
self.cls_tokens = EmbodimentEmbedding(self.embodiment_config, out_len, dim)
|
| 377 |
+
|
| 378 |
+
self.pos_emb_q = PositionalEmbedding(dim, encoding_type="sincos")
|
| 379 |
+
self.pos_emb_kv = PositionalEmbedding(dim, encoding_type=config.encoder_pos_encoding_type)
|
| 380 |
+
|
| 381 |
+
self.layers = nn.ModuleList(
|
| 382 |
+
[
|
| 383 |
+
PerceiverTransformerBlock(
|
| 384 |
+
dim=dim,
|
| 385 |
+
num_heads=config.encoder_n_heads,
|
| 386 |
+
add_self_attn=config.encoder_add_self_attn,
|
| 387 |
+
add_causal_mask=config.encoder_add_causal_mask,
|
| 388 |
+
)
|
| 389 |
+
for _ in range(config.encoder_n_layers)
|
| 390 |
+
]
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
self.output_proj = nn.Linear(dim, config.z_dim)
|
| 394 |
+
self._init_weights()
|
| 395 |
+
|
| 396 |
+
def _init_weights(self):
|
| 397 |
+
nn.init.trunc_normal_(self.input_proj.weight, std=0.02)
|
| 398 |
+
if self.input_proj.bias is not None:
|
| 399 |
+
nn.init.zeros_(self.input_proj.bias)
|
| 400 |
+
nn.init.trunc_normal_(self.output_proj.weight, std=0.02)
|
| 401 |
+
if self.output_proj.bias is not None:
|
| 402 |
+
nn.init.zeros_(self.output_proj.bias)
|
| 403 |
+
|
| 404 |
+
nn.init.trunc_normal_(self.cls_tokens.embedding.weight, std=0.02)
|
| 405 |
+
|
| 406 |
+
@torch.no_grad()
|
| 407 |
+
def expand_embodiment(self, embodiment_config: dict):
|
| 408 |
+
self.cls_tokens.expand_embodiment(embodiment_config)
|
| 409 |
+
self.embodiment_config = self.cls_tokens.embodiment_config
|
| 410 |
+
_action_dim, _freq, _duration = list(), list(), list()
|
| 411 |
+
for k, v in self.embodiment_config.items():
|
| 412 |
+
_action_dim.append(v["action_dim"])
|
| 413 |
+
_freq.append(v["freq"])
|
| 414 |
+
_duration.append(v["duration"])
|
| 415 |
+
self._action_dim = torch.tensor(_action_dim)
|
| 416 |
+
self._freq = torch.tensor(_freq)
|
| 417 |
+
self._duration = torch.tensor(_duration)
|
| 418 |
+
|
| 419 |
+
max_action_dim = max(v["action_dim"] for v in self.embodiment_config.values())
|
| 420 |
+
if max_action_dim > self.max_action_dim:
|
| 421 |
+
old_weights = torch.clone(self.input_proj.weight)
|
| 422 |
+
old_bias = torch.clone(self.input_proj.bias)
|
| 423 |
+
self.input_proj = nn.Linear(max_action_dim, self.config.encoder_dim)
|
| 424 |
+
self.input_proj.weight.data[:, : self.max_action_dim] = old_weights
|
| 425 |
+
self.input_proj.bias.data = old_bias
|
| 426 |
+
self.max_action_dim = max_action_dim
|
| 427 |
+
|
| 428 |
+
return self
|
| 429 |
+
|
| 430 |
+
def forward(
|
| 431 |
+
self,
|
| 432 |
+
x: torch.Tensor,
|
| 433 |
+
embodiment_ids: torch.Tensor | int,
|
| 434 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 435 |
+
) -> torch.Tensor:
|
| 436 |
+
"""Encode action sequences into latent representations.
|
| 437 |
+
|
| 438 |
+
Args:
|
| 439 |
+
x (torch.Tensor): Action sequences to encode. Shape: (b, seq_len, max_action_dim).
|
| 440 |
+
Assumes that the action dimension is zero-padded to the max action dimension.
|
| 441 |
+
`seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length.
|
| 442 |
+
embodiment_ids (torch.Tensor | int): Embodiment IDs. Shape: (b,).
|
| 443 |
+
If int, the same embodiment ID is repeated for all sequences in the batch.
|
| 444 |
+
It specifies the embodiment to encode.
|
| 445 |
+
padding_mask (Optional[torch.Tensor], optional): Padding mask, where `False` values indicate padding. Shape: (b, seq_len). Defaults to None.
|
| 446 |
+
It is used to mask the padding tokens on `seq_len` dimension.
|
| 447 |
+
|
| 448 |
+
Returns:
|
| 449 |
+
torch.Tensor: Encoded latent representations. Shape: (b, n_tokens_per_quantizer, z_dim).
|
| 450 |
+
"""
|
| 451 |
+
b, seq_len, _ = x.shape
|
| 452 |
+
|
| 453 |
+
x = self.input_proj(x)
|
| 454 |
+
|
| 455 |
+
if isinstance(embodiment_ids, int):
|
| 456 |
+
embodiment_ids = torch.tensor([embodiment_ids], dtype=torch.long, device=x.device).repeat(b)
|
| 457 |
+
|
| 458 |
+
cls_tokens = self.cls_tokens(embodiment_ids)
|
| 459 |
+
|
| 460 |
+
freqs = self._freq[embodiment_ids].to(x.device, x.dtype)
|
| 461 |
+
|
| 462 |
+
pos_emb_q = self.pos_emb_q(cls_tokens)
|
| 463 |
+
pos_emb_kv = self.pos_emb_kv(x, freqs)
|
| 464 |
+
|
| 465 |
+
cls_tokens = cls_tokens + pos_emb_q
|
| 466 |
+
x = x + pos_emb_kv
|
| 467 |
+
|
| 468 |
+
if padding_mask is not None:
|
| 469 |
+
padding_mask = padding_mask.unsqueeze(1).expand(-1, cls_tokens.shape[1], -1)
|
| 470 |
+
|
| 471 |
+
for layer in self.layers:
|
| 472 |
+
cls_tokens = layer(x=cls_tokens, context=x, context_mask=padding_mask)
|
| 473 |
+
|
| 474 |
+
return self.output_proj(cls_tokens)
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
class PerceiverDecoder(nn.Module):
|
| 478 |
+
def __init__(self, config: ActionCodecConfig):
|
| 479 |
+
super().__init__()
|
| 480 |
+
self.config = config
|
| 481 |
+
self.embodiment_config = deepcopy(config.embodiment_config)
|
| 482 |
+
|
| 483 |
+
dim = config.decoder_dim
|
| 484 |
+
|
| 485 |
+
_action_dim, _freq, _duration = list(), list(), list()
|
| 486 |
+
for k, v in self.embodiment_config.items():
|
| 487 |
+
_action_dim.append(v["action_dim"])
|
| 488 |
+
_freq.append(v["freq"])
|
| 489 |
+
_duration.append(v["duration"])
|
| 490 |
+
self.register_buffer("_action_dim", torch.tensor(_action_dim), persistent=False)
|
| 491 |
+
self.register_buffer("_freq", torch.tensor(_freq), persistent=False)
|
| 492 |
+
self.register_buffer("_duration", torch.tensor(_duration), persistent=False)
|
| 493 |
+
|
| 494 |
+
self.max_action_dim = max(v["action_dim"] for v in self.embodiment_config.values())
|
| 495 |
+
self.input_proj = nn.Linear(config.z_dim, dim)
|
| 496 |
+
|
| 497 |
+
self.cls_tokens = EmbodimentEmbedding(self.embodiment_config, config.decoder_cls_size, dim)
|
| 498 |
+
|
| 499 |
+
self.pos_emb_q = PositionalEmbedding(dim, encoding_type=config.decoder_pos_encoding_type)
|
| 500 |
+
self.pos_emb_kv = PositionalEmbedding(dim, encoding_type="sincos")
|
| 501 |
+
|
| 502 |
+
self.layers = nn.ModuleList(
|
| 503 |
+
[
|
| 504 |
+
PerceiverTransformerBlock(
|
| 505 |
+
dim=dim,
|
| 506 |
+
num_heads=config.decoder_n_heads,
|
| 507 |
+
add_self_attn=config.decoder_add_self_attn,
|
| 508 |
+
add_causal_mask=config.decoder_add_causal_mask,
|
| 509 |
+
)
|
| 510 |
+
for _ in range(config.decoder_n_layers)
|
| 511 |
+
]
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
self.output_proj = nn.Linear(dim, self.max_action_dim)
|
| 515 |
+
self._init_weights()
|
| 516 |
+
|
| 517 |
+
def _init_weights(self):
|
| 518 |
+
nn.init.trunc_normal_(self.input_proj.weight, std=0.02)
|
| 519 |
+
if self.input_proj.bias is not None:
|
| 520 |
+
nn.init.zeros_(self.input_proj.bias)
|
| 521 |
+
nn.init.trunc_normal_(self.output_proj.weight, std=0.02)
|
| 522 |
+
if self.output_proj.bias is not None:
|
| 523 |
+
nn.init.zeros_(self.output_proj.bias)
|
| 524 |
+
nn.init.trunc_normal_(self.cls_tokens.embedding.weight, std=0.02)
|
| 525 |
+
|
| 526 |
+
@torch.no_grad()
|
| 527 |
+
def expand_embodiment(self, embodiment_config: dict):
|
| 528 |
+
self.cls_tokens.expand_embodiment(embodiment_config)
|
| 529 |
+
self.embodiment_config = self.cls_tokens.embodiment_config
|
| 530 |
+
|
| 531 |
+
_action_dim, _freq, _duration = list(), list(), list()
|
| 532 |
+
for k, v in self.embodiment_config.items():
|
| 533 |
+
_action_dim.append(v["action_dim"])
|
| 534 |
+
_freq.append(v["freq"])
|
| 535 |
+
_duration.append(v["duration"])
|
| 536 |
+
self._action_dim = torch.tensor(_action_dim)
|
| 537 |
+
self._freq = torch.tensor(_freq)
|
| 538 |
+
self._duration = torch.tensor(_duration)
|
| 539 |
+
|
| 540 |
+
max_action_dim = max(v["action_dim"] for v in self.embodiment_config.values())
|
| 541 |
+
|
| 542 |
+
if max_action_dim > self.max_action_dim:
|
| 543 |
+
old_weights = torch.clone(self.output_proj.weight)
|
| 544 |
+
old_bias = torch.clone(self.output_proj.bias)
|
| 545 |
+
|
| 546 |
+
self.output_proj = nn.Linear(self.config.decoder_dim, max_action_dim)
|
| 547 |
+
|
| 548 |
+
self.output_proj.weight.data[: self.max_action_dim, :] = old_weights
|
| 549 |
+
self.output_proj.bias.data[: self.max_action_dim] = old_bias
|
| 550 |
+
|
| 551 |
+
self.max_action_dim = max_action_dim
|
| 552 |
+
|
| 553 |
+
return self
|
| 554 |
+
|
| 555 |
+
def forward(
|
| 556 |
+
self, x: torch.Tensor, embodiment_ids: torch.Tensor | int, durations: torch.Tensor | None = None
|
| 557 |
+
) -> torch.Tensor:
|
| 558 |
+
"""Decode latent representations into action sequences.
|
| 559 |
+
|
| 560 |
+
Args:
|
| 561 |
+
x (torch.Tensor): Latent representations to decode. Shape: (b, n_tokens_per_quantizer, z_dim).
|
| 562 |
+
embodiment_ids (torch.Tensor | int): Embodiment IDs. Shape: (b,).
|
| 563 |
+
If int, the same embodiment ID is repeated for all sequences in the batch.
|
| 564 |
+
It specifies the embodiment to decode.
|
| 565 |
+
durations (torch.Tensor | None, optional): Duration of each action sequence. Shape: (b,).
|
| 566 |
+
If `None`, the duration is inferred from the default values in `embodiment_config`.
|
| 567 |
+
|
| 568 |
+
Returns:
|
| 569 |
+
torch.Tensor: Decoded action sequences. Shape: (b, seq_len, max_action_dim).
|
| 570 |
+
Assumes that the action dimension is zero-padded to the max action dimension.
|
| 571 |
+
`seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length.
|
| 572 |
+
"""
|
| 573 |
+
b, seq_len, _ = x.shape
|
| 574 |
+
x = self.input_proj(x)
|
| 575 |
+
|
| 576 |
+
if isinstance(embodiment_ids, int):
|
| 577 |
+
embodiment_ids = torch.tensor([embodiment_ids], dtype=torch.long, device=x.device).repeat(b)
|
| 578 |
+
|
| 579 |
+
cls_tokens = self.cls_tokens(embodiment_ids)
|
| 580 |
+
|
| 581 |
+
freqs = self._freq[embodiment_ids]
|
| 582 |
+
if freqs.device != x.device:
|
| 583 |
+
freqs = freqs.to(x.device)
|
| 584 |
+
|
| 585 |
+
durations = self._duration[embodiment_ids] if durations is None else durations
|
| 586 |
+
if isinstance(durations, torch.Tensor) and durations.device != x.device:
|
| 587 |
+
durations = durations.to(x.device)
|
| 588 |
+
|
| 589 |
+
action_horizons = (durations * freqs).long()
|
| 590 |
+
max_horizon = action_horizons.max().item()
|
| 591 |
+
padding_mask = torch.arange(max_horizon, device=x.device).expand(b, -1) < action_horizons.unsqueeze(1)
|
| 592 |
+
|
| 593 |
+
if self.config.decoder_cls_size == 1:
|
| 594 |
+
cls_tokens = cls_tokens.repeat(1, max_horizon, 1)
|
| 595 |
+
|
| 596 |
+
pos_emb_q = self.pos_emb_q(cls_tokens, freqs)
|
| 597 |
+
pos_emb_kv = self.pos_emb_kv(x)
|
| 598 |
+
|
| 599 |
+
cls_tokens = cls_tokens + pos_emb_q
|
| 600 |
+
x = x + pos_emb_kv
|
| 601 |
+
|
| 602 |
+
for layer in self.layers:
|
| 603 |
+
cls_tokens = layer(x=cls_tokens, context=x)
|
| 604 |
+
|
| 605 |
+
output = self.output_proj(cls_tokens)
|
| 606 |
+
|
| 607 |
+
return output, padding_mask
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
if __name__ == "__main__":
|
| 611 |
+
# ------------------------------------------
|
| 612 |
+
# 1. Initialization
|
| 613 |
+
# ------------------------------------------
|
| 614 |
+
print("=== Test 1: Initialization ===")
|
| 615 |
+
|
| 616 |
+
# Define initial config with two smaller robots
|
| 617 |
+
initial_embodiment_config = {
|
| 618 |
+
"robot_small_7d": {"action_dim": 7, "freq": 20, "duration": 1, "description": "Original Robot"},
|
| 619 |
+
"robot_tiny_3d": {"action_dim": 3, "freq": 10, "duration": 2, "description": "Tiny Robot"},
|
| 620 |
+
}
|
| 621 |
+
|
| 622 |
+
config = ActionCodecConfig(embodiment_config=initial_embodiment_config)
|
| 623 |
+
|
| 624 |
+
# Set seed for reproducibility
|
| 625 |
+
torch.manual_seed(42)
|
| 626 |
+
|
| 627 |
+
encoder = PerceiverEncoder(config)
|
| 628 |
+
decoder = PerceiverDecoder(config)
|
| 629 |
+
|
| 630 |
+
encoder.eval()
|
| 631 |
+
decoder.eval()
|
| 632 |
+
print("✅ Models initialized successfully.")
|
| 633 |
+
|
| 634 |
+
# ------------------------------------------
|
| 635 |
+
# 2. Baseline Inference (Before Expansion)
|
| 636 |
+
# ------------------------------------------
|
| 637 |
+
print("\n=== Test 2: Baseline Inference (Before Expansion) ===")
|
| 638 |
+
|
| 639 |
+
# Simulate Robot 1 (7-dim) data
|
| 640 |
+
# Max action dim currently is 7.
|
| 641 |
+
batch_size = 1
|
| 642 |
+
seq_len = 20 # 20Hz * 1s
|
| 643 |
+
|
| 644 |
+
# Input: (1, 20, 7)
|
| 645 |
+
input_action_v0 = torch.randn(batch_size, seq_len, 7)
|
| 646 |
+
emb_id_v0 = torch.tensor([0], dtype=torch.long) # ID 0 -> robot_small_7d
|
| 647 |
+
|
| 648 |
+
with torch.no_grad():
|
| 649 |
+
z_ref = encoder(input_action_v0, emb_id_v0)
|
| 650 |
+
rec_action_ref, _ = decoder(z_ref, emb_id_v0)
|
| 651 |
+
|
| 652 |
+
print(f"Reference Latent Shape: {z_ref.shape}")
|
| 653 |
+
print(f"Reference Recon Shape: {rec_action_ref.shape}")
|
| 654 |
+
|
| 655 |
+
# ------------------------------------------
|
| 656 |
+
# 3. Model Expansion (Add New Embodiment)
|
| 657 |
+
# ------------------------------------------
|
| 658 |
+
print("\n=== Test 3: Model Expansion ===")
|
| 659 |
+
|
| 660 |
+
# Add a larger robot: 10-dim, high frequency
|
| 661 |
+
new_embodiment_config = {
|
| 662 |
+
"robot_large_10d": {"action_dim": 10, "freq": 30, "duration": 1, "description": "New Large Robot"}
|
| 663 |
+
}
|
| 664 |
+
|
| 665 |
+
print(f"Expanding from Max Dim {encoder.max_action_dim} to 10...")
|
| 666 |
+
encoder.expand_embodiment(new_embodiment_config)
|
| 667 |
+
decoder.expand_embodiment(new_embodiment_config)
|
| 668 |
+
|
| 669 |
+
# Verify buffer updates
|
| 670 |
+
assert encoder._action_dim[-1] == 10
|
| 671 |
+
assert encoder.max_action_dim == 10
|
| 672 |
+
assert decoder.max_action_dim == 10
|
| 673 |
+
print(f"✅ Expansion successful. New Encoder Input Dim: {encoder.input_proj.weight.shape[1]}")
|
| 674 |
+
print(f"✅ New Decoder Output Dim: {decoder.output_proj.weight.shape[0]}")
|
| 675 |
+
|
| 676 |
+
# ------------------------------------------
|
| 677 |
+
# 4. Encoder Invariance Check
|
| 678 |
+
# ------------------------------------------
|
| 679 |
+
print("\n=== Test 4: Encoder Invariance Check ===")
|
| 680 |
+
|
| 681 |
+
# Pad old data (7 dims) to new max dim (10 dims) with ZEROS.
|
| 682 |
+
input_action_padded = torch.zeros(batch_size, seq_len, 10)
|
| 683 |
+
input_action_padded[:, :, :7] = input_action_v0
|
| 684 |
+
|
| 685 |
+
with torch.no_grad():
|
| 686 |
+
z_new = encoder(input_action_padded, emb_id_v0)
|
| 687 |
+
|
| 688 |
+
# Compare latents
|
| 689 |
+
diff_z = (z_ref - z_new).abs().max().item()
|
| 690 |
+
print(f"Latent Difference (Max Abs): {diff_z:.8f}")
|
| 691 |
+
|
| 692 |
+
if diff_z < 1e-6:
|
| 693 |
+
print("✅ PASS: Encoder produces identical latents for old data.")
|
| 694 |
+
else:
|
| 695 |
+
print("❌ FAIL: Encoder outputs changed after expansion!")
|
| 696 |
+
|
| 697 |
+
# ------------------------------------------
|
| 698 |
+
# 5. Decoder Invariance Check
|
| 699 |
+
# ------------------------------------------
|
| 700 |
+
print("\n=== Test 5: Decoder Invariance Check ===")
|
| 701 |
+
|
| 702 |
+
with torch.no_grad():
|
| 703 |
+
# Feed old latent to expanded decoder
|
| 704 |
+
rec_action_new_full, _ = decoder(z_ref, emb_id_v0)
|
| 705 |
+
|
| 706 |
+
# Output shape should be (1, 20, 10)
|
| 707 |
+
print(f"Expanded Decoder Output Shape: {rec_action_new_full.shape}")
|
| 708 |
+
|
| 709 |
+
# Slice first 7 dims, should match reference
|
| 710 |
+
rec_action_new_sliced = rec_action_new_full[:, :, :7]
|
| 711 |
+
|
| 712 |
+
diff_rec = (rec_action_ref - rec_action_new_sliced).abs().max().item()
|
| 713 |
+
print(f"Reconstruction Difference (Max Abs on valid dims): {diff_rec:.8f}")
|
| 714 |
+
|
| 715 |
+
if diff_rec < 1e-6:
|
| 716 |
+
print("✅ PASS: Decoder produces identical action values for valid dimensions.")
|
| 717 |
+
else:
|
| 718 |
+
print("❌ FAIL: Decoder outputs changed!")
|
| 719 |
+
|
| 720 |
+
# Check phantom dimensions (7-9)
|
| 721 |
+
# For old embodiment, these are driven by random weights and should be random
|
| 722 |
+
new_dims_mean = rec_action_new_full[:, :, 7:].abs().mean().item()
|
| 723 |
+
print(f"Values in new phantom dimensions (should be random garbage): {new_dims_mean:.4f}")
|
| 724 |
+
|
| 725 |
+
# ------------------------------------------
|
| 726 |
+
# 6. New Embodiment Inference
|
| 727 |
+
# ------------------------------------------
|
| 728 |
+
print("\n=== Test 6: New Embodiment Inference ===")
|
| 729 |
+
|
| 730 |
+
# ID 2 -> robot_large_10d
|
| 731 |
+
emb_id_new = torch.tensor([2], dtype=torch.long)
|
| 732 |
+
seq_len_new = 30 # 30Hz * 1s
|
| 733 |
+
|
| 734 |
+
input_action_new = torch.randn(1, seq_len_new, 10)
|
| 735 |
+
|
| 736 |
+
with torch.no_grad():
|
| 737 |
+
z_large = encoder(input_action_new, emb_id_new)
|
| 738 |
+
rec_large, mask_large = decoder(z_large, emb_id_new)
|
| 739 |
+
|
| 740 |
+
print(f"New Embodiment Output Shape: {rec_large.shape}")
|
| 741 |
+
|
| 742 |
+
if rec_large.shape == (1, 30, 10):
|
| 743 |
+
print("✅ PASS: New embodiment handled correctly with full dimensions.")
|
| 744 |
+
else:
|
| 745 |
+
print(f"❌ FAIL: Expected (1, 30, 10), got {rec_large.shape}")
|
| 746 |
+
|
| 747 |
+
# ------------------------------------------
|
| 748 |
+
# 7. Mixed Batch Processing (Masking)
|
| 749 |
+
# ------------------------------------------
|
| 750 |
+
print("\n=== Test 7: Mixed Batch Processing ===")
|
| 751 |
+
|
| 752 |
+
# Batch size 2: [Robot 0 (20Hz, 7dim), Robot 2 (30Hz, 10dim)]
|
| 753 |
+
mixed_emb_ids = torch.tensor([0, 2], dtype=torch.long)
|
| 754 |
+
|
| 755 |
+
# Max seq len is 30. Max action dim is 10.
|
| 756 |
+
batch_input = torch.zeros(2, 30, 10)
|
| 757 |
+
|
| 758 |
+
# Fill data
|
| 759 |
+
# Batch 0: Length 20, Dim 7 valid
|
| 760 |
+
batch_input[0, :20, :7] = torch.randn(20, 7)
|
| 761 |
+
# Batch 1: Length 30, Dim 10 valid
|
| 762 |
+
batch_input[1, :30, :10] = torch.randn(30, 10)
|
| 763 |
+
|
| 764 |
+
# Encoder Mask: True = Valid
|
| 765 |
+
enc_padding_mask = torch.zeros(2, 30, dtype=torch.bool)
|
| 766 |
+
enc_padding_mask[0, :20] = True
|
| 767 |
+
enc_padding_mask[1, :30] = True
|
| 768 |
+
|
| 769 |
+
print("Running mixed batch...")
|
| 770 |
+
with torch.no_grad():
|
| 771 |
+
z_mixed = encoder(batch_input, mixed_emb_ids, padding_mask=enc_padding_mask)
|
| 772 |
+
rec_mixed, dec_padding_mask = decoder(z_mixed, mixed_emb_ids)
|
| 773 |
+
|
| 774 |
+
print(f"Mixed Reconstruction Shape: {rec_mixed.shape}") # Should be (2, 30, 10)
|
| 775 |
+
|
| 776 |
+
# Verify Decoder Generated Mask
|
| 777 |
+
valid_len_0 = dec_padding_mask[0].sum().item()
|
| 778 |
+
valid_len_1 = dec_padding_mask[1].sum().item()
|
| 779 |
+
|
| 780 |
+
print(f"Decoder Mask Valid Lengths: Batch 0={valid_len_0}, Batch 1={valid_len_1}")
|
| 781 |
+
|
| 782 |
+
if valid_len_0 == 20 and valid_len_1 == 30:
|
| 783 |
+
print("✅ PASS: Decoder correctly generated masks based on frequency and duration.")
|
| 784 |
+
else:
|
| 785 |
+
print("❌ FAIL: Decoder masks are incorrect.")
|
| 786 |
+
|
| 787 |
+
print("\n✨ All Tests Completed ✨")
|
rvq.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from vector_quantize_pytorch import VectorQuantize as torchVQ
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def sample_vectors(samples, num):
|
| 13 |
+
# samples: (N, D), num_samples: N, feature dim: D
|
| 14 |
+
num_samples, device = samples.shape[0], samples.device
|
| 15 |
+
if num_samples >= num:
|
| 16 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
| 17 |
+
else:
|
| 18 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
| 19 |
+
return samples[indices].float() # (num, D), ensure fp32
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def ema_inplace(moving_avg, new, decay):
|
| 23 |
+
# moving_avg: (codebook_size) or (codebook_size, D'), new: same as moving_avg
|
| 24 |
+
"""Update exponential moving average in-place"""
|
| 25 |
+
moving_avg.data.mul_(decay).add_(new.float(), alpha=(1 - decay)) # ensure fp32
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def kmeans(samples, num_clusters, num_iters=10):
|
| 29 |
+
# samples: (N, D), N samples with D dimensions
|
| 30 |
+
dim, _ = samples.shape[-1], torch.float32 # Force fp32
|
| 31 |
+
means = sample_vectors(samples, num_clusters).float() # (num_clusters, D), ensure fp32
|
| 32 |
+
|
| 33 |
+
for _ in range(num_iters):
|
| 34 |
+
dists = -(
|
| 35 |
+
samples.float().pow(2).sum(1, keepdim=True) # (N, 1), ensure fp32
|
| 36 |
+
- 2 * samples.float() @ means.t() # (N, num_clusters), ensure fp32
|
| 37 |
+
+ means.t().float().pow(2).sum(0, keepdim=True)
|
| 38 |
+
) # (1, num_clusters), ensure fp32
|
| 39 |
+
# dists: (N, num_clusters)
|
| 40 |
+
buckets = dists.max(dim=-1).indices # (N)
|
| 41 |
+
bins = torch.bincount(buckets, minlength=num_clusters) # (num_clusters)
|
| 42 |
+
zero_mask = bins == 0 # (num_clusters)
|
| 43 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1) # (num_clusters)
|
| 44 |
+
|
| 45 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=torch.float32) # (num_clusters, D), ensure fp32
|
| 46 |
+
new_means.scatter_add_(
|
| 47 |
+
0, buckets.unsqueeze(1).expand(-1, dim), samples.float()
|
| 48 |
+
) # (num_clusters, D), ensure fp32
|
| 49 |
+
new_means = new_means / bins_min_clamped[..., None] # (num_clusters, D)
|
| 50 |
+
means = torch.where(zero_mask[..., None], means, new_means) # (num_clusters, D)
|
| 51 |
+
|
| 52 |
+
# Final cluster assignments for returning cluster sizes
|
| 53 |
+
dists = -(
|
| 54 |
+
samples.float().pow(2).sum(1, keepdim=True)
|
| 55 |
+
- 2 * samples.float() @ means.t()
|
| 56 |
+
+ means.t().float().pow(2).sum(0, keepdim=True)
|
| 57 |
+
) # (N, num_clusters), ensure fp32
|
| 58 |
+
buckets = dists.max(dim=-1).indices # (N)
|
| 59 |
+
bins = torch.bincount(buckets, minlength=num_clusters).float() # (num_clusters), ensure fp32
|
| 60 |
+
|
| 61 |
+
return means, bins # (num_clusters, D), (num_clusters)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class VectorQuantize(nn.Module):
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
input_dim,
|
| 68 |
+
codebook_size,
|
| 69 |
+
codebook_dim,
|
| 70 |
+
commitment=1.0,
|
| 71 |
+
decay=0.99, # EMA decay
|
| 72 |
+
epsilon=1e-5, # Laplace smoothing epsilon
|
| 73 |
+
threshold_ema_dead=2, # Dead code threshold
|
| 74 |
+
kmeans_init=True, # Use kmeans initialization
|
| 75 |
+
kmeans_iters=10, # Kmeans iterations
|
| 76 |
+
rotation_trick=False, # Use rotation trick
|
| 77 |
+
**kwargs,
|
| 78 |
+
):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.input_dim = input_dim
|
| 81 |
+
self.codebook_size = codebook_size
|
| 82 |
+
self.codebook_dim = codebook_dim
|
| 83 |
+
self.commitment = commitment
|
| 84 |
+
self.decay = decay
|
| 85 |
+
self.epsilon = epsilon
|
| 86 |
+
self.threshold_ema_dead = threshold_ema_dead
|
| 87 |
+
self.kmeans_init = kmeans_init
|
| 88 |
+
self.kmeans_iters = kmeans_iters
|
| 89 |
+
self.rotation_trick = rotation_trick
|
| 90 |
+
|
| 91 |
+
if self.input_dim != self.codebook_dim:
|
| 92 |
+
self.in_project = nn.Linear(input_dim, codebook_dim)
|
| 93 |
+
self.out_project = nn.Linear(codebook_dim, input_dim)
|
| 94 |
+
else:
|
| 95 |
+
self.in_project = nn.Identity()
|
| 96 |
+
self.out_project = nn.Identity()
|
| 97 |
+
|
| 98 |
+
# Initialize codebook and EMA buffers
|
| 99 |
+
init_fn = torch.zeros if kmeans_init else lambda x, y: torch.randn(x, y)
|
| 100 |
+
self.register_buffer(
|
| 101 |
+
"codebook", init_fn(codebook_size, codebook_dim).float()
|
| 102 |
+
) # (codebook_size, D'), ensure fp32
|
| 103 |
+
self.register_buffer("inited", torch.tensor([not kmeans_init], dtype=torch.bool)) # (1)
|
| 104 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size).float()) # (codebook_size), ensure fp32
|
| 105 |
+
self.register_buffer("embed_avg", self.codebook.clone().float()) # (codebook_size, D'), ensure fp32
|
| 106 |
+
|
| 107 |
+
def ema_update(self, encodings, embed_onehot):
|
| 108 |
+
# encodings: (B*T, D'), embed_onehot: (B*T, codebook_size)
|
| 109 |
+
"""Update codebook using EMA"""
|
| 110 |
+
encodings = encodings.float() # Ensure fp32
|
| 111 |
+
embed_onehot = embed_onehot.float() # Ensure fp32
|
| 112 |
+
cluster_size_new = embed_onehot.sum(0) # (codebook_size)
|
| 113 |
+
embed_sum = encodings.t() @ embed_onehot # (D', codebook_size)
|
| 114 |
+
|
| 115 |
+
# Distributed reduction
|
| 116 |
+
if dist.is_initialized():
|
| 117 |
+
dist.all_reduce(cluster_size_new, op=dist.ReduceOp.SUM)
|
| 118 |
+
dist.all_reduce(embed_sum, op=dist.ReduceOp.SUM)
|
| 119 |
+
|
| 120 |
+
ema_inplace(self.cluster_size, cluster_size_new, self.decay) # (codebook_size)
|
| 121 |
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay) # (codebook_size, D')
|
| 122 |
+
|
| 123 |
+
# Laplace smoothing
|
| 124 |
+
cluster_size = (self.cluster_size + self.epsilon) / (
|
| 125 |
+
self.cluster_size.sum() + self.codebook_size * self.epsilon
|
| 126 |
+
) # (codebook_size)
|
| 127 |
+
cluster_size = cluster_size * self.cluster_size.sum() # (codebook_size)
|
| 128 |
+
self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1)) # (codebook_size, D')
|
| 129 |
+
|
| 130 |
+
def replace_dead_codes(self, encodings):
|
| 131 |
+
# encodings: (B*T, D')
|
| 132 |
+
"""Replace dead codes with random samples from current batch"""
|
| 133 |
+
if self.threshold_ema_dead == 0:
|
| 134 |
+
return
|
| 135 |
+
|
| 136 |
+
dead_mask = self.cluster_size < self.threshold_ema_dead # (codebook_size)
|
| 137 |
+
if dead_mask.any():
|
| 138 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
| 139 |
+
samples = sample_vectors(encodings.float(), self.codebook_size) # (codebook_size, D'), ensure fp32
|
| 140 |
+
print(f"Replace {dead_mask.sum().item()} dead codes")
|
| 141 |
+
else:
|
| 142 |
+
samples = torch.zeros_like(self.codebook).float() # Placeholder, ensure fp32
|
| 143 |
+
|
| 144 |
+
# Broadcast samples
|
| 145 |
+
if dist.is_initialized():
|
| 146 |
+
dist.broadcast(samples, src=0)
|
| 147 |
+
|
| 148 |
+
self.codebook[dead_mask] = samples[: dead_mask.sum()].to(self.codebook.dtype) # Update dead codes
|
| 149 |
+
|
| 150 |
+
def init_codebook(self, encodings):
|
| 151 |
+
# encodings: (B*T, D')
|
| 152 |
+
"""Initialize codebook with k-means and update cluster_size"""
|
| 153 |
+
if self.inited.item():
|
| 154 |
+
return
|
| 155 |
+
|
| 156 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
| 157 |
+
embed, cluster_sizes = kmeans(
|
| 158 |
+
encodings.float(), self.codebook_size, self.kmeans_iters
|
| 159 |
+
) # (codebook_size, D'), (codebook_size), ensure fp32
|
| 160 |
+
else:
|
| 161 |
+
embed = torch.zeros(self.codebook_size, self.codebook_dim, device=encodings.device).float() # ensure fp32
|
| 162 |
+
cluster_sizes = torch.zeros(self.codebook_size, device=encodings.device, dtype=torch.float32) # ensure fp32
|
| 163 |
+
|
| 164 |
+
# Broadcast results
|
| 165 |
+
if dist.is_initialized():
|
| 166 |
+
dist.broadcast(embed, src=0)
|
| 167 |
+
dist.broadcast(cluster_sizes, src=0)
|
| 168 |
+
|
| 169 |
+
self.codebook.copy_(embed) # (codebook_size, D')
|
| 170 |
+
self.embed_avg.copy_(embed.clone()) # (codebook_size, D')
|
| 171 |
+
self.cluster_size.copy_(cluster_sizes.float()) # (codebook_size)
|
| 172 |
+
self.inited.fill_(True)
|
| 173 |
+
|
| 174 |
+
def forward(self, z):
|
| 175 |
+
self = self.to(torch.float32)
|
| 176 |
+
z = z.float()
|
| 177 |
+
z_e = self.in_project(z).float()
|
| 178 |
+
|
| 179 |
+
# Rearrange for quantization
|
| 180 |
+
encodings = rearrange(z_e, "b t d -> (b t) d").float() # (B*T, D'), ensure fp32
|
| 181 |
+
|
| 182 |
+
# Initialize codebook if needed
|
| 183 |
+
if self.kmeans_init and not self.inited.item():
|
| 184 |
+
self.init_codebook(encodings)
|
| 185 |
+
|
| 186 |
+
dist = (
|
| 187 |
+
encodings.pow(2).sum(1, keepdim=True)
|
| 188 |
+
- 2 * encodings @ self.codebook.float().t()
|
| 189 |
+
+ self.codebook.float().pow(2).sum(1, keepdim=True).t()
|
| 190 |
+
)
|
| 191 |
+
indices = (-dist).max(1)[1]
|
| 192 |
+
|
| 193 |
+
# cosine_similarity = F.cosine_similarity(encodings[None], self.codebook[:, None], dim=-1)
|
| 194 |
+
# indices = cosine_similarity.max(dim=0)[1]
|
| 195 |
+
|
| 196 |
+
indices = rearrange(indices, "(b t) -> b t", b=z.size(0))
|
| 197 |
+
z_q = self.decode_code(indices).float()
|
| 198 |
+
commit_loss = F.mse_loss(z_e, z_q.detach()) * self.commitment
|
| 199 |
+
|
| 200 |
+
if self.training and torch.is_grad_enabled():
|
| 201 |
+
embed_onehot = F.one_hot(indices.view(-1), self.codebook_size).float()
|
| 202 |
+
self.ema_update(encodings, embed_onehot)
|
| 203 |
+
self.replace_dead_codes(encodings)
|
| 204 |
+
|
| 205 |
+
z_q = (z_q - z_e).detach() + z_e
|
| 206 |
+
z_q = self.out_project(z_q).float()
|
| 207 |
+
|
| 208 |
+
return (
|
| 209 |
+
z_q,
|
| 210 |
+
commit_loss,
|
| 211 |
+
torch.tensor(0.0, device=z.device, dtype=torch.float32),
|
| 212 |
+
indices,
|
| 213 |
+
z_e,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
def decode_code(self, embed_id): # embed_id: (B, T)
|
| 217 |
+
return F.embedding(embed_id, self.codebook).float() # (B, D', T), ensure fp32
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# class VectorQuantize(nn.Module):
|
| 221 |
+
# """
|
| 222 |
+
# Implementation of VQ similar to Karpathy's repo:
|
| 223 |
+
# https://github.com/karpathy/deep-vector-quantization
|
| 224 |
+
# Additionally uses following tricks from Improved VQGAN
|
| 225 |
+
# (https://arxiv.org/pdf/2110.04627.pdf):
|
| 226 |
+
# 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
| 227 |
+
# for improved codebook usage
|
| 228 |
+
# 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
| 229 |
+
# improves training stability
|
| 230 |
+
# """
|
| 231 |
+
|
| 232 |
+
# def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
|
| 233 |
+
# super().__init__()
|
| 234 |
+
# self.codebook_size = codebook_size
|
| 235 |
+
# self.codebook_dim = codebook_dim
|
| 236 |
+
|
| 237 |
+
# self.in_proj = nn.Linear(input_dim, codebook_dim)
|
| 238 |
+
# self.out_proj = nn.Linear(codebook_dim, input_dim)
|
| 239 |
+
# self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
| 240 |
+
|
| 241 |
+
# def forward(self, z: torch.Tensor):
|
| 242 |
+
# """
|
| 243 |
+
# Args:
|
| 244 |
+
# z (torch.Tensor): shape (b, t, d)
|
| 245 |
+
|
| 246 |
+
# Returns:
|
| 247 |
+
# z_q (torch.Tensor): shape (b, t, d)
|
| 248 |
+
# commitment_loss (torch.Tensor): shape (1)
|
| 249 |
+
# codebook_loss (torch.Tensor): shape (1)
|
| 250 |
+
# indices (torch.Tensor): shape (b, t)
|
| 251 |
+
# z_e (torch.Tensor): shape (b, t, d)
|
| 252 |
+
# """
|
| 253 |
+
|
| 254 |
+
# # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
| 255 |
+
# z_e = self.in_proj(z)
|
| 256 |
+
# z_q, indices = self.decode_latents(z_e)
|
| 257 |
+
|
| 258 |
+
# commitment_loss = F.mse_loss(z_e, z_q.detach()) * 0.25
|
| 259 |
+
# codebook_loss = F.mse_loss(z_q, z_e.detach())
|
| 260 |
+
|
| 261 |
+
# z_q = z_e + (z_q - z_e).detach() # noop in forward pass, straight-through gradient estimator in backward pass
|
| 262 |
+
|
| 263 |
+
# z_q = self.out_proj(z_q)
|
| 264 |
+
|
| 265 |
+
# return z_q, commitment_loss, codebook_loss, indices, z_e
|
| 266 |
+
|
| 267 |
+
# def embed_code(self, embed_id):
|
| 268 |
+
# return F.embedding(embed_id, self.codebook.weight)
|
| 269 |
+
|
| 270 |
+
# def decode_code(self, embed_id):
|
| 271 |
+
# return self.embed_code(embed_id)
|
| 272 |
+
|
| 273 |
+
# def decode_latents(self, latents: torch.Tensor):
|
| 274 |
+
# codebook = self.codebook.weight
|
| 275 |
+
# encodings = rearrange(latents, "b t d -> (b t) d")
|
| 276 |
+
|
| 277 |
+
# cosine_similarity = F.cosine_similarity(encodings[None], codebook[:, None], dim=-1)
|
| 278 |
+
# indices = cosine_similarity.max(dim=0)[1]
|
| 279 |
+
# indices = rearrange(indices, "(b t) -> b t", b=latents.size(0))
|
| 280 |
+
|
| 281 |
+
# # encodings = F.normalize(encodings)
|
| 282 |
+
# # codebook = F.normalize(codebook)
|
| 283 |
+
# # dist = (
|
| 284 |
+
# # encodings.pow(2).sum(1, keepdim=True)
|
| 285 |
+
# # - 2 * encodings @ codebook.t()
|
| 286 |
+
# # + codebook.pow(2).sum(1, keepdim=True).t()
|
| 287 |
+
# # )
|
| 288 |
+
# # indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
| 289 |
+
|
| 290 |
+
# z_q = self.decode_code(indices)
|
| 291 |
+
# return z_q, indices
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class ResidualVectorQuantize(nn.Module):
|
| 295 |
+
def __init__(
|
| 296 |
+
self,
|
| 297 |
+
dim: int = 256,
|
| 298 |
+
n_codebooks: int = 4,
|
| 299 |
+
codebook_size: int = 512,
|
| 300 |
+
codebook_dim: Union[int, list] = 8,
|
| 301 |
+
quantizer_dropout: float = 0.25,
|
| 302 |
+
commitment: float = 0.25,
|
| 303 |
+
decay: float = 0.99,
|
| 304 |
+
epsilon: float = 1e-5,
|
| 305 |
+
threshold_ema_dead: int = 2,
|
| 306 |
+
kmeans_init: bool = True,
|
| 307 |
+
kmeans_iters: int = 10,
|
| 308 |
+
rotation_trick: bool = False,
|
| 309 |
+
):
|
| 310 |
+
super().__init__()
|
| 311 |
+
if isinstance(codebook_dim, int):
|
| 312 |
+
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
| 313 |
+
|
| 314 |
+
self.n_codebooks = n_codebooks
|
| 315 |
+
self.codebook_dim = codebook_dim
|
| 316 |
+
self.codebook_size = codebook_size
|
| 317 |
+
|
| 318 |
+
self.quantizers = nn.ModuleList(
|
| 319 |
+
[
|
| 320 |
+
VectorQuantize(
|
| 321 |
+
input_dim=dim,
|
| 322 |
+
codebook_size=codebook_size,
|
| 323 |
+
codebook_dim=codebook_dim[i],
|
| 324 |
+
commitment=commitment,
|
| 325 |
+
decay=decay,
|
| 326 |
+
epsilon=epsilon,
|
| 327 |
+
threshold_ema_dead=threshold_ema_dead,
|
| 328 |
+
kmeans_init=kmeans_init,
|
| 329 |
+
kmeans_iters=kmeans_iters,
|
| 330 |
+
rotation_trick=rotation_trick,
|
| 331 |
+
)
|
| 332 |
+
for i in range(n_codebooks)
|
| 333 |
+
]
|
| 334 |
+
)
|
| 335 |
+
self.quantizer_dropout = quantizer_dropout
|
| 336 |
+
|
| 337 |
+
def forward(self, z, n_quantizers: int = None):
|
| 338 |
+
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
| 339 |
+
the corresponding codebook vectors
|
| 340 |
+
Parameters
|
| 341 |
+
----------
|
| 342 |
+
z : Tensor[B x D x T]
|
| 343 |
+
n_quantizers : int, optional
|
| 344 |
+
No. of quantizers to use
|
| 345 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
| 346 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
| 347 |
+
when in training mode, and a random number of quantizers is used.
|
| 348 |
+
Returns
|
| 349 |
+
-------
|
| 350 |
+
dict
|
| 351 |
+
A dictionary with the following keys:
|
| 352 |
+
|
| 353 |
+
"z" : Tensor[B x D x T]
|
| 354 |
+
Quantized continuous representation of input
|
| 355 |
+
"codes" : Tensor[B x N x T]
|
| 356 |
+
Codebook indices for each codebook
|
| 357 |
+
(quantized discrete representation of input)
|
| 358 |
+
"latents" : Tensor[B x N*D x T]
|
| 359 |
+
Projected latents (continuous representation of input before quantization)
|
| 360 |
+
"vq/commitment_loss" : Tensor[1]
|
| 361 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
| 362 |
+
entries
|
| 363 |
+
"vq/codebook_loss" : Tensor[1]
|
| 364 |
+
Codebook loss to update the codebook
|
| 365 |
+
"""
|
| 366 |
+
z_q, residual = 0, z
|
| 367 |
+
commitment_loss, codebook_loss = 0, 0
|
| 368 |
+
|
| 369 |
+
codebook_indices, latents = [], []
|
| 370 |
+
|
| 371 |
+
if n_quantizers is None:
|
| 372 |
+
n_quantizers = self.n_codebooks
|
| 373 |
+
if self.training:
|
| 374 |
+
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
| 375 |
+
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
| 376 |
+
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
| 377 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
| 378 |
+
n_quantizers = n_quantizers.to(z.device)
|
| 379 |
+
|
| 380 |
+
for i, quantizer in enumerate(self.quantizers):
|
| 381 |
+
if self.training is False and i >= n_quantizers:
|
| 382 |
+
break
|
| 383 |
+
|
| 384 |
+
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(residual)
|
| 385 |
+
|
| 386 |
+
# Create mask to apply quantizer dropout
|
| 387 |
+
mask = torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
| 388 |
+
z_q = z_q + z_q_i * mask[:, None, None]
|
| 389 |
+
residual = residual - z_q_i
|
| 390 |
+
|
| 391 |
+
# Sum losses
|
| 392 |
+
commitment_loss += (commitment_loss_i * mask).mean()
|
| 393 |
+
codebook_loss += (codebook_loss_i * mask).mean()
|
| 394 |
+
|
| 395 |
+
codebook_indices.append(indices_i)
|
| 396 |
+
latents.append(z_e_i)
|
| 397 |
+
|
| 398 |
+
codes = torch.stack(codebook_indices, dim=-1)
|
| 399 |
+
latents = torch.cat(latents, dim=1)
|
| 400 |
+
|
| 401 |
+
return z_q, codes, latents, commitment_loss, codebook_loss
|
| 402 |
+
|
| 403 |
+
def from_codes(self, codes: torch.Tensor):
|
| 404 |
+
"""Given the quantized codes, reconstruct the continuous representation
|
| 405 |
+
Parameters
|
| 406 |
+
----------
|
| 407 |
+
codes : Tensor[B x N x T]
|
| 408 |
+
Quantized discrete representation of input
|
| 409 |
+
Returns
|
| 410 |
+
-------
|
| 411 |
+
Tensor[B x D x T]
|
| 412 |
+
Quantized continuous representation of input
|
| 413 |
+
"""
|
| 414 |
+
z_q = 0.0
|
| 415 |
+
z_p = []
|
| 416 |
+
n_codebooks = codes.shape[-1]
|
| 417 |
+
for i in range(n_codebooks):
|
| 418 |
+
z_p_i = self.quantizers[i].decode_code(codes[..., i])
|
| 419 |
+
z_p.append(z_p_i)
|
| 420 |
+
|
| 421 |
+
z_q_i = self.quantizers[i].out_project(z_p_i)
|
| 422 |
+
z_q = z_q + z_q_i
|
| 423 |
+
return z_q, torch.cat(z_p, dim=-1), codes
|
| 424 |
+
|
| 425 |
+
def from_latents(self, latents: torch.Tensor):
|
| 426 |
+
"""Given the unquantized latents, reconstruct the
|
| 427 |
+
continuous representation after quantization.
|
| 428 |
+
|
| 429 |
+
Parameters
|
| 430 |
+
----------
|
| 431 |
+
latents : Tensor[B x N x T]
|
| 432 |
+
Continuous representation of input after projection
|
| 433 |
+
|
| 434 |
+
Returns
|
| 435 |
+
-------
|
| 436 |
+
Tensor[B x D x T]
|
| 437 |
+
Quantized representation of full-projected space
|
| 438 |
+
Tensor[B x D x T]
|
| 439 |
+
Quantized representation of latent space
|
| 440 |
+
"""
|
| 441 |
+
z_q = 0
|
| 442 |
+
z_p = []
|
| 443 |
+
codes = []
|
| 444 |
+
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
| 445 |
+
|
| 446 |
+
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
|
| 447 |
+
for i in range(n_codebooks):
|
| 448 |
+
j, k = dims[i], dims[i + 1]
|
| 449 |
+
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
| 450 |
+
z_p.append(z_p_i)
|
| 451 |
+
codes.append(codes_i)
|
| 452 |
+
|
| 453 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 454 |
+
z_q = z_q + z_q_i
|
| 455 |
+
|
| 456 |
+
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class IndependentVectorQuantize(nn.Module):
|
| 460 |
+
def __init__(self, num_codebooks: int = 1, **kwargs):
|
| 461 |
+
super().__init__()
|
| 462 |
+
self.vector_quantizers = nn.ModuleList([torchVQ(**kwargs) for _ in range(num_codebooks)])
|
| 463 |
+
self.num_codebooks = num_codebooks
|
| 464 |
+
self.codebook_size = self.vector_quantizers[0].codebook_size
|
| 465 |
+
|
| 466 |
+
@property
|
| 467 |
+
def ema_update(self):
|
| 468 |
+
return [vq.ema_update for vq in self.vector_quantizers]
|
| 469 |
+
|
| 470 |
+
@property
|
| 471 |
+
def codebook(self):
|
| 472 |
+
return torch.stack([vq.codebook for vq in self.vector_quantizers], dim=0)
|
| 473 |
+
|
| 474 |
+
@codebook.setter
|
| 475 |
+
def codebook(self, codes: List[torch.Tensor]):
|
| 476 |
+
assert len(codes) == self.num_codebooks, "Number of codebooks must match"
|
| 477 |
+
if not self.separate_codebook_per_head:
|
| 478 |
+
codes = rearrange(codes, "... -> 1 ...")
|
| 479 |
+
|
| 480 |
+
for i, code in enumerate(codes):
|
| 481 |
+
self.vector_quantizers[i].codebook.copy_(code)
|
| 482 |
+
|
| 483 |
+
def get_codes_from_indices(self, indices: torch.Tensor):
|
| 484 |
+
codes = list()
|
| 485 |
+
for i in range(self.num_codebooks):
|
| 486 |
+
codes.append(self.vector_quantizers[i].get_codes_from_indices(indices[..., i : i + 1]))
|
| 487 |
+
return torch.cat(codes, dim=-2)
|
| 488 |
+
|
| 489 |
+
def get_output_from_indices(self, indices: torch.Tensor):
|
| 490 |
+
outputs = list()
|
| 491 |
+
for i in range(self.num_codebooks):
|
| 492 |
+
outputs.append(self.vector_quantizers[i].get_output_from_indices(indices[..., i : i + 1]))
|
| 493 |
+
return torch.cat(outputs, dim=-2)
|
| 494 |
+
|
| 495 |
+
def update_in_place_optimizer(self):
|
| 496 |
+
for i in range(self.num_codebooks):
|
| 497 |
+
self.vector_quantizers[i].update_in_place_optimizer()
|
| 498 |
+
|
| 499 |
+
def forward(self, x: torch.Tensor, *args, **kwargs):
|
| 500 |
+
assert x.shape[1] == self.num_codebooks
|
| 501 |
+
quantized, indices, commit_losses = list(), list(), 0
|
| 502 |
+
for i in range(self.num_codebooks):
|
| 503 |
+
quantized_i, indices_i, commit_loss_i = self.vector_quantizers[i](x[:, i : i + 1])
|
| 504 |
+
quantized.append(quantized_i)
|
| 505 |
+
indices.append(indices_i)
|
| 506 |
+
commit_losses += commit_loss_i
|
| 507 |
+
quantized = torch.cat(quantized, dim=-2)
|
| 508 |
+
indices = torch.cat(indices, dim=-1)
|
| 509 |
+
return quantized, indices, commit_losses / self.num_codebooks
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
if __name__ == "__main__":
|
| 513 |
+
vq = IndependentVectorQuantize(
|
| 514 |
+
num_codebooks=16,
|
| 515 |
+
dim=256,
|
| 516 |
+
codebook_size=2048,
|
| 517 |
+
decay=0.8, # the exponential moving average decay, lower means the dictionary will change faster
|
| 518 |
+
commitment_weight=1.0, # the weight on the commitment loss
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
x = torch.randn(1, 16, 256)
|
| 522 |
+
quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)
|