File size: 7,905 Bytes
aae3ba1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
from vitra.models.action_model.dit import DiT
from vitra.models.action_model import create_diffusion
from . import gaussian_diffusion as gd
from vitra.datasets.dataset_utils import ActionFeature
import torch
from torch import nn

def DiT_T(**kwargs):
    return DiT(depth=3, hidden_size=256, num_heads=4, **kwargs)
def DiT_S(**kwargs):
    return DiT(depth=6, hidden_size=384, num_heads=4, **kwargs)
def DiT_M(**kwargs):
    return DiT(depth=12, hidden_size=384, num_heads=6, **kwargs)
def DiT_B(**kwargs):
    return DiT(depth=12, hidden_size=768, num_heads=12, **kwargs)
def DiT_L(**kwargs):
    return DiT(depth=24, hidden_size=1024, num_heads=16, **kwargs)

DiT_models = {'DiT-S': DiT_S, 'DiT-M': DiT_M, 'DiT-B': DiT_B, 'DiT-T': DiT_T, 'DiT-L': DiT_L}

class DiffusionPolicy(nn.Module):
    def __init__(
        self, 
        token_size, 
        model_type='DiT-B', 
        in_channels=192, 
        future_action_window_size=16, 
        past_action_window_size=0, 
        use_state=None, 
        action_type='angle',
        diffusion_steps=100,
        state_dim=None,
        loss_type='human',
    ):
        super().__init__()
        # SimpleMLP takes in x_t, timestep, and condition, and outputs predicted noise.
        self.in_channels = in_channels
        # GaussianDiffusion offers forward and backward functions q_sample and p_sample.
        self.diffusion_steps = diffusion_steps
        self.diffusion = create_diffusion(timestep_respacing="", 
                                        noise_schedule = 'squaredcos_cap_v2', 
                                        diffusion_steps=self.diffusion_steps, 
                                        sigma_small=True, 
                                        learn_sigma = False
                                        ) 
        #self.diffusion = create_diffusion(timestep_respacing="", noise_schedule = 'linear', diffusion_steps=100, sigma_small=True, learn_sigma = False)
        self.ddim_diffusion = None
        if self.diffusion.model_var_type in [gd.ModelVarType.LEARNED, gd.ModelVarType.LEARNED_RANGE]:
            learn_sigma = True
        else:
            learn_sigma = False
        self.past_action_window_size = past_action_window_size
        self.future_action_window_size = future_action_window_size
        self.use_state = use_state
        self.action_type = action_type
        
        # Get loss components and hand group mapping from ActionFeature
        if loss_type == 'human':
            self.loss_components = ActionFeature.get_loss_components(action_type)
        elif loss_type == 'robot':
            self.loss_components = ActionFeature.get_xhand_loss_components()
        else:
            raise ValueError(f"Unknown loss_type: {loss_type}")
        self.net = DiT_models[model_type](
            token_size = token_size, 
            action_dim = in_channels, 
            class_dropout_prob = 0.1, 
            learn_sigma = learn_sigma, 
            future_action_window_size = future_action_window_size, 
            past_action_window_size = past_action_window_size,
            use_state = use_state,
            state_dim=state_dim
        )

    # Given condition z and ground truth token x, x_mask, compute loss
    def loss(self, x, z, x_mask, state=None, state_mask=None):
        # sample random noise and timestep
        noise = torch.randn_like(x) # [B, T, C]
        timestep = torch.randint(0, self.diffusion.num_timesteps, (x.size(0),), device= x.device)
        
        # sample x_t from x
        x_t = self.diffusion.q_sample(x, timestep, noise)

        # predict noise from x_t
        noise_pred = self.net(x_t, timestep, z, x_mask, state, state_mask)

        assert noise_pred.shape == noise.shape == x.shape

        # L2 loss with mask
        square_delta = (noise_pred - noise) ** 2 * x_mask
        
        # Generic mask loss computation function
        def mask_loss(from_dim, to_dim):
            s = square_delta[:, :, from_dim:to_dim].sum()
            n = x_mask[:, :, from_dim:to_dim].sum()
            return s / n if n > 0 else 0
        
        # Compute loss for each component using ActionFeature definitions
        component_losses = {}
        component_counts = {}
        
        for name, (start, end, weight) in self.loss_components.items():
            component_losses[name] = mask_loss(start, end) * weight
            component_counts[name] = x_mask[:, :, start].sum()
        
        total_count = sum(component_counts.values())

        if total_count == 0:
            loss = square_delta[0, 0, 0]
        else:
            loss = sum(
                component_losses[k] * component_counts[k]
                for k in component_counts.keys()
            ) / total_count

        # Return loss with detailed component losses for logging
        return {
            "loss": loss,
            **component_losses,  # Unpack all component losses
        }
    
    # Given condition and noise, sample x using reverse diffusion process
    def sample(self, 
            action_features,
            cfg_scale,
            current_state,
            current_state_mask,
            use_ddim,
            num_ddim_steps,
            action_masks,
        ):
        B = action_features.shape[0]
        noise = torch.randn(action_features.shape[0], self.future_action_window_size+1, 
                self.in_channels,  device=action_features.device)   #[B, T, D]

        x_mask = action_masks.to(action_features.device)

        using_cfg = cfg_scale > 1.0
        if using_cfg:
            noise = torch.cat([noise, noise], 0)
            uncondition = self.net.z_embedder.uncondition
            uncondition = uncondition.unsqueeze(0)  #[1, D]
            uncondition = uncondition.expand(B, 1, -1) #[B, 1, D]
            z = torch.cat([action_features, uncondition], 0)
            cfg_scale = cfg_scale

            if self.use_state == 'DiT':
                model_kwargs = dict(
                    z=z, x_mask=x_mask, 
                    cfg_scale=cfg_scale, state=current_state, 
                    state_mask=current_state_mask
                )
            else:
                model_kwargs = dict(z=z, x_mask=x_mask, cfg_scale=cfg_scale)
            sample_fn = self.net.forward_with_cfg
        else:
            if self.use_state == 'DiT':
                model_kwargs = dict(z=action_features, x_mask=x_mask, state=current_state, state_mask=current_state_mask)
            else:
                model_kwargs = dict(z=action_features, x_mask=x_mask)
            sample_fn = self.net.forward

        if use_ddim and num_ddim_steps is not None:
            if self.ddim_diffusion is None:
                self.create_ddim(ddim_step=num_ddim_steps)
            samples = self.ddim_diffusion.ddim_sample_loop(
                sample_fn, 
                noise.shape, 
                noise, 
                clip_denoised=False,
                model_kwargs=model_kwargs,
                progress=False,
                device=action_features.device,
                eta=0.0
            )
        else:
            samples = self.ddim_diffusion.diffusion.p_sample_loop(
                sample_fn, 
                noise.shape, 
                noise, 
                clip_denoised=False,
                model_kwargs=model_kwargs,
                progress=False,
                device=action_features.device
            )
        if using_cfg:
            samples, _ = samples.chunk(2, dim=0)  # Remove null class samples
        return samples

    # Create DDIM sampler
    def create_ddim(self, ddim_step=10):
        self.ddim_diffusion = create_diffusion(
            timestep_respacing="ddim"+str(ddim_step), 
            noise_schedule = 'squaredcos_cap_v2', 
            diffusion_steps=self.diffusion_steps, 
            sigma_small=True, 
            learn_sigma = False
        )
        return self.ddim_diffusion