01RAI's picture
PredictLM v11.0 + Mini ship-bundle
4ea7152 verified
"""
v11 prediction heads: bar-distribution regression + bin-based classification.
Both heads are ICL-friendly: they take trunk output [B, n_query, d_model]
and produce per-query predictions over a fixed-size output space (1024 bins
for regression, MAX_CLASSES=10 logits for classification).
## Why bar-dist for regression
Verified in v10: a single-Gaussian (μ, log_σ²) head can collapse to a
constant when the trunk's output drifts (v9 failure mode). Bar-dist's
1024-bin cross-entropy can't collapse — every bin is independently
supervised. Also matches TabPFN v2's reg head exactly.
## Why bin-based for classification
v8/v10 used Linear(d_model, n_classes_max). For v11 we keep that structure
but add per-task masking: a task with n_classes=3 only computes CE over
the first 3 logits. This avoids the per-task linear-head trick used by
TabPFN (where the head is built from class prototypes inside each task)
which is harder to fit and gives no measurable gain at this scale per
Expert 4's pre-mortem on v11.
## Trunk interface contract
The trunk returns one tensor per task type:
reg_out: [B, n_query, d_model] - last column, query rows, after reg trunk layers
cls_out: [B, n_query, d_model] - last column, query rows, after cls trunk layers
Both heads take this shape and produce per-query outputs.
"""
from __future__ import annotations
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
# Shared constants — keep aligned with task_sampler.SCMConfig.max_classes
MAX_CLASSES: int = 10
# ─── 1. bar-distribution regression head (proven in v10) ─────────────────────
def default_bin_edges(n_bins: int = 1024, tail: float = 0.0001) -> torch.Tensor:
"""Quantile-based bin edges on N(0,1), symmetric around 0.
With n_bins=1024 + tail=0.0001, outer bins cover N⁻¹(0.0001) ≈ -3.72 to
N⁻¹(0.9999) ≈ +3.72 — wide enough to keep the heavy-tailed targets
that v11's `apply_heavy_tail_noise` extension is supposed to be
teaching from saturating the outermost bins. Earlier (tail=0.001)
capped at ±3.09, which forced ~0.5% of any heavy-tailed task's
targets into the outermost two bins (each ≈3σ wide) where CE has
no resolution.
"""
probs = torch.linspace(tail, 1.0 - tail, n_bins + 1)
edges = math.sqrt(2) * torch.erfinv(2 * probs - 1)
return edges
class BarDistributionHead(nn.Module):
"""
Bar-distribution (Riemann) regression head.
Forward: x [..., d_model] → logits [..., n_bins].
Loss is CE between predicted bin distribution and the bin containing
the (per-task standardized) target.
"""
def __init__(
self,
d_model: int,
n_bins: int = 1024,
hidden_multiplier: int = 2,
dropout: float = 0.0,
bin_edges: Optional[torch.Tensor] = None,
):
super().__init__()
self.d_model = d_model
self.n_bins = n_bins
if bin_edges is None:
bin_edges = default_bin_edges(n_bins)
assert bin_edges.shape == (n_bins + 1,)
self.register_buffer("bin_edges", bin_edges.float())
centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
self.register_buffer("bin_centers", centers.float())
hidden = d_model * hidden_multiplier
self.mlp = nn.Sequential(
nn.Linear(d_model, hidden),
nn.GELU(),
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
nn.Linear(hidden, n_bins),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
def predict_bin_ids(self, y_standardized: torch.Tensor) -> torch.Tensor:
idx = torch.bucketize(y_standardized, self.bin_edges[1:-1].to(y_standardized.device))
return torch.clamp(idx, min=0, max=self.n_bins - 1)
def standardize_y_per_task(
y_ctx: torch.Tensor,
y_query: Optional[torch.Tensor] = None,
std_clip: float = 1e-3,
):
"""Per-task z-score using context-only stats, clipping std before division."""
assert y_ctx.dtype == torch.float32, "y must be float32 for stable z-scoring"
mean = y_ctx.mean(dim=-1, keepdim=True)
std = y_ctx.std(dim=-1, keepdim=True, unbiased=False)
std_clipped = torch.clamp(std, min=std_clip)
y_ctx_std = (y_ctx - mean) / std_clipped
y_q_std = None if y_query is None else (y_query - mean) / std_clipped
return y_ctx_std, y_q_std, mean.squeeze(-1), std_clipped.squeeze(-1)
def bar_distribution_loss(
logits: torch.Tensor,
y_standardized: torch.Tensor,
head: BarDistributionHead,
label_smoothing: float = 0.0,
row_mask: Optional[torch.Tensor] = None,
reduction: str = "mean",
) -> torch.Tensor:
"""Cross-entropy over n_bins with optional row-mask for padded query rows.
Args:
logits: [..., n_bins]
y_standardized: [...] standardized targets (per-task z-scored)
head: BarDistributionHead — needed for its bin structure
row_mask: optional bool mask, True = padded (excluded from loss)
reduction: "mean" returns a scalar; "none" returns per-task means [B]
"""
bin_ids = head.predict_bin_ids(y_standardized)
flat_logits = logits.reshape(-1, head.n_bins)
flat_targets = bin_ids.reshape(-1)
per_token = F.cross_entropy(
flat_logits, flat_targets,
label_smoothing=label_smoothing,
reduction="none",
).reshape(*y_standardized.shape)
if row_mask is not None:
keep = (~row_mask).float()
else:
keep = torch.ones_like(per_token)
if reduction == "none":
denom = keep.sum(dim=-1).clamp(min=1)
return (per_token * keep).sum(dim=-1) / denom
total = (per_token * keep).sum()
n = keep.sum().clamp(min=1)
return total / n
def decode_bar_distribution(
logits: torch.Tensor,
head: BarDistributionHead,
mode: str = "mean",
quantile: float = 0.5,
y_mean: Optional[torch.Tensor] = None,
y_std: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Decode bar-dist logits to point predictions in original y space."""
probs = F.softmax(logits, dim=-1)
centers = head.bin_centers.to(logits.device)
if mode == "mean":
pred_std = (probs * centers).sum(dim=-1)
elif mode in ("median", "quantile"):
q = 0.5 if mode == "median" else quantile
cdf = probs.cumsum(dim=-1)
idx = torch.searchsorted(cdf, torch.full_like(cdf[..., :1], q)).squeeze(-1)
idx = torch.clamp(idx, 0, head.n_bins - 1)
pred_std = centers[idx]
else:
raise ValueError(f"Unknown mode: {mode}")
if y_mean is not None and y_std is not None:
if y_mean.dim() != pred_std.dim():
y_mean = y_mean.unsqueeze(-1)
y_std = y_std.unsqueeze(-1)
return pred_std * y_std + y_mean
return pred_std
def predict_variance(
logits: torch.Tensor,
head: BarDistributionHead,
y_std: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Predictive variance from the bar distribution (for coverage / calibration)."""
probs = F.softmax(logits, dim=-1)
centers = head.bin_centers.to(logits.device)
mean = (probs * centers).sum(dim=-1, keepdim=True)
var_std = (probs * (centers - mean) ** 2).sum(dim=-1)
if y_std is not None:
if y_std.dim() != var_std.dim():
y_std = y_std.unsqueeze(-1)
return var_std * y_std * y_std
return var_std
# ─── 2. bin-based classification head (variable n_classes per task) ──────────
class BinClassificationHead(nn.Module):
"""
Classification head that emits MAX_CLASSES logits; the trainer masks
out logits ≥ task.n_classes before computing CE.
Architecture: same 2-layer MLP as the bar-dist head, but output is
over MAX_CLASSES (default 10) instead of n_bins.
Forward: x [..., d_model] → logits [..., MAX_CLASSES].
"""
def __init__(
self,
d_model: int,
max_classes: int = MAX_CLASSES,
hidden_multiplier: int = 2,
dropout: float = 0.0,
):
super().__init__()
self.d_model = d_model
self.max_classes = max_classes
hidden = d_model * hidden_multiplier
self.mlp = nn.Sequential(
nn.Linear(d_model, hidden),
nn.GELU(),
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
nn.Linear(hidden, max_classes),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
def cls_masked_loss(
logits: torch.Tensor,
y: torch.Tensor,
n_classes: torch.Tensor,
label_smoothing: float = 0.0,
row_mask: Optional[torch.Tensor] = None,
reduction: str = "mean",
) -> torch.Tensor:
"""
Cross-entropy with per-task masking of unused class logits.
Args:
logits: [B, n_query, MAX_CLASSES]
y: [B, n_query] integer labels in [0, n_classes_b)
n_classes: [B] integer count of valid classes per task in batch
label_smoothing: smoothing distributed ONLY across the valid class
range per task. Naive `F.cross_entropy(label_smoothing=ls)` over
logits-with-(-1e9)-on-invalid produces ls/C * 1e9 ≈ 5e6 per row
for invalid classes; here we smooth only over valid classes so
the invalid-class contribution is exactly zero.
row_mask: [B, n_query] bool, True = padded row to skip in loss
reduction: "mean" (scalar) or "none" (per-task tensor [B])
Each batch entry's unused logits are set to -inf so softmax respects
the per-task class count.
"""
B, N, C = logits.shape
device = logits.device
# Per-class validity mask
arange_C = torch.arange(C, device=device)[None, :]
valid_mask = arange_C < n_classes[:, None] # [B, C]
valid_mask_full = valid_mask[:, None, :].expand(B, N, C) # [B, N, C]
# Mask invalid logits and compute log_softmax over valid range
masked_logits = logits.masked_fill(~valid_mask_full, float("-inf"))
log_probs = F.log_softmax(masked_logits, dim=-1) # [B, N, C]
y_long = y.long()
nll = -log_probs.gather(-1, y_long.unsqueeze(-1)).squeeze(-1) # [B, N]
if label_smoothing > 0:
# Smooth only across valid classes: target_dist[c valid] = (1-ls)*[c==y] + ls/n_valid
n_valid = n_classes.float().clamp(min=1)[:, None] # [B, 1]
# Smoothed loss = (1-ls) * NLL + ls * mean_over_valid_classes(-log_probs)
# mean of -log_probs over valid classes is what we want as the smoothing term
valid_count = valid_mask.sum(dim=-1, keepdim=True).clamp(min=1).float() # [B, 1]
# Sum log_probs over valid only (invalid rows have -inf, masked_fill them to 0
# for the sum so the smoothing term stays finite)
log_probs_valid_only = log_probs.masked_fill(~valid_mask_full, 0.0)
mean_neg_log = -log_probs_valid_only.sum(dim=-1) / valid_count # [B, N]
loss_per_row = (1.0 - label_smoothing) * nll + label_smoothing * mean_neg_log
else:
loss_per_row = nll
if row_mask is not None:
keep = (~row_mask).float()
else:
keep = torch.ones_like(loss_per_row)
if reduction == "none":
denom = keep.sum(dim=-1).clamp(min=1) # [B]
return (loss_per_row * keep).sum(dim=-1) / denom
# "mean"
total = (loss_per_row * keep).sum()
n = keep.sum().clamp(min=1)
return total / n
def cls_predict(
logits: torch.Tensor,
n_classes: torch.Tensor,
) -> torch.Tensor:
"""Argmax over the valid logit range per task. Returns [B, n_query]."""
B, N, C = logits.shape
device = logits.device
arange_C = torch.arange(C, device=device)[None, :]
valid_mask = arange_C < n_classes[:, None]
valid_mask_full = valid_mask[:, None, :].expand(B, N, C)
masked_logits = logits.masked_fill(~valid_mask_full, -1e9)
return masked_logits.argmax(dim=-1)
def cls_probs(
logits: torch.Tensor,
n_classes: torch.Tensor,
) -> torch.Tensor:
"""Softmax over the valid logit range per task. Invalid classes → 0 prob."""
B, N, C = logits.shape
device = logits.device
arange_C = torch.arange(C, device=device)[None, :]
valid_mask = arange_C < n_classes[:, None]
valid_mask_full = valid_mask[:, None, :].expand(B, N, C)
masked_logits = logits.masked_fill(~valid_mask_full, -1e9)
return F.softmax(masked_logits, dim=-1)
# ─── 3. self-test: shapes, masking, decoding all roundtrip ───────────────────
if __name__ == "__main__":
torch.manual_seed(0)
# Reg head smoke
head_r = BarDistributionHead(d_model=256, n_bins=1024)
trunk_out = torch.randn(2, 64, 256) # [B=2, n_query=64, d_model=256]
logits_r = head_r(trunk_out)
assert logits_r.shape == (2, 64, 1024)
y_ctx = torch.randn(2, 256) # [B, n_ctx]
y_q = torch.randn(2, 64)
y_ctx_s, y_q_s, mu, sigma = standardize_y_per_task(y_ctx, y_q)
assert y_ctx_s.shape == y_ctx.shape and y_q_s.shape == y_q.shape
loss_r = bar_distribution_loss(logits_r, y_q_s, head_r)
assert torch.isfinite(loss_r).item()
pred_mean = decode_bar_distribution(logits_r, head_r, mode="mean", y_mean=mu, y_std=sigma)
assert pred_mean.shape == (2, 64)
var_pred = predict_variance(logits_r, head_r, y_std=sigma)
assert var_pred.shape == (2, 64)
print(f"[reg] logits {tuple(logits_r.shape)} loss={loss_r.item():.4f} pred_mean[0,0]={pred_mean[0,0].item():+.3f}")
# Cls head smoke
head_c = BinClassificationHead(d_model=256, max_classes=10)
logits_c = head_c(trunk_out)
assert logits_c.shape == (2, 64, 10)
# Task 0: 3-class, Task 1: 7-class
n_classes = torch.tensor([3, 7])
y_c = torch.stack([
torch.randint(0, 3, (64,)),
torch.randint(0, 7, (64,)),
])
loss_c = cls_masked_loss(logits_c, y_c, n_classes)
assert torch.isfinite(loss_c).item()
preds = cls_predict(logits_c, n_classes)
probs = cls_probs(logits_c, n_classes)
# Verify masking: invalid classes have 0 probability
assert (probs[0, :, 3:] == 0.0).all(), "task 0 should have 0 prob on classes >= 3"
assert (probs[1, :, 7:] == 0.0).all(), "task 1 should have 0 prob on classes >= 7"
# Verify predictions stay within valid range
assert (preds[0] < 3).all() and (preds[1] < 7).all()
# Verify softmax sums to 1 over valid logits
sums = probs.sum(dim=-1)
assert torch.allclose(sums, torch.ones_like(sums), atol=1e-5)
print(f"[cls] logits {tuple(logits_c.shape)} loss={loss_c.item():.4f} "
f"preds[0]={preds[0,:5].tolist()} preds[1]={preds[1,:5].tolist()}")
print("[OK] heads self-test passed")