cross13tasks / code /model /modules /action_model /ActionModel_FM.py
Timsty's picture
Upload folder using huggingface_hub
e94400c verified
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file is a StarVLA-local variant of the action encoder/decoder.
# It keeps the overall structure but replaces the decoder with a
# flow-matching based decoder (velocity prediction) and injects timestep
# conditioning into RMSNorm (AdaRMSNorm) in the decoder.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import sys
sys.path.append("/mnt/data/fangyu/code/reward_new")
import math
from typing import List
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from torch.distributions import Beta
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, logging
from starVLA.model.modules.action_model.ActionModel import (
Qwen3Attention,
Qwen3MLP,
Qwen3RMSNorm,
Qwen3RotaryEmbedding,
ActionPreTrainedModel,
)
from starVLA.model.modules.action_model.configuration_actionmodel import ActionModelConfig
from starVLA.model.tools import FRAMEWORK_REGISTRY
logger = logging.get_logger(__name__)
class _GradientReversalFunction(torch.autograd.Function):
"""
Forward: identity. Backward: scale gradient by -lambda (inverse gradient).
Used for domain adversarial training so the encoder receives reversed gradient
and is encouraged to produce domain-invariant embeddings.
"""
@staticmethod
def forward(ctx, x: torch.Tensor, lambda_: float) -> torch.Tensor:
ctx.lambda_ = lambda_
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
return -ctx.lambda_ * grad_output, None
def _timestep_embedding(t: torch.Tensor, dim: int, max_period: float = 10000.0) -> torch.Tensor:
"""
Standard sinusoidal timestep embedding.
Args:
t: (B,) float tensor, typically in [0, 1].
Returns:
(B, dim)
"""
if t.ndim != 1:
raise ValueError(f"Expected `t` to have shape (B,), got {tuple(t.shape)}")
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(0, half, device=t.device, dtype=torch.float32) / max(half, 1)
)
args = t.to(torch.float32)[:, None] * freqs[None]
emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2 == 1:
emb = torch.cat([emb, torch.zeros((emb.shape[0], 1), device=t.device, dtype=emb.dtype)], dim=-1)
return emb.to(dtype=t.dtype)
class Qwen3AdaRMSNorm(nn.Module):
"""
RMSNorm + timestep conditioning.
y = RMSNorm(x) * (1 + scale(t)) + shift(t)
"""
def __init__(self, hidden_size: int, cond_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.cond_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(cond_size, 2 * hidden_size, bias=True),
)
def forward(self, hidden_states: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
if cond is None:
raise ValueError("Qwen3AdaRMSNorm requires `cond` but got None.")
if cond.ndim != 2:
raise ValueError(f"Expected `cond` to have shape (B, C), got {tuple(cond.shape)}")
input_dtype = hidden_states.dtype
x = hidden_states.to(torch.float32)
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = self.weight * x.to(input_dtype)
scale, shift = self.cond_mlp(cond).chunk(2, dim=-1)
return x * (1 + scale[:, None, :]) + shift[:, None, :]
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Qwen3LayerFM(nn.Module):
"""
Same block structure as `Qwen3Layer`, but decoder-side RMSNorms are timestep-conditioned.
Attention/MLP are unchanged.
"""
def __init__(self, config: ActionModelConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx)
self.mlp = Qwen3MLP(config)
self.input_layernorm = Qwen3AdaRMSNorm(config.hidden_size, cond_size=config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen3AdaRMSNorm(
config.hidden_size, cond_size=config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states, temb)
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states, temb)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class ActionModelFM(ActionPreTrainedModel):
"""
Flow-matching based decoder variant for StarVLA `ActionModel`.
Encoder stays the same; decoder predicts velocity under linear interpolation noise.
"""
def __init__(self, config: ActionModelConfig):
super().__init__(config)
self.config = config
# ===== tokens / embeddings (same as original) =====
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.action_mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.state_mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.dataset_embed = nn.Embedding(
config.dataset_vocab_size,
config.hidden_size * config.num_data_tokens,
)
self.action_proj_in = nn.Linear(config.action_size, config.hidden_size)
self.state_proj_in = nn.Linear(config.state_size, config.hidden_size)
self.use_state = config.use_state
print(f"use_state: {self.use_state}")
# ===== encoder (unchanged blocks) =====
# Reuse the original Qwen3Layer implementation from ActionModel.py through `ActionPreTrainedModel` machinery
from starVLA.model.modules.action_model.ActionModel import Qwen3Layer # local import
self.action_encoder = nn.ModuleList([Qwen3Layer(config, layer_idx) for layer_idx in range(config.num_encoder_layers)])
# ===== decoder (FM) =====
self.action_decoder = nn.ModuleList([Qwen3LayerFM(config, layer_idx) for layer_idx in range(config.num_decoder_layers)])
self.norm = Qwen3AdaRMSNorm(config.hidden_size, cond_size=config.hidden_size, eps=config.rms_norm_eps)
self.action_proj_out = nn.Linear(config.hidden_size, config.action_size)
self.rotary_emb = Qwen3RotaryEmbedding(config=config)
self.gradient_checkpointing = False
# ===== FM hyperparams =====
self.fm_time_min = float(getattr(config, "fm_time_min", 0.001))
self.fm_time_max = float(getattr(config, "fm_time_max", 0.999))
self.fm_num_inference_steps = int(getattr(config, "fm_num_inference_steps", 10))
self.fm_time_sampling = str(getattr(config, "fm_time_sampling", "uniform")) # "uniform" | "beta"
self.fm_beta_alpha = float(getattr(config, "fm_beta_alpha", 1.5))
self.fm_beta_beta = float(getattr(config, "fm_beta_beta", 1.0))
self._beta_dist = Beta(self.fm_beta_alpha, self.fm_beta_beta)
# timestep -> temb (B,H)
self.fm_timestep_mlp = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size * 4, bias=True),
nn.SiLU(),
nn.Linear(config.hidden_size * 4, config.hidden_size, bias=True),
)
# ===== Loss mode: masked-action recon =====
self.use_masked_action_recon = bool(getattr(config, "use_masked_action_recon", False))
self.post_init()
self._maybe_init_from_qwen3()
def _sample_fm_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
if self.fm_time_sampling == "beta":
t = self._beta_dist.sample([batch_size]).to(device=device, dtype=dtype)
else:
t = torch.rand((batch_size,), device=device, dtype=dtype)
t = t * (self.fm_time_max - self.fm_time_min) + self.fm_time_min
return t
def _fm_temb(self, t: torch.Tensor) -> torch.Tensor:
return self.fm_timestep_mlp(_timestep_embedding(t, self.config.hidden_size))
def _gather_embeddings(self, x: torch.Tensor) -> tuple[torch.Tensor, int]:
"""
Gather embeddings from all ranks.
Returns (gathered_tensor, offset) where offset is the start index of this rank's data in the global batch.
Single-GPU: returns (x, 0).
"""
if not (self.contrastive_use_distributed and dist.is_initialized() and dist.get_world_size() > 1):
return x, 0
world_size = dist.get_world_size()
local_size = x.shape[0]
size_list = [torch.tensor([0], dtype=torch.long, device=x.device) for _ in range(world_size)]
dist.all_gather(size_list, torch.tensor([local_size], dtype=torch.long, device=x.device))
sizes = [s.item() for s in size_list]
max_size = max(sizes)
offset = sum(sizes[: dist.get_rank()])
if max_size > local_size:
padding = torch.zeros(max_size - local_size, x.shape[1], device=x.device, dtype=x.dtype)
x = torch.cat([x, padding], dim=0)
gather_list = [torch.zeros_like(x) for _ in range(world_size)]
dist.all_gather(gather_list, x)
out = torch.cat([g[: sizes[i]] for i, g in enumerate(gather_list)], dim=0)
return out, offset
def random_masking(self, x: torch.Tensor, mask_ratio: float | torch.Tensor):
"""
MAE-style per-sample random masking by shuffling (argsort noise).
This version DOES NOT drop tokens; it returns `x_masked` with the same shape as `x`,
where masked positions are replaced by `self.action_mask_token`.
Args:
x: [N, L, D]
mask_ratio: float in [0, 1) OR tensor of shape [N] with per-sample ratios
Returns:
x_masked: [N, L, D]
mask: [N, L] (0=keep, 1=mask)
ids_restore: [N, L]
"""
N, L, D = x.shape
token_dim = int(self.action_mask_token.shape[-1])
if D != token_dim:
raise ValueError(
f"`random_masking` expects last dim D=={token_dim} (same as action_mask_token), got D=={D}."
)
if isinstance(mask_ratio, torch.Tensor):
if mask_ratio.ndim != 1 or mask_ratio.shape[0] != N:
raise ValueError(
f"When `mask_ratio` is a tensor it must have shape (N,), got {tuple(mask_ratio.shape)}"
)
# clamp to safe range
mask_ratio = mask_ratio.to(device=x.device, dtype=torch.float32).clamp(min=0.0, max=0.999)
len_keep = torch.floor(L * (1.0 - mask_ratio)).to(dtype=torch.long) # (N,)
else:
mr = float(mask_ratio)
mr = max(0.0, min(0.999, mr))
len_keep = int(L * (1.0 - mr))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is mask
ids_restore = torch.argsort(ids_shuffle, dim=1)
# generate the binary mask: 0 is keep, 1 is mask
mask = torch.ones([N, L], device=x.device, dtype=torch.float32)
if isinstance(len_keep, torch.Tensor):
# build mask in shuffled order then unshuffle
keep = torch.arange(L, device=x.device)[None, :].expand(N, L) < len_keep[:, None] # (N,L)
mask = (~keep).to(dtype=torch.float32)
else:
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore) # unshuffle
# replace masked tokens with action_mask_token (keep sequence length)
mask_token = self.action_mask_token.expand(N, L, -1).to(dtype=x.dtype, device=x.device)
x_masked = x * (1.0 - mask[:, :, None]) + mask[:, :, None] * mask_token
return x_masked, mask, ids_restore
def random_masking_interleaved(
self,
interleaved: torch.Tensor,
mask_ratio: float | torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
MAE-style random masking for interleaved [state_0, action_0, state_1, action_1, ...].
Positions 0, 2, 4, ... are state (replaced with state_mask_token when masked);
positions 1, 3, 5, ... are action (replaced with action_mask_token when masked).
Args:
interleaved: [N, 2*L, D] (state, action, state, action, ...)
mask_ratio: float in [0, 1) OR tensor [N] per-sample
Returns:
x_masked: [N, 2*L, D]
mask: [N, 2*L] (0=keep, 1=mask)
ids_restore: [N, 2*L]
"""
N, two_L, D = interleaved.shape
L = two_L // 2
if isinstance(mask_ratio, torch.Tensor):
mask_ratio = mask_ratio.to(device=interleaved.device, dtype=torch.float32).clamp(min=0.0, max=0.999)
len_keep = torch.floor(two_L * (1.0 - mask_ratio)).to(dtype=torch.long)
else:
mr = max(0.0, min(0.999, float(mask_ratio)))
len_keep = int(two_L * (1.0 - mr))
noise = torch.rand(N, two_L, device=interleaved.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
if isinstance(len_keep, torch.Tensor):
keep = torch.arange(two_L, device=interleaved.device)[None, :].expand(N, two_L) < len_keep[:, None]
mask = (~keep).to(dtype=torch.float32)
else:
mask = torch.ones(N, two_L, device=interleaved.device, dtype=torch.float32)
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
state_mtk = self.state_mask_token.expand(N, two_L, -1).to(dtype=interleaved.dtype, device=interleaved.device)
action_mtk = self.action_mask_token.expand(N, two_L, -1).to(dtype=interleaved.dtype, device=interleaved.device)
# even indices -> state, odd -> action
state_pos = torch.zeros(two_L, device=interleaved.device, dtype=torch.float32)
state_pos[0::2] = 1.0
state_pos = state_pos.view(1, two_L, 1)
action_pos = 1.0 - state_pos
mask_expand = mask[:, :, None]
replacement = mask_expand * (state_pos * state_mtk + action_pos * action_mtk)
x_masked = interleaved * (1.0 - mask_expand) + replacement
return x_masked, mask, ids_restore
# --- copied optional init helper from original ---
def _maybe_init_from_qwen3(self) -> None:
from transformers import AutoModel
name_or_path = getattr(self.config, "qwen3_pretrained_name_or_path", None)
if not name_or_path:
return
pretrained = AutoModel.from_pretrained(
name_or_path,
torch_dtype="auto",
low_cpu_mem_usage=True,
)
src_sd = pretrained.state_dict()
layer_prefix = None
for p in ("model.layers.", "layers."):
if any(k.startswith(p) for k in src_sd.keys()):
layer_prefix = p
break
norm_prefix = None
for p in ("model.norm.", "norm."):
if any(k.startswith(p) for k in src_sd.keys()):
norm_prefix = p
break
if layer_prefix is None:
return
def _map_layer_key(target_key: str, module_prefix: str, layer_offset: int) -> str | None:
rest = target_key[len(module_prefix) + 1 :]
parts = rest.split(".", 1)
if len(parts) != 2:
return None
try:
tgt_idx = int(parts[0])
except ValueError:
return None
src_idx = tgt_idx + int(layer_offset)
return f"{layer_prefix}{src_idx}.{parts[1]}"
own_sd = self.state_dict()
to_load: dict[str, torch.Tensor] = {}
matched = 0
missing = 0
shape_mismatch = 0
init_enc = bool(getattr(self.config, "qwen3_init_action_encoder", True))
init_dec = bool(getattr(self.config, "qwen3_init_action_decoder", True))
init_norm = bool(getattr(self.config, "qwen3_init_norm", True))
enc_off = int(getattr(self.config, "qwen3_encoder_layer_offset", 0))
dec_off = int(getattr(self.config, "qwen3_decoder_layer_offset", 0))
# NOTE: decoder has AdaRMSNorm (extra cond_mlp weights), but many weights still match:
# - action_decoder.*.self_attn.*
# - action_decoder.*.mlp.*
# - action_decoder.*.(input_layernorm|post_attention_layernorm).weight (load RMS weight only)
# - norm.weight (load RMS weight only)
for k, tgt_tensor in own_sd.items():
src_key = None
if init_enc and k.startswith("action_encoder."):
src_key = _map_layer_key(k, "action_encoder", enc_off)
elif init_dec and k.startswith("action_decoder."):
# Skip timestep-conditioned MLP weights (no counterpart in Qwen3)
if ".cond_mlp." in k:
continue
src_key = _map_layer_key(k, "action_decoder", dec_off)
elif init_norm and k == "norm.weight" and norm_prefix is not None:
src_key = f"{norm_prefix}weight"
if not src_key:
continue
src_tensor = src_sd.get(src_key, None)
if src_tensor is None:
missing += 1
continue
if src_tensor.shape != tgt_tensor.shape:
shape_mismatch += 1
continue
to_load[k] = src_tensor.to(device=tgt_tensor.device, dtype=tgt_tensor.dtype)
matched += 1
self.load_state_dict(to_load, strict=False)
print(
f"Initialized from Qwen3 checkpoint {name_or_path}. "
f"matched={matched} missing={missing} shape_mismatch={shape_mismatch} prefix={layer_prefix}"
)
if matched == 0:
# Most common culprit: config dims don't match Qwen3 checkpoint.
src_cfg = getattr(pretrained, "config", None)
if src_cfg is not None:
fields = [
"hidden_size",
"intermediate_size",
"num_hidden_layers",
"num_attention_heads",
"num_key_value_heads",
"head_dim",
"rms_norm_eps",
]
diffs = []
for f in fields:
if hasattr(src_cfg, f) and hasattr(self.config, f):
a = getattr(self.config, f)
b = getattr(src_cfg, f)
if a != b:
diffs.append((f, a, b))
if diffs:
print("[ActionModelFM] Qwen3 init got 0 matches. Config differs from checkpoint:")
for f, a, b in diffs:
print(f" - {f}: ActionModelConfig={a} vs Qwen3={b}")
def forward(
self,
examples: List[dict] = None,
**kwargs: Unpack[TransformersKwargs],
):
device = next(self.parameters()).device
batch_size = len(examples)
# =========================================================================
# 1. Variable-length Horizon (same as original)
# =========================================================================
raw_actions = torch.tensor(
np.array([ex["action"] for ex in examples]),
device=device,
dtype=torch.float32,
) # [B, L, D]
use_state = self.use_state
raw_states = None
if use_state:
raw_states = torch.tensor(
np.array([ex["state"] for ex in examples]),
device=device,
dtype=torch.float32,
) # [B, L, state_dim]
# =========================================================================
# 2. Action (and optional State) Input Construction & Masking (DAE)
# Encoder sequence: cls, dataset_tokens, [state_0, action_0, state_1, action_1, ...]
# Two-view (masked + clean) when use_masked_action_recon.
# =========================================================================
with torch.autocast("cuda", dtype=torch.float32):
clean_action_embeds = self.action_proj_in(raw_actions) # [B, L, H]
if use_state:
clean_state_embeds = self.state_proj_in(raw_states) # [B, L, H]
# Interleave: [s0, a0, s1, a1, ...] -> [B, 2*L, H]
clean_inputs_embeds = torch.stack(
[clean_state_embeds, clean_action_embeds], dim=2
).reshape(batch_size, 2 * raw_actions.shape[1], -1)
else:
clean_inputs_embeds = clean_action_embeds
masked_inputs_embeds = clean_inputs_embeds
if self.use_masked_action_recon:
if use_state:
if getattr(self.config, "mask_ratio_mode", "fixed") == "uniform_per_traj":
mr_min = float(getattr(self.config, "mask_ratio_min", self.config.mask_ratio))
mr_max = float(getattr(self.config, "mask_ratio_max", self.config.mask_ratio))
per_traj_mr = torch.rand((batch_size,), device=device) * (mr_max - mr_min) + mr_min
masked_inputs_embeds, _, _ = self.random_masking_interleaved(clean_inputs_embeds, per_traj_mr)
else:
masked_inputs_embeds, _, _ = self.random_masking_interleaved(
clean_inputs_embeds, float(self.config.mask_ratio)
)
else:
if getattr(self.config, "mask_ratio_mode", "fixed") == "uniform_per_traj":
mr_min = float(getattr(self.config, "mask_ratio_min", self.config.mask_ratio))
mr_max = float(getattr(self.config, "mask_ratio_max", self.config.mask_ratio))
per_traj_mr = torch.rand((batch_size,), device=device) * (mr_max - mr_min) + mr_min
masked_inputs_embeds, _, _ = self.random_masking(clean_inputs_embeds, per_traj_mr)
else:
masked_inputs_embeds, _, _ = self.random_masking(clean_inputs_embeds, float(self.config.mask_ratio))
# =========================================================================
# 3. Dataset Soft Prompt (same as original)
# =========================================================================
dataset_ids = [ex.get("dataset_id") for ex in examples]
dataset_ids_tensor = torch.tensor(dataset_ids, device=device, dtype=torch.long)
ds_embeds = self.dataset_embed(dataset_ids_tensor).view(
batch_size, self.config.num_data_tokens, self.config.hidden_size
)
cls_token_expanded = self.cls_token.expand(batch_size, -1, -1)
encoder_inputs_clean = torch.cat((cls_token_expanded, ds_embeds, clean_inputs_embeds), dim=1)
encoder_inputs_masked = torch.cat((cls_token_expanded, ds_embeds, masked_inputs_embeds), dim=1)
seq_len = encoder_inputs_clean.shape[1]
enc_bs = batch_size * 2 if self.use_masked_action_recon else batch_size
encoder_attention_mask = torch.ones((enc_bs, 1, seq_len, seq_len), device=device, dtype=torch.bool)
encoder_pos_ids = torch.arange(seq_len, device=device).unsqueeze(0)
# rotary embeddings are position-based; we keep position_ids batch=1 and broadcast.
enc_pos_emb = self.rotary_emb(encoder_inputs_clean, encoder_pos_ids)
hidden_states = (
torch.cat((encoder_inputs_masked, encoder_inputs_clean), dim=0)
if self.use_masked_action_recon
else encoder_inputs_clean
)
for encoder_layer in self.action_encoder:
hidden_states = encoder_layer(
hidden_states,
attention_mask=encoder_attention_mask,
position_embeddings=enc_pos_emb,
position_ids=encoder_pos_ids,
**kwargs,
)
if self.use_masked_action_recon:
hidden_masked, hidden_clean = hidden_states.chunk(2, dim=0)
action_embedding_masked = F.normalize(hidden_masked[:, :1, :], p=2, dim=-1)
action_embedding_clean = F.normalize(hidden_clean[:, :1, :], p=2, dim=-1)
else:
action_embedding_clean = F.normalize(hidden_states[:, :1, :], p=2, dim=-1)
action_embedding_masked = None
# =========================================================================
# 4. Flow-matching Decoder
# =========================================================================
t = self._sample_fm_time(batch_size, device=device, dtype=raw_actions.dtype) # (B,)
noise = torch.randn_like(raw_actions)
noisy_actions = t[:, None, None] * noise + (1 - t[:, None, None]) * raw_actions
target_velocity = noise - raw_actions
noisy_embeds = self.action_proj_in(noisy_actions)
if self.use_masked_action_recon:
# Single decoder forward for both views in one batch.
decoder_cond = torch.cat((action_embedding_clean, action_embedding_masked), dim=0)
noisy_embeds = torch.cat((noisy_embeds, noisy_embeds), dim=0)
t = torch.cat((t, t), dim=0)
target_velocity = torch.cat((target_velocity, target_velocity), dim=0)
else:
decoder_cond = action_embedding_clean
decoder_inputs = torch.cat((decoder_cond, noisy_embeds), dim=1) # [B or 2B, 1+L, H]
dec_seq_len = decoder_inputs.shape[1]
dec_bs = decoder_inputs.shape[0]
decoder_attention_mask = torch.ones((dec_bs, 1, dec_seq_len, dec_seq_len), device=device, dtype=torch.bool)
dec_pos_ids = torch.arange(dec_seq_len, device=device).unsqueeze(0)
dec_pos_emb = self.rotary_emb(decoder_inputs, dec_pos_ids)
temb = self._fm_temb(t)
hidden_states = decoder_inputs
for decoder_layer in self.action_decoder:
hidden_states = decoder_layer(
hidden_states,
temb=temb,
attention_mask=decoder_attention_mask,
position_embeddings=dec_pos_emb,
position_ids=dec_pos_ids,
)
hidden_states = self.norm(hidden_states, temb)
pred_velocity = self.action_proj_out(hidden_states[:, 1:, :])
if self.use_masked_action_recon:
pred_clean, pred_masked = pred_velocity.chunk(2, dim=0)
target_clean, target_masked = target_velocity.chunk(2, dim=0)
recon_loss_clean = F.mse_loss(pred_clean, target_clean)
recon_loss_masked = F.mse_loss(pred_masked, target_masked)
recon_loss = 0.5 * (recon_loss_clean + recon_loss_masked)
else:
recon_loss = F.mse_loss(pred_velocity, target_velocity)
return recon_loss
def recon_loss(self, actions, dataset_ids: list[int], state=None, **kwargs):
"""
Same interface as `ActionModel.recon_loss`, but using flow-matching decoder loss.
Args:
actions: (B, L, action_dim)
dataset_ids: list[int]; used for dataset soft prompt when state is provided.
state: optional (B, L, state_dim); if provided and state_proj_in exists,
encoder sees interleaved sequence [state_0, action_0, state_1, action_1, ...].
Returns:
scalar loss
"""
# Optional fast-path: pass a precomputed action embedding to avoid another encoder forward.
action_embedding = kwargs.pop("action_embedding", None)
t = kwargs.pop("t", None)
noise = kwargs.pop("noise", None)
if action_embedding is None:
action_embedding = self.encode_actions(actions, dataset_ids, state, **kwargs)
return self.recon_loss_from_embedding(
action_embedding=action_embedding,
actions=actions,
t=t,
noise=noise,
)
def recon_loss_from_embedding(
self,
action_embedding: torch.Tensor,
actions: torch.Tensor,
t: torch.Tensor | None = None,
noise: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Flow-matching velocity loss conditioned on a provided action embedding.
This is the preferred interface when you already have an action embedding (e.g., from VLM projector),
since it avoids an extra action-encoder forward.
Args:
action_embedding: (B, H) or (B, 1, H), assumed L2-normalized (recommended).
actions: (B, L, action_dim)
t: optional (B,) time; if None sample internally
noise: optional (B, L, action_dim) noise; if None sample internally
"""
if action_embedding.dim() == 2:
action_embedding = action_embedding.unsqueeze(1)
if action_embedding.dim() != 3 or action_embedding.shape[1] != 1:
raise ValueError(f"Expected action_embedding shape (B,1,H) or (B,H); got {tuple(action_embedding.shape)}")
batch_size = actions.shape[0]
device = actions.device
dtype = actions.dtype
if t is None:
t = self._sample_fm_time(batch_size, device=device, dtype=dtype)
if noise is None:
noise = torch.randn_like(actions)
noisy_actions = t[:, None, None] * noise + (1 - t[:, None, None]) * actions
target_velocity = noise - actions
temb = self._fm_temb(t)
action_embeds = self.action_proj_in(noisy_actions)
hidden_states = torch.cat((action_embedding, action_embeds), dim=1)
dec_seq_len = hidden_states.shape[1]
decoder_attention_mask = torch.ones(
(batch_size, 1, dec_seq_len, dec_seq_len),
device=device,
dtype=torch.bool,
)
dec_pos_ids = torch.arange(dec_seq_len, device=device).unsqueeze(0)
dec_pos_emb = self.rotary_emb(hidden_states, dec_pos_ids)
for decoder_layer in self.action_decoder:
hidden_states = decoder_layer(
hidden_states,
temb=temb,
attention_mask=decoder_attention_mask,
position_embeddings=dec_pos_emb,
position_ids=dec_pos_ids,
)
hidden_states = self.norm(hidden_states, temb)
pred_velocity = self.action_proj_out(hidden_states[:, 1:, :])
return F.mse_loss(pred_velocity, target_velocity)
def encode_actions(self, actions, dataset_ids: list[int], state=None, **kwargs):
"""
Encode action chunk (and optional state chunk) to a single CLS embedding.
Args:
actions: (B, L, action_dim)
state: optional (B, L, state_dim); if provided and state_proj_in exists,
encoder sees interleaved sequence [state_0, action_0, state_1, action_1, ...].
dataset_ids: list[int]; used for dataset soft prompt when state is provided.
"""
action_embeds = self.action_proj_in(actions)
batch_size = action_embeds.shape[0]
use_state = state is not None and self.state_proj_in is not None
if use_state:
state_embeds = self.state_proj_in(state)
L = action_embeds.shape[1]
inputs_embeds = torch.stack(
[state_embeds, action_embeds], dim=2
).reshape(batch_size, 2 * L, -1)
else:
inputs_embeds = action_embeds
cls_token_expanded = self.cls_token.expand(batch_size, -1, -1)
dataset_ids_tensor = torch.tensor(dataset_ids, device=action_embeds.device, dtype=torch.long)
ds_embeds = self.dataset_embed(dataset_ids_tensor).view(
batch_size, self.config.num_data_tokens, self.config.hidden_size
)
inputs_embeds = torch.cat((cls_token_expanded, ds_embeds, inputs_embeds), dim=1)
seq_len = inputs_embeds.shape[1]
encoder_attention_mask = torch.ones(
(batch_size, 1, seq_len, seq_len),
device=inputs_embeds.device,
dtype=torch.bool,
)
encoder_pos_ids = torch.arange(seq_len, device=inputs_embeds.device).unsqueeze(0)
enc_pos_emb = self.rotary_emb(inputs_embeds, encoder_pos_ids)
hidden_states = inputs_embeds
for encoder_layer in self.action_encoder:
hidden_states = encoder_layer(
hidden_states,
attention_mask=encoder_attention_mask,
position_embeddings=enc_pos_emb,
position_ids=encoder_pos_ids,
**kwargs,
)
action_embedding = hidden_states[:, :1, :]
return F.normalize(action_embedding, p=2, dim=-1)
@torch.no_grad()
def decode_actions(self, action_embedding, chunk_size, **kwargs):
"""
FM sampling via simple Euler integration of the learned velocity field.
"""
if chunk_size is None:
chunk_size = self.config.max_action_chunk_size
if action_embedding.dim() == 2:
action_embedding = action_embedding.unsqueeze(1)
batch_size = action_embedding.shape[0]
device = action_embedding.device
dtype = action_embedding.dtype
actions = torch.randn((batch_size, chunk_size, self.config.action_size), device=device, dtype=dtype)
num_steps = max(int(self.fm_num_inference_steps), 1)
dt = -1.0 / float(num_steps)
for step in range(num_steps):
t = torch.full((batch_size,), 1.0 - step / float(num_steps), device=device, dtype=dtype)
temb = self._fm_temb(t)
action_embeds = self.action_proj_in(actions)
hidden_states = torch.cat((action_embedding, action_embeds), dim=1)
dec_seq_len = hidden_states.shape[1]
decoder_attention_mask = torch.ones((batch_size, 1, dec_seq_len, dec_seq_len), device=device, dtype=torch.bool)
dec_pos_ids = torch.arange(dec_seq_len, device=device).unsqueeze(0)
dec_pos_emb = self.rotary_emb(hidden_states, dec_pos_ids)
for decoder_layer in self.action_decoder:
hidden_states = decoder_layer(
hidden_states,
temb=temb,
attention_mask=decoder_attention_mask,
position_embeddings=dec_pos_emb,
position_ids=dec_pos_ids,
)
hidden_states = self.norm(hidden_states, temb)
pred_velocity = self.action_proj_out(hidden_states[:, 1:, :])
actions = actions + dt * pred_velocity
return actions
__all__ = [
"ActionModelFM",
]
if __name__ == "__main__":
config = ActionModelConfig()
action_model = ActionModelFM(config)
print(action_model)
print("Total number of DiT parameters: ",
sum(p.numel() for p in action_model.parameters() if p.requires_grad))
fake_actions = torch.randn(10, 15, 64).to("cuda:7")
sample = {
"action": np.random.uniform(-1, 1, size=(16, 29)).astype(np.float16), # action_chunk, action_dim (unified 29D)
"lang": "put the ball on the table",
}
batch = [sample, sample]
device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
action_model = action_model.to(device)
outputs = action_model(batch)
print(outputs)
action_embedding = action_model.encode_actions(fake_actions)
print(f"action_embedding: {action_embedding.shape}")
reconstructed_actions = action_model.decode_actions(action_embedding, chunk_size=15)
print(f"reconstructed_actions: {reconstructed_actions.shape}")