ZibinDong commited on
Commit
b8cd73d
·
verified ·
1 Parent(s): 0e0b680

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+ # Model Card for Model ID
6
+
7
+ <!-- Provide a quick summary of what the model is/does. -->
8
+
9
+
10
+
11
+ ## Model Details
12
+
13
+ ### Model Description
14
+
15
+ ActionCodec model trained only on bridgedata:
16
+ - franka_libero_20hz_0s (dummy)
17
+ - widowx_bridge_5hz_2s
18
+ - franka_droid_15hz_0s (dummy)
19
+
20
+ ### Model Sources [optional]
21
+
22
+ <!-- Provide the basic links for the model. -->
23
+
24
+ TODO
25
+
26
+ ## Uses
27
+
28
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
29
+
30
+ ### Direct Use
31
+
32
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
33
+
34
+ ```python
35
+ import numpy as np
36
+ from transformers import AutoModel
37
+ np.set_printoptions(suppress=True)
38
+ if __name__ == "__main__":
39
+ tokenizer = AutoModel.from_pretrained("ZibinDong/ActionCodec-bridge-RVQft", trust_remote_code=True)
40
+ q99 = np.array([0.9375, 0.91071427, 0.9375, 0.20357142, 0.26357144, 0.375, 1.0])
41
+ q01 = np.array([-0.87857145, -0.87589288, -0.9375, -0.15107143, -0.20678571, -0.27964285, 0.0])
42
+ # an example action from physical-intelligence/libero
43
+ action = np.array(
44
+ [
45
+ [0.3268, 0.2089, -0.3295, 0.0000, -0.0868, -0.0611, 1.0000],
46
+ [0.3696, 0.1955, -0.2866, 0.0000, -0.0793, -0.0643, 1.0000],
47
+ [0.3857, 0.1929, -0.2759, 0.0000, -0.0782, -0.0654, 1.0000],
48
+ [0.3964, 0.2089, -0.2786, 0.0000, -0.0761, -0.0654, 1.0000],
49
+ [0.3321, 0.1741, -0.3268, 0.0000, -0.0793, -0.0686, 1.0000],
50
+ [0.2250, 0.0964, -0.4232, 0.0000, -0.0932, -0.0761, 1.0000],
51
+ [0.0723, 0.0000, -0.5625, 0.0000, -0.1339, -0.0879, 1.0000],
52
+ [0.0536, 0.0000, -0.5652, 0.0000, -0.1521, -0.0921, 1.0000],
53
+ [0.0750, 0.0000, -0.5464, 0.0000, -0.1511, -0.0964, 1.0000],
54
+ [0.0723, 0.0000, -0.5411, 0.0000, -0.1414, -0.0986, 1.0000],
55
+ [0.0402, 0.0000, -0.5196, 0.0000, -0.1350, -0.1007, 1.0000],
56
+ [0.0080, 0.0000, -0.4795, 0.0000, -0.1189, -0.1018, 1.0000],
57
+ [0.0000, 0.0000, -0.4527, 0.0000, -0.0986, -0.1018, 1.0000],
58
+ [0.0000, 0.0000, -0.4313, 0.0000, -0.0846, -0.1018, 1.0000],
59
+ [-0.0455, -0.0268, -0.3509, 0.0000, -0.0568, -0.1018, 1.0000],
60
+ [-0.0964, -0.0482, -0.3321, 0.0000, -0.0439, -0.1039, 1.0000],
61
+ [-0.1768, -0.0562, -0.3402, 0.0000, -0.0300, -0.1050, 1.0000],
62
+ [-0.2438, -0.0429, -0.3187, 0.0000, -0.0193, -0.0996, 1.0000],
63
+ [-0.3054, -0.0054, -0.2893, 0.0000, -0.0139, -0.0932, 1.0000],
64
+ [-0.3509, 0.0000, -0.2598, 0.0000, -0.0054, -0.0879, 1.0000],
65
+ ],
66
+ )[None]
67
+ # normalization
68
+ normalized_action = np.copy(action)
69
+ normalized_action[..., :-1] = normalized_action[..., :-1] / np.maximum(np.abs(q99), np.abs(q01))[..., :-1]
70
+ normalized_action[..., -1] = normalized_action[..., -1] * 2.0 - 1.0 # scale to [-1, 1]
71
+ normalized_action = normalized_action.clip(-1.0, 1.0)
72
+ # tokenization
73
+ tokens = tokenizer.encode(normalized_action) # numpy (b, n, d) -> list of ints
74
+ print(tokens)
75
+ # decoding
76
+ decoded_action, padding_mask = tokenizer.decode(tokens) # list of ints -> numpy (b, n, d)
77
+ # calculate reconstruction error
78
+ mse_error = np.mean((normalized_action - decoded_action) ** 2)
79
+ l1_error = np.mean(np.abs(normalized_action - decoded_action))
80
+ print(f"Reconstruction MSE error: {mse_error:.6f}")
81
+ print(f"Reconstruction L1 error: {l1_error:.6f}")
82
+ ```
83
+
84
+ ### Downstream Use [optional]
85
+
86
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
87
+
88
+ TODO
89
+
90
+ ### Out-of-Scope Use
91
+
92
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
93
+
94
+ TODO
95
+
96
+ ## Bias, Risks, and Limitations
97
+
98
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
99
+
100
+ TODO
101
+
102
+ ### Recommendations
103
+
104
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
105
+
106
+ TODO
107
+
108
+ ## How to Get Started with the Model
109
+
110
+ Use the code below to get started with the model.
111
+
112
+ TODO
113
+
114
+ ## Training Details
115
+
116
+ ### Training Data
117
+
118
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
119
+
120
+ [More Information Needed]
121
+
122
+ ### Training Procedure
123
+
124
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
125
+
126
+ #### Preprocessing [optional]
127
+
128
+ [More Information Needed]
129
+
130
+
131
+ #### Training Hyperparameters
132
+
133
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
134
+
135
+ #### Speeds, Sizes, Times [optional]
136
+
137
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Evaluation
142
+
143
+ <!-- This section describes the evaluation protocols and provides the results. -->
144
+
145
+ ### Testing Data, Factors & Metrics
146
+
147
+ #### Testing Data
148
+
149
+ <!-- This should link to a Dataset Card if possible. -->
150
+
151
+ [More Information Needed]
152
+
153
+ #### Factors
154
+
155
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
156
+
157
+ [More Information Needed]
158
+
159
+ #### Metrics
160
+
161
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
162
+
163
+ [More Information Needed]
164
+
165
+ ### Results
166
+
167
+ [More Information Needed]
168
+
169
+ #### Summary
170
+
171
+
172
+
173
+ ## Model Examination [optional]
174
+
175
+ <!-- Relevant interpretability work for the model goes here -->
176
+
177
+ [More Information Needed]
178
+
179
+ ## Environmental Impact
180
+
181
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
182
+
183
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
184
+
185
+ - **Hardware Type:** [More Information Needed]
186
+ - **Hours used:** [More Information Needed]
187
+ - **Cloud Provider:** [More Information Needed]
188
+ - **Compute Region:** [More Information Needed]
189
+ - **Carbon Emitted:** [More Information Needed]
190
+
191
+ ## Technical Specifications [optional]
192
+
193
+ ### Model Architecture and Objective
194
+
195
+ [More Information Needed]
196
+
197
+ ### Compute Infrastructure
198
+
199
+ [More Information Needed]
200
+
201
+ #### Hardware
202
+
203
+ [More Information Needed]
204
+
205
+ #### Software
206
+
207
+ [More Information Needed]
208
+
209
+ ## Citation [optional]
210
+
211
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
212
+
213
+ **BibTeX:**
214
+
215
+ [More Information Needed]
216
+
217
+ **APA:**
218
+
219
+ [More Information Needed]
220
+
221
+ ## Glossary [optional]
222
+
223
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
224
+
225
+ [More Information Needed]
226
+
227
+ ## More Information [optional]
228
+
229
+ [More Information Needed]
230
+
231
+ ## Model Card Authors [optional]
232
+
233
+ [More Information Needed]
234
+
235
+ ## Model Card Contact
236
+
237
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ActionCodec"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_actioncodec.ActionCodecConfig",
7
+ "AutoModel": "modeling_actioncodec.ActionCodec"
8
+ },
9
+ "decoder_add_causal_mask": false,
10
+ "decoder_add_self_attn": false,
11
+ "decoder_cls_size": 1,
12
+ "decoder_dim": 384,
13
+ "decoder_n_heads": 6,
14
+ "decoder_n_layers": 12,
15
+ "decoder_pos_encoding_type": "fourier",
16
+ "dtype": "float32",
17
+ "embodiment_config": {
18
+ "a_franka_libero_20hz": {
19
+ "action_dim": 7,
20
+ "description": "20Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).",
21
+ "duration": 0,
22
+ "freq": 20
23
+ },
24
+ "b_widowx_bridge_5hz": {
25
+ "action_dim": 7,
26
+ "description": "5Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).",
27
+ "duration": 2,
28
+ "freq": 5
29
+ },
30
+ "c_franka_droid_15hz": {
31
+ "action_dim": 7,
32
+ "description": "15Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).",
33
+ "duration": 0,
34
+ "freq": 15
35
+ }
36
+ },
37
+ "encoder_add_causal_mask": false,
38
+ "encoder_add_self_attn": false,
39
+ "encoder_dim": 384,
40
+ "encoder_n_heads": 6,
41
+ "encoder_n_layers": 12,
42
+ "encoder_pos_encoding_type": "fourier",
43
+ "model_type": "action_codec",
44
+ "n_quantizers": 3,
45
+ "n_tokens": 48,
46
+ "transformers_version": "4.57.3",
47
+ "vq_codebook_size": 2048,
48
+ "vq_commitment_weight": 0.25,
49
+ "vq_decay": 0.99,
50
+ "vq_kmeans_init": true,
51
+ "vq_quantizer_dropout": 0.25,
52
+ "vq_threshold_ema_dead_code": 2,
53
+ "vq_type": "rvq",
54
+ "z_dim": 512
55
+ }
configuration_actioncodec.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Any, Dict
3
+
4
+ from transformers import AutoConfig, PretrainedConfig
5
+
6
+
7
+ class ActionCodecConfig(PretrainedConfig):
8
+ model_type = "action_codec"
9
+
10
+ def __init__(
11
+ self,
12
+ embodiment_config: Dict[str, Any] = None,
13
+ n_tokens: int = 16,
14
+ n_quantizers: int = 1,
15
+ z_dim: int = 512,
16
+ vq_type: str = "vq",
17
+ vq_codebook_size: int = 2048,
18
+ vq_commitment_weight: float = 0.25,
19
+ vq_decay: float = 0.99,
20
+ vq_kmeans_init: bool = True,
21
+ vq_threshold_ema_dead_code: int = 2,
22
+ vq_quantizer_dropout: float = 0.25,
23
+ encoder_dim: int = 256,
24
+ encoder_n_layers: int = 6,
25
+ encoder_n_heads: int = 8,
26
+ encoder_add_self_attn: bool = False,
27
+ encoder_add_causal_mask: bool = False,
28
+ encoder_pos_encoding_type: str = "fourier",
29
+ decoder_dim: int = 256,
30
+ decoder_n_layers: int = 6,
31
+ decoder_n_heads: int = 8,
32
+ decoder_add_self_attn: bool = False,
33
+ decoder_add_causal_mask: bool = False,
34
+ decoder_pos_encoding_type: str = "fourier",
35
+ decoder_cls_size: int = 1,
36
+ **kwargs,
37
+ ):
38
+ super().__init__(**kwargs)
39
+
40
+ if embodiment_config is None:
41
+ default_config = {
42
+ "franka_libero_20hz": {
43
+ "action_dim": 7,
44
+ "freq": 20,
45
+ "duration": 1,
46
+ "description": "20Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).",
47
+ },
48
+ "widowx_bridge_5hz": {
49
+ "action_dim": 7,
50
+ "freq": 5,
51
+ "duration": 1,
52
+ "description": "5Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).",
53
+ },
54
+ "franka_droid_15hz": {
55
+ "action_dim": 7,
56
+ "freq": 15,
57
+ "duration": 1,
58
+ "description": "15Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).",
59
+ },
60
+ }
61
+ self.embodiment_config = copy.deepcopy(default_config)
62
+ else:
63
+ self.embodiment_config = copy.deepcopy(embodiment_config)
64
+
65
+ self.n_tokens = n_tokens
66
+ self.n_quantizers = n_quantizers
67
+ self.z_dim = z_dim
68
+
69
+ self.encoder_dim = encoder_dim
70
+ self.encoder_n_layers = encoder_n_layers
71
+ self.encoder_n_heads = encoder_n_heads
72
+ self.encoder_add_self_attn = encoder_add_self_attn
73
+ self.encoder_add_causal_mask = encoder_add_causal_mask
74
+ self.encoder_pos_encoding_type = encoder_pos_encoding_type
75
+
76
+ self.decoder_dim = decoder_dim
77
+ self.decoder_n_layers = decoder_n_layers
78
+ self.decoder_n_heads = decoder_n_heads
79
+ self.decoder_add_self_attn = decoder_add_self_attn
80
+ self.decoder_add_causal_mask = decoder_add_causal_mask
81
+ self.decoder_pos_encoding_type = decoder_pos_encoding_type
82
+ self.decoder_cls_size = decoder_cls_size
83
+
84
+ self.vq_type = vq_type
85
+ self.vq_codebook_size = vq_codebook_size
86
+ self.vq_commitment_weight = vq_commitment_weight
87
+ self.vq_decay = vq_decay
88
+ self.vq_kmeans_init = vq_kmeans_init
89
+ self.vq_threshold_ema_dead_code = vq_threshold_ema_dead_code
90
+ self.vq_quantizer_dropout = vq_quantizer_dropout
91
+
92
+
93
+ class ActionCodecConfigOld(PretrainedConfig):
94
+ model_type = "action_codec"
95
+
96
+ def __init__(
97
+ self,
98
+ horizon: int = 20,
99
+ action_dim: int = 7,
100
+ action_encoding: str = "independent_v2",
101
+ horizon_patch_size: int = 1,
102
+ encoder_class: str = "action_codec.modules.perceiver.PerceiverEncoder",
103
+ decoder_class: str = "action_codec.modules.perceiver.PerceiverDecoder",
104
+ vq_class: str = "vector_quantize_pytorch.VectorQuantize",
105
+ encoder_kwargs: Dict[str, Any] = None,
106
+ decoder_kwargs: Dict[str, Any] = None,
107
+ vq_kwargs: Dict[str, Any] = None,
108
+ **kwargs,
109
+ ):
110
+ super().__init__(**kwargs)
111
+ self.horizon = horizon
112
+ self.action_dim = action_dim
113
+ self.action_encoding = action_encoding
114
+ self.horizon_patch_size = horizon_patch_size
115
+ self.encoder_class = encoder_class
116
+ self.decoder_class = decoder_class
117
+ self.vq_class = vq_class
118
+ self.encoder_kwargs = (
119
+ dict(encoder_kwargs)
120
+ if encoder_kwargs is not None
121
+ else {
122
+ "dim": 384,
123
+ "in_len": horizon,
124
+ "out_len": 16,
125
+ "num_layers": 12,
126
+ "num_heads": 4,
127
+ "output_round": -1.0,
128
+ }
129
+ )
130
+ self.decoder_kwargs = (
131
+ dict(decoder_kwargs)
132
+ if decoder_kwargs is not None
133
+ else {
134
+ "dim": 384,
135
+ "in_len": 16,
136
+ "out_len": horizon,
137
+ "num_layers": 12,
138
+ "num_heads": 4,
139
+ }
140
+ )
141
+ self.vq_kwargs = (
142
+ dict(vq_kwargs)
143
+ if vq_kwargs is not None
144
+ else {
145
+ "dim": 512,
146
+ "codebook_size": 2048,
147
+ "kmeans_init": True,
148
+ "kmeans_iters": 10,
149
+ "decay": 0.99,
150
+ "commitment_weight": 0.25,
151
+ "rotation_trick": False,
152
+ "threshold_ema_dead_code": 2,
153
+ "use_cosine_sim": False,
154
+ "codebook_diversity_loss_weight": 0.0,
155
+ }
156
+ )
157
+
158
+
159
+ class BPEActionCodecConfig(PretrainedConfig):
160
+ model_type = "bpe_action_codec"
161
+
162
+ def __init__(
163
+ self,
164
+ horizon: int = 20,
165
+ action_dim: int = 7,
166
+ action_encoding: str = "independent_v2",
167
+ horizon_patch_size: int = 1,
168
+ encoder_class: str = "action_codec.modules.perceiver.PerceiverEncoder",
169
+ decoder_class: str = "action_codec.modules.perceiver.PerceiverDecoder",
170
+ vq_class: str = "vector_quantize_pytorch.VectorQuantize",
171
+ encoder_kwargs: Dict[str, Any] = None,
172
+ decoder_kwargs: Dict[str, Any] = None,
173
+ vq_kwargs: Dict[str, Any] = None,
174
+ **kwargs,
175
+ ):
176
+ super().__init__(**kwargs)
177
+ self.horizon = horizon
178
+ self.action_dim = action_dim
179
+ self.action_encoding = action_encoding
180
+ self.horizon_patch_size = horizon_patch_size
181
+ self.encoder_class = encoder_class
182
+ self.decoder_class = decoder_class
183
+ self.vq_class = vq_class
184
+ self.encoder_kwargs = (
185
+ dict(encoder_kwargs)
186
+ if encoder_kwargs is not None
187
+ else {
188
+ "dim": 384,
189
+ "in_len": horizon,
190
+ "out_len": 16,
191
+ "num_layers": 12,
192
+ "num_heads": 4,
193
+ "output_round": -1.0,
194
+ }
195
+ )
196
+ self.decoder_kwargs = (
197
+ dict(decoder_kwargs)
198
+ if decoder_kwargs is not None
199
+ else {
200
+ "dim": 384,
201
+ "in_len": 16,
202
+ "out_len": horizon,
203
+ "num_layers": 12,
204
+ "num_heads": 4,
205
+ }
206
+ )
207
+ self.vq_kwargs = (
208
+ dict(vq_kwargs)
209
+ if vq_kwargs is not None
210
+ else {
211
+ "dim": 512,
212
+ "codebook_size": 2048,
213
+ "kmeans_init": True,
214
+ "kmeans_iters": 10,
215
+ "decay": 0.99,
216
+ "commitment_weight": 0.25,
217
+ "rotation_trick": False,
218
+ "threshold_ema_dead_code": 2,
219
+ "use_cosine_sim": False,
220
+ "codebook_diversity_loss_weight": 0.0,
221
+ }
222
+ )
223
+
224
+
225
+ AutoConfig.register("action_codec", ActionCodecConfig)
226
+ AutoConfig.register("bpe_action_codec", BPEActionCodecConfig)
227
+
228
+ ActionCodecConfig.register_for_auto_class()
229
+
230
+ __all__ = ["ActionCodecConfig", "BPEActionCodecConfig"]
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67e53d45f8bfefe368b5c6ad35cf271368e85df03f1ecb7edc7674037050c73b
3
+ size 197182335
modeling_actioncodec.py ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Union
2
+
3
+ import einops
4
+ import numpy as np
5
+ import torch
6
+ from transformers import AutoModel, PreTrainedModel
7
+ from vector_quantize_pytorch import VectorQuantize
8
+
9
+ from .configuration_actioncodec import ActionCodecConfig
10
+ from .modular_actioncodec import PerceiverDecoder, PerceiverEncoder
11
+ from .rvq import ResidualVectorQuantize
12
+
13
+
14
+ def trim_trailing_zeros(arr: np.ndarray) -> list[np.ndarray]:
15
+ if arr.shape[0] == 0:
16
+ return []
17
+
18
+ b, n = arr.shape
19
+
20
+ is_nonzero = arr != 0
21
+ flipped_mask = np.flip(is_nonzero, axis=1)
22
+ last_nonzero_indices = n - 1 - np.argmax(flipped_mask, axis=1)
23
+ any_nonzero_in_row = is_nonzero.any(axis=1)
24
+ new_lengths = (last_nonzero_indices + 1) * any_nonzero_in_row
25
+ result = [arr[i, :length].tolist() for i, length in enumerate(new_lengths)]
26
+
27
+ return result
28
+
29
+
30
+ class ActionCodec(PreTrainedModel):
31
+ """ActionCodec: A neural codec for encoding and decoding robot action sequences.
32
+
33
+ This model uses a Perceiver-based encoder-decoder architecture with vector quantization
34
+ to convert continuous action sequences into discrete token sequences. It supports
35
+ multiple robot embodiments with different action dimensions and control frequencies.
36
+
37
+ The model supports two vector quantization types:
38
+ - VQ (Vector Quantization): Single quantizer
39
+ - RVQ (Residual Vector Quantization): Multiple quantizers for hierarchical encoding
40
+
41
+ Key features:
42
+ - Multi-embodiment support: Handle different robots with varying action dimensions
43
+ - Dynamic expansion: Add new robot configurations without retraining
44
+ - Flexible input/output: Support numpy arrays and torch tensors
45
+ """
46
+
47
+ config_class = ActionCodecConfig
48
+
49
+ def __init__(self, config: ActionCodecConfig):
50
+ """Initialize the ActionCodec model.
51
+
52
+ Args:
53
+ config (ActionCodecConfig): Model configuration containing hyperparameters
54
+ and embodiment configurations.
55
+
56
+ Raises:
57
+ ValueError: If configuration parameters are invalid.
58
+ NotImplementedError: If the specified VQ type is not supported.
59
+ """
60
+ super().__init__(config)
61
+
62
+ # Validate configuration
63
+ if config.n_tokens % config.n_quantizers != 0:
64
+ raise ValueError(f"n_tokens ({config.n_tokens}) must be divisible by n_quantizers ({config.n_quantizers})")
65
+
66
+ if config.n_quantizers < 1:
67
+ raise ValueError(f"n_quantizers must be at least 1, got {config.n_quantizers}")
68
+
69
+ if config.vq_codebook_size < 1:
70
+ raise ValueError(f"vq_codebook_size must be at least 1, got {config.vq_codebook_size}")
71
+
72
+ if config.z_dim < 1:
73
+ raise ValueError(f"z_dim must be at least 1, got {config.z_dim}")
74
+
75
+ if not isinstance(config.embodiment_config, dict) or len(config.embodiment_config) == 0:
76
+ raise ValueError(
77
+ "embodiment_config must be a non-empty dictionary mapping embodiment names to configurations"
78
+ )
79
+
80
+ self.default_embodiment_id = 0
81
+
82
+ # Initialize encoder and decoder
83
+ self.encoder = PerceiverEncoder(config)
84
+ self.decoder = PerceiverDecoder(config)
85
+
86
+ # Initialize vector quantizer based on type
87
+ if config.vq_type == "vq":
88
+ if config.n_quantizers != 1:
89
+ raise ValueError(
90
+ f"VQ type requires n_quantizers=1, got {config.n_quantizers}. Use RVQ type for multiple quantizers."
91
+ )
92
+ self.vq = VectorQuantize(
93
+ dim=config.z_dim,
94
+ codebook_size=config.vq_codebook_size,
95
+ commitment_weight=config.vq_commitment_weight,
96
+ decay=config.vq_decay,
97
+ kmeans_init=config.vq_kmeans_init,
98
+ threshold_ema_dead_code=config.vq_threshold_ema_dead_code,
99
+ rotation_trick=False,
100
+ straight_through=True,
101
+ )
102
+ elif config.vq_type == "rvq":
103
+ if config.n_quantizers < 2:
104
+ raise ValueError(
105
+ f"RVQ type requires n_quantizers >= 2, got {config.n_quantizers}. Use VQ type for single quantizer."
106
+ )
107
+ self.vq = ResidualVectorQuantize(
108
+ dim=config.z_dim,
109
+ n_codebooks=config.n_quantizers,
110
+ codebook_size=config.vq_codebook_size,
111
+ codebook_dim=config.z_dim,
112
+ quantizer_dropout=config.vq_quantizer_dropout,
113
+ commitment=config.vq_commitment_weight,
114
+ )
115
+ else:
116
+ raise NotImplementedError(f"VQ type '{config.vq_type}' not implemented. Supported types: 'vq', 'rvq'")
117
+
118
+ # Store quantization-related attributes
119
+ self.vocab_size = config.vq_codebook_size
120
+ self.num_quantizers = config.n_quantizers
121
+ self.n_tokens_per_quantizer = config.n_tokens // config.n_quantizers
122
+
123
+ def expand_embodiment(self, embodiment_config: dict):
124
+ """Dynamically expand the model to support new robot embodiments.
125
+
126
+ This method allows adding new robot configurations to the codec without retraining
127
+ the entire model. It updates the encoder and decoder to handle the new action dimensions
128
+ and frequencies while preserving existing functionality for previously configured robots.
129
+
130
+ Args:
131
+ embodiment_config (dict): Dictionary mapping embodiment names to their configurations.
132
+ Each configuration should be a dict with keys:
133
+ - "action_dim" (int): Action dimensionality for this embodiment.
134
+ - "freq" (float): Control frequency in Hz.
135
+ - "duration" (float): Default action sequence duration in seconds.
136
+ - "description" (str, optional): Human-readable description.
137
+
138
+ Example:
139
+ {
140
+ "robot_B": {
141
+ "action_dim": 10,
142
+ "freq": 20,
143
+ "duration": 1.0,
144
+ "description": "10-dim robot at 20Hz"
145
+ }
146
+ }
147
+
148
+ Returns:
149
+ ActionCodec: Returns self for method chaining.
150
+
151
+ Note:
152
+ - New embodiment keys must not already exist in the current configuration.
153
+ - The model will automatically update max_action_dim if the new embodiment
154
+ has a larger action dimension.
155
+ - Existing embodiments will continue to work with their original configurations.
156
+ """
157
+ if not isinstance(embodiment_config, dict):
158
+ raise TypeError(f"embodiment_config must be a dict, got {type(embodiment_config)}")
159
+ if len(embodiment_config) == 0:
160
+ raise ValueError("embodiment_config cannot be empty")
161
+
162
+ # Check for duplicate keys
163
+ overlapping_keys = set(embodiment_config.keys()) & set(self.config.embodiment_config.keys())
164
+ if overlapping_keys:
165
+ raise ValueError(f"The following embodiment keys already exist and cannot be redefined: {overlapping_keys}")
166
+
167
+ self.encoder.expand_embodiment(embodiment_config)
168
+ self.decoder.expand_embodiment(embodiment_config)
169
+ self.config.embodiment_config.update(embodiment_config)
170
+ return self
171
+
172
+ def _encode(
173
+ self,
174
+ x: torch.Tensor,
175
+ embodiment_ids: torch.Tensor | int | None = None,
176
+ padding_mask: torch.Tensor | None = None,
177
+ ) -> torch.Tensor:
178
+ """Encode action sequences into latent representations.
179
+
180
+ Args:
181
+ x (torch.Tensor): Action sequences to encode. Shape: (b, seq_len, max_action_dim).
182
+ Assumes that the action dimension is zero-padded to the max action dimension.
183
+ `seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length.
184
+ embodiment_ids (torch.Tensor | int): Embodiment IDs. Shape: (b,).
185
+ If int, the same embodiment ID is repeated for all sequences in the batch.
186
+ It specifies the embodiment to encode.
187
+ padding_mask (Optional[torch.Tensor], optional): Padding mask, where `False` values indicate padding. Shape: (b, seq_len). Defaults to None.
188
+ It is used to mask the padding tokens on `seq_len` dimension.
189
+
190
+ Returns:
191
+ torch.Tensor: Encoded latent representations. Shape: (b, n_tokens_per_quantizer, z_dim).
192
+ """
193
+ embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
194
+ z_e = self.encoder(x, embodiment_ids, padding_mask)
195
+ return z_e
196
+
197
+ def _quantize(
198
+ self, z_e: torch.Tensor, return_perplexity: bool = True
199
+ ) -> Tuple[torch.Tensor, torch.Tensor, Union[float, List[float]], torch.Tensor]:
200
+ """Quantize encoded representations using vector quantization.
201
+
202
+ Args:
203
+ z_e (torch.Tensor): Encoded latent representations to quantize.
204
+ Shape: (b, n_tokens_per_quantizer, z_dim).
205
+ return_perplexity (bool, optional): Whether to compute and return perplexity.
206
+ Defaults to True.
207
+
208
+ Returns:
209
+ Tuple[torch.Tensor, torch.Tensor, Union[float, List[float]], torch.Tensor]:
210
+ A tuple containing:
211
+ - z_q (torch.Tensor): Quantized representations.
212
+ Shape: (b, n_tokens_per_quantizer, z_dim).
213
+ - indices (torch.Tensor): Quantization indices.
214
+ Shape: (b, n_tokens_per_quantizer) for VQ or (b, n_tokens_per_quantizer, n_quantizers) for RVQ.
215
+ - perplexity (Union[float, List[float]]): Codebook perplexity.
216
+ Float for single quantizer, List[float] for multiple quantizers.
217
+ - commit_loss (torch.Tensor): Commitment loss scalar tensor.
218
+ """
219
+ if isinstance(self.vq, ResidualVectorQuantize):
220
+ z_q, indices, _, commitment_loss, codebook_loss = self.vq(z_e)
221
+ commit_loss = commitment_loss.mean() + codebook_loss.mean()
222
+ elif isinstance(self.vq, VectorQuantize):
223
+ z_q, indices, commit_loss = self.vq(z_e)
224
+ else:
225
+ raise NotImplementedError(f"VQ type {type(self.vq)} not implemented")
226
+
227
+ if return_perplexity:
228
+ if len(indices.size()) < 3:
229
+ indices = indices.unsqueeze(-1)
230
+ perplexity = []
231
+ for k in range(indices.size(-1)):
232
+ this_indices = indices[:, :, k]
233
+ indices_count = torch.bincount(this_indices.view(-1), minlength=self.vq.codebook_size)
234
+ if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1:
235
+ torch.distributed.all_reduce(indices_count)
236
+ this_avg_probs = indices_count.float() / indices_count.sum()
237
+ perplexity.append(((-(this_avg_probs * torch.log(this_avg_probs + 1e-10)).sum()).exp().item()))
238
+ else:
239
+ perplexity = 0
240
+
241
+ return z_q, indices, perplexity, commit_loss
242
+
243
+ def _dequantize(self, indices: torch.Tensor) -> torch.Tensor:
244
+ """Dequantize token indices back to continuous latent representations.
245
+
246
+ Args:
247
+ indices (torch.Tensor): Quantization indices. Shape depends on quantizer type:
248
+ - For VQ: (b, n_tokens) or (b, n_tokens, 1)
249
+ - For RVQ: (b, n_tokens_per_quantizer, n_quantizers)
250
+
251
+ Returns:
252
+ torch.Tensor: Dequantized latent representations.
253
+ Shape: (b, n_tokens_per_quantizer, z_dim)
254
+ """
255
+ if self.num_quantizers == 1:
256
+ if len(indices.size()) == 3:
257
+ indices = indices.squeeze(-1)
258
+ if isinstance(self.vq, ResidualVectorQuantize):
259
+ z_q = self.vq.from_codes(indices)[0]
260
+ elif isinstance(self.vq, VectorQuantize):
261
+ z_q = self.vq.get_output_from_indices(indices)
262
+ else:
263
+ raise NotImplementedError(f"VQ type {type(self.vq)} not implemented in _dequantize")
264
+ return z_q
265
+
266
+ def _decode(
267
+ self, z_q: torch.Tensor, embodiment_ids: torch.Tensor | int | None = None, durations: torch.Tensor | None = None
268
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
269
+ """Decode quantized latent representations into action sequences.
270
+
271
+ Args:
272
+ z_q (torch.Tensor): Quantized latent representations.
273
+ Shape: (b, n_tokens_per_quantizer, z_dim).
274
+ embodiment_ids (Union[torch.Tensor, int, None], optional): Embodiment IDs.
275
+ Shape: (b,) if tensor. If int, the same embodiment ID is used for all
276
+ sequences. Defaults to None, which uses `self.default_embodiment_id`.
277
+ durations (torch.Tensor | None, optional): Duration of each action sequence in seconds.
278
+ Shape: (b,). If None, uses default duration from embodiment_config.
279
+ Defaults to None.
280
+
281
+ Returns:
282
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
283
+ - x_recon (torch.Tensor): Reconstructed action sequences.
284
+ Shape: (b, seq_len, max_action_dim).
285
+ - padding_mask (torch.Tensor): Padding mask indicating valid timesteps.
286
+ Shape: (b, seq_len), where True indicates valid timesteps.
287
+ """
288
+ embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
289
+ x_recon, padding_mask = self.decoder(z_q, embodiment_ids, durations)
290
+ return x_recon, padding_mask
291
+
292
+ @torch.no_grad()
293
+ def encode(
294
+ self,
295
+ x: Union[np.ndarray, torch.Tensor],
296
+ embodiment_ids: Union[List[int], int, None] = None,
297
+ padding_mask: Union[List[bool], np.ndarray, torch.Tensor, None] = None,
298
+ **kwargs,
299
+ ) -> List[List[int]]:
300
+ """Encode action sequences into latent representations (token indices).
301
+
302
+ This method converts action sequences into discrete token indices using the encoder
303
+ and vector quantizer. The input can be either a numpy array or torch tensor.
304
+
305
+ Args:
306
+ x (Union[np.ndarray, torch.Tensor]): Action sequences to encode.
307
+ Shape: (b, seq_len, max_action_dim).
308
+ Assumes that the action dimension is zero-padded to the max action dimension.
309
+ `seq_len` is supposed to be `int(duration * freq)` for each embodiment and
310
+ padded to the max sequence length.
311
+ embodiment_ids (Union[List[int], int, None], optional): Embodiment IDs.
312
+ Shape: (b,) if list. If int, the same embodiment ID is repeated for all
313
+ sequences in the batch. It specifies the embodiment to encode.
314
+ Defaults to None, which uses `self.default_embodiment_id`.
315
+ padding_mask (Union[List[bool], np.ndarray, torch.Tensor, None], optional):
316
+ Padding mask, where `False` values indicate padding. Shape: (b, seq_len).
317
+ Defaults to None. It is used to mask the padding tokens on `seq_len` dimension.
318
+ **kwargs: Additional keyword arguments (currently unused, reserved for future use).
319
+
320
+ Returns:
321
+ List[List[int]]: List of token sequences. Shape: (b, n_tokens), where n_tokens
322
+ is determined by the model configuration (typically `config.n_tokens`).
323
+
324
+ Raises:
325
+ ValueError: If input shapes are invalid or incompatible with the model configuration.
326
+ TypeError: If input types are not supported.
327
+
328
+ Examples:
329
+ >>> import numpy as np
330
+ >>> # Using numpy array
331
+ >>> x = np.random.randn(2, 10, 7).astype(np.float32)
332
+ >>> tokens = model.encode(x, embodiment_ids=[0, 0])
333
+ >>> # Using torch tensor
334
+ >>> x_tensor = torch.randn(2, 10, 7)
335
+ >>> tokens = model.encode(x_tensor, embodiment_ids=[0, 0])
336
+ """
337
+ self.eval()
338
+
339
+ # Validate and convert input x
340
+ if isinstance(x, np.ndarray):
341
+ if x.ndim != 3:
342
+ raise ValueError(
343
+ f"Expected 3D input array (batch, seq_len, action_dim), got {x.ndim}D array with shape {x.shape}"
344
+ )
345
+ x_tensor = torch.tensor(x, dtype=self.dtype, device=self.device)
346
+ elif isinstance(x, torch.Tensor):
347
+ if x.ndim != 3:
348
+ raise ValueError(
349
+ f"Expected 3D tensor (batch, seq_len, action_dim), got {x.ndim}D tensor with shape {x.shape}"
350
+ )
351
+ x_tensor = x.to(dtype=self.dtype, device=self.device)
352
+ else:
353
+ raise TypeError(f"Input x must be numpy.ndarray or torch.Tensor, got {type(x)}")
354
+
355
+ # Validate batch size
356
+ batch_size = x_tensor.shape[0]
357
+ if batch_size == 0:
358
+ raise ValueError("Batch size must be at least 1")
359
+
360
+ # Handle embodiment_ids
361
+ embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
362
+ if isinstance(embodiment_ids, int):
363
+ if not 0 <= embodiment_ids < len(self.config.embodiment_config):
364
+ raise ValueError(
365
+ f"embodiment_id {embodiment_ids} is out of range [0, {len(self.config.embodiment_config)}). "
366
+ f"Available embodiment IDs: {list(range(len(self.config.embodiment_config)))}"
367
+ )
368
+ embodiment_ids_tensor = torch.tensor([embodiment_ids] * batch_size, dtype=torch.long, device=self.device)
369
+ elif isinstance(embodiment_ids, list):
370
+ if len(embodiment_ids) != batch_size:
371
+ raise ValueError(
372
+ f"Length of embodiment_ids ({len(embodiment_ids)}) must match batch size ({batch_size})"
373
+ )
374
+ for eid in embodiment_ids:
375
+ if not isinstance(eid, int) or not 0 <= eid < len(self.config.embodiment_config):
376
+ raise ValueError(
377
+ f"Invalid embodiment_id {eid}. Must be an integer in range [0, {len(self.config.embodiment_config)})"
378
+ )
379
+ embodiment_ids_tensor = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device)
380
+ else:
381
+ raise TypeError(f"embodiment_ids must be int, List[int], or None, got {type(embodiment_ids)}")
382
+
383
+ # Handle padding_mask
384
+ padding_mask_tensor = None
385
+ if padding_mask is not None:
386
+ if isinstance(padding_mask, (list, np.ndarray)):
387
+ padding_mask_tensor = torch.tensor(padding_mask, dtype=torch.bool, device=self.device)
388
+ elif isinstance(padding_mask, torch.Tensor):
389
+ padding_mask_tensor = padding_mask.to(dtype=torch.bool, device=self.device)
390
+ else:
391
+ raise TypeError(
392
+ f"padding_mask must be List[bool], np.ndarray, torch.Tensor, or None, got {type(padding_mask)}"
393
+ )
394
+ if padding_mask_tensor.shape != (batch_size, x_tensor.shape[1]):
395
+ raise ValueError(
396
+ f"padding_mask shape {padding_mask_tensor.shape} does not match expected shape "
397
+ f"({batch_size}, {x_tensor.shape[1]})"
398
+ )
399
+
400
+ with torch.no_grad():
401
+ z_e = self._encode(x_tensor, embodiment_ids_tensor, padding_mask_tensor)
402
+ _, indices, _, _ = self._quantize(z_e, return_perplexity=False)
403
+
404
+ # Reshape indices: for RVQ, indices shape is (b, n, s), for VQ it's (b, n)
405
+ if len(indices.size()) > 2:
406
+ codes_list = einops.rearrange(indices, "b n s -> b (s n)").cpu()
407
+ else:
408
+ codes_list = indices.cpu()
409
+
410
+ codes_list = codes_list.tolist()
411
+ return codes_list
412
+
413
+ @torch.no_grad()
414
+ def decode(
415
+ self,
416
+ tokens: Union[List[List[int]], np.ndarray, torch.Tensor],
417
+ embodiment_ids: Union[List[int], int, None] = None,
418
+ durations: Union[List[float], np.ndarray, torch.Tensor, None] = None,
419
+ **kwargs,
420
+ ) -> Tuple[np.ndarray, np.ndarray]:
421
+ """Decode token sequences into action sequences.
422
+
423
+ This method reconstructs action sequences from discrete token indices using the
424
+ vector quantizer and decoder. The input tokens can be a list of lists, numpy array,
425
+ or torch tensor.
426
+
427
+ Args:
428
+ tokens (Union[List[List[int]], np.ndarray, torch.Tensor]): Token sequences to decode.
429
+ Shape: (b, n_tokens), where n_tokens must be divisible by `n_tokens_per_quantizer`.
430
+ For RVQ, tokens are interleaved: [q0_t0, q1_t0, ..., qN_t0, q0_t1, ...].
431
+ embodiment_ids (Union[List[int], int, None], optional): Embodiment IDs.
432
+ Shape: (b,) if list. If int, the same embodiment ID is repeated for all
433
+ sequences in the batch. It specifies the embodiment to decode.
434
+ Defaults to None, which uses `self.default_embodiment_id`.
435
+ durations (Union[List[float], np.ndarray, torch.Tensor, None], optional):
436
+ Duration of each action sequence in seconds. Shape: (b,).
437
+ If None, the duration is inferred from the default values in `embodiment_config`.
438
+ Defaults to None.
439
+ **kwargs: Additional keyword arguments (currently unused, reserved for future use).
440
+
441
+ Returns:
442
+ Tuple[np.ndarray, np.ndarray]: A tuple containing:
443
+ - reconstructed_actions: Reconstructed action sequences.
444
+ Shape: (b, seq_len, max_action_dim).
445
+ - padding_mask: Padding mask indicating valid timesteps.
446
+ Shape: (b, seq_len), where True indicates valid timesteps.
447
+
448
+ Raises:
449
+ ValueError: If token sequence length is invalid or incompatible with the model configuration.
450
+ TypeError: If input types are not supported.
451
+
452
+ Examples:
453
+ >>> # Using list of lists
454
+ >>> tokens = [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16]]
455
+ >>> actions, mask = model.decode(tokens, embodiment_ids=[0, 0])
456
+ >>> # Using numpy array
457
+ >>> tokens_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
458
+ >>> actions, mask = model.decode(tokens_np, embodiment_ids=[0, 0])
459
+ >>> # Using torch tensor
460
+ >>> tokens_tensor = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
461
+ >>> actions, mask = model.decode(tokens_tensor, embodiment_ids=[0, 0])
462
+ """
463
+ self.eval()
464
+
465
+ # Validate and convert input tokens
466
+ if isinstance(tokens, list):
467
+ if not all(isinstance(seq, list) for seq in tokens):
468
+ raise TypeError("If tokens is a list, all elements must be lists")
469
+ if len(tokens) == 0:
470
+ raise ValueError("Tokens list cannot be empty")
471
+ if not all(isinstance(val, (int, np.integer)) for seq in tokens for val in seq):
472
+ raise TypeError("All token values must be integers")
473
+ tokens_tensor = torch.tensor(tokens, dtype=torch.long, device=self.device)
474
+ elif isinstance(tokens, np.ndarray):
475
+ if tokens.ndim != 2:
476
+ raise ValueError(
477
+ f"Expected 2D array (batch, n_tokens), got {tokens.ndim}D array with shape {tokens.shape}"
478
+ )
479
+ if not np.issubdtype(tokens.dtype, np.integer):
480
+ raise TypeError(f"Tokens array must have integer dtype, got {tokens.dtype}")
481
+ tokens_tensor = torch.tensor(tokens, dtype=torch.long, device=self.device)
482
+ elif isinstance(tokens, torch.Tensor):
483
+ if tokens.ndim != 2:
484
+ raise ValueError(
485
+ f"Expected 2D tensor (batch, n_tokens), got {tokens.ndim}D tensor with shape {tokens.shape}"
486
+ )
487
+ if not tokens.dtype.is_integer:
488
+ raise TypeError(f"Tokens tensor must have integer dtype, got {tokens.dtype}")
489
+ tokens_tensor = tokens.to(dtype=torch.long, device=self.device)
490
+ else:
491
+ raise TypeError(f"tokens must be List[List[int]], np.ndarray, or torch.Tensor, got {type(tokens)}")
492
+
493
+ batch_size, n_tokens = tokens_tensor.shape
494
+ if batch_size == 0:
495
+ raise ValueError("Batch size must be at least 1")
496
+ if n_tokens == 0:
497
+ raise ValueError("Token sequence length must be at least 1")
498
+
499
+ # Validate token sequence length
500
+ if n_tokens % self.n_tokens_per_quantizer != 0:
501
+ raise ValueError(
502
+ f"Token sequence length ({n_tokens}) must be divisible by tokens per quantizer "
503
+ f"({self.n_tokens_per_quantizer}). Total tokens: {n_tokens}, "
504
+ f"Expected multiple of: {self.n_tokens_per_quantizer}. "
505
+ f"Number of quantizers: {self.num_quantizers}, Total tokens per sequence: {self.config.n_tokens}"
506
+ )
507
+
508
+ # Validate token values are within codebook range
509
+ if tokens_tensor.min() < 0 or tokens_tensor.max() >= self.vocab_size:
510
+ raise ValueError(
511
+ f"Token values must be in range [0, {self.vocab_size}), "
512
+ f"got range [{tokens_tensor.min().item()}, {tokens_tensor.max().item()}]"
513
+ )
514
+
515
+ # Handle embodiment_ids
516
+ embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
517
+ if isinstance(embodiment_ids, int):
518
+ if not 0 <= embodiment_ids < len(self.config.embodiment_config):
519
+ raise ValueError(
520
+ f"embodiment_id {embodiment_ids} is out of range [0, {len(self.config.embodiment_config)}). "
521
+ f"Available embodiment IDs: {list(range(len(self.config.embodiment_config)))}"
522
+ )
523
+ embodiment_ids_tensor = torch.tensor([embodiment_ids] * batch_size, dtype=torch.long, device=self.device)
524
+ elif isinstance(embodiment_ids, list):
525
+ if len(embodiment_ids) != batch_size:
526
+ raise ValueError(
527
+ f"Length of embodiment_ids ({len(embodiment_ids)}) must match batch size ({batch_size})"
528
+ )
529
+ for eid in embodiment_ids:
530
+ if not isinstance(eid, int) or not 0 <= eid < len(self.config.embodiment_config):
531
+ raise ValueError(
532
+ f"Invalid embodiment_id {eid}. Must be an integer in range [0, {len(self.config.embodiment_config)})"
533
+ )
534
+ embodiment_ids_tensor = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device)
535
+ else:
536
+ raise TypeError(f"embodiment_ids must be int, List[int], or None, got {type(embodiment_ids)}")
537
+
538
+ # Handle durations
539
+ durations_tensor = None
540
+ if durations is not None:
541
+ if isinstance(durations, (list, np.ndarray)):
542
+ durations_tensor = torch.tensor(durations, dtype=torch.float32, device=self.device)
543
+ elif isinstance(durations, torch.Tensor):
544
+ durations_tensor = durations.to(dtype=torch.float32, device=self.device)
545
+ else:
546
+ raise TypeError(
547
+ f"durations must be List[float], np.ndarray, torch.Tensor, or None, got {type(durations)}"
548
+ )
549
+ if durations_tensor.ndim != 1:
550
+ raise ValueError(
551
+ f"durations must be 1D, got {durations_tensor.ndim}D with shape {durations_tensor.shape}"
552
+ )
553
+ if len(durations_tensor) != batch_size:
554
+ raise ValueError(f"Length of durations ({len(durations_tensor)}) must match batch size ({batch_size})")
555
+ if (durations_tensor <= 0).any():
556
+ raise ValueError("All durations must be positive")
557
+
558
+ # Reshape tokens for dequantization: (b, n_tokens) -> (b, n_tokens_per_quantizer, n_quantizers)
559
+ indices = einops.rearrange(tokens_tensor, "b (n m) -> b m n", m=self.n_tokens_per_quantizer)
560
+
561
+ with torch.no_grad():
562
+ z_q = self._dequantize(indices)
563
+ x_recon, padding_mask = self._decode(z_q, embodiment_ids_tensor, durations_tensor)
564
+
565
+ return x_recon.float().cpu().numpy(), padding_mask.float().cpu().numpy()
566
+
567
+ def forward(
568
+ self,
569
+ x: Union[torch.Tensor, np.ndarray],
570
+ embodiment_ids: Union[torch.Tensor, int, List[int], None] = None,
571
+ padding_mask: Union[torch.Tensor, List[bool], np.ndarray, None] = None,
572
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
573
+ """Forward pass through the full ActionCodec pipeline.
574
+
575
+ This method performs encoding, quantization, and decoding in a single forward pass.
576
+ It is primarily used during training to compute reconstruction loss and commitment loss.
577
+ Both numpy arrays and torch tensors are supported as input.
578
+
579
+ Args:
580
+ x (Union[torch.Tensor, np.ndarray]): Action sequences to process.
581
+ Shape: (b, seq_len, max_action_dim).
582
+ embodiment_ids (Union[torch.Tensor, int, List[int], None], optional):
583
+ Embodiment IDs. Shape: (b,) if tensor or list. If int, same ID for all sequences.
584
+ Defaults to None, which uses `self.default_embodiment_id`.
585
+ padding_mask (Union[torch.Tensor, List[bool], np.ndarray, None], optional):
586
+ Padding mask. Shape: (b, seq_len). Defaults to None.
587
+
588
+ Returns:
589
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
590
+ - x_recon (torch.Tensor): Reconstructed action sequences.
591
+ Shape: (b, seq_len, max_action_dim).
592
+ - recon_mask (torch.Tensor): Reconstruction mask indicating valid timesteps.
593
+ Shape: (b, seq_len), where True indicates valid timesteps.
594
+
595
+ Note:
596
+ - For inference use cases, prefer using `encode()` and `decode()` methods separately.
597
+ - If you need token indices, use the `encode()` method instead.
598
+ """
599
+ # Convert numpy array to torch tensor if needed
600
+ if isinstance(x, np.ndarray):
601
+ x = torch.tensor(x, dtype=self.dtype, device=self.device)
602
+
603
+ # Handle embodiment_ids conversion
604
+ if isinstance(embodiment_ids, list):
605
+ embodiment_ids = torch.tensor(embodiment_ids, device=x.device, dtype=torch.long)
606
+ elif isinstance(embodiment_ids, int):
607
+ # Keep as int, will be handled by _encode
608
+ pass
609
+
610
+ # Handle padding_mask conversion
611
+ if isinstance(padding_mask, (list, np.ndarray)):
612
+ padding_mask = torch.tensor(padding_mask, device=x.device, dtype=torch.bool)
613
+
614
+ # Full forward pass: encode -> quantize -> decode
615
+ z_e = self._encode(x, embodiment_ids, padding_mask)
616
+ z_q, indices, perplexity, commit_loss = self._quantize(z_e, return_perplexity=True)
617
+ x_recon, recon_mask = self._decode(z_q, embodiment_ids)
618
+
619
+ return x_recon, recon_mask
620
+
621
+
622
+ AutoModel.register(ActionCodecConfig, ActionCodec)
623
+
624
+ __all__ = ["ActionCodec"]
625
+
626
+
627
+ if __name__ == "__main__":
628
+ print("=== ActionCodec Comprehensive Test ===\n")
629
+
630
+ # 1. Configuration Setup (RVQ enabled with n_quantizers=4)
631
+ initial_config = {
632
+ "robot_A": {"action_dim": 7, "freq": 10, "duration": 1, "description": "Robot A"},
633
+ }
634
+
635
+ # We set n_quantizers=4 to test Residual VQ logic
636
+ config = ActionCodecConfig(
637
+ embodiment_config=initial_config,
638
+ n_tokens=16, # Total tokens per sequence (latent_len * n_quantizers)
639
+ n_quantizers=4, # RVQ depth
640
+ vq_type="rvq",
641
+ vq_codebook_size=256,
642
+ encoder_dim=128,
643
+ decoder_dim=128,
644
+ )
645
+
646
+ # Expected latent sequence length = n_tokens / n_quantizers = 16 / 4 = 4
647
+ latent_seq_len = int(config.n_tokens // config.n_quantizers)
648
+ print(f"Config: {config.n_quantizers} quantizers, {latent_seq_len} latent vectors per sequence.")
649
+
650
+ codec = ActionCodec(config)
651
+ codec.eval()
652
+
653
+ # 2. Basic Encode/Decode Test
654
+ print("\n--- Test 1: Basic Encode/Decode ---")
655
+ batch_size = 2
656
+ seq_len_A = 10 # 10Hz * 1s
657
+
658
+ # Create random action data for Robot A (ID 0)
659
+ x = np.random.randn(batch_size, seq_len_A, 7).astype(np.float32)
660
+ # Masking: Second item in batch is half padding
661
+ padding_mask = np.ones((batch_size, seq_len_A), dtype=bool)
662
+ padding_mask[1, 5:] = False
663
+
664
+ embodiment_ids = [0, 0]
665
+
666
+ # Encode
667
+ codes = codec.encode(x, embodiment_ids, padding_mask)
668
+ print(f"Encoded codes shape (list length): {len(codes)} x {len(codes[0])}")
669
+
670
+ # Validate code length
671
+ assert len(codes[0]) == config.n_tokens, f"Expected {config.n_tokens} tokens, got {len(codes[0])}"
672
+
673
+ # Decode
674
+ x_recon, recon_mask = codec.decode(codes, embodiment_ids)
675
+ print(f"Reconstructed shape: {x_recon.shape}")
676
+ print(f"Recon mask shape: {recon_mask.shape}")
677
+
678
+ assert x_recon.shape == (batch_size, seq_len_A, 7) # Should imply zero-padding to max dim 7
679
+
680
+ # 3. Expansion Test
681
+ print("\n--- Test 2: Dynamic Expansion ---")
682
+ new_robot_config = {"robot_B": {"action_dim": 10, "freq": 20, "duration": 1, "description": "Robot B (Larger)"}}
683
+
684
+ print("Expanding codec to include Robot B (10 dims, 20Hz)...")
685
+ codec.expand_embodiment(new_robot_config)
686
+
687
+ assert codec.encoder.max_action_dim == 10
688
+ assert codec.decoder.max_action_dim == 10
689
+ print("✅ Expansion successful.")
690
+
691
+ # 4. Mixed Batch Test (Old + New Robot)
692
+ print("\n--- Test 3: Mixed Batch Inference ---")
693
+
694
+ # Batch: [Robot A, Robot B]
695
+ # Robot A: 10Hz, 1s -> 10 steps. Dims 7.
696
+ # Robot B: 20Hz, 1s -> 20 steps. Dims 10.
697
+ # Batch Max Steps: 20. Batch Max Dims: 10.
698
+
699
+ batch_x_mixed = np.zeros((2, 20, 10), dtype=np.float32)
700
+
701
+ # Fill Robot A data (index 0)
702
+ data_A = np.random.randn(10, 7)
703
+ batch_x_mixed[0, :10, :7] = data_A
704
+
705
+ # Fill Robot B data (index 1)
706
+ data_B = np.random.randn(20, 10)
707
+ batch_x_mixed[1, :20, :10] = data_B
708
+
709
+ # Embodiment IDs: 0 for A, 1 for B
710
+ # Note: expand_embodiment appends. Original was 0, new is 1.
711
+ mixed_ids = [0, 1]
712
+
713
+ # Encode Mask
714
+ mixed_mask = np.zeros((2, 20), dtype=bool)
715
+ mixed_mask[0, :10] = True
716
+ mixed_mask[1, :20] = True
717
+
718
+ print("Encoding mixed batch...")
719
+ mixed_codes = codec.encode(batch_x_mixed, mixed_ids, mixed_mask)
720
+
721
+ print("Decoding mixed batch...")
722
+ # Explicit durations (optional, but good for verification if we wanted to override defaults)
723
+ durations = [1, 1]
724
+ x_recon_mixed, dec_mask_mixed = codec.decode(mixed_codes, mixed_ids, durations)
725
+
726
+ print(f"Mixed Recon Shape: {x_recon_mixed.shape}")
727
+
728
+ # Validation
729
+ # Robot A output check (mask should be True for first 10, False for rest)
730
+ valid_A = dec_mask_mixed[0].sum()
731
+ valid_B = dec_mask_mixed[1].sum()
732
+
733
+ print(f"Valid steps detected by Decoder: Robot A={valid_A}, Robot B={valid_B}")
734
+
735
+ assert valid_A == 10
736
+ assert valid_B == 20
737
+
738
+ # Check dimensionality preservation
739
+ # Robot A's reconstruction in dims 7-9 should be noise or zero (depending on implementation),
740
+ # but dims 0-6 should contain signal.
741
+ print("✅ Mixed batch processed successfully.")
742
+
743
+ print("\n✨ All systems go.")
modular_actioncodec.py ADDED
@@ -0,0 +1,787 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from copy import deepcopy
3
+ from typing import List, Literal, Optional, Tuple, Union
4
+
5
+ import einops
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from .configuration_actioncodec import ActionCodecConfig
12
+
13
+
14
+ def apply_rotary_pos_emb(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
15
+ original_dtype = x.dtype
16
+
17
+ x = x.to(torch.float32)
18
+ sin = sin.to(torch.float32)
19
+ cos = cos.to(torch.float32)
20
+
21
+ x1 = x[..., 0::2]
22
+ x2 = x[..., 1::2]
23
+
24
+ rotated_x1 = x1 * cos - x2 * sin
25
+ rotated_x2 = x1 * sin + x2 * cos
26
+
27
+ x_out = torch.empty_like(x)
28
+ x_out[..., 0::2] = rotated_x1
29
+ x_out[..., 1::2] = rotated_x2
30
+
31
+ return x_out.to(original_dtype)
32
+
33
+
34
+ def attention_op(
35
+ q: torch.Tensor,
36
+ k: torch.Tensor,
37
+ v: torch.Tensor,
38
+ mask: torch.Tensor | None = None,
39
+ is_causal: bool = False,
40
+ ) -> torch.Tensor:
41
+ """
42
+
43
+ Args:
44
+ q (torch.Tensor): (*b, h, l, d)
45
+ k (torch.Tensor): (*b, k, s, d)
46
+ v (torch.Tensor): (*b, k, s, d)
47
+ mask (torch.Tensor | None, optional): (*b, l, s), where `True` indicates the element should take part in attention. Defaults to None.
48
+ is_causal (bool, optional): Whether to apply causal mask. Defaults to False.
49
+
50
+ Returns:
51
+ torch.Tensor: (*b, h, l, d)
52
+ """
53
+ heads, kv_heads = q.shape[-3], k.shape[-3]
54
+ if heads != kv_heads:
55
+ assert heads % kv_heads == 0, f"q_heads must be divisible by kv_heads, but got {heads} and {kv_heads}"
56
+ heads_per_kv_head = heads // kv_heads
57
+ k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim=1), (k, v))
58
+
59
+ if mask is not None:
60
+ if mask.dim() == 3:
61
+ mask = mask.unsqueeze(1)
62
+ mask = mask.expand(mask.shape[0], heads, -1, -1)
63
+
64
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=is_causal)
65
+ return out
66
+
67
+
68
+ class L2Norm(nn.Module):
69
+ def forward(self, x: torch.Tensor):
70
+ return F.normalize(x, p=2, dim=-1)
71
+
72
+
73
+ class Attention(nn.Module):
74
+ """
75
+ Args:
76
+ hidden_size (int): Hidden size of the input tensor.
77
+ num_heads (int): Number of attention heads.
78
+ num_kv_heads (int, optional): Number of key/value heads. Defaults to None.
79
+ qk_norm (Literal["l2", "ln", "none"], optional): Type of normalization to apply to query/key. Defaults to "none".
80
+ bias (bool, optional): Whether to use bias in linear layers. Defaults to False.
81
+
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ hidden_size: int,
87
+ num_heads: int,
88
+ num_kv_heads: int | None = None,
89
+ qk_norm: Literal["l2", "ln", "none"] = "none",
90
+ bias: bool = False,
91
+ zero_init_output: bool = False,
92
+ ):
93
+ super().__init__()
94
+ num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
95
+ self.dim = hidden_size // num_heads
96
+ self.num_heads, self.num_kv_heads = num_heads, num_kv_heads
97
+
98
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
99
+ self.k_proj = nn.Linear(hidden_size, self.dim * num_kv_heads, bias=bias)
100
+ self.v_proj = nn.Linear(hidden_size, self.dim * num_kv_heads, bias=bias)
101
+ self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
102
+
103
+ if qk_norm == "l2":
104
+ self.q_norm = L2Norm()
105
+ self.k_norm = L2Norm()
106
+ elif qk_norm == "ln":
107
+ self.q_norm = nn.LayerNorm(self.dim, elementwise_affine=False)
108
+ self.k_norm = nn.LayerNorm(self.dim, elementwise_affine=False)
109
+ else:
110
+ self.q_norm = nn.Identity()
111
+ self.k_norm = nn.Identity()
112
+
113
+ if zero_init_output:
114
+ nn.init.zeros_(self.out_proj.weight)
115
+ if self.out_proj.bias is not None:
116
+ nn.init.zeros_(self.out_proj.bias)
117
+
118
+ def forward(
119
+ self,
120
+ x: torch.Tensor,
121
+ context: torch.Tensor | None = None,
122
+ mask: torch.Tensor | None = None,
123
+ rotary_pos_emb: Tuple[torch.Tensor, torch.Tensor] | None = None,
124
+ is_causal: bool = False,
125
+ ) -> torch.Tensor:
126
+ context = x if context is None else context
127
+
128
+ q = self.q_proj(x)
129
+ k, v = self.k_proj(context), self.v_proj(context)
130
+
131
+ q = einops.rearrange(q, "b l (h d) -> b h l d", h=self.num_heads)
132
+ k = einops.rearrange(k, "b s (h d) -> b h s d", h=self.num_kv_heads)
133
+ v = einops.rearrange(v, "b s (h d) -> b h s d", h=self.num_kv_heads)
134
+
135
+ q, k = self.q_norm(q), self.k_norm(k)
136
+
137
+ if rotary_pos_emb is not None:
138
+ q, k = map(lambda t: apply_rotary_pos_emb(t, *rotary_pos_emb), (q, k))
139
+
140
+ out = attention_op(q, k, v, mask=mask, is_causal=is_causal)
141
+ out = einops.rearrange(out, "b h l d -> b l (h d)")
142
+ out = self.out_proj(out)
143
+
144
+ return out
145
+
146
+
147
+ class PositionalEmbedding(nn.Module):
148
+ def __init__(
149
+ self,
150
+ dim: int,
151
+ encoding_type: Literal["sincos", "fourier"] = "sincos",
152
+ scale: float = 2.0,
153
+ ):
154
+ super().__init__()
155
+ self.dim = dim
156
+ self.encoding_type = encoding_type
157
+
158
+ if encoding_type == "fourier":
159
+ self.register_buffer("freqs", torch.randn(dim // 2) * scale, persistent=True)
160
+ elif encoding_type == "sincos":
161
+ pass
162
+ else:
163
+ raise ValueError(f"encoding_type must be 'sincos' or 'fourier', but got {encoding_type}")
164
+
165
+ def _create_sincos_emb(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
166
+ position = torch.arange(seq_len, device=device, dtype=torch.float32).unsqueeze(1)
167
+ div_term = torch.exp(
168
+ torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) * -(math.log(10000.0) / self.dim)
169
+ )
170
+
171
+ pos_emb = torch.zeros(seq_len, self.dim, device=device, dtype=dtype)
172
+ pos_emb[:, 0::2] = torch.sin(position * div_term).to(dtype)
173
+ pos_emb[:, 1::2] = torch.cos(position * div_term).to(dtype)
174
+
175
+ return pos_emb
176
+
177
+ def _create_fourier_emb(self, timestamps: torch.Tensor, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
178
+ # Ensure freqs is on the correct device
179
+ freqs = self.freqs.to(device)
180
+ pos_emb = torch.einsum("b t, d -> b t d", timestamps, 2 * np.pi * freqs).to(device, torch.float32)
181
+ pos_emb = torch.cat([pos_emb.cos(), pos_emb.sin()], dim=-1).to(dtype)
182
+ return pos_emb
183
+
184
+ def forward(
185
+ self, x: torch.Tensor, freq: Optional[Union[float, torch.Tensor]] = None, dtype: torch.dtype = torch.float32
186
+ ) -> torch.Tensor:
187
+ b, t = x.shape[0], x.shape[1]
188
+ device = x.device
189
+
190
+ if self.encoding_type == "sincos":
191
+ pos_emb = self._create_sincos_emb(t, device, dtype)
192
+ pos_emb = pos_emb.unsqueeze(0).expand(b, -1, -1)
193
+ return pos_emb * 0.1
194
+
195
+ elif self.encoding_type == "fourier":
196
+ if freq is None:
197
+ raise ValueError(
198
+ "freq must be provided when encoding_type is 'fourier'. Please provide the sequence frequency."
199
+ )
200
+ if isinstance(freq, float):
201
+ freq = torch.tensor(freq, dtype=dtype, device=device)[None].expand(b)
202
+ timestamps = torch.einsum("t, b -> b t", torch.arange(t, dtype=dtype, device=device), 1 / freq)
203
+ pos_emb = self._create_fourier_emb(timestamps, device, dtype)
204
+ return pos_emb * 0.1
205
+ else:
206
+ raise ValueError(f"Unknown encoding_type: {self.encoding_type}")
207
+
208
+
209
+ class SinusoidalPositionalEmbedding(PositionalEmbedding):
210
+ def __init__(self, dim: int):
211
+ super().__init__(dim=dim, encoding_type="sincos")
212
+
213
+ def forward(self, x: torch.Tensor, pos: Optional[torch.Tensor] = None) -> torch.Tensor:
214
+ return super().forward(x, freq=None)
215
+
216
+
217
+ class FeedForward(nn.Module):
218
+ def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False):
219
+ super().__init__()
220
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
221
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
222
+ self.act_fn = nn.GELU()
223
+
224
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
225
+ down_proj = self.down_proj(self.act_fn(self.up_proj(x)))
226
+ return down_proj
227
+
228
+
229
+ class LayerScale(nn.Module):
230
+ def __init__(self, dim, init_val=1e-2):
231
+ super().__init__()
232
+ self.scale = nn.Parameter(torch.full([dim], init_val))
233
+
234
+ def forward(self, x):
235
+ return x * self.scale
236
+
237
+
238
+ class PerceiverTransformerBlock(nn.Module):
239
+ def __init__(
240
+ self,
241
+ dim: int,
242
+ num_heads: int,
243
+ mlp_ratio: int = 4,
244
+ dropout: float = 0.0,
245
+ qk_norm: str = "ln",
246
+ layer_scale: bool = True,
247
+ zero_init_output: bool = False,
248
+ add_self_attn: bool = False,
249
+ add_causal_mask: bool = False,
250
+ ):
251
+ super().__init__()
252
+ self.add_self_attn = add_self_attn
253
+ self.add_causal_mask = add_causal_mask
254
+
255
+ self.norm1 = nn.LayerNorm(dim, eps=1e-2)
256
+ self.cross_attn = Attention(
257
+ hidden_size=dim, num_heads=num_heads, qk_norm=qk_norm, bias=False, zero_init_output=zero_init_output
258
+ )
259
+
260
+ if add_self_attn:
261
+ self.norm_self_attn = nn.LayerNorm(dim, eps=1e-2)
262
+ self.self_attn = Attention(
263
+ hidden_size=dim, num_heads=num_heads, qk_norm=qk_norm, bias=False, zero_init_output=zero_init_output
264
+ )
265
+ else:
266
+ self.self_attn = None
267
+
268
+ self.norm2 = nn.LayerNorm(dim, eps=1e-2)
269
+ self.mlp = FeedForward(hidden_size=dim, intermediate_size=int(mlp_ratio * dim), bias=True)
270
+ self.dropout = nn.Dropout(dropout)
271
+
272
+ self.attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
273
+ self.mlp_scale = LayerScale(dim) if layer_scale else nn.Identity()
274
+
275
+ if zero_init_output:
276
+ nn.init.zeros_(self.mlp.down_proj.weight)
277
+ if self.mlp.down_proj.bias is not None:
278
+ nn.init.zeros_(self.mlp.down_proj.bias)
279
+
280
+ def forward(
281
+ self,
282
+ x: torch.Tensor,
283
+ context: torch.Tensor,
284
+ context_mask: Optional[torch.Tensor] = None,
285
+ rotary_pos_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
286
+ ) -> torch.Tensor:
287
+ residual = x
288
+ x = self.norm1(x)
289
+ x = self.cross_attn(x=x, context=context, mask=context_mask, rotary_pos_emb=rotary_pos_emb, is_causal=False)
290
+ x = self.dropout(x)
291
+ x = self.attn_scale(x)
292
+ x = x + residual
293
+
294
+ if self.add_self_attn:
295
+ residual = x
296
+ x = self.norm_self_attn(x)
297
+ x = self.self_attn(
298
+ x=x,
299
+ context=None,
300
+ mask=None,
301
+ rotary_pos_emb=rotary_pos_emb,
302
+ is_causal=self.add_causal_mask,
303
+ )
304
+ x = self.dropout(x)
305
+ x = self.attn_scale(x)
306
+ x = x + residual
307
+
308
+ residual = x
309
+ x = self.norm2(x)
310
+ x = self.mlp(x)
311
+ x = self.dropout(x)
312
+ x = self.mlp_scale(x)
313
+ x = x + residual
314
+
315
+ return x
316
+
317
+
318
+ class EmbodimentEmbedding(nn.Module):
319
+ def __init__(self, embodiment_config: dict, out_len: int, out_dim: int) -> None:
320
+ super().__init__()
321
+ self.out_len, self.out_dim = out_len, out_dim
322
+
323
+ self.embodiment_config = embodiment_config
324
+ self.num_embodiments = len(self.embodiment_config)
325
+
326
+ self.embedding = nn.Embedding(self.num_embodiments, out_dim * out_len)
327
+
328
+ @torch.no_grad()
329
+ def expand_embodiment(self, embodiment_config: dict):
330
+ for k in embodiment_config.keys():
331
+ assert k not in self.embodiment_config.keys()
332
+ self.embodiment_config.update(embodiment_config)
333
+ self.num_embodiments = len(self.embodiment_config)
334
+
335
+ extra_embodiments = len(embodiment_config)
336
+
337
+ old_weights = torch.clone(self.embedding.weight)
338
+ self.embedding = nn.Embedding(self.num_embodiments, self.out_dim * self.out_len)
339
+ self.embedding.weight.data[:-extra_embodiments] = old_weights
340
+ return self
341
+
342
+ def keys(self) -> list[str]:
343
+ return list(self.embodiment_config.keys())
344
+
345
+ def ids_to_keys(self, ids: torch.Tensor) -> List[str]:
346
+ return [self.keys()[i] for i in ids]
347
+
348
+ def keys_to_ids(self, keys: List[str]) -> torch.Tensor:
349
+ return torch.tensor([self.keys().index(k) for k in keys])
350
+
351
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
352
+ return einops.rearrange(self.embedding(x), "b (l d) -> b l d", d=self.out_dim)
353
+
354
+
355
+ class PerceiverEncoder(nn.Module):
356
+ def __init__(self, config: ActionCodecConfig):
357
+ super().__init__()
358
+ self.config = config
359
+ self.embodiment_config = deepcopy(config.embodiment_config)
360
+
361
+ out_len = int(config.n_tokens // config.n_quantizers)
362
+ dim = config.encoder_dim
363
+
364
+ _action_dim, _freq, _duration = list(), list(), list()
365
+ for k, v in self.embodiment_config.items():
366
+ _action_dim.append(v["action_dim"])
367
+ _freq.append(v["freq"])
368
+ _duration.append(v["duration"])
369
+ self.register_buffer("_action_dim", torch.tensor(_action_dim), persistent=False)
370
+ self.register_buffer("_freq", torch.tensor(_freq), persistent=False)
371
+ self.register_buffer("_duration", torch.tensor(_duration), persistent=False)
372
+
373
+ self.max_action_dim = max(v["action_dim"] for v in self.embodiment_config.values())
374
+ self.input_proj = nn.Linear(self.max_action_dim, dim)
375
+
376
+ self.cls_tokens = EmbodimentEmbedding(self.embodiment_config, out_len, dim)
377
+
378
+ self.pos_emb_q = PositionalEmbedding(dim, encoding_type="sincos")
379
+ self.pos_emb_kv = PositionalEmbedding(dim, encoding_type=config.encoder_pos_encoding_type)
380
+
381
+ self.layers = nn.ModuleList(
382
+ [
383
+ PerceiverTransformerBlock(
384
+ dim=dim,
385
+ num_heads=config.encoder_n_heads,
386
+ add_self_attn=config.encoder_add_self_attn,
387
+ add_causal_mask=config.encoder_add_causal_mask,
388
+ )
389
+ for _ in range(config.encoder_n_layers)
390
+ ]
391
+ )
392
+
393
+ self.output_proj = nn.Linear(dim, config.z_dim)
394
+ self._init_weights()
395
+
396
+ def _init_weights(self):
397
+ nn.init.trunc_normal_(self.input_proj.weight, std=0.02)
398
+ if self.input_proj.bias is not None:
399
+ nn.init.zeros_(self.input_proj.bias)
400
+ nn.init.trunc_normal_(self.output_proj.weight, std=0.02)
401
+ if self.output_proj.bias is not None:
402
+ nn.init.zeros_(self.output_proj.bias)
403
+
404
+ nn.init.trunc_normal_(self.cls_tokens.embedding.weight, std=0.02)
405
+
406
+ @torch.no_grad()
407
+ def expand_embodiment(self, embodiment_config: dict):
408
+ self.cls_tokens.expand_embodiment(embodiment_config)
409
+ self.embodiment_config = self.cls_tokens.embodiment_config
410
+ _action_dim, _freq, _duration = list(), list(), list()
411
+ for k, v in self.embodiment_config.items():
412
+ _action_dim.append(v["action_dim"])
413
+ _freq.append(v["freq"])
414
+ _duration.append(v["duration"])
415
+ self._action_dim = torch.tensor(_action_dim)
416
+ self._freq = torch.tensor(_freq)
417
+ self._duration = torch.tensor(_duration)
418
+
419
+ max_action_dim = max(v["action_dim"] for v in self.embodiment_config.values())
420
+ if max_action_dim > self.max_action_dim:
421
+ old_weights = torch.clone(self.input_proj.weight)
422
+ old_bias = torch.clone(self.input_proj.bias)
423
+ self.input_proj = nn.Linear(max_action_dim, self.config.encoder_dim)
424
+ self.input_proj.weight.data[:, : self.max_action_dim] = old_weights
425
+ self.input_proj.bias.data = old_bias
426
+ self.max_action_dim = max_action_dim
427
+
428
+ return self
429
+
430
+ def forward(
431
+ self,
432
+ x: torch.Tensor,
433
+ embodiment_ids: torch.Tensor | int,
434
+ padding_mask: Optional[torch.Tensor] = None,
435
+ ) -> torch.Tensor:
436
+ """Encode action sequences into latent representations.
437
+
438
+ Args:
439
+ x (torch.Tensor): Action sequences to encode. Shape: (b, seq_len, max_action_dim).
440
+ Assumes that the action dimension is zero-padded to the max action dimension.
441
+ `seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length.
442
+ embodiment_ids (torch.Tensor | int): Embodiment IDs. Shape: (b,).
443
+ If int, the same embodiment ID is repeated for all sequences in the batch.
444
+ It specifies the embodiment to encode.
445
+ padding_mask (Optional[torch.Tensor], optional): Padding mask, where `False` values indicate padding. Shape: (b, seq_len). Defaults to None.
446
+ It is used to mask the padding tokens on `seq_len` dimension.
447
+
448
+ Returns:
449
+ torch.Tensor: Encoded latent representations. Shape: (b, n_tokens_per_quantizer, z_dim).
450
+ """
451
+ b, seq_len, _ = x.shape
452
+
453
+ x = self.input_proj(x)
454
+
455
+ if isinstance(embodiment_ids, int):
456
+ embodiment_ids = torch.tensor([embodiment_ids], dtype=torch.long, device=x.device).repeat(b)
457
+
458
+ cls_tokens = self.cls_tokens(embodiment_ids)
459
+
460
+ freqs = self._freq[embodiment_ids].to(x.device, x.dtype)
461
+
462
+ pos_emb_q = self.pos_emb_q(cls_tokens)
463
+ pos_emb_kv = self.pos_emb_kv(x, freqs)
464
+
465
+ cls_tokens = cls_tokens + pos_emb_q
466
+ x = x + pos_emb_kv
467
+
468
+ if padding_mask is not None:
469
+ padding_mask = padding_mask.unsqueeze(1).expand(-1, cls_tokens.shape[1], -1)
470
+
471
+ for layer in self.layers:
472
+ cls_tokens = layer(x=cls_tokens, context=x, context_mask=padding_mask)
473
+
474
+ return self.output_proj(cls_tokens)
475
+
476
+
477
+ class PerceiverDecoder(nn.Module):
478
+ def __init__(self, config: ActionCodecConfig):
479
+ super().__init__()
480
+ self.config = config
481
+ self.embodiment_config = deepcopy(config.embodiment_config)
482
+
483
+ dim = config.decoder_dim
484
+
485
+ _action_dim, _freq, _duration = list(), list(), list()
486
+ for k, v in self.embodiment_config.items():
487
+ _action_dim.append(v["action_dim"])
488
+ _freq.append(v["freq"])
489
+ _duration.append(v["duration"])
490
+ self.register_buffer("_action_dim", torch.tensor(_action_dim), persistent=False)
491
+ self.register_buffer("_freq", torch.tensor(_freq), persistent=False)
492
+ self.register_buffer("_duration", torch.tensor(_duration), persistent=False)
493
+
494
+ self.max_action_dim = max(v["action_dim"] for v in self.embodiment_config.values())
495
+ self.input_proj = nn.Linear(config.z_dim, dim)
496
+
497
+ self.cls_tokens = EmbodimentEmbedding(self.embodiment_config, config.decoder_cls_size, dim)
498
+
499
+ self.pos_emb_q = PositionalEmbedding(dim, encoding_type=config.decoder_pos_encoding_type)
500
+ self.pos_emb_kv = PositionalEmbedding(dim, encoding_type="sincos")
501
+
502
+ self.layers = nn.ModuleList(
503
+ [
504
+ PerceiverTransformerBlock(
505
+ dim=dim,
506
+ num_heads=config.decoder_n_heads,
507
+ add_self_attn=config.decoder_add_self_attn,
508
+ add_causal_mask=config.decoder_add_causal_mask,
509
+ )
510
+ for _ in range(config.decoder_n_layers)
511
+ ]
512
+ )
513
+
514
+ self.output_proj = nn.Linear(dim, self.max_action_dim)
515
+ self._init_weights()
516
+
517
+ def _init_weights(self):
518
+ nn.init.trunc_normal_(self.input_proj.weight, std=0.02)
519
+ if self.input_proj.bias is not None:
520
+ nn.init.zeros_(self.input_proj.bias)
521
+ nn.init.trunc_normal_(self.output_proj.weight, std=0.02)
522
+ if self.output_proj.bias is not None:
523
+ nn.init.zeros_(self.output_proj.bias)
524
+ nn.init.trunc_normal_(self.cls_tokens.embedding.weight, std=0.02)
525
+
526
+ @torch.no_grad()
527
+ def expand_embodiment(self, embodiment_config: dict):
528
+ self.cls_tokens.expand_embodiment(embodiment_config)
529
+ self.embodiment_config = self.cls_tokens.embodiment_config
530
+
531
+ _action_dim, _freq, _duration = list(), list(), list()
532
+ for k, v in self.embodiment_config.items():
533
+ _action_dim.append(v["action_dim"])
534
+ _freq.append(v["freq"])
535
+ _duration.append(v["duration"])
536
+ self._action_dim = torch.tensor(_action_dim)
537
+ self._freq = torch.tensor(_freq)
538
+ self._duration = torch.tensor(_duration)
539
+
540
+ max_action_dim = max(v["action_dim"] for v in self.embodiment_config.values())
541
+
542
+ if max_action_dim > self.max_action_dim:
543
+ old_weights = torch.clone(self.output_proj.weight)
544
+ old_bias = torch.clone(self.output_proj.bias)
545
+
546
+ self.output_proj = nn.Linear(self.config.decoder_dim, max_action_dim)
547
+
548
+ self.output_proj.weight.data[: self.max_action_dim, :] = old_weights
549
+ self.output_proj.bias.data[: self.max_action_dim] = old_bias
550
+
551
+ self.max_action_dim = max_action_dim
552
+
553
+ return self
554
+
555
+ def forward(
556
+ self, x: torch.Tensor, embodiment_ids: torch.Tensor | int, durations: torch.Tensor | None = None
557
+ ) -> torch.Tensor:
558
+ """Decode latent representations into action sequences.
559
+
560
+ Args:
561
+ x (torch.Tensor): Latent representations to decode. Shape: (b, n_tokens_per_quantizer, z_dim).
562
+ embodiment_ids (torch.Tensor | int): Embodiment IDs. Shape: (b,).
563
+ If int, the same embodiment ID is repeated for all sequences in the batch.
564
+ It specifies the embodiment to decode.
565
+ durations (torch.Tensor | None, optional): Duration of each action sequence. Shape: (b,).
566
+ If `None`, the duration is inferred from the default values in `embodiment_config`.
567
+
568
+ Returns:
569
+ torch.Tensor: Decoded action sequences. Shape: (b, seq_len, max_action_dim).
570
+ Assumes that the action dimension is zero-padded to the max action dimension.
571
+ `seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length.
572
+ """
573
+ b, seq_len, _ = x.shape
574
+ x = self.input_proj(x)
575
+
576
+ if isinstance(embodiment_ids, int):
577
+ embodiment_ids = torch.tensor([embodiment_ids], dtype=torch.long, device=x.device).repeat(b)
578
+
579
+ cls_tokens = self.cls_tokens(embodiment_ids)
580
+
581
+ freqs = self._freq[embodiment_ids]
582
+ if freqs.device != x.device:
583
+ freqs = freqs.to(x.device)
584
+
585
+ durations = self._duration[embodiment_ids] if durations is None else durations
586
+ if isinstance(durations, torch.Tensor) and durations.device != x.device:
587
+ durations = durations.to(x.device)
588
+
589
+ action_horizons = (durations * freqs).long()
590
+ max_horizon = action_horizons.max().item()
591
+ padding_mask = torch.arange(max_horizon, device=x.device).expand(b, -1) < action_horizons.unsqueeze(1)
592
+
593
+ if self.config.decoder_cls_size == 1:
594
+ cls_tokens = cls_tokens.repeat(1, max_horizon, 1)
595
+
596
+ pos_emb_q = self.pos_emb_q(cls_tokens, freqs)
597
+ pos_emb_kv = self.pos_emb_kv(x)
598
+
599
+ cls_tokens = cls_tokens + pos_emb_q
600
+ x = x + pos_emb_kv
601
+
602
+ for layer in self.layers:
603
+ cls_tokens = layer(x=cls_tokens, context=x)
604
+
605
+ output = self.output_proj(cls_tokens)
606
+
607
+ return output, padding_mask
608
+
609
+
610
+ if __name__ == "__main__":
611
+ # ------------------------------------------
612
+ # 1. Initialization
613
+ # ------------------------------------------
614
+ print("=== Test 1: Initialization ===")
615
+
616
+ # Define initial config with two smaller robots
617
+ initial_embodiment_config = {
618
+ "robot_small_7d": {"action_dim": 7, "freq": 20, "duration": 1, "description": "Original Robot"},
619
+ "robot_tiny_3d": {"action_dim": 3, "freq": 10, "duration": 2, "description": "Tiny Robot"},
620
+ }
621
+
622
+ config = ActionCodecConfig(embodiment_config=initial_embodiment_config)
623
+
624
+ # Set seed for reproducibility
625
+ torch.manual_seed(42)
626
+
627
+ encoder = PerceiverEncoder(config)
628
+ decoder = PerceiverDecoder(config)
629
+
630
+ encoder.eval()
631
+ decoder.eval()
632
+ print("✅ Models initialized successfully.")
633
+
634
+ # ------------------------------------------
635
+ # 2. Baseline Inference (Before Expansion)
636
+ # ------------------------------------------
637
+ print("\n=== Test 2: Baseline Inference (Before Expansion) ===")
638
+
639
+ # Simulate Robot 1 (7-dim) data
640
+ # Max action dim currently is 7.
641
+ batch_size = 1
642
+ seq_len = 20 # 20Hz * 1s
643
+
644
+ # Input: (1, 20, 7)
645
+ input_action_v0 = torch.randn(batch_size, seq_len, 7)
646
+ emb_id_v0 = torch.tensor([0], dtype=torch.long) # ID 0 -> robot_small_7d
647
+
648
+ with torch.no_grad():
649
+ z_ref = encoder(input_action_v0, emb_id_v0)
650
+ rec_action_ref, _ = decoder(z_ref, emb_id_v0)
651
+
652
+ print(f"Reference Latent Shape: {z_ref.shape}")
653
+ print(f"Reference Recon Shape: {rec_action_ref.shape}")
654
+
655
+ # ------------------------------------------
656
+ # 3. Model Expansion (Add New Embodiment)
657
+ # ------------------------------------------
658
+ print("\n=== Test 3: Model Expansion ===")
659
+
660
+ # Add a larger robot: 10-dim, high frequency
661
+ new_embodiment_config = {
662
+ "robot_large_10d": {"action_dim": 10, "freq": 30, "duration": 1, "description": "New Large Robot"}
663
+ }
664
+
665
+ print(f"Expanding from Max Dim {encoder.max_action_dim} to 10...")
666
+ encoder.expand_embodiment(new_embodiment_config)
667
+ decoder.expand_embodiment(new_embodiment_config)
668
+
669
+ # Verify buffer updates
670
+ assert encoder._action_dim[-1] == 10
671
+ assert encoder.max_action_dim == 10
672
+ assert decoder.max_action_dim == 10
673
+ print(f"✅ Expansion successful. New Encoder Input Dim: {encoder.input_proj.weight.shape[1]}")
674
+ print(f"✅ New Decoder Output Dim: {decoder.output_proj.weight.shape[0]}")
675
+
676
+ # ------------------------------------------
677
+ # 4. Encoder Invariance Check
678
+ # ------------------------------------------
679
+ print("\n=== Test 4: Encoder Invariance Check ===")
680
+
681
+ # Pad old data (7 dims) to new max dim (10 dims) with ZEROS.
682
+ input_action_padded = torch.zeros(batch_size, seq_len, 10)
683
+ input_action_padded[:, :, :7] = input_action_v0
684
+
685
+ with torch.no_grad():
686
+ z_new = encoder(input_action_padded, emb_id_v0)
687
+
688
+ # Compare latents
689
+ diff_z = (z_ref - z_new).abs().max().item()
690
+ print(f"Latent Difference (Max Abs): {diff_z:.8f}")
691
+
692
+ if diff_z < 1e-6:
693
+ print("✅ PASS: Encoder produces identical latents for old data.")
694
+ else:
695
+ print("❌ FAIL: Encoder outputs changed after expansion!")
696
+
697
+ # ------------------------------------------
698
+ # 5. Decoder Invariance Check
699
+ # ------------------------------------------
700
+ print("\n=== Test 5: Decoder Invariance Check ===")
701
+
702
+ with torch.no_grad():
703
+ # Feed old latent to expanded decoder
704
+ rec_action_new_full, _ = decoder(z_ref, emb_id_v0)
705
+
706
+ # Output shape should be (1, 20, 10)
707
+ print(f"Expanded Decoder Output Shape: {rec_action_new_full.shape}")
708
+
709
+ # Slice first 7 dims, should match reference
710
+ rec_action_new_sliced = rec_action_new_full[:, :, :7]
711
+
712
+ diff_rec = (rec_action_ref - rec_action_new_sliced).abs().max().item()
713
+ print(f"Reconstruction Difference (Max Abs on valid dims): {diff_rec:.8f}")
714
+
715
+ if diff_rec < 1e-6:
716
+ print("✅ PASS: Decoder produces identical action values for valid dimensions.")
717
+ else:
718
+ print("❌ FAIL: Decoder outputs changed!")
719
+
720
+ # Check phantom dimensions (7-9)
721
+ # For old embodiment, these are driven by random weights and should be random
722
+ new_dims_mean = rec_action_new_full[:, :, 7:].abs().mean().item()
723
+ print(f"Values in new phantom dimensions (should be random garbage): {new_dims_mean:.4f}")
724
+
725
+ # ------------------------------------------
726
+ # 6. New Embodiment Inference
727
+ # ------------------------------------------
728
+ print("\n=== Test 6: New Embodiment Inference ===")
729
+
730
+ # ID 2 -> robot_large_10d
731
+ emb_id_new = torch.tensor([2], dtype=torch.long)
732
+ seq_len_new = 30 # 30Hz * 1s
733
+
734
+ input_action_new = torch.randn(1, seq_len_new, 10)
735
+
736
+ with torch.no_grad():
737
+ z_large = encoder(input_action_new, emb_id_new)
738
+ rec_large, mask_large = decoder(z_large, emb_id_new)
739
+
740
+ print(f"New Embodiment Output Shape: {rec_large.shape}")
741
+
742
+ if rec_large.shape == (1, 30, 10):
743
+ print("✅ PASS: New embodiment handled correctly with full dimensions.")
744
+ else:
745
+ print(f"❌ FAIL: Expected (1, 30, 10), got {rec_large.shape}")
746
+
747
+ # ------------------------------------------
748
+ # 7. Mixed Batch Processing (Masking)
749
+ # ------------------------------------------
750
+ print("\n=== Test 7: Mixed Batch Processing ===")
751
+
752
+ # Batch size 2: [Robot 0 (20Hz, 7dim), Robot 2 (30Hz, 10dim)]
753
+ mixed_emb_ids = torch.tensor([0, 2], dtype=torch.long)
754
+
755
+ # Max seq len is 30. Max action dim is 10.
756
+ batch_input = torch.zeros(2, 30, 10)
757
+
758
+ # Fill data
759
+ # Batch 0: Length 20, Dim 7 valid
760
+ batch_input[0, :20, :7] = torch.randn(20, 7)
761
+ # Batch 1: Length 30, Dim 10 valid
762
+ batch_input[1, :30, :10] = torch.randn(30, 10)
763
+
764
+ # Encoder Mask: True = Valid
765
+ enc_padding_mask = torch.zeros(2, 30, dtype=torch.bool)
766
+ enc_padding_mask[0, :20] = True
767
+ enc_padding_mask[1, :30] = True
768
+
769
+ print("Running mixed batch...")
770
+ with torch.no_grad():
771
+ z_mixed = encoder(batch_input, mixed_emb_ids, padding_mask=enc_padding_mask)
772
+ rec_mixed, dec_padding_mask = decoder(z_mixed, mixed_emb_ids)
773
+
774
+ print(f"Mixed Reconstruction Shape: {rec_mixed.shape}") # Should be (2, 30, 10)
775
+
776
+ # Verify Decoder Generated Mask
777
+ valid_len_0 = dec_padding_mask[0].sum().item()
778
+ valid_len_1 = dec_padding_mask[1].sum().item()
779
+
780
+ print(f"Decoder Mask Valid Lengths: Batch 0={valid_len_0}, Batch 1={valid_len_1}")
781
+
782
+ if valid_len_0 == 20 and valid_len_1 == 30:
783
+ print("✅ PASS: Decoder correctly generated masks based on frequency and duration.")
784
+ else:
785
+ print("❌ FAIL: Decoder masks are incorrect.")
786
+
787
+ print("\n✨ All Tests Completed ✨")
rvq.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.distributed as dist
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from vector_quantize_pytorch import VectorQuantize as torchVQ
10
+
11
+
12
+ def sample_vectors(samples, num):
13
+ # samples: (N, D), num_samples: N, feature dim: D
14
+ num_samples, device = samples.shape[0], samples.device
15
+ if num_samples >= num:
16
+ indices = torch.randperm(num_samples, device=device)[:num]
17
+ else:
18
+ indices = torch.randint(0, num_samples, (num,), device=device)
19
+ return samples[indices].float() # (num, D), ensure fp32
20
+
21
+
22
+ def ema_inplace(moving_avg, new, decay):
23
+ # moving_avg: (codebook_size) or (codebook_size, D'), new: same as moving_avg
24
+ """Update exponential moving average in-place"""
25
+ moving_avg.data.mul_(decay).add_(new.float(), alpha=(1 - decay)) # ensure fp32
26
+
27
+
28
+ def kmeans(samples, num_clusters, num_iters=10):
29
+ # samples: (N, D), N samples with D dimensions
30
+ dim, _ = samples.shape[-1], torch.float32 # Force fp32
31
+ means = sample_vectors(samples, num_clusters).float() # (num_clusters, D), ensure fp32
32
+
33
+ for _ in range(num_iters):
34
+ dists = -(
35
+ samples.float().pow(2).sum(1, keepdim=True) # (N, 1), ensure fp32
36
+ - 2 * samples.float() @ means.t() # (N, num_clusters), ensure fp32
37
+ + means.t().float().pow(2).sum(0, keepdim=True)
38
+ ) # (1, num_clusters), ensure fp32
39
+ # dists: (N, num_clusters)
40
+ buckets = dists.max(dim=-1).indices # (N)
41
+ bins = torch.bincount(buckets, minlength=num_clusters) # (num_clusters)
42
+ zero_mask = bins == 0 # (num_clusters)
43
+ bins_min_clamped = bins.masked_fill(zero_mask, 1) # (num_clusters)
44
+
45
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=torch.float32) # (num_clusters, D), ensure fp32
46
+ new_means.scatter_add_(
47
+ 0, buckets.unsqueeze(1).expand(-1, dim), samples.float()
48
+ ) # (num_clusters, D), ensure fp32
49
+ new_means = new_means / bins_min_clamped[..., None] # (num_clusters, D)
50
+ means = torch.where(zero_mask[..., None], means, new_means) # (num_clusters, D)
51
+
52
+ # Final cluster assignments for returning cluster sizes
53
+ dists = -(
54
+ samples.float().pow(2).sum(1, keepdim=True)
55
+ - 2 * samples.float() @ means.t()
56
+ + means.t().float().pow(2).sum(0, keepdim=True)
57
+ ) # (N, num_clusters), ensure fp32
58
+ buckets = dists.max(dim=-1).indices # (N)
59
+ bins = torch.bincount(buckets, minlength=num_clusters).float() # (num_clusters), ensure fp32
60
+
61
+ return means, bins # (num_clusters, D), (num_clusters)
62
+
63
+
64
+ class VectorQuantize(nn.Module):
65
+ def __init__(
66
+ self,
67
+ input_dim,
68
+ codebook_size,
69
+ codebook_dim,
70
+ commitment=1.0,
71
+ decay=0.99, # EMA decay
72
+ epsilon=1e-5, # Laplace smoothing epsilon
73
+ threshold_ema_dead=2, # Dead code threshold
74
+ kmeans_init=True, # Use kmeans initialization
75
+ kmeans_iters=10, # Kmeans iterations
76
+ rotation_trick=False, # Use rotation trick
77
+ **kwargs,
78
+ ):
79
+ super().__init__()
80
+ self.input_dim = input_dim
81
+ self.codebook_size = codebook_size
82
+ self.codebook_dim = codebook_dim
83
+ self.commitment = commitment
84
+ self.decay = decay
85
+ self.epsilon = epsilon
86
+ self.threshold_ema_dead = threshold_ema_dead
87
+ self.kmeans_init = kmeans_init
88
+ self.kmeans_iters = kmeans_iters
89
+ self.rotation_trick = rotation_trick
90
+
91
+ if self.input_dim != self.codebook_dim:
92
+ self.in_project = nn.Linear(input_dim, codebook_dim)
93
+ self.out_project = nn.Linear(codebook_dim, input_dim)
94
+ else:
95
+ self.in_project = nn.Identity()
96
+ self.out_project = nn.Identity()
97
+
98
+ # Initialize codebook and EMA buffers
99
+ init_fn = torch.zeros if kmeans_init else lambda x, y: torch.randn(x, y)
100
+ self.register_buffer(
101
+ "codebook", init_fn(codebook_size, codebook_dim).float()
102
+ ) # (codebook_size, D'), ensure fp32
103
+ self.register_buffer("inited", torch.tensor([not kmeans_init], dtype=torch.bool)) # (1)
104
+ self.register_buffer("cluster_size", torch.zeros(codebook_size).float()) # (codebook_size), ensure fp32
105
+ self.register_buffer("embed_avg", self.codebook.clone().float()) # (codebook_size, D'), ensure fp32
106
+
107
+ def ema_update(self, encodings, embed_onehot):
108
+ # encodings: (B*T, D'), embed_onehot: (B*T, codebook_size)
109
+ """Update codebook using EMA"""
110
+ encodings = encodings.float() # Ensure fp32
111
+ embed_onehot = embed_onehot.float() # Ensure fp32
112
+ cluster_size_new = embed_onehot.sum(0) # (codebook_size)
113
+ embed_sum = encodings.t() @ embed_onehot # (D', codebook_size)
114
+
115
+ # Distributed reduction
116
+ if dist.is_initialized():
117
+ dist.all_reduce(cluster_size_new, op=dist.ReduceOp.SUM)
118
+ dist.all_reduce(embed_sum, op=dist.ReduceOp.SUM)
119
+
120
+ ema_inplace(self.cluster_size, cluster_size_new, self.decay) # (codebook_size)
121
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay) # (codebook_size, D')
122
+
123
+ # Laplace smoothing
124
+ cluster_size = (self.cluster_size + self.epsilon) / (
125
+ self.cluster_size.sum() + self.codebook_size * self.epsilon
126
+ ) # (codebook_size)
127
+ cluster_size = cluster_size * self.cluster_size.sum() # (codebook_size)
128
+ self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1)) # (codebook_size, D')
129
+
130
+ def replace_dead_codes(self, encodings):
131
+ # encodings: (B*T, D')
132
+ """Replace dead codes with random samples from current batch"""
133
+ if self.threshold_ema_dead == 0:
134
+ return
135
+
136
+ dead_mask = self.cluster_size < self.threshold_ema_dead # (codebook_size)
137
+ if dead_mask.any():
138
+ if dist.is_initialized() and dist.get_rank() == 0:
139
+ samples = sample_vectors(encodings.float(), self.codebook_size) # (codebook_size, D'), ensure fp32
140
+ print(f"Replace {dead_mask.sum().item()} dead codes")
141
+ else:
142
+ samples = torch.zeros_like(self.codebook).float() # Placeholder, ensure fp32
143
+
144
+ # Broadcast samples
145
+ if dist.is_initialized():
146
+ dist.broadcast(samples, src=0)
147
+
148
+ self.codebook[dead_mask] = samples[: dead_mask.sum()].to(self.codebook.dtype) # Update dead codes
149
+
150
+ def init_codebook(self, encodings):
151
+ # encodings: (B*T, D')
152
+ """Initialize codebook with k-means and update cluster_size"""
153
+ if self.inited.item():
154
+ return
155
+
156
+ if dist.is_initialized() and dist.get_rank() == 0:
157
+ embed, cluster_sizes = kmeans(
158
+ encodings.float(), self.codebook_size, self.kmeans_iters
159
+ ) # (codebook_size, D'), (codebook_size), ensure fp32
160
+ else:
161
+ embed = torch.zeros(self.codebook_size, self.codebook_dim, device=encodings.device).float() # ensure fp32
162
+ cluster_sizes = torch.zeros(self.codebook_size, device=encodings.device, dtype=torch.float32) # ensure fp32
163
+
164
+ # Broadcast results
165
+ if dist.is_initialized():
166
+ dist.broadcast(embed, src=0)
167
+ dist.broadcast(cluster_sizes, src=0)
168
+
169
+ self.codebook.copy_(embed) # (codebook_size, D')
170
+ self.embed_avg.copy_(embed.clone()) # (codebook_size, D')
171
+ self.cluster_size.copy_(cluster_sizes.float()) # (codebook_size)
172
+ self.inited.fill_(True)
173
+
174
+ def forward(self, z):
175
+ self = self.to(torch.float32)
176
+ z = z.float()
177
+ z_e = self.in_project(z).float()
178
+
179
+ # Rearrange for quantization
180
+ encodings = rearrange(z_e, "b t d -> (b t) d").float() # (B*T, D'), ensure fp32
181
+
182
+ # Initialize codebook if needed
183
+ if self.kmeans_init and not self.inited.item():
184
+ self.init_codebook(encodings)
185
+
186
+ dist = (
187
+ encodings.pow(2).sum(1, keepdim=True)
188
+ - 2 * encodings @ self.codebook.float().t()
189
+ + self.codebook.float().pow(2).sum(1, keepdim=True).t()
190
+ )
191
+ indices = (-dist).max(1)[1]
192
+
193
+ # cosine_similarity = F.cosine_similarity(encodings[None], self.codebook[:, None], dim=-1)
194
+ # indices = cosine_similarity.max(dim=0)[1]
195
+
196
+ indices = rearrange(indices, "(b t) -> b t", b=z.size(0))
197
+ z_q = self.decode_code(indices).float()
198
+ commit_loss = F.mse_loss(z_e, z_q.detach()) * self.commitment
199
+
200
+ if self.training and torch.is_grad_enabled():
201
+ embed_onehot = F.one_hot(indices.view(-1), self.codebook_size).float()
202
+ self.ema_update(encodings, embed_onehot)
203
+ self.replace_dead_codes(encodings)
204
+
205
+ z_q = (z_q - z_e).detach() + z_e
206
+ z_q = self.out_project(z_q).float()
207
+
208
+ return (
209
+ z_q,
210
+ commit_loss,
211
+ torch.tensor(0.0, device=z.device, dtype=torch.float32),
212
+ indices,
213
+ z_e,
214
+ )
215
+
216
+ def decode_code(self, embed_id): # embed_id: (B, T)
217
+ return F.embedding(embed_id, self.codebook).float() # (B, D', T), ensure fp32
218
+
219
+
220
+ # class VectorQuantize(nn.Module):
221
+ # """
222
+ # Implementation of VQ similar to Karpathy's repo:
223
+ # https://github.com/karpathy/deep-vector-quantization
224
+ # Additionally uses following tricks from Improved VQGAN
225
+ # (https://arxiv.org/pdf/2110.04627.pdf):
226
+ # 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
227
+ # for improved codebook usage
228
+ # 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
229
+ # improves training stability
230
+ # """
231
+
232
+ # def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
233
+ # super().__init__()
234
+ # self.codebook_size = codebook_size
235
+ # self.codebook_dim = codebook_dim
236
+
237
+ # self.in_proj = nn.Linear(input_dim, codebook_dim)
238
+ # self.out_proj = nn.Linear(codebook_dim, input_dim)
239
+ # self.codebook = nn.Embedding(codebook_size, codebook_dim)
240
+
241
+ # def forward(self, z: torch.Tensor):
242
+ # """
243
+ # Args:
244
+ # z (torch.Tensor): shape (b, t, d)
245
+
246
+ # Returns:
247
+ # z_q (torch.Tensor): shape (b, t, d)
248
+ # commitment_loss (torch.Tensor): shape (1)
249
+ # codebook_loss (torch.Tensor): shape (1)
250
+ # indices (torch.Tensor): shape (b, t)
251
+ # z_e (torch.Tensor): shape (b, t, d)
252
+ # """
253
+
254
+ # # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
255
+ # z_e = self.in_proj(z)
256
+ # z_q, indices = self.decode_latents(z_e)
257
+
258
+ # commitment_loss = F.mse_loss(z_e, z_q.detach()) * 0.25
259
+ # codebook_loss = F.mse_loss(z_q, z_e.detach())
260
+
261
+ # z_q = z_e + (z_q - z_e).detach() # noop in forward pass, straight-through gradient estimator in backward pass
262
+
263
+ # z_q = self.out_proj(z_q)
264
+
265
+ # return z_q, commitment_loss, codebook_loss, indices, z_e
266
+
267
+ # def embed_code(self, embed_id):
268
+ # return F.embedding(embed_id, self.codebook.weight)
269
+
270
+ # def decode_code(self, embed_id):
271
+ # return self.embed_code(embed_id)
272
+
273
+ # def decode_latents(self, latents: torch.Tensor):
274
+ # codebook = self.codebook.weight
275
+ # encodings = rearrange(latents, "b t d -> (b t) d")
276
+
277
+ # cosine_similarity = F.cosine_similarity(encodings[None], codebook[:, None], dim=-1)
278
+ # indices = cosine_similarity.max(dim=0)[1]
279
+ # indices = rearrange(indices, "(b t) -> b t", b=latents.size(0))
280
+
281
+ # # encodings = F.normalize(encodings)
282
+ # # codebook = F.normalize(codebook)
283
+ # # dist = (
284
+ # # encodings.pow(2).sum(1, keepdim=True)
285
+ # # - 2 * encodings @ codebook.t()
286
+ # # + codebook.pow(2).sum(1, keepdim=True).t()
287
+ # # )
288
+ # # indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
289
+
290
+ # z_q = self.decode_code(indices)
291
+ # return z_q, indices
292
+
293
+
294
+ class ResidualVectorQuantize(nn.Module):
295
+ def __init__(
296
+ self,
297
+ dim: int = 256,
298
+ n_codebooks: int = 4,
299
+ codebook_size: int = 512,
300
+ codebook_dim: Union[int, list] = 8,
301
+ quantizer_dropout: float = 0.25,
302
+ commitment: float = 0.25,
303
+ decay: float = 0.99,
304
+ epsilon: float = 1e-5,
305
+ threshold_ema_dead: int = 2,
306
+ kmeans_init: bool = True,
307
+ kmeans_iters: int = 10,
308
+ rotation_trick: bool = False,
309
+ ):
310
+ super().__init__()
311
+ if isinstance(codebook_dim, int):
312
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
313
+
314
+ self.n_codebooks = n_codebooks
315
+ self.codebook_dim = codebook_dim
316
+ self.codebook_size = codebook_size
317
+
318
+ self.quantizers = nn.ModuleList(
319
+ [
320
+ VectorQuantize(
321
+ input_dim=dim,
322
+ codebook_size=codebook_size,
323
+ codebook_dim=codebook_dim[i],
324
+ commitment=commitment,
325
+ decay=decay,
326
+ epsilon=epsilon,
327
+ threshold_ema_dead=threshold_ema_dead,
328
+ kmeans_init=kmeans_init,
329
+ kmeans_iters=kmeans_iters,
330
+ rotation_trick=rotation_trick,
331
+ )
332
+ for i in range(n_codebooks)
333
+ ]
334
+ )
335
+ self.quantizer_dropout = quantizer_dropout
336
+
337
+ def forward(self, z, n_quantizers: int = None):
338
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
339
+ the corresponding codebook vectors
340
+ Parameters
341
+ ----------
342
+ z : Tensor[B x D x T]
343
+ n_quantizers : int, optional
344
+ No. of quantizers to use
345
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
346
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
347
+ when in training mode, and a random number of quantizers is used.
348
+ Returns
349
+ -------
350
+ dict
351
+ A dictionary with the following keys:
352
+
353
+ "z" : Tensor[B x D x T]
354
+ Quantized continuous representation of input
355
+ "codes" : Tensor[B x N x T]
356
+ Codebook indices for each codebook
357
+ (quantized discrete representation of input)
358
+ "latents" : Tensor[B x N*D x T]
359
+ Projected latents (continuous representation of input before quantization)
360
+ "vq/commitment_loss" : Tensor[1]
361
+ Commitment loss to train encoder to predict vectors closer to codebook
362
+ entries
363
+ "vq/codebook_loss" : Tensor[1]
364
+ Codebook loss to update the codebook
365
+ """
366
+ z_q, residual = 0, z
367
+ commitment_loss, codebook_loss = 0, 0
368
+
369
+ codebook_indices, latents = [], []
370
+
371
+ if n_quantizers is None:
372
+ n_quantizers = self.n_codebooks
373
+ if self.training:
374
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
375
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
376
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
377
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
378
+ n_quantizers = n_quantizers.to(z.device)
379
+
380
+ for i, quantizer in enumerate(self.quantizers):
381
+ if self.training is False and i >= n_quantizers:
382
+ break
383
+
384
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(residual)
385
+
386
+ # Create mask to apply quantizer dropout
387
+ mask = torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
388
+ z_q = z_q + z_q_i * mask[:, None, None]
389
+ residual = residual - z_q_i
390
+
391
+ # Sum losses
392
+ commitment_loss += (commitment_loss_i * mask).mean()
393
+ codebook_loss += (codebook_loss_i * mask).mean()
394
+
395
+ codebook_indices.append(indices_i)
396
+ latents.append(z_e_i)
397
+
398
+ codes = torch.stack(codebook_indices, dim=-1)
399
+ latents = torch.cat(latents, dim=1)
400
+
401
+ return z_q, codes, latents, commitment_loss, codebook_loss
402
+
403
+ def from_codes(self, codes: torch.Tensor):
404
+ """Given the quantized codes, reconstruct the continuous representation
405
+ Parameters
406
+ ----------
407
+ codes : Tensor[B x N x T]
408
+ Quantized discrete representation of input
409
+ Returns
410
+ -------
411
+ Tensor[B x D x T]
412
+ Quantized continuous representation of input
413
+ """
414
+ z_q = 0.0
415
+ z_p = []
416
+ n_codebooks = codes.shape[-1]
417
+ for i in range(n_codebooks):
418
+ z_p_i = self.quantizers[i].decode_code(codes[..., i])
419
+ z_p.append(z_p_i)
420
+
421
+ z_q_i = self.quantizers[i].out_project(z_p_i)
422
+ z_q = z_q + z_q_i
423
+ return z_q, torch.cat(z_p, dim=-1), codes
424
+
425
+ def from_latents(self, latents: torch.Tensor):
426
+ """Given the unquantized latents, reconstruct the
427
+ continuous representation after quantization.
428
+
429
+ Parameters
430
+ ----------
431
+ latents : Tensor[B x N x T]
432
+ Continuous representation of input after projection
433
+
434
+ Returns
435
+ -------
436
+ Tensor[B x D x T]
437
+ Quantized representation of full-projected space
438
+ Tensor[B x D x T]
439
+ Quantized representation of latent space
440
+ """
441
+ z_q = 0
442
+ z_p = []
443
+ codes = []
444
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
445
+
446
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
447
+ for i in range(n_codebooks):
448
+ j, k = dims[i], dims[i + 1]
449
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
450
+ z_p.append(z_p_i)
451
+ codes.append(codes_i)
452
+
453
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
454
+ z_q = z_q + z_q_i
455
+
456
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
457
+
458
+
459
+ class IndependentVectorQuantize(nn.Module):
460
+ def __init__(self, num_codebooks: int = 1, **kwargs):
461
+ super().__init__()
462
+ self.vector_quantizers = nn.ModuleList([torchVQ(**kwargs) for _ in range(num_codebooks)])
463
+ self.num_codebooks = num_codebooks
464
+ self.codebook_size = self.vector_quantizers[0].codebook_size
465
+
466
+ @property
467
+ def ema_update(self):
468
+ return [vq.ema_update for vq in self.vector_quantizers]
469
+
470
+ @property
471
+ def codebook(self):
472
+ return torch.stack([vq.codebook for vq in self.vector_quantizers], dim=0)
473
+
474
+ @codebook.setter
475
+ def codebook(self, codes: List[torch.Tensor]):
476
+ assert len(codes) == self.num_codebooks, "Number of codebooks must match"
477
+ if not self.separate_codebook_per_head:
478
+ codes = rearrange(codes, "... -> 1 ...")
479
+
480
+ for i, code in enumerate(codes):
481
+ self.vector_quantizers[i].codebook.copy_(code)
482
+
483
+ def get_codes_from_indices(self, indices: torch.Tensor):
484
+ codes = list()
485
+ for i in range(self.num_codebooks):
486
+ codes.append(self.vector_quantizers[i].get_codes_from_indices(indices[..., i : i + 1]))
487
+ return torch.cat(codes, dim=-2)
488
+
489
+ def get_output_from_indices(self, indices: torch.Tensor):
490
+ outputs = list()
491
+ for i in range(self.num_codebooks):
492
+ outputs.append(self.vector_quantizers[i].get_output_from_indices(indices[..., i : i + 1]))
493
+ return torch.cat(outputs, dim=-2)
494
+
495
+ def update_in_place_optimizer(self):
496
+ for i in range(self.num_codebooks):
497
+ self.vector_quantizers[i].update_in_place_optimizer()
498
+
499
+ def forward(self, x: torch.Tensor, *args, **kwargs):
500
+ assert x.shape[1] == self.num_codebooks
501
+ quantized, indices, commit_losses = list(), list(), 0
502
+ for i in range(self.num_codebooks):
503
+ quantized_i, indices_i, commit_loss_i = self.vector_quantizers[i](x[:, i : i + 1])
504
+ quantized.append(quantized_i)
505
+ indices.append(indices_i)
506
+ commit_losses += commit_loss_i
507
+ quantized = torch.cat(quantized, dim=-2)
508
+ indices = torch.cat(indices, dim=-1)
509
+ return quantized, indices, commit_losses / self.num_codebooks
510
+
511
+
512
+ if __name__ == "__main__":
513
+ vq = IndependentVectorQuantize(
514
+ num_codebooks=16,
515
+ dim=256,
516
+ codebook_size=2048,
517
+ decay=0.8, # the exponential moving average decay, lower means the dictionary will change faster
518
+ commitment_weight=1.0, # the weight on the commitment loss
519
+ )
520
+
521
+ x = torch.randn(1, 16, 256)
522
+ quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)