Robotics
Transformers
Safetensors
English
prts_qwen3_vl
feature-extraction
vision-language-action
vla
contrastive-reinforcement-learning
goal-conditioned-rl
qwen3-vl
prts
custom_code
Instructions to use TeleEmbodied/PRTS-4B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use TeleEmbodied/PRTS-4B with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("TeleEmbodied/PRTS-4B", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| DiT (Diffusion Transformer) based flow matching action head for PRTS. | |
| Replaces the Qwen3VLTextModel-based fm_action_expert with a lightweight DiT | |
| that uses explicit cross-attention to VLM hidden states, following the architecture | |
| from GR00T / pi05. | |
| Architecture: | |
| ActionEncoder(noisy_actions + dof_mask, timestep) | |
| → action_features | |
| → DiT(cross-attn to VLM hidden states, ada-norm timestep conditioning) | |
| → ActionDecoder → predicted velocity | |
| """ | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.distributions import Beta | |
| from typing import Optional | |
| from transformers.cache_utils import Cache | |
| from transformers.modeling_flash_attention_utils import _flash_attention_forward | |
| # DIT_PRESETS = { | |
| # "DiT-B": {"num_attention_heads": 12, "attention_head_dim": 64, "output_dim": 768}, | |
| # "DiT-L": {"num_attention_heads": 32, "attention_head_dim": 48, "output_dim": 1536}, | |
| # } | |
| class SinusoidalPositionalEncoding(nn.Module): | |
| """Sinusoidal positional encoding for sequence positions or timesteps.""" | |
| def __init__(self, embedding_dim: int): | |
| super().__init__() | |
| self.embedding_dim = embedding_dim | |
| def forward(self, timesteps: torch.Tensor) -> torch.Tensor: | |
| timesteps = timesteps.float() | |
| squeeze = False | |
| if timesteps.dim() == 1: | |
| timesteps = timesteps.unsqueeze(1) | |
| squeeze = True | |
| half_dim = self.embedding_dim // 2 | |
| exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * ( | |
| math.log(10000.0) / half_dim | |
| ) | |
| freqs = timesteps.unsqueeze(-1) * exponent.exp() | |
| enc = torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1) | |
| if squeeze: | |
| enc = enc.squeeze(1) | |
| return enc | |
| class TimestepEncoder(nn.Module): | |
| """Projects scalar timesteps to embedding space via sinusoidal encoding + MLP.""" | |
| def __init__(self, embedding_dim: int): | |
| super().__init__() | |
| self.sinusoidal = SinusoidalPositionalEncoding(256) | |
| self.linear_1 = nn.Linear(256, embedding_dim) | |
| self.act = nn.SiLU() | |
| self.linear_2 = nn.Linear(embedding_dim, embedding_dim) | |
| def forward(self, timesteps: torch.Tensor) -> torch.Tensor: | |
| t_emb = self.sinusoidal(timesteps) | |
| t_emb = self.linear_1(t_emb.to(dtype=self.linear_1.weight.dtype)) | |
| t_emb = self.act(t_emb) | |
| t_emb = self.linear_2(t_emb) | |
| return t_emb | |
| class AdaLayerNorm(nn.Module): | |
| """Adaptive Layer Normalization conditioned on timestep embeddings. | |
| Applies scale-shift modulation: out = norm(x) * (1 + scale) + shift, | |
| where (scale, shift) are linearly projected from the timestep embedding. | |
| """ | |
| def __init__(self, embedding_dim: int, eps: float = 1e-5): | |
| super().__init__() | |
| self.silu = nn.SiLU() | |
| self.linear = nn.Linear(embedding_dim, embedding_dim * 2) | |
| self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=False) | |
| def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: | |
| temb = self.linear(self.silu(temb)) | |
| scale, shift = temb.chunk(2, dim=-1) | |
| x = self.norm(x) * (1 + scale[:, None]) + shift[:, None] | |
| return x | |
| class DiTAttention(nn.Module): | |
| """Multi-head attention supporting both self-attention and cross-attention. | |
| Supports two backends selected via ``attn_implementation``: | |
| * ``"sdpa"`` (default) – uses :func:`F.scaled_dot_product_attention`, which | |
| dispatches automatically to FlashAttention / memory-efficient attention | |
| depending on the installed PyTorch build. The encoder padding mask is | |
| expanded to ``(B, 1, 1, S)`` and passed as ``attn_mask``. | |
| * ``"flash_attention_2"`` – calls the ``flash_attn`` package directly for | |
| lower memory usage and higher throughput. For cross-attention with an | |
| encoder padding mask the k/v tensors are unpadded and | |
| :func:`flash_attn_varlen_func` is used so that padding tokens are never | |
| processed. For self-attention (no mask) the simpler | |
| :func:`flash_attn_func` is used. | |
| """ | |
| def __init__( | |
| self, | |
| query_dim: int, | |
| num_heads: int, | |
| head_dim: int, | |
| cross_attention_dim: Optional[int] = None, | |
| dropout: float = 0.0, | |
| bias: bool = True, | |
| attn_implementation: str = "sdpa", | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.head_dim = head_dim | |
| self.attn_implementation = attn_implementation | |
| inner_dim = num_heads * head_dim | |
| self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) | |
| kv_dim = cross_attention_dim if cross_attention_dim is not None else query_dim | |
| self.to_k = nn.Linear(kv_dim, inner_dim, bias=bias) | |
| self.to_v = nn.Linear(kv_dim, inner_dim, bias=bias) | |
| self.to_out = nn.Sequential( | |
| nn.Linear(inner_dim, query_dim, bias=bias), | |
| nn.Dropout(dropout), | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Flash-Attention backend | |
| # ------------------------------------------------------------------ | |
| def _flash_attn_forward( | |
| self, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor], | |
| ) -> torch.Tensor: | |
| """Run Flash Attention via HuggingFace's ``_flash_attention_forward``. | |
| Args: | |
| q: ``(B, T_q, H, D)`` | |
| k: ``(B, T_k, H, D)`` | |
| v: ``(B, T_k, H, D)`` | |
| attention_mask: ``(B, T_k)`` bool, True = valid token. | |
| Returns: | |
| ``(B, T_q, H*D)`` | |
| """ | |
| B, T_q, H, D = q.shape | |
| # _flash_attention_forward returns (B, T_q, H, D); handles unpad/varlen internally. | |
| out = _flash_attention_forward( | |
| q, k, v, | |
| attention_mask=attention_mask, | |
| query_length=T_q, | |
| is_causal=False, | |
| dropout=0.0, | |
| ) | |
| return out.reshape(B, T_q, H * D) | |
| # ------------------------------------------------------------------ | |
| # Forward | |
| # ------------------------------------------------------------------ | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| B, T, _ = hidden_states.shape | |
| q = self.to_q(hidden_states) | |
| kv_input = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | |
| k = self.to_k(kv_input) | |
| v = self.to_v(kv_input) | |
| if self.attn_implementation == "flash_attention_2": | |
| # Flash Attention expects (B, S, H, D) | |
| q = q.view(B, T, self.num_heads, self.head_dim) | |
| k = k.view(B, -1, self.num_heads, self.head_dim) | |
| v = v.view(B, -1, self.num_heads, self.head_dim) | |
| attn_output = self._flash_attn_forward(q, k, v, attention_mask) | |
| else: | |
| # SDPA expects (B, H, S, D) | |
| q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) | |
| k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) | |
| v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) | |
| # Expand (B, S) bool mask → (B, 1, 1, S) for broadcasting. | |
| sdpa_mask = None | |
| if attention_mask is not None: | |
| if attention_mask.dim() == 2: | |
| sdpa_mask = attention_mask[:, None, None, :] | |
| else: | |
| sdpa_mask = attention_mask | |
| attn_output = F.scaled_dot_product_attention( | |
| q, k, v, attn_mask=sdpa_mask, dropout_p=0.0 | |
| ) | |
| attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, -1) | |
| return self.to_out(attn_output) | |
| class FeedForward(nn.Module): | |
| """Feed-forward network with GELU activation.""" | |
| def __init__(self, dim: int, dropout: float = 0.0, mult: int = 4): | |
| super().__init__() | |
| inner_dim = dim * mult | |
| self.net = nn.Sequential( | |
| nn.Linear(dim, inner_dim), | |
| nn.GELU(approximate="tanh"), | |
| nn.Dropout(dropout), | |
| nn.Linear(inner_dim, dim), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.net(x) | |
| class BasicTransformerBlock(nn.Module): | |
| """Transformer block with self/cross-attention, optional AdaLayerNorm, and feed-forward. | |
| When cross_attention_dim is set, the attention block performs cross-attention | |
| to encoder_hidden_states. Otherwise, it performs self-attention. | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_attention_heads: int, | |
| attention_head_dim: int, | |
| dropout: float = 0.0, | |
| cross_attention_dim: Optional[int] = None, | |
| norm_type: str = "ada_norm", | |
| final_dropout: bool = False, | |
| attn_implementation: str = "sdpa", | |
| ): | |
| super().__init__() | |
| self.norm_type = norm_type | |
| if norm_type == "ada_norm": | |
| self.norm1 = AdaLayerNorm(dim) | |
| else: | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.attn1 = DiTAttention( | |
| query_dim=dim, | |
| num_heads=num_attention_heads, | |
| head_dim=attention_head_dim, | |
| cross_attention_dim=cross_attention_dim, | |
| dropout=dropout, | |
| attn_implementation=attn_implementation, | |
| ) | |
| self.norm3 = nn.LayerNorm(dim) | |
| self.ff = FeedForward(dim, dropout=dropout) | |
| self.final_dropout = nn.Dropout(dropout) if final_dropout else None | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| temb: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| if self.norm_type == "ada_norm": | |
| norm_hidden_states = self.norm1(hidden_states, temb) | |
| else: | |
| norm_hidden_states = self.norm1(hidden_states) | |
| attn_output = self.attn1( | |
| norm_hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=encoder_attention_mask, | |
| ) | |
| if self.final_dropout is not None: | |
| attn_output = self.final_dropout(attn_output) | |
| hidden_states = attn_output + hidden_states | |
| norm_hidden_states = self.norm3(hidden_states) | |
| ff_output = self.ff(norm_hidden_states) | |
| hidden_states = ff_output + hidden_states | |
| return hidden_states | |
| class DiT(nn.Module): | |
| """Diffusion Transformer with cross-attention to VLM context features. | |
| Interleaves cross-attention blocks (attending to encoder_hidden_states) | |
| with self-attention blocks when interleave_self_attention=True. | |
| Uses AdaLayerNorm for timestep conditioning throughout. | |
| Output block applies timestep-conditioned scale-shift before final projection. | |
| """ | |
| def __init__( | |
| self, | |
| num_attention_heads: int = 12, | |
| attention_head_dim: int = 64, | |
| output_dim: int = 768, | |
| num_layers: int = 12, | |
| dropout: float = 0.1, | |
| norm_type: str = "ada_norm", | |
| final_dropout: bool = True, | |
| interleave_self_attention: bool = False, | |
| cross_attention_dim: Optional[int] = None, | |
| attn_implementation: str = "sdpa", | |
| ): | |
| super().__init__() | |
| self.inner_dim = num_attention_heads * attention_head_dim | |
| self.output_dim = output_dim | |
| self.num_layers = num_layers | |
| self.interleave_self_attention = interleave_self_attention | |
| self.timestep_encoder = TimestepEncoder(self.inner_dim) | |
| all_blocks = [] | |
| for idx in range(num_layers): | |
| use_self_attn = idx % 2 == 1 and interleave_self_attention | |
| curr_cross_attention_dim = cross_attention_dim if not use_self_attn else None | |
| all_blocks.append( | |
| BasicTransformerBlock( | |
| dim=self.inner_dim, | |
| num_attention_heads=num_attention_heads, | |
| attention_head_dim=attention_head_dim, | |
| dropout=dropout, | |
| cross_attention_dim=curr_cross_attention_dim, | |
| norm_type=norm_type, | |
| final_dropout=final_dropout, | |
| attn_implementation=attn_implementation, | |
| ) | |
| ) | |
| self.transformer_blocks = nn.ModuleList(all_blocks) | |
| self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) | |
| self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) | |
| self.proj_out_2 = nn.Linear(self.inner_dim, output_dim) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| timestep: torch.LongTensor, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| temb = self.timestep_encoder(timestep) | |
| hidden_states = hidden_states.contiguous() | |
| encoder_hidden_states = encoder_hidden_states.contiguous() | |
| for idx, block in enumerate(self.transformer_blocks): | |
| if idx % 2 == 1 and self.interleave_self_attention: | |
| hidden_states = block( | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| encoder_attention_mask=None, | |
| temb=temb, | |
| ) | |
| else: | |
| hidden_states = block( | |
| hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| temb=temb, | |
| ) | |
| conditioning = temb | |
| shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=-1) | |
| hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] | |
| return self.proj_out_2(hidden_states) | |
| class AlternateVLDiT(DiT): | |
| """DiT variant that separates visual and text tokens during cross-attention. | |
| Mirrors GR00T's AlternateVLDiT: even-indexed blocks do cross-attention, | |
| alternating every ``attend_text_every_n_blocks`` between text tokens and | |
| visual tokens. Odd-indexed blocks do self-attention (requires | |
| ``interleave_self_attention=True``). | |
| When no visual tokens are present (``image_mask`` is None or all-False), | |
| all valid tokens are treated as text. | |
| """ | |
| def __init__(self, *args, attend_text_every_n_blocks: int = 2, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| assert self.interleave_self_attention, ( | |
| "AlternateVLDiT requires interleave_self_attention=True" | |
| ) | |
| self.attend_text_every_n_blocks = attend_text_every_n_blocks | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| timestep: torch.LongTensor, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| image_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| encoder_attention_mask: (B, S) bool – True = valid VLM token. | |
| image_mask: (B, S) bool – True = visual token position. | |
| If None, all valid tokens are treated as text. | |
| """ | |
| temb = self.timestep_encoder(timestep) | |
| hidden_states = hidden_states.contiguous() | |
| encoder_hidden_states = encoder_hidden_states.contiguous() | |
| B, S, _ = encoder_hidden_states.shape | |
| backbone_mask = ( | |
| encoder_attention_mask.bool() | |
| if encoder_attention_mask is not None | |
| else torch.ones(B, S, dtype=torch.bool, device=hidden_states.device) | |
| ) | |
| if image_mask is not None and image_mask.any(): | |
| vis_mask = image_mask.bool() & backbone_mask # visual tokens | |
| text_mask = (~image_mask.bool()) & backbone_mask # text tokens | |
| else: | |
| # No visual tokens – treat everything as text. | |
| vis_mask = torch.zeros_like(backbone_mask) | |
| text_mask = backbone_mask | |
| for idx, block in enumerate(self.transformer_blocks): | |
| if idx % 2 == 1: | |
| # Self-attention block. | |
| hidden_states = block( | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| encoder_attention_mask=None, | |
| temb=temb, | |
| ) | |
| else: | |
| # Cross-attention block: alternate text / visual every N blocks. | |
| if idx % (2 * self.attend_text_every_n_blocks) == 0: | |
| curr_mask = text_mask | |
| else: | |
| curr_mask = vis_mask | |
| hidden_states = block( | |
| hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=curr_mask, | |
| temb=temb, | |
| ) | |
| conditioning = temb | |
| shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=-1) | |
| hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] | |
| return self.proj_out_2(hidden_states) | |
| class ActionEncoder(nn.Module): | |
| """Encodes noisy actions (optionally concatenated with DOF mask) and timestep | |
| into hidden features via MLP + sinusoidal time encoding. | |
| Architecture: Linear → concat(action_emb, time_emb) → SiLU + Linear → Linear | |
| """ | |
| def __init__(self, action_input_dim: int, hidden_size: int): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.layer1 = nn.Linear(action_input_dim, hidden_size) | |
| self.layer2 = nn.Linear(2 * hidden_size, hidden_size) | |
| self.layer3 = nn.Linear(hidden_size, hidden_size) | |
| self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) | |
| def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| actions: (B, T, action_input_dim) noisy actions (+ DOF mask) | |
| timesteps: (B,) discretized timesteps | |
| """ | |
| B, T, _ = actions.shape | |
| timesteps_expanded = timesteps.unsqueeze(1).expand(-1, T) | |
| a_emb = self.layer1(actions) | |
| tau_emb = self.pos_encoding(timesteps_expanded).to(dtype=a_emb.dtype) | |
| x = torch.cat([a_emb, tau_emb], dim=-1) | |
| x = F.silu(self.layer2(x)) | |
| x = self.layer3(x) | |
| return x | |
| class ActionDecoder(nn.Module): | |
| """2-layer MLP that decodes DiT output to action-space velocity.""" | |
| def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): | |
| super().__init__() | |
| self.layer1 = nn.Linear(input_dim, hidden_dim) | |
| self.layer2 = nn.Linear(hidden_dim, output_dim) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.layer2(F.relu(self.layer1(x))) | |
| class FlowMatchingDiTHead(nn.Module): | |
| """Flow matching action head using DiT (Diffusion Transformer). | |
| Replaces the fm_action_expert (Qwen3VLTextModel-based) with a DiT that uses | |
| explicit cross-attention to VLM hidden states instead of KV cache continuation. | |
| Training: | |
| 1. Sample noise and timestep from Beta distribution | |
| 2. Compute noisy trajectory: x_t = (1-t)*noise + t*actions | |
| 3. Compute velocity target: v = actions - noise | |
| 4. Encode noisy actions + DOF mask + timestep → action features | |
| 5. Prepend learned future query tokens | |
| 6. Run DiT with cross-attention to VLM hidden states | |
| 7. Decode to action-space velocity prediction | |
| Inference: | |
| Euler integration from pure noise (t=0) to clean actions (t=1) | |
| over num_inference_timesteps steps. | |
| """ | |
| def __init__( | |
| self, | |
| action_dim: int, | |
| action_chunk_size: int, | |
| cross_attention_dim: int, | |
| num_inference_timesteps: int = 4, | |
| config: Optional[dict] = None, | |
| ): | |
| super().__init__() | |
| cfg = { | |
| "num_layers": 16, | |
| "num_attention_heads": 12, | |
| "attention_head_dim": 64, | |
| "output_dim": 1024, | |
| "dropout": 0.2, | |
| "interleave_self_attention": True, | |
| "norm_type": "ada_norm", | |
| "final_dropout": True, | |
| "add_pos_embed": True, | |
| "noise_beta_alpha": 1.5, | |
| "noise_beta_beta": 1.0, | |
| "noise_s": 0.999, | |
| "num_timestep_buckets": 1000, | |
| "attn_implementation": "sdpa", | |
| "use_alternate_vl_dit": False, | |
| "attend_text_every_n_blocks": 2, | |
| } | |
| if config is not None: | |
| cfg.update(config) | |
| # dit_model_type = config.get("dit_model_type") | |
| # if dit_model_type and dit_model_type in DIT_PRESETS: | |
| # cfg.update(DIT_PRESETS[dit_model_type]) | |
| # cfg.pop("dit_model_type", None) | |
| self.action_dim = action_dim | |
| self.action_chunk_size = action_chunk_size | |
| self.num_inference_timesteps = num_inference_timesteps | |
| self.num_timestep_buckets = cfg["num_timestep_buckets"] | |
| self.noise_s = cfg["noise_s"] | |
| self.use_alternate_vl_dit = cfg["use_alternate_vl_dit"] | |
| self.add_pos_embed = cfg["add_pos_embed"] | |
| num_attention_heads = cfg["num_attention_heads"] | |
| attention_head_dim = cfg["attention_head_dim"] | |
| output_dim = cfg["output_dim"] | |
| inner_dim = num_attention_heads * attention_head_dim | |
| dit_kwargs = dict( | |
| num_attention_heads=num_attention_heads, | |
| attention_head_dim=attention_head_dim, | |
| output_dim=output_dim, | |
| num_layers=cfg["num_layers"], | |
| dropout=cfg["dropout"], | |
| norm_type=cfg["norm_type"], | |
| final_dropout=cfg["final_dropout"], | |
| interleave_self_attention=cfg["interleave_self_attention"], | |
| cross_attention_dim=cross_attention_dim, | |
| attn_implementation=cfg["attn_implementation"], | |
| ) | |
| if self.use_alternate_vl_dit: | |
| self.dit = AlternateVLDiT( | |
| **dit_kwargs, | |
| attend_text_every_n_blocks=cfg["attend_text_every_n_blocks"], | |
| ) | |
| else: | |
| self.dit = DiT(**dit_kwargs) | |
| # action_dim * 2: noisy action + DOF mask concatenated | |
| self.action_encoder = ActionEncoder(action_dim * 2, inner_dim) | |
| self.action_decoder = ActionDecoder(output_dim, inner_dim, action_dim) | |
| if self.add_pos_embed: | |
| max_seq_len = max(action_chunk_size, 256) | |
| self.position_embedding = nn.Embedding(max_seq_len, inner_dim) | |
| nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) | |
| # self.beta_dist = Beta(cfg["noise_beta_alpha"], cfg["noise_beta_beta"]) | |
| self._beta_alpha = cfg["noise_beta_alpha"] | |
| self._beta_beta = cfg["noise_beta_beta"] | |
| def reset_parameters(self): | |
| """Re-apply proper initialization. | |
| HuggingFace from_pretrained calls _init_weights on modules whose | |
| parameters are absent from the checkpoint, overwriting any custom | |
| init done in __init__. Call this after from_pretrained when loading | |
| from a base VLM checkpoint that does not contain DiT weights. | |
| """ | |
| if self.add_pos_embed: | |
| nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) | |
| for module in self.modules(): | |
| if isinstance(module, nn.Linear): | |
| nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) | |
| if module.bias is not None: | |
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) | |
| bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 | |
| nn.init.uniform_(module.bias, -bound, bound) | |
| elif isinstance(module, nn.LayerNorm): | |
| if module.elementwise_affine: | |
| nn.init.ones_(module.weight) | |
| nn.init.zeros_(module.bias) | |
| def sample_time(self, batch_size: int, device, dtype) -> torch.Tensor: | |
| beta_dist = Beta(self._beta_alpha, self._beta_beta) | |
| sample = beta_dist.sample([batch_size]).to(device, dtype=dtype).clamp(max=self.noise_s) | |
| return (self.noise_s - sample) / self.noise_s | |
| def _encode_actions( | |
| self, | |
| noisy_actions: torch.Tensor, | |
| t_discretized: torch.Tensor, | |
| action_dof_mask: Optional[torch.Tensor], | |
| device, | |
| ) -> torch.Tensor: | |
| """Encode noisy actions with DOF mask and timestep, add position embeddings.""" | |
| if action_dof_mask is not None: | |
| encoder_input = torch.cat( | |
| [noisy_actions, action_dof_mask.to(noisy_actions.dtype)], dim=-1 | |
| ) | |
| else: | |
| encoder_input = torch.cat( | |
| [noisy_actions, torch.ones_like(noisy_actions)], dim=-1 | |
| ) | |
| action_features = self.action_encoder(encoder_input, t_discretized) | |
| if self.add_pos_embed: | |
| pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) | |
| pos_embs = self.position_embedding(pos_ids).unsqueeze(0) | |
| action_features = action_features + pos_embs | |
| return action_features | |
| def _dit_forward( | |
| self, | |
| sa_embs: torch.Tensor, | |
| vl_embs: torch.Tensor, | |
| t_discretized: torch.LongTensor, | |
| encoder_attention_mask: Optional[torch.Tensor], | |
| image_mask: Optional[torch.Tensor], | |
| ) -> torch.Tensor: | |
| if self.use_alternate_vl_dit: | |
| return self.dit( | |
| hidden_states=sa_embs, | |
| encoder_hidden_states=vl_embs, | |
| timestep=t_discretized, | |
| encoder_attention_mask=encoder_attention_mask, | |
| image_mask=image_mask, | |
| ) | |
| return self.dit( | |
| hidden_states=sa_embs, | |
| encoder_hidden_states=vl_embs, | |
| timestep=t_discretized, | |
| encoder_attention_mask=encoder_attention_mask, | |
| ) | |
| def forward( | |
| self, | |
| vl_embs: torch.Tensor, | |
| actions: torch.Tensor, | |
| action_dof_mask: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| image_mask: Optional[torch.Tensor] = None, | |
| ) -> tuple: | |
| """Training forward pass. | |
| Args: | |
| vl_embs: (B, S, D) VLM hidden states for cross-attention | |
| actions: (B, T, action_dim) ground truth action trajectories | |
| action_dof_mask: (B, T, action_dim) DOF validity mask | |
| encoder_attention_mask: (B, S) bool – True = valid VLM token | |
| image_mask: (B, S) bool – True = visual token (used by AlternateVLDiT) | |
| Returns: | |
| (pred_v, velocity): predicted velocity and target velocity, both (B, T, action_dim) | |
| """ | |
| device = vl_embs.device | |
| B = actions.shape[0] | |
| noise = torch.randn(actions.shape, device=device, dtype=actions.dtype) | |
| t = self.sample_time(B, device=device, dtype=actions.dtype) | |
| t_expanded = t[:, None, None] | |
| noisy_trajectory = (1 - t_expanded) * noise + t_expanded * actions | |
| velocity = actions - noise | |
| t_discretized = (t * self.num_timestep_buckets).long() | |
| action_features = self._encode_actions(noisy_trajectory, t_discretized, action_dof_mask, device) | |
| model_output = self._dit_forward( | |
| action_features, vl_embs, t_discretized, encoder_attention_mask, image_mask | |
| ) | |
| pred = self.action_decoder(model_output) | |
| pred_v = pred[:, :actions.shape[1]] | |
| return pred_v, velocity | |
| def predict_action( | |
| self, | |
| vl_embs: torch.Tensor, | |
| action_dof_mask: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| image_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """Inference: denoise actions from noise using Euler integration. | |
| Args: | |
| vl_embs: (B, S, D) VLM hidden states | |
| action_dof_mask: optional (B, T, action_dim) or (1, T, action_dim) DOF mask | |
| encoder_attention_mask: (B, S) bool – True = valid VLM token | |
| image_mask: (B, S) bool – True = visual token (used by AlternateVLDiT) | |
| Returns: | |
| (B, T, action_dim) denoised action trajectories | |
| """ | |
| B = vl_embs.shape[0] | |
| device = vl_embs.device | |
| dtype = vl_embs.dtype | |
| actions = torch.randn( | |
| (B, self.action_chunk_size, self.action_dim), | |
| device=device, dtype=dtype, | |
| ) | |
| dt = 1.0 / self.num_inference_timesteps | |
| for step in range(self.num_inference_timesteps): | |
| t_cont = step / float(self.num_inference_timesteps) | |
| t_discretized_val = int(t_cont * self.num_timestep_buckets) | |
| timesteps_tensor = torch.full((B,), t_discretized_val, device=device, dtype=torch.long) | |
| action_features = self._encode_actions(actions, timesteps_tensor, action_dof_mask, device) | |
| model_output = self._dit_forward( | |
| action_features, vl_embs, timesteps_tensor, encoder_attention_mask, image_mask | |
| ) | |
| pred = self.action_decoder(model_output) | |
| pred_velocity = pred[:, :self.action_chunk_size] | |
| actions = actions + dt * pred_velocity | |
| return actions | |
| # ============================================================================ | |
| # Pi0.5-style KV-cache action expert (VLM K/V concat + GQA + SwiGLU FFN) | |
| # ============================================================================ | |
| class AdaRMSNorm(nn.Module): | |
| """Adaptive RMS normalization: (scale, shift, gate) from cond; zero-init.""" | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.modulation = nn.Linear(dim, dim * 3) | |
| nn.init.zeros_(self.modulation.weight) | |
| nn.init.zeros_(self.modulation.bias) | |
| def forward(self, x: torch.Tensor, cond: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| var = x.float().pow(2).mean(-1, keepdim=True) | |
| normed = (x * torch.rsqrt(var + self.eps)).to(x.dtype) | |
| scale, shift, gate = self.modulation(cond).chunk(3, dim=-1) | |
| normed = normed * (1 + scale[:, None]) + shift[:, None] | |
| return normed, gate[:, None] | |
| class SwiGLUFeedForward(nn.Module): | |
| """SiLU(gate_proj(x)) * up_proj(x) → down_proj.""" | |
| def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0, bias: bool = True): | |
| super().__init__() | |
| self.gate_proj = nn.Linear(dim, hidden_dim, bias=bias) | |
| self.up_proj = nn.Linear(dim, hidden_dim, bias=bias) | |
| self.down_proj = nn.Linear(hidden_dim, dim, bias=bias) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.down_proj(self.dropout(F.silu(self.gate_proj(x)) * self.up_proj(x))) | |
| class MoTAttention(nn.Module): | |
| """Action Q attends to concatenated [VLM KV cache ; action KV]; GQA expand for SDPA.""" | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| num_attention_heads: int, | |
| num_kv_heads: int, | |
| head_dim: int, | |
| dropout: float = 0.0, | |
| bias: bool = True, | |
| ): | |
| super().__init__() | |
| if num_attention_heads % num_kv_heads != 0: | |
| raise ValueError( | |
| f"num_attention_heads ({num_attention_heads}) must be divisible by " | |
| f"num_kv_heads ({num_kv_heads})" | |
| ) | |
| self.num_attention_heads = num_attention_heads | |
| self.num_kv_heads = num_kv_heads | |
| self.head_dim = head_dim | |
| q_dim = num_attention_heads * head_dim | |
| kv_dim = num_kv_heads * head_dim | |
| self.q_proj = nn.Linear(hidden_size, q_dim, bias=bias) | |
| self.k_proj = nn.Linear(hidden_size, kv_dim, bias=bias) | |
| self.v_proj = nn.Linear(hidden_size, kv_dim, bias=bias) | |
| self.o_proj = nn.Linear(q_dim, hidden_size, bias=bias) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward( | |
| self, | |
| action_hidden: torch.Tensor, | |
| vlm_cached_k: torch.Tensor, | |
| vlm_cached_v: torch.Tensor, | |
| vlm_attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| B, T_a, _ = action_hidden.shape | |
| q = self.q_proj(action_hidden) | |
| act_k = self.k_proj(action_hidden) | |
| act_v = self.v_proj(action_hidden) | |
| q = q.view(B, T_a, self.num_attention_heads, self.head_dim).transpose(1, 2) | |
| act_k = act_k.view(B, T_a, self.num_kv_heads, self.head_dim).transpose(1, 2) | |
| act_v = act_v.view(B, T_a, self.num_kv_heads, self.head_dim).transpose(1, 2) | |
| k = torch.cat([vlm_cached_k, act_k], dim=2) | |
| v = torch.cat([vlm_cached_v, act_v], dim=2) | |
| repeat_factor = self.num_attention_heads // self.num_kv_heads | |
| k = k.repeat_interleave(repeat_factor, dim=1) | |
| v = v.repeat_interleave(repeat_factor, dim=1) | |
| sdpa_mask = None | |
| if vlm_attention_mask is not None: | |
| action_mask = vlm_attention_mask.new_ones(B, T_a) | |
| combined_mask = torch.cat([vlm_attention_mask, action_mask], dim=1) | |
| sdpa_mask = combined_mask[:, None, None, :] | |
| attn_out = F.scaled_dot_product_attention( | |
| q, k, v, attn_mask=sdpa_mask, dropout_p=0.0, | |
| ) | |
| attn_out = attn_out.transpose(1, 2).contiguous().view(B, T_a, -1) | |
| return self.dropout(self.o_proj(attn_out)) | |
| class MoTBlock(nn.Module): | |
| """AdaRMSNorm → attention → gated residual → AdaRMSNorm → SwiGLU FFN → gated residual.""" | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| num_attention_heads: int, | |
| num_kv_heads: int, | |
| head_dim: int, | |
| intermediate_size: int, | |
| dropout: float = 0.0, | |
| ): | |
| super().__init__() | |
| self.pre_attn_norm = AdaRMSNorm(hidden_size) | |
| self.attn = MoTAttention( | |
| hidden_size=hidden_size, | |
| num_attention_heads=num_attention_heads, | |
| num_kv_heads=num_kv_heads, | |
| head_dim=head_dim, | |
| dropout=dropout, | |
| ) | |
| self.pre_ffn_norm = AdaRMSNorm(hidden_size) | |
| self.ffn = SwiGLUFeedForward(hidden_size, intermediate_size, dropout=dropout) | |
| def forward( | |
| self, | |
| action_hidden: torch.Tensor, | |
| vlm_cached_k: torch.Tensor, | |
| vlm_cached_v: torch.Tensor, | |
| adarms_cond: torch.Tensor, | |
| vlm_attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| normed, gate1 = self.pre_attn_norm(action_hidden, adarms_cond) | |
| attn_out = self.attn(normed, vlm_cached_k, vlm_cached_v, vlm_attention_mask) | |
| action_hidden = action_hidden + attn_out * gate1 | |
| normed2, gate2 = self.pre_ffn_norm(action_hidden, adarms_cond) | |
| action_hidden = action_hidden + self.ffn(normed2) * gate2 | |
| return action_hidden | |
| class MoTDiT(nn.Module): | |
| """Stack of ActionBlocks; each block uses one VLM layer's KV pair.""" | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| num_attention_heads: int, | |
| num_kv_heads: int, | |
| head_dim: int, | |
| intermediate_size: int, | |
| num_layers: int, | |
| dropout: float = 0.2, | |
| ): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| self.blocks = nn.ModuleList([ | |
| MoTBlock( | |
| hidden_size=hidden_size, | |
| num_attention_heads=num_attention_heads, | |
| num_kv_heads=num_kv_heads, | |
| head_dim=head_dim, | |
| intermediate_size=intermediate_size, | |
| dropout=dropout, | |
| ) | |
| for _ in range(num_layers) | |
| ]) | |
| self.final_norm = AdaRMSNorm(hidden_size) | |
| def forward( | |
| self, | |
| action_hidden: torch.Tensor, | |
| vlm_kv_cache: list, | |
| adarms_cond: torch.Tensor, | |
| vlm_attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| for idx, block in enumerate(self.blocks): | |
| cached_k, cached_v = vlm_kv_cache[idx] | |
| action_hidden = block( | |
| action_hidden, cached_k, cached_v, adarms_cond, vlm_attention_mask, | |
| ) | |
| action_hidden, _ = self.final_norm(action_hidden, adarms_cond) | |
| return action_hidden | |
| def _kv_pairs_from_past_key_values(past_key_values: Cache) -> list[tuple[torch.Tensor, torch.Tensor]]: | |
| """Per-layer (K, V) from a HuggingFace decoder KV cache (order matches transformer layers).""" | |
| return [ | |
| (past_key_values[i][0], past_key_values[i][1]) | |
| for i in range(len(past_key_values)) | |
| ] | |
| class MoTFlowMatchingHead(nn.Module): | |
| """Flow matching head: MoT-style action expert over VLM KV cache (concat + GQA).""" | |
| def __init__( | |
| self, | |
| action_dim: int, | |
| action_chunk_size: int, | |
| vlm_config, | |
| num_inference_timesteps: int = 10, | |
| config: Optional[dict] = None, | |
| ): | |
| super().__init__() | |
| _vlm_num_q_heads = 8 # vlm_config.num_attention_heads // 2 # optional: 8 | |
| _vlm_num_kv_heads = vlm_config.num_key_value_heads # 8 | |
| _vlm_head_dim = getattr( | |
| vlm_config, "head_dim", vlm_config.hidden_size // vlm_config.num_attention_heads | |
| ) # 128 | |
| cfg = { | |
| "hidden_size": 1024, # vlm_config.hidden_size // 2, | |
| # "hidden_size": vlm_config.hidden_size // 2, | |
| "intermediate_size": vlm_config.intermediate_size // 4, | |
| "expert_num_layers": vlm_config.num_hidden_layers, | |
| # Attention dims default to VLM values (required for KV cache compat) | |
| "num_attention_heads": _vlm_num_q_heads, | |
| "num_kv_heads": _vlm_num_kv_heads, | |
| "head_dim": _vlm_head_dim, | |
| # Noise schedule | |
| "dropout": 0.2, | |
| "add_pos_embed": True, | |
| "noise_beta_alpha": 1.5, | |
| "noise_beta_beta": 1.0, | |
| "noise_s": 0.999, | |
| "num_timestep_buckets": 1000, | |
| } | |
| if config is not None: | |
| config = cfg.copy() | |
| num_attention_heads = cfg["num_attention_heads"] | |
| num_kv_heads = cfg["num_kv_heads"] | |
| head_dim = cfg["head_dim"] | |
| hidden_size = cfg["hidden_size"] | |
| intermediate_size = cfg["intermediate_size"] | |
| num_layers = cfg["expert_num_layers"] | |
| self.action_dim = action_dim | |
| self.action_chunk_size = action_chunk_size | |
| self.num_inference_timesteps = num_inference_timesteps | |
| self.num_timestep_buckets = cfg["num_timestep_buckets"] | |
| self.noise_s = cfg["noise_s"] | |
| self.add_pos_embed = cfg["add_pos_embed"] | |
| self.action_in_proj = nn.Linear(action_dim * 2, hidden_size) | |
| self.action_out_proj = nn.Linear(hidden_size, action_dim) | |
| self.time_sinusoidal = SinusoidalPositionalEncoding(hidden_size) | |
| self.time_mlp_1 = nn.Linear(hidden_size, hidden_size) | |
| self.time_mlp_2 = nn.Linear(hidden_size, hidden_size) | |
| if self.add_pos_embed: | |
| max_seq = max(action_chunk_size, 256) | |
| self.position_embedding = nn.Embedding(max_seq, hidden_size) | |
| nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) | |
| self.dit = MoTDiT( | |
| hidden_size=hidden_size, | |
| num_attention_heads=num_attention_heads, | |
| num_kv_heads=num_kv_heads, | |
| head_dim=head_dim, | |
| intermediate_size=intermediate_size, | |
| num_layers=num_layers, | |
| dropout=cfg["dropout"], | |
| ) | |
| self._beta_alpha = cfg["noise_beta_alpha"] | |
| self._beta_beta = cfg["noise_beta_beta"] | |
| def num_dit_layers(self) -> int: | |
| """Number of expert blocks; must match ``len(past_key_values.key_cache)``.""" | |
| return self.dit.num_layers | |
| def _vlm_kv_list_from_past(self, past_key_values: Cache) -> list[tuple[torch.Tensor, torch.Tensor]]: | |
| n = len(past_key_values) | |
| if n != self.num_dit_layers: | |
| raise ValueError( | |
| f"MoT expert has {self.num_dit_layers} blocks but `past_key_values` has {n} " | |
| "layers. Set `dit_action_head_config['expert_num_layers']` to match " | |
| "`text_config.num_hidden_layers`." | |
| ) | |
| return _kv_pairs_from_past_key_values(past_key_values) | |
| def reset_parameters(self): | |
| """Re-apply proper initialization after from_pretrained.""" | |
| if self.add_pos_embed: | |
| nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) | |
| for module in self.modules(): | |
| if isinstance(module, AdaRMSNorm): | |
| nn.init.zeros_(module.modulation.weight) | |
| nn.init.zeros_(module.modulation.bias) | |
| elif isinstance(module, nn.Linear): | |
| nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) | |
| if module.bias is not None: | |
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) | |
| bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 | |
| nn.init.uniform_(module.bias, -bound, bound) | |
| def _compute_adarms_cond(self, t_discretized: torch.Tensor) -> torch.Tensor: | |
| t_emb = self.time_sinusoidal(t_discretized.float()) | |
| t_emb = t_emb.to(dtype=self.time_mlp_1.weight.dtype) | |
| t_emb = F.silu(self.time_mlp_1(t_emb)) | |
| t_emb = F.silu(self.time_mlp_2(t_emb)) | |
| return t_emb | |
| def sample_time(self, batch_size: int, device, dtype) -> torch.Tensor: | |
| beta_dist = Beta(self._beta_alpha, self._beta_beta) | |
| sample = beta_dist.sample([batch_size]).to(device, dtype=dtype).clamp(max=self.noise_s) | |
| return (self.noise_s - sample) / self.noise_s | |
| def _prepare_action_embeds( | |
| self, | |
| noisy_actions: torch.Tensor, | |
| action_dof_mask: Optional[torch.Tensor], | |
| ) -> torch.Tensor: | |
| if action_dof_mask is not None: | |
| x = torch.cat( | |
| [noisy_actions, action_dof_mask.to(noisy_actions.dtype)], dim=-1, | |
| ) | |
| else: | |
| x = torch.cat([noisy_actions, torch.ones_like(noisy_actions)], dim=-1) | |
| tokens = self.action_in_proj(x) | |
| if self.add_pos_embed: | |
| pos_ids = torch.arange(tokens.shape[1], dtype=torch.long, device=noisy_actions.device) | |
| tokens = tokens + self.position_embedding(pos_ids).unsqueeze(0) | |
| return tokens | |
| def forward( | |
| self, | |
| past_key_values: Cache, | |
| actions: torch.Tensor, | |
| action_dof_mask: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| ) -> tuple: | |
| """Training: returns (pred_velocity, target_velocity). | |
| Args: | |
| past_key_values: VLM decoder KV cache; layer count must equal ``num_dit_layers``. | |
| """ | |
| vlm_kv_cache = self._vlm_kv_list_from_past(past_key_values) | |
| device = actions.device | |
| B = actions.shape[0] | |
| noise = torch.randn(actions.shape, device=device, dtype=actions.dtype) | |
| t = self.sample_time(B, device=device, dtype=actions.dtype) | |
| t_expanded = t[:, None, None] | |
| noisy_trajectory = (1 - t_expanded) * noise + t_expanded * actions | |
| velocity = actions - noise | |
| t_discretized = (t * self.num_timestep_buckets).long() | |
| adarms_cond = self._compute_adarms_cond(t_discretized) | |
| action_tokens = self._prepare_action_embeds(noisy_trajectory, action_dof_mask) | |
| output = self.dit( | |
| action_tokens, vlm_kv_cache, adarms_cond, encoder_attention_mask, | |
| ) | |
| pred = self.action_out_proj(output) | |
| pred_v = pred[:, :actions.shape[1]] | |
| return pred_v, velocity | |
| def compute_velocity( | |
| self, | |
| past_key_values: Cache, | |
| actions: torch.Tensor, | |
| noise: torch.Tensor, | |
| t: torch.Tensor, | |
| action_dof_mask: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """Compute velocity prediction for pre-sampled noise and timestep. | |
| Used by DiffusionNFT where noise and timestep must be shared between | |
| the current policy (v_θ) and the reference policy (v_old). | |
| Args: | |
| past_key_values: VLM decoder KV cache | |
| actions: (B, T, action_dim) ground truth actions (x_0) | |
| noise: (B, T, action_dim) pre-sampled noise (ε) | |
| t: (B,) continuous timesteps in [0, 1) | |
| action_dof_mask, encoder_attention_mask, | |
| Returns: | |
| pred_v: (B, T, action_dim) predicted velocity | |
| """ | |
| vlm_kv_cache = self._vlm_kv_list_from_past(past_key_values) | |
| device = actions.device | |
| t_expanded = t[:, None, None] | |
| noisy_trajectory = (1 - t_expanded) * noise + t_expanded * actions | |
| t_discretized = (t * self.num_timestep_buckets).long() | |
| adarms_cond = self._compute_adarms_cond(t_discretized) | |
| action_tokens = self._prepare_action_embeds(noisy_trajectory, action_dof_mask) | |
| output = self.dit( | |
| action_tokens, vlm_kv_cache, adarms_cond, encoder_attention_mask, | |
| ) | |
| pred = self.action_out_proj(output) | |
| return pred[:, :actions.shape[1]] | |
| def predict_action( | |
| self, | |
| past_key_values: Cache, | |
| action_dof_mask: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """Inference: Euler integration, returns (B, chunk_size, action_dim).""" | |
| k0 = past_key_values[0][0] | |
| B = k0.shape[0] | |
| device = k0.device | |
| dtype = k0.dtype | |
| vlm_kv_cache = self._vlm_kv_list_from_past(past_key_values) | |
| actions = torch.randn( | |
| (B, self.action_chunk_size, self.action_dim), | |
| device=device, dtype=dtype, | |
| ) | |
| dt = 1.0 / self.num_inference_timesteps | |
| for step in range(self.num_inference_timesteps): | |
| t_cont = step / float(self.num_inference_timesteps) | |
| t_disc_val = int(t_cont * self.num_timestep_buckets) | |
| t_tensor = torch.full((B,), t_disc_val, device=device, dtype=torch.long) | |
| adarms_cond = self._compute_adarms_cond(t_tensor) | |
| action_tokens = self._prepare_action_embeds(actions, action_dof_mask) | |
| output = self.dit( | |
| action_tokens, vlm_kv_cache, adarms_cond, encoder_attention_mask, | |
| ) | |
| pred_velocity = self.action_out_proj(output)[:, :self.action_chunk_size] | |
| actions = actions + dt * pred_velocity | |
| return actions |