File size: 39,906 Bytes
8b0aeb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
#!/usr/bin/env python3
"""
inference_sad.py – Block-wise hierarchical diffusion sampling from a trained
SADModel.

Generation proceeds block by block left-to-right. Within each block, a small
random subset of non-leaf positions is advanced each round to some strictly
finer level in the hierarchy
    mask (level K+1)  >  ancestors (K, …, 1)  >  leaf (level 0)
A transition may jump any number of levels (e.g. mask β†’ leaf directly, or
ancestor l β†’ ancestor l' with l' < l, or ancestor β†’ leaf) as long as the new
level is strictly finer than the current one β€” never stay, never revert.
Rounds repeat until every position in the block is leaf; then the next block
begins.

Each denoising round:
  1. One forward pass on the current block (K/V cache holds earlier blocks).
  2. Softmax the leaf logits and project through the fixed LUT
     (`AncestorTable.projection_matrix`) into every strictly-finer ancestor
     level; max over each distribution gives per-level confidence (used to
     rank candidate levels). For ancestor levels the conf is multiplied by
     a per-level scalar λ_l ∈ [0, 1] before the cross-level comparison
     (smaller Ξ»_l biases the schedule away from that ancestor level β€”
     Ξ»_l = 0 disables it; the default Ξ» = 1 reproduces the original
     behavior). Leaf (l=0) is never scaled. The target id is then produced
     per-level:
       - leaf level (l=0):     argmax over the leaf distribution (deterministic)
       - ancestor level (lβ‰₯1): multinomial sampling from the cluster dist. (stochastic)
     Cross-level confidence is always computed from the original (temperature=1)
     softmax so that leaf and ancestor probabilities are comparable.
  3. Randomly pick `positions_per_step` non-leaf positions per sample and
     transition each to its best strictly-finer level.

Finalized blocks' K/V are cached so forwards only recompute the current block.

Usage:
    python scripts/inference_sad.py \\
        --config configs/sad_owt.yaml \\
        --checkpoint outputs/sad/latest.pt \\
        --num_samples 4
"""

from __future__ import annotations

import argparse
import sys
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]  # sad/
from typing import Optional

import torch
import torch.nn.functional as F
import yaml

sys.path.insert(0, str(ROOT))

from src.models.sad_model import SADModel
from src.models.dit_components import apply_rotary_pos_emb, modulate_fused
from src.diffusion.ancestor_table import AncestorTable
from src.data import build_owt_dataloader
from einops import rearrange


# ─────────────────────────────────────────────────────────────────────────────
# Sampler
# ─────────────────────────────────────────────────────────────────────────────

class BlockDiffusionSampler:
    """
    Block-wise hierarchical diffusion sampler for SADModel.

    State per position is (level, value):
      level = 0          β†’ leaf token; value = token id
      level ∈ [1, K]     β†’ ancestor at level l; value = cluster id in K_l
      level = K + 1      β†’ mask

    Per-block denoising loop (random position selection, strict-descent schedule):
      Until every position in the block is leaf:
        1. Forward pass on the current block (cache holds earlier blocks).
        2. Vectorized over all block positions, project the leaf softmax
           through the LUT:
             leaf target     (l=0): prob = softmax(logits)
             ancestor target (lβ‰₯1): prob = softmax(logits) @ W_l   [V, K_l]
           Each candidate level contributes (conf, id): conf is the max-prob
           (used only to compare levels). The id is argmax if the level is
           leaf (l=0) and a multinomial draw if the level is ancestor (lβ‰₯1)
           β€” so only the final landing in the leaf layer is deterministic,
           while intermediate ancestor steps are stochastic. Only levels
           strictly finer than the position's current level are eligible β€”
           so mask β†’ leaf (skipping every ancestor) is a legal transition,
           as is any multi-level jump. The eligible level with the highest
           confidence wins.
        3. Randomly pick `positions_per_step` non-leaf positions per sample
           and apply the selected transition at those positions only.
    """

    def __init__(
        self,
        model: SADModel,
        ancestor_table: AncestorTable,
        tokenizer,
        device: torch.device,
        dtype: torch.dtype = torch.bfloat16,
        level_lambdas: Optional[list] = None,
        leaf_temperature: float = 1.0,
    ):
        """
        level_lambdas: length-K list of floats in [0, 1]. Ξ»_l (for ancestor
            level l = 1..K) multiplies that level's max-prob conf before the
            cross-level argmax that picks the winning target. Leaf (l=0) is
            never scaled. None β†’ all ones (original behavior).
        leaf_temperature: temperature applied to leaf logits before softmax.
            Values < 1.0 sharpen the leaf distribution (higher confidence),
            which is then used for both leaf sampling and ancestor projection.
            Default 1.0 (no temperature scaling).
        """
        self.model = model
        self.ancestor_table = ancestor_table
        self.tokenizer = tokenizer
        self.device = device
        self.dtype = dtype
        self.leaf_temperature = float(leaf_temperature)

        self.block_size: int = model.block_size
        self.max_seq_len: int = model.max_seq_len
        self.vocab_size: int = model.vocab_size
        self.mask_id: int = tokenizer.mask_token_id
        assert self.mask_id is not None, "tokenizer must have mask_token_id"

        self.K: int = ancestor_table.num_levels      # number of ancestor levels
        self.mask_level: int = self.K + 1

        if level_lambdas is None:
            level_lambdas = [1.0] * self.K
        assert len(level_lambdas) == self.K, (
            f"level_lambdas must have length K={self.K}, got {len(level_lambdas)}"
        )
        for x in level_lambdas:
            assert 0.0 <= float(x) <= 1.0, f"each Ξ» must be in [0, 1], got {x}"
        # 1-indexed: self.level_lambdas[l] is λ_l for ancestor level l ∈ [1, K]
        self.level_lambdas = [None] + [float(x) for x in level_lambdas]

        # Leaf embedding table (tied with output head β€” read-only view).
        self.leaf_emb = model.get_leaf_embeddings().to(device=device, dtype=dtype).detach()
        self.mask_emb = self.leaf_emb[self.mask_id]  # [d]

        # Ancestor embeddings per level: fed into the model, so keep them in
        # self.dtype to match model weights.
        self.anc_embs = [None] + [
            ancestor_table.ancestor_embeddings(l).to(device=device, dtype=dtype).detach()
            for l in range(1, self.K + 1)
        ]
        # LUT projection matrices W_l: used only on the scoring side (fp32).
        # Fixed buffers, no grad, so fp32 storage is cheap.
        self.W = [None] + [
            ancestor_table.projection_matrix(l).to(device=device, dtype=torch.float32).detach()
            for l in range(1, self.K + 1)
        ]

    # ───────────────────────────────────────────────────────────────────────
    def _build_mixed_embeddings(
        self, level_ids: torch.Tensor, value_ids: torch.Tensor,
    ) -> torch.Tensor:
        """
        Build [B, S, d] input embeddings from per-position (level, value).

        Mirrors NoisyStateBuilder.build_noisy_embeddings so inference-time
        inputs match the training distribution.
        """
        B, S = level_ids.shape
        d = self.leaf_emb.shape[-1]
        embs = torch.empty(B, S, d, device=self.device, dtype=self.dtype)

        # leaf (level 0) β€” leaf_emb[value]
        m0 = (level_ids == 0)
        if m0.any():
            embs[m0] = self.leaf_emb[value_ids[m0]]

        # mask (level K+1) β€” leaf_emb[mask_id]
        mM = (level_ids == self.mask_level)
        if mM.any():
            embs[mM] = self.mask_emb

        # ancestor levels 1..K β€” anc_embs[l][value]
        for l in range(1, self.K + 1):
            ml = (level_ids == l)
            if ml.any():
                embs[ml] = self.anc_embs[l][value_ids[ml]]

        return embs

    # ───────────────────────────────────────────────────────────────────────
    # KV-cache–aware forward. The key observation: under the block-causal mask,
    # the K/V produced at positions in finalized (leaf) earlier blocks are
    # deterministic and never change. So we compute them once per block and
    # reuse them across all denoising rounds of the current block.
    #
    # This method inlines DDiTBlockWithMask.forward so we can (a) accept a K/V
    # prefix cache, (b) avoid recomputing Q/K/V for earlier blocks. When
    # k_prefix is None it also serves as an uncached single-block pass (used
    # for prompt blocks and the final K/V capture).
    # ───────────────────────────────────────────────────────────────────────
    def _run_layer_cached(
        self,
        layer_idx: int,
        x: torch.Tensor,
        rotary_cos_sin,
        c: torch.Tensor,
        k_prefix: Optional[torch.Tensor] = None,
        v_prefix: Optional[torch.Tensor] = None,
    ):
        """
        Run one DiT block on `x` (current block positions only) with an
        optional cached K/V prefix.

        Args:
            layer_idx:         index into self.model.blocks
            x:                 [B, bs, d] current block hidden state
            rotary_cos_sin:    rotary cos/sin for positions block_start..block_end-1
            c:                 [B, cond_dim] conditioning
            k_prefix, v_prefix: [B, H, S_prefix, d_head] post-rotary cached K/V
                                (from earlier blocks). None means no prefix.

        Returns:
            x_out:  [B, bs, d]
            k_new:  [B, H, bs, d_head] post-rotary K for current block
            v_new:  [B, H, bs, d_head] post-rotary V for current block
        """
        layer = self.model.blocks[layer_idx]
        B = x.shape[0]
        H = layer.n_heads
        dropout = layer.dropout
        bds_fn = layer._bias_dropout_scale_fn()

        (shift_msa, scale_msa, gate_msa,
         shift_mlp, scale_mlp, gate_mlp) = layer.adaLN_modulation(c)[:, None].chunk(6, dim=2)

        x_skip = x
        x_normed = modulate_fused(layer.norm1(x), shift_msa, scale_msa)
        qkv = layer.attn_qkv(x_normed)
        qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=H)
        cos, sin = rotary_cos_sin
        qkv = apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))

        q = qkv[:, :, 0].transpose(1, 2)      # [B, H, bs, d_h]
        k_new = qkv[:, :, 1].transpose(1, 2)  # [B, H, bs, d_h]
        v_new = qkv[:, :, 2].transpose(1, 2)

        if k_prefix is not None:
            k = torch.cat([k_prefix, k_new], dim=2)
            v = torch.cat([v_prefix, v_new], dim=2)
        else:
            k = k_new
            v = v_new

        # No mask: current block may attend to all prefix (block-causal lookback)
        # and to itself (bidirectional within block).
        attn_out = F.scaled_dot_product_attention(q, k, v)
        attn_out = rearrange(attn_out, "b h s d -> b s (h d)", b=B)

        x = bds_fn(layer.attn_out(attn_out), None, gate_msa, x_skip, dropout)
        x = bds_fn(
            layer.mlp(modulate_fused(layer.norm2(x), shift_mlp, scale_mlp)),
            None, gate_mlp, x, dropout,
        )
        return x, k_new, v_new

    def _forward_block_cached(
        self,
        level_ids_cur: torch.Tensor,
        value_ids_cur: torch.Tensor,
        block_idx: int,
        kv_cache: list,
        is_clean: bool = False,
    ):
        """
        Forward pass over a single block using cached prefix K/V.

        Args:
            level_ids_cur, value_ids_cur: [B, bs] current block state
            block_idx: int, absolute block index (for pos/rotary)
            kv_cache: list[(k_prefix, v_prefix) or (None, None)] per layer
            is_clean: if True, use segment_embed(1) (clean half) to match
                      training's clean context. Used when capturing K/V for
                      finalized blocks and prompt warm-up.

        Returns:
            logits_cur: [B, bs, V] (mask column already set to -inf)
            new_kv:     list[(k_cur, v_cur)] per layer β€” caller appends to cache
        """
        model = self.model
        B, bs = level_ids_cur.shape
        block_start = block_idx * self.block_size
        block_end = block_start + bs
        device = self.device

        embs = self._build_mixed_embeddings(level_ids_cur, value_ids_cur)  # [B, bs, d]

        # Input projection (weights are self.dtype; embs already self.dtype).
        x = model.input_proj(embs)

        # Position embeddings for this block only.
        block_idx_t = torch.full(
            (bs,), block_idx, dtype=torch.long, device=device,
        )
        intra_pos = torch.arange(self.block_size, device=device)
        # segment=0 for noisy (denoising rounds), segment=1 for clean (cache capture)
        seg_val = 1 if is_clean else 0
        seg_id = torch.full((bs,), seg_val, dtype=torch.long, device=device)
        pos_emb = (
            model.block_idx_embed(block_idx_t)
            + model.intra_pos_embed(intra_pos)
            + model.segment_embed(seg_id)
        ).unsqueeze(0).to(x.dtype)
        x = x + pos_emb

        c = model.cond_bias.unsqueeze(0).expand(B, -1).to(x.dtype)

        # Rotary for absolute positions of this block.
        position_ids = torch.arange(block_start, block_end, device=device)
        rotary_cos_sin = model.rotary_emb(x, position_ids=position_ids)

        new_kv = []
        autocast_device = "cuda" if device.type == "cuda" else "cpu"
        with torch.autocast(device_type=autocast_device, dtype=self.dtype):
            for layer_idx in range(len(model.blocks)):
                k_prefix, v_prefix = kv_cache[layer_idx]
                x, k_cur, v_cur = self._run_layer_cached(
                    layer_idx, x, rotary_cos_sin, c,
                    k_prefix=k_prefix, v_prefix=v_prefix,
                )
                new_kv.append((k_cur, v_cur))
            logits = model.output_layer(x, c)  # [B, bs, rounded_leaf]

        logits = logits[..., :self.vocab_size]
        logits[..., self.mask_id] = float("-inf")
        return logits, new_kv

    @staticmethod
    def _append_kv(kv_cache: list, new_kv: list) -> list:
        """Append per-layer new_kv to kv_cache along the sequence dim."""
        out = []
        for (kp, vp), (kn, vn) in zip(kv_cache, new_kv):
            if kp is None:
                out.append((kn, vn))
            else:
                out.append((torch.cat([kp, kn], dim=2),
                            torch.cat([vp, vn], dim=2)))
        return out

    # ───────────────────────────────────────────────────────────────────────
    @torch.no_grad()
    def generate(
        self,
        batch_size: Optional[int] = None,
        prompt_ids: Optional[torch.Tensor] = None,
        positions_per_step: int = 1,
        return_intermediate: bool = False,
        stop_on_eos: bool = True,
    ) -> dict:
        """
        Block-by-block generation with KV cache and random per-round position
        selection.

        Within each block, rounds repeat until every position is leaf. Each
        round runs one forward, computes the best strictly-finer target
        (level, id) for every non-leaf position, then picks
        `positions_per_step` random non-leaf positions per sample and applies
        their transitions. The strict-descent schedule (pick the finest level
        whose LUT-projected max-prob is highest) is unchanged.

        Unconditional: pass `batch_size` (and leave `prompt_ids=None`); starts
        from an all-mask sequence of length `self.max_seq_len`.

        Conditional: pass `prompt_ids` with shape [B, P] where P is a multiple
        of `block_size`; the first P positions are fixed as leaf tokens, the
        remaining positions are generated block by block.
        """
        block_size = self.block_size
        device = self.device

        total_len = self.max_seq_len
        assert total_len % block_size == 0, (
            f"max_seq_len ({total_len}) must be divisible by block_size "
            f"({block_size})"
        )

        if prompt_ids is not None:
            prompt_ids = prompt_ids.to(device=device, dtype=torch.long)
            B, P = prompt_ids.shape
            assert P % block_size == 0, (
                f"prompt length P={P} must be a multiple of block_size={block_size}"
            )
            assert P < total_len, (
                f"prompt length P={P} must be < total_len={total_len}"
            )
            start_block = P // block_size
        else:
            assert batch_size is not None, (
                "Either batch_size (unconditional) or prompt_ids (conditional) "
                "must be provided."
            )
            B = batch_size
            P = 0
            start_block = 0

        # ── Initialize state: every position is mask; prompt positions set as leaf.
        level_ids = torch.full(
            (B, total_len), self.mask_level, dtype=torch.long, device=device,
        )
        value_ids = torch.zeros((B, total_len), dtype=torch.long, device=device)
        if P > 0:
            level_ids[:, :P] = 0
            value_ids[:, :P] = prompt_ids

        num_blocks = total_len // block_size

        intermediate = [] if return_intermediate else None
        finished = torch.zeros(B, dtype=torch.bool, device=device)
        eos_id = getattr(self.tokenizer, "eos_token_id", None)

        # ── KV cache: per-layer (k_prefix, v_prefix) for finalized blocks.
        # Starts empty; we append block b's K/V after b is fully resolved,
        # so when block b+1 starts the cache covers blocks 0..b.
        num_layers = len(self.model.blocks)
        kv_cache = [(None, None) for _ in range(num_layers)]

        # ── Warm up KV cache over prompt blocks (all leaf, deterministic).
        # Use is_clean=True: prompt blocks act as clean context for later blocks,
        # matching training's clean half (segment=1).
        for b in range(start_block):
            bs0 = b * block_size
            be0 = (b + 1) * block_size
            _, new_kv = self._forward_block_cached(
                level_ids[:, bs0:be0], value_ids[:, bs0:be0], b, kv_cache,
                is_clean=True,
            )
            kv_cache = self._append_kv(kv_cache, new_kv)

        # ── Block loop (skips prompt blocks). ──────────────────────────────
        # Each round advances up to `positions_per_step` non-leaf positions by
        # β‰₯1 level each (strict descent). Worst case every position needs K+1
        # transitions β†’ cap at block_size * (K+1) rounds, which is slack.
        rounds_cap_per_block = block_size * (self.K + 1)

        total_steps = 0  # total denoising rounds across all generated blocks

        for b in range(start_block, num_blocks):
            block_start = b * block_size
            block_end = (b + 1) * block_size

            for _ in range(rounds_cap_per_block):
                cur_level_block = level_ids[:, block_start:block_end]        # [B, bs]
                non_leaf_block = (cur_level_block > 0)                       # [B, bs]
                if not non_leaf_block.any():
                    break

                # 1) Forward pass on current block (cache holds blocks 0..b-1).
                block_logits, _ = self._forward_block_cached(
                    level_ids[:, block_start:block_end],
                    value_ids[:, block_start:block_end],
                    b, kv_cache,
                )                                                            # [B, bs, V]
                # Compute raw (temperature=1) and temperature-sharpened leaf probs.
                # p_leaf_raw / p_ancestor_raw are used for sampling; conf uses
                # temp for leaf and raw+lambda for ancestor.
                leaf_logits_fp = block_logits.float()
                leaf_prob_raw = F.softmax(leaf_logits_fp, dim=-1)            # [B, bs, V]
                if self.leaf_temperature != 1.0:
                    leaf_prob_temp = F.softmax(
                        leaf_logits_fp / self.leaf_temperature, dim=-1,
                    )                                                        # [B, bs, V]
                else:
                    leaf_prob_temp = leaf_prob_raw

                # 2) Best strictly-finer target for every block position.
                best_conf = torch.full(
                    (B, block_size), float("-inf"),
                    device=device, dtype=torch.float32,
                )
                best_level = torch.full(
                    (B, block_size), -1, device=device, dtype=torch.long,
                )
                best_id = torch.zeros(
                    (B, block_size), device=device, dtype=torch.long,
                )

                # Leaf target (l = 0): conf from temp-sharpened dist, sample
                # from temp-sharpened dist.
                leaf_conf = leaf_prob_temp.max(dim=-1).values                # [B, bs]
                leaf_id = torch.multinomial(
                    leaf_prob_temp.reshape(-1, leaf_prob_temp.shape[-1]),
                    num_samples=1,
                ).squeeze(-1).reshape(B, block_size)                         # [B, bs]
                elig = cur_level_block > 0
                upd = elig & (leaf_conf > best_conf)
                best_conf = torch.where(upd, leaf_conf, best_conf)
                best_level = torch.where(upd, torch.zeros_like(best_level), best_level)
                best_id = torch.where(upd, leaf_id, best_id)

                # Ancestor targets l = 1..K.
                # Conf is max-prob over RAW cluster probs times Ξ»_l.
                # Sample is drawn from RAW cluster probs.
                for l in range(1, self.K + 1):
                    V_anc = self.W[l].shape[0]
                    cluster_prob_raw = leaf_prob_raw[..., :V_anc] @ self.W[l]  # [B, bs, K_l]
                    conf_l = cluster_prob_raw.max(dim=-1).values               # [B, bs]
                    conf_l = conf_l * self.level_lambdas[l]
                    id_l = torch.multinomial(
                        cluster_prob_raw.reshape(-1, cluster_prob_raw.shape[-1]),
                        num_samples=1,
                    ).squeeze(-1).reshape(B, block_size)                       # [B, bs]
                    elig_l = cur_level_block > l
                    upd = elig_l & (conf_l > best_conf)
                    best_conf = torch.where(upd, conf_l, best_conf)
                    best_level = torch.where(
                        upd, torch.full_like(best_level, l), best_level,
                    )
                    best_id = torch.where(upd, id_l, best_id)

                # 3) Randomly pick `positions_per_step` non-leaf positions per
                # sample. Leaf positions get score = -inf so they never win a
                # top-k slot; samples with fewer than k non-leaf positions
                # drop the extra slots via the explicit non_leaf_block mask.
                k = min(positions_per_step, block_size)
                scores = torch.rand(B, block_size, device=device)
                scores = torch.where(
                    non_leaf_block, scores, torch.full_like(scores, -1.0),
                )
                _, topk_idx = scores.topk(k, dim=-1)                         # [B, k]
                selected = torch.zeros_like(non_leaf_block)
                selected.scatter_(1, topk_idx, True)
                apply_mask = selected & non_leaf_block                       # [B, bs]

                level_ids[:, block_start:block_end] = torch.where(
                    apply_mask, best_level, cur_level_block,
                )
                value_ids[:, block_start:block_end] = torch.where(
                    apply_mask, best_id, value_ids[:, block_start:block_end],
                )

                if return_intermediate:
                    intermediate.append(
                        (level_ids.clone().cpu(), value_ids.clone().cpu())
                    )

                total_steps += 1

            # Safety net: force any lingering non-leaf positions to leaf.
            # Use the same temperature-sharpened distribution for consistency.
            block_level = level_ids[:, block_start:block_end]
            non_leaf = (block_level > 0)
            if non_leaf.any():
                block_logits, _ = self._forward_block_cached(
                    level_ids[:, block_start:block_end],
                    value_ids[:, block_start:block_end],
                    b, kv_cache,
                )
                leaf_logits_fp = block_logits.float()
                if self.leaf_temperature != 1.0:
                    leaf_logits_fp = leaf_logits_fp / self.leaf_temperature
                leaf_prob_fallback = F.softmax(leaf_logits_fp, dim=-1)
                leaf_id_fallback = torch.multinomial(
                    leaf_prob_fallback.reshape(-1, leaf_prob_fallback.shape[-1]),
                    num_samples=1,
                ).squeeze(-1).reshape(B, block_size)
                level_ids[:, block_start:block_end] = torch.where(
                    non_leaf, torch.zeros_like(block_level), block_level,
                )
                value_ids[:, block_start:block_end] = torch.where(
                    non_leaf, leaf_id_fallback, value_ids[:, block_start:block_end],
                )

            # ── Finalize block b in the KV cache ───────────────────────────
            # Run one more forward on the block's final (all-leaf) state to
            # grab K/V that are consistent with the resolved tokens, then
            # append to the cache so block b+1 can see block b.
            # Use is_clean=True: finalized blocks serve as clean context for
            # later blocks, matching training's clean half (segment=1).
            _, new_kv = self._forward_block_cached(
                level_ids[:, block_start:block_end],
                value_ids[:, block_start:block_end],
                b, kv_cache,
                is_clean=True,
            )
            kv_cache = self._append_kv(kv_cache, new_kv)

            if stop_on_eos and eos_id is not None:
                block_vals = value_ids[:, block_start:block_end]
                block_lvls = level_ids[:, block_start:block_end]
                has_eos = ((block_lvls == 0) & (block_vals == eos_id)).any(dim=-1)
                finished = finished | has_eos
                if finished.all():
                    break

        # ── Package output ──────────────────────────────────────────────────
        # Every position is now leaf (level 0), so value_ids holds token ids.
        result = {
            "tokens": value_ids.cpu(),
            "prompt_len": P,
            "num_steps": total_steps,
        }
        if return_intermediate:
            result["intermediate"] = intermediate
        return result


# ─────────────────────────────────────────────────────────────────────────────
# Checkpoint / model plumbing
# ─────────────────────────────────────────────────────────────────────────────

def _unwrap(model):
    """Peel DDP (.module) and torch.compile (._orig_mod) wrappers."""
    while True:
        if hasattr(model, "_orig_mod"):
            model = model._orig_mod
        elif hasattr(model, "module"):
            model = model.module
        else:
            return model


def load_config(path: str) -> dict:
    with open(path) as f:
        return yaml.safe_load(f)


def build_tokenizer(config: dict):
    from transformers import AutoTokenizer
    tok = AutoTokenizer.from_pretrained(
        ROOT / "tokenizers" / "gpt2",
        local_files_only=True,
    )
    if tok.eos_token is None:
        tok.add_special_tokens({"eos_token": "<|endoftext|>"})
    if tok.bos_token is None:
        tok.bos_token = tok.eos_token
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    if tok.mask_token_id is None:
        tok.add_special_tokens({"mask_token": "[MASK]"})
    config["model"]["vocab_size"] = len(tok)
    if "level_sizes" in config["model"]:
        config["model"]["level_sizes"][0] = len(tok)
    return tok


def build_ancestor_table(config: dict, device, embed_dim: int) -> AncestorTable:
    """Mirror of train_sad.build_ancestor_table β€” load fixed LUT (and proto)
    so the returned module has the right shape for ckpt state_dict loading."""
    ancestor_cfg = config.get("ancestor", {})
    script_dir = ROOT
    lut_path = ancestor_cfg.get("lut_path", None)

    if lut_path is None:
        # Debug path: random LUT. Uses the training seed so the random LUT
        # lines up across train/infer β€” checkpoint's state_dict will overwrite
        # the learnable embeddings anyway.
        vocab_size = config["model"]["vocab_size"]
        K = ancestor_cfg.get("num_clusters", 64)
        top_k = ancestor_cfg.get("top_k", 3)
        seed = config.get("training", {}).get("seed", 42)
        g = torch.Generator().manual_seed(seed)
        indices = torch.randint(0, K, (vocab_size, top_k), generator=g)
        raw_w = torch.rand(vocab_size, top_k, generator=g)
        probs = raw_w / raw_w.sum(dim=-1, keepdim=True)
        init_emb = torch.randn(K, embed_dim, generator=g) * 0.02
        return AncestorTable(
            lut_indices=[indices],
            lut_probs=[probs],
            init_embeddings=[init_emb],
        ).to(device)

    lut_path = Path(lut_path) if Path(lut_path).is_absolute() else script_dir / lut_path
    proto_path = ancestor_cfg.get("proto_path", None)
    if proto_path is not None:
        proto_path = Path(proto_path) if Path(proto_path).is_absolute() else script_dir / proto_path
    table = AncestorTable.from_files(
        lut_path=lut_path, proto_path=proto_path,
        embed_dim=embed_dim, device=device,
    )
    return table.to(device)


def build_model(config: dict, device: torch.device) -> SADModel:
    mc = config["model"]
    model = SADModel(
        vocab_size=mc["vocab_size"],
        hidden_size=mc["hidden_size"],
        n_blocks=mc["n_blocks"],
        n_heads=mc["n_heads"],
        cond_dim=mc["cond_dim"],
        max_seq_len=mc["max_seq_len"],
        block_size=mc.get("block_size", 8),
        dropout=mc.get("dropout", 0.0),
        num_levels=mc.get("num_levels", 2),
        level_sizes=mc.get("level_sizes"),
        tie_weights=mc.get("tie_weights", False),
    ).to(device)
    return model


# ─────────────────────────────────────────────────────────────────────────────
# CLI
# ─────────────────────────────────────────────────────────────────────────────

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--checkpoint", type=str, required=True)
    p.add_argument("--config", type=str, default="configs/sad_owt.yaml")
    p.add_argument("--num_samples", type=int, default=1)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--device", type=str,
                   default="cuda" if torch.cuda.is_available() else "cpu")
    p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"])
    p.add_argument("--stop_on_eos", action="store_true", default=True)
    p.add_argument("--mode", type=str, default="unconditional",
                   choices=["unconditional", "conditional"],
                   help="unconditional: start from all-mask. "
                        "conditional: take a block from the training set as the first block(s).")
    p.add_argument("--prompt_blocks", type=int, default=1,
                   help="(conditional) number of leading blocks taken from the training data.")
    p.add_argument("--data_seed", type=int, default=0,
                   help="(conditional) seed for shuffling the training split when picking a sample.")
    p.add_argument("--positions_per_step", type=int, default=1,
                   help="Number of random non-leaf positions to advance per "
                        "denoising round within a block.")
    p.add_argument("--level_lambdas", type=str, default=None,
                   help="Comma-separated K floats in [0, 1], one per ancestor "
                        "level l = 1..K (e.g. '1.0,0.8,0.5'). Multiplies the "
                        "level's max-prob conf before the cross-level argmax. "
                        "Ξ»_l < 1 biases the schedule away from level l; "
                        "Ξ»_l = 0 disables it. Default: all 1.0 (no change).")
    p.add_argument("--leaf_temperature", type=float, default=1.0,
                   help="Temperature applied to leaf logits before softmax. "
                        "Values < 1.0 sharpen p_leaf, which is then used for "
                        "both leaf multinomial sampling and ancestor projection. "
                        "Default 1.0 (no sharpening).")
    return p.parse_args()


def resolve_dtype(name: str) -> torch.dtype:
    return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[name]


def main():
    args = parse_args()
    torch.manual_seed(args.seed)

    device = torch.device(args.device)
    dtype = resolve_dtype(args.dtype)

    config = load_config(args.config)
    tokenizer = build_tokenizer(config)

    # ── Build + load model ─────────────────────────────────────────────────
    model = build_model(config, device).to(dtype)
    ckpt = torch.load(args.checkpoint, map_location=device)
    raw_state = ckpt.get("model", ckpt)
    _unwrap(model).load_state_dict(raw_state, strict=False)
    model.eval()
    print(f"Loaded checkpoint: {args.checkpoint}  (step={ckpt.get('step', '?')})")

    # ── Build + load ancestor table ────────────────────────────────────────
    # Fixed LUT comes from config (same file as training); learnable ancestor
    # embeddings come from the checkpoint. load_state_dict overwrites both
    # buffers (LUT, W_l) and parameters (ancestor_embeddings) to match training
    # exactly.
    ancestor_table = build_ancestor_table(
        config, device, embed_dim=config["model"]["hidden_size"],
    )
    assert "ancestor_table" in ckpt, (
        "Checkpoint has no 'ancestor_table' entry β€” cannot run hierarchical "
        "inference. Re-train with train_sad.py or use an older inference "
        "script that ignores ancestors."
    )
    ancestor_table.load_state_dict(ckpt["ancestor_table"])
    ancestor_table.to(device=device, dtype=dtype).eval()
    print(f"Loaded ancestor table: {ancestor_table.num_levels} ancestor level(s)")

    level_lambdas = None
    if args.level_lambdas:
        level_lambdas = [float(x) for x in args.level_lambdas.split(",")]

    sampler = BlockDiffusionSampler(
        model=_unwrap(model),
        ancestor_table=ancestor_table,
        tokenizer=tokenizer,
        device=device,
        dtype=dtype,
        level_lambdas=level_lambdas,
        leaf_temperature=args.leaf_temperature,
    )
    print(f"level_lambdas (per ancestor level l=1..K) = "
          f"{sampler.level_lambdas[1:]}")
    print(f"leaf_temperature = {sampler.leaf_temperature}")

    # ── Optionally load a prompt from the training data ────────────────────
    prompt_ids = None
    if args.mode == "conditional":
        data_cfg = config.get("data", {})
        seq_len = config["model"]["max_seq_len"]
        block_size = config["model"]["block_size"]
        prompt_len = args.prompt_blocks * block_size
        assert prompt_len < seq_len, (
            f"prompt_blocks * block_size = {prompt_len} must be < max_seq_len = {seq_len}"
        )
        # Resolve relative cache_dir against the sad/ repo root (scripts/..), so
        # the script works regardless of cwd (training ran from sad/).
        cache_dir = data_cfg.get("cache_dir", None)
        if cache_dir is not None and not Path(cache_dir).is_absolute():
            repo_root = ROOT
            candidate = repo_root / cache_dir
            if candidate.exists():
                cache_dir = str(candidate)
        loader = build_owt_dataloader(
            tokenizer,
            split="train[:-100000]",
            seq_len=seq_len,
            batch_size=args.num_samples,
            num_workers=0,
            cache_dir=cache_dir,
            seed=args.data_seed,
            mode=data_cfg.get("mode", "subsample"),
            shard_across_ranks=False,
        )
        batch = next(iter(loader))
        prompt_ids = batch["input_ids"][:args.num_samples, :prompt_len].to(device)
        print(f"Loaded conditional prompt from training data: "
              f"shape={tuple(prompt_ids.shape)} (prompt_blocks={args.prompt_blocks})")

    print(f"Sampling {args.num_samples} sequences ({args.mode}) "
          f"length={config['model']['max_seq_len']}, "
          f"random positions_per_step={args.positions_per_step}")

    out = sampler.generate(
        batch_size=args.num_samples if prompt_ids is None else None,
        prompt_ids=prompt_ids,
        positions_per_step=args.positions_per_step,
        stop_on_eos=args.stop_on_eos,
    )

    # ── Decode & print ─────────────────────────────────────────────────────
    P = out.get("prompt_len", 0)
    print("\n" + "=" * 72)
    for i, ids in enumerate(out["tokens"]):
        ids_list = ids.tolist()
        print(f"[Sample {i + 1}]")
        if P > 0:
            prompt_text = tokenizer.decode(ids_list[:P], skip_special_tokens=True)
            gen_text = tokenizer.decode(ids_list[P:], skip_special_tokens=True)
            print(f"<prompt ({P} tok)> {prompt_text}")
            print(f"<generated> {gen_text}")
        else:
            print(tokenizer.decode(ids_list, skip_special_tokens=True))
        print()


if __name__ == "__main__":
    main()