File size: 6,333 Bytes
1641a08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159

"""
losses.py
"""

from dataclasses import dataclass
from typing import Dict, Tuple

import torch
import torch.nn.functional as F


# ============================================================
# 1) CONFIG
# ============================================================

@dataclass
class GeneralistLossConfig:
    w_action: float = 1.0
    w_physics: float = 20.0
    w_value: float = 100.0
    label_smoothing: float = 0.0
    use_rtg_weighting: bool = True
    rtg_weight_mode: str = "exp"
    rtg_weight_beta: float = 2.0 
    min_token_weight: float = 0.05


# ============================================================
# 2) HELPERS
# ============================================================

def _expand_rtg_to_tokens(rtg_bt: torch.Tensor, K: int) -> torch.Tensor:
    return rtg_bt.unsqueeze(-1).expand(-1, -1, K)


def _rtg_to_weights(rtg_input: torch.Tensor, mode: str, beta: float) -> torch.Tensor:
    if mode == "none": 
        return torch.ones(rtg_input.shape[:2], device=rtg_input.device)
    if rtg_input.dim() == 3:
        mu = rtg_input.mean(dim=1, keepdim=True)
        sig = rtg_input.std(dim=1, keepdim=True, unbiased=False).clamp_min(1e-5)
        rtg_norm = (rtg_input - mu) / sig
        scalar_rtg = rtg_norm.sum(dim=-1)
    else:
        scalar_rtg = rtg_input
    mu_s = scalar_rtg.mean(dim=1, keepdim=True)
    sig_s = scalar_rtg.std(dim=1, keepdim=True, unbiased=False).clamp_min(1e-5)
    
    z = (scalar_rtg - mu_s) / sig_s
    z = torch.clamp(z, -5.0, 5.0)
    if mode == "clamp01": 
        w = torch.sigmoid(beta * z)
    elif mode == "softplus": 
        w = F.softplus(beta * z)
    elif mode == "exp": 
        w = torch.exp(beta * z)
    else: 
        raise ValueError(f"Unknown rtg_weight_mode={mode}")
    w = torch.clamp(w, min=0.01, max=50.0)
    return w

#     return total, metrics
def compute_generalist_loss(
    model_out: Dict[str, torch.Tensor],
    batch: Dict[str, torch.Tensor],
    config: GeneralistLossConfig
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
    """
    Computes Physics loss and Rescaled Value loss.
    """
    action_logits = model_out["action_logits"] # [B, T, K, n_bins]
    state_preds = model_out["state_preds"]     # [B, T, K]
    state_preds_4h = model_out["state_preds_4h"] # [B, T, K]
    return_preds = model_out["return_preds"]   # [B, T, 2]

    target_tokens = batch["target_action_tokens"]
    target_mask = batch["target_mask"].float()
    attn_mask = batch["attention_mask"].float()
    target_rtg = batch["rtg"].float() 
    time_mask = batch.get("time_mask", torch.ones(target_rtg.shape[:2], device=target_rtg.device)).float()
    
    B, T, K, n_bins = action_logits.shape
    is_state = (1.0 - target_mask) 
    valid_phys = attn_mask * is_state

    # 1) Stitching
    if config.use_rtg_weighting:
        w_bt = _rtg_to_weights(target_rtg, config.rtg_weight_mode, config.rtg_weight_beta)
        w_btk = _expand_rtg_to_tokens(w_bt, K)
        norm_factor = (target_mask * attn_mask).sum().clamp_min(1e-6) / (w_btk * target_mask * attn_mask).sum().clamp_min(1e-6)
        token_importance = w_btk * norm_factor
    else:
        w_bt = torch.ones((B, T), device=action_logits.device)
        token_importance = torch.ones((B, T, K), device=action_logits.device)

    # 2) ACTION LOSS (CE)
    flat_logits = action_logits.reshape(-1, n_bins)
    flat_targets = target_tokens.reshape(-1)
    flat_mask = (target_mask * attn_mask).reshape(-1)
    flat_importance = token_importance.reshape(-1)

    with torch.no_grad():
        valid_t = flat_targets[flat_mask > 0.5]
        if valid_t.numel() > 0:
            counts = torch.bincount(valid_t, minlength=n_bins).float()
            class_weights = (1.0 / (counts + 10.0)) / (1.0 / (counts + 10.0)).mean()
        else:
            class_weights = torch.ones(n_bins, device=flat_logits.device)

    ce_per_token = F.cross_entropy(flat_logits, flat_targets, weight=class_weights, reduction="none", ignore_index=-100)
    loss_action = (ce_per_token * flat_mask * flat_importance).sum() / flat_mask.sum().clamp_min(1e-6)

    # ============================================================
    # 3) PHYSICS LOSS (The Delta Fix)
    # ============================================================
    # Ground Truth from Dataloader
    # next_obs is [B, T, 21]
    # feature_values is [B, T, 64] (Padded tokens)
    true_next = batch["next_obs"].float()              
    target_delta_4h = batch["target_4h_delta"].float() 
    K_limit = true_next.shape[2] 
    true_vals_sliced = batch["feature_values"].float().narrow(2, 0, K_limit)
    s_pred_valid = state_preds.narrow(2, 0, K_limit)
    s_pred_4h_valid = state_preds_4h.narrow(2, 0, K_limit)
    v_phys_mask = valid_phys.narrow(2, 0, K_limit)
    target_delta_1s = true_next - true_vals_sliced
    mse_1s = (s_pred_valid - target_delta_1s) ** 2
    mse_4h = (s_pred_4h_valid - target_delta_4h) ** 2
    with torch.no_grad():
        act_diff = torch.zeros((B, T), device=true_next.device)
        if T > 1:
            act_diff[:, 1:] = torch.abs(true_vals_sliced[:, 1:] - true_vals_sliced[:, :-1]).sum(dim=-1)
        excitation = (1.0 + 5.0 * act_diff).unsqueeze(-1) 
    denom = (v_phys_mask * excitation).sum().clamp_min(1e-6)
    
    loss_phys_1s = (mse_1s * v_phys_mask * excitation).sum() / denom
    loss_phys_4h = (mse_4h * v_phys_mask * excitation).sum() / denom
    
    loss_physics = loss_phys_1s + 0.5 * loss_phys_4h
    val_mse = ((return_preds - target_rtg) ** 2).sum(dim=-1)
    loss_value = (val_mse * w_bt * time_mask).sum() / time_mask.sum().clamp_min(1e-6)
    loss_value = loss_value * 500.0 
    total = (config.w_action * loss_action) + \
            (config.w_physics * loss_physics) + \
            (config.w_value * loss_value)
    with torch.no_grad():
        acc = ((torch.argmax(flat_logits, -1) == flat_targets).float() * flat_mask).sum() / flat_mask.sum().clamp_min(1e-6)
        if torch.rand(1) < 0.001:
            print(f"[Loss Debug] Action: {loss_action.item():.3f} | Phys: {loss_physics.item():.3f} | Val: {loss_value.item():.3f}")

    metrics = {
        "loss_action": loss_action.detach(),
        "loss_physics": loss_physics.detach(),
        "loss_value": loss_value.detach(),
        "accuracy": acc.detach(),
        "total_loss": total.detach(),
    }
    return total, metrics