|
|
| """ |
| losses.py |
| """ |
|
|
| from dataclasses import dataclass |
| from typing import Dict, Tuple |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class GeneralistLossConfig: |
| w_action: float = 1.0 |
| w_physics: float = 20.0 |
| w_value: float = 100.0 |
| label_smoothing: float = 0.0 |
| use_rtg_weighting: bool = True |
| rtg_weight_mode: str = "exp" |
| rtg_weight_beta: float = 2.0 |
| min_token_weight: float = 0.05 |
|
|
|
|
| |
| |
| |
|
|
| def _expand_rtg_to_tokens(rtg_bt: torch.Tensor, K: int) -> torch.Tensor: |
| return rtg_bt.unsqueeze(-1).expand(-1, -1, K) |
|
|
|
|
| def _rtg_to_weights(rtg_input: torch.Tensor, mode: str, beta: float) -> torch.Tensor: |
| if mode == "none": |
| return torch.ones(rtg_input.shape[:2], device=rtg_input.device) |
| if rtg_input.dim() == 3: |
| mu = rtg_input.mean(dim=1, keepdim=True) |
| sig = rtg_input.std(dim=1, keepdim=True, unbiased=False).clamp_min(1e-5) |
| rtg_norm = (rtg_input - mu) / sig |
| scalar_rtg = rtg_norm.sum(dim=-1) |
| else: |
| scalar_rtg = rtg_input |
| mu_s = scalar_rtg.mean(dim=1, keepdim=True) |
| sig_s = scalar_rtg.std(dim=1, keepdim=True, unbiased=False).clamp_min(1e-5) |
| |
| z = (scalar_rtg - mu_s) / sig_s |
| z = torch.clamp(z, -5.0, 5.0) |
| if mode == "clamp01": |
| w = torch.sigmoid(beta * z) |
| elif mode == "softplus": |
| w = F.softplus(beta * z) |
| elif mode == "exp": |
| w = torch.exp(beta * z) |
| else: |
| raise ValueError(f"Unknown rtg_weight_mode={mode}") |
| w = torch.clamp(w, min=0.01, max=50.0) |
| return w |
|
|
| |
| def compute_generalist_loss( |
| model_out: Dict[str, torch.Tensor], |
| batch: Dict[str, torch.Tensor], |
| config: GeneralistLossConfig |
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
| """ |
| Computes Physics loss and Rescaled Value loss. |
| """ |
| action_logits = model_out["action_logits"] |
| state_preds = model_out["state_preds"] |
| state_preds_4h = model_out["state_preds_4h"] |
| return_preds = model_out["return_preds"] |
|
|
| target_tokens = batch["target_action_tokens"] |
| target_mask = batch["target_mask"].float() |
| attn_mask = batch["attention_mask"].float() |
| target_rtg = batch["rtg"].float() |
| time_mask = batch.get("time_mask", torch.ones(target_rtg.shape[:2], device=target_rtg.device)).float() |
| |
| B, T, K, n_bins = action_logits.shape |
| is_state = (1.0 - target_mask) |
| valid_phys = attn_mask * is_state |
|
|
| |
| if config.use_rtg_weighting: |
| w_bt = _rtg_to_weights(target_rtg, config.rtg_weight_mode, config.rtg_weight_beta) |
| w_btk = _expand_rtg_to_tokens(w_bt, K) |
| norm_factor = (target_mask * attn_mask).sum().clamp_min(1e-6) / (w_btk * target_mask * attn_mask).sum().clamp_min(1e-6) |
| token_importance = w_btk * norm_factor |
| else: |
| w_bt = torch.ones((B, T), device=action_logits.device) |
| token_importance = torch.ones((B, T, K), device=action_logits.device) |
|
|
| |
| flat_logits = action_logits.reshape(-1, n_bins) |
| flat_targets = target_tokens.reshape(-1) |
| flat_mask = (target_mask * attn_mask).reshape(-1) |
| flat_importance = token_importance.reshape(-1) |
|
|
| with torch.no_grad(): |
| valid_t = flat_targets[flat_mask > 0.5] |
| if valid_t.numel() > 0: |
| counts = torch.bincount(valid_t, minlength=n_bins).float() |
| class_weights = (1.0 / (counts + 10.0)) / (1.0 / (counts + 10.0)).mean() |
| else: |
| class_weights = torch.ones(n_bins, device=flat_logits.device) |
|
|
| ce_per_token = F.cross_entropy(flat_logits, flat_targets, weight=class_weights, reduction="none", ignore_index=-100) |
| loss_action = (ce_per_token * flat_mask * flat_importance).sum() / flat_mask.sum().clamp_min(1e-6) |
|
|
| |
| |
| |
| |
| |
| |
| true_next = batch["next_obs"].float() |
| target_delta_4h = batch["target_4h_delta"].float() |
| K_limit = true_next.shape[2] |
| true_vals_sliced = batch["feature_values"].float().narrow(2, 0, K_limit) |
| s_pred_valid = state_preds.narrow(2, 0, K_limit) |
| s_pred_4h_valid = state_preds_4h.narrow(2, 0, K_limit) |
| v_phys_mask = valid_phys.narrow(2, 0, K_limit) |
| target_delta_1s = true_next - true_vals_sliced |
| mse_1s = (s_pred_valid - target_delta_1s) ** 2 |
| mse_4h = (s_pred_4h_valid - target_delta_4h) ** 2 |
| with torch.no_grad(): |
| act_diff = torch.zeros((B, T), device=true_next.device) |
| if T > 1: |
| act_diff[:, 1:] = torch.abs(true_vals_sliced[:, 1:] - true_vals_sliced[:, :-1]).sum(dim=-1) |
| excitation = (1.0 + 5.0 * act_diff).unsqueeze(-1) |
| denom = (v_phys_mask * excitation).sum().clamp_min(1e-6) |
| |
| loss_phys_1s = (mse_1s * v_phys_mask * excitation).sum() / denom |
| loss_phys_4h = (mse_4h * v_phys_mask * excitation).sum() / denom |
| |
| loss_physics = loss_phys_1s + 0.5 * loss_phys_4h |
| val_mse = ((return_preds - target_rtg) ** 2).sum(dim=-1) |
| loss_value = (val_mse * w_bt * time_mask).sum() / time_mask.sum().clamp_min(1e-6) |
| loss_value = loss_value * 500.0 |
| total = (config.w_action * loss_action) + \ |
| (config.w_physics * loss_physics) + \ |
| (config.w_value * loss_value) |
| with torch.no_grad(): |
| acc = ((torch.argmax(flat_logits, -1) == flat_targets).float() * flat_mask).sum() / flat_mask.sum().clamp_min(1e-6) |
| if torch.rand(1) < 0.001: |
| print(f"[Loss Debug] Action: {loss_action.item():.3f} | Phys: {loss_physics.item():.3f} | Val: {loss_value.item():.3f}") |
|
|
| metrics = { |
| "loss_action": loss_action.detach(), |
| "loss_physics": loss_physics.detach(), |
| "loss_value": loss_value.detach(), |
| "accuracy": acc.detach(), |
| "total_loss": total.detach(), |
| } |
| return total, metrics |