""" v11 model — same trunk as v8 so we can warm-start from v8's final checkpoint. The architecture differences vs v8 are the prediction heads: v8: reg_head = Linear(d_model, 2) # mean, log_var v8: cls_head = Linear(d_model, max_classes) v11: reg_head = BarDistributionHead(d_model, n_bins=1024) v11: cls_head = BinClassificationHead(d_model, max_classes=10) Everything else (feature_weights, y_embed, class_embed, type_embed, shared_layers, reg_layers, cls_layers, *_norm) keeps the same module names and parameter shapes, so: v11_model.load_state_dict(v8_ckpt, strict=False) will load the trunk and leave only the heads as randomly-initialized. The v11 trainer's head-warmup phase trains only the heads + reg_norm / cls_norm for the first 5k steps, exactly as v10 did. Tokenization is identical to v8: 2D grid [B, n_rows, n_cols, d_model] with one token per cell. Each layer alternates feature-attention (within a row) and datapoint-attention (within a column with the context-vs-query mask). For now, v11 SKIPS v8's metadata conditioning (the column-statistics encoder). The v11 plan defers architectural cleanups to v13; the goal here is data-prior work, not arch work. Once warm-started, the metadata-related parameters in the v8 ckpt are simply ignored. """ from __future__ import annotations import math from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint as grad_checkpoint from .heads import ( BarDistributionHead, BinClassificationHead, bar_distribution_loss, cls_masked_loss, standardize_y_per_task, decode_bar_distribution, cls_predict, ) # ─── config ────────────────────────────────────────────────────────────────── @dataclass class V11Config: d_model: int = 256 n_layers: int = 12 # 8 shared + 4 task-specific per branch n_heads: int = 8 d_ffn: int = 1024 dropout: float = 0.0 max_features: int = 128 # warm-start slices v8's feature_weights[500] → [128] in warm_start_from_v8 max_classes: int = 10 max_context: int = 1024 max_query: int = 256 n_periodic_freqs: int = 8 n_bins: int = 1024 cls_label_smoothing: float = 0.05 # v11.0.6-tiny architecture toggles. Defaults preserve v11.0 behavior so # existing ckpts load unchanged via warm_start_from_v8 / strict=False. mlp_variant: str = "gelu" # "gelu" (legacy) or "swiglu" norm_variant: str = "layernorm" # "layernorm" (legacy) or "rmsnorm" # ALBERT-style cross-layer parameter sharing. share_factor>1 means the # `n_layers`-deep stack uses only `n_layers // share_factor` UNIQUE # modules; each unique block is applied `share_factor` times via index # cycling. share_factor=1 = legacy (no sharing). share_factor: int = 1 def v11_default_config() -> V11Config: return V11Config() # ─── v11.0.6-tiny blocks (drop-in upgrades behind config flag) ────────────── class SwiGLUFFN(nn.Module): """SwiGLU MLP (Shazeer 2020, arXiv 2002.05202). Default in PaLM/LLaMA. Pattern: Linear(d, 8d/3) gate + Linear(d, 8d/3) value, silu*gate, Linear(8d/3, d). Hidden dim scaled to (8/3)d_ffn/4 = (2/3)d_ffn to hold param count constant vs the legacy GELU FFN (Linear(d, d_ffn), GELU, Linear(d_ffn, d)). """ def __init__(self, d_model: int, d_ffn: int): super().__init__() # Match legacy FFN's parameter count: legacy is 2 * d_model * d_ffn. # SwiGLU is 3 linears (gate, value, out), each d_model * d_hidden. # So set d_hidden = (2/3) * d_ffn for parity. d_hidden = int(round(d_ffn * 2 / 3)) self.w_gate = nn.Linear(d_model, d_hidden, bias=False) self.w_value = nn.Linear(d_model, d_hidden, bias=False) self.w_out = nn.Linear(d_hidden, d_model, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w_out(F.silu(self.w_gate(x)) * self.w_value(x)) class RMSNorm(nn.Module): """Root Mean Square Layer Norm (Zhang & Sennrich 2019). LLaMA default. No mean subtraction, no learned bias. Cheaper than LayerNorm; works as a drop-in for transformer pre-norm. """ def __init__(self, d_model: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(d_model)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: return self.weight * (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)) def _build_ffn(d_model: int, d_ffn: int, variant: str = "gelu") -> nn.Module: """Factory: return GELU MLP (legacy) or SwiGLU MLP based on variant.""" if variant == "swiglu": return SwiGLUFFN(d_model, d_ffn) return nn.Sequential( nn.Linear(d_model, d_ffn), nn.GELU(), nn.Linear(d_ffn, d_model), ) def _build_norm(d_model: int, variant: str = "layernorm") -> nn.Module: """Factory: return LayerNorm (legacy) or RMSNorm based on variant.""" if variant == "rmsnorm": return RMSNorm(d_model) return nn.LayerNorm(d_model) # ─── blocks (verbatim from v8 so state_dict keys match) ─────────────────────── class FlashPreLNAttention(nn.Module): """Pre-LN attention + FFN using F.scaled_dot_product_attention (Flash).""" def __init__(self, d_model: int, n_heads: int, d_ffn: int, dropout: float = 0.0, mlp_variant: str = "gelu", norm_variant: str = "layernorm"): super().__init__() self.n_heads = n_heads self.head_dim = d_model // n_heads self.d_model = d_model self.norm1 = _build_norm(d_model, norm_variant) self.q_proj = nn.Linear(d_model, d_model) self.k_proj = nn.Linear(d_model, d_model) self.v_proj = nn.Linear(d_model, d_model) self.o_proj = nn.Linear(d_model, d_model) self.norm2 = _build_norm(d_model, norm_variant) self.ffn = _build_ffn(d_model, d_ffn, mlp_variant) def _heads(self, x: torch.Tensor) -> torch.Tensor: B, S, _ = x.shape return x.view(B, S, self.n_heads, self.head_dim).transpose(1, 2) def forward( self, x: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: residual = x x = self.norm1(x) q = self._heads(self.q_proj(x)) k = self._heads(self.k_proj(x)) v = self._heads(self.v_proj(x)) sdpa_mask = None if attn_mask is not None: # attn_mask may be 2D [seq, seq] (shared across batch) or 3D [B, seq, seq] if attn_mask.dim() == 2: amask = torch.zeros_like(attn_mask, dtype=q.dtype) amask.masked_fill_(attn_mask, float("-inf")) sdpa_mask = amask.unsqueeze(0).unsqueeze(0) # [1,1,seq,seq] else: amask = torch.zeros_like(attn_mask, dtype=q.dtype) amask.masked_fill_(attn_mask, float("-inf")) sdpa_mask = amask.unsqueeze(1) # [B,1,seq,seq] if key_padding_mask is not None: pad_mask = torch.zeros( key_padding_mask.shape[0], 1, 1, key_padding_mask.shape[1], dtype=q.dtype, device=q.device, ) pad_mask.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")) sdpa_mask = pad_mask if sdpa_mask is None else sdpa_mask + pad_mask attn_out = F.scaled_dot_product_attention(q, k, v, attn_mask=sdpa_mask, dropout_p=0.0) attn_out = attn_out.transpose(1, 2).contiguous().view(x.shape[0], x.shape[1], self.d_model) x = self.o_proj(attn_out) + residual residual = x x = self.norm2(x) x = self.ffn(x) + residual return x class AlternatingLayerV8(nn.Module): """Feature attention (within rows) → Datapoint attention (within cols). Name matches v8 verbatim so state_dict keys align for warm-start. """ def __init__(self, d_model: int, n_heads: int, d_ffn: int, dropout: float = 0.0, mlp_variant: str = "gelu", norm_variant: str = "layernorm"): super().__init__() self.feature_attn = FlashPreLNAttention(d_model, n_heads, d_ffn, dropout, mlp_variant=mlp_variant, norm_variant=norm_variant) self.datapoint_attn = FlashPreLNAttention(d_model, n_heads, d_ffn, dropout, mlp_variant=mlp_variant, norm_variant=norm_variant) def forward( self, x: torch.Tensor, # [B, n_rows, n_cols, d_model] feature_pad_mask: torch.Tensor, datapoint_mask: torch.Tensor, # [n_rows, n_rows] OR [B, n_rows, n_rows] ) -> torch.Tensor: B, n_rows, n_cols, d_model = x.shape # within-row feature attn x_feat = x.reshape(B * n_rows, n_cols, d_model) feat_pad = feature_pad_mask.unsqueeze(1).expand(B, n_rows, n_cols).reshape(B * n_rows, n_cols) x_feat = self.feature_attn(x_feat, key_padding_mask=feat_pad) x = x_feat.reshape(B, n_rows, n_cols, d_model) # within-col datapoint attn — expand per-batch mask along n_cols if needed x_data = x.permute(0, 2, 1, 3).reshape(B * n_cols, n_rows, d_model) if datapoint_mask.dim() == 3: # [B, n_rows, n_rows] → [B*n_cols, n_rows, n_rows] dp_mask = ( datapoint_mask.unsqueeze(1) .expand(B, n_cols, n_rows, n_rows) .reshape(B * n_cols, n_rows, n_rows) ) else: dp_mask = datapoint_mask x_data = self.datapoint_attn(x_data, attn_mask=dp_mask) x = x_data.reshape(B, n_cols, n_rows, d_model).permute(0, 2, 1, 3) return x # ─── numerical-value embedding (matches v8's NumericalFeatureEmbedding) ────── class NumericalFeatureEmbedding(nn.Module): """Embed a scalar numerical value into a d_model vector via Fourier features.""" def __init__(self, d_model: int = 256, n_freqs: int = 8): super().__init__() self.d_model = d_model self.n_freqs = n_freqs freqs = 2.0 ** torch.arange(n_freqs, dtype=torch.float32) self.register_buffer("freqs", freqs) in_dim = 1 + 1 + 2 * n_freqs # sign + log_mag + sin/cos at each freq self.mlp = nn.Sequential( nn.Linear(in_dim, d_model), nn.GELU(), nn.Linear(d_model, d_model), ) self.missing_token = nn.Parameter(torch.randn(d_model) * 0.02) def forward(self, values: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: sign = torch.sign(values) log_mag = torch.log1p(torch.abs(values)) # Sinusoidal features at multiple frequencies f = self.freqs.to(values.device).view(*([1] * (values.dim() - 1)), self.n_freqs) scaled = values.unsqueeze(-1) * f sins = torch.sin(scaled) coss = torch.cos(scaled) feats = torch.cat([sign.unsqueeze(-1), log_mag.unsqueeze(-1), sins, coss], dim=-1) emb = self.mlp(feats) if mask is not None: emb = torch.where(mask.unsqueeze(-1), self.missing_token.expand_as(emb), emb) return emb # ─── main v11 model ────────────────────────────────────────────────────────── @dataclass class V11Output: """Single forward pass output.""" reg_logits: Optional[torch.Tensor] = None # [B, n_query, n_bins] for reg cls_logits: Optional[torch.Tensor] = None # [B, n_query, max_classes] for cls y_mean: Optional[torch.Tensor] = None # [B] context y mean (reg only) y_std: Optional[torch.Tensor] = None # [B] context y std (reg only) class PredictLMv11(nn.Module): """ v11 model: same trunk as v8, new heads. Forward returns either reg_logits (for regression) or cls_logits (for classification). For mixed-batch joint training, the trainer should call the model twice — once with task_type='regression' and once with task_type='classification' — sharing the trunk pass via gradient accumulation. (Per-batch-element task_type would require padding to a max-class shape and we keep it simple.) State-dict keys match v8's PredictLMv8 exactly EXCEPT: - reg_head (Linear → BarDistributionHead.mlp) - cls_head (Linear → BinClassificationHead.mlp) All other keys load via load_state_dict(strict=False). """ def __init__(self, cfg: V11Config = None): super().__init__() cfg = cfg or v11_default_config() self.cfg = cfg # Toggle gradient checkpointing. Default True (memory-conservative, # for H100/T4 sized batches). On A100 80GB we can disable for ~2-3× # throughput when memory permits. Set via `model.use_grad_checkpoint = False`. self.use_grad_checkpoint = True # Per-feature projection (same as v8) self.feature_weights = nn.Parameter(torch.randn(cfg.max_features, cfg.d_model) * 0.02) self.feature_biases = nn.Parameter(torch.zeros(cfg.max_features, cfg.d_model)) # y embeddings self.y_embed = NumericalFeatureEmbedding(cfg.d_model, n_freqs=cfg.n_periodic_freqs) self.class_embed = nn.Embedding(cfg.max_classes, cfg.d_model) nn.init.normal_(self.class_embed.weight, std=0.02) # tokens self.query_token = nn.Parameter(torch.randn(cfg.d_model) * 0.02) self.type_embed = nn.Embedding(2, cfg.d_model) nn.init.normal_(self.type_embed.weight, std=0.02) self.col_type_embed = nn.Embedding(2, cfg.d_model) nn.init.normal_(self.col_type_embed.weight, std=0.02) # trunk: 8 shared + 4 reg + 4 cls # v11.0.6-tiny: variant flags flow through to FFN/norm choice; defaults # preserve v11.0 layout for backward-compat with existing ckpts. mv = getattr(cfg, "mlp_variant", "gelu") nv = getattr(cfg, "norm_variant", "layernorm") share = max(1, int(getattr(cfg, "share_factor", 1))) _layer = lambda: AlternatingLayerV8( cfg.d_model, cfg.n_heads, cfg.d_ffn, cfg.dropout, mlp_variant=mv, norm_variant=nv, ) n_shared = cfg.n_layers - 4 # Under share_factor>1, build only n//share unique blocks; the # forward pass cycles through them. n_shared and n_branch (=4) must # both be divisible by share_factor. if n_shared % share != 0 or 4 % share != 0: raise ValueError( f"share_factor={share} must divide both n_shared={n_shared} and 4 (branch layers)" ) n_shared_unique = n_shared // share n_branch_unique = 4 // share self.shared_layers = nn.ModuleList([_layer() for _ in range(n_shared_unique)]) self.reg_layers = nn.ModuleList([_layer() for _ in range(n_branch_unique)]) self.cls_layers = nn.ModuleList([_layer() for _ in range(n_branch_unique)]) self.shared_norm = _build_norm(cfg.d_model, nv) self.reg_norm = _build_norm(cfg.d_model, nv) self.cls_norm = _build_norm(cfg.d_model, nv) # Stored for forward to know how many depth-passes to do. self.effective_n_shared = n_shared self.effective_n_branch = 4 # v11 heads self.reg_head = BarDistributionHead( d_model=cfg.d_model, n_bins=cfg.n_bins, dropout=cfg.dropout, ) self.cls_head = BinClassificationHead( d_model=cfg.d_model, max_classes=cfg.max_classes, dropout=cfg.dropout, ) # NOTE: v8's `log_var_reg` / `log_var_cls` Kendall-style task weights # are intentionally NOT instantiated here. They were declared but # never read in the v11 trainer, and ratio-balancing reg/cls via # alternation + curriculum bias is sufficient at this scale per # Expert 4. If they appear in a v8 checkpoint, `warm_start_from_v8` # filters them out via `strict=False` (they land in `unexpected_keys`). self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) # ────────────────────────────────────────────────────────────── # Internal: build the [B, n_rows, n_cols, d_model] grid # ────────────────────────────────────────────────────────────── def _build_grid( self, X_ctx: torch.Tensor, # [B, n_ctx, n_features] y_ctx: torch.Tensor, # [B, n_ctx] X_query: torch.Tensor, # [B, n_query, n_features] feature_mask: torch.Tensor, # [B, n_features] bool, True=padded task_type: str, ctx_row_mask: Optional[torch.Tensor] = None, # [B, n_ctx] bool, True=padded query_row_mask: Optional[torch.Tensor] = None, # [B, n_query] bool, True=padded ): B, n_ctx, n_features = X_ctx.shape n_query = X_query.shape[1] n_rows = n_ctx + n_query max_f = self.cfg.max_features device = X_ctx.device # Effective feature count if feature_mask.any(): real_per_item = (~feature_mask).sum(dim=1) n_real = min(int(real_per_item.max().item()), max_f) else: n_real = min(n_features, max_f) n_real = max(n_real, 2) n_cols = n_real + 1 X_all = torch.cat([X_ctx, X_query], dim=1) # [B, n_rows, n_features] X_real = X_all[:, :, :n_real] # [B, n_rows, n_real] # Per-feature projection feat_grid = ( X_real.unsqueeze(-1) * self.feature_weights[:n_real] + self.feature_biases[:n_real] ) # [B, n_rows, n_real, d_model] # Target column embedding if task_type == "classification": y_clamped = y_ctx.long().clamp(0, self.cfg.max_classes - 1) y_emb_ctx = self.class_embed(y_clamped) # [B, n_ctx, d_model] else: y_emb_ctx = self.y_embed(y_ctx.float()) # [B, n_ctx, d_model] y_emb_q = self.query_token.unsqueeze(0).unsqueeze(0).expand(B, n_query, -1) y_emb = torch.cat([y_emb_ctx, y_emb_q], dim=1).unsqueeze(2) # [B, n_rows, 1, d_model] grid = torch.cat([feat_grid, y_emb], dim=2) # [B, n_rows, n_cols, d_model] # Type (ctx vs query) and column-type (feature vs target) embeds type_ids = torch.zeros(B, n_rows, dtype=torch.long, device=device) type_ids[:, n_ctx:] = 1 grid = grid + self.type_embed(type_ids).unsqueeze(2) col_types = torch.zeros(n_cols, dtype=torch.long, device=device) col_types[-1] = 1 grid = grid + self.col_type_embed(col_types).unsqueeze(0).unsqueeze(0) # Feature-pad mask feature_pad_mask = torch.zeros(B, n_cols, dtype=torch.bool, device=device) if feature_mask.shape[1] >= n_real: feature_pad_mask[:, :n_real] = feature_mask[:, :n_real] # Datapoint mask: query rows can't attend to other query rows (they each # predict independently). If ctx_row_mask / query_row_mask are provided, # padded rows are also blocked from being keys (per-batch [B, n_rows, n_rows]). # Without row-pad masks, build the simple [n_rows, n_rows] shared mask. if ctx_row_mask is None and query_row_mask is None: datapoint_mask = torch.zeros(n_rows, n_rows, dtype=torch.bool, device=device) datapoint_mask[n_ctx:, n_ctx:] = True for i in range(n_query): datapoint_mask[n_ctx + i, n_ctx + i] = False else: row_pad = torch.zeros(B, n_rows, dtype=torch.bool, device=device) if ctx_row_mask is not None: row_pad[:, :n_ctx] = ctx_row_mask if query_row_mask is not None: row_pad[:, n_ctx:] = query_row_mask # base [n_rows, n_rows] block-mask: query↔query disallowed except diag base = torch.zeros(n_rows, n_rows, dtype=torch.bool, device=device) base[n_ctx:, n_ctx:] = True for i in range(n_query): base[n_ctx + i, n_ctx + i] = False base = base.unsqueeze(0).expand(B, n_rows, n_rows).clone() # block any KEY row that is padded (broadcast over queries) base = base | row_pad.unsqueeze(1).expand(B, n_rows, n_rows) datapoint_mask = base return grid, feature_pad_mask, datapoint_mask, n_ctx # ────────────────────────────────────────────────────────────── # Forward # ────────────────────────────────────────────────────────────── def forward( self, X_ctx: torch.Tensor, y_ctx: torch.Tensor, X_query: torch.Tensor, feature_mask: torch.Tensor, task_type: str = "regression", ctx_row_mask: Optional[torch.Tensor] = None, query_row_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Returns logits over bins (reg) or classes (cls). For regression, the trainer is responsible for calling `standardize_y_per_task(y_ctx_orig)` BEFORE this forward to obtain the standardized y_ctx (and stash mean/std for un-standardization). Optional ctx_row_mask / query_row_mask (bool, True=padded row) block padded rows from attention as keys, preventing zero-padded fake-context contamination. """ grid, feat_pad, dp_mask, n_ctx = self._build_grid( X_ctx, y_ctx, X_query, feature_mask, task_type, ctx_row_mask=ctx_row_mask, query_row_mask=query_row_mask, ) # Shared trunk. Under share_factor>1, len(self.shared_layers) may be # < effective_n_shared; cycle via modulo index (ALBERT pattern). n_uniq_shared = len(self.shared_layers) for i in range(self.effective_n_shared): layer = self.shared_layers[i % n_uniq_shared] if self.training and torch.is_grad_enabled() and self.use_grad_checkpoint: grid = grad_checkpoint(layer, grid, feat_pad, dp_mask, use_reentrant=False) else: grid = layer(grid, feat_pad, dp_mask) grid = self.shared_norm(grid) # Task-specific layers if task_type == "regression": h = grid n_uniq_branch = len(self.reg_layers) for i in range(self.effective_n_branch): layer = self.reg_layers[i % n_uniq_branch] if self.training and torch.is_grad_enabled() and self.use_grad_checkpoint: h = grad_checkpoint(layer, h, feat_pad, dp_mask, use_reentrant=False) else: h = layer(h, feat_pad, dp_mask) h = self.reg_norm(h) query_target = h[:, n_ctx:, -1, :] # [B, n_query, d_model] return self.reg_head(query_target) # [B, n_query, n_bins] # classification — symmetric grad flow with reg path. Earlier # versions had `h = 0.5*grid + 0.5*grid.detach()` here, which # halved the cls branch's gradient into the shared trunk while # the reg branch passed full gradient. Combined with bar-dist # reg loss being ~3× larger by magnitude than cls (ln(1024) vs # ln(10)) and 50/50 step alternation, the trunk was receiving # ~6× more reg signal than cls signal per step. Removed. h = grid n_uniq_branch = len(self.cls_layers) for i in range(self.effective_n_branch): layer = self.cls_layers[i % n_uniq_branch] if self.training and torch.is_grad_enabled() and self.use_grad_checkpoint: h = grad_checkpoint(layer, h, feat_pad, dp_mask, use_reentrant=False) else: h = layer(h, feat_pad, dp_mask) h = self.cls_norm(h) query_target = h[:, n_ctx:, -1, :] return self.cls_head(query_target) # [B, n_query, max_classes] # ────────────────────────────────────────────────────────────── # Convenience: warm-start from v8 checkpoint # ────────────────────────────────────────────────────────────── @torch.no_grad() def warm_start_from_v8(self, v8_state_dict: dict, verbose: bool = True) -> dict: """Load v8 trunk weights, leave heads at random init. Args: v8_state_dict: a v8 checkpoint's state_dict Returns: dict with `loaded`, `missing`, `unexpected` key counts """ # Filter out v8's old reg_head / cls_head (shape-incompatible) and # the dead log_var weights (removed in v11). skip_prefixes = ("reg_head.", "cls_head.", "log_var_reg", "log_var_cls") filtered = { k: v for k, v in v8_state_dict.items() if not k.startswith(skip_prefixes) } # Slice feature_weights / feature_biases if v8 ckpt has more features # than v11's max_features (v8 used 500, v11 default 128 for VRAM). # Keep the first N rows (v8 trained on tasks that primarily used the # earliest column slots). target_max = self.cfg.max_features for k in ("feature_weights", "feature_biases"): if k in filtered and filtered[k].shape[0] > target_max: filtered[k] = filtered[k][:target_max] result = self.load_state_dict(filtered, strict=False) if verbose: print(f"[v11.warm_start_from_v8] loaded {len(filtered)} keys") if result.missing_keys: print(f" missing ({len(result.missing_keys)}): {result.missing_keys[:5]}…") if result.unexpected_keys: print(f" unexpected ({len(result.unexpected_keys)}): {result.unexpected_keys[:5]}…") return { "loaded": len(filtered), "missing": len(result.missing_keys), "unexpected": len(result.unexpected_keys), } @torch.no_grad() def warm_start_slice_from_v11(self, v11_state_dict: dict, verbose: bool = True) -> dict: """Initialize this (smaller) model from a v11.0 ckpt by SLICING layers. Used when this model has `share_factor > 1`: the v11.0 trunk has `n_layers` unique blocks, but this model has only `n_layers / share_factor` unique blocks (each used `share_factor` times via cycling). We copy every-`share_factor`-th v11.0 block into the student's unique-blocks list. Non-layer modules (feature_weights, y_embed, class_embed, query_token, col_type_embed, shared_norm/reg_norm/cls_norm, reg_head, cls_head) copy verbatim — they're share-factor-independent. Requires this model use legacy (gelu + layernorm) MLP/norm variants for the layer slicing to be shape-compatible. """ if self.cfg.mlp_variant != "gelu" or self.cfg.norm_variant != "layernorm": raise ValueError( "warm_start_slice_from_v11 requires mlp_variant=gelu, " "norm_variant=layernorm for shape compatibility with v11.0 ckpt. " f"Got mlp_variant={self.cfg.mlp_variant}, norm_variant={self.cfg.norm_variant}." ) share = max(1, int(self.cfg.share_factor)) # Build the source→target index map for layer slicing. # v11.0 trunk: 8 shared + 4 reg + 4 cls v11_n_shared = self.cfg.n_layers - 4 # 8 typically v11_n_branch = 4 # Student unique counts s_n_shared = v11_n_shared // share s_n_branch = v11_n_branch // share # Pick every share-th index from v11.0 shared_src = list(range(0, v11_n_shared, share))[:s_n_shared] branch_src = list(range(0, v11_n_branch, share))[:s_n_branch] new_state = {} layer_keys_copied = 0 non_layer_keys_copied = 0 for k, v in v11_state_dict.items(): # Layer-keyed weights: rewrite the layer index per the slicing map. if k.startswith("shared_layers."): # k = "shared_layers.." parts = k.split(".", 2) src_idx = int(parts[1]) if src_idx in shared_src: tgt_idx = shared_src.index(src_idx) new_state[f"shared_layers.{tgt_idx}.{parts[2]}"] = v layer_keys_copied += 1 elif k.startswith("reg_layers."): parts = k.split(".", 2) src_idx = int(parts[1]) if src_idx in branch_src: tgt_idx = branch_src.index(src_idx) new_state[f"reg_layers.{tgt_idx}.{parts[2]}"] = v layer_keys_copied += 1 elif k.startswith("cls_layers."): parts = k.split(".", 2) src_idx = int(parts[1]) if src_idx in branch_src: tgt_idx = branch_src.index(src_idx) new_state[f"cls_layers.{tgt_idx}.{parts[2]}"] = v layer_keys_copied += 1 else: # Non-layer weights copy verbatim. new_state[k] = v non_layer_keys_copied += 1 result = self.load_state_dict(new_state, strict=False) param_names = {n for n, _ in self.named_parameters()} missing_params = [k for k in result.missing_keys if k in param_names] if verbose: print(f"[v11.warm_start_slice] share_factor={share}, slice indices: " f"shared={shared_src}, branch={branch_src}") print(f" copied {layer_keys_copied} layer-keys + {non_layer_keys_copied} non-layer keys") if missing_params: print(f" WARN: {len(missing_params)} trainable params unmatched: " f"{missing_params[:5]}{'...' if len(missing_params) > 5 else ''}") if result.unexpected_keys: print(f" ignored {len(result.unexpected_keys)} unexpected keys (e.g., v11.0 layers we didn't slice)") return { "share_factor": share, "layer_keys_copied": layer_keys_copied, "non_layer_keys_copied": non_layer_keys_copied, "missing_params": len(missing_params), "unexpected": len(result.unexpected_keys), } def count_params(model: nn.Module) -> int: return sum(p.numel() for p in model.parameters() if p.requires_grad) # ─── self-test: forward pass shapes + warm-start sanity ─────────────────────── if __name__ == "__main__": torch.manual_seed(0) cfg = V11Config() model = PredictLMv11(cfg) print(f"v11 model: {count_params(model)/1e6:.1f}M params (cfg={cfg})") B, n_ctx, n_q, n_f = 2, 64, 16, 8 X_ctx = torch.randn(B, n_ctx, n_f) y_ctx = torch.randn(B, n_ctx) X_q = torch.randn(B, n_q, n_f) feat_mask = torch.zeros(B, n_f, dtype=torch.bool) # Regression path reg_logits = model(X_ctx, y_ctx, X_q, feat_mask, task_type="regression") print(f"[reg] logits shape: {tuple(reg_logits.shape)} (expected (2,16,1024))") assert reg_logits.shape == (B, n_q, cfg.n_bins) loss = bar_distribution_loss(reg_logits, y_ctx[:, :n_q], model.reg_head) print(f"[reg] uniform-prior loss: {loss.item():.3f} (≈ ln(1024) = 6.93)") # Classification path y_ctx_cls = torch.randint(0, 5, (B, n_ctx)) cls_logits = model(X_ctx, y_ctx_cls, X_q, feat_mask, task_type="classification") print(f"[cls] logits shape: {tuple(cls_logits.shape)} (expected (2,16,10))") assert cls_logits.shape == (B, n_q, cfg.max_classes) n_classes_per_task = torch.tensor([3, 5]) y_q_cls = torch.stack([ torch.randint(0, 3, (n_q,)), torch.randint(0, 5, (n_q,)), ]) loss_c = cls_masked_loss(cls_logits, y_q_cls, n_classes_per_task) print(f"[cls] masked loss: {loss_c.item():.3f}") # Warm-start dry run: simulate a v8 ckpt with wrong-shape heads fake_v8_ckpt = {k: v.clone() for k, v in model.state_dict().items() if not k.startswith("reg_head.") and not k.startswith("cls_head.")} fake_v8_ckpt["reg_head.weight"] = torch.zeros(2, cfg.d_model) # v8 shape fake_v8_ckpt["reg_head.bias"] = torch.zeros(2) fake_v8_ckpt["cls_head.weight"] = torch.zeros(cfg.max_classes, cfg.d_model) fake_v8_ckpt["cls_head.bias"] = torch.zeros(cfg.max_classes) fresh = PredictLMv11(cfg) info = fresh.warm_start_from_v8(fake_v8_ckpt) print(f"[warm-start] loaded={info['loaded']}, missing={info['missing']}, unexpected={info['unexpected']}") assert info['unexpected'] == 0, "v8 reg/cls heads should be filtered, got unexpected" print("[OK] v11 model self-test passed")