ActionCodec-Base-RVQft / modeling_actioncodec.py
ZibinDong's picture
Upload folder using huggingface_hub
cc2596c verified
raw
history blame
22.8 kB
from typing import List
import einops
import numpy as np
import torch
from transformers import AutoModel, PreTrainedModel
from vector_quantize_pytorch import VectorQuantize
from .configuration_actioncodec import ActionCodecConfig
from .modular_actioncodec import PerceiverDecoder, PerceiverEncoder
from .rvq import ResidualVectorQuantize
def trim_trailing_zeros(arr: np.ndarray) -> list[np.ndarray]:
if arr.shape[0] == 0:
return []
b, n = arr.shape
is_nonzero = arr != 0
flipped_mask = np.flip(is_nonzero, axis=1)
last_nonzero_indices = n - 1 - np.argmax(flipped_mask, axis=1)
any_nonzero_in_row = is_nonzero.any(axis=1)
new_lengths = (last_nonzero_indices + 1) * any_nonzero_in_row
result = [arr[i, :length].tolist() for i, length in enumerate(new_lengths)]
return result
class ActionCodec(PreTrainedModel):
config_class = ActionCodecConfig
def __init__(self, config: ActionCodecConfig):
super().__init__(config)
self.default_embodiment_id = 0
self.encoder = PerceiverEncoder(config)
self.decoder = PerceiverDecoder(config)
if config.vq_type == "vq":
assert config.n_quantizers == 1, "Only one quantizer is supported for VQ"
self.vq = VectorQuantize(
dim=config.z_dim,
codebook_size=config.vq_codebook_size,
commitment_weight=config.vq_commitment_weight,
decay=config.vq_decay,
kmeans_init=config.vq_kmeans_init,
threshold_ema_dead_code=config.vq_threshold_ema_dead_code,
rotation_trick=False,
straight_through=True,
)
elif config.vq_type == "rvq":
assert config.n_quantizers > 1, "At least two quantizers are supported for RVQ"
self.vq = ResidualVectorQuantize(
dim=config.z_dim,
n_codebooks=config.n_quantizers,
codebook_size=config.vq_codebook_size,
codebook_dim=config.z_dim,
quantizer_dropout=config.vq_quantizer_dropout,
commitment=config.vq_commitment_weight,
)
else:
raise NotImplementedError(f"VQ type {config.vq_type} not implemented")
self.vocab_size = config.vq_codebook_size
self.num_quantizers = config.n_quantizers
self.n_tokens_per_quantizer = config.n_tokens // config.n_quantizers
def expand_embodiment(self, embodiment_config: dict):
"""
Delegates expansion to the underlying Encoder and Decoder.
This allows the Codec to adapt to new robots dynamically.
"""
self.encoder.expand_embodiment(embodiment_config)
self.decoder.expand_embodiment(embodiment_config)
self.config.embodiment_config.update(embodiment_config)
return self
def _encode(
self,
x: torch.Tensor,
embodiment_ids: torch.Tensor | int | None = None,
padding_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Encode action sequences into latent representations.
Args:
x (torch.Tensor): Action sequences to encode. Shape: (b, seq_len, max_action_dim).
Assumes that the action dimension is zero-padded to the max action dimension.
`seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length.
embodiment_ids (torch.Tensor | int): Embodiment IDs. Shape: (b,).
If int, the same embodiment ID is repeated for all sequences in the batch.
It specifies the embodiment to encode.
padding_mask (Optional[torch.Tensor], optional): Padding mask, where `False` values indicate padding. Shape: (b, seq_len). Defaults to None.
It is used to mask the padding tokens on `seq_len` dimension.
Returns:
torch.Tensor: Encoded latent representations. Shape: (b, n_tokens_per_quantizer, z_dim).
"""
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
z_e = self.encoder(x, embodiment_ids, padding_mask)
return z_e
def _quantize(self, z_e: torch.Tensor, return_perplexity: bool = True) -> List[torch.Tensor]:
if isinstance(self.vq, ResidualVectorQuantize):
z_q, indices, _, commitment_loss, codebook_loss = self.vq(z_e)
commit_loss = commitment_loss.mean() + codebook_loss.mean()
elif isinstance(self.vq, VectorQuantize):
z_q, indices, commit_loss = self.vq(z_e)
else:
raise NotImplementedError(f"VQ type {type(self.vq)} not implemented")
if return_perplexity:
if len(indices.size()) < 3:
indices = indices.unsqueeze(-1)
perplexity = []
for k in range(indices.size(-1)):
this_indices = indices[:, :, k]
indices_count = torch.bincount(this_indices.view(-1), minlength=self.vq.codebook_size)
if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1:
torch.distributed.all_reduce(indices_count)
this_avg_probs = indices_count.float() / indices_count.sum()
perplexity.append(((-(this_avg_probs * torch.log(this_avg_probs + 1e-10)).sum()).exp().item()))
else:
perplexity = 0
return z_q, indices, perplexity, commit_loss
def _dequantize(self, indices: torch.Tensor) -> torch.Tensor:
if self.num_quantizers == 1:
if len(indices.size()) == 3:
indices = indices.squeeze(-1)
if isinstance(self.vq, ResidualVectorQuantize):
z_q = self.vq.from_codes(indices)[0]
else:
z_q = self.vq.get_output_from_indices(indices)
return z_q
def _decode(
self, z_q: torch.Tensor, embodiment_ids: torch.Tensor | int | None = None, durations: torch.Tensor | None = None
) -> torch.Tensor:
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
x_recon, padding_mask = self.decoder(z_q, embodiment_ids, durations)
return x_recon, padding_mask
@torch.no_grad()
def encode(
self,
x: np.ndarray,
embodiment_ids: List[int] | int | None = None,
padding_mask: List[bool] | None = None,
) -> List[List[int]]:
"""Encode action sequences into latent representations.
Args:
x (np.ndarray): Action sequences to encode. Shape: (b, seq_len, max_action_dim).
Assumes that the action dimension is zero-padded to the max action dimension.
`seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length.
embodiment_ids (List[int] | int): Embodiment IDs. Shape: (b,).
If int, the same embodiment ID is repeated for all sequences in the batch.
It specifies the embodiment to encode.
padding_mask (List[bool] | None): Padding mask, where `False` values indicate padding. Shape: (b, seq_len). Defaults to None.
It is used to mask the padding tokens on `seq_len` dimension.
Returns:
List[List[int]]: List of token sequences. Shape: (b, n_tokens).
"""
self.eval()
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
with torch.no_grad():
x_tensor = torch.tensor(x, dtype=self.dtype, device=self.device)
if not isinstance(embodiment_ids, int):
embodiment_ids = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device)
if padding_mask is not None:
padding_mask = torch.tensor(padding_mask, dtype=torch.bool, device=self.device)
z_e = self._encode(x_tensor, embodiment_ids, padding_mask)
_, indices, _, _ = self._quantize(z_e, return_perplexity=False)
if len(indices.size()) > 2:
codes_list = einops.rearrange(indices, "b n s -> b (s n)").cpu()
else:
codes_list = indices.cpu()
codes_list = codes_list.tolist()
return codes_list
@torch.no_grad()
def decode(
self, tokens: List[List[int]], embodiment_ids: List[int] | int | None = None, durations: List[float] | None = None
) -> np.ndarray:
self.eval()
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
tokens = torch.tensor(tokens, dtype=torch.long, device=self.device)
if not isinstance(embodiment_ids, int):
embodiment_ids = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device)
if durations is not None:
durations = torch.tensor(durations, dtype=torch.float32, device=self.device)
b, n = tokens.shape
assert n % self.n_tokens_per_quantizer == 0, (
f"Expected {self.n_tokens_per_quantizer} tokens per quantizer, got {n} in total."
)
indices = einops.rearrange(tokens, "b (n m) -> b m n", m=self.n_tokens_per_quantizer)
z_q = self._dequantize(indices)
x_recon, padding_mask = self._decode(z_q, embodiment_ids, durations)
return x_recon.cpu().numpy(), padding_mask.cpu().numpy()
# def sparse_encode(
# self,
# x: np.ndarray,
# search_num: int = 10,
# threshold: float = 0.1,
# action_encoding: str | None = None,
# remove_padding: bool = True,
# ) -> List[List[int]]:
# """
# Sparse encoding with adaptive token selection based on reconstruction error threshold.
# Uses quaternary search to find optimal token length.
# Args:
# x: Input action arrays of shape (b, n, d)
# search_num: Maximum number of search iterations
# threshold: Reconstruction error threshold
# action_encoding: Action encoding type
# remove_padding: Whether to remove trailing zeros
# Returns:
# List of sparse token sequences
# """
# self.eval()
# with torch.no_grad():
# x_tensor = self._numpy_to_tensor(x)
# # Get initial encoding
# z_e = self._encode(x_tensor, action_encoding)
# _, indices, _, _ = self._quantize(z_e, return_perplexity=False)
# # Convert indices to proper format
# if len(indices.size()) > 2:
# indices_flat = einops.rearrange(indices, "b n s -> b (s n)")
# else:
# indices_flat = indices
# # Use quaternary search to find optimal token lengths
# optimal_lengths = self._quaternary_search(x_tensor, indices_flat, threshold, search_num, action_encoding)
# # Create final sparse tokens based on optimal lengths
# final_tokens = self._create_sparse_tokens_from_lengths(indices_flat, optimal_lengths)
# # Convert to list format
# if remove_padding:
# final_tokens = trim_trailing_zeros(final_tokens.cpu().numpy())
# else:
# final_tokens = final_tokens.cpu().tolist()
# return final_tokens
# def _quaternary_search(
# self,
# x_tensor: torch.Tensor,
# indices_flat: torch.Tensor,
# threshold: float,
# search_num: int,
# action_encoding: str | None = None,
# ) -> torch.Tensor:
# """
# Quaternary search to find optimal token lengths for each batch item.
# Returns tensor of shape (batch_size,) containing optimal lengths.
# """
# batch_size, seq_len = indices_flat.shape
# # Initialize search bounds
# device = indices_flat.device
# left = torch.ones(batch_size, dtype=torch.long, device=device)
# right = torch.full((batch_size,), seq_len, dtype=torch.long, device=device)
# # Perform quaternary search
# for _ in range(search_num):
# # Calculate three division points
# range_size = right - left
# q1 = left + range_size // 4
# q2 = left + range_size // 2
# q3 = left + 3 * range_size // 4
# # Ensure q1, q2, q3 are within bounds and distinct
# q1 = torch.clamp(q1, left, right)
# q2 = torch.clamp(q2, q1 + 1, right)
# q3 = torch.clamp(q3, q2 + 1, right)
# # Create test lengths: [left, q1, q2, q3, right]
# test_lengths = torch.stack([left, q1, q2, q3, right], dim=1) # (batch_size, 5)
# # Calculate errors for all test lengths
# errors = self._calculate_errors_for_lengths(x_tensor, indices_flat, test_lengths, action_encoding)
# # Update search bounds based on results (vectorized)
# # Find which lengths meet threshold for each batch item
# meets_threshold = errors <= threshold
# # For each batch item, find the smallest length that meets threshold
# valid_indices = torch.argmax(meets_threshold.float(), dim=1) # First True index
# has_valid = meets_threshold.any(dim=1) # Whether any length meets threshold
# # Create batch indices for advanced indexing
# batch_indices = torch.arange(batch_size, device=device)
# # Get the smallest valid length for each batch
# smallest_valid_lengths = test_lengths[batch_indices, valid_indices]
# # Update bounds based on results
# # If has valid length, use it; otherwise use longest length
# right = torch.where(has_valid, smallest_valid_lengths, test_lengths[:, -1])
# # Update left bound: if we found a valid length and it's not the first one,
# # use the previous length; otherwise keep current left
# prev_lengths = torch.where(valid_indices > 0, test_lengths[batch_indices, valid_indices - 1], left)
# left = torch.where(has_valid & (valid_indices > 0), prev_lengths, left)
# # Check convergence
# if (right - left).max() <= 1:
# break
# return right # Return optimal lengths
# def _calculate_errors_for_lengths(
# self,
# x_tensor: torch.Tensor,
# indices_flat: torch.Tensor,
# test_lengths: torch.Tensor,
# action_encoding: str | None = None,
# ) -> torch.Tensor:
# """
# Calculate reconstruction errors for given token lengths.
# Args:
# x_tensor: Original input tensor (batch_size, ...)
# indices_flat: Full token indices (batch_size, seq_len)
# test_lengths: Test lengths tensor (batch_size, num_tests)
# action_encoding: Action encoding type
# Returns:
# Error tensor (batch_size, num_tests)
# """
# # Create sparse tokens for all test lengths (vectorized)
# batch_size, num_tests = test_lengths.shape
# seq_len = indices_flat.shape[1]
# device = indices_flat.device
# # Create position tensor for all combinations
# positions = torch.arange(seq_len, device=device).unsqueeze(0).unsqueeze(0) # (1, 1, seq_len)
# positions = positions.expand(batch_size, num_tests, -1) # (batch_size, num_tests, seq_len)
# # Create length mask: positions < test_lengths
# length_mask = positions < test_lengths.unsqueeze(2) # (batch_size, num_tests, seq_len)
# # Create sparse tokens using advanced indexing
# sparse_tokens = torch.where(
# length_mask,
# indices_flat.unsqueeze(1).expand(-1, num_tests, -1),
# torch.zeros_like(indices_flat).unsqueeze(1).expand(-1, num_tests, -1),
# )
# # Reshape for parallel processing
# sparse_flat = sparse_tokens.view(batch_size * num_tests, seq_len)
# # Decode all sparse tokens in parallel
# reconstructed_flat = self._decode_sparse_tokens(sparse_flat, action_encoding)
# # Reshape back and calculate errors
# reconstructed = reconstructed_flat.view(batch_size, num_tests, *x_tensor.shape[1:])
# # Calculate errors
# x_expanded = x_tensor.unsqueeze(1).expand(-1, num_tests, -1, -1)
# errors = (x_expanded - reconstructed).abs().mean((-1, -2)) # (batch_size, num_tests)
# return errors
# def _decode_sparse_tokens(self, sparse_tokens: torch.Tensor, action_encoding: str | None = None) -> torch.Tensor:
# """Decode sparse tokens to reconstructed data."""
# batch_size, seq_len = sparse_tokens.shape
# # Convert to proper indices format for dequantization
# if self.num_quantizers > 1:
# seq_len_per_quantizer = seq_len // self.num_quantizers
# if seq_len % self.num_quantizers != 0:
# raise ValueError("Sequence length must be divisible by num_quantizers")
# indices_for_decode = sparse_tokens.view(batch_size, self.num_quantizers, seq_len_per_quantizer).transpose(
# 1, 2
# ) # (batch_size, seq_len_per_quantizer, num_quantizers)
# else:
# indices_for_decode = sparse_tokens.unsqueeze(-1) # (batch_size, seq_len, 1)
# # Dequantize and decode
# z_q = self._dequantize(indices_for_decode)
# reconstructed = self._decode(z_q, action_encoding)
# return reconstructed
# def _create_sparse_tokens_from_lengths(
# self, indices_flat: torch.Tensor, optimal_lengths: torch.Tensor
# ) -> torch.Tensor:
# """Create sparse tokens based on optimal lengths (vectorized)."""
# batch_size, seq_len = indices_flat.shape
# device = indices_flat.device
# # Create position mask for all batch items simultaneously
# positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) # (batch_size, seq_len)
# length_mask = positions < optimal_lengths.unsqueeze(1) # (batch_size, seq_len)
# # Apply mask to create sparse tokens
# result = torch.where(length_mask, indices_flat, torch.zeros_like(indices_flat))
# return result
def forward(self, x: torch.Tensor, embodiment_ids: int | None = None, padding_mask: List[bool] | None = None):
return self.encode(x, embodiment_ids, padding_mask)
AutoModel.register(ActionCodecConfig, ActionCodec)
__all__ = ["ActionCodec"]
if __name__ == "__main__":
print("=== ActionCodec Comprehensive Test ===\n")
# 1. Configuration Setup (RVQ enabled with n_quantizers=4)
initial_config = {
"robot_A": {"action_dim": 7, "freq": 10, "duration": 1, "description": "Robot A"},
}
# We set n_quantizers=4 to test Residual VQ logic
config = ActionCodecConfig(
embodiment_config=initial_config,
n_tokens=16, # Total tokens per sequence (latent_len * n_quantizers)
n_quantizers=4, # RVQ depth
vq_type="rvq",
vq_codebook_size=256,
encoder_dim=128,
decoder_dim=128,
)
# Expected latent sequence length = n_tokens / n_quantizers = 16 / 4 = 4
latent_seq_len = int(config.n_tokens // config.n_quantizers)
print(f"Config: {config.n_quantizers} quantizers, {latent_seq_len} latent vectors per sequence.")
codec = ActionCodec(config)
codec.eval()
# 2. Basic Encode/Decode Test
print("\n--- Test 1: Basic Encode/Decode ---")
batch_size = 2
seq_len_A = 10 # 10Hz * 1s
# Create random action data for Robot A (ID 0)
x = np.random.randn(batch_size, seq_len_A, 7).astype(np.float32)
# Masking: Second item in batch is half padding
padding_mask = np.ones((batch_size, seq_len_A), dtype=bool)
padding_mask[1, 5:] = False
embodiment_ids = [0, 0]
# Encode
codes = codec.encode(x, embodiment_ids, padding_mask)
print(f"Encoded codes shape (list length): {len(codes)} x {len(codes[0])}")
# Validate code length
assert len(codes[0]) == config.n_tokens, f"Expected {config.n_tokens} tokens, got {len(codes[0])}"
# Decode
x_recon, recon_mask = codec.decode(codes, embodiment_ids)
print(f"Reconstructed shape: {x_recon.shape}")
print(f"Recon mask shape: {recon_mask.shape}")
assert x_recon.shape == (batch_size, seq_len_A, 7) # Should imply zero-padding to max dim 7
# 3. Expansion Test
print("\n--- Test 2: Dynamic Expansion ---")
new_robot_config = {"robot_B": {"action_dim": 10, "freq": 20, "duration": 1, "description": "Robot B (Larger)"}}
print("Expanding codec to include Robot B (10 dims, 20Hz)...")
codec.expand_embodiment(new_robot_config)
assert codec.encoder.max_action_dim == 10
assert codec.decoder.max_action_dim == 10
print("✅ Expansion successful.")
# 4. Mixed Batch Test (Old + New Robot)
print("\n--- Test 3: Mixed Batch Inference ---")
# Batch: [Robot A, Robot B]
# Robot A: 10Hz, 1s -> 10 steps. Dims 7.
# Robot B: 20Hz, 1s -> 20 steps. Dims 10.
# Batch Max Steps: 20. Batch Max Dims: 10.
batch_x_mixed = np.zeros((2, 20, 10), dtype=np.float32)
# Fill Robot A data (index 0)
data_A = np.random.randn(10, 7)
batch_x_mixed[0, :10, :7] = data_A
# Fill Robot B data (index 1)
data_B = np.random.randn(20, 10)
batch_x_mixed[1, :20, :10] = data_B
# Embodiment IDs: 0 for A, 1 for B
# Note: expand_embodiment appends. Original was 0, new is 1.
mixed_ids = [0, 1]
# Encode Mask
mixed_mask = np.zeros((2, 20), dtype=bool)
mixed_mask[0, :10] = True
mixed_mask[1, :20] = True
print("Encoding mixed batch...")
mixed_codes = codec.encode(batch_x_mixed, mixed_ids, mixed_mask)
print("Decoding mixed batch...")
# Explicit durations (optional, but good for verification if we wanted to override defaults)
durations = [1, 1]
x_recon_mixed, dec_mask_mixed = codec.decode(mixed_codes, mixed_ids, durations)
print(f"Mixed Recon Shape: {x_recon_mixed.shape}")
# Validation
# Robot A output check (mask should be True for first 10, False for rest)
valid_A = dec_mask_mixed[0].sum()
valid_B = dec_mask_mixed[1].sum()
print(f"Valid steps detected by Decoder: Robot A={valid_A}, Robot B={valid_B}")
assert valid_A == 10
assert valid_B == 20
# Check dimensionality preservation
# Robot A's reconstruction in dims 7-9 should be noise or zero (depending on implementation),
# but dims 0-6 should contain signal.
print("✅ Mixed batch processed successfully.")
print("\n✨ All systems go.")