""" Mixture of Horizons (MoH) version of GROOT ActionHeader. This is a reference implementation showing how to integrate MoH strategy into GROOT. Key changes: 1. Support multiple horizons (e.g., [5, 10, 15, 20, 50]) 2. Parallel processing via batching (batch_size * num_horizons) 3. Gating network for ensemble 4. Multi-component loss (individual + auxiliary + load balancing) """ import torch import torch.nn.functional as F from torch import nn from typing import Optional, List from dataclasses import dataclass, field from starVLA.model.modules.action_model.flow_matching_head.action_encoder import ( SinusoidalPositionalEncoding, swish, ) from starVLA.model.modules.action_model.flow_matching_head.cross_attention_dit import DiT from starVLA.model.modules.action_model.GR00T_ActionHeader import ( FlowmatchingActionHeadConfig, ActionEncoder, MLP, ) class FlowmatchingActionHeadMoH(nn.Module): """ GROOT ActionHeader with Mixture of Horizons support. Key differences from original: - Supports multiple horizons (e.g., [5, 10, 15, 20, 50]) - Processes all horizons in parallel via batching - Uses gating network to ensemble predictions - Multi-component loss function """ def __init__( self, full_config, horizons: List[int] = [2,5,8], # Different horizon lengths use_gate_noise: bool = True, # Add learnable noise to gate logits ): super().__init__() config = full_config.framework.action_model self.horizons = sorted(horizons) # Ensure sorted self.max_horizon = self.horizons[-1] self.num_horizons = len(self.horizons) self.use_gate_noise = use_gate_noise self.hidden_size = config.hidden_size self.full_config = full_config action_model_type = config.action_model_type action_model_cfg = { "DiT-B": {"input_embedding_dim": 768, "attention_head_dim": 64, "num_attention_heads": 12}, "DiT-L": {"input_embedding_dim": 1536, "attention_head_dim": 48, "num_attention_heads": 32}, }[action_model_type] self.input_embedding_dim = action_model_cfg["input_embedding_dim"] diffusion_model_cfg = config.diffusion_model_cfg diffusion_model_cfg = {**action_model_cfg, **diffusion_model_cfg} self.model = DiT(**diffusion_model_cfg) self.action_dim = config.action_dim self.action_horizon = config.future_action_window_size + 1 self.num_inference_timesteps = config.num_inference_timesteps self.state_encoder = MLP( input_dim=config.state_dim, hidden_dim=self.hidden_size, output_dim=self.input_embedding_dim, ) if config.state_dim else None self.action_encoder = ActionEncoder( action_dim=config.action_dim, hidden_size=self.input_embedding_dim, ) self.action_decoder = MLP( input_dim=self.model.config.output_dim, hidden_dim=self.hidden_size, output_dim=self.action_dim, ) self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim) nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02) if config.add_pos_embed: self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim) nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) self.beta_dist = torch.distributions.Beta(config.noise_beta_alpha, config.noise_beta_beta) self.num_timestep_buckets = config.num_timestep_buckets self.config = config # MoH-specific components # Gating network: predicts weights for each horizon at each timestep # Input: model output features, Output: gate logits for each horizon self.gate_out_proj = nn.Linear(self.model.config.output_dim, 1) if self.use_gate_noise: self.gate_noise_layer = nn.Linear(self.model.config.output_dim, 1) self.softplus = nn.Softplus() print(f"MoH ActionHeader initialized with horizons: {self.horizons}") def sample_time(self, batch_size, device, dtype): sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype).clamp(max=self.config.noise_s) return (self.config.noise_s - sample) / self.config.noise_s def cv_squared(self, x): """Coefficient of variation squared for load balancing.""" eps = 1e-10 if x.shape[0] == 1: return torch.tensor(0.0, device=x.device, dtype=x.dtype) return x.float().var() / (x.float().mean() ** 2 + eps) def forward( self, vl_embs: torch.Tensor, actions: torch.Tensor, state: torch.Tensor = None, encoder_attention_mask=None, loss_config: dict = None ): """ Forward pass with MoH strategy. Args: vl_embs: (B, seq_length, feature_dim) - Vision-language embeddings actions: (B, max_horizon, D_action) - Ground truth actions (padded to max_horizon) state: (B, state_dim) - Optional state features encoder_attention_mask: Attention mask for encoder loss_config: Dict with 'aux_weight' and 'balance_weight' Returns: total_loss: Combined loss from all components """ device = vl_embs.device batch_size = actions.shape[0] num_horizons = len(self.horizons) max_horizon = self.max_horizon # Sample noise and time noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype) time_scalar = self.sample_time(batch_size, device, actions.dtype) # Expand time for each horizon: (num_h, batch_size) time = time_scalar.unsqueeze(0).expand(num_horizons, -1) # x_t: (num_h, batch_size, max_horizon, action_dim) # Flow matching: x_t = (1-t) * noise + t * actions, where t=0 is noise and t=1 is actions # Expand noise and actions to (num_h, batch_size, max_horizon, action_dim) noise_expanded = noise.unsqueeze(0).expand(num_horizons, -1, -1, -1) actions_expanded = actions.unsqueeze(0).expand(num_horizons, -1, -1, -1) t_expanded = time[:, :, None, None] # (num_h, batch_size, 1, 1) x_t = (1 - t_expanded) * noise_expanded + t_expanded * actions_expanded # u_t (target velocity): (batch_size, max_horizon, action_dim) u_t = actions - noise # ============================================================ # STAGE 1: Prepare inputs for parallel processing # ============================================================ # Repeat vl_embs and state for each horizon batched_vl_embs = vl_embs.repeat_interleave(num_horizons, dim=0) # (B*H, seq_len, dim) batched_state = state.repeat_interleave(num_horizons, dim=0) if state is not None else None if encoder_attention_mask is not None: batched_encoder_attention_mask = encoder_attention_mask.repeat_interleave(num_horizons, dim=0) else: batched_encoder_attention_mask = None # Reshape x_t and time for batched processing # x_t: (num_h, batch_size, max_horizon, dim) -> (batch_size * num_h, max_horizon, dim) batched_x_t = x_t.permute(1, 0, 2, 3).reshape(batch_size * num_horizons, max_horizon, -1) # time: (num_h, batch_size) -> (batch_size * num_h) batched_time = time.permute(1, 0).reshape(batch_size * num_horizons) # Create padding masks for each horizon # action_pad_mask: (num_h, max_horizon) - True where valid, False where padding action_pad_mask = torch.arange(max_horizon, device=device)[None, :] < \ torch.tensor(self.horizons, device=device)[:, None] # Expand to batch: (num_h, batch_size, max_horizon) action_pad_mask = action_pad_mask.unsqueeze(1).expand(-1, batch_size, -1) # Reshape: (batch_size * num_h, max_horizon) batched_action_pad_mask = action_pad_mask.permute(1, 0, 2).reshape(batch_size * num_horizons, max_horizon) # ============================================================ # STAGE 2: Forward pass through model (parallel for all horizons) # ============================================================ # Convert time to discrete timesteps t_discretized = (batched_time * self.num_timestep_buckets).long() # Encode actions action_features = self.action_encoder(batched_x_t, t_discretized) # Add position embedding if needed if self.config.add_pos_embed: pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) pos_embs = self.position_embedding(pos_ids).unsqueeze(0) action_features = action_features + pos_embs # Prepare state and action embeddings future_tokens = self.future_tokens.weight.unsqueeze(0).expand( batch_size * num_horizons, -1, -1 ) if batched_state is not None: state_features = self.state_encoder(batched_state) sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1) else: sa_embs = torch.cat((future_tokens, action_features), dim=1) # Forward through DiT model model_output = self.model( hidden_states=sa_embs, encoder_hidden_states=batched_vl_embs, encoder_attention_mask=batched_encoder_attention_mask, timestep=t_discretized, return_all_hidden_states=False, ) # Decode actions pred = self.action_decoder(model_output) # Extract action predictions (last max_horizon tokens) # pred: (B*H, seq_len, action_dim) -> (B*H, max_horizon, action_dim) state_offset = 1 if state is not None else 0 future_tokens_len = self.future_tokens.num_embeddings action_start_idx = state_offset + future_tokens_len pred_actions_padded = pred[:, action_start_idx:action_start_idx + max_horizon, :] # Reshape to separate predictions for each horizon # (B*H, max_horizon, dim) -> (B, H, max_horizon, dim) -> (H, B, max_horizon, dim) all_v_t_preds = pred_actions_padded.view( batch_size, num_horizons, max_horizon, -1 ).permute(1, 0, 2, 3) # ============================================================ # STAGE 3: Compute losses # ============================================================ # 1. Individual loss: Each horizon's prediction vs target all_head_losses = [] for i, h in enumerate(self.horizons): v_t_head = all_v_t_preds[i, :, :h, :] # (B, h, dim) target_v_t = u_t[:, :h, :] # (B, h, dim) head_loss = F.mse_loss(v_t_head, target_v_t) all_head_losses.append(head_loss) individual_loss = torch.sum(torch.stack(all_head_losses)) # 2. Gating network: Generate weights for ensemble gate_logits = self.gate_out_proj(model_output.to(torch.float32)) gate_logits = gate_logits[:, action_start_idx:action_start_idx + max_horizon, :] # (B*H, max_horizon, 1) if self.use_gate_noise: # Add learnable noise to gate logits noise_epsilon = 1e-2 raw_noise_stddev = self.gate_noise_layer(model_output.to(torch.float32)) raw_noise_stddev = raw_noise_stddev[:, action_start_idx:action_start_idx + max_horizon, :] noise_stddev = self.softplus(raw_noise_stddev) + noise_epsilon gate_logits = gate_logits + (torch.randn_like(gate_logits) * noise_stddev) # Reshape gate logits: (B*H, max_horizon, 1) -> (B, H, max_horizon) -> (B, max_horizon, H) gate_logits = gate_logits.reshape(batch_size, num_horizons, max_horizon).permute(0, 2, 1) # Apply mask: invalid horizons (where step >= horizon) get -inf valid_heads_mask = torch.tensor( [[step < h for h in self.horizons] for step in range(max_horizon)], device=device, dtype=torch.bool ).unsqueeze(0) # (1, max_horizon, H) masked_gate_logits = torch.where( valid_heads_mask, gate_logits, torch.finfo(gate_logits.dtype).min ) gate_weights = F.softmax(masked_gate_logits, dim=-1) # (B, max_horizon, H) # 3. Ensemble predictions using gate weights # all_v_t_preds: (H, B, max_horizon, dim) -> (B, H, max_horizon, dim) all_v_t_preds_padded = all_v_t_preds.permute(1, 0, 2, 3) # gate_weights: (B, max_horizon, H) -> (B, H, max_horizon, 1) # combined: (B, max_horizon, dim) v_t_combined = (gate_weights.permute(0, 2, 1).unsqueeze(-1) * all_v_t_preds_padded).sum(dim=1) # Auxiliary loss: Ensemble prediction vs target aux_loss_weight = loss_config.get("aux_weight", 1.0) if loss_config else 1.0 auxiliary_loss = F.mse_loss(v_t_combined, u_t) # 4. Load balancing loss: Encourage balanced usage of horizons loss_components = [] boundaries = sorted(list(set([0] + self.horizons))) for i in range(len(boundaries) - 1): start_step, end_step = boundaries[i], boundaries[i + 1] active_expert_indices = [idx for idx, h in enumerate(self.horizons) if h > start_step] if len(active_expert_indices) > 1: segment_gate_weights = gate_weights[:, start_step:end_step, :] active_expert_weights = segment_gate_weights[:, :, active_expert_indices] avg_expert_prob_in_segment = active_expert_weights.mean(dim=(0, 1)) segment_loss = self.cv_squared(avg_expert_prob_in_segment) loss_components.append(segment_loss) load_balancing_loss = torch.mean(torch.stack(loss_components)) if loss_components else torch.tensor(0.0, device=device) balance_loss_weight = loss_config.get("balance_weight", 0.001) if loss_config else 0.001 # Total loss total_loss = individual_loss + aux_loss_weight * auxiliary_loss + balance_loss_weight * load_balancing_loss return total_loss @torch.no_grad() def predict_action( self, vl_embs: torch.Tensor, state: torch.Tensor = None, ret_weights: bool = False ) -> dict: """ Inference with MoH ensemble. Args: vl_embs: (B, seq_length, feature_dim) state: (B, state_dim) - Optional ret_weights: Whether to return gate weights Returns: dict with 'actions' and optionally 'gate_weights' """ batch_size = vl_embs.shape[0] device = vl_embs.device num_horizons = len(self.horizons) max_horizon = self.max_horizon # Initialize actions as noise actions = torch.randn( size=(batch_size, max_horizon, self.action_dim), dtype=vl_embs.dtype, device=device, ) num_steps = self.num_inference_timesteps dt = 1.0 / num_steps gate_weights_to_log = [] # Prepare batched inputs (same for all denoising steps) batched_vl_embs = vl_embs.repeat_interleave(num_horizons, dim=0) batched_state = state.repeat_interleave(num_horizons, dim=0) if state is not None else None # Denoising loop for t in range(num_steps): t_cont = t / float(num_steps) t_discretized = int(t_cont * self.num_timestep_buckets) timesteps_tensor = torch.full( size=(batch_size * num_horizons,), fill_value=t_discretized, device=device ) # Prepare padded actions for each horizon padded_x_t_list, action_pad_mask_list = [], [] for h in self.horizons: padded_x_t = F.pad(actions[:, :h, :], (0, 0, 0, max_horizon - h)) padded_x_t_list.append(padded_x_t) pad_mask = F.pad( torch.ones((batch_size, h), device=device, dtype=torch.bool), (0, max_horizon - h), value=False ) action_pad_mask_list.append(pad_mask) batched_x_t = torch.cat(padded_x_t_list, dim=0) action_pad_mask = torch.cat(action_pad_mask_list, dim=0) # Encode actions action_features = self.action_encoder(batched_x_t, timesteps_tensor) if self.config.add_pos_embed: pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) pos_embs = self.position_embedding(pos_ids).unsqueeze(0) action_features = action_features + pos_embs # Prepare embeddings future_tokens = self.future_tokens.weight.unsqueeze(0).expand( batch_size * num_horizons, -1, -1 ) if batched_state is not None: state_features = self.state_encoder(batched_state) sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1) else: sa_embs = torch.cat((future_tokens, action_features), dim=1) # Forward through model model_output = self.model( hidden_states=sa_embs, encoder_hidden_states=batched_vl_embs, timestep=timesteps_tensor, ) # Decode actions pred = self.action_decoder(model_output) # Extract action predictions state_offset = 1 if state is not None else 0 future_tokens_len = self.future_tokens.num_embeddings action_start_idx = state_offset + future_tokens_len pred_actions_padded = pred[:, action_start_idx:action_start_idx + max_horizon, :] # Reshape: (B*H, max_horizon, dim) -> (B, H, max_horizon, dim) all_v_t_preds_padded = pred_actions_padded.view( num_horizons, batch_size, max_horizon, -1 ).permute(1, 0, 2, 3) # Gating network gate_logits = self.gate_out_proj(model_output.to(torch.float32)) gate_logits = gate_logits[:, action_start_idx:action_start_idx + max_horizon, :] gate_logits = gate_logits.reshape(batch_size, num_horizons, max_horizon).permute(0, 2, 1) valid_heads_mask = torch.tensor( [[step < h for h in self.horizons] for step in range(max_horizon)], device=device, dtype=torch.bool ).unsqueeze(0) masked_gate_logits = torch.where( valid_heads_mask, gate_logits, torch.finfo(gate_logits.dtype).min ) gate_weights = F.softmax(masked_gate_logits, dim=-1) if ret_weights: gate_weights_to_log.append(torch.round(gate_weights, decimals=3)) # Ensemble predictions v_t = (gate_weights.permute(0, 2, 1).unsqueeze(-1) * all_v_t_preds_padded).sum(dim=1) # Euler update actions = actions + dt * v_t return_dict = {"actions": actions} if ret_weights and len(gate_weights_to_log) > 0: return_dict["gate_weights"] = torch.stack(gate_weights_to_log, dim=1).detach().cpu() return return_dict["actions"] def get_action_model(config=None, horizons: List[int] = [2,5,8]): """ Factory: build FlowmatchingActionHeadMoH from global framework config. Args: config: Global config (expects config.framework.action_model namespace). horizons: List of horizon lengths to use for MoH Returns: FlowmatchingActionHeadMoH: Initialized MoH ActionHeader. """ return FlowmatchingActionHeadMoH( full_config=config, horizons=horizons, )