from typing import List import einops import numpy as np import torch from transformers import AutoModel, PreTrainedModel from vector_quantize_pytorch import VectorQuantize from .configuration_actioncodec import ActionCodecConfig from .modular_actioncodec import PerceiverDecoder, PerceiverEncoder from .rvq import ResidualVectorQuantize def trim_trailing_zeros(arr: np.ndarray) -> list[np.ndarray]: if arr.shape[0] == 0: return [] b, n = arr.shape is_nonzero = arr != 0 flipped_mask = np.flip(is_nonzero, axis=1) last_nonzero_indices = n - 1 - np.argmax(flipped_mask, axis=1) any_nonzero_in_row = is_nonzero.any(axis=1) new_lengths = (last_nonzero_indices + 1) * any_nonzero_in_row result = [arr[i, :length].tolist() for i, length in enumerate(new_lengths)] return result class ActionCodec(PreTrainedModel): config_class = ActionCodecConfig def __init__(self, config: ActionCodecConfig): super().__init__(config) self.default_embodiment_id = 0 self.encoder = PerceiverEncoder(config) self.decoder = PerceiverDecoder(config) if config.vq_type == "vq": assert config.n_quantizers == 1, "Only one quantizer is supported for VQ" self.vq = VectorQuantize( dim=config.z_dim, codebook_size=config.vq_codebook_size, commitment_weight=config.vq_commitment_weight, decay=config.vq_decay, kmeans_init=config.vq_kmeans_init, threshold_ema_dead_code=config.vq_threshold_ema_dead_code, rotation_trick=False, straight_through=True, ) elif config.vq_type == "rvq": assert config.n_quantizers > 1, "At least two quantizers are supported for RVQ" self.vq = ResidualVectorQuantize( dim=config.z_dim, n_codebooks=config.n_quantizers, codebook_size=config.vq_codebook_size, codebook_dim=config.z_dim, quantizer_dropout=config.vq_quantizer_dropout, commitment=config.vq_commitment_weight, ) else: raise NotImplementedError(f"VQ type {config.vq_type} not implemented") self.vocab_size = config.vq_codebook_size self.num_quantizers = config.n_quantizers self.n_tokens_per_quantizer = config.n_tokens // config.n_quantizers def expand_embodiment(self, embodiment_config: dict): """ Delegates expansion to the underlying Encoder and Decoder. This allows the Codec to adapt to new robots dynamically. """ self.encoder.expand_embodiment(embodiment_config) self.decoder.expand_embodiment(embodiment_config) self.config.embodiment_config.update(embodiment_config) return self def _encode( self, x: torch.Tensor, embodiment_ids: torch.Tensor | int | None = None, padding_mask: torch.Tensor | None = None, ) -> torch.Tensor: """Encode action sequences into latent representations. Args: x (torch.Tensor): Action sequences to encode. Shape: (b, seq_len, max_action_dim). Assumes that the action dimension is zero-padded to the max action dimension. `seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length. embodiment_ids (torch.Tensor | int): Embodiment IDs. Shape: (b,). If int, the same embodiment ID is repeated for all sequences in the batch. It specifies the embodiment to encode. padding_mask (Optional[torch.Tensor], optional): Padding mask, where `False` values indicate padding. Shape: (b, seq_len). Defaults to None. It is used to mask the padding tokens on `seq_len` dimension. Returns: torch.Tensor: Encoded latent representations. Shape: (b, n_tokens_per_quantizer, z_dim). """ embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id z_e = self.encoder(x, embodiment_ids, padding_mask) return z_e def _quantize(self, z_e: torch.Tensor, return_perplexity: bool = True) -> List[torch.Tensor]: if isinstance(self.vq, ResidualVectorQuantize): z_q, indices, _, commitment_loss, codebook_loss = self.vq(z_e) commit_loss = commitment_loss.mean() + codebook_loss.mean() elif isinstance(self.vq, VectorQuantize): z_q, indices, commit_loss = self.vq(z_e) else: raise NotImplementedError(f"VQ type {type(self.vq)} not implemented") if return_perplexity: if len(indices.size()) < 3: indices = indices.unsqueeze(-1) perplexity = [] for k in range(indices.size(-1)): this_indices = indices[:, :, k] indices_count = torch.bincount(this_indices.view(-1), minlength=self.vq.codebook_size) if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: torch.distributed.all_reduce(indices_count) this_avg_probs = indices_count.float() / indices_count.sum() perplexity.append(((-(this_avg_probs * torch.log(this_avg_probs + 1e-10)).sum()).exp().item())) else: perplexity = 0 return z_q, indices, perplexity, commit_loss def _dequantize(self, indices: torch.Tensor) -> torch.Tensor: if self.num_quantizers == 1: if len(indices.size()) == 3: indices = indices.squeeze(-1) if isinstance(self.vq, ResidualVectorQuantize): z_q = self.vq.from_codes(indices)[0] else: z_q = self.vq.get_output_from_indices(indices) return z_q def _decode( self, z_q: torch.Tensor, embodiment_ids: torch.Tensor | int | None = None, durations: torch.Tensor | None = None ) -> torch.Tensor: embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id x_recon, padding_mask = self.decoder(z_q, embodiment_ids, durations) return x_recon, padding_mask @torch.no_grad() def encode( self, x: np.ndarray, embodiment_ids: List[int] | int | None = None, padding_mask: List[bool] | None = None, ) -> List[List[int]]: """Encode action sequences into latent representations. Args: x (np.ndarray): Action sequences to encode. Shape: (b, seq_len, max_action_dim). Assumes that the action dimension is zero-padded to the max action dimension. `seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length. embodiment_ids (List[int] | int): Embodiment IDs. Shape: (b,). If int, the same embodiment ID is repeated for all sequences in the batch. It specifies the embodiment to encode. padding_mask (List[bool] | None): Padding mask, where `False` values indicate padding. Shape: (b, seq_len). Defaults to None. It is used to mask the padding tokens on `seq_len` dimension. Returns: List[List[int]]: List of token sequences. Shape: (b, n_tokens). """ self.eval() embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id with torch.no_grad(): x_tensor = torch.tensor(x, dtype=self.dtype, device=self.device) if not isinstance(embodiment_ids, int): embodiment_ids = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device) if padding_mask is not None: padding_mask = torch.tensor(padding_mask, dtype=torch.bool, device=self.device) z_e = self._encode(x_tensor, embodiment_ids, padding_mask) _, indices, _, _ = self._quantize(z_e, return_perplexity=False) if len(indices.size()) > 2: codes_list = einops.rearrange(indices, "b n s -> b (s n)").cpu() else: codes_list = indices.cpu() codes_list = codes_list.tolist() return codes_list @torch.no_grad() def decode( self, tokens: List[List[int]], embodiment_ids: List[int] | int | None = None, durations: List[float] | None = None ) -> np.ndarray: self.eval() embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id tokens = torch.tensor(tokens, dtype=torch.long, device=self.device) if not isinstance(embodiment_ids, int): embodiment_ids = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device) if durations is not None: durations = torch.tensor(durations, dtype=torch.float32, device=self.device) b, n = tokens.shape assert n % self.n_tokens_per_quantizer == 0, ( f"Expected {self.n_tokens_per_quantizer} tokens per quantizer, got {n} in total." ) indices = einops.rearrange(tokens, "b (n m) -> b m n", m=self.n_tokens_per_quantizer) z_q = self._dequantize(indices) x_recon, padding_mask = self._decode(z_q, embodiment_ids, durations) return x_recon.cpu().numpy(), padding_mask.cpu().numpy() # def sparse_encode( # self, # x: np.ndarray, # search_num: int = 10, # threshold: float = 0.1, # action_encoding: str | None = None, # remove_padding: bool = True, # ) -> List[List[int]]: # """ # Sparse encoding with adaptive token selection based on reconstruction error threshold. # Uses quaternary search to find optimal token length. # Args: # x: Input action arrays of shape (b, n, d) # search_num: Maximum number of search iterations # threshold: Reconstruction error threshold # action_encoding: Action encoding type # remove_padding: Whether to remove trailing zeros # Returns: # List of sparse token sequences # """ # self.eval() # with torch.no_grad(): # x_tensor = self._numpy_to_tensor(x) # # Get initial encoding # z_e = self._encode(x_tensor, action_encoding) # _, indices, _, _ = self._quantize(z_e, return_perplexity=False) # # Convert indices to proper format # if len(indices.size()) > 2: # indices_flat = einops.rearrange(indices, "b n s -> b (s n)") # else: # indices_flat = indices # # Use quaternary search to find optimal token lengths # optimal_lengths = self._quaternary_search(x_tensor, indices_flat, threshold, search_num, action_encoding) # # Create final sparse tokens based on optimal lengths # final_tokens = self._create_sparse_tokens_from_lengths(indices_flat, optimal_lengths) # # Convert to list format # if remove_padding: # final_tokens = trim_trailing_zeros(final_tokens.cpu().numpy()) # else: # final_tokens = final_tokens.cpu().tolist() # return final_tokens # def _quaternary_search( # self, # x_tensor: torch.Tensor, # indices_flat: torch.Tensor, # threshold: float, # search_num: int, # action_encoding: str | None = None, # ) -> torch.Tensor: # """ # Quaternary search to find optimal token lengths for each batch item. # Returns tensor of shape (batch_size,) containing optimal lengths. # """ # batch_size, seq_len = indices_flat.shape # # Initialize search bounds # device = indices_flat.device # left = torch.ones(batch_size, dtype=torch.long, device=device) # right = torch.full((batch_size,), seq_len, dtype=torch.long, device=device) # # Perform quaternary search # for _ in range(search_num): # # Calculate three division points # range_size = right - left # q1 = left + range_size // 4 # q2 = left + range_size // 2 # q3 = left + 3 * range_size // 4 # # Ensure q1, q2, q3 are within bounds and distinct # q1 = torch.clamp(q1, left, right) # q2 = torch.clamp(q2, q1 + 1, right) # q3 = torch.clamp(q3, q2 + 1, right) # # Create test lengths: [left, q1, q2, q3, right] # test_lengths = torch.stack([left, q1, q2, q3, right], dim=1) # (batch_size, 5) # # Calculate errors for all test lengths # errors = self._calculate_errors_for_lengths(x_tensor, indices_flat, test_lengths, action_encoding) # # Update search bounds based on results (vectorized) # # Find which lengths meet threshold for each batch item # meets_threshold = errors <= threshold # # For each batch item, find the smallest length that meets threshold # valid_indices = torch.argmax(meets_threshold.float(), dim=1) # First True index # has_valid = meets_threshold.any(dim=1) # Whether any length meets threshold # # Create batch indices for advanced indexing # batch_indices = torch.arange(batch_size, device=device) # # Get the smallest valid length for each batch # smallest_valid_lengths = test_lengths[batch_indices, valid_indices] # # Update bounds based on results # # If has valid length, use it; otherwise use longest length # right = torch.where(has_valid, smallest_valid_lengths, test_lengths[:, -1]) # # Update left bound: if we found a valid length and it's not the first one, # # use the previous length; otherwise keep current left # prev_lengths = torch.where(valid_indices > 0, test_lengths[batch_indices, valid_indices - 1], left) # left = torch.where(has_valid & (valid_indices > 0), prev_lengths, left) # # Check convergence # if (right - left).max() <= 1: # break # return right # Return optimal lengths # def _calculate_errors_for_lengths( # self, # x_tensor: torch.Tensor, # indices_flat: torch.Tensor, # test_lengths: torch.Tensor, # action_encoding: str | None = None, # ) -> torch.Tensor: # """ # Calculate reconstruction errors for given token lengths. # Args: # x_tensor: Original input tensor (batch_size, ...) # indices_flat: Full token indices (batch_size, seq_len) # test_lengths: Test lengths tensor (batch_size, num_tests) # action_encoding: Action encoding type # Returns: # Error tensor (batch_size, num_tests) # """ # # Create sparse tokens for all test lengths (vectorized) # batch_size, num_tests = test_lengths.shape # seq_len = indices_flat.shape[1] # device = indices_flat.device # # Create position tensor for all combinations # positions = torch.arange(seq_len, device=device).unsqueeze(0).unsqueeze(0) # (1, 1, seq_len) # positions = positions.expand(batch_size, num_tests, -1) # (batch_size, num_tests, seq_len) # # Create length mask: positions < test_lengths # length_mask = positions < test_lengths.unsqueeze(2) # (batch_size, num_tests, seq_len) # # Create sparse tokens using advanced indexing # sparse_tokens = torch.where( # length_mask, # indices_flat.unsqueeze(1).expand(-1, num_tests, -1), # torch.zeros_like(indices_flat).unsqueeze(1).expand(-1, num_tests, -1), # ) # # Reshape for parallel processing # sparse_flat = sparse_tokens.view(batch_size * num_tests, seq_len) # # Decode all sparse tokens in parallel # reconstructed_flat = self._decode_sparse_tokens(sparse_flat, action_encoding) # # Reshape back and calculate errors # reconstructed = reconstructed_flat.view(batch_size, num_tests, *x_tensor.shape[1:]) # # Calculate errors # x_expanded = x_tensor.unsqueeze(1).expand(-1, num_tests, -1, -1) # errors = (x_expanded - reconstructed).abs().mean((-1, -2)) # (batch_size, num_tests) # return errors # def _decode_sparse_tokens(self, sparse_tokens: torch.Tensor, action_encoding: str | None = None) -> torch.Tensor: # """Decode sparse tokens to reconstructed data.""" # batch_size, seq_len = sparse_tokens.shape # # Convert to proper indices format for dequantization # if self.num_quantizers > 1: # seq_len_per_quantizer = seq_len // self.num_quantizers # if seq_len % self.num_quantizers != 0: # raise ValueError("Sequence length must be divisible by num_quantizers") # indices_for_decode = sparse_tokens.view(batch_size, self.num_quantizers, seq_len_per_quantizer).transpose( # 1, 2 # ) # (batch_size, seq_len_per_quantizer, num_quantizers) # else: # indices_for_decode = sparse_tokens.unsqueeze(-1) # (batch_size, seq_len, 1) # # Dequantize and decode # z_q = self._dequantize(indices_for_decode) # reconstructed = self._decode(z_q, action_encoding) # return reconstructed # def _create_sparse_tokens_from_lengths( # self, indices_flat: torch.Tensor, optimal_lengths: torch.Tensor # ) -> torch.Tensor: # """Create sparse tokens based on optimal lengths (vectorized).""" # batch_size, seq_len = indices_flat.shape # device = indices_flat.device # # Create position mask for all batch items simultaneously # positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) # (batch_size, seq_len) # length_mask = positions < optimal_lengths.unsqueeze(1) # (batch_size, seq_len) # # Apply mask to create sparse tokens # result = torch.where(length_mask, indices_flat, torch.zeros_like(indices_flat)) # return result def forward(self, x: torch.Tensor, embodiment_ids: int | None = None, padding_mask: List[bool] | None = None): return self.encode(x, embodiment_ids, padding_mask) AutoModel.register(ActionCodecConfig, ActionCodec) __all__ = ["ActionCodec"] if __name__ == "__main__": print("=== ActionCodec Comprehensive Test ===\n") # 1. Configuration Setup (RVQ enabled with n_quantizers=4) initial_config = { "robot_A": {"action_dim": 7, "freq": 10, "duration": 1, "description": "Robot A"}, } # We set n_quantizers=4 to test Residual VQ logic config = ActionCodecConfig( embodiment_config=initial_config, n_tokens=16, # Total tokens per sequence (latent_len * n_quantizers) n_quantizers=4, # RVQ depth vq_type="rvq", vq_codebook_size=256, encoder_dim=128, decoder_dim=128, ) # Expected latent sequence length = n_tokens / n_quantizers = 16 / 4 = 4 latent_seq_len = int(config.n_tokens // config.n_quantizers) print(f"Config: {config.n_quantizers} quantizers, {latent_seq_len} latent vectors per sequence.") codec = ActionCodec(config) codec.eval() # 2. Basic Encode/Decode Test print("\n--- Test 1: Basic Encode/Decode ---") batch_size = 2 seq_len_A = 10 # 10Hz * 1s # Create random action data for Robot A (ID 0) x = np.random.randn(batch_size, seq_len_A, 7).astype(np.float32) # Masking: Second item in batch is half padding padding_mask = np.ones((batch_size, seq_len_A), dtype=bool) padding_mask[1, 5:] = False embodiment_ids = [0, 0] # Encode codes = codec.encode(x, embodiment_ids, padding_mask) print(f"Encoded codes shape (list length): {len(codes)} x {len(codes[0])}") # Validate code length assert len(codes[0]) == config.n_tokens, f"Expected {config.n_tokens} tokens, got {len(codes[0])}" # Decode x_recon, recon_mask = codec.decode(codes, embodiment_ids) print(f"Reconstructed shape: {x_recon.shape}") print(f"Recon mask shape: {recon_mask.shape}") assert x_recon.shape == (batch_size, seq_len_A, 7) # Should imply zero-padding to max dim 7 # 3. Expansion Test print("\n--- Test 2: Dynamic Expansion ---") new_robot_config = {"robot_B": {"action_dim": 10, "freq": 20, "duration": 1, "description": "Robot B (Larger)"}} print("Expanding codec to include Robot B (10 dims, 20Hz)...") codec.expand_embodiment(new_robot_config) assert codec.encoder.max_action_dim == 10 assert codec.decoder.max_action_dim == 10 print("✅ Expansion successful.") # 4. Mixed Batch Test (Old + New Robot) print("\n--- Test 3: Mixed Batch Inference ---") # Batch: [Robot A, Robot B] # Robot A: 10Hz, 1s -> 10 steps. Dims 7. # Robot B: 20Hz, 1s -> 20 steps. Dims 10. # Batch Max Steps: 20. Batch Max Dims: 10. batch_x_mixed = np.zeros((2, 20, 10), dtype=np.float32) # Fill Robot A data (index 0) data_A = np.random.randn(10, 7) batch_x_mixed[0, :10, :7] = data_A # Fill Robot B data (index 1) data_B = np.random.randn(20, 10) batch_x_mixed[1, :20, :10] = data_B # Embodiment IDs: 0 for A, 1 for B # Note: expand_embodiment appends. Original was 0, new is 1. mixed_ids = [0, 1] # Encode Mask mixed_mask = np.zeros((2, 20), dtype=bool) mixed_mask[0, :10] = True mixed_mask[1, :20] = True print("Encoding mixed batch...") mixed_codes = codec.encode(batch_x_mixed, mixed_ids, mixed_mask) print("Decoding mixed batch...") # Explicit durations (optional, but good for verification if we wanted to override defaults) durations = [1, 1] x_recon_mixed, dec_mask_mixed = codec.decode(mixed_codes, mixed_ids, durations) print(f"Mixed Recon Shape: {x_recon_mixed.shape}") # Validation # Robot A output check (mask should be True for first 10, False for rest) valid_A = dec_mask_mixed[0].sum() valid_B = dec_mask_mixed[1].sum() print(f"Valid steps detected by Decoder: Robot A={valid_A}, Robot B={valid_B}") assert valid_A == 10 assert valid_B == 20 # Check dimensionality preservation # Robot A's reconstruction in dims 7-9 should be noise or zero (depending on implementation), # but dims 0-6 should contain signal. print("✅ Mixed batch processed successfully.") print("\n✨ All systems go.")