File size: 30,074 Bytes
486838c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer
from dataclasses import dataclass
import os
import math


# ============== Model Architecture ==============

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization."""

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        var = x.pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(var + self.eps)
        return self.weight * x


class RotaryEmbedding(nn.Module):
    """Rotary Position Embeddings (RoPE) with NTK extrapolation."""

    def __init__(self, dim, max_position_embeddings=16384, base=100000, scaling_factor=1.0):
        super().__init__()
        self.scaling_factor = scaling_factor
        self.dim = dim
        self.base = base
        self.max_position_embeddings = max_position_embeddings
        self.inv_freq = None
        self._cache = {}

    def _update_freqs(self, device):
        base = self.base * (self.scaling_factor ** (self.dim / (self.dim - 2)))
        inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.inv_freq = inv_freq

    def forward(self, x, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[-2]

        if self.inv_freq is None or self.inv_freq.device != x.device:
            self._update_freqs(x.device)

        cache_key = (seq_len, x.device, x.dtype)
        if cache_key in self._cache:
            return self._cache[cache_key]

        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)

        cos = emb.cos()[None, None, :, :]
        sin = emb.sin()[None, None, :, :]

        self._cache[cache_key] = (cos, sin)
        if len(self._cache) > 10:
            self._cache.pop(next(iter(self._cache)))

        return cos, sin


def apply_rotary_pos_emb(q, k, cos, sin):
    """Apply rotary embeddings to Q and K."""
    def rotate_half(x):
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2:]
        return torch.cat((-x2, x1), dim=-1)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class DiffusionAttention(nn.Module):
    """Multi-head attention with GQA and Flash Attention support."""

    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.use_flash_attn = config.use_flash_attn

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

    def forward(self, hidden_states, freqs_cis, attention_mask=None, past_kv=None):
        bsz, q_len, _ = hidden_states.size()

        q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = freqs_cis
        cos = cos[:, :, :q_len, :]
        sin = sin[:, :, :q_len, :]
        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        if past_kv is not None:
            cache_k, cache_v = past_kv
            k = torch.cat([cache_k, k], dim=2)
            v = torch.cat([cache_v, v], dim=2)

        current_kv = (k, v)

        k = k.repeat_interleave(self.num_key_value_groups, dim=1)
        v = v.repeat_interleave(self.num_key_value_groups, dim=1)

        attn_mask = None
        if attention_mask is not None:
            attn_mask = attention_mask[:, None, None, :].to(dtype=q.dtype)
            attn_mask = (1.0 - attn_mask) * torch.finfo(q.dtype).min

        output = F.scaled_dot_product_attention(
            q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
        )

        output = output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
        return self.o_proj(output), current_kv


class MLP(nn.Module):
    """Gated MLP with SiLU activation."""

    def __init__(self, config):
        super().__init__()
        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
        self.act_fn = nn.SiLU()

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


class BlockDiffusionBlock(nn.Module):
    """Transformer block with pre-norm, attention, and MLP."""

    def __init__(self, config):
        super().__init__()
        self.self_attn = DiffusionAttention(config)
        self.mlp = MLP(config)
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.use_activation_checkpointing = config.use_activation_checkpointing

    def forward(self, hidden_states, freqs_cis, attention_mask, past_kv):
        return self._forward(hidden_states, freqs_cis, attention_mask, past_kv)

    def _forward(self, hidden_states, freqs_cis, attention_mask, past_kv):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        attn_out, new_kv = self.self_attn(hidden_states, freqs_cis, attention_mask, past_kv)
        hidden_states = residual + attn_out

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + self.mlp(hidden_states)
        return hidden_states, new_kv


@dataclass
class ModelConfig:
    """Model architecture configuration."""
    vocab_size: int = 151936
    hidden_size: int = 1024
    intermediate_size: int = 2816
    num_hidden_layers: int = 16
    num_attention_heads: int = 16
    num_key_value_heads: int = 4
    max_position_embeddings: int = 16384
    rms_norm_eps: float = 1e-6
    rope_theta: float = 100000.0
    pad_token_id: int = 0
    mask_token_id: int = 1
    use_flash_attn: bool = True
    use_activation_checkpointing: bool = False
    attention_dropout: float = 0.0
    hidden_dropout: float = 0.0


class DiffusionLLM(nn.Module):
    """Complete diffusion language model."""

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config

        pad_idx = config.pad_token_id if config.pad_token_id < config.vocab_size else None
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=pad_idx)

        self.layers = nn.ModuleList([BlockDiffusionBlock(config) for _ in range(config.num_hidden_layers)])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.rotary_emb = RotaryEmbedding(
            config.hidden_size // config.num_attention_heads,
            config.max_position_embeddings
        )

        self.lm_head.weight = self.embed_tokens.weight

    def forward(self, input_ids, attention_mask=None, past_key_values=None):
        bsz, seqlen = input_ids.shape
        hidden_states = self.embed_tokens(input_ids)
        freqs_cis = self.rotary_emb(hidden_states, seq_len=seqlen)

        if past_key_values is None:
            past_key_values = [None] * len(self.layers)

        new_kvs = []
        for i, layer in enumerate(self.layers):
            hidden_states, kv = layer(hidden_states, freqs_cis, attention_mask, past_key_values[i])
            new_kvs.append(kv)

        hidden_states = self.norm(hidden_states)
        logits = self.lm_head(hidden_states)
        return logits, new_kvs

    def get_num_params(self, trainable_only=True):
        if trainable_only:
            return sum(p.numel() for p in self.parameters() if p.requires_grad)
        else:
            return sum(p.numel() for p in self.parameters())


# ============== Inference Functions ==============

def load_model(model_path: str, device: str = 'cuda'):
    """Load a saved model (fp16 or fp32) for inference."""
    print(f"Loading model from {model_path}...")

    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    config = checkpoint['config']

    model = DiffusionLLM(config)

    state_dict = checkpoint['model_state']
    state_dict = {k: v.float() for k, v in state_dict.items()}
    model.load_state_dict(state_dict)

    model = model.to(device)
    model.eval()

    num_params = model.get_num_params() / 1e6
    file_size = os.path.getsize(model_path) / 1e6
    print(f"βœ“ Model loaded: {num_params:.1f}M params from {file_size:.1f} MB file")

    return model, config


def visualize_diffusion_state(tokenizer, context_ids, mask_blocks, is_masked_list, config, clear=True, block_colors=None):
    """Visualize the current state of diffusion generation with multiple blocks.
    
    Args:
        mask_blocks: Either a single block tensor (1, block_size) or list of block tensors
        is_masked_list: Either a single mask tensor (1, block_size) or list of mask tensors
        block_colors: List of ANSI color codes for each block. If None, uses defaults.
    """
    import sys
    import os

    # Default colors for different blocks (green, cyan, yellow, magenta)
    DEFAULT_COLORS = ['\033[92m', '\033[96m', '\033[93m', '\033[95m']
    MASK_COLOR = '\033[90m'  # Gray for masked tokens
    RESET = '\033[0m'

    # Normalize inputs to lists
    if not isinstance(mask_blocks, list):
        mask_blocks = [mask_blocks]
        is_masked_list = [is_masked_list]

    if block_colors is None:
        block_colors = DEFAULT_COLORS

    # Decode context (prompt + previously generated blocks) and replace newlines
    context_text = tokenizer.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')

    # Build visualization for all blocks
    all_blocks_text = []
    for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
        color = block_colors[block_idx % len(block_colors)]
        block_tokens = mask_block[0].tolist()
        block_color_tokens = []

        for i, token_id in enumerate(block_tokens):
            if is_masked[0, i]:
                # Use block-specific color for masked tokens to distinguish blocks
                block_color_tokens.append(f'{MASK_COLOR}β–ˆβ–ˆ{RESET}')
            else:
                # Decode individual token; use block color for revealed tokens
                token_text = tokenizer.decode([token_id], skip_special_tokens=False)
                block_color_tokens.append(f'{color}{token_text}{RESET}')

        all_blocks_text.append(''.join(block_color_tokens))

    # Join all blocks with a subtle separator
    blocks_combined = ''.join(all_blocks_text)

    # Clear entire terminal
    if clear:
        clear_cmd = 'cls' if os.name == 'nt' else 'clear'
        try:
            os.system(clear_cmd)
        except Exception:
            sys.stdout.write('\r\033[K')

    # Print legend for parallel blocks
    if len(mask_blocks) > 1:
        legend_parts = []
        for i in range(len(mask_blocks)):
            color = block_colors[i % len(block_colors)]
            legend_parts.append(f'{color}Block {i+1}{RESET}')
        print(f"Generating: {' | '.join(legend_parts)}\n")

    # Print the full context with colored blocks
    print(f"{context_text}{blocks_combined}", flush=True)


def demo_visualize_truncation():
    """Demo for visualize_diffusion_state without a full model.
    Simulates streaming output and verifies there is no line duplication when content exceeds terminal width.
    """
    class MockTokenizer:
        def __init__(self):
            # Map token id to token text (simple ASCII characters and spaces)
            self.vocab = {i: chr(65 + (i % 26)) for i in range(256)}
            self.vocab[32] = ' '
            self.eos_token = '\n'
            self.pad_token = ' '

        def decode(self, ids, skip_special_tokens=True):
            # ids can be tensor or list
            if isinstance(ids, torch.Tensor):
                ids = ids.tolist()
            if isinstance(ids, (list, tuple)):
                return ''.join(self.vocab.get(int(i) % 256, '?') for i in ids)
            return str(ids)

    tok = MockTokenizer()
    # Create a long context and a block that's also long
    # Make context exceed terminal width
    term_width = 80
    long_context_ids = torch.tensor([[i % 26 + 65 for i in range(120)]], dtype=torch.long)
    block_size = 32
    mask_block = torch.full((1, block_size), 32, dtype=torch.long)  # spaces
    is_masked = torch.ones(1, block_size, dtype=torch.bool)
    for i in range(0, block_size, 3):
        is_masked[0, i] = False
        mask_block[0, i] = 65 + (i % 26)

    print('\nRunning demo: long prompt + block to test truncation\n')
    for i in range(8):
        visualize_diffusion_state(tok, long_context_ids, [mask_block], [is_masked], ModelConfig(), clear=(i > 0))
        # rotate some tokens to simulate diffusion
        mask_block = torch.roll(mask_block, shifts=1, dims=1)
        time_delay = 0.08
        try:
            import time
            time.sleep(time_delay)
        except Exception:
            pass
    print('\n\nDemo completed.')


@torch.no_grad()
def generate_block_diffusion(
    model,
    tokenizer,
    prompt: str,
    steps: int = 16,
    block_size: int = 64,
    max_new_tokens: int = 256,
    device: str = 'cuda',
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.9,
    repetition_penalty: float = 1.2,
    no_repeat_ngram_size: int = 3,
    visualize: bool = False,
    parallel_blocks: int = 1,  # Number of blocks to generate in parallel
):
    """Generate text using block diffusion with proper sampling and repetition control.
    
    Args:
        visualize: If True, stream output in real-time showing the diffusion effect.
        parallel_blocks: Number of blocks to generate in parallel (1-4 recommended).
    """
    import time
    model.eval()

    prompt_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    config = model.module.config if hasattr(model, 'module') else model.config
    if hasattr(model, '_orig_mod'):
        config = model._orig_mod.config

    num_blocks = max_new_tokens // block_size
    parallel_blocks = min(parallel_blocks, num_blocks)  # Can't parallelize more than total blocks

    if not visualize:
        if parallel_blocks > 1:
            print(f"Generating {num_blocks} blocks of {block_size} tokens each ({parallel_blocks} blocks in parallel)...")
        else:
            print(f"Generating {num_blocks} blocks of {block_size} tokens each...")
    else:
        print(f"\n\033[94mStarting diffusion generation...\033[0m\n")
        print(prompt, end='', flush=True)

    context_ids = prompt_ids
    all_generated_tokens = set(prompt_ids[0].tolist())

    # Process blocks in batches of parallel_blocks
    blocks_generated = 0
    while blocks_generated < num_blocks:
        # Determine how many blocks to generate this iteration
        current_parallel = min(parallel_blocks, num_blocks - blocks_generated)

        if current_parallel > 1:
            # Parallel block generation
            generated_blocks = _generate_parallel_blocks(
                model, tokenizer, context_ids, config, device,
                current_parallel, block_size, steps, temperature,
                top_k, top_p, repetition_penalty, no_repeat_ngram_size,
                all_generated_tokens, visualize
            )

            # Concatenate all generated blocks to context
            for block in generated_blocks:
                context_ids = torch.cat([context_ids, block], dim=1)
                all_generated_tokens.update(block[0].tolist())

            if not visualize:
                print(f"  Blocks {blocks_generated + 1}-{blocks_generated + current_parallel}/{num_blocks} complete")
            blocks_generated += current_parallel
        else:
            # Single block generation (original logic)
            mask_block, block_token_history = _generate_single_block(
                model, tokenizer, context_ids, config, device,
                block_size, steps, temperature, top_k, top_p,
                repetition_penalty, no_repeat_ngram_size,
                all_generated_tokens, visualize
            )

            context_ids = torch.cat([context_ids, mask_block], dim=1)
            all_generated_tokens.update(mask_block[0].tolist())

            if not visualize:
                print(f"  Block {blocks_generated + 1}/{num_blocks} complete")
            blocks_generated += 1

    if visualize:
        # Final newline after visualization
        print("\n")

    generated_ids = context_ids[0].tolist()
    return tokenizer.decode(generated_ids, skip_special_tokens=True)


def _generate_single_block(
    model, tokenizer, context_ids, config, device,
    block_size, steps, temperature, top_k, top_p,
    repetition_penalty, no_repeat_ngram_size,
    all_generated_tokens, visualize
):
    """Generate a single block using diffusion."""
    mask_block = torch.full((1, block_size), config.mask_token_id, device=device)
    is_masked = torch.ones(1, block_size, dtype=torch.bool, device=device)
    block_token_history = []

    for step_idx in range(steps):
        full_input = torch.cat([context_ids, mask_block], dim=1)
        attention_mask = torch.ones_like(full_input, dtype=torch.float32)

        logits, _ = model(full_input, attention_mask=attention_mask)
        block_logits = logits[:, -block_size:, :]

        block_logits = _apply_sampling_controls(
            block_logits, context_ids, mask_block, is_masked,
            repetition_penalty, temperature, top_k, top_p,
            no_repeat_ngram_size, block_token_history
        )

        probs = F.softmax(block_logits, dim=-1)
        probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
        probs = probs.clamp(min=1e-10)
        probs = probs / probs.sum(dim=-1, keepdim=True)

        sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1)
        sampled_tokens = sampled_tokens.view(1, block_size)

        confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1)

        tokens_to_unmask = max(1, block_size // steps)
        if step_idx == steps - 1:
            tokens_to_unmask = is_masked.sum().item()

        if tokens_to_unmask > 0 and is_masked.sum() > 0:
            masked_confidence = confidence.clone()
            masked_confidence[~is_masked] = -1.0

            num_to_unmask = min(tokens_to_unmask, is_masked.sum().item())
            _, top_indices = torch.topk(masked_confidence.view(-1), num_to_unmask)

            for idx in top_indices:
                mask_block[0, idx] = sampled_tokens[0, idx]
                is_masked[0, idx] = False
                block_token_history.append(sampled_tokens[0, idx].item())
                all_generated_tokens.add(sampled_tokens[0, idx].item())

        if visualize:
            visualize_diffusion_state(tokenizer, context_ids, [mask_block], [is_masked], config, clear=(step_idx > 0))

    return mask_block, block_token_history


def _generate_parallel_blocks(
    model, tokenizer, context_ids, config, device,
    num_parallel, block_size, steps, temperature,
    top_k, top_p, repetition_penalty, no_repeat_ngram_size,
    all_generated_tokens, visualize
):
    """Generate multiple blocks in parallel using batched computation.
    
    Each block sees all previous blocks in the sequence, maintaining proper order:
    - Block 0: context + [block0]
    - Block 1: context + [block0] + [block1]
    - Block 2: context + [block0] + [block1] + [block2]
    - etc.
    
    This ensures sequential coherence while still benefiting from batched computation.
    """
    batch_size = num_parallel
    context_len = context_ids.shape[1]

    # Initialize mask blocks for all parallel blocks
    # Shape: (num_parallel, block_size)
    mask_blocks = torch.full((batch_size, block_size), config.mask_token_id, device=device)
    is_masked = torch.ones(batch_size, block_size, dtype=torch.bool, device=device)
    block_token_histories = [[] for _ in range(batch_size)]

    for step_idx in range(steps):
        # Build inputs with proper sequential structure
        # Each batch item has context + all blocks up to and including its own position
        # Block i sees: context + block_0 + block_1 + ... + block_i

        # Create padded inputs - each batch item has different length
        # We'll pad to the longest sequence (which is the last block)
        max_seq_len = context_len + (num_parallel * block_size)

        # Build full input for each batch item
        full_inputs = []
        attention_masks = []

        for b in range(batch_size):
            # This block sees: context + all previous blocks + its own block
            seq_parts = [context_ids[0]]  # Start with context

            # Add all blocks from 0 to b (inclusive)
            for prev_b in range(b + 1):
                seq_parts.append(mask_blocks[prev_b])

            # Concatenate to form this batch item's input
            batch_input = torch.cat(seq_parts, dim=0)  # (seq_len,)
            current_len = batch_input.shape[0]

            # Pad to max_seq_len
            padding_needed = max_seq_len - current_len
            if padding_needed > 0:
                padding = torch.full((padding_needed,), config.pad_token_id, device=device)
                batch_input = torch.cat([batch_input, padding], dim=0)

            full_inputs.append(batch_input)

            # Create attention mask (1 for real tokens, 0 for padding)
            attn_mask = torch.zeros(max_seq_len, device=device)
            attn_mask[:current_len] = 1.0
            attention_masks.append(attn_mask)

        # Stack into batched tensors
        full_input = torch.stack(full_inputs, dim=0)  # (batch, max_seq_len)
        attention_mask = torch.stack(attention_masks, dim=0)  # (batch, max_seq_len)

        # Single forward pass for all blocks
        logits, _ = model(full_input, attention_mask=attention_mask)

        # Extract logits for each block's position
        # Block b's logits are at positions [context_len + b*block_size : context_len + (b+1)*block_size]
        block_logits_list = []
        for b in range(batch_size):
            start_pos = context_len + (b * block_size)
            end_pos = start_pos + block_size
            block_logits_list.append(logits[b, start_pos:end_pos, :])

        block_logits = torch.stack(block_logits_list, dim=0)  # (batch, block_size, vocab)

        # Apply sampling controls per batch item
        for b in range(batch_size):
            # Build context that includes previous blocks for repetition penalty
            extended_context = context_ids
            if b > 0:
                prev_blocks = torch.cat([mask_blocks[pb:pb+1] for pb in range(b)], dim=1)
                extended_context = torch.cat([context_ids, prev_blocks], dim=1)

            block_logits[b:b+1] = _apply_sampling_controls(
                block_logits[b:b+1],
                extended_context,
                mask_blocks[b:b+1],
                is_masked[b:b+1],
                repetition_penalty, temperature, top_k, top_p,
                no_repeat_ngram_size, block_token_histories[b]
            )

        probs = F.softmax(block_logits, dim=-1)
        probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
        probs = probs.clamp(min=1e-10)
        probs = probs / probs.sum(dim=-1, keepdim=True)

        # Sample for all batches
        sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1)
        sampled_tokens = sampled_tokens.view(batch_size, block_size)

        confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1)

        tokens_to_unmask = max(1, block_size // steps)
        if step_idx == steps - 1:
            tokens_to_unmask = block_size  # Unmask all remaining

        # Unmask for each batch item
        for b in range(batch_size):
            if is_masked[b].sum() > 0:
                masked_confidence = confidence[b].clone()
                masked_confidence[~is_masked[b]] = -1.0

                num_to_unmask = min(tokens_to_unmask, is_masked[b].sum().item())
                if num_to_unmask > 0:
                    _, top_indices = torch.topk(masked_confidence, num_to_unmask)

                    for idx in top_indices:
                        mask_blocks[b, idx] = sampled_tokens[b, idx]
                        is_masked[b, idx] = False
                        block_token_histories[b].append(sampled_tokens[b, idx].item())

        if visualize:
            # Visualize all blocks with different colors
            block_list = [mask_blocks[b:b+1] for b in range(batch_size)]
            is_masked_list = [is_masked[b:b+1] for b in range(batch_size)]
            visualize_diffusion_state(
                tokenizer, context_ids, block_list, is_masked_list,
                config, clear=(step_idx > 0)
            )

    # Return list of generated blocks
    return [mask_blocks[b:b+1] for b in range(batch_size)]


def _apply_sampling_controls(
    block_logits, context_ids, mask_block, is_masked,
    repetition_penalty, temperature, top_k, top_p,
    no_repeat_ngram_size, block_token_history
):
    """Apply repetition penalty, temperature, top-k, top-p, and n-gram blocking."""
    if repetition_penalty != 1.0:
        seen_tokens = set(context_ids[0].tolist())
        for i in range(mask_block.shape[1]):
            if not is_masked[0, i]:
                seen_tokens.add(mask_block[0, i].item())

        for token_id in seen_tokens:
            if token_id < block_logits.shape[-1]:
                if block_logits[0, :, token_id].mean() > 0:
                    block_logits[:, :, token_id] /= repetition_penalty
                else:
                    block_logits[:, :, token_id] *= repetition_penalty

    block_logits = block_logits / temperature

    if top_k > 0:
        top_k_logits, top_k_indices = torch.topk(block_logits, top_k, dim=-1)
        block_logits = torch.full_like(block_logits, float('-inf'))
        block_logits.scatter_(-1, top_k_indices, top_k_logits)

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(block_logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
        block_logits[indices_to_remove] = float('-inf')

    if no_repeat_ngram_size > 0 and len(block_token_history) >= no_repeat_ngram_size - 1:
        recent_ngram = tuple(block_token_history[-(no_repeat_ngram_size-1):])
        full_history = context_ids[0].tolist() + block_token_history
        for i in range(len(full_history) - no_repeat_ngram_size + 1):
            if tuple(full_history[i:i+no_repeat_ngram_size-1]) == recent_ngram:
                blocked_token = full_history[i + no_repeat_ngram_size - 1]
                if blocked_token < block_logits.shape[-1]:
                    block_logits[:, :, blocked_token] = float('-inf')

    # Safety check: if all logits are -inf, reset to uniform distribution
    all_inf_mask = torch.isinf(block_logits).all(dim=-1)
    if all_inf_mask.any():
        block_logits[all_inf_mask] = 0.0

    return block_logits


# ============== Main Entry Point ==============

def main():
    """Main inference function."""
    # Configuration
    model_path = "../extra-final-boss/checkpoints/model_fp32.pt"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Using device: {device}")

    # Allow a quick demo mode to test visualization without loading the model
    import sys
    if len(sys.argv) > 1 and sys.argv[1] == 'demo':
        demo_visualize_truncation()
        return

    # Load tokenizer
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load model
    model, config = load_model(model_path, device)

    # Generate text
    print("\n" + "=" * 50)
    print("Text Generation")
    print("=" * 50)

    prompt = "Barrack Obama was born in "
    print(f"Prompt: {prompt}\n")

    # Set visualize=True to see real-time diffusion effect
    visualize = True
    parallel_blocks = 4  # Generate 2-4 blocks in parallel for speedup

    generated = generate_block_diffusion(
        model,
        tokenizer,
        prompt=prompt,
        steps=64,
        block_size=64,
        max_new_tokens=512,
        device=device,
        temperature=1,
        top_k=40,
        top_p=0.9,
        repetition_penalty=1.3,
        no_repeat_ngram_size=3,
        visualize=visualize,
        parallel_blocks=parallel_blocks,
    )

    print(f"\nGenerated text:\n{generated}")


if __name__ == "__main__":
    main()