| |
| |
| |
| |
| |
| |
|
|
| import sys |
| sys.path.append("/mnt/data/fangyu/code/reward_new") |
|
|
| import math |
| from typing import List |
|
|
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| import torch.nn.functional as F |
| from torch import nn |
| from torch.distributions import Beta |
|
|
| from transformers.processing_utils import Unpack |
| from transformers.utils import TransformersKwargs, logging |
|
|
| from starVLA.model.modules.action_model.ActionModel import ( |
| Qwen3Attention, |
| Qwen3MLP, |
| Qwen3RMSNorm, |
| Qwen3RotaryEmbedding, |
| ActionPreTrainedModel, |
| ) |
| from starVLA.model.modules.action_model.configuration_actionmodel import ActionModelConfig |
| from starVLA.model.tools import FRAMEWORK_REGISTRY |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class _GradientReversalFunction(torch.autograd.Function): |
| """ |
| Forward: identity. Backward: scale gradient by -lambda (inverse gradient). |
| Used for domain adversarial training so the encoder receives reversed gradient |
| and is encouraged to produce domain-invariant embeddings. |
| """ |
|
|
| @staticmethod |
| def forward(ctx, x: torch.Tensor, lambda_: float) -> torch.Tensor: |
| ctx.lambda_ = lambda_ |
| return x.view_as(x) |
|
|
| @staticmethod |
| def backward(ctx, grad_output: torch.Tensor): |
| return -ctx.lambda_ * grad_output, None |
|
|
|
|
| def _timestep_embedding(t: torch.Tensor, dim: int, max_period: float = 10000.0) -> torch.Tensor: |
| """ |
| Standard sinusoidal timestep embedding. |
| Args: |
| t: (B,) float tensor, typically in [0, 1]. |
| Returns: |
| (B, dim) |
| """ |
| if t.ndim != 1: |
| raise ValueError(f"Expected `t` to have shape (B,), got {tuple(t.shape)}") |
| half = dim // 2 |
| freqs = torch.exp( |
| -math.log(max_period) * torch.arange(0, half, device=t.device, dtype=torch.float32) / max(half, 1) |
| ) |
| args = t.to(torch.float32)[:, None] * freqs[None] |
| emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| if dim % 2 == 1: |
| emb = torch.cat([emb, torch.zeros((emb.shape[0], 1), device=t.device, dtype=emb.dtype)], dim=-1) |
| return emb.to(dtype=t.dtype) |
|
|
|
|
| class Qwen3AdaRMSNorm(nn.Module): |
| """ |
| RMSNorm + timestep conditioning. |
| |
| y = RMSNorm(x) * (1 + scale(t)) + shift(t) |
| """ |
|
|
| def __init__(self, hidden_size: int, cond_size: int, eps: float = 1e-6) -> None: |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
| self.cond_mlp = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(cond_size, 2 * hidden_size, bias=True), |
| ) |
|
|
| def forward(self, hidden_states: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: |
| if cond is None: |
| raise ValueError("Qwen3AdaRMSNorm requires `cond` but got None.") |
| if cond.ndim != 2: |
| raise ValueError(f"Expected `cond` to have shape (B, C), got {tuple(cond.shape)}") |
|
|
| input_dtype = hidden_states.dtype |
| x = hidden_states.to(torch.float32) |
| variance = x.pow(2).mean(-1, keepdim=True) |
| x = x * torch.rsqrt(variance + self.variance_epsilon) |
| x = self.weight * x.to(input_dtype) |
|
|
| scale, shift = self.cond_mlp(cond).chunk(2, dim=-1) |
| return x * (1 + scale[:, None, :]) + shift[:, None, :] |
|
|
| def extra_repr(self): |
| return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
|
|
|
|
| class Qwen3LayerFM(nn.Module): |
| """ |
| Same block structure as `Qwen3Layer`, but decoder-side RMSNorms are timestep-conditioned. |
| Attention/MLP are unchanged. |
| """ |
|
|
| def __init__(self, config: ActionModelConfig, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) |
| self.mlp = Qwen3MLP(config) |
| self.input_layernorm = Qwen3AdaRMSNorm(config.hidden_size, cond_size=config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = Qwen3AdaRMSNorm( |
| config.hidden_size, cond_size=config.hidden_size, eps=config.rms_norm_eps |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| temb: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states, temb) |
| hidden_states, _ = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states, temb) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
| return hidden_states |
|
|
|
|
| class ActionModelFM(ActionPreTrainedModel): |
| """ |
| Flow-matching based decoder variant for StarVLA `ActionModel`. |
| Encoder stays the same; decoder predicts velocity under linear interpolation noise. |
| """ |
|
|
| def __init__(self, config: ActionModelConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| |
| self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) |
| self.action_mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) |
| self.state_mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) |
|
|
| self.dataset_embed = nn.Embedding( |
| config.dataset_vocab_size, |
| config.hidden_size * config.num_data_tokens, |
| ) |
|
|
| self.action_proj_in = nn.Linear(config.action_size, config.hidden_size) |
| self.state_proj_in = nn.Linear(config.state_size, config.hidden_size) |
| self.use_state = config.use_state |
| print(f"use_state: {self.use_state}") |
|
|
| |
| |
| from starVLA.model.modules.action_model.ActionModel import Qwen3Layer |
| self.action_encoder = nn.ModuleList([Qwen3Layer(config, layer_idx) for layer_idx in range(config.num_encoder_layers)]) |
|
|
| |
| self.action_decoder = nn.ModuleList([Qwen3LayerFM(config, layer_idx) for layer_idx in range(config.num_decoder_layers)]) |
| self.norm = Qwen3AdaRMSNorm(config.hidden_size, cond_size=config.hidden_size, eps=config.rms_norm_eps) |
| self.action_proj_out = nn.Linear(config.hidden_size, config.action_size) |
|
|
| self.rotary_emb = Qwen3RotaryEmbedding(config=config) |
| self.gradient_checkpointing = False |
|
|
| |
| self.fm_time_min = float(getattr(config, "fm_time_min", 0.001)) |
| self.fm_time_max = float(getattr(config, "fm_time_max", 0.999)) |
| self.fm_num_inference_steps = int(getattr(config, "fm_num_inference_steps", 10)) |
| self.fm_time_sampling = str(getattr(config, "fm_time_sampling", "uniform")) |
| self.fm_beta_alpha = float(getattr(config, "fm_beta_alpha", 1.5)) |
| self.fm_beta_beta = float(getattr(config, "fm_beta_beta", 1.0)) |
| self._beta_dist = Beta(self.fm_beta_alpha, self.fm_beta_beta) |
|
|
| |
| self.fm_timestep_mlp = nn.Sequential( |
| nn.Linear(config.hidden_size, config.hidden_size * 4, bias=True), |
| nn.SiLU(), |
| nn.Linear(config.hidden_size * 4, config.hidden_size, bias=True), |
| ) |
|
|
| |
| self.use_masked_action_recon = bool(getattr(config, "use_masked_action_recon", False)) |
| self.post_init() |
|
|
| self._maybe_init_from_qwen3() |
|
|
|
|
| def _sample_fm_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: |
| if self.fm_time_sampling == "beta": |
| t = self._beta_dist.sample([batch_size]).to(device=device, dtype=dtype) |
| else: |
| t = torch.rand((batch_size,), device=device, dtype=dtype) |
| t = t * (self.fm_time_max - self.fm_time_min) + self.fm_time_min |
| return t |
|
|
| def _fm_temb(self, t: torch.Tensor) -> torch.Tensor: |
| return self.fm_timestep_mlp(_timestep_embedding(t, self.config.hidden_size)) |
|
|
| def _gather_embeddings(self, x: torch.Tensor) -> tuple[torch.Tensor, int]: |
| """ |
| Gather embeddings from all ranks. |
| Returns (gathered_tensor, offset) where offset is the start index of this rank's data in the global batch. |
| Single-GPU: returns (x, 0). |
| """ |
| if not (self.contrastive_use_distributed and dist.is_initialized() and dist.get_world_size() > 1): |
| return x, 0 |
| world_size = dist.get_world_size() |
| local_size = x.shape[0] |
| size_list = [torch.tensor([0], dtype=torch.long, device=x.device) for _ in range(world_size)] |
| dist.all_gather(size_list, torch.tensor([local_size], dtype=torch.long, device=x.device)) |
| sizes = [s.item() for s in size_list] |
| max_size = max(sizes) |
| offset = sum(sizes[: dist.get_rank()]) |
| if max_size > local_size: |
| padding = torch.zeros(max_size - local_size, x.shape[1], device=x.device, dtype=x.dtype) |
| x = torch.cat([x, padding], dim=0) |
| gather_list = [torch.zeros_like(x) for _ in range(world_size)] |
| dist.all_gather(gather_list, x) |
| out = torch.cat([g[: sizes[i]] for i, g in enumerate(gather_list)], dim=0) |
| return out, offset |
|
|
| def random_masking(self, x: torch.Tensor, mask_ratio: float | torch.Tensor): |
| """ |
| MAE-style per-sample random masking by shuffling (argsort noise). |
| |
| This version DOES NOT drop tokens; it returns `x_masked` with the same shape as `x`, |
| where masked positions are replaced by `self.action_mask_token`. |
| |
| Args: |
| x: [N, L, D] |
| mask_ratio: float in [0, 1) OR tensor of shape [N] with per-sample ratios |
| |
| Returns: |
| x_masked: [N, L, D] |
| mask: [N, L] (0=keep, 1=mask) |
| ids_restore: [N, L] |
| """ |
| N, L, D = x.shape |
| token_dim = int(self.action_mask_token.shape[-1]) |
| if D != token_dim: |
| raise ValueError( |
| f"`random_masking` expects last dim D=={token_dim} (same as action_mask_token), got D=={D}." |
| ) |
| if isinstance(mask_ratio, torch.Tensor): |
| if mask_ratio.ndim != 1 or mask_ratio.shape[0] != N: |
| raise ValueError( |
| f"When `mask_ratio` is a tensor it must have shape (N,), got {tuple(mask_ratio.shape)}" |
| ) |
| |
| mask_ratio = mask_ratio.to(device=x.device, dtype=torch.float32).clamp(min=0.0, max=0.999) |
| len_keep = torch.floor(L * (1.0 - mask_ratio)).to(dtype=torch.long) |
| else: |
| mr = float(mask_ratio) |
| mr = max(0.0, min(0.999, mr)) |
| len_keep = int(L * (1.0 - mr)) |
|
|
| noise = torch.rand(N, L, device=x.device) |
| ids_shuffle = torch.argsort(noise, dim=1) |
| ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
| |
| mask = torch.ones([N, L], device=x.device, dtype=torch.float32) |
| if isinstance(len_keep, torch.Tensor): |
| |
| keep = torch.arange(L, device=x.device)[None, :].expand(N, L) < len_keep[:, None] |
| mask = (~keep).to(dtype=torch.float32) |
| else: |
| mask[:, :len_keep] = 0 |
| mask = torch.gather(mask, dim=1, index=ids_restore) |
|
|
| |
| mask_token = self.action_mask_token.expand(N, L, -1).to(dtype=x.dtype, device=x.device) |
| x_masked = x * (1.0 - mask[:, :, None]) + mask[:, :, None] * mask_token |
|
|
| return x_masked, mask, ids_restore |
|
|
| def random_masking_interleaved( |
| self, |
| interleaved: torch.Tensor, |
| mask_ratio: float | torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| MAE-style random masking for interleaved [state_0, action_0, state_1, action_1, ...]. |
| Positions 0, 2, 4, ... are state (replaced with state_mask_token when masked); |
| positions 1, 3, 5, ... are action (replaced with action_mask_token when masked). |
| |
| Args: |
| interleaved: [N, 2*L, D] (state, action, state, action, ...) |
| mask_ratio: float in [0, 1) OR tensor [N] per-sample |
| |
| Returns: |
| x_masked: [N, 2*L, D] |
| mask: [N, 2*L] (0=keep, 1=mask) |
| ids_restore: [N, 2*L] |
| """ |
| N, two_L, D = interleaved.shape |
| L = two_L // 2 |
| if isinstance(mask_ratio, torch.Tensor): |
| mask_ratio = mask_ratio.to(device=interleaved.device, dtype=torch.float32).clamp(min=0.0, max=0.999) |
| len_keep = torch.floor(two_L * (1.0 - mask_ratio)).to(dtype=torch.long) |
| else: |
| mr = max(0.0, min(0.999, float(mask_ratio))) |
| len_keep = int(two_L * (1.0 - mr)) |
|
|
| noise = torch.rand(N, two_L, device=interleaved.device) |
| ids_shuffle = torch.argsort(noise, dim=1) |
| ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
| if isinstance(len_keep, torch.Tensor): |
| keep = torch.arange(two_L, device=interleaved.device)[None, :].expand(N, two_L) < len_keep[:, None] |
| mask = (~keep).to(dtype=torch.float32) |
| else: |
| mask = torch.ones(N, two_L, device=interleaved.device, dtype=torch.float32) |
| mask[:, :len_keep] = 0 |
| mask = torch.gather(mask, dim=1, index=ids_restore) |
|
|
| state_mtk = self.state_mask_token.expand(N, two_L, -1).to(dtype=interleaved.dtype, device=interleaved.device) |
| action_mtk = self.action_mask_token.expand(N, two_L, -1).to(dtype=interleaved.dtype, device=interleaved.device) |
| |
| state_pos = torch.zeros(two_L, device=interleaved.device, dtype=torch.float32) |
| state_pos[0::2] = 1.0 |
| state_pos = state_pos.view(1, two_L, 1) |
| action_pos = 1.0 - state_pos |
| mask_expand = mask[:, :, None] |
| replacement = mask_expand * (state_pos * state_mtk + action_pos * action_mtk) |
| x_masked = interleaved * (1.0 - mask_expand) + replacement |
| return x_masked, mask, ids_restore |
|
|
| |
| def _maybe_init_from_qwen3(self) -> None: |
| from transformers import AutoModel |
|
|
| name_or_path = getattr(self.config, "qwen3_pretrained_name_or_path", None) |
| if not name_or_path: |
| return |
|
|
| pretrained = AutoModel.from_pretrained( |
| name_or_path, |
| torch_dtype="auto", |
| low_cpu_mem_usage=True, |
| ) |
|
|
| src_sd = pretrained.state_dict() |
| layer_prefix = None |
| for p in ("model.layers.", "layers."): |
| if any(k.startswith(p) for k in src_sd.keys()): |
| layer_prefix = p |
| break |
|
|
| norm_prefix = None |
| for p in ("model.norm.", "norm."): |
| if any(k.startswith(p) for k in src_sd.keys()): |
| norm_prefix = p |
| break |
|
|
| if layer_prefix is None: |
| return |
|
|
| def _map_layer_key(target_key: str, module_prefix: str, layer_offset: int) -> str | None: |
| rest = target_key[len(module_prefix) + 1 :] |
| parts = rest.split(".", 1) |
| if len(parts) != 2: |
| return None |
| try: |
| tgt_idx = int(parts[0]) |
| except ValueError: |
| return None |
| src_idx = tgt_idx + int(layer_offset) |
| return f"{layer_prefix}{src_idx}.{parts[1]}" |
|
|
| own_sd = self.state_dict() |
| to_load: dict[str, torch.Tensor] = {} |
| matched = 0 |
| missing = 0 |
| shape_mismatch = 0 |
|
|
| init_enc = bool(getattr(self.config, "qwen3_init_action_encoder", True)) |
| init_dec = bool(getattr(self.config, "qwen3_init_action_decoder", True)) |
| init_norm = bool(getattr(self.config, "qwen3_init_norm", True)) |
| enc_off = int(getattr(self.config, "qwen3_encoder_layer_offset", 0)) |
| dec_off = int(getattr(self.config, "qwen3_decoder_layer_offset", 0)) |
|
|
| |
| |
| |
| |
| |
| for k, tgt_tensor in own_sd.items(): |
| src_key = None |
| if init_enc and k.startswith("action_encoder."): |
| src_key = _map_layer_key(k, "action_encoder", enc_off) |
| elif init_dec and k.startswith("action_decoder."): |
| |
| if ".cond_mlp." in k: |
| continue |
| src_key = _map_layer_key(k, "action_decoder", dec_off) |
| elif init_norm and k == "norm.weight" and norm_prefix is not None: |
| src_key = f"{norm_prefix}weight" |
|
|
| if not src_key: |
| continue |
| src_tensor = src_sd.get(src_key, None) |
| if src_tensor is None: |
| missing += 1 |
| continue |
| if src_tensor.shape != tgt_tensor.shape: |
| shape_mismatch += 1 |
| continue |
|
|
| to_load[k] = src_tensor.to(device=tgt_tensor.device, dtype=tgt_tensor.dtype) |
| matched += 1 |
|
|
| self.load_state_dict(to_load, strict=False) |
| print( |
| f"Initialized from Qwen3 checkpoint {name_or_path}. " |
| f"matched={matched} missing={missing} shape_mismatch={shape_mismatch} prefix={layer_prefix}" |
| ) |
|
|
| if matched == 0: |
| |
| src_cfg = getattr(pretrained, "config", None) |
| if src_cfg is not None: |
| fields = [ |
| "hidden_size", |
| "intermediate_size", |
| "num_hidden_layers", |
| "num_attention_heads", |
| "num_key_value_heads", |
| "head_dim", |
| "rms_norm_eps", |
| ] |
| diffs = [] |
| for f in fields: |
| if hasattr(src_cfg, f) and hasattr(self.config, f): |
| a = getattr(self.config, f) |
| b = getattr(src_cfg, f) |
| if a != b: |
| diffs.append((f, a, b)) |
| if diffs: |
| print("[ActionModelFM] Qwen3 init got 0 matches. Config differs from checkpoint:") |
| for f, a, b in diffs: |
| print(f" - {f}: ActionModelConfig={a} vs Qwen3={b}") |
|
|
| def forward( |
| self, |
| examples: List[dict] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ): |
| device = next(self.parameters()).device |
| batch_size = len(examples) |
|
|
| |
| |
| |
| raw_actions = torch.tensor( |
| np.array([ex["action"] for ex in examples]), |
| device=device, |
| dtype=torch.float32, |
| ) |
|
|
| use_state = self.use_state |
| raw_states = None |
| if use_state: |
| raw_states = torch.tensor( |
| np.array([ex["state"] for ex in examples]), |
| device=device, |
| dtype=torch.float32, |
| ) |
|
|
| |
| |
| |
| |
| |
| with torch.autocast("cuda", dtype=torch.float32): |
| clean_action_embeds = self.action_proj_in(raw_actions) |
| if use_state: |
| clean_state_embeds = self.state_proj_in(raw_states) |
| |
| clean_inputs_embeds = torch.stack( |
| [clean_state_embeds, clean_action_embeds], dim=2 |
| ).reshape(batch_size, 2 * raw_actions.shape[1], -1) |
| else: |
| clean_inputs_embeds = clean_action_embeds |
|
|
| masked_inputs_embeds = clean_inputs_embeds |
| if self.use_masked_action_recon: |
| if use_state: |
| if getattr(self.config, "mask_ratio_mode", "fixed") == "uniform_per_traj": |
| mr_min = float(getattr(self.config, "mask_ratio_min", self.config.mask_ratio)) |
| mr_max = float(getattr(self.config, "mask_ratio_max", self.config.mask_ratio)) |
| per_traj_mr = torch.rand((batch_size,), device=device) * (mr_max - mr_min) + mr_min |
| masked_inputs_embeds, _, _ = self.random_masking_interleaved(clean_inputs_embeds, per_traj_mr) |
| else: |
| masked_inputs_embeds, _, _ = self.random_masking_interleaved( |
| clean_inputs_embeds, float(self.config.mask_ratio) |
| ) |
| else: |
| if getattr(self.config, "mask_ratio_mode", "fixed") == "uniform_per_traj": |
| mr_min = float(getattr(self.config, "mask_ratio_min", self.config.mask_ratio)) |
| mr_max = float(getattr(self.config, "mask_ratio_max", self.config.mask_ratio)) |
| per_traj_mr = torch.rand((batch_size,), device=device) * (mr_max - mr_min) + mr_min |
| masked_inputs_embeds, _, _ = self.random_masking(clean_inputs_embeds, per_traj_mr) |
| else: |
| masked_inputs_embeds, _, _ = self.random_masking(clean_inputs_embeds, float(self.config.mask_ratio)) |
|
|
| |
| |
| |
| dataset_ids = [ex.get("dataset_id") for ex in examples] |
| dataset_ids_tensor = torch.tensor(dataset_ids, device=device, dtype=torch.long) |
| ds_embeds = self.dataset_embed(dataset_ids_tensor).view( |
| batch_size, self.config.num_data_tokens, self.config.hidden_size |
| ) |
|
|
| cls_token_expanded = self.cls_token.expand(batch_size, -1, -1) |
| encoder_inputs_clean = torch.cat((cls_token_expanded, ds_embeds, clean_inputs_embeds), dim=1) |
| encoder_inputs_masked = torch.cat((cls_token_expanded, ds_embeds, masked_inputs_embeds), dim=1) |
|
|
| seq_len = encoder_inputs_clean.shape[1] |
| enc_bs = batch_size * 2 if self.use_masked_action_recon else batch_size |
| encoder_attention_mask = torch.ones((enc_bs, 1, seq_len, seq_len), device=device, dtype=torch.bool) |
| encoder_pos_ids = torch.arange(seq_len, device=device).unsqueeze(0) |
| |
| enc_pos_emb = self.rotary_emb(encoder_inputs_clean, encoder_pos_ids) |
|
|
| hidden_states = ( |
| torch.cat((encoder_inputs_masked, encoder_inputs_clean), dim=0) |
| if self.use_masked_action_recon |
| else encoder_inputs_clean |
| ) |
| for encoder_layer in self.action_encoder: |
| hidden_states = encoder_layer( |
| hidden_states, |
| attention_mask=encoder_attention_mask, |
| position_embeddings=enc_pos_emb, |
| position_ids=encoder_pos_ids, |
| **kwargs, |
| ) |
|
|
| if self.use_masked_action_recon: |
| hidden_masked, hidden_clean = hidden_states.chunk(2, dim=0) |
| action_embedding_masked = F.normalize(hidden_masked[:, :1, :], p=2, dim=-1) |
| action_embedding_clean = F.normalize(hidden_clean[:, :1, :], p=2, dim=-1) |
| else: |
| action_embedding_clean = F.normalize(hidden_states[:, :1, :], p=2, dim=-1) |
| action_embedding_masked = None |
|
|
| |
| |
| |
| t = self._sample_fm_time(batch_size, device=device, dtype=raw_actions.dtype) |
| noise = torch.randn_like(raw_actions) |
| noisy_actions = t[:, None, None] * noise + (1 - t[:, None, None]) * raw_actions |
| target_velocity = noise - raw_actions |
|
|
| noisy_embeds = self.action_proj_in(noisy_actions) |
| if self.use_masked_action_recon: |
| |
| decoder_cond = torch.cat((action_embedding_clean, action_embedding_masked), dim=0) |
| noisy_embeds = torch.cat((noisy_embeds, noisy_embeds), dim=0) |
| t = torch.cat((t, t), dim=0) |
| target_velocity = torch.cat((target_velocity, target_velocity), dim=0) |
| else: |
| decoder_cond = action_embedding_clean |
|
|
| decoder_inputs = torch.cat((decoder_cond, noisy_embeds), dim=1) |
|
|
| dec_seq_len = decoder_inputs.shape[1] |
| dec_bs = decoder_inputs.shape[0] |
| decoder_attention_mask = torch.ones((dec_bs, 1, dec_seq_len, dec_seq_len), device=device, dtype=torch.bool) |
| dec_pos_ids = torch.arange(dec_seq_len, device=device).unsqueeze(0) |
| dec_pos_emb = self.rotary_emb(decoder_inputs, dec_pos_ids) |
| temb = self._fm_temb(t) |
|
|
| hidden_states = decoder_inputs |
| for decoder_layer in self.action_decoder: |
| hidden_states = decoder_layer( |
| hidden_states, |
| temb=temb, |
| attention_mask=decoder_attention_mask, |
| position_embeddings=dec_pos_emb, |
| position_ids=dec_pos_ids, |
| ) |
|
|
| hidden_states = self.norm(hidden_states, temb) |
| pred_velocity = self.action_proj_out(hidden_states[:, 1:, :]) |
|
|
| if self.use_masked_action_recon: |
| pred_clean, pred_masked = pred_velocity.chunk(2, dim=0) |
| target_clean, target_masked = target_velocity.chunk(2, dim=0) |
| recon_loss_clean = F.mse_loss(pred_clean, target_clean) |
| recon_loss_masked = F.mse_loss(pred_masked, target_masked) |
| recon_loss = 0.5 * (recon_loss_clean + recon_loss_masked) |
| else: |
| recon_loss = F.mse_loss(pred_velocity, target_velocity) |
| return recon_loss |
|
|
| def recon_loss(self, actions, dataset_ids: list[int], state=None, **kwargs): |
| """ |
| Same interface as `ActionModel.recon_loss`, but using flow-matching decoder loss. |
| |
| Args: |
| actions: (B, L, action_dim) |
| dataset_ids: list[int]; used for dataset soft prompt when state is provided. |
| state: optional (B, L, state_dim); if provided and state_proj_in exists, |
| encoder sees interleaved sequence [state_0, action_0, state_1, action_1, ...]. |
| Returns: |
| scalar loss |
| """ |
| |
| action_embedding = kwargs.pop("action_embedding", None) |
| t = kwargs.pop("t", None) |
| noise = kwargs.pop("noise", None) |
|
|
| if action_embedding is None: |
| action_embedding = self.encode_actions(actions, dataset_ids, state, **kwargs) |
|
|
| return self.recon_loss_from_embedding( |
| action_embedding=action_embedding, |
| actions=actions, |
| t=t, |
| noise=noise, |
| ) |
|
|
| def recon_loss_from_embedding( |
| self, |
| action_embedding: torch.Tensor, |
| actions: torch.Tensor, |
| t: torch.Tensor | None = None, |
| noise: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| """ |
| Flow-matching velocity loss conditioned on a provided action embedding. |
| |
| This is the preferred interface when you already have an action embedding (e.g., from VLM projector), |
| since it avoids an extra action-encoder forward. |
| |
| Args: |
| action_embedding: (B, H) or (B, 1, H), assumed L2-normalized (recommended). |
| actions: (B, L, action_dim) |
| t: optional (B,) time; if None sample internally |
| noise: optional (B, L, action_dim) noise; if None sample internally |
| """ |
| if action_embedding.dim() == 2: |
| action_embedding = action_embedding.unsqueeze(1) |
| if action_embedding.dim() != 3 or action_embedding.shape[1] != 1: |
| raise ValueError(f"Expected action_embedding shape (B,1,H) or (B,H); got {tuple(action_embedding.shape)}") |
|
|
| batch_size = actions.shape[0] |
| device = actions.device |
| dtype = actions.dtype |
|
|
| if t is None: |
| t = self._sample_fm_time(batch_size, device=device, dtype=dtype) |
| if noise is None: |
| noise = torch.randn_like(actions) |
|
|
| noisy_actions = t[:, None, None] * noise + (1 - t[:, None, None]) * actions |
| target_velocity = noise - actions |
|
|
| temb = self._fm_temb(t) |
| action_embeds = self.action_proj_in(noisy_actions) |
| hidden_states = torch.cat((action_embedding, action_embeds), dim=1) |
|
|
| dec_seq_len = hidden_states.shape[1] |
| decoder_attention_mask = torch.ones( |
| (batch_size, 1, dec_seq_len, dec_seq_len), |
| device=device, |
| dtype=torch.bool, |
| ) |
| dec_pos_ids = torch.arange(dec_seq_len, device=device).unsqueeze(0) |
| dec_pos_emb = self.rotary_emb(hidden_states, dec_pos_ids) |
|
|
| for decoder_layer in self.action_decoder: |
| hidden_states = decoder_layer( |
| hidden_states, |
| temb=temb, |
| attention_mask=decoder_attention_mask, |
| position_embeddings=dec_pos_emb, |
| position_ids=dec_pos_ids, |
| ) |
|
|
| hidden_states = self.norm(hidden_states, temb) |
| pred_velocity = self.action_proj_out(hidden_states[:, 1:, :]) |
| return F.mse_loss(pred_velocity, target_velocity) |
|
|
| def encode_actions(self, actions, dataset_ids: list[int], state=None, **kwargs): |
| """ |
| Encode action chunk (and optional state chunk) to a single CLS embedding. |
| |
| Args: |
| actions: (B, L, action_dim) |
| state: optional (B, L, state_dim); if provided and state_proj_in exists, |
| encoder sees interleaved sequence [state_0, action_0, state_1, action_1, ...]. |
| dataset_ids: list[int]; used for dataset soft prompt when state is provided. |
| """ |
| action_embeds = self.action_proj_in(actions) |
| batch_size = action_embeds.shape[0] |
| use_state = state is not None and self.state_proj_in is not None |
| if use_state: |
| state_embeds = self.state_proj_in(state) |
| L = action_embeds.shape[1] |
| inputs_embeds = torch.stack( |
| [state_embeds, action_embeds], dim=2 |
| ).reshape(batch_size, 2 * L, -1) |
| else: |
| inputs_embeds = action_embeds |
|
|
| cls_token_expanded = self.cls_token.expand(batch_size, -1, -1) |
|
|
| dataset_ids_tensor = torch.tensor(dataset_ids, device=action_embeds.device, dtype=torch.long) |
| ds_embeds = self.dataset_embed(dataset_ids_tensor).view( |
| batch_size, self.config.num_data_tokens, self.config.hidden_size |
| ) |
| inputs_embeds = torch.cat((cls_token_expanded, ds_embeds, inputs_embeds), dim=1) |
|
|
| seq_len = inputs_embeds.shape[1] |
| encoder_attention_mask = torch.ones( |
| (batch_size, 1, seq_len, seq_len), |
| device=inputs_embeds.device, |
| dtype=torch.bool, |
| ) |
| encoder_pos_ids = torch.arange(seq_len, device=inputs_embeds.device).unsqueeze(0) |
| enc_pos_emb = self.rotary_emb(inputs_embeds, encoder_pos_ids) |
|
|
| hidden_states = inputs_embeds |
| for encoder_layer in self.action_encoder: |
| hidden_states = encoder_layer( |
| hidden_states, |
| attention_mask=encoder_attention_mask, |
| position_embeddings=enc_pos_emb, |
| position_ids=encoder_pos_ids, |
| **kwargs, |
| ) |
|
|
| action_embedding = hidden_states[:, :1, :] |
| return F.normalize(action_embedding, p=2, dim=-1) |
|
|
| @torch.no_grad() |
| def decode_actions(self, action_embedding, chunk_size, **kwargs): |
| """ |
| FM sampling via simple Euler integration of the learned velocity field. |
| """ |
| if chunk_size is None: |
| chunk_size = self.config.max_action_chunk_size |
|
|
| if action_embedding.dim() == 2: |
| action_embedding = action_embedding.unsqueeze(1) |
|
|
| batch_size = action_embedding.shape[0] |
| device = action_embedding.device |
| dtype = action_embedding.dtype |
|
|
| actions = torch.randn((batch_size, chunk_size, self.config.action_size), device=device, dtype=dtype) |
| num_steps = max(int(self.fm_num_inference_steps), 1) |
| dt = -1.0 / float(num_steps) |
|
|
| for step in range(num_steps): |
| t = torch.full((batch_size,), 1.0 - step / float(num_steps), device=device, dtype=dtype) |
| temb = self._fm_temb(t) |
|
|
| action_embeds = self.action_proj_in(actions) |
| hidden_states = torch.cat((action_embedding, action_embeds), dim=1) |
|
|
| dec_seq_len = hidden_states.shape[1] |
| decoder_attention_mask = torch.ones((batch_size, 1, dec_seq_len, dec_seq_len), device=device, dtype=torch.bool) |
| dec_pos_ids = torch.arange(dec_seq_len, device=device).unsqueeze(0) |
| dec_pos_emb = self.rotary_emb(hidden_states, dec_pos_ids) |
|
|
| for decoder_layer in self.action_decoder: |
| hidden_states = decoder_layer( |
| hidden_states, |
| temb=temb, |
| attention_mask=decoder_attention_mask, |
| position_embeddings=dec_pos_emb, |
| position_ids=dec_pos_ids, |
| ) |
|
|
| hidden_states = self.norm(hidden_states, temb) |
| pred_velocity = self.action_proj_out(hidden_states[:, 1:, :]) |
| actions = actions + dt * pred_velocity |
|
|
| return actions |
|
|
|
|
| __all__ = [ |
| "ActionModelFM", |
| ] |
|
|
| if __name__ == "__main__": |
| config = ActionModelConfig() |
| action_model = ActionModelFM(config) |
| print(action_model) |
|
|
| print("Total number of DiT parameters: ", |
| sum(p.numel() for p in action_model.parameters() if p.requires_grad)) |
|
|
| fake_actions = torch.randn(10, 15, 64).to("cuda:7") |
|
|
| sample = { |
| "action": np.random.uniform(-1, 1, size=(16, 29)).astype(np.float16), |
| "lang": "put the ball on the table", |
| } |
|
|
| batch = [sample, sample] |
| device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu") |
| action_model = action_model.to(device) |
|
|
| outputs = action_model(batch) |
| print(outputs) |
|
|
| action_embedding = action_model.encode_actions(fake_actions) |
| print(f"action_embedding: {action_embedding.shape}") |
|
|
| reconstructed_actions = action_model.decode_actions(action_embedding, chunk_size=15) |
| print(f"reconstructed_actions: {reconstructed_actions.shape}") |
|
|