""" losses.py """ from dataclasses import dataclass from typing import Dict, Tuple import torch import torch.nn.functional as F # ============================================================ # 1) CONFIG # ============================================================ @dataclass class GeneralistLossConfig: w_action: float = 1.0 w_physics: float = 20.0 w_value: float = 100.0 label_smoothing: float = 0.0 use_rtg_weighting: bool = True rtg_weight_mode: str = "exp" rtg_weight_beta: float = 2.0 min_token_weight: float = 0.05 # ============================================================ # 2) HELPERS # ============================================================ def _expand_rtg_to_tokens(rtg_bt: torch.Tensor, K: int) -> torch.Tensor: return rtg_bt.unsqueeze(-1).expand(-1, -1, K) def _rtg_to_weights(rtg_input: torch.Tensor, mode: str, beta: float) -> torch.Tensor: if mode == "none": return torch.ones(rtg_input.shape[:2], device=rtg_input.device) if rtg_input.dim() == 3: mu = rtg_input.mean(dim=1, keepdim=True) sig = rtg_input.std(dim=1, keepdim=True, unbiased=False).clamp_min(1e-5) rtg_norm = (rtg_input - mu) / sig scalar_rtg = rtg_norm.sum(dim=-1) else: scalar_rtg = rtg_input mu_s = scalar_rtg.mean(dim=1, keepdim=True) sig_s = scalar_rtg.std(dim=1, keepdim=True, unbiased=False).clamp_min(1e-5) z = (scalar_rtg - mu_s) / sig_s z = torch.clamp(z, -5.0, 5.0) if mode == "clamp01": w = torch.sigmoid(beta * z) elif mode == "softplus": w = F.softplus(beta * z) elif mode == "exp": w = torch.exp(beta * z) else: raise ValueError(f"Unknown rtg_weight_mode={mode}") w = torch.clamp(w, min=0.01, max=50.0) return w # return total, metrics def compute_generalist_loss( model_out: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor], config: GeneralistLossConfig ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Computes Physics loss and Rescaled Value loss. """ action_logits = model_out["action_logits"] # [B, T, K, n_bins] state_preds = model_out["state_preds"] # [B, T, K] state_preds_4h = model_out["state_preds_4h"] # [B, T, K] return_preds = model_out["return_preds"] # [B, T, 2] target_tokens = batch["target_action_tokens"] target_mask = batch["target_mask"].float() attn_mask = batch["attention_mask"].float() target_rtg = batch["rtg"].float() time_mask = batch.get("time_mask", torch.ones(target_rtg.shape[:2], device=target_rtg.device)).float() B, T, K, n_bins = action_logits.shape is_state = (1.0 - target_mask) valid_phys = attn_mask * is_state # 1) Stitching if config.use_rtg_weighting: w_bt = _rtg_to_weights(target_rtg, config.rtg_weight_mode, config.rtg_weight_beta) w_btk = _expand_rtg_to_tokens(w_bt, K) norm_factor = (target_mask * attn_mask).sum().clamp_min(1e-6) / (w_btk * target_mask * attn_mask).sum().clamp_min(1e-6) token_importance = w_btk * norm_factor else: w_bt = torch.ones((B, T), device=action_logits.device) token_importance = torch.ones((B, T, K), device=action_logits.device) # 2) ACTION LOSS (CE) flat_logits = action_logits.reshape(-1, n_bins) flat_targets = target_tokens.reshape(-1) flat_mask = (target_mask * attn_mask).reshape(-1) flat_importance = token_importance.reshape(-1) with torch.no_grad(): valid_t = flat_targets[flat_mask > 0.5] if valid_t.numel() > 0: counts = torch.bincount(valid_t, minlength=n_bins).float() class_weights = (1.0 / (counts + 10.0)) / (1.0 / (counts + 10.0)).mean() else: class_weights = torch.ones(n_bins, device=flat_logits.device) ce_per_token = F.cross_entropy(flat_logits, flat_targets, weight=class_weights, reduction="none", ignore_index=-100) loss_action = (ce_per_token * flat_mask * flat_importance).sum() / flat_mask.sum().clamp_min(1e-6) # ============================================================ # 3) PHYSICS LOSS (The Delta Fix) # ============================================================ # Ground Truth from Dataloader # next_obs is [B, T, 21] # feature_values is [B, T, 64] (Padded tokens) true_next = batch["next_obs"].float() target_delta_4h = batch["target_4h_delta"].float() K_limit = true_next.shape[2] true_vals_sliced = batch["feature_values"].float().narrow(2, 0, K_limit) s_pred_valid = state_preds.narrow(2, 0, K_limit) s_pred_4h_valid = state_preds_4h.narrow(2, 0, K_limit) v_phys_mask = valid_phys.narrow(2, 0, K_limit) target_delta_1s = true_next - true_vals_sliced mse_1s = (s_pred_valid - target_delta_1s) ** 2 mse_4h = (s_pred_4h_valid - target_delta_4h) ** 2 with torch.no_grad(): act_diff = torch.zeros((B, T), device=true_next.device) if T > 1: act_diff[:, 1:] = torch.abs(true_vals_sliced[:, 1:] - true_vals_sliced[:, :-1]).sum(dim=-1) excitation = (1.0 + 5.0 * act_diff).unsqueeze(-1) denom = (v_phys_mask * excitation).sum().clamp_min(1e-6) loss_phys_1s = (mse_1s * v_phys_mask * excitation).sum() / denom loss_phys_4h = (mse_4h * v_phys_mask * excitation).sum() / denom loss_physics = loss_phys_1s + 0.5 * loss_phys_4h val_mse = ((return_preds - target_rtg) ** 2).sum(dim=-1) loss_value = (val_mse * w_bt * time_mask).sum() / time_mask.sum().clamp_min(1e-6) loss_value = loss_value * 500.0 total = (config.w_action * loss_action) + \ (config.w_physics * loss_physics) + \ (config.w_value * loss_value) with torch.no_grad(): acc = ((torch.argmax(flat_logits, -1) == flat_targets).float() * flat_mask).sum() / flat_mask.sum().clamp_min(1e-6) if torch.rand(1) < 0.001: print(f"[Loss Debug] Action: {loss_action.item():.3f} | Phys: {loss_physics.item():.3f} | Val: {loss_value.item():.3f}") metrics = { "loss_action": loss_action.detach(), "loss_physics": loss_physics.detach(), "loss_value": loss_value.detach(), "accuracy": acc.detach(), "total_loss": total.detach(), } return total, metrics