|
|
from typing import List |
|
|
|
|
|
import einops |
|
|
import numpy as np |
|
|
import torch |
|
|
from transformers import AutoModel, PreTrainedModel |
|
|
from vector_quantize_pytorch import VectorQuantize |
|
|
|
|
|
from .configuration_actioncodec import ActionCodecConfig |
|
|
from .modular_actioncodec import PerceiverDecoder, PerceiverEncoder |
|
|
from .rvq import ResidualVectorQuantize |
|
|
|
|
|
|
|
|
def trim_trailing_zeros(arr: np.ndarray) -> list[np.ndarray]: |
|
|
if arr.shape[0] == 0: |
|
|
return [] |
|
|
|
|
|
b, n = arr.shape |
|
|
|
|
|
is_nonzero = arr != 0 |
|
|
flipped_mask = np.flip(is_nonzero, axis=1) |
|
|
last_nonzero_indices = n - 1 - np.argmax(flipped_mask, axis=1) |
|
|
any_nonzero_in_row = is_nonzero.any(axis=1) |
|
|
new_lengths = (last_nonzero_indices + 1) * any_nonzero_in_row |
|
|
result = [arr[i, :length].tolist() for i, length in enumerate(new_lengths)] |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
class ActionCodec(PreTrainedModel): |
|
|
config_class = ActionCodecConfig |
|
|
|
|
|
def __init__(self, config: ActionCodecConfig): |
|
|
super().__init__(config) |
|
|
self.default_embodiment_id = 0 |
|
|
|
|
|
self.encoder = PerceiverEncoder(config) |
|
|
self.decoder = PerceiverDecoder(config) |
|
|
|
|
|
if config.vq_type == "vq": |
|
|
assert config.n_quantizers == 1, "Only one quantizer is supported for VQ" |
|
|
self.vq = VectorQuantize( |
|
|
dim=config.z_dim, |
|
|
codebook_size=config.vq_codebook_size, |
|
|
commitment_weight=config.vq_commitment_weight, |
|
|
decay=config.vq_decay, |
|
|
kmeans_init=config.vq_kmeans_init, |
|
|
threshold_ema_dead_code=config.vq_threshold_ema_dead_code, |
|
|
rotation_trick=False, |
|
|
straight_through=True, |
|
|
) |
|
|
elif config.vq_type == "rvq": |
|
|
assert config.n_quantizers > 1, "At least two quantizers are supported for RVQ" |
|
|
self.vq = ResidualVectorQuantize( |
|
|
dim=config.z_dim, |
|
|
n_codebooks=config.n_quantizers, |
|
|
codebook_size=config.vq_codebook_size, |
|
|
codebook_dim=config.z_dim, |
|
|
quantizer_dropout=config.vq_quantizer_dropout, |
|
|
commitment=config.vq_commitment_weight, |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError(f"VQ type {config.vq_type} not implemented") |
|
|
|
|
|
self.vocab_size = config.vq_codebook_size |
|
|
self.num_quantizers = config.n_quantizers |
|
|
self.n_tokens_per_quantizer = config.n_tokens // config.n_quantizers |
|
|
|
|
|
def expand_embodiment(self, embodiment_config: dict): |
|
|
""" |
|
|
Delegates expansion to the underlying Encoder and Decoder. |
|
|
This allows the Codec to adapt to new robots dynamically. |
|
|
""" |
|
|
self.encoder.expand_embodiment(embodiment_config) |
|
|
self.decoder.expand_embodiment(embodiment_config) |
|
|
self.config.embodiment_config.update(embodiment_config) |
|
|
return self |
|
|
|
|
|
def _encode( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
embodiment_ids: torch.Tensor | int | None = None, |
|
|
padding_mask: torch.Tensor | None = None, |
|
|
) -> torch.Tensor: |
|
|
"""Encode action sequences into latent representations. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): Action sequences to encode. Shape: (b, seq_len, max_action_dim). |
|
|
Assumes that the action dimension is zero-padded to the max action dimension. |
|
|
`seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length. |
|
|
embodiment_ids (torch.Tensor | int): Embodiment IDs. Shape: (b,). |
|
|
If int, the same embodiment ID is repeated for all sequences in the batch. |
|
|
It specifies the embodiment to encode. |
|
|
padding_mask (Optional[torch.Tensor], optional): Padding mask, where `False` values indicate padding. Shape: (b, seq_len). Defaults to None. |
|
|
It is used to mask the padding tokens on `seq_len` dimension. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Encoded latent representations. Shape: (b, n_tokens_per_quantizer, z_dim). |
|
|
""" |
|
|
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id |
|
|
z_e = self.encoder(x, embodiment_ids, padding_mask) |
|
|
return z_e |
|
|
|
|
|
def _quantize(self, z_e: torch.Tensor, return_perplexity: bool = True) -> List[torch.Tensor]: |
|
|
if isinstance(self.vq, ResidualVectorQuantize): |
|
|
z_q, indices, _, commitment_loss, codebook_loss = self.vq(z_e) |
|
|
commit_loss = commitment_loss.mean() + codebook_loss.mean() |
|
|
elif isinstance(self.vq, VectorQuantize): |
|
|
z_q, indices, commit_loss = self.vq(z_e) |
|
|
else: |
|
|
raise NotImplementedError(f"VQ type {type(self.vq)} not implemented") |
|
|
|
|
|
if return_perplexity: |
|
|
if len(indices.size()) < 3: |
|
|
indices = indices.unsqueeze(-1) |
|
|
perplexity = [] |
|
|
for k in range(indices.size(-1)): |
|
|
this_indices = indices[:, :, k] |
|
|
indices_count = torch.bincount(this_indices.view(-1), minlength=self.vq.codebook_size) |
|
|
if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: |
|
|
torch.distributed.all_reduce(indices_count) |
|
|
this_avg_probs = indices_count.float() / indices_count.sum() |
|
|
perplexity.append(((-(this_avg_probs * torch.log(this_avg_probs + 1e-10)).sum()).exp().item())) |
|
|
else: |
|
|
perplexity = 0 |
|
|
|
|
|
return z_q, indices, perplexity, commit_loss |
|
|
|
|
|
def _dequantize(self, indices: torch.Tensor) -> torch.Tensor: |
|
|
if self.num_quantizers == 1: |
|
|
if len(indices.size()) == 3: |
|
|
indices = indices.squeeze(-1) |
|
|
if isinstance(self.vq, ResidualVectorQuantize): |
|
|
z_q = self.vq.from_codes(indices)[0] |
|
|
else: |
|
|
z_q = self.vq.get_output_from_indices(indices) |
|
|
return z_q |
|
|
|
|
|
def _decode( |
|
|
self, z_q: torch.Tensor, embodiment_ids: torch.Tensor | int | None = None, durations: torch.Tensor | None = None |
|
|
) -> torch.Tensor: |
|
|
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id |
|
|
x_recon, padding_mask = self.decoder(z_q, embodiment_ids, durations) |
|
|
return x_recon, padding_mask |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode( |
|
|
self, |
|
|
x: np.ndarray, |
|
|
embodiment_ids: List[int] | int | None = None, |
|
|
padding_mask: List[bool] | None = None, |
|
|
) -> List[List[int]]: |
|
|
"""Encode action sequences into latent representations. |
|
|
|
|
|
Args: |
|
|
x (np.ndarray): Action sequences to encode. Shape: (b, seq_len, max_action_dim). |
|
|
Assumes that the action dimension is zero-padded to the max action dimension. |
|
|
`seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length. |
|
|
embodiment_ids (List[int] | int): Embodiment IDs. Shape: (b,). |
|
|
If int, the same embodiment ID is repeated for all sequences in the batch. |
|
|
It specifies the embodiment to encode. |
|
|
padding_mask (List[bool] | None): Padding mask, where `False` values indicate padding. Shape: (b, seq_len). Defaults to None. |
|
|
It is used to mask the padding tokens on `seq_len` dimension. |
|
|
|
|
|
Returns: |
|
|
List[List[int]]: List of token sequences. Shape: (b, n_tokens). |
|
|
""" |
|
|
self.eval() |
|
|
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id |
|
|
|
|
|
with torch.no_grad(): |
|
|
x_tensor = torch.tensor(x, dtype=self.dtype, device=self.device) |
|
|
if not isinstance(embodiment_ids, int): |
|
|
embodiment_ids = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device) |
|
|
if padding_mask is not None: |
|
|
padding_mask = torch.tensor(padding_mask, dtype=torch.bool, device=self.device) |
|
|
|
|
|
z_e = self._encode(x_tensor, embodiment_ids, padding_mask) |
|
|
_, indices, _, _ = self._quantize(z_e, return_perplexity=False) |
|
|
if len(indices.size()) > 2: |
|
|
codes_list = einops.rearrange(indices, "b n s -> b (s n)").cpu() |
|
|
else: |
|
|
codes_list = indices.cpu() |
|
|
codes_list = codes_list.tolist() |
|
|
return codes_list |
|
|
|
|
|
@torch.no_grad() |
|
|
def decode( |
|
|
self, tokens: List[List[int]], embodiment_ids: List[int] | int | None = None, durations: List[float] | None = None |
|
|
) -> np.ndarray: |
|
|
self.eval() |
|
|
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id |
|
|
tokens = torch.tensor(tokens, dtype=torch.long, device=self.device) |
|
|
if not isinstance(embodiment_ids, int): |
|
|
embodiment_ids = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device) |
|
|
if durations is not None: |
|
|
durations = torch.tensor(durations, dtype=torch.float32, device=self.device) |
|
|
|
|
|
b, n = tokens.shape |
|
|
assert n % self.n_tokens_per_quantizer == 0, ( |
|
|
f"Expected {self.n_tokens_per_quantizer} tokens per quantizer, got {n} in total." |
|
|
) |
|
|
indices = einops.rearrange(tokens, "b (n m) -> b m n", m=self.n_tokens_per_quantizer) |
|
|
z_q = self._dequantize(indices) |
|
|
x_recon, padding_mask = self._decode(z_q, embodiment_ids, durations) |
|
|
return x_recon.cpu().numpy(), padding_mask.cpu().numpy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor, embodiment_ids: int | None = None, padding_mask: List[bool] | None = None): |
|
|
return self.encode(x, embodiment_ids, padding_mask) |
|
|
|
|
|
|
|
|
AutoModel.register(ActionCodecConfig, ActionCodec) |
|
|
|
|
|
__all__ = ["ActionCodec"] |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("=== ActionCodec Comprehensive Test ===\n") |
|
|
|
|
|
|
|
|
initial_config = { |
|
|
"robot_A": {"action_dim": 7, "freq": 10, "duration": 1, "description": "Robot A"}, |
|
|
} |
|
|
|
|
|
|
|
|
config = ActionCodecConfig( |
|
|
embodiment_config=initial_config, |
|
|
n_tokens=16, |
|
|
n_quantizers=4, |
|
|
vq_type="rvq", |
|
|
vq_codebook_size=256, |
|
|
encoder_dim=128, |
|
|
decoder_dim=128, |
|
|
) |
|
|
|
|
|
|
|
|
latent_seq_len = int(config.n_tokens // config.n_quantizers) |
|
|
print(f"Config: {config.n_quantizers} quantizers, {latent_seq_len} latent vectors per sequence.") |
|
|
|
|
|
codec = ActionCodec(config) |
|
|
codec.eval() |
|
|
|
|
|
|
|
|
print("\n--- Test 1: Basic Encode/Decode ---") |
|
|
batch_size = 2 |
|
|
seq_len_A = 10 |
|
|
|
|
|
|
|
|
x = np.random.randn(batch_size, seq_len_A, 7).astype(np.float32) |
|
|
|
|
|
padding_mask = np.ones((batch_size, seq_len_A), dtype=bool) |
|
|
padding_mask[1, 5:] = False |
|
|
|
|
|
embodiment_ids = [0, 0] |
|
|
|
|
|
|
|
|
codes = codec.encode(x, embodiment_ids, padding_mask) |
|
|
print(f"Encoded codes shape (list length): {len(codes)} x {len(codes[0])}") |
|
|
|
|
|
|
|
|
assert len(codes[0]) == config.n_tokens, f"Expected {config.n_tokens} tokens, got {len(codes[0])}" |
|
|
|
|
|
|
|
|
x_recon, recon_mask = codec.decode(codes, embodiment_ids) |
|
|
print(f"Reconstructed shape: {x_recon.shape}") |
|
|
print(f"Recon mask shape: {recon_mask.shape}") |
|
|
|
|
|
assert x_recon.shape == (batch_size, seq_len_A, 7) |
|
|
|
|
|
|
|
|
print("\n--- Test 2: Dynamic Expansion ---") |
|
|
new_robot_config = {"robot_B": {"action_dim": 10, "freq": 20, "duration": 1, "description": "Robot B (Larger)"}} |
|
|
|
|
|
print("Expanding codec to include Robot B (10 dims, 20Hz)...") |
|
|
codec.expand_embodiment(new_robot_config) |
|
|
|
|
|
assert codec.encoder.max_action_dim == 10 |
|
|
assert codec.decoder.max_action_dim == 10 |
|
|
print("✅ Expansion successful.") |
|
|
|
|
|
|
|
|
print("\n--- Test 3: Mixed Batch Inference ---") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_x_mixed = np.zeros((2, 20, 10), dtype=np.float32) |
|
|
|
|
|
|
|
|
data_A = np.random.randn(10, 7) |
|
|
batch_x_mixed[0, :10, :7] = data_A |
|
|
|
|
|
|
|
|
data_B = np.random.randn(20, 10) |
|
|
batch_x_mixed[1, :20, :10] = data_B |
|
|
|
|
|
|
|
|
|
|
|
mixed_ids = [0, 1] |
|
|
|
|
|
|
|
|
mixed_mask = np.zeros((2, 20), dtype=bool) |
|
|
mixed_mask[0, :10] = True |
|
|
mixed_mask[1, :20] = True |
|
|
|
|
|
print("Encoding mixed batch...") |
|
|
mixed_codes = codec.encode(batch_x_mixed, mixed_ids, mixed_mask) |
|
|
|
|
|
print("Decoding mixed batch...") |
|
|
|
|
|
durations = [1, 1] |
|
|
x_recon_mixed, dec_mask_mixed = codec.decode(mixed_codes, mixed_ids, durations) |
|
|
|
|
|
print(f"Mixed Recon Shape: {x_recon_mixed.shape}") |
|
|
|
|
|
|
|
|
|
|
|
valid_A = dec_mask_mixed[0].sum() |
|
|
valid_B = dec_mask_mixed[1].sum() |
|
|
|
|
|
print(f"Valid steps detected by Decoder: Robot A={valid_A}, Robot B={valid_B}") |
|
|
|
|
|
assert valid_A == 10 |
|
|
assert valid_B == 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("✅ Mixed batch processed successfully.") |
|
|
|
|
|
print("\n✨ All systems go.") |
|
|
|