File size: 34,190 Bytes
d303b8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
"""
v11 model — same trunk as v8 so we can warm-start from v8's final checkpoint.
The architecture differences vs v8 are the prediction heads:

    v8:    reg_head = Linear(d_model, 2)              # mean, log_var
    v8:    cls_head = Linear(d_model, max_classes)
    v11:   reg_head = BarDistributionHead(d_model, n_bins=1024)
    v11:   cls_head = BinClassificationHead(d_model, max_classes=10)

Everything else (feature_weights, y_embed, class_embed, type_embed,
shared_layers, reg_layers, cls_layers, *_norm) keeps the same module
names and parameter shapes, so:

    v11_model.load_state_dict(v8_ckpt, strict=False)

will load the trunk and leave only the heads as randomly-initialized.
The v11 trainer's head-warmup phase trains only the heads + reg_norm /
cls_norm for the first 5k steps, exactly as v10 did.

Tokenization is identical to v8: 2D grid [B, n_rows, n_cols, d_model]
with one token per cell. Each layer alternates feature-attention (within
a row) and datapoint-attention (within a column with the
context-vs-query mask).

For now, v11 SKIPS v8's metadata conditioning (the column-statistics
encoder). The v11 plan defers architectural cleanups to v13; the goal
here is data-prior work, not arch work. Once warm-started, the
metadata-related parameters in the v8 ckpt are simply ignored.
"""
from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as grad_checkpoint

from .heads import (
    BarDistributionHead,
    BinClassificationHead,
    bar_distribution_loss,
    cls_masked_loss,
    standardize_y_per_task,
    decode_bar_distribution,
    cls_predict,
)


# ─── config ──────────────────────────────────────────────────────────────────


@dataclass
class V11Config:
    d_model: int = 256
    n_layers: int = 12      # 8 shared + 4 task-specific per branch
    n_heads: int = 8
    d_ffn: int = 1024
    dropout: float = 0.0

    max_features: int = 128   # warm-start slices v8's feature_weights[500] → [128] in warm_start_from_v8
    max_classes: int = 10
    max_context: int = 1024
    max_query: int = 256

    n_periodic_freqs: int = 8

    n_bins: int = 1024
    cls_label_smoothing: float = 0.05

    # v11.0.6-tiny architecture toggles. Defaults preserve v11.0 behavior so
    # existing ckpts load unchanged via warm_start_from_v8 / strict=False.
    mlp_variant: str = "gelu"      # "gelu" (legacy) or "swiglu"
    norm_variant: str = "layernorm"  # "layernorm" (legacy) or "rmsnorm"
    # ALBERT-style cross-layer parameter sharing. share_factor>1 means the
    # `n_layers`-deep stack uses only `n_layers // share_factor` UNIQUE
    # modules; each unique block is applied `share_factor` times via index
    # cycling. share_factor=1 = legacy (no sharing).
    share_factor: int = 1


def v11_default_config() -> V11Config:
    return V11Config()


# ─── v11.0.6-tiny blocks (drop-in upgrades behind config flag) ──────────────


class SwiGLUFFN(nn.Module):
    """SwiGLU MLP (Shazeer 2020, arXiv 2002.05202). Default in PaLM/LLaMA.

    Pattern: Linear(d, 8d/3) gate + Linear(d, 8d/3) value, silu*gate, Linear(8d/3, d).
    Hidden dim scaled to (8/3)d_ffn/4 = (2/3)d_ffn to hold param count constant
    vs the legacy GELU FFN (Linear(d, d_ffn), GELU, Linear(d_ffn, d)).
    """
    def __init__(self, d_model: int, d_ffn: int):
        super().__init__()
        # Match legacy FFN's parameter count: legacy is 2 * d_model * d_ffn.
        # SwiGLU is 3 linears (gate, value, out), each d_model * d_hidden.
        # So set d_hidden = (2/3) * d_ffn for parity.
        d_hidden = int(round(d_ffn * 2 / 3))
        self.w_gate = nn.Linear(d_model, d_hidden, bias=False)
        self.w_value = nn.Linear(d_model, d_hidden, bias=False)
        self.w_out = nn.Linear(d_hidden, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w_out(F.silu(self.w_gate(x)) * self.w_value(x))


class RMSNorm(nn.Module):
    """Root Mean Square Layer Norm (Zhang & Sennrich 2019). LLaMA default.

    No mean subtraction, no learned bias. Cheaper than LayerNorm; works as
    a drop-in for transformer pre-norm.
    """
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(d_model))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.weight * (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps))


def _build_ffn(d_model: int, d_ffn: int, variant: str = "gelu") -> nn.Module:
    """Factory: return GELU MLP (legacy) or SwiGLU MLP based on variant."""
    if variant == "swiglu":
        return SwiGLUFFN(d_model, d_ffn)
    return nn.Sequential(
        nn.Linear(d_model, d_ffn),
        nn.GELU(),
        nn.Linear(d_ffn, d_model),
    )


def _build_norm(d_model: int, variant: str = "layernorm") -> nn.Module:
    """Factory: return LayerNorm (legacy) or RMSNorm based on variant."""
    if variant == "rmsnorm":
        return RMSNorm(d_model)
    return nn.LayerNorm(d_model)


# ─── blocks (verbatim from v8 so state_dict keys match) ───────────────────────


class FlashPreLNAttention(nn.Module):
    """Pre-LN attention + FFN using F.scaled_dot_product_attention (Flash)."""

    def __init__(self, d_model: int, n_heads: int, d_ffn: int, dropout: float = 0.0,
                 mlp_variant: str = "gelu", norm_variant: str = "layernorm"):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.d_model = d_model

        self.norm1 = _build_norm(d_model, norm_variant)
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.o_proj = nn.Linear(d_model, d_model)

        self.norm2 = _build_norm(d_model, norm_variant)
        self.ffn = _build_ffn(d_model, d_ffn, mlp_variant)

    def _heads(self, x: torch.Tensor) -> torch.Tensor:
        B, S, _ = x.shape
        return x.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)

    def forward(
        self,
        x: torch.Tensor,
        key_padding_mask: Optional[torch.Tensor] = None,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        residual = x
        x = self.norm1(x)
        q = self._heads(self.q_proj(x))
        k = self._heads(self.k_proj(x))
        v = self._heads(self.v_proj(x))

        sdpa_mask = None
        if attn_mask is not None:
            # attn_mask may be 2D [seq, seq] (shared across batch) or 3D [B, seq, seq]
            if attn_mask.dim() == 2:
                amask = torch.zeros_like(attn_mask, dtype=q.dtype)
                amask.masked_fill_(attn_mask, float("-inf"))
                sdpa_mask = amask.unsqueeze(0).unsqueeze(0)               # [1,1,seq,seq]
            else:
                amask = torch.zeros_like(attn_mask, dtype=q.dtype)
                amask.masked_fill_(attn_mask, float("-inf"))
                sdpa_mask = amask.unsqueeze(1)                            # [B,1,seq,seq]
        if key_padding_mask is not None:
            pad_mask = torch.zeros(
                key_padding_mask.shape[0], 1, 1, key_padding_mask.shape[1],
                dtype=q.dtype, device=q.device,
            )
            pad_mask.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"))
            sdpa_mask = pad_mask if sdpa_mask is None else sdpa_mask + pad_mask

        attn_out = F.scaled_dot_product_attention(q, k, v, attn_mask=sdpa_mask, dropout_p=0.0)
        attn_out = attn_out.transpose(1, 2).contiguous().view(x.shape[0], x.shape[1], self.d_model)
        x = self.o_proj(attn_out) + residual

        residual = x
        x = self.norm2(x)
        x = self.ffn(x) + residual
        return x


class AlternatingLayerV8(nn.Module):
    """Feature attention (within rows) → Datapoint attention (within cols).

    Name matches v8 verbatim so state_dict keys align for warm-start.
    """

    def __init__(self, d_model: int, n_heads: int, d_ffn: int, dropout: float = 0.0,
                 mlp_variant: str = "gelu", norm_variant: str = "layernorm"):
        super().__init__()
        self.feature_attn = FlashPreLNAttention(d_model, n_heads, d_ffn, dropout,
                                                mlp_variant=mlp_variant, norm_variant=norm_variant)
        self.datapoint_attn = FlashPreLNAttention(d_model, n_heads, d_ffn, dropout,
                                                  mlp_variant=mlp_variant, norm_variant=norm_variant)

    def forward(
        self,
        x: torch.Tensor,                # [B, n_rows, n_cols, d_model]
        feature_pad_mask: torch.Tensor,
        datapoint_mask: torch.Tensor,   # [n_rows, n_rows] OR [B, n_rows, n_rows]
    ) -> torch.Tensor:
        B, n_rows, n_cols, d_model = x.shape
        # within-row feature attn
        x_feat = x.reshape(B * n_rows, n_cols, d_model)
        feat_pad = feature_pad_mask.unsqueeze(1).expand(B, n_rows, n_cols).reshape(B * n_rows, n_cols)
        x_feat = self.feature_attn(x_feat, key_padding_mask=feat_pad)
        x = x_feat.reshape(B, n_rows, n_cols, d_model)
        # within-col datapoint attn — expand per-batch mask along n_cols if needed
        x_data = x.permute(0, 2, 1, 3).reshape(B * n_cols, n_rows, d_model)
        if datapoint_mask.dim() == 3:
            # [B, n_rows, n_rows] → [B*n_cols, n_rows, n_rows]
            dp_mask = (
                datapoint_mask.unsqueeze(1)
                .expand(B, n_cols, n_rows, n_rows)
                .reshape(B * n_cols, n_rows, n_rows)
            )
        else:
            dp_mask = datapoint_mask
        x_data = self.datapoint_attn(x_data, attn_mask=dp_mask)
        x = x_data.reshape(B, n_cols, n_rows, d_model).permute(0, 2, 1, 3)
        return x


# ─── numerical-value embedding (matches v8's NumericalFeatureEmbedding) ──────


class NumericalFeatureEmbedding(nn.Module):
    """Embed a scalar numerical value into a d_model vector via Fourier features."""

    def __init__(self, d_model: int = 256, n_freqs: int = 8):
        super().__init__()
        self.d_model = d_model
        self.n_freqs = n_freqs
        freqs = 2.0 ** torch.arange(n_freqs, dtype=torch.float32)
        self.register_buffer("freqs", freqs)
        in_dim = 1 + 1 + 2 * n_freqs   # sign + log_mag + sin/cos at each freq
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model),
        )
        self.missing_token = nn.Parameter(torch.randn(d_model) * 0.02)

    def forward(self, values: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        sign = torch.sign(values)
        log_mag = torch.log1p(torch.abs(values))
        # Sinusoidal features at multiple frequencies
        f = self.freqs.to(values.device).view(*([1] * (values.dim() - 1)), self.n_freqs)
        scaled = values.unsqueeze(-1) * f
        sins = torch.sin(scaled)
        coss = torch.cos(scaled)
        feats = torch.cat([sign.unsqueeze(-1), log_mag.unsqueeze(-1), sins, coss], dim=-1)
        emb = self.mlp(feats)
        if mask is not None:
            emb = torch.where(mask.unsqueeze(-1), self.missing_token.expand_as(emb), emb)
        return emb


# ─── main v11 model ──────────────────────────────────────────────────────────


@dataclass
class V11Output:
    """Single forward pass output."""
    reg_logits: Optional[torch.Tensor] = None     # [B, n_query, n_bins] for reg
    cls_logits: Optional[torch.Tensor] = None     # [B, n_query, max_classes] for cls
    y_mean: Optional[torch.Tensor] = None         # [B] context y mean (reg only)
    y_std: Optional[torch.Tensor] = None          # [B] context y std (reg only)


class PredictLMv11(nn.Module):
    """
    v11 model: same trunk as v8, new heads.

    Forward returns either reg_logits (for regression) or cls_logits (for
    classification). For mixed-batch joint training, the trainer should
    call the model twice — once with task_type='regression' and once with
    task_type='classification' — sharing the trunk pass via gradient
    accumulation. (Per-batch-element task_type would require padding to
    a max-class shape and we keep it simple.)

    State-dict keys match v8's PredictLMv8 exactly EXCEPT:
      - reg_head (Linear → BarDistributionHead.mlp)
      - cls_head (Linear → BinClassificationHead.mlp)
    All other keys load via load_state_dict(strict=False).
    """

    def __init__(self, cfg: V11Config = None):
        super().__init__()
        cfg = cfg or v11_default_config()
        self.cfg = cfg
        # Toggle gradient checkpointing. Default True (memory-conservative,
        # for H100/T4 sized batches). On A100 80GB we can disable for ~2-3×
        # throughput when memory permits. Set via `model.use_grad_checkpoint = False`.
        self.use_grad_checkpoint = True

        # Per-feature projection (same as v8)
        self.feature_weights = nn.Parameter(torch.randn(cfg.max_features, cfg.d_model) * 0.02)
        self.feature_biases = nn.Parameter(torch.zeros(cfg.max_features, cfg.d_model))

        # y embeddings
        self.y_embed = NumericalFeatureEmbedding(cfg.d_model, n_freqs=cfg.n_periodic_freqs)
        self.class_embed = nn.Embedding(cfg.max_classes, cfg.d_model)
        nn.init.normal_(self.class_embed.weight, std=0.02)

        # tokens
        self.query_token = nn.Parameter(torch.randn(cfg.d_model) * 0.02)
        self.type_embed = nn.Embedding(2, cfg.d_model)
        nn.init.normal_(self.type_embed.weight, std=0.02)
        self.col_type_embed = nn.Embedding(2, cfg.d_model)
        nn.init.normal_(self.col_type_embed.weight, std=0.02)

        # trunk: 8 shared + 4 reg + 4 cls
        # v11.0.6-tiny: variant flags flow through to FFN/norm choice; defaults
        # preserve v11.0 layout for backward-compat with existing ckpts.
        mv = getattr(cfg, "mlp_variant", "gelu")
        nv = getattr(cfg, "norm_variant", "layernorm")
        share = max(1, int(getattr(cfg, "share_factor", 1)))
        _layer = lambda: AlternatingLayerV8(
            cfg.d_model, cfg.n_heads, cfg.d_ffn, cfg.dropout,
            mlp_variant=mv, norm_variant=nv,
        )
        n_shared = cfg.n_layers - 4
        # Under share_factor>1, build only n//share unique blocks; the
        # forward pass cycles through them. n_shared and n_branch (=4) must
        # both be divisible by share_factor.
        if n_shared % share != 0 or 4 % share != 0:
            raise ValueError(
                f"share_factor={share} must divide both n_shared={n_shared} and 4 (branch layers)"
            )
        n_shared_unique = n_shared // share
        n_branch_unique = 4 // share
        self.shared_layers = nn.ModuleList([_layer() for _ in range(n_shared_unique)])
        self.reg_layers = nn.ModuleList([_layer() for _ in range(n_branch_unique)])
        self.cls_layers = nn.ModuleList([_layer() for _ in range(n_branch_unique)])
        self.shared_norm = _build_norm(cfg.d_model, nv)
        self.reg_norm = _build_norm(cfg.d_model, nv)
        self.cls_norm = _build_norm(cfg.d_model, nv)
        # Stored for forward to know how many depth-passes to do.
        self.effective_n_shared = n_shared
        self.effective_n_branch = 4

        # v11 heads
        self.reg_head = BarDistributionHead(
            d_model=cfg.d_model, n_bins=cfg.n_bins, dropout=cfg.dropout,
        )
        self.cls_head = BinClassificationHead(
            d_model=cfg.d_model, max_classes=cfg.max_classes, dropout=cfg.dropout,
        )

        # NOTE: v8's `log_var_reg` / `log_var_cls` Kendall-style task weights
        # are intentionally NOT instantiated here. They were declared but
        # never read in the v11 trainer, and ratio-balancing reg/cls via
        # alternation + curriculum bias is sufficient at this scale per
        # Expert 4. If they appear in a v8 checkpoint, `warm_start_from_v8`
        # filters them out via `strict=False` (they land in `unexpected_keys`).
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    # ──────────────────────────────────────────────────────────────
    # Internal: build the [B, n_rows, n_cols, d_model] grid
    # ──────────────────────────────────────────────────────────────
    def _build_grid(
        self,
        X_ctx: torch.Tensor,                # [B, n_ctx, n_features]
        y_ctx: torch.Tensor,                # [B, n_ctx]
        X_query: torch.Tensor,              # [B, n_query, n_features]
        feature_mask: torch.Tensor,         # [B, n_features] bool, True=padded
        task_type: str,
        ctx_row_mask: Optional[torch.Tensor] = None,   # [B, n_ctx] bool, True=padded
        query_row_mask: Optional[torch.Tensor] = None, # [B, n_query] bool, True=padded
    ):
        B, n_ctx, n_features = X_ctx.shape
        n_query = X_query.shape[1]
        n_rows = n_ctx + n_query
        max_f = self.cfg.max_features
        device = X_ctx.device

        # Effective feature count
        if feature_mask.any():
            real_per_item = (~feature_mask).sum(dim=1)
            n_real = min(int(real_per_item.max().item()), max_f)
        else:
            n_real = min(n_features, max_f)
        n_real = max(n_real, 2)
        n_cols = n_real + 1

        X_all = torch.cat([X_ctx, X_query], dim=1)            # [B, n_rows, n_features]
        X_real = X_all[:, :, :n_real]                          # [B, n_rows, n_real]

        # Per-feature projection
        feat_grid = (
            X_real.unsqueeze(-1) * self.feature_weights[:n_real]
            + self.feature_biases[:n_real]
        )                                                       # [B, n_rows, n_real, d_model]

        # Target column embedding
        if task_type == "classification":
            y_clamped = y_ctx.long().clamp(0, self.cfg.max_classes - 1)
            y_emb_ctx = self.class_embed(y_clamped)             # [B, n_ctx, d_model]
        else:
            y_emb_ctx = self.y_embed(y_ctx.float())             # [B, n_ctx, d_model]

        y_emb_q = self.query_token.unsqueeze(0).unsqueeze(0).expand(B, n_query, -1)
        y_emb = torch.cat([y_emb_ctx, y_emb_q], dim=1).unsqueeze(2)   # [B, n_rows, 1, d_model]

        grid = torch.cat([feat_grid, y_emb], dim=2)             # [B, n_rows, n_cols, d_model]

        # Type (ctx vs query) and column-type (feature vs target) embeds
        type_ids = torch.zeros(B, n_rows, dtype=torch.long, device=device)
        type_ids[:, n_ctx:] = 1
        grid = grid + self.type_embed(type_ids).unsqueeze(2)

        col_types = torch.zeros(n_cols, dtype=torch.long, device=device)
        col_types[-1] = 1
        grid = grid + self.col_type_embed(col_types).unsqueeze(0).unsqueeze(0)

        # Feature-pad mask
        feature_pad_mask = torch.zeros(B, n_cols, dtype=torch.bool, device=device)
        if feature_mask.shape[1] >= n_real:
            feature_pad_mask[:, :n_real] = feature_mask[:, :n_real]

        # Datapoint mask: query rows can't attend to other query rows (they each
        # predict independently). If ctx_row_mask / query_row_mask are provided,
        # padded rows are also blocked from being keys (per-batch [B, n_rows, n_rows]).
        # Without row-pad masks, build the simple [n_rows, n_rows] shared mask.
        if ctx_row_mask is None and query_row_mask is None:
            datapoint_mask = torch.zeros(n_rows, n_rows, dtype=torch.bool, device=device)
            datapoint_mask[n_ctx:, n_ctx:] = True
            for i in range(n_query):
                datapoint_mask[n_ctx + i, n_ctx + i] = False
        else:
            row_pad = torch.zeros(B, n_rows, dtype=torch.bool, device=device)
            if ctx_row_mask is not None:
                row_pad[:, :n_ctx] = ctx_row_mask
            if query_row_mask is not None:
                row_pad[:, n_ctx:] = query_row_mask
            # base [n_rows, n_rows] block-mask: query↔query disallowed except diag
            base = torch.zeros(n_rows, n_rows, dtype=torch.bool, device=device)
            base[n_ctx:, n_ctx:] = True
            for i in range(n_query):
                base[n_ctx + i, n_ctx + i] = False
            base = base.unsqueeze(0).expand(B, n_rows, n_rows).clone()
            # block any KEY row that is padded (broadcast over queries)
            base = base | row_pad.unsqueeze(1).expand(B, n_rows, n_rows)
            datapoint_mask = base

        return grid, feature_pad_mask, datapoint_mask, n_ctx

    # ──────────────────────────────────────────────────────────────
    # Forward
    # ──────────────────────────────────────────────────────────────
    def forward(
        self,
        X_ctx: torch.Tensor,
        y_ctx: torch.Tensor,
        X_query: torch.Tensor,
        feature_mask: torch.Tensor,
        task_type: str = "regression",
        ctx_row_mask: Optional[torch.Tensor] = None,
        query_row_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Returns logits over bins (reg) or classes (cls).

        For regression, the trainer is responsible for calling
        `standardize_y_per_task(y_ctx_orig)` BEFORE this forward to obtain
        the standardized y_ctx (and stash mean/std for un-standardization).

        Optional ctx_row_mask / query_row_mask (bool, True=padded row)
        block padded rows from attention as keys, preventing
        zero-padded fake-context contamination.
        """
        grid, feat_pad, dp_mask, n_ctx = self._build_grid(
            X_ctx, y_ctx, X_query, feature_mask, task_type,
            ctx_row_mask=ctx_row_mask, query_row_mask=query_row_mask,
        )

        # Shared trunk. Under share_factor>1, len(self.shared_layers) may be
        # < effective_n_shared; cycle via modulo index (ALBERT pattern).
        n_uniq_shared = len(self.shared_layers)
        for i in range(self.effective_n_shared):
            layer = self.shared_layers[i % n_uniq_shared]
            if self.training and torch.is_grad_enabled() and self.use_grad_checkpoint:
                grid = grad_checkpoint(layer, grid, feat_pad, dp_mask, use_reentrant=False)
            else:
                grid = layer(grid, feat_pad, dp_mask)
        grid = self.shared_norm(grid)

        # Task-specific layers
        if task_type == "regression":
            h = grid
            n_uniq_branch = len(self.reg_layers)
            for i in range(self.effective_n_branch):
                layer = self.reg_layers[i % n_uniq_branch]
                if self.training and torch.is_grad_enabled() and self.use_grad_checkpoint:
                    h = grad_checkpoint(layer, h, feat_pad, dp_mask, use_reentrant=False)
                else:
                    h = layer(h, feat_pad, dp_mask)
            h = self.reg_norm(h)
            query_target = h[:, n_ctx:, -1, :]                # [B, n_query, d_model]
            return self.reg_head(query_target)                 # [B, n_query, n_bins]

        # classification — symmetric grad flow with reg path. Earlier
        # versions had `h = 0.5*grid + 0.5*grid.detach()` here, which
        # halved the cls branch's gradient into the shared trunk while
        # the reg branch passed full gradient. Combined with bar-dist
        # reg loss being ~3× larger by magnitude than cls (ln(1024) vs
        # ln(10)) and 50/50 step alternation, the trunk was receiving
        # ~6× more reg signal than cls signal per step. Removed.
        h = grid
        n_uniq_branch = len(self.cls_layers)
        for i in range(self.effective_n_branch):
            layer = self.cls_layers[i % n_uniq_branch]
            if self.training and torch.is_grad_enabled() and self.use_grad_checkpoint:
                h = grad_checkpoint(layer, h, feat_pad, dp_mask, use_reentrant=False)
            else:
                h = layer(h, feat_pad, dp_mask)
        h = self.cls_norm(h)
        query_target = h[:, n_ctx:, -1, :]
        return self.cls_head(query_target)                     # [B, n_query, max_classes]

    # ──────────────────────────────────────────────────────────────
    # Convenience: warm-start from v8 checkpoint
    # ──────────────────────────────────────────────────────────────
    @torch.no_grad()
    def warm_start_from_v8(self, v8_state_dict: dict, verbose: bool = True) -> dict:
        """Load v8 trunk weights, leave heads at random init.

        Args:
            v8_state_dict: a v8 checkpoint's state_dict
        Returns:
            dict with `loaded`, `missing`, `unexpected` key counts
        """
        # Filter out v8's old reg_head / cls_head (shape-incompatible) and
        # the dead log_var weights (removed in v11).
        skip_prefixes = ("reg_head.", "cls_head.", "log_var_reg", "log_var_cls")
        filtered = {
            k: v for k, v in v8_state_dict.items()
            if not k.startswith(skip_prefixes)
        }
        # Slice feature_weights / feature_biases if v8 ckpt has more features
        # than v11's max_features (v8 used 500, v11 default 128 for VRAM).
        # Keep the first N rows (v8 trained on tasks that primarily used the
        # earliest column slots).
        target_max = self.cfg.max_features
        for k in ("feature_weights", "feature_biases"):
            if k in filtered and filtered[k].shape[0] > target_max:
                filtered[k] = filtered[k][:target_max]
        result = self.load_state_dict(filtered, strict=False)
        if verbose:
            print(f"[v11.warm_start_from_v8] loaded {len(filtered)} keys")
            if result.missing_keys:
                print(f"  missing  ({len(result.missing_keys)}): {result.missing_keys[:5]}…")
            if result.unexpected_keys:
                print(f"  unexpected ({len(result.unexpected_keys)}): {result.unexpected_keys[:5]}…")
        return {
            "loaded": len(filtered),
            "missing": len(result.missing_keys),
            "unexpected": len(result.unexpected_keys),
        }


    @torch.no_grad()
    def warm_start_slice_from_v11(self, v11_state_dict: dict, verbose: bool = True) -> dict:
        """Initialize this (smaller) model from a v11.0 ckpt by SLICING layers.

        Used when this model has `share_factor > 1`: the v11.0 trunk has
        `n_layers` unique blocks, but this model has only `n_layers /
        share_factor` unique blocks (each used `share_factor` times via
        cycling). We copy every-`share_factor`-th v11.0 block into the
        student's unique-blocks list.

        Non-layer modules (feature_weights, y_embed, class_embed, query_token,
        col_type_embed, shared_norm/reg_norm/cls_norm, reg_head, cls_head)
        copy verbatim — they're share-factor-independent.

        Requires this model use legacy (gelu + layernorm) MLP/norm variants
        for the layer slicing to be shape-compatible.
        """
        if self.cfg.mlp_variant != "gelu" or self.cfg.norm_variant != "layernorm":
            raise ValueError(
                "warm_start_slice_from_v11 requires mlp_variant=gelu, "
                "norm_variant=layernorm for shape compatibility with v11.0 ckpt. "
                f"Got mlp_variant={self.cfg.mlp_variant}, norm_variant={self.cfg.norm_variant}."
            )
        share = max(1, int(self.cfg.share_factor))

        # Build the source→target index map for layer slicing.
        # v11.0 trunk: 8 shared + 4 reg + 4 cls
        v11_n_shared = self.cfg.n_layers - 4  # 8 typically
        v11_n_branch = 4
        # Student unique counts
        s_n_shared = v11_n_shared // share
        s_n_branch = v11_n_branch // share
        # Pick every share-th index from v11.0
        shared_src = list(range(0, v11_n_shared, share))[:s_n_shared]
        branch_src = list(range(0, v11_n_branch, share))[:s_n_branch]

        new_state = {}
        layer_keys_copied = 0
        non_layer_keys_copied = 0

        for k, v in v11_state_dict.items():
            # Layer-keyed weights: rewrite the layer index per the slicing map.
            if k.startswith("shared_layers."):
                # k = "shared_layers.<idx>.<rest>"
                parts = k.split(".", 2)
                src_idx = int(parts[1])
                if src_idx in shared_src:
                    tgt_idx = shared_src.index(src_idx)
                    new_state[f"shared_layers.{tgt_idx}.{parts[2]}"] = v
                    layer_keys_copied += 1
            elif k.startswith("reg_layers."):
                parts = k.split(".", 2)
                src_idx = int(parts[1])
                if src_idx in branch_src:
                    tgt_idx = branch_src.index(src_idx)
                    new_state[f"reg_layers.{tgt_idx}.{parts[2]}"] = v
                    layer_keys_copied += 1
            elif k.startswith("cls_layers."):
                parts = k.split(".", 2)
                src_idx = int(parts[1])
                if src_idx in branch_src:
                    tgt_idx = branch_src.index(src_idx)
                    new_state[f"cls_layers.{tgt_idx}.{parts[2]}"] = v
                    layer_keys_copied += 1
            else:
                # Non-layer weights copy verbatim.
                new_state[k] = v
                non_layer_keys_copied += 1

        result = self.load_state_dict(new_state, strict=False)
        param_names = {n for n, _ in self.named_parameters()}
        missing_params = [k for k in result.missing_keys if k in param_names]

        if verbose:
            print(f"[v11.warm_start_slice] share_factor={share}, slice indices: "
                  f"shared={shared_src}, branch={branch_src}")
            print(f"  copied {layer_keys_copied} layer-keys + {non_layer_keys_copied} non-layer keys")
            if missing_params:
                print(f"  WARN: {len(missing_params)} trainable params unmatched: "
                      f"{missing_params[:5]}{'...' if len(missing_params) > 5 else ''}")
            if result.unexpected_keys:
                print(f"  ignored {len(result.unexpected_keys)} unexpected keys (e.g., v11.0 layers we didn't slice)")
        return {
            "share_factor": share,
            "layer_keys_copied": layer_keys_copied,
            "non_layer_keys_copied": non_layer_keys_copied,
            "missing_params": len(missing_params),
            "unexpected": len(result.unexpected_keys),
        }


def count_params(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# ─── self-test: forward pass shapes + warm-start sanity ───────────────────────


if __name__ == "__main__":
    torch.manual_seed(0)
    cfg = V11Config()
    model = PredictLMv11(cfg)
    print(f"v11 model: {count_params(model)/1e6:.1f}M params  (cfg={cfg})")

    B, n_ctx, n_q, n_f = 2, 64, 16, 8
    X_ctx = torch.randn(B, n_ctx, n_f)
    y_ctx = torch.randn(B, n_ctx)
    X_q = torch.randn(B, n_q, n_f)
    feat_mask = torch.zeros(B, n_f, dtype=torch.bool)

    # Regression path
    reg_logits = model(X_ctx, y_ctx, X_q, feat_mask, task_type="regression")
    print(f"[reg] logits shape: {tuple(reg_logits.shape)}  (expected (2,16,1024))")
    assert reg_logits.shape == (B, n_q, cfg.n_bins)

    loss = bar_distribution_loss(reg_logits, y_ctx[:, :n_q], model.reg_head)
    print(f"[reg] uniform-prior loss: {loss.item():.3f}  (≈ ln(1024) = 6.93)")

    # Classification path
    y_ctx_cls = torch.randint(0, 5, (B, n_ctx))
    cls_logits = model(X_ctx, y_ctx_cls, X_q, feat_mask, task_type="classification")
    print(f"[cls] logits shape: {tuple(cls_logits.shape)}  (expected (2,16,10))")
    assert cls_logits.shape == (B, n_q, cfg.max_classes)

    n_classes_per_task = torch.tensor([3, 5])
    y_q_cls = torch.stack([
        torch.randint(0, 3, (n_q,)),
        torch.randint(0, 5, (n_q,)),
    ])
    loss_c = cls_masked_loss(cls_logits, y_q_cls, n_classes_per_task)
    print(f"[cls] masked loss: {loss_c.item():.3f}")

    # Warm-start dry run: simulate a v8 ckpt with wrong-shape heads
    fake_v8_ckpt = {k: v.clone() for k, v in model.state_dict().items()
                    if not k.startswith("reg_head.") and not k.startswith("cls_head.")}
    fake_v8_ckpt["reg_head.weight"] = torch.zeros(2, cfg.d_model)   # v8 shape
    fake_v8_ckpt["reg_head.bias"] = torch.zeros(2)
    fake_v8_ckpt["cls_head.weight"] = torch.zeros(cfg.max_classes, cfg.d_model)
    fake_v8_ckpt["cls_head.bias"] = torch.zeros(cfg.max_classes)
    fresh = PredictLMv11(cfg)
    info = fresh.warm_start_from_v8(fake_v8_ckpt)
    print(f"[warm-start] loaded={info['loaded']}, missing={info['missing']}, unexpected={info['unexpected']}")
    assert info['unexpected'] == 0, "v8 reg/cls heads should be filtered, got unexpected"

    print("[OK] v11 model self-test passed")