| |
| |
| |
| """ |
| Diffusion-based action prediction head (DiT variant). |
| |
| Provides: |
| - Size presets (S/B/L) for transformer-based temporal action diffusion backbone |
| - ActionModel: wraps diffusion process (training + optional DDIM sampling creation) |
| """ |
|
|
| from starVLA.model.modules.action_model.DiT_modules.models import DiT |
| from starVLA.model.modules.action_model import create_diffusion |
| from .DiT_modules import gaussian_diffusion as gd |
|
|
| import torch |
| from torch import nn |
|
|
|
|
| |
| def DiT_S(**kwargs): |
| """ |
| Small DiT variant. |
| |
| Args: |
| **kwargs: Passed through to DiT constructor. |
| |
| Returns: |
| DiT: Initialized small model. |
| """ |
| return DiT(depth=6, token_size=384, num_heads=4, **kwargs) |
|
|
|
|
| def DiT_B(**kwargs): |
| """ |
| Base DiT variant. |
| |
| Args: |
| **kwargs: Passed through to DiT constructor. |
| |
| Returns: |
| DiT: Initialized base model. |
| """ |
| return DiT(depth=12, token_size=768, num_heads=12, **kwargs) |
|
|
|
|
| def DiT_L(**kwargs): |
| """ |
| Large DiT variant. |
| |
| Args: |
| **kwargs: Passed through to DiT constructor. |
| |
| Returns: |
| DiT: Initialized large model. |
| """ |
| return DiT(depth=24, token_size=1024, num_heads=16, **kwargs) |
|
|
|
|
| |
| DiT_models = {"DiT-S": DiT_S, "DiT-B": DiT_B, "DiT-L": DiT_L} |
|
|
|
|
| |
| class ActionModel(nn.Module): |
| """ |
| Diffusion temporal action head. |
| |
| Components: |
| - DiT transformer backbone (token-wise denoiser) |
| - Gaussian diffusion scheduler (noise forward/backward) |
| - Optional DDIM sampler (created lazily) |
| |
| Responsibilities: |
| - Forward: add noise + predict denoised residual |
| - loss(): simple MSE on noise prediction |
| - create_ddim(): build deterministic sampler |
| """ |
|
|
| def __init__( |
| self, |
| action_hidden_dim, |
| model_type, |
| in_channels, |
| future_action_window_size, |
| past_action_window_size, |
| diffusion_steps=100, |
| noise_schedule="squaredcos_cap_v2", |
| ): |
| """ |
| Initialize diffusion model and backbone. |
| |
| Args: |
| action_hidden_dim: Hidden size of conditioning tokens (QFormer output dim). |
| model_type: One of {'DiT-S','DiT-B','DiT-L'}. |
| in_channels: Action dimensionality (per timestep). |
| future_action_window_size: Number of future steps modeled. |
| past_action_window_size: Number of past steps possibly encoded (for context). |
| diffusion_steps: Total diffusion timesteps. |
| noise_schedule: Scheduler type string. |
| """ |
| super().__init__() |
| self.in_channels = in_channels |
| self.noise_schedule = noise_schedule |
| |
| self.diffusion_steps = diffusion_steps |
| self.diffusion = create_diffusion( |
| timestep_respacing="", |
| noise_schedule=noise_schedule, |
| diffusion_steps=self.diffusion_steps, |
| sigma_small=True, |
| learn_sigma=False, |
| ) |
| self.ddim_diffusion = None |
| if self.diffusion.model_var_type in [gd.ModelVarType.LEARNED, gd.ModelVarType.LEARNED_RANGE]: |
| learn_sigma = True |
| else: |
| learn_sigma = False |
| self.past_action_window_size = past_action_window_size |
| self.future_action_window_size = future_action_window_size |
| self.token_size = action_hidden_dim |
| self.net = DiT_models[model_type]( |
| in_channels=in_channels, |
| class_dropout_prob=0.1, |
| learn_sigma=learn_sigma, |
| future_action_window_size=future_action_window_size, |
| past_action_window_size=past_action_window_size, |
| ) |
|
|
| def forward(self, gt_action, condition, **kwargs): |
| """ |
| Perform one diffusion training step. |
| |
| Args: |
| gt_action: Ground truth action tensor [B, T, C]. |
| condition: Conditioning tokens [B, L, D]. |
| **kwargs: Ignored (reserved). |
| |
| Returns: |
| tuple: |
| noise_pred: Predicted noise tensor. |
| noise: Sampled noise tensor. |
| timestep: Timesteps used per batch element. |
| """ |
| |
| noise = torch.randn_like(gt_action) |
| timestep = torch.randint(0, self.diffusion.num_timesteps, (gt_action.size(0),), device=gt_action.device) |
|
|
| |
| x_t = self.diffusion.q_sample(gt_action, timestep, noise) |
|
|
| |
| noise_pred = self.net(x_t, timestep, condition) |
|
|
| assert noise_pred.shape == noise.shape == gt_action.shape |
|
|
| return noise_pred, noise, timestep |
|
|
| def loss(self, noise_pred, noise): |
| """ |
| Compute MSE noise prediction loss. |
| |
| Args: |
| noise_pred: Predicted noise tensor. |
| noise: Target noise tensor. |
| |
| Returns: |
| torch.Tensor: Scalar loss. |
| """ |
| |
| loss = ((noise_pred - noise) ** 2).mean() |
| |
|
|
| return loss |
|
|
| def create_ddim(self, ddim_step=10): |
| """ |
| Lazily create DDIM sampler instance. |
| |
| Args: |
| ddim_step: Number of DDIM steps. |
| |
| Returns: |
| Diffusion: DDIM diffusion object. |
| """ |
| self.ddim_diffusion = create_diffusion( |
| timestep_respacing="ddim" + str(ddim_step), |
| noise_schedule=self.noise_schedule, |
| diffusion_steps=self.diffusion_steps, |
| sigma_small=True, |
| learn_sigma=False, |
| ) |
| return self.ddim_diffusion |
|
|
|
|
| def get_action_model(model_typ="DiT-B", config=None): |
| """ |
| Factory: build ActionModel from global framework config. |
| |
| Args: |
| model_typ: (Unused override; model type inferred from config). |
| config: Global config (expects config.framework.action_model namespace). |
| |
| Returns: |
| ActionModel: Initialized diffusion action head. |
| """ |
| action_model_cfg = config.framework.action_model |
|
|
| model_type = action_model_cfg.action_model_type |
| action_hidden_dim = action_model_cfg.action_hidden_dim |
| action_dim = action_model_cfg.action_dim |
| future_action_window_size = action_model_cfg.future_action_window_size |
| past_action_window_size = action_model_cfg.past_action_window_size |
|
|
| return ActionModel( |
| model_type=model_type, |
| action_hidden_dim=action_hidden_dim, |
| in_channels=action_dim, |
| future_action_window_size=future_action_window_size, |
| past_action_window_size=past_action_window_size, |
| ) |
|
|