File size: 22,813 Bytes
cc2596c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 |
from typing import List
import einops
import numpy as np
import torch
from transformers import AutoModel, PreTrainedModel
from vector_quantize_pytorch import VectorQuantize
from .configuration_actioncodec import ActionCodecConfig
from .modular_actioncodec import PerceiverDecoder, PerceiverEncoder
from .rvq import ResidualVectorQuantize
def trim_trailing_zeros(arr: np.ndarray) -> list[np.ndarray]:
if arr.shape[0] == 0:
return []
b, n = arr.shape
is_nonzero = arr != 0
flipped_mask = np.flip(is_nonzero, axis=1)
last_nonzero_indices = n - 1 - np.argmax(flipped_mask, axis=1)
any_nonzero_in_row = is_nonzero.any(axis=1)
new_lengths = (last_nonzero_indices + 1) * any_nonzero_in_row
result = [arr[i, :length].tolist() for i, length in enumerate(new_lengths)]
return result
class ActionCodec(PreTrainedModel):
config_class = ActionCodecConfig
def __init__(self, config: ActionCodecConfig):
super().__init__(config)
self.default_embodiment_id = 0
self.encoder = PerceiverEncoder(config)
self.decoder = PerceiverDecoder(config)
if config.vq_type == "vq":
assert config.n_quantizers == 1, "Only one quantizer is supported for VQ"
self.vq = VectorQuantize(
dim=config.z_dim,
codebook_size=config.vq_codebook_size,
commitment_weight=config.vq_commitment_weight,
decay=config.vq_decay,
kmeans_init=config.vq_kmeans_init,
threshold_ema_dead_code=config.vq_threshold_ema_dead_code,
rotation_trick=False,
straight_through=True,
)
elif config.vq_type == "rvq":
assert config.n_quantizers > 1, "At least two quantizers are supported for RVQ"
self.vq = ResidualVectorQuantize(
dim=config.z_dim,
n_codebooks=config.n_quantizers,
codebook_size=config.vq_codebook_size,
codebook_dim=config.z_dim,
quantizer_dropout=config.vq_quantizer_dropout,
commitment=config.vq_commitment_weight,
)
else:
raise NotImplementedError(f"VQ type {config.vq_type} not implemented")
self.vocab_size = config.vq_codebook_size
self.num_quantizers = config.n_quantizers
self.n_tokens_per_quantizer = config.n_tokens // config.n_quantizers
def expand_embodiment(self, embodiment_config: dict):
"""
Delegates expansion to the underlying Encoder and Decoder.
This allows the Codec to adapt to new robots dynamically.
"""
self.encoder.expand_embodiment(embodiment_config)
self.decoder.expand_embodiment(embodiment_config)
self.config.embodiment_config.update(embodiment_config)
return self
def _encode(
self,
x: torch.Tensor,
embodiment_ids: torch.Tensor | int | None = None,
padding_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Encode action sequences into latent representations.
Args:
x (torch.Tensor): Action sequences to encode. Shape: (b, seq_len, max_action_dim).
Assumes that the action dimension is zero-padded to the max action dimension.
`seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length.
embodiment_ids (torch.Tensor | int): Embodiment IDs. Shape: (b,).
If int, the same embodiment ID is repeated for all sequences in the batch.
It specifies the embodiment to encode.
padding_mask (Optional[torch.Tensor], optional): Padding mask, where `False` values indicate padding. Shape: (b, seq_len). Defaults to None.
It is used to mask the padding tokens on `seq_len` dimension.
Returns:
torch.Tensor: Encoded latent representations. Shape: (b, n_tokens_per_quantizer, z_dim).
"""
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
z_e = self.encoder(x, embodiment_ids, padding_mask)
return z_e
def _quantize(self, z_e: torch.Tensor, return_perplexity: bool = True) -> List[torch.Tensor]:
if isinstance(self.vq, ResidualVectorQuantize):
z_q, indices, _, commitment_loss, codebook_loss = self.vq(z_e)
commit_loss = commitment_loss.mean() + codebook_loss.mean()
elif isinstance(self.vq, VectorQuantize):
z_q, indices, commit_loss = self.vq(z_e)
else:
raise NotImplementedError(f"VQ type {type(self.vq)} not implemented")
if return_perplexity:
if len(indices.size()) < 3:
indices = indices.unsqueeze(-1)
perplexity = []
for k in range(indices.size(-1)):
this_indices = indices[:, :, k]
indices_count = torch.bincount(this_indices.view(-1), minlength=self.vq.codebook_size)
if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1:
torch.distributed.all_reduce(indices_count)
this_avg_probs = indices_count.float() / indices_count.sum()
perplexity.append(((-(this_avg_probs * torch.log(this_avg_probs + 1e-10)).sum()).exp().item()))
else:
perplexity = 0
return z_q, indices, perplexity, commit_loss
def _dequantize(self, indices: torch.Tensor) -> torch.Tensor:
if self.num_quantizers == 1:
if len(indices.size()) == 3:
indices = indices.squeeze(-1)
if isinstance(self.vq, ResidualVectorQuantize):
z_q = self.vq.from_codes(indices)[0]
else:
z_q = self.vq.get_output_from_indices(indices)
return z_q
def _decode(
self, z_q: torch.Tensor, embodiment_ids: torch.Tensor | int | None = None, durations: torch.Tensor | None = None
) -> torch.Tensor:
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
x_recon, padding_mask = self.decoder(z_q, embodiment_ids, durations)
return x_recon, padding_mask
@torch.no_grad()
def encode(
self,
x: np.ndarray,
embodiment_ids: List[int] | int | None = None,
padding_mask: List[bool] | None = None,
) -> List[List[int]]:
"""Encode action sequences into latent representations.
Args:
x (np.ndarray): Action sequences to encode. Shape: (b, seq_len, max_action_dim).
Assumes that the action dimension is zero-padded to the max action dimension.
`seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length.
embodiment_ids (List[int] | int): Embodiment IDs. Shape: (b,).
If int, the same embodiment ID is repeated for all sequences in the batch.
It specifies the embodiment to encode.
padding_mask (List[bool] | None): Padding mask, where `False` values indicate padding. Shape: (b, seq_len). Defaults to None.
It is used to mask the padding tokens on `seq_len` dimension.
Returns:
List[List[int]]: List of token sequences. Shape: (b, n_tokens).
"""
self.eval()
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
with torch.no_grad():
x_tensor = torch.tensor(x, dtype=self.dtype, device=self.device)
if not isinstance(embodiment_ids, int):
embodiment_ids = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device)
if padding_mask is not None:
padding_mask = torch.tensor(padding_mask, dtype=torch.bool, device=self.device)
z_e = self._encode(x_tensor, embodiment_ids, padding_mask)
_, indices, _, _ = self._quantize(z_e, return_perplexity=False)
if len(indices.size()) > 2:
codes_list = einops.rearrange(indices, "b n s -> b (s n)").cpu()
else:
codes_list = indices.cpu()
codes_list = codes_list.tolist()
return codes_list
@torch.no_grad()
def decode(
self, tokens: List[List[int]], embodiment_ids: List[int] | int | None = None, durations: List[float] | None = None
) -> np.ndarray:
self.eval()
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
tokens = torch.tensor(tokens, dtype=torch.long, device=self.device)
if not isinstance(embodiment_ids, int):
embodiment_ids = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device)
if durations is not None:
durations = torch.tensor(durations, dtype=torch.float32, device=self.device)
b, n = tokens.shape
assert n % self.n_tokens_per_quantizer == 0, (
f"Expected {self.n_tokens_per_quantizer} tokens per quantizer, got {n} in total."
)
indices = einops.rearrange(tokens, "b (n m) -> b m n", m=self.n_tokens_per_quantizer)
z_q = self._dequantize(indices)
x_recon, padding_mask = self._decode(z_q, embodiment_ids, durations)
return x_recon.cpu().numpy(), padding_mask.cpu().numpy()
# def sparse_encode(
# self,
# x: np.ndarray,
# search_num: int = 10,
# threshold: float = 0.1,
# action_encoding: str | None = None,
# remove_padding: bool = True,
# ) -> List[List[int]]:
# """
# Sparse encoding with adaptive token selection based on reconstruction error threshold.
# Uses quaternary search to find optimal token length.
# Args:
# x: Input action arrays of shape (b, n, d)
# search_num: Maximum number of search iterations
# threshold: Reconstruction error threshold
# action_encoding: Action encoding type
# remove_padding: Whether to remove trailing zeros
# Returns:
# List of sparse token sequences
# """
# self.eval()
# with torch.no_grad():
# x_tensor = self._numpy_to_tensor(x)
# # Get initial encoding
# z_e = self._encode(x_tensor, action_encoding)
# _, indices, _, _ = self._quantize(z_e, return_perplexity=False)
# # Convert indices to proper format
# if len(indices.size()) > 2:
# indices_flat = einops.rearrange(indices, "b n s -> b (s n)")
# else:
# indices_flat = indices
# # Use quaternary search to find optimal token lengths
# optimal_lengths = self._quaternary_search(x_tensor, indices_flat, threshold, search_num, action_encoding)
# # Create final sparse tokens based on optimal lengths
# final_tokens = self._create_sparse_tokens_from_lengths(indices_flat, optimal_lengths)
# # Convert to list format
# if remove_padding:
# final_tokens = trim_trailing_zeros(final_tokens.cpu().numpy())
# else:
# final_tokens = final_tokens.cpu().tolist()
# return final_tokens
# def _quaternary_search(
# self,
# x_tensor: torch.Tensor,
# indices_flat: torch.Tensor,
# threshold: float,
# search_num: int,
# action_encoding: str | None = None,
# ) -> torch.Tensor:
# """
# Quaternary search to find optimal token lengths for each batch item.
# Returns tensor of shape (batch_size,) containing optimal lengths.
# """
# batch_size, seq_len = indices_flat.shape
# # Initialize search bounds
# device = indices_flat.device
# left = torch.ones(batch_size, dtype=torch.long, device=device)
# right = torch.full((batch_size,), seq_len, dtype=torch.long, device=device)
# # Perform quaternary search
# for _ in range(search_num):
# # Calculate three division points
# range_size = right - left
# q1 = left + range_size // 4
# q2 = left + range_size // 2
# q3 = left + 3 * range_size // 4
# # Ensure q1, q2, q3 are within bounds and distinct
# q1 = torch.clamp(q1, left, right)
# q2 = torch.clamp(q2, q1 + 1, right)
# q3 = torch.clamp(q3, q2 + 1, right)
# # Create test lengths: [left, q1, q2, q3, right]
# test_lengths = torch.stack([left, q1, q2, q3, right], dim=1) # (batch_size, 5)
# # Calculate errors for all test lengths
# errors = self._calculate_errors_for_lengths(x_tensor, indices_flat, test_lengths, action_encoding)
# # Update search bounds based on results (vectorized)
# # Find which lengths meet threshold for each batch item
# meets_threshold = errors <= threshold
# # For each batch item, find the smallest length that meets threshold
# valid_indices = torch.argmax(meets_threshold.float(), dim=1) # First True index
# has_valid = meets_threshold.any(dim=1) # Whether any length meets threshold
# # Create batch indices for advanced indexing
# batch_indices = torch.arange(batch_size, device=device)
# # Get the smallest valid length for each batch
# smallest_valid_lengths = test_lengths[batch_indices, valid_indices]
# # Update bounds based on results
# # If has valid length, use it; otherwise use longest length
# right = torch.where(has_valid, smallest_valid_lengths, test_lengths[:, -1])
# # Update left bound: if we found a valid length and it's not the first one,
# # use the previous length; otherwise keep current left
# prev_lengths = torch.where(valid_indices > 0, test_lengths[batch_indices, valid_indices - 1], left)
# left = torch.where(has_valid & (valid_indices > 0), prev_lengths, left)
# # Check convergence
# if (right - left).max() <= 1:
# break
# return right # Return optimal lengths
# def _calculate_errors_for_lengths(
# self,
# x_tensor: torch.Tensor,
# indices_flat: torch.Tensor,
# test_lengths: torch.Tensor,
# action_encoding: str | None = None,
# ) -> torch.Tensor:
# """
# Calculate reconstruction errors for given token lengths.
# Args:
# x_tensor: Original input tensor (batch_size, ...)
# indices_flat: Full token indices (batch_size, seq_len)
# test_lengths: Test lengths tensor (batch_size, num_tests)
# action_encoding: Action encoding type
# Returns:
# Error tensor (batch_size, num_tests)
# """
# # Create sparse tokens for all test lengths (vectorized)
# batch_size, num_tests = test_lengths.shape
# seq_len = indices_flat.shape[1]
# device = indices_flat.device
# # Create position tensor for all combinations
# positions = torch.arange(seq_len, device=device).unsqueeze(0).unsqueeze(0) # (1, 1, seq_len)
# positions = positions.expand(batch_size, num_tests, -1) # (batch_size, num_tests, seq_len)
# # Create length mask: positions < test_lengths
# length_mask = positions < test_lengths.unsqueeze(2) # (batch_size, num_tests, seq_len)
# # Create sparse tokens using advanced indexing
# sparse_tokens = torch.where(
# length_mask,
# indices_flat.unsqueeze(1).expand(-1, num_tests, -1),
# torch.zeros_like(indices_flat).unsqueeze(1).expand(-1, num_tests, -1),
# )
# # Reshape for parallel processing
# sparse_flat = sparse_tokens.view(batch_size * num_tests, seq_len)
# # Decode all sparse tokens in parallel
# reconstructed_flat = self._decode_sparse_tokens(sparse_flat, action_encoding)
# # Reshape back and calculate errors
# reconstructed = reconstructed_flat.view(batch_size, num_tests, *x_tensor.shape[1:])
# # Calculate errors
# x_expanded = x_tensor.unsqueeze(1).expand(-1, num_tests, -1, -1)
# errors = (x_expanded - reconstructed).abs().mean((-1, -2)) # (batch_size, num_tests)
# return errors
# def _decode_sparse_tokens(self, sparse_tokens: torch.Tensor, action_encoding: str | None = None) -> torch.Tensor:
# """Decode sparse tokens to reconstructed data."""
# batch_size, seq_len = sparse_tokens.shape
# # Convert to proper indices format for dequantization
# if self.num_quantizers > 1:
# seq_len_per_quantizer = seq_len // self.num_quantizers
# if seq_len % self.num_quantizers != 0:
# raise ValueError("Sequence length must be divisible by num_quantizers")
# indices_for_decode = sparse_tokens.view(batch_size, self.num_quantizers, seq_len_per_quantizer).transpose(
# 1, 2
# ) # (batch_size, seq_len_per_quantizer, num_quantizers)
# else:
# indices_for_decode = sparse_tokens.unsqueeze(-1) # (batch_size, seq_len, 1)
# # Dequantize and decode
# z_q = self._dequantize(indices_for_decode)
# reconstructed = self._decode(z_q, action_encoding)
# return reconstructed
# def _create_sparse_tokens_from_lengths(
# self, indices_flat: torch.Tensor, optimal_lengths: torch.Tensor
# ) -> torch.Tensor:
# """Create sparse tokens based on optimal lengths (vectorized)."""
# batch_size, seq_len = indices_flat.shape
# device = indices_flat.device
# # Create position mask for all batch items simultaneously
# positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) # (batch_size, seq_len)
# length_mask = positions < optimal_lengths.unsqueeze(1) # (batch_size, seq_len)
# # Apply mask to create sparse tokens
# result = torch.where(length_mask, indices_flat, torch.zeros_like(indices_flat))
# return result
def forward(self, x: torch.Tensor, embodiment_ids: int | None = None, padding_mask: List[bool] | None = None):
return self.encode(x, embodiment_ids, padding_mask)
AutoModel.register(ActionCodecConfig, ActionCodec)
__all__ = ["ActionCodec"]
if __name__ == "__main__":
print("=== ActionCodec Comprehensive Test ===\n")
# 1. Configuration Setup (RVQ enabled with n_quantizers=4)
initial_config = {
"robot_A": {"action_dim": 7, "freq": 10, "duration": 1, "description": "Robot A"},
}
# We set n_quantizers=4 to test Residual VQ logic
config = ActionCodecConfig(
embodiment_config=initial_config,
n_tokens=16, # Total tokens per sequence (latent_len * n_quantizers)
n_quantizers=4, # RVQ depth
vq_type="rvq",
vq_codebook_size=256,
encoder_dim=128,
decoder_dim=128,
)
# Expected latent sequence length = n_tokens / n_quantizers = 16 / 4 = 4
latent_seq_len = int(config.n_tokens // config.n_quantizers)
print(f"Config: {config.n_quantizers} quantizers, {latent_seq_len} latent vectors per sequence.")
codec = ActionCodec(config)
codec.eval()
# 2. Basic Encode/Decode Test
print("\n--- Test 1: Basic Encode/Decode ---")
batch_size = 2
seq_len_A = 10 # 10Hz * 1s
# Create random action data for Robot A (ID 0)
x = np.random.randn(batch_size, seq_len_A, 7).astype(np.float32)
# Masking: Second item in batch is half padding
padding_mask = np.ones((batch_size, seq_len_A), dtype=bool)
padding_mask[1, 5:] = False
embodiment_ids = [0, 0]
# Encode
codes = codec.encode(x, embodiment_ids, padding_mask)
print(f"Encoded codes shape (list length): {len(codes)} x {len(codes[0])}")
# Validate code length
assert len(codes[0]) == config.n_tokens, f"Expected {config.n_tokens} tokens, got {len(codes[0])}"
# Decode
x_recon, recon_mask = codec.decode(codes, embodiment_ids)
print(f"Reconstructed shape: {x_recon.shape}")
print(f"Recon mask shape: {recon_mask.shape}")
assert x_recon.shape == (batch_size, seq_len_A, 7) # Should imply zero-padding to max dim 7
# 3. Expansion Test
print("\n--- Test 2: Dynamic Expansion ---")
new_robot_config = {"robot_B": {"action_dim": 10, "freq": 20, "duration": 1, "description": "Robot B (Larger)"}}
print("Expanding codec to include Robot B (10 dims, 20Hz)...")
codec.expand_embodiment(new_robot_config)
assert codec.encoder.max_action_dim == 10
assert codec.decoder.max_action_dim == 10
print("✅ Expansion successful.")
# 4. Mixed Batch Test (Old + New Robot)
print("\n--- Test 3: Mixed Batch Inference ---")
# Batch: [Robot A, Robot B]
# Robot A: 10Hz, 1s -> 10 steps. Dims 7.
# Robot B: 20Hz, 1s -> 20 steps. Dims 10.
# Batch Max Steps: 20. Batch Max Dims: 10.
batch_x_mixed = np.zeros((2, 20, 10), dtype=np.float32)
# Fill Robot A data (index 0)
data_A = np.random.randn(10, 7)
batch_x_mixed[0, :10, :7] = data_A
# Fill Robot B data (index 1)
data_B = np.random.randn(20, 10)
batch_x_mixed[1, :20, :10] = data_B
# Embodiment IDs: 0 for A, 1 for B
# Note: expand_embodiment appends. Original was 0, new is 1.
mixed_ids = [0, 1]
# Encode Mask
mixed_mask = np.zeros((2, 20), dtype=bool)
mixed_mask[0, :10] = True
mixed_mask[1, :20] = True
print("Encoding mixed batch...")
mixed_codes = codec.encode(batch_x_mixed, mixed_ids, mixed_mask)
print("Decoding mixed batch...")
# Explicit durations (optional, but good for verification if we wanted to override defaults)
durations = [1, 1]
x_recon_mixed, dec_mask_mixed = codec.decode(mixed_codes, mixed_ids, durations)
print(f"Mixed Recon Shape: {x_recon_mixed.shape}")
# Validation
# Robot A output check (mask should be True for first 10, False for rest)
valid_A = dec_mask_mixed[0].sum()
valid_B = dec_mask_mixed[1].sum()
print(f"Valid steps detected by Decoder: Robot A={valid_A}, Robot B={valid_B}")
assert valid_A == 10
assert valid_B == 20
# Check dimensionality preservation
# Robot A's reconstruction in dims 7-9 should be noise or zero (depending on implementation),
# but dims 0-6 should contain signal.
print("✅ Mixed batch processed successfully.")
print("\n✨ All systems go.")
|