| """ |
| v11 prediction heads: bar-distribution regression + bin-based classification. |
| |
| Both heads are ICL-friendly: they take trunk output [B, n_query, d_model] |
| and produce per-query predictions over a fixed-size output space (1024 bins |
| for regression, MAX_CLASSES=10 logits for classification). |
| |
| ## Why bar-dist for regression |
| Verified in v10: a single-Gaussian (μ, log_σ²) head can collapse to a |
| constant when the trunk's output drifts (v9 failure mode). Bar-dist's |
| 1024-bin cross-entropy can't collapse — every bin is independently |
| supervised. Also matches TabPFN v2's reg head exactly. |
| |
| ## Why bin-based for classification |
| v8/v10 used Linear(d_model, n_classes_max). For v11 we keep that structure |
| but add per-task masking: a task with n_classes=3 only computes CE over |
| the first 3 logits. This avoids the per-task linear-head trick used by |
| TabPFN (where the head is built from class prototypes inside each task) |
| which is harder to fit and gives no measurable gain at this scale per |
| Expert 4's pre-mortem on v11. |
| |
| ## Trunk interface contract |
| The trunk returns one tensor per task type: |
| reg_out: [B, n_query, d_model] - last column, query rows, after reg trunk layers |
| cls_out: [B, n_query, d_model] - last column, query rows, after cls trunk layers |
| |
| Both heads take this shape and produce per-query outputs. |
| """ |
| from __future__ import annotations |
|
|
| import math |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| MAX_CLASSES: int = 10 |
|
|
|
|
| |
|
|
|
|
| def default_bin_edges(n_bins: int = 1024, tail: float = 0.0001) -> torch.Tensor: |
| """Quantile-based bin edges on N(0,1), symmetric around 0. |
| |
| With n_bins=1024 + tail=0.0001, outer bins cover N⁻¹(0.0001) ≈ -3.72 to |
| N⁻¹(0.9999) ≈ +3.72 — wide enough to keep the heavy-tailed targets |
| that v11's `apply_heavy_tail_noise` extension is supposed to be |
| teaching from saturating the outermost bins. Earlier (tail=0.001) |
| capped at ±3.09, which forced ~0.5% of any heavy-tailed task's |
| targets into the outermost two bins (each ≈3σ wide) where CE has |
| no resolution. |
| """ |
| probs = torch.linspace(tail, 1.0 - tail, n_bins + 1) |
| edges = math.sqrt(2) * torch.erfinv(2 * probs - 1) |
| return edges |
|
|
|
|
| class BarDistributionHead(nn.Module): |
| """ |
| Bar-distribution (Riemann) regression head. |
| |
| Forward: x [..., d_model] → logits [..., n_bins]. |
| Loss is CE between predicted bin distribution and the bin containing |
| the (per-task standardized) target. |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| n_bins: int = 1024, |
| hidden_multiplier: int = 2, |
| dropout: float = 0.0, |
| bin_edges: Optional[torch.Tensor] = None, |
| ): |
| super().__init__() |
| self.d_model = d_model |
| self.n_bins = n_bins |
|
|
| if bin_edges is None: |
| bin_edges = default_bin_edges(n_bins) |
| assert bin_edges.shape == (n_bins + 1,) |
| self.register_buffer("bin_edges", bin_edges.float()) |
| centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) |
| self.register_buffer("bin_centers", centers.float()) |
|
|
| hidden = d_model * hidden_multiplier |
| self.mlp = nn.Sequential( |
| nn.Linear(d_model, hidden), |
| nn.GELU(), |
| nn.Dropout(dropout) if dropout > 0 else nn.Identity(), |
| nn.Linear(hidden, n_bins), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.mlp(x) |
|
|
| def predict_bin_ids(self, y_standardized: torch.Tensor) -> torch.Tensor: |
| idx = torch.bucketize(y_standardized, self.bin_edges[1:-1].to(y_standardized.device)) |
| return torch.clamp(idx, min=0, max=self.n_bins - 1) |
|
|
|
|
| def standardize_y_per_task( |
| y_ctx: torch.Tensor, |
| y_query: Optional[torch.Tensor] = None, |
| std_clip: float = 1e-3, |
| ): |
| """Per-task z-score using context-only stats, clipping std before division.""" |
| assert y_ctx.dtype == torch.float32, "y must be float32 for stable z-scoring" |
| mean = y_ctx.mean(dim=-1, keepdim=True) |
| std = y_ctx.std(dim=-1, keepdim=True, unbiased=False) |
| std_clipped = torch.clamp(std, min=std_clip) |
| y_ctx_std = (y_ctx - mean) / std_clipped |
| y_q_std = None if y_query is None else (y_query - mean) / std_clipped |
| return y_ctx_std, y_q_std, mean.squeeze(-1), std_clipped.squeeze(-1) |
|
|
|
|
| def bar_distribution_loss( |
| logits: torch.Tensor, |
| y_standardized: torch.Tensor, |
| head: BarDistributionHead, |
| label_smoothing: float = 0.0, |
| row_mask: Optional[torch.Tensor] = None, |
| reduction: str = "mean", |
| ) -> torch.Tensor: |
| """Cross-entropy over n_bins with optional row-mask for padded query rows. |
| |
| Args: |
| logits: [..., n_bins] |
| y_standardized: [...] standardized targets (per-task z-scored) |
| head: BarDistributionHead — needed for its bin structure |
| row_mask: optional bool mask, True = padded (excluded from loss) |
| reduction: "mean" returns a scalar; "none" returns per-task means [B] |
| """ |
| bin_ids = head.predict_bin_ids(y_standardized) |
| flat_logits = logits.reshape(-1, head.n_bins) |
| flat_targets = bin_ids.reshape(-1) |
| per_token = F.cross_entropy( |
| flat_logits, flat_targets, |
| label_smoothing=label_smoothing, |
| reduction="none", |
| ).reshape(*y_standardized.shape) |
|
|
| if row_mask is not None: |
| keep = (~row_mask).float() |
| else: |
| keep = torch.ones_like(per_token) |
|
|
| if reduction == "none": |
| denom = keep.sum(dim=-1).clamp(min=1) |
| return (per_token * keep).sum(dim=-1) / denom |
| total = (per_token * keep).sum() |
| n = keep.sum().clamp(min=1) |
| return total / n |
|
|
|
|
| def decode_bar_distribution( |
| logits: torch.Tensor, |
| head: BarDistributionHead, |
| mode: str = "mean", |
| quantile: float = 0.5, |
| y_mean: Optional[torch.Tensor] = None, |
| y_std: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Decode bar-dist logits to point predictions in original y space.""" |
| probs = F.softmax(logits, dim=-1) |
| centers = head.bin_centers.to(logits.device) |
| if mode == "mean": |
| pred_std = (probs * centers).sum(dim=-1) |
| elif mode in ("median", "quantile"): |
| q = 0.5 if mode == "median" else quantile |
| cdf = probs.cumsum(dim=-1) |
| idx = torch.searchsorted(cdf, torch.full_like(cdf[..., :1], q)).squeeze(-1) |
| idx = torch.clamp(idx, 0, head.n_bins - 1) |
| pred_std = centers[idx] |
| else: |
| raise ValueError(f"Unknown mode: {mode}") |
| if y_mean is not None and y_std is not None: |
| if y_mean.dim() != pred_std.dim(): |
| y_mean = y_mean.unsqueeze(-1) |
| y_std = y_std.unsqueeze(-1) |
| return pred_std * y_std + y_mean |
| return pred_std |
|
|
|
|
| def predict_variance( |
| logits: torch.Tensor, |
| head: BarDistributionHead, |
| y_std: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Predictive variance from the bar distribution (for coverage / calibration).""" |
| probs = F.softmax(logits, dim=-1) |
| centers = head.bin_centers.to(logits.device) |
| mean = (probs * centers).sum(dim=-1, keepdim=True) |
| var_std = (probs * (centers - mean) ** 2).sum(dim=-1) |
| if y_std is not None: |
| if y_std.dim() != var_std.dim(): |
| y_std = y_std.unsqueeze(-1) |
| return var_std * y_std * y_std |
| return var_std |
|
|
|
|
| |
|
|
|
|
| class BinClassificationHead(nn.Module): |
| """ |
| Classification head that emits MAX_CLASSES logits; the trainer masks |
| out logits ≥ task.n_classes before computing CE. |
| |
| Architecture: same 2-layer MLP as the bar-dist head, but output is |
| over MAX_CLASSES (default 10) instead of n_bins. |
| |
| Forward: x [..., d_model] → logits [..., MAX_CLASSES]. |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| max_classes: int = MAX_CLASSES, |
| hidden_multiplier: int = 2, |
| dropout: float = 0.0, |
| ): |
| super().__init__() |
| self.d_model = d_model |
| self.max_classes = max_classes |
|
|
| hidden = d_model * hidden_multiplier |
| self.mlp = nn.Sequential( |
| nn.Linear(d_model, hidden), |
| nn.GELU(), |
| nn.Dropout(dropout) if dropout > 0 else nn.Identity(), |
| nn.Linear(hidden, max_classes), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.mlp(x) |
|
|
|
|
| def cls_masked_loss( |
| logits: torch.Tensor, |
| y: torch.Tensor, |
| n_classes: torch.Tensor, |
| label_smoothing: float = 0.0, |
| row_mask: Optional[torch.Tensor] = None, |
| reduction: str = "mean", |
| ) -> torch.Tensor: |
| """ |
| Cross-entropy with per-task masking of unused class logits. |
| |
| Args: |
| logits: [B, n_query, MAX_CLASSES] |
| y: [B, n_query] integer labels in [0, n_classes_b) |
| n_classes: [B] integer count of valid classes per task in batch |
| label_smoothing: smoothing distributed ONLY across the valid class |
| range per task. Naive `F.cross_entropy(label_smoothing=ls)` over |
| logits-with-(-1e9)-on-invalid produces ls/C * 1e9 ≈ 5e6 per row |
| for invalid classes; here we smooth only over valid classes so |
| the invalid-class contribution is exactly zero. |
| row_mask: [B, n_query] bool, True = padded row to skip in loss |
| reduction: "mean" (scalar) or "none" (per-task tensor [B]) |
| |
| Each batch entry's unused logits are set to -inf so softmax respects |
| the per-task class count. |
| """ |
| B, N, C = logits.shape |
| device = logits.device |
|
|
| |
| arange_C = torch.arange(C, device=device)[None, :] |
| valid_mask = arange_C < n_classes[:, None] |
| valid_mask_full = valid_mask[:, None, :].expand(B, N, C) |
|
|
| |
| masked_logits = logits.masked_fill(~valid_mask_full, float("-inf")) |
| log_probs = F.log_softmax(masked_logits, dim=-1) |
|
|
| y_long = y.long() |
| nll = -log_probs.gather(-1, y_long.unsqueeze(-1)).squeeze(-1) |
|
|
| if label_smoothing > 0: |
| |
| n_valid = n_classes.float().clamp(min=1)[:, None] |
| |
| |
| valid_count = valid_mask.sum(dim=-1, keepdim=True).clamp(min=1).float() |
| |
| |
| log_probs_valid_only = log_probs.masked_fill(~valid_mask_full, 0.0) |
| mean_neg_log = -log_probs_valid_only.sum(dim=-1) / valid_count |
| loss_per_row = (1.0 - label_smoothing) * nll + label_smoothing * mean_neg_log |
| else: |
| loss_per_row = nll |
|
|
| if row_mask is not None: |
| keep = (~row_mask).float() |
| else: |
| keep = torch.ones_like(loss_per_row) |
|
|
| if reduction == "none": |
| denom = keep.sum(dim=-1).clamp(min=1) |
| return (loss_per_row * keep).sum(dim=-1) / denom |
| |
| total = (loss_per_row * keep).sum() |
| n = keep.sum().clamp(min=1) |
| return total / n |
|
|
|
|
| def cls_predict( |
| logits: torch.Tensor, |
| n_classes: torch.Tensor, |
| ) -> torch.Tensor: |
| """Argmax over the valid logit range per task. Returns [B, n_query].""" |
| B, N, C = logits.shape |
| device = logits.device |
| arange_C = torch.arange(C, device=device)[None, :] |
| valid_mask = arange_C < n_classes[:, None] |
| valid_mask_full = valid_mask[:, None, :].expand(B, N, C) |
| masked_logits = logits.masked_fill(~valid_mask_full, -1e9) |
| return masked_logits.argmax(dim=-1) |
|
|
|
|
| def cls_probs( |
| logits: torch.Tensor, |
| n_classes: torch.Tensor, |
| ) -> torch.Tensor: |
| """Softmax over the valid logit range per task. Invalid classes → 0 prob.""" |
| B, N, C = logits.shape |
| device = logits.device |
| arange_C = torch.arange(C, device=device)[None, :] |
| valid_mask = arange_C < n_classes[:, None] |
| valid_mask_full = valid_mask[:, None, :].expand(B, N, C) |
| masked_logits = logits.masked_fill(~valid_mask_full, -1e9) |
| return F.softmax(masked_logits, dim=-1) |
|
|
|
|
| |
|
|
|
|
| if __name__ == "__main__": |
| torch.manual_seed(0) |
|
|
| |
| head_r = BarDistributionHead(d_model=256, n_bins=1024) |
| trunk_out = torch.randn(2, 64, 256) |
| logits_r = head_r(trunk_out) |
| assert logits_r.shape == (2, 64, 1024) |
|
|
| y_ctx = torch.randn(2, 256) |
| y_q = torch.randn(2, 64) |
| y_ctx_s, y_q_s, mu, sigma = standardize_y_per_task(y_ctx, y_q) |
| assert y_ctx_s.shape == y_ctx.shape and y_q_s.shape == y_q.shape |
|
|
| loss_r = bar_distribution_loss(logits_r, y_q_s, head_r) |
| assert torch.isfinite(loss_r).item() |
| pred_mean = decode_bar_distribution(logits_r, head_r, mode="mean", y_mean=mu, y_std=sigma) |
| assert pred_mean.shape == (2, 64) |
|
|
| var_pred = predict_variance(logits_r, head_r, y_std=sigma) |
| assert var_pred.shape == (2, 64) |
| print(f"[reg] logits {tuple(logits_r.shape)} loss={loss_r.item():.4f} pred_mean[0,0]={pred_mean[0,0].item():+.3f}") |
|
|
| |
| head_c = BinClassificationHead(d_model=256, max_classes=10) |
| logits_c = head_c(trunk_out) |
| assert logits_c.shape == (2, 64, 10) |
|
|
| |
| n_classes = torch.tensor([3, 7]) |
| y_c = torch.stack([ |
| torch.randint(0, 3, (64,)), |
| torch.randint(0, 7, (64,)), |
| ]) |
| loss_c = cls_masked_loss(logits_c, y_c, n_classes) |
| assert torch.isfinite(loss_c).item() |
| preds = cls_predict(logits_c, n_classes) |
| probs = cls_probs(logits_c, n_classes) |
| |
| assert (probs[0, :, 3:] == 0.0).all(), "task 0 should have 0 prob on classes >= 3" |
| assert (probs[1, :, 7:] == 0.0).all(), "task 1 should have 0 prob on classes >= 7" |
| |
| assert (preds[0] < 3).all() and (preds[1] < 7).all() |
| |
| sums = probs.sum(dim=-1) |
| assert torch.allclose(sums, torch.ones_like(sums), atol=1e-5) |
| print(f"[cls] logits {tuple(logits_c.shape)} loss={loss_c.item():.4f} " |
| f"preds[0]={preds[0,:5].tolist()} preds[1]={preds[1,:5].tolist()}") |
|
|
| print("[OK] heads self-test passed") |
|
|