# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file is a StarVLA-local variant of the action encoder/decoder. # It keeps the overall structure but replaces the decoder with a # flow-matching based decoder (velocity prediction) and injects timestep # conditioning into RMSNorm (AdaRMSNorm) in the decoder. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import sys sys.path.append("/mnt/data/fangyu/code/reward_new") import math from typing import List import numpy as np import torch import torch.distributed as dist import torch.nn.functional as F from torch import nn from torch.distributions import Beta from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, logging from starVLA.model.modules.action_model.ActionModel import ( Qwen3Attention, Qwen3MLP, Qwen3RMSNorm, Qwen3RotaryEmbedding, ActionPreTrainedModel, ) from starVLA.model.modules.action_model.configuration_actionmodel import ActionModelConfig from starVLA.model.tools import FRAMEWORK_REGISTRY logger = logging.get_logger(__name__) class _GradientReversalFunction(torch.autograd.Function): """ Forward: identity. Backward: scale gradient by -lambda (inverse gradient). Used for domain adversarial training so the encoder receives reversed gradient and is encouraged to produce domain-invariant embeddings. """ @staticmethod def forward(ctx, x: torch.Tensor, lambda_: float) -> torch.Tensor: ctx.lambda_ = lambda_ return x.view_as(x) @staticmethod def backward(ctx, grad_output: torch.Tensor): return -ctx.lambda_ * grad_output, None def _timestep_embedding(t: torch.Tensor, dim: int, max_period: float = 10000.0) -> torch.Tensor: """ Standard sinusoidal timestep embedding. Args: t: (B,) float tensor, typically in [0, 1]. Returns: (B, dim) """ if t.ndim != 1: raise ValueError(f"Expected `t` to have shape (B,), got {tuple(t.shape)}") half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(0, half, device=t.device, dtype=torch.float32) / max(half, 1) ) args = t.to(torch.float32)[:, None] * freqs[None] emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2 == 1: emb = torch.cat([emb, torch.zeros((emb.shape[0], 1), device=t.device, dtype=emb.dtype)], dim=-1) return emb.to(dtype=t.dtype) class Qwen3AdaRMSNorm(nn.Module): """ RMSNorm + timestep conditioning. y = RMSNorm(x) * (1 + scale(t)) + shift(t) """ def __init__(self, hidden_size: int, cond_size: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps self.cond_mlp = nn.Sequential( nn.SiLU(), nn.Linear(cond_size, 2 * hidden_size, bias=True), ) def forward(self, hidden_states: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: if cond is None: raise ValueError("Qwen3AdaRMSNorm requires `cond` but got None.") if cond.ndim != 2: raise ValueError(f"Expected `cond` to have shape (B, C), got {tuple(cond.shape)}") input_dtype = hidden_states.dtype x = hidden_states.to(torch.float32) variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.variance_epsilon) x = self.weight * x.to(input_dtype) scale, shift = self.cond_mlp(cond).chunk(2, dim=-1) return x * (1 + scale[:, None, :]) + shift[:, None, :] def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class Qwen3LayerFM(nn.Module): """ Same block structure as `Qwen3Layer`, but decoder-side RMSNorms are timestep-conditioned. Attention/MLP are unchanged. """ def __init__(self, config: ActionModelConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) self.mlp = Qwen3MLP(config) self.input_layernorm = Qwen3AdaRMSNorm(config.hidden_size, cond_size=config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen3AdaRMSNorm( config.hidden_size, cond_size=config.hidden_size, eps=config.rms_norm_eps ) def forward( self, hidden_states: torch.Tensor, temb: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs, ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states, temb) hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states, temb) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class ActionModelFM(ActionPreTrainedModel): """ Flow-matching based decoder variant for StarVLA `ActionModel`. Encoder stays the same; decoder predicts velocity under linear interpolation noise. """ def __init__(self, config: ActionModelConfig): super().__init__(config) self.config = config # ===== tokens / embeddings (same as original) ===== self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.action_mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.state_mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.dataset_embed = nn.Embedding( config.dataset_vocab_size, config.hidden_size * config.num_data_tokens, ) self.action_proj_in = nn.Linear(config.action_size, config.hidden_size) self.state_proj_in = nn.Linear(config.state_size, config.hidden_size) self.use_state = config.use_state print(f"use_state: {self.use_state}") # ===== encoder (unchanged blocks) ===== # Reuse the original Qwen3Layer implementation from ActionModel.py through `ActionPreTrainedModel` machinery from starVLA.model.modules.action_model.ActionModel import Qwen3Layer # local import self.action_encoder = nn.ModuleList([Qwen3Layer(config, layer_idx) for layer_idx in range(config.num_encoder_layers)]) # ===== decoder (FM) ===== self.action_decoder = nn.ModuleList([Qwen3LayerFM(config, layer_idx) for layer_idx in range(config.num_decoder_layers)]) self.norm = Qwen3AdaRMSNorm(config.hidden_size, cond_size=config.hidden_size, eps=config.rms_norm_eps) self.action_proj_out = nn.Linear(config.hidden_size, config.action_size) self.rotary_emb = Qwen3RotaryEmbedding(config=config) self.gradient_checkpointing = False # ===== FM hyperparams ===== self.fm_time_min = float(getattr(config, "fm_time_min", 0.001)) self.fm_time_max = float(getattr(config, "fm_time_max", 0.999)) self.fm_num_inference_steps = int(getattr(config, "fm_num_inference_steps", 10)) self.fm_time_sampling = str(getattr(config, "fm_time_sampling", "uniform")) # "uniform" | "beta" self.fm_beta_alpha = float(getattr(config, "fm_beta_alpha", 1.5)) self.fm_beta_beta = float(getattr(config, "fm_beta_beta", 1.0)) self._beta_dist = Beta(self.fm_beta_alpha, self.fm_beta_beta) # timestep -> temb (B,H) self.fm_timestep_mlp = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size * 4, bias=True), nn.SiLU(), nn.Linear(config.hidden_size * 4, config.hidden_size, bias=True), ) # ===== Loss mode: masked-action recon ===== self.use_masked_action_recon = bool(getattr(config, "use_masked_action_recon", False)) self.post_init() self._maybe_init_from_qwen3() def _sample_fm_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: if self.fm_time_sampling == "beta": t = self._beta_dist.sample([batch_size]).to(device=device, dtype=dtype) else: t = torch.rand((batch_size,), device=device, dtype=dtype) t = t * (self.fm_time_max - self.fm_time_min) + self.fm_time_min return t def _fm_temb(self, t: torch.Tensor) -> torch.Tensor: return self.fm_timestep_mlp(_timestep_embedding(t, self.config.hidden_size)) def _gather_embeddings(self, x: torch.Tensor) -> tuple[torch.Tensor, int]: """ Gather embeddings from all ranks. Returns (gathered_tensor, offset) where offset is the start index of this rank's data in the global batch. Single-GPU: returns (x, 0). """ if not (self.contrastive_use_distributed and dist.is_initialized() and dist.get_world_size() > 1): return x, 0 world_size = dist.get_world_size() local_size = x.shape[0] size_list = [torch.tensor([0], dtype=torch.long, device=x.device) for _ in range(world_size)] dist.all_gather(size_list, torch.tensor([local_size], dtype=torch.long, device=x.device)) sizes = [s.item() for s in size_list] max_size = max(sizes) offset = sum(sizes[: dist.get_rank()]) if max_size > local_size: padding = torch.zeros(max_size - local_size, x.shape[1], device=x.device, dtype=x.dtype) x = torch.cat([x, padding], dim=0) gather_list = [torch.zeros_like(x) for _ in range(world_size)] dist.all_gather(gather_list, x) out = torch.cat([g[: sizes[i]] for i, g in enumerate(gather_list)], dim=0) return out, offset def random_masking(self, x: torch.Tensor, mask_ratio: float | torch.Tensor): """ MAE-style per-sample random masking by shuffling (argsort noise). This version DOES NOT drop tokens; it returns `x_masked` with the same shape as `x`, where masked positions are replaced by `self.action_mask_token`. Args: x: [N, L, D] mask_ratio: float in [0, 1) OR tensor of shape [N] with per-sample ratios Returns: x_masked: [N, L, D] mask: [N, L] (0=keep, 1=mask) ids_restore: [N, L] """ N, L, D = x.shape token_dim = int(self.action_mask_token.shape[-1]) if D != token_dim: raise ValueError( f"`random_masking` expects last dim D=={token_dim} (same as action_mask_token), got D=={D}." ) if isinstance(mask_ratio, torch.Tensor): if mask_ratio.ndim != 1 or mask_ratio.shape[0] != N: raise ValueError( f"When `mask_ratio` is a tensor it must have shape (N,), got {tuple(mask_ratio.shape)}" ) # clamp to safe range mask_ratio = mask_ratio.to(device=x.device, dtype=torch.float32).clamp(min=0.0, max=0.999) len_keep = torch.floor(L * (1.0 - mask_ratio)).to(dtype=torch.long) # (N,) else: mr = float(mask_ratio) mr = max(0.0, min(0.999, mr)) len_keep = int(L * (1.0 - mr)) noise = torch.rand(N, L, device=x.device) # noise in [0, 1] ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is mask ids_restore = torch.argsort(ids_shuffle, dim=1) # generate the binary mask: 0 is keep, 1 is mask mask = torch.ones([N, L], device=x.device, dtype=torch.float32) if isinstance(len_keep, torch.Tensor): # build mask in shuffled order then unshuffle keep = torch.arange(L, device=x.device)[None, :].expand(N, L) < len_keep[:, None] # (N,L) mask = (~keep).to(dtype=torch.float32) else: mask[:, :len_keep] = 0 mask = torch.gather(mask, dim=1, index=ids_restore) # unshuffle # replace masked tokens with action_mask_token (keep sequence length) mask_token = self.action_mask_token.expand(N, L, -1).to(dtype=x.dtype, device=x.device) x_masked = x * (1.0 - mask[:, :, None]) + mask[:, :, None] * mask_token return x_masked, mask, ids_restore def random_masking_interleaved( self, interleaved: torch.Tensor, mask_ratio: float | torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ MAE-style random masking for interleaved [state_0, action_0, state_1, action_1, ...]. Positions 0, 2, 4, ... are state (replaced with state_mask_token when masked); positions 1, 3, 5, ... are action (replaced with action_mask_token when masked). Args: interleaved: [N, 2*L, D] (state, action, state, action, ...) mask_ratio: float in [0, 1) OR tensor [N] per-sample Returns: x_masked: [N, 2*L, D] mask: [N, 2*L] (0=keep, 1=mask) ids_restore: [N, 2*L] """ N, two_L, D = interleaved.shape L = two_L // 2 if isinstance(mask_ratio, torch.Tensor): mask_ratio = mask_ratio.to(device=interleaved.device, dtype=torch.float32).clamp(min=0.0, max=0.999) len_keep = torch.floor(two_L * (1.0 - mask_ratio)).to(dtype=torch.long) else: mr = max(0.0, min(0.999, float(mask_ratio))) len_keep = int(two_L * (1.0 - mr)) noise = torch.rand(N, two_L, device=interleaved.device) ids_shuffle = torch.argsort(noise, dim=1) ids_restore = torch.argsort(ids_shuffle, dim=1) if isinstance(len_keep, torch.Tensor): keep = torch.arange(two_L, device=interleaved.device)[None, :].expand(N, two_L) < len_keep[:, None] mask = (~keep).to(dtype=torch.float32) else: mask = torch.ones(N, two_L, device=interleaved.device, dtype=torch.float32) mask[:, :len_keep] = 0 mask = torch.gather(mask, dim=1, index=ids_restore) state_mtk = self.state_mask_token.expand(N, two_L, -1).to(dtype=interleaved.dtype, device=interleaved.device) action_mtk = self.action_mask_token.expand(N, two_L, -1).to(dtype=interleaved.dtype, device=interleaved.device) # even indices -> state, odd -> action state_pos = torch.zeros(two_L, device=interleaved.device, dtype=torch.float32) state_pos[0::2] = 1.0 state_pos = state_pos.view(1, two_L, 1) action_pos = 1.0 - state_pos mask_expand = mask[:, :, None] replacement = mask_expand * (state_pos * state_mtk + action_pos * action_mtk) x_masked = interleaved * (1.0 - mask_expand) + replacement return x_masked, mask, ids_restore # --- copied optional init helper from original --- def _maybe_init_from_qwen3(self) -> None: from transformers import AutoModel name_or_path = getattr(self.config, "qwen3_pretrained_name_or_path", None) if not name_or_path: return pretrained = AutoModel.from_pretrained( name_or_path, torch_dtype="auto", low_cpu_mem_usage=True, ) src_sd = pretrained.state_dict() layer_prefix = None for p in ("model.layers.", "layers."): if any(k.startswith(p) for k in src_sd.keys()): layer_prefix = p break norm_prefix = None for p in ("model.norm.", "norm."): if any(k.startswith(p) for k in src_sd.keys()): norm_prefix = p break if layer_prefix is None: return def _map_layer_key(target_key: str, module_prefix: str, layer_offset: int) -> str | None: rest = target_key[len(module_prefix) + 1 :] parts = rest.split(".", 1) if len(parts) != 2: return None try: tgt_idx = int(parts[0]) except ValueError: return None src_idx = tgt_idx + int(layer_offset) return f"{layer_prefix}{src_idx}.{parts[1]}" own_sd = self.state_dict() to_load: dict[str, torch.Tensor] = {} matched = 0 missing = 0 shape_mismatch = 0 init_enc = bool(getattr(self.config, "qwen3_init_action_encoder", True)) init_dec = bool(getattr(self.config, "qwen3_init_action_decoder", True)) init_norm = bool(getattr(self.config, "qwen3_init_norm", True)) enc_off = int(getattr(self.config, "qwen3_encoder_layer_offset", 0)) dec_off = int(getattr(self.config, "qwen3_decoder_layer_offset", 0)) # NOTE: decoder has AdaRMSNorm (extra cond_mlp weights), but many weights still match: # - action_decoder.*.self_attn.* # - action_decoder.*.mlp.* # - action_decoder.*.(input_layernorm|post_attention_layernorm).weight (load RMS weight only) # - norm.weight (load RMS weight only) for k, tgt_tensor in own_sd.items(): src_key = None if init_enc and k.startswith("action_encoder."): src_key = _map_layer_key(k, "action_encoder", enc_off) elif init_dec and k.startswith("action_decoder."): # Skip timestep-conditioned MLP weights (no counterpart in Qwen3) if ".cond_mlp." in k: continue src_key = _map_layer_key(k, "action_decoder", dec_off) elif init_norm and k == "norm.weight" and norm_prefix is not None: src_key = f"{norm_prefix}weight" if not src_key: continue src_tensor = src_sd.get(src_key, None) if src_tensor is None: missing += 1 continue if src_tensor.shape != tgt_tensor.shape: shape_mismatch += 1 continue to_load[k] = src_tensor.to(device=tgt_tensor.device, dtype=tgt_tensor.dtype) matched += 1 self.load_state_dict(to_load, strict=False) print( f"Initialized from Qwen3 checkpoint {name_or_path}. " f"matched={matched} missing={missing} shape_mismatch={shape_mismatch} prefix={layer_prefix}" ) if matched == 0: # Most common culprit: config dims don't match Qwen3 checkpoint. src_cfg = getattr(pretrained, "config", None) if src_cfg is not None: fields = [ "hidden_size", "intermediate_size", "num_hidden_layers", "num_attention_heads", "num_key_value_heads", "head_dim", "rms_norm_eps", ] diffs = [] for f in fields: if hasattr(src_cfg, f) and hasattr(self.config, f): a = getattr(self.config, f) b = getattr(src_cfg, f) if a != b: diffs.append((f, a, b)) if diffs: print("[ActionModelFM] Qwen3 init got 0 matches. Config differs from checkpoint:") for f, a, b in diffs: print(f" - {f}: ActionModelConfig={a} vs Qwen3={b}") def forward( self, examples: List[dict] = None, **kwargs: Unpack[TransformersKwargs], ): device = next(self.parameters()).device batch_size = len(examples) # ========================================================================= # 1. Variable-length Horizon (same as original) # ========================================================================= raw_actions = torch.tensor( np.array([ex["action"] for ex in examples]), device=device, dtype=torch.float32, ) # [B, L, D] use_state = self.use_state raw_states = None if use_state: raw_states = torch.tensor( np.array([ex["state"] for ex in examples]), device=device, dtype=torch.float32, ) # [B, L, state_dim] # ========================================================================= # 2. Action (and optional State) Input Construction & Masking (DAE) # Encoder sequence: cls, dataset_tokens, [state_0, action_0, state_1, action_1, ...] # Two-view (masked + clean) when use_masked_action_recon. # ========================================================================= with torch.autocast("cuda", dtype=torch.float32): clean_action_embeds = self.action_proj_in(raw_actions) # [B, L, H] if use_state: clean_state_embeds = self.state_proj_in(raw_states) # [B, L, H] # Interleave: [s0, a0, s1, a1, ...] -> [B, 2*L, H] clean_inputs_embeds = torch.stack( [clean_state_embeds, clean_action_embeds], dim=2 ).reshape(batch_size, 2 * raw_actions.shape[1], -1) else: clean_inputs_embeds = clean_action_embeds masked_inputs_embeds = clean_inputs_embeds if self.use_masked_action_recon: if use_state: if getattr(self.config, "mask_ratio_mode", "fixed") == "uniform_per_traj": mr_min = float(getattr(self.config, "mask_ratio_min", self.config.mask_ratio)) mr_max = float(getattr(self.config, "mask_ratio_max", self.config.mask_ratio)) per_traj_mr = torch.rand((batch_size,), device=device) * (mr_max - mr_min) + mr_min masked_inputs_embeds, _, _ = self.random_masking_interleaved(clean_inputs_embeds, per_traj_mr) else: masked_inputs_embeds, _, _ = self.random_masking_interleaved( clean_inputs_embeds, float(self.config.mask_ratio) ) else: if getattr(self.config, "mask_ratio_mode", "fixed") == "uniform_per_traj": mr_min = float(getattr(self.config, "mask_ratio_min", self.config.mask_ratio)) mr_max = float(getattr(self.config, "mask_ratio_max", self.config.mask_ratio)) per_traj_mr = torch.rand((batch_size,), device=device) * (mr_max - mr_min) + mr_min masked_inputs_embeds, _, _ = self.random_masking(clean_inputs_embeds, per_traj_mr) else: masked_inputs_embeds, _, _ = self.random_masking(clean_inputs_embeds, float(self.config.mask_ratio)) # ========================================================================= # 3. Dataset Soft Prompt (same as original) # ========================================================================= dataset_ids = [ex.get("dataset_id") for ex in examples] dataset_ids_tensor = torch.tensor(dataset_ids, device=device, dtype=torch.long) ds_embeds = self.dataset_embed(dataset_ids_tensor).view( batch_size, self.config.num_data_tokens, self.config.hidden_size ) cls_token_expanded = self.cls_token.expand(batch_size, -1, -1) encoder_inputs_clean = torch.cat((cls_token_expanded, ds_embeds, clean_inputs_embeds), dim=1) encoder_inputs_masked = torch.cat((cls_token_expanded, ds_embeds, masked_inputs_embeds), dim=1) seq_len = encoder_inputs_clean.shape[1] enc_bs = batch_size * 2 if self.use_masked_action_recon else batch_size encoder_attention_mask = torch.ones((enc_bs, 1, seq_len, seq_len), device=device, dtype=torch.bool) encoder_pos_ids = torch.arange(seq_len, device=device).unsqueeze(0) # rotary embeddings are position-based; we keep position_ids batch=1 and broadcast. enc_pos_emb = self.rotary_emb(encoder_inputs_clean, encoder_pos_ids) hidden_states = ( torch.cat((encoder_inputs_masked, encoder_inputs_clean), dim=0) if self.use_masked_action_recon else encoder_inputs_clean ) for encoder_layer in self.action_encoder: hidden_states = encoder_layer( hidden_states, attention_mask=encoder_attention_mask, position_embeddings=enc_pos_emb, position_ids=encoder_pos_ids, **kwargs, ) if self.use_masked_action_recon: hidden_masked, hidden_clean = hidden_states.chunk(2, dim=0) action_embedding_masked = F.normalize(hidden_masked[:, :1, :], p=2, dim=-1) action_embedding_clean = F.normalize(hidden_clean[:, :1, :], p=2, dim=-1) else: action_embedding_clean = F.normalize(hidden_states[:, :1, :], p=2, dim=-1) action_embedding_masked = None # ========================================================================= # 4. Flow-matching Decoder # ========================================================================= t = self._sample_fm_time(batch_size, device=device, dtype=raw_actions.dtype) # (B,) noise = torch.randn_like(raw_actions) noisy_actions = t[:, None, None] * noise + (1 - t[:, None, None]) * raw_actions target_velocity = noise - raw_actions noisy_embeds = self.action_proj_in(noisy_actions) if self.use_masked_action_recon: # Single decoder forward for both views in one batch. decoder_cond = torch.cat((action_embedding_clean, action_embedding_masked), dim=0) noisy_embeds = torch.cat((noisy_embeds, noisy_embeds), dim=0) t = torch.cat((t, t), dim=0) target_velocity = torch.cat((target_velocity, target_velocity), dim=0) else: decoder_cond = action_embedding_clean decoder_inputs = torch.cat((decoder_cond, noisy_embeds), dim=1) # [B or 2B, 1+L, H] dec_seq_len = decoder_inputs.shape[1] dec_bs = decoder_inputs.shape[0] decoder_attention_mask = torch.ones((dec_bs, 1, dec_seq_len, dec_seq_len), device=device, dtype=torch.bool) dec_pos_ids = torch.arange(dec_seq_len, device=device).unsqueeze(0) dec_pos_emb = self.rotary_emb(decoder_inputs, dec_pos_ids) temb = self._fm_temb(t) hidden_states = decoder_inputs for decoder_layer in self.action_decoder: hidden_states = decoder_layer( hidden_states, temb=temb, attention_mask=decoder_attention_mask, position_embeddings=dec_pos_emb, position_ids=dec_pos_ids, ) hidden_states = self.norm(hidden_states, temb) pred_velocity = self.action_proj_out(hidden_states[:, 1:, :]) if self.use_masked_action_recon: pred_clean, pred_masked = pred_velocity.chunk(2, dim=0) target_clean, target_masked = target_velocity.chunk(2, dim=0) recon_loss_clean = F.mse_loss(pred_clean, target_clean) recon_loss_masked = F.mse_loss(pred_masked, target_masked) recon_loss = 0.5 * (recon_loss_clean + recon_loss_masked) else: recon_loss = F.mse_loss(pred_velocity, target_velocity) return recon_loss def recon_loss(self, actions, dataset_ids: list[int], state=None, **kwargs): """ Same interface as `ActionModel.recon_loss`, but using flow-matching decoder loss. Args: actions: (B, L, action_dim) dataset_ids: list[int]; used for dataset soft prompt when state is provided. state: optional (B, L, state_dim); if provided and state_proj_in exists, encoder sees interleaved sequence [state_0, action_0, state_1, action_1, ...]. Returns: scalar loss """ # Optional fast-path: pass a precomputed action embedding to avoid another encoder forward. action_embedding = kwargs.pop("action_embedding", None) t = kwargs.pop("t", None) noise = kwargs.pop("noise", None) if action_embedding is None: action_embedding = self.encode_actions(actions, dataset_ids, state, **kwargs) return self.recon_loss_from_embedding( action_embedding=action_embedding, actions=actions, t=t, noise=noise, ) def recon_loss_from_embedding( self, action_embedding: torch.Tensor, actions: torch.Tensor, t: torch.Tensor | None = None, noise: torch.Tensor | None = None, ) -> torch.Tensor: """ Flow-matching velocity loss conditioned on a provided action embedding. This is the preferred interface when you already have an action embedding (e.g., from VLM projector), since it avoids an extra action-encoder forward. Args: action_embedding: (B, H) or (B, 1, H), assumed L2-normalized (recommended). actions: (B, L, action_dim) t: optional (B,) time; if None sample internally noise: optional (B, L, action_dim) noise; if None sample internally """ if action_embedding.dim() == 2: action_embedding = action_embedding.unsqueeze(1) if action_embedding.dim() != 3 or action_embedding.shape[1] != 1: raise ValueError(f"Expected action_embedding shape (B,1,H) or (B,H); got {tuple(action_embedding.shape)}") batch_size = actions.shape[0] device = actions.device dtype = actions.dtype if t is None: t = self._sample_fm_time(batch_size, device=device, dtype=dtype) if noise is None: noise = torch.randn_like(actions) noisy_actions = t[:, None, None] * noise + (1 - t[:, None, None]) * actions target_velocity = noise - actions temb = self._fm_temb(t) action_embeds = self.action_proj_in(noisy_actions) hidden_states = torch.cat((action_embedding, action_embeds), dim=1) dec_seq_len = hidden_states.shape[1] decoder_attention_mask = torch.ones( (batch_size, 1, dec_seq_len, dec_seq_len), device=device, dtype=torch.bool, ) dec_pos_ids = torch.arange(dec_seq_len, device=device).unsqueeze(0) dec_pos_emb = self.rotary_emb(hidden_states, dec_pos_ids) for decoder_layer in self.action_decoder: hidden_states = decoder_layer( hidden_states, temb=temb, attention_mask=decoder_attention_mask, position_embeddings=dec_pos_emb, position_ids=dec_pos_ids, ) hidden_states = self.norm(hidden_states, temb) pred_velocity = self.action_proj_out(hidden_states[:, 1:, :]) return F.mse_loss(pred_velocity, target_velocity) def encode_actions(self, actions, dataset_ids: list[int], state=None, **kwargs): """ Encode action chunk (and optional state chunk) to a single CLS embedding. Args: actions: (B, L, action_dim) state: optional (B, L, state_dim); if provided and state_proj_in exists, encoder sees interleaved sequence [state_0, action_0, state_1, action_1, ...]. dataset_ids: list[int]; used for dataset soft prompt when state is provided. """ action_embeds = self.action_proj_in(actions) batch_size = action_embeds.shape[0] use_state = state is not None and self.state_proj_in is not None if use_state: state_embeds = self.state_proj_in(state) L = action_embeds.shape[1] inputs_embeds = torch.stack( [state_embeds, action_embeds], dim=2 ).reshape(batch_size, 2 * L, -1) else: inputs_embeds = action_embeds cls_token_expanded = self.cls_token.expand(batch_size, -1, -1) dataset_ids_tensor = torch.tensor(dataset_ids, device=action_embeds.device, dtype=torch.long) ds_embeds = self.dataset_embed(dataset_ids_tensor).view( batch_size, self.config.num_data_tokens, self.config.hidden_size ) inputs_embeds = torch.cat((cls_token_expanded, ds_embeds, inputs_embeds), dim=1) seq_len = inputs_embeds.shape[1] encoder_attention_mask = torch.ones( (batch_size, 1, seq_len, seq_len), device=inputs_embeds.device, dtype=torch.bool, ) encoder_pos_ids = torch.arange(seq_len, device=inputs_embeds.device).unsqueeze(0) enc_pos_emb = self.rotary_emb(inputs_embeds, encoder_pos_ids) hidden_states = inputs_embeds for encoder_layer in self.action_encoder: hidden_states = encoder_layer( hidden_states, attention_mask=encoder_attention_mask, position_embeddings=enc_pos_emb, position_ids=encoder_pos_ids, **kwargs, ) action_embedding = hidden_states[:, :1, :] return F.normalize(action_embedding, p=2, dim=-1) @torch.no_grad() def decode_actions(self, action_embedding, chunk_size, **kwargs): """ FM sampling via simple Euler integration of the learned velocity field. """ if chunk_size is None: chunk_size = self.config.max_action_chunk_size if action_embedding.dim() == 2: action_embedding = action_embedding.unsqueeze(1) batch_size = action_embedding.shape[0] device = action_embedding.device dtype = action_embedding.dtype actions = torch.randn((batch_size, chunk_size, self.config.action_size), device=device, dtype=dtype) num_steps = max(int(self.fm_num_inference_steps), 1) dt = -1.0 / float(num_steps) for step in range(num_steps): t = torch.full((batch_size,), 1.0 - step / float(num_steps), device=device, dtype=dtype) temb = self._fm_temb(t) action_embeds = self.action_proj_in(actions) hidden_states = torch.cat((action_embedding, action_embeds), dim=1) dec_seq_len = hidden_states.shape[1] decoder_attention_mask = torch.ones((batch_size, 1, dec_seq_len, dec_seq_len), device=device, dtype=torch.bool) dec_pos_ids = torch.arange(dec_seq_len, device=device).unsqueeze(0) dec_pos_emb = self.rotary_emb(hidden_states, dec_pos_ids) for decoder_layer in self.action_decoder: hidden_states = decoder_layer( hidden_states, temb=temb, attention_mask=decoder_attention_mask, position_embeddings=dec_pos_emb, position_ids=dec_pos_ids, ) hidden_states = self.norm(hidden_states, temb) pred_velocity = self.action_proj_out(hidden_states[:, 1:, :]) actions = actions + dt * pred_velocity return actions __all__ = [ "ActionModelFM", ] if __name__ == "__main__": config = ActionModelConfig() action_model = ActionModelFM(config) print(action_model) print("Total number of DiT parameters: ", sum(p.numel() for p in action_model.parameters() if p.requires_grad)) fake_actions = torch.randn(10, 15, 64).to("cuda:7") sample = { "action": np.random.uniform(-1, 1, size=(16, 29)).astype(np.float16), # action_chunk, action_dim (unified 29D) "lang": "put the ball on the table", } batch = [sample, sample] device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu") action_model = action_model.to(device) outputs = action_model(batch) print(outputs) action_embedding = action_model.encode_actions(fake_actions) print(f"action_embedding: {action_embedding.shape}") reconstructed_actions = action_model.decode_actions(action_embedding, chunk_size=15) print(f"reconstructed_actions: {reconstructed_actions.shape}")