File size: 7,023 Bytes
e94400c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 | # Copyright 2025 CogACT. All rights reserved.
# Modified by [Jinhui YE/ HKUST University] in [2025].
# Modification: [add global config ].
"""
Diffusion-based action prediction head (DiT variant).
Provides:
- Size presets (S/B/L) for transformer-based temporal action diffusion backbone
- ActionModel: wraps diffusion process (training + optional DDIM sampling creation)
"""
from starVLA.model.modules.action_model.DiT_modules.models import DiT
from starVLA.model.modules.action_model import create_diffusion
from .DiT_modules import gaussian_diffusion as gd
import torch
from torch import nn
# Create model sizes of ActionModels
def DiT_S(**kwargs): # TODO move to config for reproducibility
"""
Small DiT variant.
Args:
**kwargs: Passed through to DiT constructor.
Returns:
DiT: Initialized small model.
"""
return DiT(depth=6, token_size=384, num_heads=4, **kwargs)
def DiT_B(**kwargs):
"""
Base DiT variant.
Args:
**kwargs: Passed through to DiT constructor.
Returns:
DiT: Initialized base model.
"""
return DiT(depth=12, token_size=768, num_heads=12, **kwargs)
def DiT_L(**kwargs):
"""
Large DiT variant.
Args:
**kwargs: Passed through to DiT constructor.
Returns:
DiT: Initialized large model.
"""
return DiT(depth=24, token_size=1024, num_heads=16, **kwargs)
# Model size
DiT_models = {"DiT-S": DiT_S, "DiT-B": DiT_B, "DiT-L": DiT_L}
# Create ActionModel
class ActionModel(nn.Module):
"""
Diffusion temporal action head.
Components:
- DiT transformer backbone (token-wise denoiser)
- Gaussian diffusion scheduler (noise forward/backward)
- Optional DDIM sampler (created lazily)
Responsibilities:
- Forward: add noise + predict denoised residual
- loss(): simple MSE on noise prediction
- create_ddim(): build deterministic sampler
"""
def __init__(
self,
action_hidden_dim,
model_type,
in_channels,
future_action_window_size,
past_action_window_size,
diffusion_steps=100,
noise_schedule="squaredcos_cap_v2",
):
"""
Initialize diffusion model and backbone.
Args:
action_hidden_dim: Hidden size of conditioning tokens (QFormer output dim).
model_type: One of {'DiT-S','DiT-B','DiT-L'}.
in_channels: Action dimensionality (per timestep).
future_action_window_size: Number of future steps modeled.
past_action_window_size: Number of past steps possibly encoded (for context).
diffusion_steps: Total diffusion timesteps.
noise_schedule: Scheduler type string.
"""
super().__init__()
self.in_channels = in_channels
self.noise_schedule = noise_schedule
# GaussianDiffusion offers forward and backward functions q_sample and p_sample.
self.diffusion_steps = diffusion_steps
self.diffusion = create_diffusion(
timestep_respacing="",
noise_schedule=noise_schedule,
diffusion_steps=self.diffusion_steps,
sigma_small=True,
learn_sigma=False,
)
self.ddim_diffusion = None
if self.diffusion.model_var_type in [gd.ModelVarType.LEARNED, gd.ModelVarType.LEARNED_RANGE]:
learn_sigma = True
else:
learn_sigma = False
self.past_action_window_size = past_action_window_size
self.future_action_window_size = future_action_window_size
self.token_size = action_hidden_dim # QFormer output size
self.net = DiT_models[model_type](
in_channels=in_channels,
class_dropout_prob=0.1,
learn_sigma=learn_sigma,
future_action_window_size=future_action_window_size,
past_action_window_size=past_action_window_size,
)
def forward(self, gt_action, condition, **kwargs):
"""
Perform one diffusion training step.
Args:
gt_action: Ground truth action tensor [B, T, C].
condition: Conditioning tokens [B, L, D].
**kwargs: Ignored (reserved).
Returns:
tuple:
noise_pred: Predicted noise tensor.
noise: Sampled noise tensor.
timestep: Timesteps used per batch element.
"""
# sample random noise and timestep
noise = torch.randn_like(gt_action) # [B, T, C]
timestep = torch.randint(0, self.diffusion.num_timesteps, (gt_action.size(0),), device=gt_action.device)
# sample x_t from x
x_t = self.diffusion.q_sample(gt_action, timestep, noise)
# predict noise from x_t
noise_pred = self.net(x_t, timestep, condition)
assert noise_pred.shape == noise.shape == gt_action.shape
return noise_pred, noise, timestep
def loss(self, noise_pred, noise):
"""
Compute MSE noise prediction loss.
Args:
noise_pred: Predicted noise tensor.
noise: Target noise tensor.
Returns:
torch.Tensor: Scalar loss.
"""
# Compute L2 loss
loss = ((noise_pred - noise) ** 2).mean()
# Optional: loss += loss_vlb
return loss
def create_ddim(self, ddim_step=10):
"""
Lazily create DDIM sampler instance.
Args:
ddim_step: Number of DDIM steps.
Returns:
Diffusion: DDIM diffusion object.
"""
self.ddim_diffusion = create_diffusion(
timestep_respacing="ddim" + str(ddim_step),
noise_schedule=self.noise_schedule,
diffusion_steps=self.diffusion_steps,
sigma_small=True,
learn_sigma=False,
)
return self.ddim_diffusion
def get_action_model(model_typ="DiT-B", config=None):
"""
Factory: build ActionModel from global framework config.
Args:
model_typ: (Unused override; model type inferred from config).
config: Global config (expects config.framework.action_model namespace).
Returns:
ActionModel: Initialized diffusion action head.
"""
action_model_cfg = config.framework.action_model
model_type = action_model_cfg.action_model_type
action_hidden_dim = action_model_cfg.action_hidden_dim
action_dim = action_model_cfg.action_dim
future_action_window_size = action_model_cfg.future_action_window_size
past_action_window_size = action_model_cfg.past_action_window_size
return ActionModel(
model_type=model_type, # Model type, e.g., 'DiT-B'
action_hidden_dim=action_hidden_dim, # Hidden size of action tokens
in_channels=action_dim, # Input channel size
future_action_window_size=future_action_window_size, # Future action window size
past_action_window_size=past_action_window_size, # Past action window size
)
|