| """ |
| v11 model — same trunk as v8 so we can warm-start from v8's final checkpoint. |
| The architecture differences vs v8 are the prediction heads: |
| |
| v8: reg_head = Linear(d_model, 2) # mean, log_var |
| v8: cls_head = Linear(d_model, max_classes) |
| v11: reg_head = BarDistributionHead(d_model, n_bins=1024) |
| v11: cls_head = BinClassificationHead(d_model, max_classes=10) |
| |
| Everything else (feature_weights, y_embed, class_embed, type_embed, |
| shared_layers, reg_layers, cls_layers, *_norm) keeps the same module |
| names and parameter shapes, so: |
| |
| v11_model.load_state_dict(v8_ckpt, strict=False) |
| |
| will load the trunk and leave only the heads as randomly-initialized. |
| The v11 trainer's head-warmup phase trains only the heads + reg_norm / |
| cls_norm for the first 5k steps, exactly as v10 did. |
| |
| Tokenization is identical to v8: 2D grid [B, n_rows, n_cols, d_model] |
| with one token per cell. Each layer alternates feature-attention (within |
| a row) and datapoint-attention (within a column with the |
| context-vs-query mask). |
| |
| For now, v11 SKIPS v8's metadata conditioning (the column-statistics |
| encoder). The v11 plan defers architectural cleanups to v13; the goal |
| here is data-prior work, not arch work. Once warm-started, the |
| metadata-related parameters in the v8 ckpt are simply ignored. |
| """ |
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.checkpoint import checkpoint as grad_checkpoint |
|
|
| from .heads import ( |
| BarDistributionHead, |
| BinClassificationHead, |
| bar_distribution_loss, |
| cls_masked_loss, |
| standardize_y_per_task, |
| decode_bar_distribution, |
| cls_predict, |
| ) |
|
|
|
|
| |
|
|
|
|
| @dataclass |
| class V11Config: |
| d_model: int = 256 |
| n_layers: int = 12 |
| n_heads: int = 8 |
| d_ffn: int = 1024 |
| dropout: float = 0.0 |
|
|
| max_features: int = 128 |
| max_classes: int = 10 |
| max_context: int = 1024 |
| max_query: int = 256 |
|
|
| n_periodic_freqs: int = 8 |
|
|
| n_bins: int = 1024 |
| cls_label_smoothing: float = 0.05 |
|
|
| |
| |
| mlp_variant: str = "gelu" |
| norm_variant: str = "layernorm" |
| |
| |
| |
| |
| share_factor: int = 1 |
|
|
|
|
| def v11_default_config() -> V11Config: |
| return V11Config() |
|
|
|
|
| |
|
|
|
|
| class SwiGLUFFN(nn.Module): |
| """SwiGLU MLP (Shazeer 2020, arXiv 2002.05202). Default in PaLM/LLaMA. |
| |
| Pattern: Linear(d, 8d/3) gate + Linear(d, 8d/3) value, silu*gate, Linear(8d/3, d). |
| Hidden dim scaled to (8/3)d_ffn/4 = (2/3)d_ffn to hold param count constant |
| vs the legacy GELU FFN (Linear(d, d_ffn), GELU, Linear(d_ffn, d)). |
| """ |
| def __init__(self, d_model: int, d_ffn: int): |
| super().__init__() |
| |
| |
| |
| d_hidden = int(round(d_ffn * 2 / 3)) |
| self.w_gate = nn.Linear(d_model, d_hidden, bias=False) |
| self.w_value = nn.Linear(d_model, d_hidden, bias=False) |
| self.w_out = nn.Linear(d_hidden, d_model, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.w_out(F.silu(self.w_gate(x)) * self.w_value(x)) |
|
|
|
|
| class RMSNorm(nn.Module): |
| """Root Mean Square Layer Norm (Zhang & Sennrich 2019). LLaMA default. |
| |
| No mean subtraction, no learned bias. Cheaper than LayerNorm; works as |
| a drop-in for transformer pre-norm. |
| """ |
| def __init__(self, d_model: int, eps: float = 1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(d_model)) |
| self.eps = eps |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.weight * (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)) |
|
|
|
|
| def _build_ffn(d_model: int, d_ffn: int, variant: str = "gelu") -> nn.Module: |
| """Factory: return GELU MLP (legacy) or SwiGLU MLP based on variant.""" |
| if variant == "swiglu": |
| return SwiGLUFFN(d_model, d_ffn) |
| return nn.Sequential( |
| nn.Linear(d_model, d_ffn), |
| nn.GELU(), |
| nn.Linear(d_ffn, d_model), |
| ) |
|
|
|
|
| def _build_norm(d_model: int, variant: str = "layernorm") -> nn.Module: |
| """Factory: return LayerNorm (legacy) or RMSNorm based on variant.""" |
| if variant == "rmsnorm": |
| return RMSNorm(d_model) |
| return nn.LayerNorm(d_model) |
|
|
|
|
| |
|
|
|
|
| class FlashPreLNAttention(nn.Module): |
| """Pre-LN attention + FFN using F.scaled_dot_product_attention (Flash).""" |
|
|
| def __init__(self, d_model: int, n_heads: int, d_ffn: int, dropout: float = 0.0, |
| mlp_variant: str = "gelu", norm_variant: str = "layernorm"): |
| super().__init__() |
| self.n_heads = n_heads |
| self.head_dim = d_model // n_heads |
| self.d_model = d_model |
|
|
| self.norm1 = _build_norm(d_model, norm_variant) |
| self.q_proj = nn.Linear(d_model, d_model) |
| self.k_proj = nn.Linear(d_model, d_model) |
| self.v_proj = nn.Linear(d_model, d_model) |
| self.o_proj = nn.Linear(d_model, d_model) |
|
|
| self.norm2 = _build_norm(d_model, norm_variant) |
| self.ffn = _build_ffn(d_model, d_ffn, mlp_variant) |
|
|
| def _heads(self, x: torch.Tensor) -> torch.Tensor: |
| B, S, _ = x.shape |
| return x.view(B, S, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| key_padding_mask: Optional[torch.Tensor] = None, |
| attn_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| residual = x |
| x = self.norm1(x) |
| q = self._heads(self.q_proj(x)) |
| k = self._heads(self.k_proj(x)) |
| v = self._heads(self.v_proj(x)) |
|
|
| sdpa_mask = None |
| if attn_mask is not None: |
| |
| if attn_mask.dim() == 2: |
| amask = torch.zeros_like(attn_mask, dtype=q.dtype) |
| amask.masked_fill_(attn_mask, float("-inf")) |
| sdpa_mask = amask.unsqueeze(0).unsqueeze(0) |
| else: |
| amask = torch.zeros_like(attn_mask, dtype=q.dtype) |
| amask.masked_fill_(attn_mask, float("-inf")) |
| sdpa_mask = amask.unsqueeze(1) |
| if key_padding_mask is not None: |
| pad_mask = torch.zeros( |
| key_padding_mask.shape[0], 1, 1, key_padding_mask.shape[1], |
| dtype=q.dtype, device=q.device, |
| ) |
| pad_mask.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")) |
| sdpa_mask = pad_mask if sdpa_mask is None else sdpa_mask + pad_mask |
|
|
| attn_out = F.scaled_dot_product_attention(q, k, v, attn_mask=sdpa_mask, dropout_p=0.0) |
| attn_out = attn_out.transpose(1, 2).contiguous().view(x.shape[0], x.shape[1], self.d_model) |
| x = self.o_proj(attn_out) + residual |
|
|
| residual = x |
| x = self.norm2(x) |
| x = self.ffn(x) + residual |
| return x |
|
|
|
|
| class AlternatingLayerV8(nn.Module): |
| """Feature attention (within rows) → Datapoint attention (within cols). |
| |
| Name matches v8 verbatim so state_dict keys align for warm-start. |
| """ |
|
|
| def __init__(self, d_model: int, n_heads: int, d_ffn: int, dropout: float = 0.0, |
| mlp_variant: str = "gelu", norm_variant: str = "layernorm"): |
| super().__init__() |
| self.feature_attn = FlashPreLNAttention(d_model, n_heads, d_ffn, dropout, |
| mlp_variant=mlp_variant, norm_variant=norm_variant) |
| self.datapoint_attn = FlashPreLNAttention(d_model, n_heads, d_ffn, dropout, |
| mlp_variant=mlp_variant, norm_variant=norm_variant) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| feature_pad_mask: torch.Tensor, |
| datapoint_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| B, n_rows, n_cols, d_model = x.shape |
| |
| x_feat = x.reshape(B * n_rows, n_cols, d_model) |
| feat_pad = feature_pad_mask.unsqueeze(1).expand(B, n_rows, n_cols).reshape(B * n_rows, n_cols) |
| x_feat = self.feature_attn(x_feat, key_padding_mask=feat_pad) |
| x = x_feat.reshape(B, n_rows, n_cols, d_model) |
| |
| x_data = x.permute(0, 2, 1, 3).reshape(B * n_cols, n_rows, d_model) |
| if datapoint_mask.dim() == 3: |
| |
| dp_mask = ( |
| datapoint_mask.unsqueeze(1) |
| .expand(B, n_cols, n_rows, n_rows) |
| .reshape(B * n_cols, n_rows, n_rows) |
| ) |
| else: |
| dp_mask = datapoint_mask |
| x_data = self.datapoint_attn(x_data, attn_mask=dp_mask) |
| x = x_data.reshape(B, n_cols, n_rows, d_model).permute(0, 2, 1, 3) |
| return x |
|
|
|
|
| |
|
|
|
|
| class NumericalFeatureEmbedding(nn.Module): |
| """Embed a scalar numerical value into a d_model vector via Fourier features.""" |
|
|
| def __init__(self, d_model: int = 256, n_freqs: int = 8): |
| super().__init__() |
| self.d_model = d_model |
| self.n_freqs = n_freqs |
| freqs = 2.0 ** torch.arange(n_freqs, dtype=torch.float32) |
| self.register_buffer("freqs", freqs) |
| in_dim = 1 + 1 + 2 * n_freqs |
| self.mlp = nn.Sequential( |
| nn.Linear(in_dim, d_model), |
| nn.GELU(), |
| nn.Linear(d_model, d_model), |
| ) |
| self.missing_token = nn.Parameter(torch.randn(d_model) * 0.02) |
|
|
| def forward(self, values: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| sign = torch.sign(values) |
| log_mag = torch.log1p(torch.abs(values)) |
| |
| f = self.freqs.to(values.device).view(*([1] * (values.dim() - 1)), self.n_freqs) |
| scaled = values.unsqueeze(-1) * f |
| sins = torch.sin(scaled) |
| coss = torch.cos(scaled) |
| feats = torch.cat([sign.unsqueeze(-1), log_mag.unsqueeze(-1), sins, coss], dim=-1) |
| emb = self.mlp(feats) |
| if mask is not None: |
| emb = torch.where(mask.unsqueeze(-1), self.missing_token.expand_as(emb), emb) |
| return emb |
|
|
|
|
| |
|
|
|
|
| @dataclass |
| class V11Output: |
| """Single forward pass output.""" |
| reg_logits: Optional[torch.Tensor] = None |
| cls_logits: Optional[torch.Tensor] = None |
| y_mean: Optional[torch.Tensor] = None |
| y_std: Optional[torch.Tensor] = None |
|
|
|
|
| class PredictLMv11(nn.Module): |
| """ |
| v11 model: same trunk as v8, new heads. |
| |
| Forward returns either reg_logits (for regression) or cls_logits (for |
| classification). For mixed-batch joint training, the trainer should |
| call the model twice — once with task_type='regression' and once with |
| task_type='classification' — sharing the trunk pass via gradient |
| accumulation. (Per-batch-element task_type would require padding to |
| a max-class shape and we keep it simple.) |
| |
| State-dict keys match v8's PredictLMv8 exactly EXCEPT: |
| - reg_head (Linear → BarDistributionHead.mlp) |
| - cls_head (Linear → BinClassificationHead.mlp) |
| All other keys load via load_state_dict(strict=False). |
| """ |
|
|
| def __init__(self, cfg: V11Config = None): |
| super().__init__() |
| cfg = cfg or v11_default_config() |
| self.cfg = cfg |
| |
| |
| |
| self.use_grad_checkpoint = True |
|
|
| |
| self.feature_weights = nn.Parameter(torch.randn(cfg.max_features, cfg.d_model) * 0.02) |
| self.feature_biases = nn.Parameter(torch.zeros(cfg.max_features, cfg.d_model)) |
|
|
| |
| self.y_embed = NumericalFeatureEmbedding(cfg.d_model, n_freqs=cfg.n_periodic_freqs) |
| self.class_embed = nn.Embedding(cfg.max_classes, cfg.d_model) |
| nn.init.normal_(self.class_embed.weight, std=0.02) |
|
|
| |
| self.query_token = nn.Parameter(torch.randn(cfg.d_model) * 0.02) |
| self.type_embed = nn.Embedding(2, cfg.d_model) |
| nn.init.normal_(self.type_embed.weight, std=0.02) |
| self.col_type_embed = nn.Embedding(2, cfg.d_model) |
| nn.init.normal_(self.col_type_embed.weight, std=0.02) |
|
|
| |
| |
| |
| mv = getattr(cfg, "mlp_variant", "gelu") |
| nv = getattr(cfg, "norm_variant", "layernorm") |
| share = max(1, int(getattr(cfg, "share_factor", 1))) |
| _layer = lambda: AlternatingLayerV8( |
| cfg.d_model, cfg.n_heads, cfg.d_ffn, cfg.dropout, |
| mlp_variant=mv, norm_variant=nv, |
| ) |
| n_shared = cfg.n_layers - 4 |
| |
| |
| |
| if n_shared % share != 0 or 4 % share != 0: |
| raise ValueError( |
| f"share_factor={share} must divide both n_shared={n_shared} and 4 (branch layers)" |
| ) |
| n_shared_unique = n_shared // share |
| n_branch_unique = 4 // share |
| self.shared_layers = nn.ModuleList([_layer() for _ in range(n_shared_unique)]) |
| self.reg_layers = nn.ModuleList([_layer() for _ in range(n_branch_unique)]) |
| self.cls_layers = nn.ModuleList([_layer() for _ in range(n_branch_unique)]) |
| self.shared_norm = _build_norm(cfg.d_model, nv) |
| self.reg_norm = _build_norm(cfg.d_model, nv) |
| self.cls_norm = _build_norm(cfg.d_model, nv) |
| |
| self.effective_n_shared = n_shared |
| self.effective_n_branch = 4 |
|
|
| |
| self.reg_head = BarDistributionHead( |
| d_model=cfg.d_model, n_bins=cfg.n_bins, dropout=cfg.dropout, |
| ) |
| self.cls_head = BinClassificationHead( |
| d_model=cfg.d_model, max_classes=cfg.max_classes, dropout=cfg.dropout, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
|
|
| |
| |
| |
| def _build_grid( |
| self, |
| X_ctx: torch.Tensor, |
| y_ctx: torch.Tensor, |
| X_query: torch.Tensor, |
| feature_mask: torch.Tensor, |
| task_type: str, |
| ctx_row_mask: Optional[torch.Tensor] = None, |
| query_row_mask: Optional[torch.Tensor] = None, |
| ): |
| B, n_ctx, n_features = X_ctx.shape |
| n_query = X_query.shape[1] |
| n_rows = n_ctx + n_query |
| max_f = self.cfg.max_features |
| device = X_ctx.device |
|
|
| |
| if feature_mask.any(): |
| real_per_item = (~feature_mask).sum(dim=1) |
| n_real = min(int(real_per_item.max().item()), max_f) |
| else: |
| n_real = min(n_features, max_f) |
| n_real = max(n_real, 2) |
| n_cols = n_real + 1 |
|
|
| X_all = torch.cat([X_ctx, X_query], dim=1) |
| X_real = X_all[:, :, :n_real] |
|
|
| |
| feat_grid = ( |
| X_real.unsqueeze(-1) * self.feature_weights[:n_real] |
| + self.feature_biases[:n_real] |
| ) |
|
|
| |
| if task_type == "classification": |
| y_clamped = y_ctx.long().clamp(0, self.cfg.max_classes - 1) |
| y_emb_ctx = self.class_embed(y_clamped) |
| else: |
| y_emb_ctx = self.y_embed(y_ctx.float()) |
|
|
| y_emb_q = self.query_token.unsqueeze(0).unsqueeze(0).expand(B, n_query, -1) |
| y_emb = torch.cat([y_emb_ctx, y_emb_q], dim=1).unsqueeze(2) |
|
|
| grid = torch.cat([feat_grid, y_emb], dim=2) |
|
|
| |
| type_ids = torch.zeros(B, n_rows, dtype=torch.long, device=device) |
| type_ids[:, n_ctx:] = 1 |
| grid = grid + self.type_embed(type_ids).unsqueeze(2) |
|
|
| col_types = torch.zeros(n_cols, dtype=torch.long, device=device) |
| col_types[-1] = 1 |
| grid = grid + self.col_type_embed(col_types).unsqueeze(0).unsqueeze(0) |
|
|
| |
| feature_pad_mask = torch.zeros(B, n_cols, dtype=torch.bool, device=device) |
| if feature_mask.shape[1] >= n_real: |
| feature_pad_mask[:, :n_real] = feature_mask[:, :n_real] |
|
|
| |
| |
| |
| |
| if ctx_row_mask is None and query_row_mask is None: |
| datapoint_mask = torch.zeros(n_rows, n_rows, dtype=torch.bool, device=device) |
| datapoint_mask[n_ctx:, n_ctx:] = True |
| for i in range(n_query): |
| datapoint_mask[n_ctx + i, n_ctx + i] = False |
| else: |
| row_pad = torch.zeros(B, n_rows, dtype=torch.bool, device=device) |
| if ctx_row_mask is not None: |
| row_pad[:, :n_ctx] = ctx_row_mask |
| if query_row_mask is not None: |
| row_pad[:, n_ctx:] = query_row_mask |
| |
| base = torch.zeros(n_rows, n_rows, dtype=torch.bool, device=device) |
| base[n_ctx:, n_ctx:] = True |
| for i in range(n_query): |
| base[n_ctx + i, n_ctx + i] = False |
| base = base.unsqueeze(0).expand(B, n_rows, n_rows).clone() |
| |
| base = base | row_pad.unsqueeze(1).expand(B, n_rows, n_rows) |
| datapoint_mask = base |
|
|
| return grid, feature_pad_mask, datapoint_mask, n_ctx |
|
|
| |
| |
| |
| def forward( |
| self, |
| X_ctx: torch.Tensor, |
| y_ctx: torch.Tensor, |
| X_query: torch.Tensor, |
| feature_mask: torch.Tensor, |
| task_type: str = "regression", |
| ctx_row_mask: Optional[torch.Tensor] = None, |
| query_row_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Returns logits over bins (reg) or classes (cls). |
| |
| For regression, the trainer is responsible for calling |
| `standardize_y_per_task(y_ctx_orig)` BEFORE this forward to obtain |
| the standardized y_ctx (and stash mean/std for un-standardization). |
| |
| Optional ctx_row_mask / query_row_mask (bool, True=padded row) |
| block padded rows from attention as keys, preventing |
| zero-padded fake-context contamination. |
| """ |
| grid, feat_pad, dp_mask, n_ctx = self._build_grid( |
| X_ctx, y_ctx, X_query, feature_mask, task_type, |
| ctx_row_mask=ctx_row_mask, query_row_mask=query_row_mask, |
| ) |
|
|
| |
| |
| n_uniq_shared = len(self.shared_layers) |
| for i in range(self.effective_n_shared): |
| layer = self.shared_layers[i % n_uniq_shared] |
| if self.training and torch.is_grad_enabled() and self.use_grad_checkpoint: |
| grid = grad_checkpoint(layer, grid, feat_pad, dp_mask, use_reentrant=False) |
| else: |
| grid = layer(grid, feat_pad, dp_mask) |
| grid = self.shared_norm(grid) |
|
|
| |
| if task_type == "regression": |
| h = grid |
| n_uniq_branch = len(self.reg_layers) |
| for i in range(self.effective_n_branch): |
| layer = self.reg_layers[i % n_uniq_branch] |
| if self.training and torch.is_grad_enabled() and self.use_grad_checkpoint: |
| h = grad_checkpoint(layer, h, feat_pad, dp_mask, use_reentrant=False) |
| else: |
| h = layer(h, feat_pad, dp_mask) |
| h = self.reg_norm(h) |
| query_target = h[:, n_ctx:, -1, :] |
| return self.reg_head(query_target) |
|
|
| |
| |
| |
| |
| |
| |
| |
| h = grid |
| n_uniq_branch = len(self.cls_layers) |
| for i in range(self.effective_n_branch): |
| layer = self.cls_layers[i % n_uniq_branch] |
| if self.training and torch.is_grad_enabled() and self.use_grad_checkpoint: |
| h = grad_checkpoint(layer, h, feat_pad, dp_mask, use_reentrant=False) |
| else: |
| h = layer(h, feat_pad, dp_mask) |
| h = self.cls_norm(h) |
| query_target = h[:, n_ctx:, -1, :] |
| return self.cls_head(query_target) |
|
|
| |
| |
| |
| @torch.no_grad() |
| def warm_start_from_v8(self, v8_state_dict: dict, verbose: bool = True) -> dict: |
| """Load v8 trunk weights, leave heads at random init. |
| |
| Args: |
| v8_state_dict: a v8 checkpoint's state_dict |
| Returns: |
| dict with `loaded`, `missing`, `unexpected` key counts |
| """ |
| |
| |
| skip_prefixes = ("reg_head.", "cls_head.", "log_var_reg", "log_var_cls") |
| filtered = { |
| k: v for k, v in v8_state_dict.items() |
| if not k.startswith(skip_prefixes) |
| } |
| |
| |
| |
| |
| target_max = self.cfg.max_features |
| for k in ("feature_weights", "feature_biases"): |
| if k in filtered and filtered[k].shape[0] > target_max: |
| filtered[k] = filtered[k][:target_max] |
| result = self.load_state_dict(filtered, strict=False) |
| if verbose: |
| print(f"[v11.warm_start_from_v8] loaded {len(filtered)} keys") |
| if result.missing_keys: |
| print(f" missing ({len(result.missing_keys)}): {result.missing_keys[:5]}…") |
| if result.unexpected_keys: |
| print(f" unexpected ({len(result.unexpected_keys)}): {result.unexpected_keys[:5]}…") |
| return { |
| "loaded": len(filtered), |
| "missing": len(result.missing_keys), |
| "unexpected": len(result.unexpected_keys), |
| } |
|
|
|
|
| @torch.no_grad() |
| def warm_start_slice_from_v11(self, v11_state_dict: dict, verbose: bool = True) -> dict: |
| """Initialize this (smaller) model from a v11.0 ckpt by SLICING layers. |
| |
| Used when this model has `share_factor > 1`: the v11.0 trunk has |
| `n_layers` unique blocks, but this model has only `n_layers / |
| share_factor` unique blocks (each used `share_factor` times via |
| cycling). We copy every-`share_factor`-th v11.0 block into the |
| student's unique-blocks list. |
| |
| Non-layer modules (feature_weights, y_embed, class_embed, query_token, |
| col_type_embed, shared_norm/reg_norm/cls_norm, reg_head, cls_head) |
| copy verbatim — they're share-factor-independent. |
| |
| Requires this model use legacy (gelu + layernorm) MLP/norm variants |
| for the layer slicing to be shape-compatible. |
| """ |
| if self.cfg.mlp_variant != "gelu" or self.cfg.norm_variant != "layernorm": |
| raise ValueError( |
| "warm_start_slice_from_v11 requires mlp_variant=gelu, " |
| "norm_variant=layernorm for shape compatibility with v11.0 ckpt. " |
| f"Got mlp_variant={self.cfg.mlp_variant}, norm_variant={self.cfg.norm_variant}." |
| ) |
| share = max(1, int(self.cfg.share_factor)) |
|
|
| |
| |
| v11_n_shared = self.cfg.n_layers - 4 |
| v11_n_branch = 4 |
| |
| s_n_shared = v11_n_shared // share |
| s_n_branch = v11_n_branch // share |
| |
| shared_src = list(range(0, v11_n_shared, share))[:s_n_shared] |
| branch_src = list(range(0, v11_n_branch, share))[:s_n_branch] |
|
|
| new_state = {} |
| layer_keys_copied = 0 |
| non_layer_keys_copied = 0 |
|
|
| for k, v in v11_state_dict.items(): |
| |
| if k.startswith("shared_layers."): |
| |
| parts = k.split(".", 2) |
| src_idx = int(parts[1]) |
| if src_idx in shared_src: |
| tgt_idx = shared_src.index(src_idx) |
| new_state[f"shared_layers.{tgt_idx}.{parts[2]}"] = v |
| layer_keys_copied += 1 |
| elif k.startswith("reg_layers."): |
| parts = k.split(".", 2) |
| src_idx = int(parts[1]) |
| if src_idx in branch_src: |
| tgt_idx = branch_src.index(src_idx) |
| new_state[f"reg_layers.{tgt_idx}.{parts[2]}"] = v |
| layer_keys_copied += 1 |
| elif k.startswith("cls_layers."): |
| parts = k.split(".", 2) |
| src_idx = int(parts[1]) |
| if src_idx in branch_src: |
| tgt_idx = branch_src.index(src_idx) |
| new_state[f"cls_layers.{tgt_idx}.{parts[2]}"] = v |
| layer_keys_copied += 1 |
| else: |
| |
| new_state[k] = v |
| non_layer_keys_copied += 1 |
|
|
| result = self.load_state_dict(new_state, strict=False) |
| param_names = {n for n, _ in self.named_parameters()} |
| missing_params = [k for k in result.missing_keys if k in param_names] |
|
|
| if verbose: |
| print(f"[v11.warm_start_slice] share_factor={share}, slice indices: " |
| f"shared={shared_src}, branch={branch_src}") |
| print(f" copied {layer_keys_copied} layer-keys + {non_layer_keys_copied} non-layer keys") |
| if missing_params: |
| print(f" WARN: {len(missing_params)} trainable params unmatched: " |
| f"{missing_params[:5]}{'...' if len(missing_params) > 5 else ''}") |
| if result.unexpected_keys: |
| print(f" ignored {len(result.unexpected_keys)} unexpected keys (e.g., v11.0 layers we didn't slice)") |
| return { |
| "share_factor": share, |
| "layer_keys_copied": layer_keys_copied, |
| "non_layer_keys_copied": non_layer_keys_copied, |
| "missing_params": len(missing_params), |
| "unexpected": len(result.unexpected_keys), |
| } |
|
|
|
|
| def count_params(model: nn.Module) -> int: |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
| |
|
|
|
|
| if __name__ == "__main__": |
| torch.manual_seed(0) |
| cfg = V11Config() |
| model = PredictLMv11(cfg) |
| print(f"v11 model: {count_params(model)/1e6:.1f}M params (cfg={cfg})") |
|
|
| B, n_ctx, n_q, n_f = 2, 64, 16, 8 |
| X_ctx = torch.randn(B, n_ctx, n_f) |
| y_ctx = torch.randn(B, n_ctx) |
| X_q = torch.randn(B, n_q, n_f) |
| feat_mask = torch.zeros(B, n_f, dtype=torch.bool) |
|
|
| |
| reg_logits = model(X_ctx, y_ctx, X_q, feat_mask, task_type="regression") |
| print(f"[reg] logits shape: {tuple(reg_logits.shape)} (expected (2,16,1024))") |
| assert reg_logits.shape == (B, n_q, cfg.n_bins) |
|
|
| loss = bar_distribution_loss(reg_logits, y_ctx[:, :n_q], model.reg_head) |
| print(f"[reg] uniform-prior loss: {loss.item():.3f} (≈ ln(1024) = 6.93)") |
|
|
| |
| y_ctx_cls = torch.randint(0, 5, (B, n_ctx)) |
| cls_logits = model(X_ctx, y_ctx_cls, X_q, feat_mask, task_type="classification") |
| print(f"[cls] logits shape: {tuple(cls_logits.shape)} (expected (2,16,10))") |
| assert cls_logits.shape == (B, n_q, cfg.max_classes) |
|
|
| n_classes_per_task = torch.tensor([3, 5]) |
| y_q_cls = torch.stack([ |
| torch.randint(0, 3, (n_q,)), |
| torch.randint(0, 5, (n_q,)), |
| ]) |
| loss_c = cls_masked_loss(cls_logits, y_q_cls, n_classes_per_task) |
| print(f"[cls] masked loss: {loss_c.item():.3f}") |
|
|
| |
| fake_v8_ckpt = {k: v.clone() for k, v in model.state_dict().items() |
| if not k.startswith("reg_head.") and not k.startswith("cls_head.")} |
| fake_v8_ckpt["reg_head.weight"] = torch.zeros(2, cfg.d_model) |
| fake_v8_ckpt["reg_head.bias"] = torch.zeros(2) |
| fake_v8_ckpt["cls_head.weight"] = torch.zeros(cfg.max_classes, cfg.d_model) |
| fake_v8_ckpt["cls_head.bias"] = torch.zeros(cfg.max_classes) |
| fresh = PredictLMv11(cfg) |
| info = fresh.warm_start_from_v8(fake_v8_ckpt) |
| print(f"[warm-start] loaded={info['loaded']}, missing={info['missing']}, unexpected={info['unexpected']}") |
| assert info['unexpected'] == 0, "v8 reg/cls heads should be filtered, got unexpected" |
|
|
| print("[OK] v11 model self-test passed") |
|
|