cross13tasks / code /model /modules /action_model /LayerwiseFM_ActionHeader.py
Timsty's picture
Upload folder using huggingface_hub
e94400c verified
# Copyright 2025 NVIDIA Corp. and affiliates. All rights reserved.
# Modified by [Junqiu YU/ Fudan University] in [2025].
# Modification: [rm and add some connect adapter to match with starVLA, e.g., "rm "].
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from torch import nn
from torch.distributions import Beta
from transformers import PretrainedConfig
from transformers.feature_extraction_utils import BatchFeature
from starVLA.model.modules.action_model.flow_matching_head.action_encoder import (
SinusoidalPositionalEncoding,
swish,
)
from starVLA.model.modules.action_model.flow_matching_head.cross_attention_dit import DiT, SelfAttentionTransformer
# TODO try to meger DiT Modules with follow_match_head, they are just the same arch, but diff loss, use diffusers package will be simple
class CategorySpecificLinear(nn.Module):
def __init__(self, num_categories, input_dim, hidden_dim):
super().__init__()
self.num_categories = num_categories
# For each category, we have separate weights and biases.
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
def forward(self, x, cat_ids):
selected_W = self.W[cat_ids]
selected_b = self.b[cat_ids]
# import ipdb; ipdb.set_trace()
return torch.bmm(x, selected_W) + selected_b.unsqueeze(1)
class CategorySpecificMLP(nn.Module):
def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
super().__init__()
self.num_categories = num_categories
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
def forward(self, x, cat_ids):
hidden = F.relu(self.layer1(x, cat_ids))
return self.layer2(hidden, cat_ids)
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim=1024, output_dim=2048):
super().__init__()
self.layer1 = nn.Linear(input_dim, hidden_dim)
self.layer2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
return self.layer2(F.relu(self.layer1(x)))
class ActionEncoder(nn.Module):
def __init__(self, action_dim, hidden_size=1024):
super().__init__()
self.hidden_size = hidden_size
self.action_dim = action_dim
self.layer1 = nn.Linear(action_dim, hidden_size)
self.layer2 = nn.Linear(2 * hidden_size, hidden_size)
self.layer3 = nn.Linear(hidden_size, hidden_size)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions, timesteps):
"""
actions: shape (B, T, action_dim)
timesteps: shape (B,) -- a single scalar per batch item
returns: shape (B, T, hidden_size)
"""
B, T, _ = actions.shape
# 1) Expand each batch's single scalar time 'tau' across all T steps
# so that shape => (B, T)
# e.g. if timesteps is (B,), replicate across T
if timesteps.dim() == 1 and timesteps.shape[0] == B:
# shape (B,) => (B,T)
timesteps = timesteps.unsqueeze(1).expand(-1, T)
else:
raise ValueError(
"Expected `timesteps` to have shape (B,) so we can replicate across T."
)
# 2) Standard action MLP step for shape => (B, T, w)
a_emb = self.layer1(actions)
# 3) Get the sinusoidal encoding (B, T, w)
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
# 4) Concat along last dim => (B, T, 2w), then layer2 => (B, T, w), swish
x = torch.cat([a_emb, tau_emb], dim=-1)
x = swish(self.layer2(x))
# 5) Finally W3 => (B, T, w)
x = self.layer3(x)
return x
class MultiEmbodimentActionEncoder(nn.Module):
def __init__(self, action_dim, hidden_size=1024, num_embodiments=8):
super().__init__()
self.hidden_size = hidden_size
self.num_embodiments = num_embodiments
# W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions, timesteps, cat_ids):
"""
actions: shape (B, T, action_dim)
timesteps: shape (B,) -- a single scalar per batch item
cat_ids: shape (B,)
returns: shape (B, T, hidden_size)
"""
B, T, _ = actions.shape
# 1) Expand each batch's single scalar time 'tau' across all T steps
# so that shape => (B, T)
# e.g. if timesteps is (B,), replicate across T
if timesteps.dim() == 1 and timesteps.shape[0] == B:
# shape (B,) => (B,T)
timesteps = timesteps.unsqueeze(1).expand(-1, T)
else:
raise ValueError(
"Expected `timesteps` to have shape (B,) so we can replicate across T."
)
# 2) Standard action MLP step for shape => (B, T, w)
a_emb = self.W1(actions, cat_ids)
# 3) Get the sinusoidal encoding (B, T, w)
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
x = torch.cat([a_emb, tau_emb], dim=-1)
x = swish(self.W2(x, cat_ids))
# 5) Finally W3 => (B, T, w)
x = self.W3(x, cat_ids)
return x
@dataclass
class FlowmatchingActionHeadConfig(PretrainedConfig):
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
add_pos_embed: bool = field(
default=True, metadata={"help": "Whether to add positional embedding"}
)
diffusion_model_cfg: dict = field(
default=None, metadata={"help": "Diffusion model configuration."}
)
input_embedding_dim: int = field(
default=1536, metadata={"help": "Input embedding channel dimension."}
)
hidden_size: int = field(default=1024, metadata={"help": "Input embedding dimension."})
max_seq_len: int = field(default=1024, metadata={"help": "Maxium Sequence Length"})
action_dim: int = field(default=None, metadata={"help": "Action dimension."})
action_horizon: int = field(default=None, metadata={"help": "Action horizon."})
noise_beta_alpha: float = field(default=1.5, metadata={"help": ""})
noise_beta_beta: float = field(default=1.0, metadata={"help": ""})
noise_s: float = field(
default=0.999, metadata={"help": "Flow matching noise Beta distribution s."}
)
num_timestep_buckets: int = field(
default=1000, metadata={"help": "Number of timestep discretization buckets."}
)
num_inference_timesteps: int = field(
default=None,
metadata={"help": "Number of inference steps for noise diffusion."},
)
max_num_embodiments: int = field(default=32, metadata={"help": "Number of embodiments."})
tune_projector: bool = field(default=True, metadata={"help": "Whether to tune the projector."})
tune_diffusion_model: bool = field(
default=True, metadata={"help": "Whether to tune the diffusion model."}
)
load_pretrained_det_decode_layer_path: str = field(
default=None, metadata={"help": "Path to pretrained detection model."}
)
detection_coeff: float = field(default=1.0, metadata={"help": "Detection coefficient."})
freeze_decode_layer: bool = field(default=False)
expand_batch: int = field(default=None)
use_vlln: bool = field(default=True)
vl_self_attention_cfg: dict = field(default=None)
num_target_vision_tokens: int = field(
default=32, metadata={"help": "Number of target vision tokens."}
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in kwargs.items():
setattr(self, key, value)
DiTConfig = {"num_layers": 36, "input_embedding_dim": 2048, "attention_head_dim": 64, "num_attention_heads": 32} # default for qwen2.5-vl
class LayerwiseFlowmatchingActionHead(nn.Module):
def __init__(
self,
global_config,
**kwargs,
):
super().__init__()
action_config = global_config.framework.action_model
diffusion_model_cfg = action_config.diffusion_model_cfg
# 更新 DiTConfig 到 diffusion_model_cfg
DiTConfig["num_layers"] = global_config.framework.qwenvl.num_vl_layers
DiTConfig["input_embedding_dim"] = global_config.framework.qwenvl.vl_hidden_dim
DiTConfig["num_attention_heads"] = DiTConfig["input_embedding_dim"] // DiTConfig["attention_head_dim"]
diffusion_model_cfg.update(DiTConfig)
# diffusion_model_cfg["interleave_self_attention"] = False
diffusion_model_cfg.cross_attention_dim = DiTConfig["input_embedding_dim"] # should match vl embedding dim, but for some case we might want to change it for cross + self attention
self.input_embedding_dim = global_config.framework.qwenvl.vl_hidden_dim
self.model = DiT(**diffusion_model_cfg) # TODO better way is copy LLM from VLM
self.dit_out_hidden_size = self.input_embedding_dim
self.action_dim = action_config.action_dim
self.action_horizon = action_config.future_action_window_size + 1
self.num_inference_timesteps = action_config.num_inference_timesteps
self.state_encoder = MLP(
input_dim=action_config.state_dim,
output_dim=self.input_embedding_dim,
) if action_config.state_dim else None
self.action_encoder = ActionEncoder(
action_dim=action_config.action_dim,
hidden_size=self.input_embedding_dim,
)
self.action_decoder = MLP(
input_dim=self.input_embedding_dim,
hidden_dim=1024,
output_dim=self.action_dim,
)
self.future_tokens = nn.Embedding(action_config.num_target_vision_tokens, self.input_embedding_dim)
nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02)
if action_config.add_pos_embed:
self.position_embedding = nn.Embedding(action_config.max_seq_len, self.input_embedding_dim)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
self.beta_dist = Beta(action_config.noise_beta_alpha, action_config.noise_beta_beta)
self.num_timestep_buckets = action_config.num_timestep_buckets
self.config = action_config
def sample_time(self, batch_size, device, dtype):
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
return (self.config.noise_s - sample) / self.config.noise_s
def prepare_input(self, batch: dict) -> BatchFeature:
return BatchFeature(data=batch)
def forward(self, vl_embs_list: list, actions: torch.Tensor, state: torch.Tensor = None, encoder_attention_mask: torch.Tensor = None):
"""
vl_embs: list of torch.Tensor, each shape (B, seq_length, feature_dim)
actions: shape (B, future_action_window_size, D_action)
encoder_attention_mask: optional (B, seq_length) mask for VLM padding tokens
"""
device = actions.device
num_layers = len(vl_embs_list)
B, L, D = vl_embs_list[0].shape
# Embed noised action trajectory.
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
t = t[:, None, None] # shape (B,1,1) for broadcast
noisy_trajectory = (1 - t) * noise + t * actions
velocity = actions - noise
# Convert (continuous) t -> discrete if needed
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
action_features = self.action_encoder(noisy_trajectory, t_discretized)
# Embed state
state_features = self.state_encoder(state) if state is not None else None
# Maybe add position embedding.
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
action_features = action_features + pos_embs
# state and action embedding along sequence dimension.
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(B, -1, -1)
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1) \
if state_features is not None else torch.cat((future_tokens, action_features), dim=1)
# Encode timesteps
temb = self.model.timestep_encoder(t_discretized)
# Convert encoder_attention_mask from long int (0/1) to bool so that
# F.scaled_dot_product_attention accepts it (requires bool or float, not long).
# Qwen attention mask: 1 = valid token (attend), 0 = padding (mask out).
if encoder_attention_mask is not None:
encoder_attention_mask = encoder_attention_mask.bool()
# Layerwise cross-attention with vl_embs
model_output = sa_embs
for layer_idx, layer in enumerate(self.model.transformer_blocks):
model_output = layer(
hidden_states=model_output,
encoder_hidden_states=vl_embs_list[layer_idx],
temb=temb,
encoder_attention_mask=encoder_attention_mask,
)
# TODO miss self att and _process_output, but work well
pred = self.action_decoder(model_output)
pred_actions = pred[:, -actions.shape[1] :]
# Slice out only the action portion of pred and target.
loss = ((pred_actions - velocity) ** 2).mean()
return loss
@torch.no_grad()
def predict_action(self, vl_embs_list: list, state: torch.Tensor = None, encoder_attention_mask: torch.Tensor = None) -> torch.Tensor:
# Set initial actions as the sampled noise.
batch_size = vl_embs_list[0].shape[0]
device = vl_embs_list[0].device
actions = torch.randn(
size=(batch_size, self.action_horizon, self.action_dim),
dtype=vl_embs_list[0].dtype,
device=device,
)
num_steps = self.num_inference_timesteps
dt = 1.0 / num_steps
state_features = self.state_encoder(state) if state is not None else None
# Convert encoder_attention_mask dtype once before the denoising loop.
if encoder_attention_mask is not None:
encoder_attention_mask = encoder_attention_mask.bool()
# Run denoising steps.
for t in range(num_steps):
t_cont = t / float(num_steps)
t_discretized_int = int(t_cont * self.num_timestep_buckets)
timesteps_tensor = torch.full(
size=(batch_size,), fill_value=t_discretized_int, device=device, dtype=torch.long
)
# Embed current action trajectory with timestep
action_features = self.action_encoder(actions, timesteps_tensor)
# Maybe add position embedding.
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
action_features = action_features + pos_embs
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(batch_size, -1, -1)
sa_embs = (
torch.cat((state_features, future_tokens, action_features), dim=1)
if state_features is not None
else torch.cat((future_tokens, action_features), dim=1)
)
# Encode timestep
temb = self.model.timestep_encoder(timesteps_tensor)
# Layerwise cross-attention with vl_embs_list
model_output = sa_embs
for layer_idx, layer in enumerate(self.model.transformer_blocks):
model_output = layer(
hidden_states=model_output,
encoder_hidden_states=vl_embs_list[layer_idx],
temb=temb,
encoder_attention_mask=encoder_attention_mask,
)
# TODO miss self att and _process_output
pred = self.action_decoder(model_output)
pred_velocity = pred[:, -self.action_horizon :]
# Euler integration
actions = actions + dt * pred_velocity
return actions
@property
def device(self):
return next(iter(self.parameters())).device
@property
def dtype(self):
return next(iter(self.parameters())).dtype
def get_action_model(config=None):
"""
Factory: build FlowmatchingActionHead from global framework config.
Args:
config: Global config (expects config.framework.action_model namespace).
Returns:
FlowmatchingActionHead: Initialized FlowMatchingActionHead.
"""
return LayerwiseFlowmatchingActionHead(
global_config=config
)
if __name__ == "__main__":
# TODO make each backbone.py can be debug independently
pass