File size: 20,366 Bytes
cd2f2fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
"""
DiffusionQwen3 Model - Converts Qwen3-1.7B AR to Bidirectional Diffusion LLM

This module provides:
1. DiffusionQwen3Config - Configuration for diffusion-adapted Qwen3
2. DiffusionQwen3Model - The main model class with diffusion training/inference

Based on CoDA (Coding LM via Diffusion Adaptation) by Salesforce AI Research
https://arxiv.org/abs/2510.03270

CRITICAL: Loss normalization matches CoDA official implementation exactly:
  loss = (dsigma[:, None] * loss).sum() / (batch_size * seq_len)
NOT dividing by num_masked (which causes gradient explosion)
"""

import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union, List, Dict, Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers import Qwen2ForCausalLM, Qwen2Config, AutoTokenizer
from transformers.modeling_outputs import CausalLMOutputWithPast


@dataclass
class DiffusionQwen3Config(PretrainedConfig):
    """Configuration for Diffusion-adapted Qwen3 model."""
    
    model_type = "diffusion_qwen3"
    
    def __init__(
        self,
        # Base Qwen3 config
        vocab_size: int = 151936,
        hidden_size: int = 2048,
        intermediate_size: int = 6144,
        num_hidden_layers: int = 28,
        num_attention_heads: int = 16,
        num_key_value_heads: int = 8,
        head_dim: int = 128,
        max_position_embeddings: int = 40960,
        rms_norm_eps: float = 1e-6,
        rope_theta: float = 1000000.0,
        hidden_act: str = "silu",
        attention_dropout: float = 0.0,
        attention_bias: bool = False,
        tie_word_embeddings: bool = True,
        
        # Diffusion-specific config
        mask_token_id: int = 151669,
        pad_token_id: int = 151643,
        bos_token_id: int = 151643,
        eos_token_id: int = 151645,
        
        # Diffusion training parameters
        sampling_eps: float = 0.001,  # CoDA default: creates 1/t in [1, 1000]
        mask_block_sizes: List[int] = None,
        block_masking_probability: float = 0.01,
        prefix_probability: float = 0.01,
        truncate_probability: float = 0.01,
        
        **kwargs
    ):
        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs
        )
        
        # Base model config
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_key_value_heads = num_key_value_heads
        self.head_dim = head_dim
        self.max_position_embeddings = max_position_embeddings
        self.rms_norm_eps = rms_norm_eps
        self.rope_theta = rope_theta
        self.hidden_act = hidden_act
        self.attention_dropout = attention_dropout
        self.attention_bias = attention_bias
        
        # Diffusion config
        self.mask_token_id = mask_token_id
        self.sampling_eps = sampling_eps
        self.mask_block_sizes = mask_block_sizes or [2, 4, 8]
        self.block_masking_probability = block_masking_probability
        self.prefix_probability = prefix_probability
        self.truncate_probability = truncate_probability


class DiffusionQwen3Model(PreTrainedModel):
    """
    Qwen3 model adapted for discrete diffusion language modeling.
    
    Key modifications from standard Qwen3:
    1. Bidirectional attention (is_causal=False)
    2. Masked diffusion training objective
    3. Loss weighted by 1/t (inverse noise level)
    4. Support for progressive masking (S1/S2/S3)
    
    CRITICAL: Loss normalization follows CoDA exactly (line 524 of modeling.py):
      loss = (dsigma[:, None] * loss).sum() / (batch_size * seq_len)
    """
    
    config_class = DiffusionQwen3Config
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["Qwen2DecoderLayer"]
    _supports_flash_attn_2 = True
    _supports_sdpa = True
    
    def __init__(self, config: DiffusionQwen3Config):
        super().__init__(config)
        self.config = config
        
        # Initialize the base Qwen2 model (Qwen3 uses Qwen2 architecture in transformers)
        # We'll load this from pretrained in the from_pretrained method
        self.model = None
        self.lm_head = None
        self.embed_tokens = None
        
        # Diffusion parameters
        self.mask_token_id = config.mask_token_id
        self.sampling_eps = config.sampling_eps
        
        # Loss function
        self.loss_fn = nn.CrossEntropyLoss(reduction='none')
        
    def _init_from_qwen(self, qwen_model: Qwen2ForCausalLM):
        """Initialize from a pretrained Qwen model."""
        # Extract the base model and lm_head
        self.model = qwen_model.model
        self.lm_head = qwen_model.lm_head
        self.embed_tokens = self.model.embed_tokens
        
        # Disable causal masking in all attention layers
        self._disable_causal_masking()
        
    def _disable_causal_masking(self):
        """Disable causal attention masks for bidirectional attention."""
        for layer in self.model.layers:
            if hasattr(layer.self_attn, 'is_causal'):
                layer.self_attn.is_causal = False
    
    def get_input_embeddings(self):
        return self.embed_tokens
    
    def set_input_embeddings(self, value):
        self.embed_tokens = value
        self.model.embed_tokens = value
    
    def get_output_embeddings(self):
        return self.lm_head
    
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings
    
    def get_embeds(self, input_ids: torch.LongTensor) -> torch.Tensor:
        """Get token embeddings."""
        return self.embed_tokens(input_ids)
    
    def transition(
        self,
        x_0: torch.LongTensor,
        sigma: torch.Tensor,
        maskable_mask: torch.BoolTensor,
        mask_block_size: int = 1,
    ) -> torch.LongTensor:
        """
        Apply noise transition: mask tokens with probability sigma.
        
        Args:
            x_0: Original token IDs [batch_size, seq_len]
            sigma: Noise level per sample [batch_size, 1] or [batch_size]
            maskable_mask: Boolean mask of which positions can be masked [batch_size, seq_len]
            mask_block_size: Size of contiguous blocks to mask (1 for individual tokens)
            
        Returns:
            x_t: Noisy token IDs with some tokens replaced by mask_token_id
        """
        if sigma.dim() == 1:
            sigma = sigma.unsqueeze(-1)
            
        if mask_block_size == 1:
            # Standard per-token masking
            move_indices = (torch.rand_like(x_0, dtype=torch.float) < sigma) & maskable_mask
            x_t = torch.where(move_indices, self.mask_token_id, x_0)
        else:
            # Block masking
            x_t = self._block_masking(x_0, sigma, maskable_mask, mask_block_size)
            
        return x_t
    
    def _block_masking(
        self,
        x_0: torch.LongTensor,
        sigma: torch.Tensor,
        maskable_mask: torch.BoolTensor,
        mask_block_size: int,
    ) -> torch.LongTensor:
        """Apply block masking for contiguous spans."""
        batch_size, seq_len = x_0.shape
        
        if seq_len < mask_block_size:
            return x_0
        
        # Calculate number of possible block positions
        num_windows = seq_len - mask_block_size + 1
        
        # Create all possible block positions
        window_starts = torch.arange(num_windows, device=x_0.device)
        block_offsets = torch.arange(mask_block_size, device=x_0.device)
        all_positions = window_starts.unsqueeze(1) + block_offsets.unsqueeze(0)
        
        # Check which blocks are fully maskable
        maskable_blocks = maskable_mask.unsqueeze(1).expand(-1, num_windows, -1)
        maskable_blocks = maskable_blocks.gather(2, all_positions.unsqueeze(0).expand(batch_size, -1, -1))
        fully_maskable = maskable_blocks.all(dim=2)
        
        # Scale sigma for block masking (CoDA line 569)
        effective_sigma = 1 - (1 - sigma) ** (1 / mask_block_size)
        
        # Determine which blocks to mask
        should_mask = (torch.rand(batch_size, num_windows, device=x_0.device) < effective_sigma) & fully_maskable
        
        # Create final mask
        position_indices = torch.arange(seq_len, device=x_0.device).unsqueeze(0).unsqueeze(0)
        all_positions_expanded = all_positions.unsqueeze(0)
        should_mask_expanded = should_mask.unsqueeze(2)
        
        position_matches = (position_indices == all_positions_expanded.unsqueeze(3)).any(dim=2)
        should_mask_positions = should_mask_expanded & position_matches
        final_mask = should_mask_positions.any(dim=1)
        
        return torch.where(final_mask, self.mask_token_id, x_0)
    
    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
        src_mask: Optional[torch.BoolTensor] = None,
        training_mode: str = "pretrain",
        masking_schedule: Optional[Dict[str, Any]] = None,
        epoch: Optional[int] = None,
        return_logits_only: bool = False,
        **kwargs,
    ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], CausalLMOutputWithPast]:
        """
        Forward pass with diffusion training.
        
        Args:
            input_ids: Input token IDs [batch_size, seq_len]
            attention_mask: Attention mask [batch_size, seq_len]
            labels: Target labels (same as input_ids for diffusion)
            src_mask: Source mask for SFT (True = prompt, False = response)
            training_mode: "pretrain", "midtrain", or "sft"
            masking_schedule: Optional override for masking probabilities
            epoch: Current epoch for progressive masking
            return_logits_only: If True, skip diffusion training logic (used by trainer)
            
        Returns:
            logits: Model predictions [batch_size, seq_len, vocab_size]
            loss: Diffusion loss (if training and not return_logits_only)
        """
        if not self.training or return_logits_only:
            # Inference mode OR trainer is handling diffusion logic
            hidden_states = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
            ).last_hidden_state
            logits = self.lm_head(hidden_states)
            return CausalLMOutputWithPast(logits=logits, loss=None)
        
        # Training mode
        batch_size, seq_len = input_ids.shape
        
        # Get masking configuration
        if masking_schedule is not None:
            prefix_prob = masking_schedule.get("prefix_probability", 0)
            truncate_prob = masking_schedule.get("truncate_probability", 0)
            block_prob = masking_schedule.get("block_masking_probability", 0)
            mask_block_sizes = masking_schedule.get("mask_block_sizes", self.config.mask_block_sizes)
        else:
            prefix_prob = self.config.prefix_probability
            truncate_prob = self.config.truncate_probability
            block_prob = self.config.block_masking_probability
            mask_block_sizes = self.config.mask_block_sizes
        
        # Create maskable_mask based on training mode
        if src_mask is not None:
            # SFT mode: only mask response tokens
            maskable_mask = ~src_mask
        else:
            # Pre-training/mid-training: all tokens maskable
            maskable_mask = torch.ones_like(input_ids, dtype=torch.bool)
            
            # Apply S1: Unmaskable prefix
            if prefix_prob > 0:
                maskable_mask = self._apply_prefix_masking(
                    input_ids, maskable_mask, prefix_prob
                )
            
            # Apply S2: Truncated suffix
            if truncate_prob > 0:
                input_ids, maskable_mask = self._apply_truncate_masking(
                    input_ids, maskable_mask, truncate_prob
                )
        
        # Sample timesteps and compute sigma
        # CoDA line 475: sigma = (1 - sampling_eps) * rand + sampling_eps
        sampling_eps = self.config.sampling_eps
        t = (1 - sampling_eps) * torch.rand(batch_size, device=input_ids.device) + sampling_eps
        sigma = t
        # CoDA line 476: dsigma = 1 / sigma (for loss weighting)
        dsigma = torch.reciprocal(t)
        
        # Select block masking size
        if block_prob > 0 and mask_block_sizes and torch.rand(1).item() < block_prob:
            mask_block_size = mask_block_sizes[torch.randint(len(mask_block_sizes), (1,)).item()]
        else:
            mask_block_size = 1
        
        # Apply noise transition
        noisy_input_ids = self.transition(
            input_ids, sigma, maskable_mask, mask_block_size
        )
        
        # Track which positions are masked (for loss computation)
        loss_mask = (noisy_input_ids == self.mask_token_id)
        
        # Forward pass through model
        hidden_states = self.model(
            input_ids=noisy_input_ids,
            attention_mask=attention_mask,
        ).last_hidden_state
        
        logits = self.lm_head(hidden_states)
        logits = logits.float()
        
        # =================================================================
        # LOSS COMPUTATION - MATCHES CODA EXACTLY (modeling.py lines 509-524)
        # =================================================================
        # Shift for next-token prediction
        # logits: [batch, seq_len-1, vocab_size]
        # labels: [batch, seq_len-1]
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = input_ids[..., 1:].contiguous()
        shift_loss_mask = loss_mask[..., 1:].contiguous()
        
        # Cross-entropy loss per token
        loss = self.loss_fn(
            shift_logits.view(-1, self.config.vocab_size),
            shift_labels.view(-1)
        ).view(batch_size, -1)
        
        # Zero out loss for non-masked positions
        loss = loss.masked_fill(~shift_loss_mask, 0)
        
        # =================================================================
        # CRITICAL: CoDA normalization (line 524)
        # Divide by (batch_size * seq_len), NOT by num_masked!
        # This gives stable gradients regardless of mask ratio
        # =================================================================
        # loss = (dsigma[:, None] * loss).sum() / (batch_size * seq_len)
        loss = (dsigma.unsqueeze(-1) * loss).sum() / (batch_size * seq_len)
        
        return logits, loss
    
    def _apply_prefix_masking(
        self,
        input_ids: torch.LongTensor,
        maskable_mask: torch.BoolTensor,
        prefix_prob: float,
    ) -> torch.BoolTensor:
        """Apply S1: Random unmaskable prefix."""
        batch_size, seq_len = input_ids.shape
        
        # Randomly decide which samples get prefix
        apply_prefix = torch.rand(batch_size, device=input_ids.device) < prefix_prob
        
        # Generate random prefix lengths
        prefix_lengths = torch.randint(1, seq_len, (batch_size,), device=input_ids.device)
        
        # Create position indices
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        
        # Create prefix mask
        prefix_mask = positions < prefix_lengths.unsqueeze(1)
        
        # Apply: set maskable_mask to False for prefix positions
        maskable_mask = maskable_mask & ~(apply_prefix.unsqueeze(1) & prefix_mask)
        
        return maskable_mask
    
    def _apply_truncate_masking(
        self,
        input_ids: torch.LongTensor,
        maskable_mask: torch.BoolTensor,
        truncate_prob: float,
    ) -> Tuple[torch.LongTensor, torch.BoolTensor]:
        """Apply S2: Random truncated suffix."""
        batch_size, seq_len = input_ids.shape
        
        # Randomly decide which samples get truncated
        apply_truncate = torch.rand(batch_size, device=input_ids.device) < truncate_prob
        
        # Generate random truncation positions
        truncate_positions = torch.randint(1, seq_len, (batch_size,), device=input_ids.device)
        
        # Create position indices
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        
        # Create truncate mask
        truncate_mask = positions >= truncate_positions.unsqueeze(1)
        
        # Apply: replace with pad token and update maskable_mask
        input_ids = torch.where(
            apply_truncate.unsqueeze(1) & truncate_mask,
            self.config.pad_token_id,
            input_ids
        )
        maskable_mask = maskable_mask & (input_ids != self.config.pad_token_id)
        
        return input_ids, maskable_mask
    
    @classmethod
    def from_pretrained_qwen(
        cls,
        pretrained_model_name_or_path: str = "Qwen/Qwen3-1.7B",
        config: Optional[DiffusionQwen3Config] = None,
        **kwargs
    ) -> "DiffusionQwen3Model":
        """
        Load from a pretrained Qwen3 model and convert to diffusion.
        
        Args:
            pretrained_model_name_or_path: HuggingFace model name or path
            config: Optional DiffusionQwen3Config override
            **kwargs: Additional arguments for from_pretrained
            
        Returns:
            DiffusionQwen3Model ready for diffusion training
        """
        # Load the base Qwen model
        print(f"Loading base model from {pretrained_model_name_or_path}...")
        
        qwen_model = Qwen2ForCausalLM.from_pretrained(
            pretrained_model_name_or_path,
            torch_dtype=kwargs.pop("torch_dtype", torch.bfloat16),
            attn_implementation=kwargs.pop("attn_implementation", "flash_attention_2"),
            **kwargs
        )
        
        # Create diffusion config if not provided
        if config is None:
            qwen_config = qwen_model.config
            config = DiffusionQwen3Config(
                vocab_size=qwen_config.vocab_size,
                hidden_size=qwen_config.hidden_size,
                intermediate_size=qwen_config.intermediate_size,
                num_hidden_layers=qwen_config.num_hidden_layers,
                num_attention_heads=qwen_config.num_attention_heads,
                num_key_value_heads=qwen_config.num_key_value_heads,
                max_position_embeddings=qwen_config.max_position_embeddings,
                rms_norm_eps=qwen_config.rms_norm_eps,
                rope_theta=qwen_config.rope_theta,
            )
        
        # Create diffusion model and initialize from Qwen
        model = cls(config)
        model._init_from_qwen(qwen_model)
        
        print(f"Converted to DiffusionQwen3Model with bidirectional attention")
        print(f"  - Mask token ID: {config.mask_token_id}")
        print(f"  - Vocab size: {config.vocab_size}")
        print(f"  - Hidden size: {config.hidden_size}")
        print(f"  - Num layers: {config.num_hidden_layers}")
        
        return model


def prepare_tokenizer(tokenizer_name: str = "Qwen/Qwen3-1.7B") -> AutoTokenizer:
    """
    Prepare tokenizer with mask token for diffusion training.
    
    Args:
        tokenizer_name: HuggingFace tokenizer name
        
    Returns:
        Tokenizer with mask token added
    """
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
    
    # Check if mask token already exists
    if tokenizer.mask_token is None:
        # Add mask token (CoDA uses ID 151669)
        tokenizer.add_tokens("<|mask|>", special_tokens=True)
        tokenizer.add_special_tokens(
            {"mask_token": "<|mask|>"},
            replace_additional_special_tokens=False
        )
        print(f"Added mask token: {tokenizer.mask_token} (ID: {tokenizer.mask_token_id})")
    else:
        print(f"Mask token already exists: {tokenizer.mask_token} (ID: {tokenizer.mask_token_id})")
    
    return tokenizer