""" DiffusionQwen3 Model - Converts Qwen3-1.7B AR to Bidirectional Diffusion LLM This module provides: 1. DiffusionQwen3Config - Configuration for diffusion-adapted Qwen3 2. DiffusionQwen3Model - The main model class with diffusion training/inference Based on CoDA (Coding LM via Diffusion Adaptation) by Salesforce AI Research https://arxiv.org/abs/2510.03270 CRITICAL: Loss normalization matches CoDA official implementation exactly: loss = (dsigma[:, None] * loss).sum() / (batch_size * seq_len) NOT dividing by num_masked (which causes gradient explosion) """ import math from dataclasses import dataclass from typing import Optional, Tuple, Union, List, Dict, Any import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PretrainedConfig from transformers import Qwen2ForCausalLM, Qwen2Config, AutoTokenizer from transformers.modeling_outputs import CausalLMOutputWithPast @dataclass class DiffusionQwen3Config(PretrainedConfig): """Configuration for Diffusion-adapted Qwen3 model.""" model_type = "diffusion_qwen3" def __init__( self, # Base Qwen3 config vocab_size: int = 151936, hidden_size: int = 2048, intermediate_size: int = 6144, num_hidden_layers: int = 28, num_attention_heads: int = 16, num_key_value_heads: int = 8, head_dim: int = 128, max_position_embeddings: int = 40960, rms_norm_eps: float = 1e-6, rope_theta: float = 1000000.0, hidden_act: str = "silu", attention_dropout: float = 0.0, attention_bias: bool = False, tie_word_embeddings: bool = True, # Diffusion-specific config mask_token_id: int = 151669, pad_token_id: int = 151643, bos_token_id: int = 151643, eos_token_id: int = 151645, # Diffusion training parameters sampling_eps: float = 0.001, # CoDA default: creates 1/t in [1, 1000] mask_block_sizes: List[int] = None, block_masking_probability: float = 0.01, prefix_probability: float = 0.01, truncate_probability: float = 0.01, **kwargs ): super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs ) # Base model config self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.head_dim = head_dim self.max_position_embeddings = max_position_embeddings self.rms_norm_eps = rms_norm_eps self.rope_theta = rope_theta self.hidden_act = hidden_act self.attention_dropout = attention_dropout self.attention_bias = attention_bias # Diffusion config self.mask_token_id = mask_token_id self.sampling_eps = sampling_eps self.mask_block_sizes = mask_block_sizes or [2, 4, 8] self.block_masking_probability = block_masking_probability self.prefix_probability = prefix_probability self.truncate_probability = truncate_probability class DiffusionQwen3Model(PreTrainedModel): """ Qwen3 model adapted for discrete diffusion language modeling. Key modifications from standard Qwen3: 1. Bidirectional attention (is_causal=False) 2. Masked diffusion training objective 3. Loss weighted by 1/t (inverse noise level) 4. Support for progressive masking (S1/S2/S3) CRITICAL: Loss normalization follows CoDA exactly (line 524 of modeling.py): loss = (dsigma[:, None] * loss).sum() / (batch_size * seq_len) """ config_class = DiffusionQwen3Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen2DecoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True def __init__(self, config: DiffusionQwen3Config): super().__init__(config) self.config = config # Initialize the base Qwen2 model (Qwen3 uses Qwen2 architecture in transformers) # We'll load this from pretrained in the from_pretrained method self.model = None self.lm_head = None self.embed_tokens = None # Diffusion parameters self.mask_token_id = config.mask_token_id self.sampling_eps = config.sampling_eps # Loss function self.loss_fn = nn.CrossEntropyLoss(reduction='none') def _init_from_qwen(self, qwen_model: Qwen2ForCausalLM): """Initialize from a pretrained Qwen model.""" # Extract the base model and lm_head self.model = qwen_model.model self.lm_head = qwen_model.lm_head self.embed_tokens = self.model.embed_tokens # Disable causal masking in all attention layers self._disable_causal_masking() def _disable_causal_masking(self): """Disable causal attention masks for bidirectional attention.""" for layer in self.model.layers: if hasattr(layer.self_attn, 'is_causal'): layer.self_attn.is_causal = False def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def get_embeds(self, input_ids: torch.LongTensor) -> torch.Tensor: """Get token embeddings.""" return self.embed_tokens(input_ids) def transition( self, x_0: torch.LongTensor, sigma: torch.Tensor, maskable_mask: torch.BoolTensor, mask_block_size: int = 1, ) -> torch.LongTensor: """ Apply noise transition: mask tokens with probability sigma. Args: x_0: Original token IDs [batch_size, seq_len] sigma: Noise level per sample [batch_size, 1] or [batch_size] maskable_mask: Boolean mask of which positions can be masked [batch_size, seq_len] mask_block_size: Size of contiguous blocks to mask (1 for individual tokens) Returns: x_t: Noisy token IDs with some tokens replaced by mask_token_id """ if sigma.dim() == 1: sigma = sigma.unsqueeze(-1) if mask_block_size == 1: # Standard per-token masking move_indices = (torch.rand_like(x_0, dtype=torch.float) < sigma) & maskable_mask x_t = torch.where(move_indices, self.mask_token_id, x_0) else: # Block masking x_t = self._block_masking(x_0, sigma, maskable_mask, mask_block_size) return x_t def _block_masking( self, x_0: torch.LongTensor, sigma: torch.Tensor, maskable_mask: torch.BoolTensor, mask_block_size: int, ) -> torch.LongTensor: """Apply block masking for contiguous spans.""" batch_size, seq_len = x_0.shape if seq_len < mask_block_size: return x_0 # Calculate number of possible block positions num_windows = seq_len - mask_block_size + 1 # Create all possible block positions window_starts = torch.arange(num_windows, device=x_0.device) block_offsets = torch.arange(mask_block_size, device=x_0.device) all_positions = window_starts.unsqueeze(1) + block_offsets.unsqueeze(0) # Check which blocks are fully maskable maskable_blocks = maskable_mask.unsqueeze(1).expand(-1, num_windows, -1) maskable_blocks = maskable_blocks.gather(2, all_positions.unsqueeze(0).expand(batch_size, -1, -1)) fully_maskable = maskable_blocks.all(dim=2) # Scale sigma for block masking (CoDA line 569) effective_sigma = 1 - (1 - sigma) ** (1 / mask_block_size) # Determine which blocks to mask should_mask = (torch.rand(batch_size, num_windows, device=x_0.device) < effective_sigma) & fully_maskable # Create final mask position_indices = torch.arange(seq_len, device=x_0.device).unsqueeze(0).unsqueeze(0) all_positions_expanded = all_positions.unsqueeze(0) should_mask_expanded = should_mask.unsqueeze(2) position_matches = (position_indices == all_positions_expanded.unsqueeze(3)).any(dim=2) should_mask_positions = should_mask_expanded & position_matches final_mask = should_mask_positions.any(dim=1) return torch.where(final_mask, self.mask_token_id, x_0) def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, src_mask: Optional[torch.BoolTensor] = None, training_mode: str = "pretrain", masking_schedule: Optional[Dict[str, Any]] = None, epoch: Optional[int] = None, return_logits_only: bool = False, **kwargs, ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], CausalLMOutputWithPast]: """ Forward pass with diffusion training. Args: input_ids: Input token IDs [batch_size, seq_len] attention_mask: Attention mask [batch_size, seq_len] labels: Target labels (same as input_ids for diffusion) src_mask: Source mask for SFT (True = prompt, False = response) training_mode: "pretrain", "midtrain", or "sft" masking_schedule: Optional override for masking probabilities epoch: Current epoch for progressive masking return_logits_only: If True, skip diffusion training logic (used by trainer) Returns: logits: Model predictions [batch_size, seq_len, vocab_size] loss: Diffusion loss (if training and not return_logits_only) """ if not self.training or return_logits_only: # Inference mode OR trainer is handling diffusion logic hidden_states = self.model( input_ids=input_ids, attention_mask=attention_mask, ).last_hidden_state logits = self.lm_head(hidden_states) return CausalLMOutputWithPast(logits=logits, loss=None) # Training mode batch_size, seq_len = input_ids.shape # Get masking configuration if masking_schedule is not None: prefix_prob = masking_schedule.get("prefix_probability", 0) truncate_prob = masking_schedule.get("truncate_probability", 0) block_prob = masking_schedule.get("block_masking_probability", 0) mask_block_sizes = masking_schedule.get("mask_block_sizes", self.config.mask_block_sizes) else: prefix_prob = self.config.prefix_probability truncate_prob = self.config.truncate_probability block_prob = self.config.block_masking_probability mask_block_sizes = self.config.mask_block_sizes # Create maskable_mask based on training mode if src_mask is not None: # SFT mode: only mask response tokens maskable_mask = ~src_mask else: # Pre-training/mid-training: all tokens maskable maskable_mask = torch.ones_like(input_ids, dtype=torch.bool) # Apply S1: Unmaskable prefix if prefix_prob > 0: maskable_mask = self._apply_prefix_masking( input_ids, maskable_mask, prefix_prob ) # Apply S2: Truncated suffix if truncate_prob > 0: input_ids, maskable_mask = self._apply_truncate_masking( input_ids, maskable_mask, truncate_prob ) # Sample timesteps and compute sigma # CoDA line 475: sigma = (1 - sampling_eps) * rand + sampling_eps sampling_eps = self.config.sampling_eps t = (1 - sampling_eps) * torch.rand(batch_size, device=input_ids.device) + sampling_eps sigma = t # CoDA line 476: dsigma = 1 / sigma (for loss weighting) dsigma = torch.reciprocal(t) # Select block masking size if block_prob > 0 and mask_block_sizes and torch.rand(1).item() < block_prob: mask_block_size = mask_block_sizes[torch.randint(len(mask_block_sizes), (1,)).item()] else: mask_block_size = 1 # Apply noise transition noisy_input_ids = self.transition( input_ids, sigma, maskable_mask, mask_block_size ) # Track which positions are masked (for loss computation) loss_mask = (noisy_input_ids == self.mask_token_id) # Forward pass through model hidden_states = self.model( input_ids=noisy_input_ids, attention_mask=attention_mask, ).last_hidden_state logits = self.lm_head(hidden_states) logits = logits.float() # ================================================================= # LOSS COMPUTATION - MATCHES CODA EXACTLY (modeling.py lines 509-524) # ================================================================= # Shift for next-token prediction # logits: [batch, seq_len-1, vocab_size] # labels: [batch, seq_len-1] shift_logits = logits[..., :-1, :].contiguous() shift_labels = input_ids[..., 1:].contiguous() shift_loss_mask = loss_mask[..., 1:].contiguous() # Cross-entropy loss per token loss = self.loss_fn( shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1) ).view(batch_size, -1) # Zero out loss for non-masked positions loss = loss.masked_fill(~shift_loss_mask, 0) # ================================================================= # CRITICAL: CoDA normalization (line 524) # Divide by (batch_size * seq_len), NOT by num_masked! # This gives stable gradients regardless of mask ratio # ================================================================= # loss = (dsigma[:, None] * loss).sum() / (batch_size * seq_len) loss = (dsigma.unsqueeze(-1) * loss).sum() / (batch_size * seq_len) return logits, loss def _apply_prefix_masking( self, input_ids: torch.LongTensor, maskable_mask: torch.BoolTensor, prefix_prob: float, ) -> torch.BoolTensor: """Apply S1: Random unmaskable prefix.""" batch_size, seq_len = input_ids.shape # Randomly decide which samples get prefix apply_prefix = torch.rand(batch_size, device=input_ids.device) < prefix_prob # Generate random prefix lengths prefix_lengths = torch.randint(1, seq_len, (batch_size,), device=input_ids.device) # Create position indices positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) # Create prefix mask prefix_mask = positions < prefix_lengths.unsqueeze(1) # Apply: set maskable_mask to False for prefix positions maskable_mask = maskable_mask & ~(apply_prefix.unsqueeze(1) & prefix_mask) return maskable_mask def _apply_truncate_masking( self, input_ids: torch.LongTensor, maskable_mask: torch.BoolTensor, truncate_prob: float, ) -> Tuple[torch.LongTensor, torch.BoolTensor]: """Apply S2: Random truncated suffix.""" batch_size, seq_len = input_ids.shape # Randomly decide which samples get truncated apply_truncate = torch.rand(batch_size, device=input_ids.device) < truncate_prob # Generate random truncation positions truncate_positions = torch.randint(1, seq_len, (batch_size,), device=input_ids.device) # Create position indices positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) # Create truncate mask truncate_mask = positions >= truncate_positions.unsqueeze(1) # Apply: replace with pad token and update maskable_mask input_ids = torch.where( apply_truncate.unsqueeze(1) & truncate_mask, self.config.pad_token_id, input_ids ) maskable_mask = maskable_mask & (input_ids != self.config.pad_token_id) return input_ids, maskable_mask @classmethod def from_pretrained_qwen( cls, pretrained_model_name_or_path: str = "Qwen/Qwen3-1.7B", config: Optional[DiffusionQwen3Config] = None, **kwargs ) -> "DiffusionQwen3Model": """ Load from a pretrained Qwen3 model and convert to diffusion. Args: pretrained_model_name_or_path: HuggingFace model name or path config: Optional DiffusionQwen3Config override **kwargs: Additional arguments for from_pretrained Returns: DiffusionQwen3Model ready for diffusion training """ # Load the base Qwen model print(f"Loading base model from {pretrained_model_name_or_path}...") qwen_model = Qwen2ForCausalLM.from_pretrained( pretrained_model_name_or_path, torch_dtype=kwargs.pop("torch_dtype", torch.bfloat16), attn_implementation=kwargs.pop("attn_implementation", "flash_attention_2"), **kwargs ) # Create diffusion config if not provided if config is None: qwen_config = qwen_model.config config = DiffusionQwen3Config( vocab_size=qwen_config.vocab_size, hidden_size=qwen_config.hidden_size, intermediate_size=qwen_config.intermediate_size, num_hidden_layers=qwen_config.num_hidden_layers, num_attention_heads=qwen_config.num_attention_heads, num_key_value_heads=qwen_config.num_key_value_heads, max_position_embeddings=qwen_config.max_position_embeddings, rms_norm_eps=qwen_config.rms_norm_eps, rope_theta=qwen_config.rope_theta, ) # Create diffusion model and initialize from Qwen model = cls(config) model._init_from_qwen(qwen_model) print(f"Converted to DiffusionQwen3Model with bidirectional attention") print(f" - Mask token ID: {config.mask_token_id}") print(f" - Vocab size: {config.vocab_size}") print(f" - Hidden size: {config.hidden_size}") print(f" - Num layers: {config.num_hidden_layers}") return model def prepare_tokenizer(tokenizer_name: str = "Qwen/Qwen3-1.7B") -> AutoTokenizer: """ Prepare tokenizer with mask token for diffusion training. Args: tokenizer_name: HuggingFace tokenizer name Returns: Tokenizer with mask token added """ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) # Check if mask token already exists if tokenizer.mask_token is None: # Add mask token (CoDA uses ID 151669) tokenizer.add_tokens("<|mask|>", special_tokens=True) tokenizer.add_special_tokens( {"mask_token": "<|mask|>"}, replace_additional_special_tokens=False ) print(f"Added mask token: {tokenizer.mask_token} (ID: {tokenizer.mask_token_id})") else: print(f"Mask token already exists: {tokenizer.mask_token} (ID: {tokenizer.mask_token_id})") return tokenizer