File size: 7,023 Bytes
e94400c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
# Copyright 2025 CogACT. All rights reserved.
# Modified by [Jinhui YE/ HKUST University] in [2025]. 
# Modification: [add global config ].
"""
Diffusion-based action prediction head (DiT variant).

Provides:
  - Size presets (S/B/L) for transformer-based temporal action diffusion backbone
  - ActionModel: wraps diffusion process (training + optional DDIM sampling creation)
"""

from starVLA.model.modules.action_model.DiT_modules.models import DiT
from starVLA.model.modules.action_model import create_diffusion
from .DiT_modules import gaussian_diffusion as gd

import torch
from torch import nn


# Create model sizes of ActionModels
def DiT_S(**kwargs):  # TODO move to config for reproducibility
    """
    Small DiT variant.

    Args:
        **kwargs: Passed through to DiT constructor.

    Returns:
        DiT: Initialized small model.
    """
    return DiT(depth=6, token_size=384, num_heads=4, **kwargs)


def DiT_B(**kwargs):
    """
    Base DiT variant.

    Args:
        **kwargs: Passed through to DiT constructor.

    Returns:
        DiT: Initialized base model.
    """
    return DiT(depth=12, token_size=768, num_heads=12, **kwargs)


def DiT_L(**kwargs):
    """
    Large DiT variant.

    Args:
        **kwargs: Passed through to DiT constructor.

    Returns:
        DiT: Initialized large model.
    """
    return DiT(depth=24, token_size=1024, num_heads=16, **kwargs)


# Model size
DiT_models = {"DiT-S": DiT_S, "DiT-B": DiT_B, "DiT-L": DiT_L}


# Create ActionModel
class ActionModel(nn.Module):
    """
    Diffusion temporal action head.

    Components:
        - DiT transformer backbone (token-wise denoiser)
        - Gaussian diffusion scheduler (noise forward/backward)
        - Optional DDIM sampler (created lazily)

    Responsibilities:
        - Forward: add noise + predict denoised residual
        - loss(): simple MSE on noise prediction
        - create_ddim(): build deterministic sampler
    """

    def __init__(
        self,
        action_hidden_dim,
        model_type,
        in_channels,
        future_action_window_size,
        past_action_window_size,
        diffusion_steps=100,
        noise_schedule="squaredcos_cap_v2",
    ):
        """
        Initialize diffusion model and backbone.

        Args:
            action_hidden_dim: Hidden size of conditioning tokens (QFormer output dim).
            model_type: One of {'DiT-S','DiT-B','DiT-L'}.
            in_channels: Action dimensionality (per timestep).
            future_action_window_size: Number of future steps modeled.
            past_action_window_size: Number of past steps possibly encoded (for context).
            diffusion_steps: Total diffusion timesteps.
            noise_schedule: Scheduler type string.
        """
        super().__init__()
        self.in_channels = in_channels
        self.noise_schedule = noise_schedule
        # GaussianDiffusion offers forward and backward functions q_sample and p_sample.
        self.diffusion_steps = diffusion_steps
        self.diffusion = create_diffusion(
            timestep_respacing="",
            noise_schedule=noise_schedule,
            diffusion_steps=self.diffusion_steps,
            sigma_small=True,
            learn_sigma=False,
        )
        self.ddim_diffusion = None
        if self.diffusion.model_var_type in [gd.ModelVarType.LEARNED, gd.ModelVarType.LEARNED_RANGE]:
            learn_sigma = True
        else:
            learn_sigma = False
        self.past_action_window_size = past_action_window_size
        self.future_action_window_size = future_action_window_size
        self.token_size = action_hidden_dim  # QFormer output size
        self.net = DiT_models[model_type](
            in_channels=in_channels,
            class_dropout_prob=0.1,
            learn_sigma=learn_sigma,
            future_action_window_size=future_action_window_size,
            past_action_window_size=past_action_window_size,
        )

    def forward(self, gt_action, condition, **kwargs):
        """
        Perform one diffusion training step.

        Args:
            gt_action: Ground truth action tensor [B, T, C].
            condition: Conditioning tokens [B, L, D].
            **kwargs: Ignored (reserved).

        Returns:
            tuple:
                noise_pred: Predicted noise tensor.
                noise: Sampled noise tensor.
                timestep: Timesteps used per batch element.
        """
        # sample random noise and timestep
        noise = torch.randn_like(gt_action)  # [B, T, C]
        timestep = torch.randint(0, self.diffusion.num_timesteps, (gt_action.size(0),), device=gt_action.device)

        # sample x_t from x
        x_t = self.diffusion.q_sample(gt_action, timestep, noise)

        # predict noise from x_t
        noise_pred = self.net(x_t, timestep, condition)

        assert noise_pred.shape == noise.shape == gt_action.shape

        return noise_pred, noise, timestep

    def loss(self, noise_pred, noise):
        """
        Compute MSE noise prediction loss.

        Args:
            noise_pred: Predicted noise tensor.
            noise: Target noise tensor.

        Returns:
            torch.Tensor: Scalar loss.
        """
        # Compute L2 loss
        loss = ((noise_pred - noise) ** 2).mean()
        # Optional: loss += loss_vlb

        return loss

    def create_ddim(self, ddim_step=10):
        """
        Lazily create DDIM sampler instance.

        Args:
            ddim_step: Number of DDIM steps.

        Returns:
            Diffusion: DDIM diffusion object.
        """
        self.ddim_diffusion = create_diffusion(
            timestep_respacing="ddim" + str(ddim_step),
            noise_schedule=self.noise_schedule,
            diffusion_steps=self.diffusion_steps,
            sigma_small=True,
            learn_sigma=False,
        )
        return self.ddim_diffusion


def get_action_model(model_typ="DiT-B", config=None):
    """
    Factory: build ActionModel from global framework config.

    Args:
        model_typ: (Unused override; model type inferred from config).
        config: Global config (expects config.framework.action_model namespace).

    Returns:
        ActionModel: Initialized diffusion action head.
    """
    action_model_cfg = config.framework.action_model

    model_type = action_model_cfg.action_model_type
    action_hidden_dim = action_model_cfg.action_hidden_dim
    action_dim = action_model_cfg.action_dim
    future_action_window_size = action_model_cfg.future_action_window_size
    past_action_window_size = action_model_cfg.past_action_window_size

    return ActionModel(
        model_type=model_type,  # Model type, e.g., 'DiT-B'
        action_hidden_dim=action_hidden_dim,  # Hidden size of action tokens
        in_channels=action_dim,  # Input channel size
        future_action_window_size=future_action_window_size,  # Future action window size
        past_action_window_size=past_action_window_size,  # Past action window size
    )