cross13tasks / code /model /modules /action_model /DiTActionHeader.py
Timsty's picture
Upload folder using huggingface_hub
e94400c verified
# Copyright 2025 CogACT. All rights reserved.
# Modified by [Jinhui YE/ HKUST University] in [2025].
# Modification: [add global config ].
"""
Diffusion-based action prediction head (DiT variant).
Provides:
- Size presets (S/B/L) for transformer-based temporal action diffusion backbone
- ActionModel: wraps diffusion process (training + optional DDIM sampling creation)
"""
from starVLA.model.modules.action_model.DiT_modules.models import DiT
from starVLA.model.modules.action_model import create_diffusion
from .DiT_modules import gaussian_diffusion as gd
import torch
from torch import nn
# Create model sizes of ActionModels
def DiT_S(**kwargs): # TODO move to config for reproducibility
"""
Small DiT variant.
Args:
**kwargs: Passed through to DiT constructor.
Returns:
DiT: Initialized small model.
"""
return DiT(depth=6, token_size=384, num_heads=4, **kwargs)
def DiT_B(**kwargs):
"""
Base DiT variant.
Args:
**kwargs: Passed through to DiT constructor.
Returns:
DiT: Initialized base model.
"""
return DiT(depth=12, token_size=768, num_heads=12, **kwargs)
def DiT_L(**kwargs):
"""
Large DiT variant.
Args:
**kwargs: Passed through to DiT constructor.
Returns:
DiT: Initialized large model.
"""
return DiT(depth=24, token_size=1024, num_heads=16, **kwargs)
# Model size
DiT_models = {"DiT-S": DiT_S, "DiT-B": DiT_B, "DiT-L": DiT_L}
# Create ActionModel
class ActionModel(nn.Module):
"""
Diffusion temporal action head.
Components:
- DiT transformer backbone (token-wise denoiser)
- Gaussian diffusion scheduler (noise forward/backward)
- Optional DDIM sampler (created lazily)
Responsibilities:
- Forward: add noise + predict denoised residual
- loss(): simple MSE on noise prediction
- create_ddim(): build deterministic sampler
"""
def __init__(
self,
action_hidden_dim,
model_type,
in_channels,
future_action_window_size,
past_action_window_size,
diffusion_steps=100,
noise_schedule="squaredcos_cap_v2",
):
"""
Initialize diffusion model and backbone.
Args:
action_hidden_dim: Hidden size of conditioning tokens (QFormer output dim).
model_type: One of {'DiT-S','DiT-B','DiT-L'}.
in_channels: Action dimensionality (per timestep).
future_action_window_size: Number of future steps modeled.
past_action_window_size: Number of past steps possibly encoded (for context).
diffusion_steps: Total diffusion timesteps.
noise_schedule: Scheduler type string.
"""
super().__init__()
self.in_channels = in_channels
self.noise_schedule = noise_schedule
# GaussianDiffusion offers forward and backward functions q_sample and p_sample.
self.diffusion_steps = diffusion_steps
self.diffusion = create_diffusion(
timestep_respacing="",
noise_schedule=noise_schedule,
diffusion_steps=self.diffusion_steps,
sigma_small=True,
learn_sigma=False,
)
self.ddim_diffusion = None
if self.diffusion.model_var_type in [gd.ModelVarType.LEARNED, gd.ModelVarType.LEARNED_RANGE]:
learn_sigma = True
else:
learn_sigma = False
self.past_action_window_size = past_action_window_size
self.future_action_window_size = future_action_window_size
self.token_size = action_hidden_dim # QFormer output size
self.net = DiT_models[model_type](
in_channels=in_channels,
class_dropout_prob=0.1,
learn_sigma=learn_sigma,
future_action_window_size=future_action_window_size,
past_action_window_size=past_action_window_size,
)
def forward(self, gt_action, condition, **kwargs):
"""
Perform one diffusion training step.
Args:
gt_action: Ground truth action tensor [B, T, C].
condition: Conditioning tokens [B, L, D].
**kwargs: Ignored (reserved).
Returns:
tuple:
noise_pred: Predicted noise tensor.
noise: Sampled noise tensor.
timestep: Timesteps used per batch element.
"""
# sample random noise and timestep
noise = torch.randn_like(gt_action) # [B, T, C]
timestep = torch.randint(0, self.diffusion.num_timesteps, (gt_action.size(0),), device=gt_action.device)
# sample x_t from x
x_t = self.diffusion.q_sample(gt_action, timestep, noise)
# predict noise from x_t
noise_pred = self.net(x_t, timestep, condition)
assert noise_pred.shape == noise.shape == gt_action.shape
return noise_pred, noise, timestep
def loss(self, noise_pred, noise):
"""
Compute MSE noise prediction loss.
Args:
noise_pred: Predicted noise tensor.
noise: Target noise tensor.
Returns:
torch.Tensor: Scalar loss.
"""
# Compute L2 loss
loss = ((noise_pred - noise) ** 2).mean()
# Optional: loss += loss_vlb
return loss
def create_ddim(self, ddim_step=10):
"""
Lazily create DDIM sampler instance.
Args:
ddim_step: Number of DDIM steps.
Returns:
Diffusion: DDIM diffusion object.
"""
self.ddim_diffusion = create_diffusion(
timestep_respacing="ddim" + str(ddim_step),
noise_schedule=self.noise_schedule,
diffusion_steps=self.diffusion_steps,
sigma_small=True,
learn_sigma=False,
)
return self.ddim_diffusion
def get_action_model(model_typ="DiT-B", config=None):
"""
Factory: build ActionModel from global framework config.
Args:
model_typ: (Unused override; model type inferred from config).
config: Global config (expects config.framework.action_model namespace).
Returns:
ActionModel: Initialized diffusion action head.
"""
action_model_cfg = config.framework.action_model
model_type = action_model_cfg.action_model_type
action_hidden_dim = action_model_cfg.action_hidden_dim
action_dim = action_model_cfg.action_dim
future_action_window_size = action_model_cfg.future_action_window_size
past_action_window_size = action_model_cfg.past_action_window_size
return ActionModel(
model_type=model_type, # Model type, e.g., 'DiT-B'
action_hidden_dim=action_hidden_dim, # Hidden size of action tokens
in_channels=action_dim, # Input channel size
future_action_window_size=future_action_window_size, # Future action window size
past_action_window_size=past_action_window_size, # Past action window size
)