""" v11 prediction heads: bar-distribution regression + bin-based classification. Both heads are ICL-friendly: they take trunk output [B, n_query, d_model] and produce per-query predictions over a fixed-size output space (1024 bins for regression, MAX_CLASSES=10 logits for classification). ## Why bar-dist for regression Verified in v10: a single-Gaussian (μ, log_σ²) head can collapse to a constant when the trunk's output drifts (v9 failure mode). Bar-dist's 1024-bin cross-entropy can't collapse — every bin is independently supervised. Also matches TabPFN v2's reg head exactly. ## Why bin-based for classification v8/v10 used Linear(d_model, n_classes_max). For v11 we keep that structure but add per-task masking: a task with n_classes=3 only computes CE over the first 3 logits. This avoids the per-task linear-head trick used by TabPFN (where the head is built from class prototypes inside each task) which is harder to fit and gives no measurable gain at this scale per Expert 4's pre-mortem on v11. ## Trunk interface contract The trunk returns one tensor per task type: reg_out: [B, n_query, d_model] - last column, query rows, after reg trunk layers cls_out: [B, n_query, d_model] - last column, query rows, after cls trunk layers Both heads take this shape and produce per-query outputs. """ from __future__ import annotations import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F # Shared constants — keep aligned with task_sampler.SCMConfig.max_classes MAX_CLASSES: int = 10 # ─── 1. bar-distribution regression head (proven in v10) ───────────────────── def default_bin_edges(n_bins: int = 1024, tail: float = 0.0001) -> torch.Tensor: """Quantile-based bin edges on N(0,1), symmetric around 0. With n_bins=1024 + tail=0.0001, outer bins cover N⁻¹(0.0001) ≈ -3.72 to N⁻¹(0.9999) ≈ +3.72 — wide enough to keep the heavy-tailed targets that v11's `apply_heavy_tail_noise` extension is supposed to be teaching from saturating the outermost bins. Earlier (tail=0.001) capped at ±3.09, which forced ~0.5% of any heavy-tailed task's targets into the outermost two bins (each ≈3σ wide) where CE has no resolution. """ probs = torch.linspace(tail, 1.0 - tail, n_bins + 1) edges = math.sqrt(2) * torch.erfinv(2 * probs - 1) return edges class BarDistributionHead(nn.Module): """ Bar-distribution (Riemann) regression head. Forward: x [..., d_model] → logits [..., n_bins]. Loss is CE between predicted bin distribution and the bin containing the (per-task standardized) target. """ def __init__( self, d_model: int, n_bins: int = 1024, hidden_multiplier: int = 2, dropout: float = 0.0, bin_edges: Optional[torch.Tensor] = None, ): super().__init__() self.d_model = d_model self.n_bins = n_bins if bin_edges is None: bin_edges = default_bin_edges(n_bins) assert bin_edges.shape == (n_bins + 1,) self.register_buffer("bin_edges", bin_edges.float()) centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) self.register_buffer("bin_centers", centers.float()) hidden = d_model * hidden_multiplier self.mlp = nn.Sequential( nn.Linear(d_model, hidden), nn.GELU(), nn.Dropout(dropout) if dropout > 0 else nn.Identity(), nn.Linear(hidden, n_bins), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.mlp(x) def predict_bin_ids(self, y_standardized: torch.Tensor) -> torch.Tensor: idx = torch.bucketize(y_standardized, self.bin_edges[1:-1].to(y_standardized.device)) return torch.clamp(idx, min=0, max=self.n_bins - 1) def standardize_y_per_task( y_ctx: torch.Tensor, y_query: Optional[torch.Tensor] = None, std_clip: float = 1e-3, ): """Per-task z-score using context-only stats, clipping std before division.""" assert y_ctx.dtype == torch.float32, "y must be float32 for stable z-scoring" mean = y_ctx.mean(dim=-1, keepdim=True) std = y_ctx.std(dim=-1, keepdim=True, unbiased=False) std_clipped = torch.clamp(std, min=std_clip) y_ctx_std = (y_ctx - mean) / std_clipped y_q_std = None if y_query is None else (y_query - mean) / std_clipped return y_ctx_std, y_q_std, mean.squeeze(-1), std_clipped.squeeze(-1) def bar_distribution_loss( logits: torch.Tensor, y_standardized: torch.Tensor, head: BarDistributionHead, label_smoothing: float = 0.0, row_mask: Optional[torch.Tensor] = None, reduction: str = "mean", ) -> torch.Tensor: """Cross-entropy over n_bins with optional row-mask for padded query rows. Args: logits: [..., n_bins] y_standardized: [...] standardized targets (per-task z-scored) head: BarDistributionHead — needed for its bin structure row_mask: optional bool mask, True = padded (excluded from loss) reduction: "mean" returns a scalar; "none" returns per-task means [B] """ bin_ids = head.predict_bin_ids(y_standardized) flat_logits = logits.reshape(-1, head.n_bins) flat_targets = bin_ids.reshape(-1) per_token = F.cross_entropy( flat_logits, flat_targets, label_smoothing=label_smoothing, reduction="none", ).reshape(*y_standardized.shape) if row_mask is not None: keep = (~row_mask).float() else: keep = torch.ones_like(per_token) if reduction == "none": denom = keep.sum(dim=-1).clamp(min=1) return (per_token * keep).sum(dim=-1) / denom total = (per_token * keep).sum() n = keep.sum().clamp(min=1) return total / n def decode_bar_distribution( logits: torch.Tensor, head: BarDistributionHead, mode: str = "mean", quantile: float = 0.5, y_mean: Optional[torch.Tensor] = None, y_std: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Decode bar-dist logits to point predictions in original y space.""" probs = F.softmax(logits, dim=-1) centers = head.bin_centers.to(logits.device) if mode == "mean": pred_std = (probs * centers).sum(dim=-1) elif mode in ("median", "quantile"): q = 0.5 if mode == "median" else quantile cdf = probs.cumsum(dim=-1) idx = torch.searchsorted(cdf, torch.full_like(cdf[..., :1], q)).squeeze(-1) idx = torch.clamp(idx, 0, head.n_bins - 1) pred_std = centers[idx] else: raise ValueError(f"Unknown mode: {mode}") if y_mean is not None and y_std is not None: if y_mean.dim() != pred_std.dim(): y_mean = y_mean.unsqueeze(-1) y_std = y_std.unsqueeze(-1) return pred_std * y_std + y_mean return pred_std def predict_variance( logits: torch.Tensor, head: BarDistributionHead, y_std: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Predictive variance from the bar distribution (for coverage / calibration).""" probs = F.softmax(logits, dim=-1) centers = head.bin_centers.to(logits.device) mean = (probs * centers).sum(dim=-1, keepdim=True) var_std = (probs * (centers - mean) ** 2).sum(dim=-1) if y_std is not None: if y_std.dim() != var_std.dim(): y_std = y_std.unsqueeze(-1) return var_std * y_std * y_std return var_std # ─── 2. bin-based classification head (variable n_classes per task) ────────── class BinClassificationHead(nn.Module): """ Classification head that emits MAX_CLASSES logits; the trainer masks out logits ≥ task.n_classes before computing CE. Architecture: same 2-layer MLP as the bar-dist head, but output is over MAX_CLASSES (default 10) instead of n_bins. Forward: x [..., d_model] → logits [..., MAX_CLASSES]. """ def __init__( self, d_model: int, max_classes: int = MAX_CLASSES, hidden_multiplier: int = 2, dropout: float = 0.0, ): super().__init__() self.d_model = d_model self.max_classes = max_classes hidden = d_model * hidden_multiplier self.mlp = nn.Sequential( nn.Linear(d_model, hidden), nn.GELU(), nn.Dropout(dropout) if dropout > 0 else nn.Identity(), nn.Linear(hidden, max_classes), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.mlp(x) def cls_masked_loss( logits: torch.Tensor, y: torch.Tensor, n_classes: torch.Tensor, label_smoothing: float = 0.0, row_mask: Optional[torch.Tensor] = None, reduction: str = "mean", ) -> torch.Tensor: """ Cross-entropy with per-task masking of unused class logits. Args: logits: [B, n_query, MAX_CLASSES] y: [B, n_query] integer labels in [0, n_classes_b) n_classes: [B] integer count of valid classes per task in batch label_smoothing: smoothing distributed ONLY across the valid class range per task. Naive `F.cross_entropy(label_smoothing=ls)` over logits-with-(-1e9)-on-invalid produces ls/C * 1e9 ≈ 5e6 per row for invalid classes; here we smooth only over valid classes so the invalid-class contribution is exactly zero. row_mask: [B, n_query] bool, True = padded row to skip in loss reduction: "mean" (scalar) or "none" (per-task tensor [B]) Each batch entry's unused logits are set to -inf so softmax respects the per-task class count. """ B, N, C = logits.shape device = logits.device # Per-class validity mask arange_C = torch.arange(C, device=device)[None, :] valid_mask = arange_C < n_classes[:, None] # [B, C] valid_mask_full = valid_mask[:, None, :].expand(B, N, C) # [B, N, C] # Mask invalid logits and compute log_softmax over valid range masked_logits = logits.masked_fill(~valid_mask_full, float("-inf")) log_probs = F.log_softmax(masked_logits, dim=-1) # [B, N, C] y_long = y.long() nll = -log_probs.gather(-1, y_long.unsqueeze(-1)).squeeze(-1) # [B, N] if label_smoothing > 0: # Smooth only across valid classes: target_dist[c valid] = (1-ls)*[c==y] + ls/n_valid n_valid = n_classes.float().clamp(min=1)[:, None] # [B, 1] # Smoothed loss = (1-ls) * NLL + ls * mean_over_valid_classes(-log_probs) # mean of -log_probs over valid classes is what we want as the smoothing term valid_count = valid_mask.sum(dim=-1, keepdim=True).clamp(min=1).float() # [B, 1] # Sum log_probs over valid only (invalid rows have -inf, masked_fill them to 0 # for the sum so the smoothing term stays finite) log_probs_valid_only = log_probs.masked_fill(~valid_mask_full, 0.0) mean_neg_log = -log_probs_valid_only.sum(dim=-1) / valid_count # [B, N] loss_per_row = (1.0 - label_smoothing) * nll + label_smoothing * mean_neg_log else: loss_per_row = nll if row_mask is not None: keep = (~row_mask).float() else: keep = torch.ones_like(loss_per_row) if reduction == "none": denom = keep.sum(dim=-1).clamp(min=1) # [B] return (loss_per_row * keep).sum(dim=-1) / denom # "mean" total = (loss_per_row * keep).sum() n = keep.sum().clamp(min=1) return total / n def cls_predict( logits: torch.Tensor, n_classes: torch.Tensor, ) -> torch.Tensor: """Argmax over the valid logit range per task. Returns [B, n_query].""" B, N, C = logits.shape device = logits.device arange_C = torch.arange(C, device=device)[None, :] valid_mask = arange_C < n_classes[:, None] valid_mask_full = valid_mask[:, None, :].expand(B, N, C) masked_logits = logits.masked_fill(~valid_mask_full, -1e9) return masked_logits.argmax(dim=-1) def cls_probs( logits: torch.Tensor, n_classes: torch.Tensor, ) -> torch.Tensor: """Softmax over the valid logit range per task. Invalid classes → 0 prob.""" B, N, C = logits.shape device = logits.device arange_C = torch.arange(C, device=device)[None, :] valid_mask = arange_C < n_classes[:, None] valid_mask_full = valid_mask[:, None, :].expand(B, N, C) masked_logits = logits.masked_fill(~valid_mask_full, -1e9) return F.softmax(masked_logits, dim=-1) # ─── 3. self-test: shapes, masking, decoding all roundtrip ─────────────────── if __name__ == "__main__": torch.manual_seed(0) # Reg head smoke head_r = BarDistributionHead(d_model=256, n_bins=1024) trunk_out = torch.randn(2, 64, 256) # [B=2, n_query=64, d_model=256] logits_r = head_r(trunk_out) assert logits_r.shape == (2, 64, 1024) y_ctx = torch.randn(2, 256) # [B, n_ctx] y_q = torch.randn(2, 64) y_ctx_s, y_q_s, mu, sigma = standardize_y_per_task(y_ctx, y_q) assert y_ctx_s.shape == y_ctx.shape and y_q_s.shape == y_q.shape loss_r = bar_distribution_loss(logits_r, y_q_s, head_r) assert torch.isfinite(loss_r).item() pred_mean = decode_bar_distribution(logits_r, head_r, mode="mean", y_mean=mu, y_std=sigma) assert pred_mean.shape == (2, 64) var_pred = predict_variance(logits_r, head_r, y_std=sigma) assert var_pred.shape == (2, 64) print(f"[reg] logits {tuple(logits_r.shape)} loss={loss_r.item():.4f} pred_mean[0,0]={pred_mean[0,0].item():+.3f}") # Cls head smoke head_c = BinClassificationHead(d_model=256, max_classes=10) logits_c = head_c(trunk_out) assert logits_c.shape == (2, 64, 10) # Task 0: 3-class, Task 1: 7-class n_classes = torch.tensor([3, 7]) y_c = torch.stack([ torch.randint(0, 3, (64,)), torch.randint(0, 7, (64,)), ]) loss_c = cls_masked_loss(logits_c, y_c, n_classes) assert torch.isfinite(loss_c).item() preds = cls_predict(logits_c, n_classes) probs = cls_probs(logits_c, n_classes) # Verify masking: invalid classes have 0 probability assert (probs[0, :, 3:] == 0.0).all(), "task 0 should have 0 prob on classes >= 3" assert (probs[1, :, 7:] == 0.0).all(), "task 1 should have 0 prob on classes >= 7" # Verify predictions stay within valid range assert (preds[0] < 3).all() and (preds[1] < 7).all() # Verify softmax sums to 1 over valid logits sums = probs.sum(dim=-1) assert torch.allclose(sums, torch.ones_like(sums), atol=1e-5) print(f"[cls] logits {tuple(logits_c.shape)} loss={loss_c.item():.4f} " f"preds[0]={preds[0,:5].tolist()} preds[1]={preds[1,:5].tolist()}") print("[OK] heads self-test passed")