adhd-diffusion / modeling_diffusion_qwen3.py
shouryamaanjain's picture
Upload smj-diffusion checkpoint (step 12000)
cd2f2fc verified
"""
DiffusionQwen3 Model - Converts Qwen3-1.7B AR to Bidirectional Diffusion LLM
This module provides:
1. DiffusionQwen3Config - Configuration for diffusion-adapted Qwen3
2. DiffusionQwen3Model - The main model class with diffusion training/inference
Based on CoDA (Coding LM via Diffusion Adaptation) by Salesforce AI Research
https://arxiv.org/abs/2510.03270
CRITICAL: Loss normalization matches CoDA official implementation exactly:
loss = (dsigma[:, None] * loss).sum() / (batch_size * seq_len)
NOT dividing by num_masked (which causes gradient explosion)
"""
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union, List, Dict, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers import Qwen2ForCausalLM, Qwen2Config, AutoTokenizer
from transformers.modeling_outputs import CausalLMOutputWithPast
@dataclass
class DiffusionQwen3Config(PretrainedConfig):
"""Configuration for Diffusion-adapted Qwen3 model."""
model_type = "diffusion_qwen3"
def __init__(
self,
# Base Qwen3 config
vocab_size: int = 151936,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_hidden_layers: int = 28,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
head_dim: int = 128,
max_position_embeddings: int = 40960,
rms_norm_eps: float = 1e-6,
rope_theta: float = 1000000.0,
hidden_act: str = "silu",
attention_dropout: float = 0.0,
attention_bias: bool = False,
tie_word_embeddings: bool = True,
# Diffusion-specific config
mask_token_id: int = 151669,
pad_token_id: int = 151643,
bos_token_id: int = 151643,
eos_token_id: int = 151645,
# Diffusion training parameters
sampling_eps: float = 0.001, # CoDA default: creates 1/t in [1, 1000]
mask_block_sizes: List[int] = None,
block_masking_probability: float = 0.01,
prefix_probability: float = 0.01,
truncate_probability: float = 0.01,
**kwargs
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs
)
# Base model config
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.max_position_embeddings = max_position_embeddings
self.rms_norm_eps = rms_norm_eps
self.rope_theta = rope_theta
self.hidden_act = hidden_act
self.attention_dropout = attention_dropout
self.attention_bias = attention_bias
# Diffusion config
self.mask_token_id = mask_token_id
self.sampling_eps = sampling_eps
self.mask_block_sizes = mask_block_sizes or [2, 4, 8]
self.block_masking_probability = block_masking_probability
self.prefix_probability = prefix_probability
self.truncate_probability = truncate_probability
class DiffusionQwen3Model(PreTrainedModel):
"""
Qwen3 model adapted for discrete diffusion language modeling.
Key modifications from standard Qwen3:
1. Bidirectional attention (is_causal=False)
2. Masked diffusion training objective
3. Loss weighted by 1/t (inverse noise level)
4. Support for progressive masking (S1/S2/S3)
CRITICAL: Loss normalization follows CoDA exactly (line 524 of modeling.py):
loss = (dsigma[:, None] * loss).sum() / (batch_size * seq_len)
"""
config_class = DiffusionQwen3Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen2DecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, config: DiffusionQwen3Config):
super().__init__(config)
self.config = config
# Initialize the base Qwen2 model (Qwen3 uses Qwen2 architecture in transformers)
# We'll load this from pretrained in the from_pretrained method
self.model = None
self.lm_head = None
self.embed_tokens = None
# Diffusion parameters
self.mask_token_id = config.mask_token_id
self.sampling_eps = config.sampling_eps
# Loss function
self.loss_fn = nn.CrossEntropyLoss(reduction='none')
def _init_from_qwen(self, qwen_model: Qwen2ForCausalLM):
"""Initialize from a pretrained Qwen model."""
# Extract the base model and lm_head
self.model = qwen_model.model
self.lm_head = qwen_model.lm_head
self.embed_tokens = self.model.embed_tokens
# Disable causal masking in all attention layers
self._disable_causal_masking()
def _disable_causal_masking(self):
"""Disable causal attention masks for bidirectional attention."""
for layer in self.model.layers:
if hasattr(layer.self_attn, 'is_causal'):
layer.self_attn.is_causal = False
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_embeds(self, input_ids: torch.LongTensor) -> torch.Tensor:
"""Get token embeddings."""
return self.embed_tokens(input_ids)
def transition(
self,
x_0: torch.LongTensor,
sigma: torch.Tensor,
maskable_mask: torch.BoolTensor,
mask_block_size: int = 1,
) -> torch.LongTensor:
"""
Apply noise transition: mask tokens with probability sigma.
Args:
x_0: Original token IDs [batch_size, seq_len]
sigma: Noise level per sample [batch_size, 1] or [batch_size]
maskable_mask: Boolean mask of which positions can be masked [batch_size, seq_len]
mask_block_size: Size of contiguous blocks to mask (1 for individual tokens)
Returns:
x_t: Noisy token IDs with some tokens replaced by mask_token_id
"""
if sigma.dim() == 1:
sigma = sigma.unsqueeze(-1)
if mask_block_size == 1:
# Standard per-token masking
move_indices = (torch.rand_like(x_0, dtype=torch.float) < sigma) & maskable_mask
x_t = torch.where(move_indices, self.mask_token_id, x_0)
else:
# Block masking
x_t = self._block_masking(x_0, sigma, maskable_mask, mask_block_size)
return x_t
def _block_masking(
self,
x_0: torch.LongTensor,
sigma: torch.Tensor,
maskable_mask: torch.BoolTensor,
mask_block_size: int,
) -> torch.LongTensor:
"""Apply block masking for contiguous spans."""
batch_size, seq_len = x_0.shape
if seq_len < mask_block_size:
return x_0
# Calculate number of possible block positions
num_windows = seq_len - mask_block_size + 1
# Create all possible block positions
window_starts = torch.arange(num_windows, device=x_0.device)
block_offsets = torch.arange(mask_block_size, device=x_0.device)
all_positions = window_starts.unsqueeze(1) + block_offsets.unsqueeze(0)
# Check which blocks are fully maskable
maskable_blocks = maskable_mask.unsqueeze(1).expand(-1, num_windows, -1)
maskable_blocks = maskable_blocks.gather(2, all_positions.unsqueeze(0).expand(batch_size, -1, -1))
fully_maskable = maskable_blocks.all(dim=2)
# Scale sigma for block masking (CoDA line 569)
effective_sigma = 1 - (1 - sigma) ** (1 / mask_block_size)
# Determine which blocks to mask
should_mask = (torch.rand(batch_size, num_windows, device=x_0.device) < effective_sigma) & fully_maskable
# Create final mask
position_indices = torch.arange(seq_len, device=x_0.device).unsqueeze(0).unsqueeze(0)
all_positions_expanded = all_positions.unsqueeze(0)
should_mask_expanded = should_mask.unsqueeze(2)
position_matches = (position_indices == all_positions_expanded.unsqueeze(3)).any(dim=2)
should_mask_positions = should_mask_expanded & position_matches
final_mask = should_mask_positions.any(dim=1)
return torch.where(final_mask, self.mask_token_id, x_0)
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
src_mask: Optional[torch.BoolTensor] = None,
training_mode: str = "pretrain",
masking_schedule: Optional[Dict[str, Any]] = None,
epoch: Optional[int] = None,
return_logits_only: bool = False,
**kwargs,
) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], CausalLMOutputWithPast]:
"""
Forward pass with diffusion training.
Args:
input_ids: Input token IDs [batch_size, seq_len]
attention_mask: Attention mask [batch_size, seq_len]
labels: Target labels (same as input_ids for diffusion)
src_mask: Source mask for SFT (True = prompt, False = response)
training_mode: "pretrain", "midtrain", or "sft"
masking_schedule: Optional override for masking probabilities
epoch: Current epoch for progressive masking
return_logits_only: If True, skip diffusion training logic (used by trainer)
Returns:
logits: Model predictions [batch_size, seq_len, vocab_size]
loss: Diffusion loss (if training and not return_logits_only)
"""
if not self.training or return_logits_only:
# Inference mode OR trainer is handling diffusion logic
hidden_states = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
).last_hidden_state
logits = self.lm_head(hidden_states)
return CausalLMOutputWithPast(logits=logits, loss=None)
# Training mode
batch_size, seq_len = input_ids.shape
# Get masking configuration
if masking_schedule is not None:
prefix_prob = masking_schedule.get("prefix_probability", 0)
truncate_prob = masking_schedule.get("truncate_probability", 0)
block_prob = masking_schedule.get("block_masking_probability", 0)
mask_block_sizes = masking_schedule.get("mask_block_sizes", self.config.mask_block_sizes)
else:
prefix_prob = self.config.prefix_probability
truncate_prob = self.config.truncate_probability
block_prob = self.config.block_masking_probability
mask_block_sizes = self.config.mask_block_sizes
# Create maskable_mask based on training mode
if src_mask is not None:
# SFT mode: only mask response tokens
maskable_mask = ~src_mask
else:
# Pre-training/mid-training: all tokens maskable
maskable_mask = torch.ones_like(input_ids, dtype=torch.bool)
# Apply S1: Unmaskable prefix
if prefix_prob > 0:
maskable_mask = self._apply_prefix_masking(
input_ids, maskable_mask, prefix_prob
)
# Apply S2: Truncated suffix
if truncate_prob > 0:
input_ids, maskable_mask = self._apply_truncate_masking(
input_ids, maskable_mask, truncate_prob
)
# Sample timesteps and compute sigma
# CoDA line 475: sigma = (1 - sampling_eps) * rand + sampling_eps
sampling_eps = self.config.sampling_eps
t = (1 - sampling_eps) * torch.rand(batch_size, device=input_ids.device) + sampling_eps
sigma = t
# CoDA line 476: dsigma = 1 / sigma (for loss weighting)
dsigma = torch.reciprocal(t)
# Select block masking size
if block_prob > 0 and mask_block_sizes and torch.rand(1).item() < block_prob:
mask_block_size = mask_block_sizes[torch.randint(len(mask_block_sizes), (1,)).item()]
else:
mask_block_size = 1
# Apply noise transition
noisy_input_ids = self.transition(
input_ids, sigma, maskable_mask, mask_block_size
)
# Track which positions are masked (for loss computation)
loss_mask = (noisy_input_ids == self.mask_token_id)
# Forward pass through model
hidden_states = self.model(
input_ids=noisy_input_ids,
attention_mask=attention_mask,
).last_hidden_state
logits = self.lm_head(hidden_states)
logits = logits.float()
# =================================================================
# LOSS COMPUTATION - MATCHES CODA EXACTLY (modeling.py lines 509-524)
# =================================================================
# Shift for next-token prediction
# logits: [batch, seq_len-1, vocab_size]
# labels: [batch, seq_len-1]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., 1:].contiguous()
shift_loss_mask = loss_mask[..., 1:].contiguous()
# Cross-entropy loss per token
loss = self.loss_fn(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1)
).view(batch_size, -1)
# Zero out loss for non-masked positions
loss = loss.masked_fill(~shift_loss_mask, 0)
# =================================================================
# CRITICAL: CoDA normalization (line 524)
# Divide by (batch_size * seq_len), NOT by num_masked!
# This gives stable gradients regardless of mask ratio
# =================================================================
# loss = (dsigma[:, None] * loss).sum() / (batch_size * seq_len)
loss = (dsigma.unsqueeze(-1) * loss).sum() / (batch_size * seq_len)
return logits, loss
def _apply_prefix_masking(
self,
input_ids: torch.LongTensor,
maskable_mask: torch.BoolTensor,
prefix_prob: float,
) -> torch.BoolTensor:
"""Apply S1: Random unmaskable prefix."""
batch_size, seq_len = input_ids.shape
# Randomly decide which samples get prefix
apply_prefix = torch.rand(batch_size, device=input_ids.device) < prefix_prob
# Generate random prefix lengths
prefix_lengths = torch.randint(1, seq_len, (batch_size,), device=input_ids.device)
# Create position indices
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
# Create prefix mask
prefix_mask = positions < prefix_lengths.unsqueeze(1)
# Apply: set maskable_mask to False for prefix positions
maskable_mask = maskable_mask & ~(apply_prefix.unsqueeze(1) & prefix_mask)
return maskable_mask
def _apply_truncate_masking(
self,
input_ids: torch.LongTensor,
maskable_mask: torch.BoolTensor,
truncate_prob: float,
) -> Tuple[torch.LongTensor, torch.BoolTensor]:
"""Apply S2: Random truncated suffix."""
batch_size, seq_len = input_ids.shape
# Randomly decide which samples get truncated
apply_truncate = torch.rand(batch_size, device=input_ids.device) < truncate_prob
# Generate random truncation positions
truncate_positions = torch.randint(1, seq_len, (batch_size,), device=input_ids.device)
# Create position indices
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
# Create truncate mask
truncate_mask = positions >= truncate_positions.unsqueeze(1)
# Apply: replace with pad token and update maskable_mask
input_ids = torch.where(
apply_truncate.unsqueeze(1) & truncate_mask,
self.config.pad_token_id,
input_ids
)
maskable_mask = maskable_mask & (input_ids != self.config.pad_token_id)
return input_ids, maskable_mask
@classmethod
def from_pretrained_qwen(
cls,
pretrained_model_name_or_path: str = "Qwen/Qwen3-1.7B",
config: Optional[DiffusionQwen3Config] = None,
**kwargs
) -> "DiffusionQwen3Model":
"""
Load from a pretrained Qwen3 model and convert to diffusion.
Args:
pretrained_model_name_or_path: HuggingFace model name or path
config: Optional DiffusionQwen3Config override
**kwargs: Additional arguments for from_pretrained
Returns:
DiffusionQwen3Model ready for diffusion training
"""
# Load the base Qwen model
print(f"Loading base model from {pretrained_model_name_or_path}...")
qwen_model = Qwen2ForCausalLM.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=kwargs.pop("torch_dtype", torch.bfloat16),
attn_implementation=kwargs.pop("attn_implementation", "flash_attention_2"),
**kwargs
)
# Create diffusion config if not provided
if config is None:
qwen_config = qwen_model.config
config = DiffusionQwen3Config(
vocab_size=qwen_config.vocab_size,
hidden_size=qwen_config.hidden_size,
intermediate_size=qwen_config.intermediate_size,
num_hidden_layers=qwen_config.num_hidden_layers,
num_attention_heads=qwen_config.num_attention_heads,
num_key_value_heads=qwen_config.num_key_value_heads,
max_position_embeddings=qwen_config.max_position_embeddings,
rms_norm_eps=qwen_config.rms_norm_eps,
rope_theta=qwen_config.rope_theta,
)
# Create diffusion model and initialize from Qwen
model = cls(config)
model._init_from_qwen(qwen_model)
print(f"Converted to DiffusionQwen3Model with bidirectional attention")
print(f" - Mask token ID: {config.mask_token_id}")
print(f" - Vocab size: {config.vocab_size}")
print(f" - Hidden size: {config.hidden_size}")
print(f" - Num layers: {config.num_hidden_layers}")
return model
def prepare_tokenizer(tokenizer_name: str = "Qwen/Qwen3-1.7B") -> AutoTokenizer:
"""
Prepare tokenizer with mask token for diffusion training.
Args:
tokenizer_name: HuggingFace tokenizer name
Returns:
Tokenizer with mask token added
"""
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
# Check if mask token already exists
if tokenizer.mask_token is None:
# Add mask token (CoDA uses ID 151669)
tokenizer.add_tokens("<|mask|>", special_tokens=True)
tokenizer.add_special_tokens(
{"mask_token": "<|mask|>"},
replace_additional_special_tokens=False
)
print(f"Added mask token: {tokenizer.mask_token} (ID: {tokenizer.mask_token_id})")
else:
print(f"Mask token already exists: {tokenizer.mask_token} (ID: {tokenizer.mask_token_id})")
return tokenizer