ZibinDong commited on
Commit
1bc40b2
·
verified ·
1 Parent(s): ae7b990

Upload folder using huggingface_hub

Browse files
config.json CHANGED
@@ -2,10 +2,6 @@
2
  "architectures": [
3
  "ActionCodec"
4
  ],
5
- "auto_map": {
6
- "AutoConfig": "configuration_actioncodec.ActionCodecConfig",
7
- "AutoModel": "modeling_actioncodec.ActionCodec"
8
- },
9
  "decoder_add_causal_mask": false,
10
  "decoder_add_self_attn": false,
11
  "decoder_cls_size": 1,
 
2
  "architectures": [
3
  "ActionCodec"
4
  ],
 
 
 
 
5
  "decoder_add_causal_mask": false,
6
  "decoder_add_self_attn": false,
7
  "decoder_cls_size": 1,
configuration_actioncodec.py CHANGED
@@ -225,4 +225,6 @@ class BPEActionCodecConfig(PretrainedConfig):
225
  AutoConfig.register("action_codec", ActionCodecConfig)
226
  AutoConfig.register("bpe_action_codec", BPEActionCodecConfig)
227
 
 
 
228
  __all__ = ["ActionCodecConfig", "BPEActionCodecConfig"]
 
225
  AutoConfig.register("action_codec", ActionCodecConfig)
226
  AutoConfig.register("bpe_action_codec", BPEActionCodecConfig)
227
 
228
+ ActionCodecConfig.register_for_auto_class()
229
+
230
  __all__ = ["ActionCodecConfig", "BPEActionCodecConfig"]
modeling_actioncodec.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import List
2
 
3
  import einops
4
  import numpy as np
@@ -28,17 +28,67 @@ def trim_trailing_zeros(arr: np.ndarray) -> list[np.ndarray]:
28
 
29
 
30
  class ActionCodec(PreTrainedModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  config_class = ActionCodecConfig
32
 
33
  def __init__(self, config: ActionCodecConfig):
 
 
 
 
 
 
 
 
 
 
34
  super().__init__(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  self.default_embodiment_id = 0
36
 
 
37
  self.encoder = PerceiverEncoder(config)
38
  self.decoder = PerceiverDecoder(config)
39
 
 
40
  if config.vq_type == "vq":
41
- assert config.n_quantizers == 1, "Only one quantizer is supported for VQ"
 
 
 
42
  self.vq = VectorQuantize(
43
  dim=config.z_dim,
44
  codebook_size=config.vq_codebook_size,
@@ -50,7 +100,10 @@ class ActionCodec(PreTrainedModel):
50
  straight_through=True,
51
  )
52
  elif config.vq_type == "rvq":
53
- assert config.n_quantizers > 1, "At least two quantizers are supported for RVQ"
 
 
 
54
  self.vq = ResidualVectorQuantize(
55
  dim=config.z_dim,
56
  n_codebooks=config.n_quantizers,
@@ -60,17 +113,57 @@ class ActionCodec(PreTrainedModel):
60
  commitment=config.vq_commitment_weight,
61
  )
62
  else:
63
- raise NotImplementedError(f"VQ type {config.vq_type} not implemented")
64
 
 
65
  self.vocab_size = config.vq_codebook_size
66
  self.num_quantizers = config.n_quantizers
67
  self.n_tokens_per_quantizer = config.n_tokens // config.n_quantizers
68
 
69
  def expand_embodiment(self, embodiment_config: dict):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  """
71
- Delegates expansion to the underlying Encoder and Decoder.
72
- This allows the Codec to adapt to new robots dynamically.
73
- """
 
 
 
 
 
 
 
74
  self.encoder.expand_embodiment(embodiment_config)
75
  self.decoder.expand_embodiment(embodiment_config)
76
  self.config.embodiment_config.update(embodiment_config)
@@ -101,7 +194,28 @@ class ActionCodec(PreTrainedModel):
101
  z_e = self.encoder(x, embodiment_ids, padding_mask)
102
  return z_e
103
 
104
- def _quantize(self, z_e: torch.Tensor, return_perplexity: bool = True) -> List[torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  if isinstance(self.vq, ResidualVectorQuantize):
106
  z_q, indices, _, commitment_loss, codebook_loss = self.vq(z_e)
107
  commit_loss = commitment_loss.mean() + codebook_loss.mean()
@@ -127,18 +241,50 @@ class ActionCodec(PreTrainedModel):
127
  return z_q, indices, perplexity, commit_loss
128
 
129
  def _dequantize(self, indices: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
130
  if self.num_quantizers == 1:
131
  if len(indices.size()) == 3:
132
  indices = indices.squeeze(-1)
133
  if isinstance(self.vq, ResidualVectorQuantize):
134
  z_q = self.vq.from_codes(indices)[0]
135
- else:
136
  z_q = self.vq.get_output_from_indices(indices)
 
 
137
  return z_q
138
 
139
  def _decode(
140
  self, z_q: torch.Tensor, embodiment_ids: torch.Tensor | int | None = None, durations: torch.Tensor | None = None
141
- ) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
143
  x_recon, padding_mask = self.decoder(z_q, embodiment_ids, durations)
144
  return x_recon, padding_mask
@@ -146,275 +292,331 @@ class ActionCodec(PreTrainedModel):
146
  @torch.no_grad()
147
  def encode(
148
  self,
149
- x: np.ndarray,
150
- embodiment_ids: List[int] | int | None = None,
151
- padding_mask: List[bool] | None = None,
 
152
  ) -> List[List[int]]:
153
- """Encode action sequences into latent representations.
 
 
 
154
 
155
  Args:
156
- x (np.ndarray): Action sequences to encode. Shape: (b, seq_len, max_action_dim).
 
157
  Assumes that the action dimension is zero-padded to the max action dimension.
158
- `seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length.
159
- embodiment_ids (List[int] | int): Embodiment IDs. Shape: (b,).
160
- If int, the same embodiment ID is repeated for all sequences in the batch.
161
- It specifies the embodiment to encode.
162
- padding_mask (List[bool] | None): Padding mask, where `False` values indicate padding. Shape: (b, seq_len). Defaults to None.
163
- It is used to mask the padding tokens on `seq_len` dimension.
 
 
 
 
164
 
165
  Returns:
166
- List[List[int]]: List of token sequences. Shape: (b, n_tokens).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  """
168
  self.eval()
169
- embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
170
 
171
- with torch.no_grad():
 
 
 
 
 
172
  x_tensor = torch.tensor(x, dtype=self.dtype, device=self.device)
173
- if not isinstance(embodiment_ids, int):
174
- embodiment_ids = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device)
175
- if padding_mask is not None:
176
- padding_mask = torch.tensor(padding_mask, dtype=torch.bool, device=self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- z_e = self._encode(x_tensor, embodiment_ids, padding_mask)
 
179
  _, indices, _, _ = self._quantize(z_e, return_perplexity=False)
 
 
180
  if len(indices.size()) > 2:
181
  codes_list = einops.rearrange(indices, "b n s -> b (s n)").cpu()
182
  else:
183
  codes_list = indices.cpu()
 
184
  codes_list = codes_list.tolist()
185
  return codes_list
186
 
187
  @torch.no_grad()
188
  def decode(
189
- self, tokens: List[List[int]], embodiment_ids: List[int] | int | None = None, durations: List[float] | None = None
190
- ) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  self.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
193
- tokens = torch.tensor(tokens, dtype=torch.long, device=self.device)
194
- if not isinstance(embodiment_ids, int):
195
- embodiment_ids = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  if durations is not None:
197
- durations = torch.tensor(durations, dtype=torch.float32, device=self.device)
198
-
199
- b, n = tokens.shape
200
- assert n % self.n_tokens_per_quantizer == 0, (
201
- f"Expected {self.n_tokens_per_quantizer} tokens per quantizer, got {n} in total."
202
- )
203
- indices = einops.rearrange(tokens, "b (n m) -> b m n", m=self.n_tokens_per_quantizer)
204
- z_q = self._dequantize(indices)
205
- x_recon, padding_mask = self._decode(z_q, embodiment_ids, durations)
206
- return x_recon.cpu().numpy(), padding_mask.cpu().numpy()
207
-
208
- # def sparse_encode(
209
- # self,
210
- # x: np.ndarray,
211
- # search_num: int = 10,
212
- # threshold: float = 0.1,
213
- # action_encoding: str | None = None,
214
- # remove_padding: bool = True,
215
- # ) -> List[List[int]]:
216
- # """
217
- # Sparse encoding with adaptive token selection based on reconstruction error threshold.
218
- # Uses quaternary search to find optimal token length.
219
-
220
- # Args:
221
- # x: Input action arrays of shape (b, n, d)
222
- # search_num: Maximum number of search iterations
223
- # threshold: Reconstruction error threshold
224
- # action_encoding: Action encoding type
225
- # remove_padding: Whether to remove trailing zeros
226
-
227
- # Returns:
228
- # List of sparse token sequences
229
- # """
230
- # self.eval()
231
- # with torch.no_grad():
232
- # x_tensor = self._numpy_to_tensor(x)
233
-
234
- # # Get initial encoding
235
- # z_e = self._encode(x_tensor, action_encoding)
236
- # _, indices, _, _ = self._quantize(z_e, return_perplexity=False)
237
-
238
- # # Convert indices to proper format
239
- # if len(indices.size()) > 2:
240
- # indices_flat = einops.rearrange(indices, "b n s -> b (s n)")
241
- # else:
242
- # indices_flat = indices
243
-
244
- # # Use quaternary search to find optimal token lengths
245
- # optimal_lengths = self._quaternary_search(x_tensor, indices_flat, threshold, search_num, action_encoding)
246
-
247
- # # Create final sparse tokens based on optimal lengths
248
- # final_tokens = self._create_sparse_tokens_from_lengths(indices_flat, optimal_lengths)
249
-
250
- # # Convert to list format
251
- # if remove_padding:
252
- # final_tokens = trim_trailing_zeros(final_tokens.cpu().numpy())
253
- # else:
254
- # final_tokens = final_tokens.cpu().tolist()
255
-
256
- # return final_tokens
257
-
258
- # def _quaternary_search(
259
- # self,
260
- # x_tensor: torch.Tensor,
261
- # indices_flat: torch.Tensor,
262
- # threshold: float,
263
- # search_num: int,
264
- # action_encoding: str | None = None,
265
- # ) -> torch.Tensor:
266
- # """
267
- # Quaternary search to find optimal token lengths for each batch item.
268
- # Returns tensor of shape (batch_size,) containing optimal lengths.
269
- # """
270
- # batch_size, seq_len = indices_flat.shape
271
-
272
- # # Initialize search bounds
273
- # device = indices_flat.device
274
- # left = torch.ones(batch_size, dtype=torch.long, device=device)
275
- # right = torch.full((batch_size,), seq_len, dtype=torch.long, device=device)
276
-
277
- # # Perform quaternary search
278
- # for _ in range(search_num):
279
- # # Calculate three division points
280
- # range_size = right - left
281
- # q1 = left + range_size // 4
282
- # q2 = left + range_size // 2
283
- # q3 = left + 3 * range_size // 4
284
-
285
- # # Ensure q1, q2, q3 are within bounds and distinct
286
- # q1 = torch.clamp(q1, left, right)
287
- # q2 = torch.clamp(q2, q1 + 1, right)
288
- # q3 = torch.clamp(q3, q2 + 1, right)
289
-
290
- # # Create test lengths: [left, q1, q2, q3, right]
291
- # test_lengths = torch.stack([left, q1, q2, q3, right], dim=1) # (batch_size, 5)
292
-
293
- # # Calculate errors for all test lengths
294
- # errors = self._calculate_errors_for_lengths(x_tensor, indices_flat, test_lengths, action_encoding)
295
-
296
- # # Update search bounds based on results (vectorized)
297
- # # Find which lengths meet threshold for each batch item
298
- # meets_threshold = errors <= threshold
299
-
300
- # # For each batch item, find the smallest length that meets threshold
301
- # valid_indices = torch.argmax(meets_threshold.float(), dim=1) # First True index
302
- # has_valid = meets_threshold.any(dim=1) # Whether any length meets threshold
303
-
304
- # # Create batch indices for advanced indexing
305
- # batch_indices = torch.arange(batch_size, device=device)
306
-
307
- # # Get the smallest valid length for each batch
308
- # smallest_valid_lengths = test_lengths[batch_indices, valid_indices]
309
-
310
- # # Update bounds based on results
311
- # # If has valid length, use it; otherwise use longest length
312
- # right = torch.where(has_valid, smallest_valid_lengths, test_lengths[:, -1])
313
-
314
- # # Update left bound: if we found a valid length and it's not the first one,
315
- # # use the previous length; otherwise keep current left
316
- # prev_lengths = torch.where(valid_indices > 0, test_lengths[batch_indices, valid_indices - 1], left)
317
- # left = torch.where(has_valid & (valid_indices > 0), prev_lengths, left)
318
-
319
- # # Check convergence
320
- # if (right - left).max() <= 1:
321
- # break
322
-
323
- # return right # Return optimal lengths
324
-
325
- # def _calculate_errors_for_lengths(
326
- # self,
327
- # x_tensor: torch.Tensor,
328
- # indices_flat: torch.Tensor,
329
- # test_lengths: torch.Tensor,
330
- # action_encoding: str | None = None,
331
- # ) -> torch.Tensor:
332
- # """
333
- # Calculate reconstruction errors for given token lengths.
334
-
335
- # Args:
336
- # x_tensor: Original input tensor (batch_size, ...)
337
- # indices_flat: Full token indices (batch_size, seq_len)
338
- # test_lengths: Test lengths tensor (batch_size, num_tests)
339
- # action_encoding: Action encoding type
340
-
341
- # Returns:
342
- # Error tensor (batch_size, num_tests)
343
- # """
344
- # # Create sparse tokens for all test lengths (vectorized)
345
- # batch_size, num_tests = test_lengths.shape
346
- # seq_len = indices_flat.shape[1]
347
- # device = indices_flat.device
348
-
349
- # # Create position tensor for all combinations
350
- # positions = torch.arange(seq_len, device=device).unsqueeze(0).unsqueeze(0) # (1, 1, seq_len)
351
- # positions = positions.expand(batch_size, num_tests, -1) # (batch_size, num_tests, seq_len)
352
-
353
- # # Create length mask: positions < test_lengths
354
- # length_mask = positions < test_lengths.unsqueeze(2) # (batch_size, num_tests, seq_len)
355
-
356
- # # Create sparse tokens using advanced indexing
357
- # sparse_tokens = torch.where(
358
- # length_mask,
359
- # indices_flat.unsqueeze(1).expand(-1, num_tests, -1),
360
- # torch.zeros_like(indices_flat).unsqueeze(1).expand(-1, num_tests, -1),
361
- # )
362
-
363
- # # Reshape for parallel processing
364
- # sparse_flat = sparse_tokens.view(batch_size * num_tests, seq_len)
365
-
366
- # # Decode all sparse tokens in parallel
367
- # reconstructed_flat = self._decode_sparse_tokens(sparse_flat, action_encoding)
368
-
369
- # # Reshape back and calculate errors
370
- # reconstructed = reconstructed_flat.view(batch_size, num_tests, *x_tensor.shape[1:])
371
-
372
- # # Calculate errors
373
- # x_expanded = x_tensor.unsqueeze(1).expand(-1, num_tests, -1, -1)
374
- # errors = (x_expanded - reconstructed).abs().mean((-1, -2)) # (batch_size, num_tests)
375
-
376
- # return errors
377
-
378
- # def _decode_sparse_tokens(self, sparse_tokens: torch.Tensor, action_encoding: str | None = None) -> torch.Tensor:
379
- # """Decode sparse tokens to reconstructed data."""
380
- # batch_size, seq_len = sparse_tokens.shape
381
-
382
- # # Convert to proper indices format for dequantization
383
- # if self.num_quantizers > 1:
384
- # seq_len_per_quantizer = seq_len // self.num_quantizers
385
- # if seq_len % self.num_quantizers != 0:
386
- # raise ValueError("Sequence length must be divisible by num_quantizers")
387
-
388
- # indices_for_decode = sparse_tokens.view(batch_size, self.num_quantizers, seq_len_per_quantizer).transpose(
389
- # 1, 2
390
- # ) # (batch_size, seq_len_per_quantizer, num_quantizers)
391
- # else:
392
- # indices_for_decode = sparse_tokens.unsqueeze(-1) # (batch_size, seq_len, 1)
393
-
394
- # # Dequantize and decode
395
- # z_q = self._dequantize(indices_for_decode)
396
- # reconstructed = self._decode(z_q, action_encoding)
397
-
398
- # return reconstructed
399
-
400
- # def _create_sparse_tokens_from_lengths(
401
- # self, indices_flat: torch.Tensor, optimal_lengths: torch.Tensor
402
- # ) -> torch.Tensor:
403
- # """Create sparse tokens based on optimal lengths (vectorized)."""
404
- # batch_size, seq_len = indices_flat.shape
405
- # device = indices_flat.device
406
-
407
- # # Create position mask for all batch items simultaneously
408
- # positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) # (batch_size, seq_len)
409
- # length_mask = positions < optimal_lengths.unsqueeze(1) # (batch_size, seq_len)
410
-
411
- # # Apply mask to create sparse tokens
412
- # result = torch.where(length_mask, indices_flat, torch.zeros_like(indices_flat))
413
-
414
- # return result
415
-
416
- def forward(self, x: torch.Tensor, embodiment_ids: int | None = None, padding_mask: List[bool] | None = None):
417
- return self.encode(x, embodiment_ids, padding_mask)
418
 
419
 
420
  AutoModel.register(ActionCodecConfig, ActionCodec)
 
1
+ from typing import List, Tuple, Union
2
 
3
  import einops
4
  import numpy as np
 
28
 
29
 
30
  class ActionCodec(PreTrainedModel):
31
+ """ActionCodec: A neural codec for encoding and decoding robot action sequences.
32
+
33
+ This model uses a Perceiver-based encoder-decoder architecture with vector quantization
34
+ to convert continuous action sequences into discrete token sequences. It supports
35
+ multiple robot embodiments with different action dimensions and control frequencies.
36
+
37
+ The model supports two vector quantization types:
38
+ - VQ (Vector Quantization): Single quantizer
39
+ - RVQ (Residual Vector Quantization): Multiple quantizers for hierarchical encoding
40
+
41
+ Key features:
42
+ - Multi-embodiment support: Handle different robots with varying action dimensions
43
+ - Dynamic expansion: Add new robot configurations without retraining
44
+ - Flexible input/output: Support numpy arrays and torch tensors
45
+ """
46
+
47
  config_class = ActionCodecConfig
48
 
49
  def __init__(self, config: ActionCodecConfig):
50
+ """Initialize the ActionCodec model.
51
+
52
+ Args:
53
+ config (ActionCodecConfig): Model configuration containing hyperparameters
54
+ and embodiment configurations.
55
+
56
+ Raises:
57
+ ValueError: If configuration parameters are invalid.
58
+ NotImplementedError: If the specified VQ type is not supported.
59
+ """
60
  super().__init__(config)
61
+
62
+ # Validate configuration
63
+ if config.n_tokens % config.n_quantizers != 0:
64
+ raise ValueError(f"n_tokens ({config.n_tokens}) must be divisible by n_quantizers ({config.n_quantizers})")
65
+
66
+ if config.n_quantizers < 1:
67
+ raise ValueError(f"n_quantizers must be at least 1, got {config.n_quantizers}")
68
+
69
+ if config.vq_codebook_size < 1:
70
+ raise ValueError(f"vq_codebook_size must be at least 1, got {config.vq_codebook_size}")
71
+
72
+ if config.z_dim < 1:
73
+ raise ValueError(f"z_dim must be at least 1, got {config.z_dim}")
74
+
75
+ if not isinstance(config.embodiment_config, dict) or len(config.embodiment_config) == 0:
76
+ raise ValueError(
77
+ "embodiment_config must be a non-empty dictionary mapping embodiment names to configurations"
78
+ )
79
+
80
  self.default_embodiment_id = 0
81
 
82
+ # Initialize encoder and decoder
83
  self.encoder = PerceiverEncoder(config)
84
  self.decoder = PerceiverDecoder(config)
85
 
86
+ # Initialize vector quantizer based on type
87
  if config.vq_type == "vq":
88
+ if config.n_quantizers != 1:
89
+ raise ValueError(
90
+ f"VQ type requires n_quantizers=1, got {config.n_quantizers}. Use RVQ type for multiple quantizers."
91
+ )
92
  self.vq = VectorQuantize(
93
  dim=config.z_dim,
94
  codebook_size=config.vq_codebook_size,
 
100
  straight_through=True,
101
  )
102
  elif config.vq_type == "rvq":
103
+ if config.n_quantizers < 2:
104
+ raise ValueError(
105
+ f"RVQ type requires n_quantizers >= 2, got {config.n_quantizers}. Use VQ type for single quantizer."
106
+ )
107
  self.vq = ResidualVectorQuantize(
108
  dim=config.z_dim,
109
  n_codebooks=config.n_quantizers,
 
113
  commitment=config.vq_commitment_weight,
114
  )
115
  else:
116
+ raise NotImplementedError(f"VQ type '{config.vq_type}' not implemented. Supported types: 'vq', 'rvq'")
117
 
118
+ # Store quantization-related attributes
119
  self.vocab_size = config.vq_codebook_size
120
  self.num_quantizers = config.n_quantizers
121
  self.n_tokens_per_quantizer = config.n_tokens // config.n_quantizers
122
 
123
  def expand_embodiment(self, embodiment_config: dict):
124
+ """Dynamically expand the model to support new robot embodiments.
125
+
126
+ This method allows adding new robot configurations to the codec without retraining
127
+ the entire model. It updates the encoder and decoder to handle the new action dimensions
128
+ and frequencies while preserving existing functionality for previously configured robots.
129
+
130
+ Args:
131
+ embodiment_config (dict): Dictionary mapping embodiment names to their configurations.
132
+ Each configuration should be a dict with keys:
133
+ - "action_dim" (int): Action dimensionality for this embodiment.
134
+ - "freq" (float): Control frequency in Hz.
135
+ - "duration" (float): Default action sequence duration in seconds.
136
+ - "description" (str, optional): Human-readable description.
137
+
138
+ Example:
139
+ {
140
+ "robot_B": {
141
+ "action_dim": 10,
142
+ "freq": 20,
143
+ "duration": 1.0,
144
+ "description": "10-dim robot at 20Hz"
145
+ }
146
+ }
147
+
148
+ Returns:
149
+ ActionCodec: Returns self for method chaining.
150
+
151
+ Note:
152
+ - New embodiment keys must not already exist in the current configuration.
153
+ - The model will automatically update max_action_dim if the new embodiment
154
+ has a larger action dimension.
155
+ - Existing embodiments will continue to work with their original configurations.
156
  """
157
+ if not isinstance(embodiment_config, dict):
158
+ raise TypeError(f"embodiment_config must be a dict, got {type(embodiment_config)}")
159
+ if len(embodiment_config) == 0:
160
+ raise ValueError("embodiment_config cannot be empty")
161
+
162
+ # Check for duplicate keys
163
+ overlapping_keys = set(embodiment_config.keys()) & set(self.config.embodiment_config.keys())
164
+ if overlapping_keys:
165
+ raise ValueError(f"The following embodiment keys already exist and cannot be redefined: {overlapping_keys}")
166
+
167
  self.encoder.expand_embodiment(embodiment_config)
168
  self.decoder.expand_embodiment(embodiment_config)
169
  self.config.embodiment_config.update(embodiment_config)
 
194
  z_e = self.encoder(x, embodiment_ids, padding_mask)
195
  return z_e
196
 
197
+ def _quantize(
198
+ self, z_e: torch.Tensor, return_perplexity: bool = True
199
+ ) -> Tuple[torch.Tensor, torch.Tensor, Union[float, List[float]], torch.Tensor]:
200
+ """Quantize encoded representations using vector quantization.
201
+
202
+ Args:
203
+ z_e (torch.Tensor): Encoded latent representations to quantize.
204
+ Shape: (b, n_tokens_per_quantizer, z_dim).
205
+ return_perplexity (bool, optional): Whether to compute and return perplexity.
206
+ Defaults to True.
207
+
208
+ Returns:
209
+ Tuple[torch.Tensor, torch.Tensor, Union[float, List[float]], torch.Tensor]:
210
+ A tuple containing:
211
+ - z_q (torch.Tensor): Quantized representations.
212
+ Shape: (b, n_tokens_per_quantizer, z_dim).
213
+ - indices (torch.Tensor): Quantization indices.
214
+ Shape: (b, n_tokens_per_quantizer) for VQ or (b, n_tokens_per_quantizer, n_quantizers) for RVQ.
215
+ - perplexity (Union[float, List[float]]): Codebook perplexity.
216
+ Float for single quantizer, List[float] for multiple quantizers.
217
+ - commit_loss (torch.Tensor): Commitment loss scalar tensor.
218
+ """
219
  if isinstance(self.vq, ResidualVectorQuantize):
220
  z_q, indices, _, commitment_loss, codebook_loss = self.vq(z_e)
221
  commit_loss = commitment_loss.mean() + codebook_loss.mean()
 
241
  return z_q, indices, perplexity, commit_loss
242
 
243
  def _dequantize(self, indices: torch.Tensor) -> torch.Tensor:
244
+ """Dequantize token indices back to continuous latent representations.
245
+
246
+ Args:
247
+ indices (torch.Tensor): Quantization indices. Shape depends on quantizer type:
248
+ - For VQ: (b, n_tokens) or (b, n_tokens, 1)
249
+ - For RVQ: (b, n_tokens_per_quantizer, n_quantizers)
250
+
251
+ Returns:
252
+ torch.Tensor: Dequantized latent representations.
253
+ Shape: (b, n_tokens_per_quantizer, z_dim)
254
+ """
255
  if self.num_quantizers == 1:
256
  if len(indices.size()) == 3:
257
  indices = indices.squeeze(-1)
258
  if isinstance(self.vq, ResidualVectorQuantize):
259
  z_q = self.vq.from_codes(indices)[0]
260
+ elif isinstance(self.vq, VectorQuantize):
261
  z_q = self.vq.get_output_from_indices(indices)
262
+ else:
263
+ raise NotImplementedError(f"VQ type {type(self.vq)} not implemented in _dequantize")
264
  return z_q
265
 
266
  def _decode(
267
  self, z_q: torch.Tensor, embodiment_ids: torch.Tensor | int | None = None, durations: torch.Tensor | None = None
268
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
269
+ """Decode quantized latent representations into action sequences.
270
+
271
+ Args:
272
+ z_q (torch.Tensor): Quantized latent representations.
273
+ Shape: (b, n_tokens_per_quantizer, z_dim).
274
+ embodiment_ids (Union[torch.Tensor, int, None], optional): Embodiment IDs.
275
+ Shape: (b,) if tensor. If int, the same embodiment ID is used for all
276
+ sequences. Defaults to None, which uses `self.default_embodiment_id`.
277
+ durations (torch.Tensor | None, optional): Duration of each action sequence in seconds.
278
+ Shape: (b,). If None, uses default duration from embodiment_config.
279
+ Defaults to None.
280
+
281
+ Returns:
282
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
283
+ - x_recon (torch.Tensor): Reconstructed action sequences.
284
+ Shape: (b, seq_len, max_action_dim).
285
+ - padding_mask (torch.Tensor): Padding mask indicating valid timesteps.
286
+ Shape: (b, seq_len), where True indicates valid timesteps.
287
+ """
288
  embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
289
  x_recon, padding_mask = self.decoder(z_q, embodiment_ids, durations)
290
  return x_recon, padding_mask
 
292
  @torch.no_grad()
293
  def encode(
294
  self,
295
+ x: Union[np.ndarray, torch.Tensor],
296
+ embodiment_ids: Union[List[int], int, None] = None,
297
+ padding_mask: Union[List[bool], np.ndarray, torch.Tensor, None] = None,
298
+ **kwargs,
299
  ) -> List[List[int]]:
300
+ """Encode action sequences into latent representations (token indices).
301
+
302
+ This method converts action sequences into discrete token indices using the encoder
303
+ and vector quantizer. The input can be either a numpy array or torch tensor.
304
 
305
  Args:
306
+ x (Union[np.ndarray, torch.Tensor]): Action sequences to encode.
307
+ Shape: (b, seq_len, max_action_dim).
308
  Assumes that the action dimension is zero-padded to the max action dimension.
309
+ `seq_len` is supposed to be `int(duration * freq)` for each embodiment and
310
+ padded to the max sequence length.
311
+ embodiment_ids (Union[List[int], int, None], optional): Embodiment IDs.
312
+ Shape: (b,) if list. If int, the same embodiment ID is repeated for all
313
+ sequences in the batch. It specifies the embodiment to encode.
314
+ Defaults to None, which uses `self.default_embodiment_id`.
315
+ padding_mask (Union[List[bool], np.ndarray, torch.Tensor, None], optional):
316
+ Padding mask, where `False` values indicate padding. Shape: (b, seq_len).
317
+ Defaults to None. It is used to mask the padding tokens on `seq_len` dimension.
318
+ **kwargs: Additional keyword arguments (currently unused, reserved for future use).
319
 
320
  Returns:
321
+ List[List[int]]: List of token sequences. Shape: (b, n_tokens), where n_tokens
322
+ is determined by the model configuration (typically `config.n_tokens`).
323
+
324
+ Raises:
325
+ ValueError: If input shapes are invalid or incompatible with the model configuration.
326
+ TypeError: If input types are not supported.
327
+
328
+ Examples:
329
+ >>> import numpy as np
330
+ >>> # Using numpy array
331
+ >>> x = np.random.randn(2, 10, 7).astype(np.float32)
332
+ >>> tokens = model.encode(x, embodiment_ids=[0, 0])
333
+ >>> # Using torch tensor
334
+ >>> x_tensor = torch.randn(2, 10, 7)
335
+ >>> tokens = model.encode(x_tensor, embodiment_ids=[0, 0])
336
  """
337
  self.eval()
 
338
 
339
+ # Validate and convert input x
340
+ if isinstance(x, np.ndarray):
341
+ if x.ndim != 3:
342
+ raise ValueError(
343
+ f"Expected 3D input array (batch, seq_len, action_dim), got {x.ndim}D array with shape {x.shape}"
344
+ )
345
  x_tensor = torch.tensor(x, dtype=self.dtype, device=self.device)
346
+ elif isinstance(x, torch.Tensor):
347
+ if x.ndim != 3:
348
+ raise ValueError(
349
+ f"Expected 3D tensor (batch, seq_len, action_dim), got {x.ndim}D tensor with shape {x.shape}"
350
+ )
351
+ x_tensor = x.to(dtype=self.dtype, device=self.device)
352
+ else:
353
+ raise TypeError(f"Input x must be numpy.ndarray or torch.Tensor, got {type(x)}")
354
+
355
+ # Validate batch size
356
+ batch_size = x_tensor.shape[0]
357
+ if batch_size == 0:
358
+ raise ValueError("Batch size must be at least 1")
359
+
360
+ # Handle embodiment_ids
361
+ embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
362
+ if isinstance(embodiment_ids, int):
363
+ if not 0 <= embodiment_ids < len(self.config.embodiment_config):
364
+ raise ValueError(
365
+ f"embodiment_id {embodiment_ids} is out of range [0, {len(self.config.embodiment_config)}). "
366
+ f"Available embodiment IDs: {list(range(len(self.config.embodiment_config)))}"
367
+ )
368
+ embodiment_ids_tensor = torch.tensor([embodiment_ids] * batch_size, dtype=torch.long, device=self.device)
369
+ elif isinstance(embodiment_ids, list):
370
+ if len(embodiment_ids) != batch_size:
371
+ raise ValueError(
372
+ f"Length of embodiment_ids ({len(embodiment_ids)}) must match batch size ({batch_size})"
373
+ )
374
+ for eid in embodiment_ids:
375
+ if not isinstance(eid, int) or not 0 <= eid < len(self.config.embodiment_config):
376
+ raise ValueError(
377
+ f"Invalid embodiment_id {eid}. Must be an integer in range [0, {len(self.config.embodiment_config)})"
378
+ )
379
+ embodiment_ids_tensor = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device)
380
+ else:
381
+ raise TypeError(f"embodiment_ids must be int, List[int], or None, got {type(embodiment_ids)}")
382
+
383
+ # Handle padding_mask
384
+ padding_mask_tensor = None
385
+ if padding_mask is not None:
386
+ if isinstance(padding_mask, (list, np.ndarray)):
387
+ padding_mask_tensor = torch.tensor(padding_mask, dtype=torch.bool, device=self.device)
388
+ elif isinstance(padding_mask, torch.Tensor):
389
+ padding_mask_tensor = padding_mask.to(dtype=torch.bool, device=self.device)
390
+ else:
391
+ raise TypeError(
392
+ f"padding_mask must be List[bool], np.ndarray, torch.Tensor, or None, got {type(padding_mask)}"
393
+ )
394
+ if padding_mask_tensor.shape != (batch_size, x_tensor.shape[1]):
395
+ raise ValueError(
396
+ f"padding_mask shape {padding_mask_tensor.shape} does not match expected shape "
397
+ f"({batch_size}, {x_tensor.shape[1]})"
398
+ )
399
 
400
+ with torch.no_grad():
401
+ z_e = self._encode(x_tensor, embodiment_ids_tensor, padding_mask_tensor)
402
  _, indices, _, _ = self._quantize(z_e, return_perplexity=False)
403
+
404
+ # Reshape indices: for RVQ, indices shape is (b, n, s), for VQ it's (b, n)
405
  if len(indices.size()) > 2:
406
  codes_list = einops.rearrange(indices, "b n s -> b (s n)").cpu()
407
  else:
408
  codes_list = indices.cpu()
409
+
410
  codes_list = codes_list.tolist()
411
  return codes_list
412
 
413
  @torch.no_grad()
414
  def decode(
415
+ self,
416
+ tokens: Union[List[List[int]], np.ndarray, torch.Tensor],
417
+ embodiment_ids: Union[List[int], int, None] = None,
418
+ durations: Union[List[float], np.ndarray, torch.Tensor, None] = None,
419
+ **kwargs,
420
+ ) -> Tuple[np.ndarray, np.ndarray]:
421
+ """Decode token sequences into action sequences.
422
+
423
+ This method reconstructs action sequences from discrete token indices using the
424
+ vector quantizer and decoder. The input tokens can be a list of lists, numpy array,
425
+ or torch tensor.
426
+
427
+ Args:
428
+ tokens (Union[List[List[int]], np.ndarray, torch.Tensor]): Token sequences to decode.
429
+ Shape: (b, n_tokens), where n_tokens must be divisible by `n_tokens_per_quantizer`.
430
+ For RVQ, tokens are interleaved: [q0_t0, q1_t0, ..., qN_t0, q0_t1, ...].
431
+ embodiment_ids (Union[List[int], int, None], optional): Embodiment IDs.
432
+ Shape: (b,) if list. If int, the same embodiment ID is repeated for all
433
+ sequences in the batch. It specifies the embodiment to decode.
434
+ Defaults to None, which uses `self.default_embodiment_id`.
435
+ durations (Union[List[float], np.ndarray, torch.Tensor, None], optional):
436
+ Duration of each action sequence in seconds. Shape: (b,).
437
+ If None, the duration is inferred from the default values in `embodiment_config`.
438
+ Defaults to None.
439
+ **kwargs: Additional keyword arguments (currently unused, reserved for future use).
440
+
441
+ Returns:
442
+ Tuple[np.ndarray, np.ndarray]: A tuple containing:
443
+ - reconstructed_actions: Reconstructed action sequences.
444
+ Shape: (b, seq_len, max_action_dim).
445
+ - padding_mask: Padding mask indicating valid timesteps.
446
+ Shape: (b, seq_len), where True indicates valid timesteps.
447
+
448
+ Raises:
449
+ ValueError: If token sequence length is invalid or incompatible with the model configuration.
450
+ TypeError: If input types are not supported.
451
+
452
+ Examples:
453
+ >>> # Using list of lists
454
+ >>> tokens = [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16]]
455
+ >>> actions, mask = model.decode(tokens, embodiment_ids=[0, 0])
456
+ >>> # Using numpy array
457
+ >>> tokens_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
458
+ >>> actions, mask = model.decode(tokens_np, embodiment_ids=[0, 0])
459
+ >>> # Using torch tensor
460
+ >>> tokens_tensor = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
461
+ >>> actions, mask = model.decode(tokens_tensor, embodiment_ids=[0, 0])
462
+ """
463
  self.eval()
464
+
465
+ # Validate and convert input tokens
466
+ if isinstance(tokens, list):
467
+ if not all(isinstance(seq, list) for seq in tokens):
468
+ raise TypeError("If tokens is a list, all elements must be lists")
469
+ if len(tokens) == 0:
470
+ raise ValueError("Tokens list cannot be empty")
471
+ if not all(isinstance(val, (int, np.integer)) for seq in tokens for val in seq):
472
+ raise TypeError("All token values must be integers")
473
+ tokens_tensor = torch.tensor(tokens, dtype=torch.long, device=self.device)
474
+ elif isinstance(tokens, np.ndarray):
475
+ if tokens.ndim != 2:
476
+ raise ValueError(
477
+ f"Expected 2D array (batch, n_tokens), got {tokens.ndim}D array with shape {tokens.shape}"
478
+ )
479
+ if not np.issubdtype(tokens.dtype, np.integer):
480
+ raise TypeError(f"Tokens array must have integer dtype, got {tokens.dtype}")
481
+ tokens_tensor = torch.tensor(tokens, dtype=torch.long, device=self.device)
482
+ elif isinstance(tokens, torch.Tensor):
483
+ if tokens.ndim != 2:
484
+ raise ValueError(
485
+ f"Expected 2D tensor (batch, n_tokens), got {tokens.ndim}D tensor with shape {tokens.shape}"
486
+ )
487
+ if not tokens.dtype.is_integer:
488
+ raise TypeError(f"Tokens tensor must have integer dtype, got {tokens.dtype}")
489
+ tokens_tensor = tokens.to(dtype=torch.long, device=self.device)
490
+ else:
491
+ raise TypeError(f"tokens must be List[List[int]], np.ndarray, or torch.Tensor, got {type(tokens)}")
492
+
493
+ batch_size, n_tokens = tokens_tensor.shape
494
+ if batch_size == 0:
495
+ raise ValueError("Batch size must be at least 1")
496
+ if n_tokens == 0:
497
+ raise ValueError("Token sequence length must be at least 1")
498
+
499
+ # Validate token sequence length
500
+ if n_tokens % self.n_tokens_per_quantizer != 0:
501
+ raise ValueError(
502
+ f"Token sequence length ({n_tokens}) must be divisible by tokens per quantizer "
503
+ f"({self.n_tokens_per_quantizer}). Total tokens: {n_tokens}, "
504
+ f"Expected multiple of: {self.n_tokens_per_quantizer}. "
505
+ f"Number of quantizers: {self.num_quantizers}, Total tokens per sequence: {self.config.n_tokens}"
506
+ )
507
+
508
+ # Validate token values are within codebook range
509
+ if tokens_tensor.min() < 0 or tokens_tensor.max() >= self.vocab_size:
510
+ raise ValueError(
511
+ f"Token values must be in range [0, {self.vocab_size}), "
512
+ f"got range [{tokens_tensor.min().item()}, {tokens_tensor.max().item()}]"
513
+ )
514
+
515
+ # Handle embodiment_ids
516
  embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
517
+ if isinstance(embodiment_ids, int):
518
+ if not 0 <= embodiment_ids < len(self.config.embodiment_config):
519
+ raise ValueError(
520
+ f"embodiment_id {embodiment_ids} is out of range [0, {len(self.config.embodiment_config)}). "
521
+ f"Available embodiment IDs: {list(range(len(self.config.embodiment_config)))}"
522
+ )
523
+ embodiment_ids_tensor = torch.tensor([embodiment_ids] * batch_size, dtype=torch.long, device=self.device)
524
+ elif isinstance(embodiment_ids, list):
525
+ if len(embodiment_ids) != batch_size:
526
+ raise ValueError(
527
+ f"Length of embodiment_ids ({len(embodiment_ids)}) must match batch size ({batch_size})"
528
+ )
529
+ for eid in embodiment_ids:
530
+ if not isinstance(eid, int) or not 0 <= eid < len(self.config.embodiment_config):
531
+ raise ValueError(
532
+ f"Invalid embodiment_id {eid}. Must be an integer in range [0, {len(self.config.embodiment_config)})"
533
+ )
534
+ embodiment_ids_tensor = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device)
535
+ else:
536
+ raise TypeError(f"embodiment_ids must be int, List[int], or None, got {type(embodiment_ids)}")
537
+
538
+ # Handle durations
539
+ durations_tensor = None
540
  if durations is not None:
541
+ if isinstance(durations, (list, np.ndarray)):
542
+ durations_tensor = torch.tensor(durations, dtype=torch.float32, device=self.device)
543
+ elif isinstance(durations, torch.Tensor):
544
+ durations_tensor = durations.to(dtype=torch.float32, device=self.device)
545
+ else:
546
+ raise TypeError(
547
+ f"durations must be List[float], np.ndarray, torch.Tensor, or None, got {type(durations)}"
548
+ )
549
+ if durations_tensor.ndim != 1:
550
+ raise ValueError(
551
+ f"durations must be 1D, got {durations_tensor.ndim}D with shape {durations_tensor.shape}"
552
+ )
553
+ if len(durations_tensor) != batch_size:
554
+ raise ValueError(f"Length of durations ({len(durations_tensor)}) must match batch size ({batch_size})")
555
+ if (durations_tensor <= 0).any():
556
+ raise ValueError("All durations must be positive")
557
+
558
+ # Reshape tokens for dequantization: (b, n_tokens) -> (b, n_tokens_per_quantizer, n_quantizers)
559
+ indices = einops.rearrange(tokens_tensor, "b (n m) -> b m n", m=self.n_tokens_per_quantizer)
560
+
561
+ with torch.no_grad():
562
+ z_q = self._dequantize(indices)
563
+ x_recon, padding_mask = self._decode(z_q, embodiment_ids_tensor, durations_tensor)
564
+
565
+ return x_recon.float().cpu().numpy(), padding_mask.float().cpu().numpy()
566
+
567
+ def forward(
568
+ self,
569
+ x: Union[torch.Tensor, np.ndarray],
570
+ embodiment_ids: Union[torch.Tensor, int, List[int], None] = None,
571
+ padding_mask: Union[torch.Tensor, List[bool], np.ndarray, None] = None,
572
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
573
+ """Forward pass through the full ActionCodec pipeline.
574
+
575
+ This method performs encoding, quantization, and decoding in a single forward pass.
576
+ It is primarily used during training to compute reconstruction loss and commitment loss.
577
+ Both numpy arrays and torch tensors are supported as input.
578
+
579
+ Args:
580
+ x (Union[torch.Tensor, np.ndarray]): Action sequences to process.
581
+ Shape: (b, seq_len, max_action_dim).
582
+ embodiment_ids (Union[torch.Tensor, int, List[int], None], optional):
583
+ Embodiment IDs. Shape: (b,) if tensor or list. If int, same ID for all sequences.
584
+ Defaults to None, which uses `self.default_embodiment_id`.
585
+ padding_mask (Union[torch.Tensor, List[bool], np.ndarray, None], optional):
586
+ Padding mask. Shape: (b, seq_len). Defaults to None.
587
+
588
+ Returns:
589
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
590
+ - x_recon (torch.Tensor): Reconstructed action sequences.
591
+ Shape: (b, seq_len, max_action_dim).
592
+ - recon_mask (torch.Tensor): Reconstruction mask indicating valid timesteps.
593
+ Shape: (b, seq_len), where True indicates valid timesteps.
594
+
595
+ Note:
596
+ - For inference use cases, prefer using `encode()` and `decode()` methods separately.
597
+ - If you need token indices, use the `encode()` method instead.
598
+ """
599
+ # Convert numpy array to torch tensor if needed
600
+ if isinstance(x, np.ndarray):
601
+ x = torch.tensor(x, dtype=self.dtype, device=self.device)
602
+
603
+ # Handle embodiment_ids conversion
604
+ if isinstance(embodiment_ids, list):
605
+ embodiment_ids = torch.tensor(embodiment_ids, device=x.device, dtype=torch.long)
606
+ elif isinstance(embodiment_ids, int):
607
+ # Keep as int, will be handled by _encode
608
+ pass
609
+
610
+ # Handle padding_mask conversion
611
+ if isinstance(padding_mask, (list, np.ndarray)):
612
+ padding_mask = torch.tensor(padding_mask, device=x.device, dtype=torch.bool)
613
+
614
+ # Full forward pass: encode -> quantize -> decode
615
+ z_e = self._encode(x, embodiment_ids, padding_mask)
616
+ z_q, indices, perplexity, commit_loss = self._quantize(z_e, return_perplexity=True)
617
+ x_recon, recon_mask = self._decode(z_q, embodiment_ids)
618
+
619
+ return x_recon, recon_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620
 
621
 
622
  AutoModel.register(ActionCodecConfig, ActionCodec)