File size: 20,535 Bytes
e94400c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
"""
Mixture of Horizons (MoH) version of GROOT ActionHeader.

This is a reference implementation showing how to integrate MoH strategy into GROOT.
Key changes:
1. Support multiple horizons (e.g., [5, 10, 15, 20, 50])
2. Parallel processing via batching (batch_size * num_horizons)
3. Gating network for ensemble
4. Multi-component loss (individual + auxiliary + load balancing)
"""

import torch
import torch.nn.functional as F
from torch import nn
from typing import Optional, List
from dataclasses import dataclass, field

from starVLA.model.modules.action_model.flow_matching_head.action_encoder import (
    SinusoidalPositionalEncoding,
    swish,
)
from starVLA.model.modules.action_model.flow_matching_head.cross_attention_dit import DiT
from starVLA.model.modules.action_model.GR00T_ActionHeader import (
    FlowmatchingActionHeadConfig,
    ActionEncoder,
    MLP,
)


class FlowmatchingActionHeadMoH(nn.Module):
    """
    GROOT ActionHeader with Mixture of Horizons support.
    
    Key differences from original:
    - Supports multiple horizons (e.g., [5, 10, 15, 20, 50])
    - Processes all horizons in parallel via batching
    - Uses gating network to ensemble predictions
    - Multi-component loss function
    """
    
    def __init__(
        self,
        full_config,
        horizons: List[int] = [2,5,8],  # Different horizon lengths
        use_gate_noise: bool = True,  # Add learnable noise to gate logits
    ):
        super().__init__()
        config = full_config.framework.action_model
        self.horizons = sorted(horizons)  # Ensure sorted
        self.max_horizon = self.horizons[-1]
        self.num_horizons = len(self.horizons)
        self.use_gate_noise = use_gate_noise
        
        self.hidden_size = config.hidden_size
        self.full_config = full_config
        action_model_type = config.action_model_type
        action_model_cfg = {
            "DiT-B": {"input_embedding_dim": 768, "attention_head_dim": 64, "num_attention_heads": 12},
            "DiT-L": {"input_embedding_dim": 1536, "attention_head_dim": 48, "num_attention_heads": 32},
        }[action_model_type]
        
        self.input_embedding_dim = action_model_cfg["input_embedding_dim"]
        diffusion_model_cfg = config.diffusion_model_cfg
        diffusion_model_cfg = {**action_model_cfg, **diffusion_model_cfg}
        self.model = DiT(**diffusion_model_cfg)
        self.action_dim = config.action_dim
        self.action_horizon = config.future_action_window_size + 1
        self.num_inference_timesteps = config.num_inference_timesteps

        self.state_encoder = MLP(
            input_dim=config.state_dim,
            hidden_dim=self.hidden_size,
            output_dim=self.input_embedding_dim,
        ) if config.state_dim else None

        self.action_encoder = ActionEncoder(
            action_dim=config.action_dim,
            hidden_size=self.input_embedding_dim,
        )
        self.action_decoder = MLP(
            input_dim=self.model.config.output_dim,
            hidden_dim=self.hidden_size,
            output_dim=self.action_dim,
        )
        self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim)
        nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02)

        if config.add_pos_embed:
            self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
            nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)

        self.beta_dist = torch.distributions.Beta(config.noise_beta_alpha, config.noise_beta_beta)
        self.num_timestep_buckets = config.num_timestep_buckets
        self.config = config
        
        # MoH-specific components
        # Gating network: predicts weights for each horizon at each timestep
        # Input: model output features, Output: gate logits for each horizon
        self.gate_out_proj = nn.Linear(self.model.config.output_dim, 1)
        if self.use_gate_noise:
            self.gate_noise_layer = nn.Linear(self.model.config.output_dim, 1)
        self.softplus = nn.Softplus()
        
        print(f"MoH ActionHeader initialized with horizons: {self.horizons}")

    def sample_time(self, batch_size, device, dtype):
        sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype).clamp(max=self.config.noise_s)
        return (self.config.noise_s - sample) / self.config.noise_s

    def cv_squared(self, x):
        """Coefficient of variation squared for load balancing."""
        eps = 1e-10
        if x.shape[0] == 1:
            return torch.tensor(0.0, device=x.device, dtype=x.dtype)
        return x.float().var() / (x.float().mean() ** 2 + eps)

    def forward(
        self, 
        vl_embs: torch.Tensor, 
        actions: torch.Tensor, 
        state: torch.Tensor = None, 
        encoder_attention_mask=None,
        loss_config: dict = None
    ):
        """
        Forward pass with MoH strategy.
        
        Args:
            vl_embs: (B, seq_length, feature_dim) - Vision-language embeddings
            actions: (B, max_horizon, D_action) - Ground truth actions (padded to max_horizon)
            state: (B, state_dim) - Optional state features
            encoder_attention_mask: Attention mask for encoder
            loss_config: Dict with 'aux_weight' and 'balance_weight'
        
        Returns:
            total_loss: Combined loss from all components
        """
        device = vl_embs.device
        batch_size = actions.shape[0]
        num_horizons = len(self.horizons)
        max_horizon = self.max_horizon
        
        # Sample noise and time
        noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
        time_scalar = self.sample_time(batch_size, device, actions.dtype)
        
        # Expand time for each horizon: (num_h, batch_size)
        time = time_scalar.unsqueeze(0).expand(num_horizons, -1)
        # x_t: (num_h, batch_size, max_horizon, action_dim)
        # Flow matching: x_t = (1-t) * noise + t * actions, where t=0 is noise and t=1 is actions
        # Expand noise and actions to (num_h, batch_size, max_horizon, action_dim)
        noise_expanded = noise.unsqueeze(0).expand(num_horizons, -1, -1, -1)
        actions_expanded = actions.unsqueeze(0).expand(num_horizons, -1, -1, -1)
        t_expanded = time[:, :, None, None]  # (num_h, batch_size, 1, 1)
        x_t = (1 - t_expanded) * noise_expanded + t_expanded * actions_expanded
        # u_t (target velocity): (batch_size, max_horizon, action_dim)
        u_t = actions - noise
        
        # ============================================================
        # STAGE 1: Prepare inputs for parallel processing
        # ============================================================
        # Repeat vl_embs and state for each horizon
        batched_vl_embs = vl_embs.repeat_interleave(num_horizons, dim=0)  # (B*H, seq_len, dim)
        batched_state = state.repeat_interleave(num_horizons, dim=0) if state is not None else None
        if encoder_attention_mask is not None:
            batched_encoder_attention_mask = encoder_attention_mask.repeat_interleave(num_horizons, dim=0)
        else:
            batched_encoder_attention_mask = None
        
        # Reshape x_t and time for batched processing
        # x_t: (num_h, batch_size, max_horizon, dim) -> (batch_size * num_h, max_horizon, dim)
        batched_x_t = x_t.permute(1, 0, 2, 3).reshape(batch_size * num_horizons, max_horizon, -1)
        # time: (num_h, batch_size) -> (batch_size * num_h)
        batched_time = time.permute(1, 0).reshape(batch_size * num_horizons)
        
        # Create padding masks for each horizon
        # action_pad_mask: (num_h, max_horizon) - True where valid, False where padding
        action_pad_mask = torch.arange(max_horizon, device=device)[None, :] < \
                          torch.tensor(self.horizons, device=device)[:, None]
        # Expand to batch: (num_h, batch_size, max_horizon)
        action_pad_mask = action_pad_mask.unsqueeze(1).expand(-1, batch_size, -1)
        # Reshape: (batch_size * num_h, max_horizon)
        batched_action_pad_mask = action_pad_mask.permute(1, 0, 2).reshape(batch_size * num_horizons, max_horizon)
        
        # ============================================================
        # STAGE 2: Forward pass through model (parallel for all horizons)
        # ============================================================
        # Convert time to discrete timesteps
        t_discretized = (batched_time * self.num_timestep_buckets).long()
        
        # Encode actions
        action_features = self.action_encoder(batched_x_t, t_discretized)
        
        # Add position embedding if needed
        if self.config.add_pos_embed:
            pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
            pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
            action_features = action_features + pos_embs
        
        # Prepare state and action embeddings
        future_tokens = self.future_tokens.weight.unsqueeze(0).expand(
            batch_size * num_horizons, -1, -1
        )
        if batched_state is not None:
            state_features = self.state_encoder(batched_state)
            sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
        else:
            sa_embs = torch.cat((future_tokens, action_features), dim=1)
        
        # Forward through DiT model
        model_output = self.model(
            hidden_states=sa_embs,
            encoder_hidden_states=batched_vl_embs,
            encoder_attention_mask=batched_encoder_attention_mask,
            timestep=t_discretized,
            return_all_hidden_states=False,
        )
        
        # Decode actions
        pred = self.action_decoder(model_output)
        
        # Extract action predictions (last max_horizon tokens)
        # pred: (B*H, seq_len, action_dim) -> (B*H, max_horizon, action_dim)
        state_offset = 1 if state is not None else 0
        future_tokens_len = self.future_tokens.num_embeddings
        action_start_idx = state_offset + future_tokens_len
        pred_actions_padded = pred[:, action_start_idx:action_start_idx + max_horizon, :]
        
        # Reshape to separate predictions for each horizon
        # (B*H, max_horizon, dim) -> (B, H, max_horizon, dim) -> (H, B, max_horizon, dim)
        all_v_t_preds = pred_actions_padded.view(
            batch_size, num_horizons, max_horizon, -1
        ).permute(1, 0, 2, 3)
        
        # ============================================================
        # STAGE 3: Compute losses
        # ============================================================
        # 1. Individual loss: Each horizon's prediction vs target
        all_head_losses = []
        for i, h in enumerate(self.horizons):
            v_t_head = all_v_t_preds[i, :, :h, :]  # (B, h, dim)
            target_v_t = u_t[:, :h, :]  # (B, h, dim)
            head_loss = F.mse_loss(v_t_head, target_v_t)
            all_head_losses.append(head_loss)
        individual_loss = torch.sum(torch.stack(all_head_losses))
        
        # 2. Gating network: Generate weights for ensemble
        gate_logits = self.gate_out_proj(model_output.to(torch.float32))
        gate_logits = gate_logits[:, action_start_idx:action_start_idx + max_horizon, :]  # (B*H, max_horizon, 1)
        
        if self.use_gate_noise:
            # Add learnable noise to gate logits
            noise_epsilon = 1e-2
            raw_noise_stddev = self.gate_noise_layer(model_output.to(torch.float32))
            raw_noise_stddev = raw_noise_stddev[:, action_start_idx:action_start_idx + max_horizon, :]
            noise_stddev = self.softplus(raw_noise_stddev) + noise_epsilon
            gate_logits = gate_logits + (torch.randn_like(gate_logits) * noise_stddev)
        
        # Reshape gate logits: (B*H, max_horizon, 1) -> (B, H, max_horizon) -> (B, max_horizon, H)
        gate_logits = gate_logits.reshape(batch_size, num_horizons, max_horizon).permute(0, 2, 1)
        
        # Apply mask: invalid horizons (where step >= horizon) get -inf
        valid_heads_mask = torch.tensor(
            [[step < h for h in self.horizons] for step in range(max_horizon)],
            device=device, dtype=torch.bool
        ).unsqueeze(0)  # (1, max_horizon, H)
        
        masked_gate_logits = torch.where(
            valid_heads_mask, 
            gate_logits, 
            torch.finfo(gate_logits.dtype).min
        )
        gate_weights = F.softmax(masked_gate_logits, dim=-1)  # (B, max_horizon, H)
        
        # 3. Ensemble predictions using gate weights
        # all_v_t_preds: (H, B, max_horizon, dim) -> (B, H, max_horizon, dim)
        all_v_t_preds_padded = all_v_t_preds.permute(1, 0, 2, 3)
        # gate_weights: (B, max_horizon, H) -> (B, H, max_horizon, 1)
        # combined: (B, max_horizon, dim)
        v_t_combined = (gate_weights.permute(0, 2, 1).unsqueeze(-1) * all_v_t_preds_padded).sum(dim=1)
        
        # Auxiliary loss: Ensemble prediction vs target
        aux_loss_weight = loss_config.get("aux_weight", 1.0) if loss_config else 1.0
        auxiliary_loss = F.mse_loss(v_t_combined, u_t)
        
        # 4. Load balancing loss: Encourage balanced usage of horizons
        loss_components = []
        boundaries = sorted(list(set([0] + self.horizons)))
        for i in range(len(boundaries) - 1):
            start_step, end_step = boundaries[i], boundaries[i + 1]
            active_expert_indices = [idx for idx, h in enumerate(self.horizons) if h > start_step]
            
            if len(active_expert_indices) > 1:
                segment_gate_weights = gate_weights[:, start_step:end_step, :]
                active_expert_weights = segment_gate_weights[:, :, active_expert_indices]
                avg_expert_prob_in_segment = active_expert_weights.mean(dim=(0, 1))
                segment_loss = self.cv_squared(avg_expert_prob_in_segment)
                loss_components.append(segment_loss)
        
        load_balancing_loss = torch.mean(torch.stack(loss_components)) if loss_components else torch.tensor(0.0, device=device)
        balance_loss_weight = loss_config.get("balance_weight", 0.001) if loss_config else 0.001
        
        # Total loss
        total_loss = individual_loss + aux_loss_weight * auxiliary_loss + balance_loss_weight * load_balancing_loss
        
        return total_loss

    @torch.no_grad()
    def predict_action(
        self, 
        vl_embs: torch.Tensor, 
        state: torch.Tensor = None,
        ret_weights: bool = False
    ) -> dict:
        """
        Inference with MoH ensemble.
        
        Args:
            vl_embs: (B, seq_length, feature_dim)
            state: (B, state_dim) - Optional
            ret_weights: Whether to return gate weights
        
        Returns:
            dict with 'actions' and optionally 'gate_weights'
        """
        batch_size = vl_embs.shape[0]
        device = vl_embs.device
        num_horizons = len(self.horizons)
        max_horizon = self.max_horizon
        
        # Initialize actions as noise
        actions = torch.randn(
            size=(batch_size, max_horizon, self.action_dim),
            dtype=vl_embs.dtype,
            device=device,
        )
        
        num_steps = self.num_inference_timesteps
        dt = 1.0 / num_steps
        
        gate_weights_to_log = []
        
        # Prepare batched inputs (same for all denoising steps)
        batched_vl_embs = vl_embs.repeat_interleave(num_horizons, dim=0)
        batched_state = state.repeat_interleave(num_horizons, dim=0) if state is not None else None
        
        # Denoising loop
        for t in range(num_steps):
            t_cont = t / float(num_steps)
            t_discretized = int(t_cont * self.num_timestep_buckets)
            timesteps_tensor = torch.full(
                size=(batch_size * num_horizons,), 
                fill_value=t_discretized, 
                device=device
            )
            
            # Prepare padded actions for each horizon
            padded_x_t_list, action_pad_mask_list = [], []
            for h in self.horizons:
                padded_x_t = F.pad(actions[:, :h, :], (0, 0, 0, max_horizon - h))
                padded_x_t_list.append(padded_x_t)
                pad_mask = F.pad(
                    torch.ones((batch_size, h), device=device, dtype=torch.bool),
                    (0, max_horizon - h),
                    value=False
                )
                action_pad_mask_list.append(pad_mask)
            
            batched_x_t = torch.cat(padded_x_t_list, dim=0)
            action_pad_mask = torch.cat(action_pad_mask_list, dim=0)
            
            # Encode actions
            action_features = self.action_encoder(batched_x_t, timesteps_tensor)
            
            if self.config.add_pos_embed:
                pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
                pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
                action_features = action_features + pos_embs
            
            # Prepare embeddings
            future_tokens = self.future_tokens.weight.unsqueeze(0).expand(
                batch_size * num_horizons, -1, -1
            )
            if batched_state is not None:
                state_features = self.state_encoder(batched_state)
                sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
            else:
                sa_embs = torch.cat((future_tokens, action_features), dim=1)
            
            # Forward through model
            model_output = self.model(
                hidden_states=sa_embs,
                encoder_hidden_states=batched_vl_embs,
                timestep=timesteps_tensor,
            )
            
            # Decode actions
            pred = self.action_decoder(model_output)
            
            # Extract action predictions
            state_offset = 1 if state is not None else 0
            future_tokens_len = self.future_tokens.num_embeddings
            action_start_idx = state_offset + future_tokens_len
            pred_actions_padded = pred[:, action_start_idx:action_start_idx + max_horizon, :]
            
            # Reshape: (B*H, max_horizon, dim) -> (B, H, max_horizon, dim)
            all_v_t_preds_padded = pred_actions_padded.view(
                num_horizons, batch_size, max_horizon, -1
            ).permute(1, 0, 2, 3)
            
            # Gating network
            gate_logits = self.gate_out_proj(model_output.to(torch.float32))
            gate_logits = gate_logits[:, action_start_idx:action_start_idx + max_horizon, :]
            gate_logits = gate_logits.reshape(batch_size, num_horizons, max_horizon).permute(0, 2, 1)
            
            valid_heads_mask = torch.tensor(
                [[step < h for h in self.horizons] for step in range(max_horizon)],
                device=device, dtype=torch.bool
            ).unsqueeze(0)
            
            masked_gate_logits = torch.where(
                valid_heads_mask,
                gate_logits,
                torch.finfo(gate_logits.dtype).min
            )
            gate_weights = F.softmax(masked_gate_logits, dim=-1)
            
            if ret_weights:
                gate_weights_to_log.append(torch.round(gate_weights, decimals=3))
            
            # Ensemble predictions
            v_t = (gate_weights.permute(0, 2, 1).unsqueeze(-1) * all_v_t_preds_padded).sum(dim=1)
            
            # Euler update
            actions = actions + dt * v_t
        
        return_dict = {"actions": actions}
        if ret_weights and len(gate_weights_to_log) > 0:
            return_dict["gate_weights"] = torch.stack(gate_weights_to_log, dim=1).detach().cpu()
        
        return return_dict["actions"]


def get_action_model(config=None, horizons: List[int] = [2,5,8]):
    """
    Factory: build FlowmatchingActionHeadMoH from global framework config.
    
    Args:
        config: Global config (expects config.framework.action_model namespace).
        horizons: List of horizon lengths to use for MoH
    
    Returns:
        FlowmatchingActionHeadMoH: Initialized MoH ActionHeader.
    """
    return FlowmatchingActionHeadMoH(
        full_config=config,
        horizons=horizons,
    )