01RAI's picture
PredictLM v11.0 + Mini ship-bundle
d303b8b verified
"""
v11 model — same trunk as v8 so we can warm-start from v8's final checkpoint.
The architecture differences vs v8 are the prediction heads:
v8: reg_head = Linear(d_model, 2) # mean, log_var
v8: cls_head = Linear(d_model, max_classes)
v11: reg_head = BarDistributionHead(d_model, n_bins=1024)
v11: cls_head = BinClassificationHead(d_model, max_classes=10)
Everything else (feature_weights, y_embed, class_embed, type_embed,
shared_layers, reg_layers, cls_layers, *_norm) keeps the same module
names and parameter shapes, so:
v11_model.load_state_dict(v8_ckpt, strict=False)
will load the trunk and leave only the heads as randomly-initialized.
The v11 trainer's head-warmup phase trains only the heads + reg_norm /
cls_norm for the first 5k steps, exactly as v10 did.
Tokenization is identical to v8: 2D grid [B, n_rows, n_cols, d_model]
with one token per cell. Each layer alternates feature-attention (within
a row) and datapoint-attention (within a column with the
context-vs-query mask).
For now, v11 SKIPS v8's metadata conditioning (the column-statistics
encoder). The v11 plan defers architectural cleanups to v13; the goal
here is data-prior work, not arch work. Once warm-started, the
metadata-related parameters in the v8 ckpt are simply ignored.
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as grad_checkpoint
from .heads import (
BarDistributionHead,
BinClassificationHead,
bar_distribution_loss,
cls_masked_loss,
standardize_y_per_task,
decode_bar_distribution,
cls_predict,
)
# ─── config ──────────────────────────────────────────────────────────────────
@dataclass
class V11Config:
d_model: int = 256
n_layers: int = 12 # 8 shared + 4 task-specific per branch
n_heads: int = 8
d_ffn: int = 1024
dropout: float = 0.0
max_features: int = 128 # warm-start slices v8's feature_weights[500] → [128] in warm_start_from_v8
max_classes: int = 10
max_context: int = 1024
max_query: int = 256
n_periodic_freqs: int = 8
n_bins: int = 1024
cls_label_smoothing: float = 0.05
# v11.0.6-tiny architecture toggles. Defaults preserve v11.0 behavior so
# existing ckpts load unchanged via warm_start_from_v8 / strict=False.
mlp_variant: str = "gelu" # "gelu" (legacy) or "swiglu"
norm_variant: str = "layernorm" # "layernorm" (legacy) or "rmsnorm"
# ALBERT-style cross-layer parameter sharing. share_factor>1 means the
# `n_layers`-deep stack uses only `n_layers // share_factor` UNIQUE
# modules; each unique block is applied `share_factor` times via index
# cycling. share_factor=1 = legacy (no sharing).
share_factor: int = 1
def v11_default_config() -> V11Config:
return V11Config()
# ─── v11.0.6-tiny blocks (drop-in upgrades behind config flag) ──────────────
class SwiGLUFFN(nn.Module):
"""SwiGLU MLP (Shazeer 2020, arXiv 2002.05202). Default in PaLM/LLaMA.
Pattern: Linear(d, 8d/3) gate + Linear(d, 8d/3) value, silu*gate, Linear(8d/3, d).
Hidden dim scaled to (8/3)d_ffn/4 = (2/3)d_ffn to hold param count constant
vs the legacy GELU FFN (Linear(d, d_ffn), GELU, Linear(d_ffn, d)).
"""
def __init__(self, d_model: int, d_ffn: int):
super().__init__()
# Match legacy FFN's parameter count: legacy is 2 * d_model * d_ffn.
# SwiGLU is 3 linears (gate, value, out), each d_model * d_hidden.
# So set d_hidden = (2/3) * d_ffn for parity.
d_hidden = int(round(d_ffn * 2 / 3))
self.w_gate = nn.Linear(d_model, d_hidden, bias=False)
self.w_value = nn.Linear(d_model, d_hidden, bias=False)
self.w_out = nn.Linear(d_hidden, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_out(F.silu(self.w_gate(x)) * self.w_value(x))
class RMSNorm(nn.Module):
"""Root Mean Square Layer Norm (Zhang & Sennrich 2019). LLaMA default.
No mean subtraction, no learned bias. Cheaper than LayerNorm; works as
a drop-in for transformer pre-norm.
"""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.weight * (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps))
def _build_ffn(d_model: int, d_ffn: int, variant: str = "gelu") -> nn.Module:
"""Factory: return GELU MLP (legacy) or SwiGLU MLP based on variant."""
if variant == "swiglu":
return SwiGLUFFN(d_model, d_ffn)
return nn.Sequential(
nn.Linear(d_model, d_ffn),
nn.GELU(),
nn.Linear(d_ffn, d_model),
)
def _build_norm(d_model: int, variant: str = "layernorm") -> nn.Module:
"""Factory: return LayerNorm (legacy) or RMSNorm based on variant."""
if variant == "rmsnorm":
return RMSNorm(d_model)
return nn.LayerNorm(d_model)
# ─── blocks (verbatim from v8 so state_dict keys match) ───────────────────────
class FlashPreLNAttention(nn.Module):
"""Pre-LN attention + FFN using F.scaled_dot_product_attention (Flash)."""
def __init__(self, d_model: int, n_heads: int, d_ffn: int, dropout: float = 0.0,
mlp_variant: str = "gelu", norm_variant: str = "layernorm"):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.d_model = d_model
self.norm1 = _build_norm(d_model, norm_variant)
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.o_proj = nn.Linear(d_model, d_model)
self.norm2 = _build_norm(d_model, norm_variant)
self.ffn = _build_ffn(d_model, d_ffn, mlp_variant)
def _heads(self, x: torch.Tensor) -> torch.Tensor:
B, S, _ = x.shape
return x.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
def forward(
self,
x: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = x
x = self.norm1(x)
q = self._heads(self.q_proj(x))
k = self._heads(self.k_proj(x))
v = self._heads(self.v_proj(x))
sdpa_mask = None
if attn_mask is not None:
# attn_mask may be 2D [seq, seq] (shared across batch) or 3D [B, seq, seq]
if attn_mask.dim() == 2:
amask = torch.zeros_like(attn_mask, dtype=q.dtype)
amask.masked_fill_(attn_mask, float("-inf"))
sdpa_mask = amask.unsqueeze(0).unsqueeze(0) # [1,1,seq,seq]
else:
amask = torch.zeros_like(attn_mask, dtype=q.dtype)
amask.masked_fill_(attn_mask, float("-inf"))
sdpa_mask = amask.unsqueeze(1) # [B,1,seq,seq]
if key_padding_mask is not None:
pad_mask = torch.zeros(
key_padding_mask.shape[0], 1, 1, key_padding_mask.shape[1],
dtype=q.dtype, device=q.device,
)
pad_mask.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"))
sdpa_mask = pad_mask if sdpa_mask is None else sdpa_mask + pad_mask
attn_out = F.scaled_dot_product_attention(q, k, v, attn_mask=sdpa_mask, dropout_p=0.0)
attn_out = attn_out.transpose(1, 2).contiguous().view(x.shape[0], x.shape[1], self.d_model)
x = self.o_proj(attn_out) + residual
residual = x
x = self.norm2(x)
x = self.ffn(x) + residual
return x
class AlternatingLayerV8(nn.Module):
"""Feature attention (within rows) → Datapoint attention (within cols).
Name matches v8 verbatim so state_dict keys align for warm-start.
"""
def __init__(self, d_model: int, n_heads: int, d_ffn: int, dropout: float = 0.0,
mlp_variant: str = "gelu", norm_variant: str = "layernorm"):
super().__init__()
self.feature_attn = FlashPreLNAttention(d_model, n_heads, d_ffn, dropout,
mlp_variant=mlp_variant, norm_variant=norm_variant)
self.datapoint_attn = FlashPreLNAttention(d_model, n_heads, d_ffn, dropout,
mlp_variant=mlp_variant, norm_variant=norm_variant)
def forward(
self,
x: torch.Tensor, # [B, n_rows, n_cols, d_model]
feature_pad_mask: torch.Tensor,
datapoint_mask: torch.Tensor, # [n_rows, n_rows] OR [B, n_rows, n_rows]
) -> torch.Tensor:
B, n_rows, n_cols, d_model = x.shape
# within-row feature attn
x_feat = x.reshape(B * n_rows, n_cols, d_model)
feat_pad = feature_pad_mask.unsqueeze(1).expand(B, n_rows, n_cols).reshape(B * n_rows, n_cols)
x_feat = self.feature_attn(x_feat, key_padding_mask=feat_pad)
x = x_feat.reshape(B, n_rows, n_cols, d_model)
# within-col datapoint attn — expand per-batch mask along n_cols if needed
x_data = x.permute(0, 2, 1, 3).reshape(B * n_cols, n_rows, d_model)
if datapoint_mask.dim() == 3:
# [B, n_rows, n_rows] → [B*n_cols, n_rows, n_rows]
dp_mask = (
datapoint_mask.unsqueeze(1)
.expand(B, n_cols, n_rows, n_rows)
.reshape(B * n_cols, n_rows, n_rows)
)
else:
dp_mask = datapoint_mask
x_data = self.datapoint_attn(x_data, attn_mask=dp_mask)
x = x_data.reshape(B, n_cols, n_rows, d_model).permute(0, 2, 1, 3)
return x
# ─── numerical-value embedding (matches v8's NumericalFeatureEmbedding) ──────
class NumericalFeatureEmbedding(nn.Module):
"""Embed a scalar numerical value into a d_model vector via Fourier features."""
def __init__(self, d_model: int = 256, n_freqs: int = 8):
super().__init__()
self.d_model = d_model
self.n_freqs = n_freqs
freqs = 2.0 ** torch.arange(n_freqs, dtype=torch.float32)
self.register_buffer("freqs", freqs)
in_dim = 1 + 1 + 2 * n_freqs # sign + log_mag + sin/cos at each freq
self.mlp = nn.Sequential(
nn.Linear(in_dim, d_model),
nn.GELU(),
nn.Linear(d_model, d_model),
)
self.missing_token = nn.Parameter(torch.randn(d_model) * 0.02)
def forward(self, values: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
sign = torch.sign(values)
log_mag = torch.log1p(torch.abs(values))
# Sinusoidal features at multiple frequencies
f = self.freqs.to(values.device).view(*([1] * (values.dim() - 1)), self.n_freqs)
scaled = values.unsqueeze(-1) * f
sins = torch.sin(scaled)
coss = torch.cos(scaled)
feats = torch.cat([sign.unsqueeze(-1), log_mag.unsqueeze(-1), sins, coss], dim=-1)
emb = self.mlp(feats)
if mask is not None:
emb = torch.where(mask.unsqueeze(-1), self.missing_token.expand_as(emb), emb)
return emb
# ─── main v11 model ──────────────────────────────────────────────────────────
@dataclass
class V11Output:
"""Single forward pass output."""
reg_logits: Optional[torch.Tensor] = None # [B, n_query, n_bins] for reg
cls_logits: Optional[torch.Tensor] = None # [B, n_query, max_classes] for cls
y_mean: Optional[torch.Tensor] = None # [B] context y mean (reg only)
y_std: Optional[torch.Tensor] = None # [B] context y std (reg only)
class PredictLMv11(nn.Module):
"""
v11 model: same trunk as v8, new heads.
Forward returns either reg_logits (for regression) or cls_logits (for
classification). For mixed-batch joint training, the trainer should
call the model twice — once with task_type='regression' and once with
task_type='classification' — sharing the trunk pass via gradient
accumulation. (Per-batch-element task_type would require padding to
a max-class shape and we keep it simple.)
State-dict keys match v8's PredictLMv8 exactly EXCEPT:
- reg_head (Linear → BarDistributionHead.mlp)
- cls_head (Linear → BinClassificationHead.mlp)
All other keys load via load_state_dict(strict=False).
"""
def __init__(self, cfg: V11Config = None):
super().__init__()
cfg = cfg or v11_default_config()
self.cfg = cfg
# Toggle gradient checkpointing. Default True (memory-conservative,
# for H100/T4 sized batches). On A100 80GB we can disable for ~2-3×
# throughput when memory permits. Set via `model.use_grad_checkpoint = False`.
self.use_grad_checkpoint = True
# Per-feature projection (same as v8)
self.feature_weights = nn.Parameter(torch.randn(cfg.max_features, cfg.d_model) * 0.02)
self.feature_biases = nn.Parameter(torch.zeros(cfg.max_features, cfg.d_model))
# y embeddings
self.y_embed = NumericalFeatureEmbedding(cfg.d_model, n_freqs=cfg.n_periodic_freqs)
self.class_embed = nn.Embedding(cfg.max_classes, cfg.d_model)
nn.init.normal_(self.class_embed.weight, std=0.02)
# tokens
self.query_token = nn.Parameter(torch.randn(cfg.d_model) * 0.02)
self.type_embed = nn.Embedding(2, cfg.d_model)
nn.init.normal_(self.type_embed.weight, std=0.02)
self.col_type_embed = nn.Embedding(2, cfg.d_model)
nn.init.normal_(self.col_type_embed.weight, std=0.02)
# trunk: 8 shared + 4 reg + 4 cls
# v11.0.6-tiny: variant flags flow through to FFN/norm choice; defaults
# preserve v11.0 layout for backward-compat with existing ckpts.
mv = getattr(cfg, "mlp_variant", "gelu")
nv = getattr(cfg, "norm_variant", "layernorm")
share = max(1, int(getattr(cfg, "share_factor", 1)))
_layer = lambda: AlternatingLayerV8(
cfg.d_model, cfg.n_heads, cfg.d_ffn, cfg.dropout,
mlp_variant=mv, norm_variant=nv,
)
n_shared = cfg.n_layers - 4
# Under share_factor>1, build only n//share unique blocks; the
# forward pass cycles through them. n_shared and n_branch (=4) must
# both be divisible by share_factor.
if n_shared % share != 0 or 4 % share != 0:
raise ValueError(
f"share_factor={share} must divide both n_shared={n_shared} and 4 (branch layers)"
)
n_shared_unique = n_shared // share
n_branch_unique = 4 // share
self.shared_layers = nn.ModuleList([_layer() for _ in range(n_shared_unique)])
self.reg_layers = nn.ModuleList([_layer() for _ in range(n_branch_unique)])
self.cls_layers = nn.ModuleList([_layer() for _ in range(n_branch_unique)])
self.shared_norm = _build_norm(cfg.d_model, nv)
self.reg_norm = _build_norm(cfg.d_model, nv)
self.cls_norm = _build_norm(cfg.d_model, nv)
# Stored for forward to know how many depth-passes to do.
self.effective_n_shared = n_shared
self.effective_n_branch = 4
# v11 heads
self.reg_head = BarDistributionHead(
d_model=cfg.d_model, n_bins=cfg.n_bins, dropout=cfg.dropout,
)
self.cls_head = BinClassificationHead(
d_model=cfg.d_model, max_classes=cfg.max_classes, dropout=cfg.dropout,
)
# NOTE: v8's `log_var_reg` / `log_var_cls` Kendall-style task weights
# are intentionally NOT instantiated here. They were declared but
# never read in the v11 trainer, and ratio-balancing reg/cls via
# alternation + curriculum bias is sufficient at this scale per
# Expert 4. If they appear in a v8 checkpoint, `warm_start_from_v8`
# filters them out via `strict=False` (they land in `unexpected_keys`).
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# ──────────────────────────────────────────────────────────────
# Internal: build the [B, n_rows, n_cols, d_model] grid
# ──────────────────────────────────────────────────────────────
def _build_grid(
self,
X_ctx: torch.Tensor, # [B, n_ctx, n_features]
y_ctx: torch.Tensor, # [B, n_ctx]
X_query: torch.Tensor, # [B, n_query, n_features]
feature_mask: torch.Tensor, # [B, n_features] bool, True=padded
task_type: str,
ctx_row_mask: Optional[torch.Tensor] = None, # [B, n_ctx] bool, True=padded
query_row_mask: Optional[torch.Tensor] = None, # [B, n_query] bool, True=padded
):
B, n_ctx, n_features = X_ctx.shape
n_query = X_query.shape[1]
n_rows = n_ctx + n_query
max_f = self.cfg.max_features
device = X_ctx.device
# Effective feature count
if feature_mask.any():
real_per_item = (~feature_mask).sum(dim=1)
n_real = min(int(real_per_item.max().item()), max_f)
else:
n_real = min(n_features, max_f)
n_real = max(n_real, 2)
n_cols = n_real + 1
X_all = torch.cat([X_ctx, X_query], dim=1) # [B, n_rows, n_features]
X_real = X_all[:, :, :n_real] # [B, n_rows, n_real]
# Per-feature projection
feat_grid = (
X_real.unsqueeze(-1) * self.feature_weights[:n_real]
+ self.feature_biases[:n_real]
) # [B, n_rows, n_real, d_model]
# Target column embedding
if task_type == "classification":
y_clamped = y_ctx.long().clamp(0, self.cfg.max_classes - 1)
y_emb_ctx = self.class_embed(y_clamped) # [B, n_ctx, d_model]
else:
y_emb_ctx = self.y_embed(y_ctx.float()) # [B, n_ctx, d_model]
y_emb_q = self.query_token.unsqueeze(0).unsqueeze(0).expand(B, n_query, -1)
y_emb = torch.cat([y_emb_ctx, y_emb_q], dim=1).unsqueeze(2) # [B, n_rows, 1, d_model]
grid = torch.cat([feat_grid, y_emb], dim=2) # [B, n_rows, n_cols, d_model]
# Type (ctx vs query) and column-type (feature vs target) embeds
type_ids = torch.zeros(B, n_rows, dtype=torch.long, device=device)
type_ids[:, n_ctx:] = 1
grid = grid + self.type_embed(type_ids).unsqueeze(2)
col_types = torch.zeros(n_cols, dtype=torch.long, device=device)
col_types[-1] = 1
grid = grid + self.col_type_embed(col_types).unsqueeze(0).unsqueeze(0)
# Feature-pad mask
feature_pad_mask = torch.zeros(B, n_cols, dtype=torch.bool, device=device)
if feature_mask.shape[1] >= n_real:
feature_pad_mask[:, :n_real] = feature_mask[:, :n_real]
# Datapoint mask: query rows can't attend to other query rows (they each
# predict independently). If ctx_row_mask / query_row_mask are provided,
# padded rows are also blocked from being keys (per-batch [B, n_rows, n_rows]).
# Without row-pad masks, build the simple [n_rows, n_rows] shared mask.
if ctx_row_mask is None and query_row_mask is None:
datapoint_mask = torch.zeros(n_rows, n_rows, dtype=torch.bool, device=device)
datapoint_mask[n_ctx:, n_ctx:] = True
for i in range(n_query):
datapoint_mask[n_ctx + i, n_ctx + i] = False
else:
row_pad = torch.zeros(B, n_rows, dtype=torch.bool, device=device)
if ctx_row_mask is not None:
row_pad[:, :n_ctx] = ctx_row_mask
if query_row_mask is not None:
row_pad[:, n_ctx:] = query_row_mask
# base [n_rows, n_rows] block-mask: query↔query disallowed except diag
base = torch.zeros(n_rows, n_rows, dtype=torch.bool, device=device)
base[n_ctx:, n_ctx:] = True
for i in range(n_query):
base[n_ctx + i, n_ctx + i] = False
base = base.unsqueeze(0).expand(B, n_rows, n_rows).clone()
# block any KEY row that is padded (broadcast over queries)
base = base | row_pad.unsqueeze(1).expand(B, n_rows, n_rows)
datapoint_mask = base
return grid, feature_pad_mask, datapoint_mask, n_ctx
# ──────────────────────────────────────────────────────────────
# Forward
# ──────────────────────────────────────────────────────────────
def forward(
self,
X_ctx: torch.Tensor,
y_ctx: torch.Tensor,
X_query: torch.Tensor,
feature_mask: torch.Tensor,
task_type: str = "regression",
ctx_row_mask: Optional[torch.Tensor] = None,
query_row_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Returns logits over bins (reg) or classes (cls).
For regression, the trainer is responsible for calling
`standardize_y_per_task(y_ctx_orig)` BEFORE this forward to obtain
the standardized y_ctx (and stash mean/std for un-standardization).
Optional ctx_row_mask / query_row_mask (bool, True=padded row)
block padded rows from attention as keys, preventing
zero-padded fake-context contamination.
"""
grid, feat_pad, dp_mask, n_ctx = self._build_grid(
X_ctx, y_ctx, X_query, feature_mask, task_type,
ctx_row_mask=ctx_row_mask, query_row_mask=query_row_mask,
)
# Shared trunk. Under share_factor>1, len(self.shared_layers) may be
# < effective_n_shared; cycle via modulo index (ALBERT pattern).
n_uniq_shared = len(self.shared_layers)
for i in range(self.effective_n_shared):
layer = self.shared_layers[i % n_uniq_shared]
if self.training and torch.is_grad_enabled() and self.use_grad_checkpoint:
grid = grad_checkpoint(layer, grid, feat_pad, dp_mask, use_reentrant=False)
else:
grid = layer(grid, feat_pad, dp_mask)
grid = self.shared_norm(grid)
# Task-specific layers
if task_type == "regression":
h = grid
n_uniq_branch = len(self.reg_layers)
for i in range(self.effective_n_branch):
layer = self.reg_layers[i % n_uniq_branch]
if self.training and torch.is_grad_enabled() and self.use_grad_checkpoint:
h = grad_checkpoint(layer, h, feat_pad, dp_mask, use_reentrant=False)
else:
h = layer(h, feat_pad, dp_mask)
h = self.reg_norm(h)
query_target = h[:, n_ctx:, -1, :] # [B, n_query, d_model]
return self.reg_head(query_target) # [B, n_query, n_bins]
# classification — symmetric grad flow with reg path. Earlier
# versions had `h = 0.5*grid + 0.5*grid.detach()` here, which
# halved the cls branch's gradient into the shared trunk while
# the reg branch passed full gradient. Combined with bar-dist
# reg loss being ~3× larger by magnitude than cls (ln(1024) vs
# ln(10)) and 50/50 step alternation, the trunk was receiving
# ~6× more reg signal than cls signal per step. Removed.
h = grid
n_uniq_branch = len(self.cls_layers)
for i in range(self.effective_n_branch):
layer = self.cls_layers[i % n_uniq_branch]
if self.training and torch.is_grad_enabled() and self.use_grad_checkpoint:
h = grad_checkpoint(layer, h, feat_pad, dp_mask, use_reentrant=False)
else:
h = layer(h, feat_pad, dp_mask)
h = self.cls_norm(h)
query_target = h[:, n_ctx:, -1, :]
return self.cls_head(query_target) # [B, n_query, max_classes]
# ──────────────────────────────────────────────────────────────
# Convenience: warm-start from v8 checkpoint
# ──────────────────────────────────────────────────────────────
@torch.no_grad()
def warm_start_from_v8(self, v8_state_dict: dict, verbose: bool = True) -> dict:
"""Load v8 trunk weights, leave heads at random init.
Args:
v8_state_dict: a v8 checkpoint's state_dict
Returns:
dict with `loaded`, `missing`, `unexpected` key counts
"""
# Filter out v8's old reg_head / cls_head (shape-incompatible) and
# the dead log_var weights (removed in v11).
skip_prefixes = ("reg_head.", "cls_head.", "log_var_reg", "log_var_cls")
filtered = {
k: v for k, v in v8_state_dict.items()
if not k.startswith(skip_prefixes)
}
# Slice feature_weights / feature_biases if v8 ckpt has more features
# than v11's max_features (v8 used 500, v11 default 128 for VRAM).
# Keep the first N rows (v8 trained on tasks that primarily used the
# earliest column slots).
target_max = self.cfg.max_features
for k in ("feature_weights", "feature_biases"):
if k in filtered and filtered[k].shape[0] > target_max:
filtered[k] = filtered[k][:target_max]
result = self.load_state_dict(filtered, strict=False)
if verbose:
print(f"[v11.warm_start_from_v8] loaded {len(filtered)} keys")
if result.missing_keys:
print(f" missing ({len(result.missing_keys)}): {result.missing_keys[:5]}…")
if result.unexpected_keys:
print(f" unexpected ({len(result.unexpected_keys)}): {result.unexpected_keys[:5]}…")
return {
"loaded": len(filtered),
"missing": len(result.missing_keys),
"unexpected": len(result.unexpected_keys),
}
@torch.no_grad()
def warm_start_slice_from_v11(self, v11_state_dict: dict, verbose: bool = True) -> dict:
"""Initialize this (smaller) model from a v11.0 ckpt by SLICING layers.
Used when this model has `share_factor > 1`: the v11.0 trunk has
`n_layers` unique blocks, but this model has only `n_layers /
share_factor` unique blocks (each used `share_factor` times via
cycling). We copy every-`share_factor`-th v11.0 block into the
student's unique-blocks list.
Non-layer modules (feature_weights, y_embed, class_embed, query_token,
col_type_embed, shared_norm/reg_norm/cls_norm, reg_head, cls_head)
copy verbatim — they're share-factor-independent.
Requires this model use legacy (gelu + layernorm) MLP/norm variants
for the layer slicing to be shape-compatible.
"""
if self.cfg.mlp_variant != "gelu" or self.cfg.norm_variant != "layernorm":
raise ValueError(
"warm_start_slice_from_v11 requires mlp_variant=gelu, "
"norm_variant=layernorm for shape compatibility with v11.0 ckpt. "
f"Got mlp_variant={self.cfg.mlp_variant}, norm_variant={self.cfg.norm_variant}."
)
share = max(1, int(self.cfg.share_factor))
# Build the source→target index map for layer slicing.
# v11.0 trunk: 8 shared + 4 reg + 4 cls
v11_n_shared = self.cfg.n_layers - 4 # 8 typically
v11_n_branch = 4
# Student unique counts
s_n_shared = v11_n_shared // share
s_n_branch = v11_n_branch // share
# Pick every share-th index from v11.0
shared_src = list(range(0, v11_n_shared, share))[:s_n_shared]
branch_src = list(range(0, v11_n_branch, share))[:s_n_branch]
new_state = {}
layer_keys_copied = 0
non_layer_keys_copied = 0
for k, v in v11_state_dict.items():
# Layer-keyed weights: rewrite the layer index per the slicing map.
if k.startswith("shared_layers."):
# k = "shared_layers.<idx>.<rest>"
parts = k.split(".", 2)
src_idx = int(parts[1])
if src_idx in shared_src:
tgt_idx = shared_src.index(src_idx)
new_state[f"shared_layers.{tgt_idx}.{parts[2]}"] = v
layer_keys_copied += 1
elif k.startswith("reg_layers."):
parts = k.split(".", 2)
src_idx = int(parts[1])
if src_idx in branch_src:
tgt_idx = branch_src.index(src_idx)
new_state[f"reg_layers.{tgt_idx}.{parts[2]}"] = v
layer_keys_copied += 1
elif k.startswith("cls_layers."):
parts = k.split(".", 2)
src_idx = int(parts[1])
if src_idx in branch_src:
tgt_idx = branch_src.index(src_idx)
new_state[f"cls_layers.{tgt_idx}.{parts[2]}"] = v
layer_keys_copied += 1
else:
# Non-layer weights copy verbatim.
new_state[k] = v
non_layer_keys_copied += 1
result = self.load_state_dict(new_state, strict=False)
param_names = {n for n, _ in self.named_parameters()}
missing_params = [k for k in result.missing_keys if k in param_names]
if verbose:
print(f"[v11.warm_start_slice] share_factor={share}, slice indices: "
f"shared={shared_src}, branch={branch_src}")
print(f" copied {layer_keys_copied} layer-keys + {non_layer_keys_copied} non-layer keys")
if missing_params:
print(f" WARN: {len(missing_params)} trainable params unmatched: "
f"{missing_params[:5]}{'...' if len(missing_params) > 5 else ''}")
if result.unexpected_keys:
print(f" ignored {len(result.unexpected_keys)} unexpected keys (e.g., v11.0 layers we didn't slice)")
return {
"share_factor": share,
"layer_keys_copied": layer_keys_copied,
"non_layer_keys_copied": non_layer_keys_copied,
"missing_params": len(missing_params),
"unexpected": len(result.unexpected_keys),
}
def count_params(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# ─── self-test: forward pass shapes + warm-start sanity ───────────────────────
if __name__ == "__main__":
torch.manual_seed(0)
cfg = V11Config()
model = PredictLMv11(cfg)
print(f"v11 model: {count_params(model)/1e6:.1f}M params (cfg={cfg})")
B, n_ctx, n_q, n_f = 2, 64, 16, 8
X_ctx = torch.randn(B, n_ctx, n_f)
y_ctx = torch.randn(B, n_ctx)
X_q = torch.randn(B, n_q, n_f)
feat_mask = torch.zeros(B, n_f, dtype=torch.bool)
# Regression path
reg_logits = model(X_ctx, y_ctx, X_q, feat_mask, task_type="regression")
print(f"[reg] logits shape: {tuple(reg_logits.shape)} (expected (2,16,1024))")
assert reg_logits.shape == (B, n_q, cfg.n_bins)
loss = bar_distribution_loss(reg_logits, y_ctx[:, :n_q], model.reg_head)
print(f"[reg] uniform-prior loss: {loss.item():.3f} (≈ ln(1024) = 6.93)")
# Classification path
y_ctx_cls = torch.randint(0, 5, (B, n_ctx))
cls_logits = model(X_ctx, y_ctx_cls, X_q, feat_mask, task_type="classification")
print(f"[cls] logits shape: {tuple(cls_logits.shape)} (expected (2,16,10))")
assert cls_logits.shape == (B, n_q, cfg.max_classes)
n_classes_per_task = torch.tensor([3, 5])
y_q_cls = torch.stack([
torch.randint(0, 3, (n_q,)),
torch.randint(0, 5, (n_q,)),
])
loss_c = cls_masked_loss(cls_logits, y_q_cls, n_classes_per_task)
print(f"[cls] masked loss: {loss_c.item():.3f}")
# Warm-start dry run: simulate a v8 ckpt with wrong-shape heads
fake_v8_ckpt = {k: v.clone() for k, v in model.state_dict().items()
if not k.startswith("reg_head.") and not k.startswith("cls_head.")}
fake_v8_ckpt["reg_head.weight"] = torch.zeros(2, cfg.d_model) # v8 shape
fake_v8_ckpt["reg_head.bias"] = torch.zeros(2)
fake_v8_ckpt["cls_head.weight"] = torch.zeros(cfg.max_classes, cfg.d_model)
fake_v8_ckpt["cls_head.bias"] = torch.zeros(cfg.max_classes)
fresh = PredictLMv11(cfg)
info = fresh.warm_start_from_v8(fake_v8_ckpt)
print(f"[warm-start] loaded={info['loaded']}, missing={info['missing']}, unexpected={info['unexpected']}")
assert info['unexpected'] == 0, "v8 reg/cls heads should be filtered, got unexpected"
print("[OK] v11 model self-test passed")