File size: 20,366 Bytes
cd2f2fc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 | """
DiffusionQwen3 Model - Converts Qwen3-1.7B AR to Bidirectional Diffusion LLM
This module provides:
1. DiffusionQwen3Config - Configuration for diffusion-adapted Qwen3
2. DiffusionQwen3Model - The main model class with diffusion training/inference
Based on CoDA (Coding LM via Diffusion Adaptation) by Salesforce AI Research
https://arxiv.org/abs/2510.03270
CRITICAL: Loss normalization matches CoDA official implementation exactly:
loss = (dsigma[:, None] * loss).sum() / (batch_size * seq_len)
NOT dividing by num_masked (which causes gradient explosion)
"""
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union, List, Dict, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers import Qwen2ForCausalLM, Qwen2Config, AutoTokenizer
from transformers.modeling_outputs import CausalLMOutputWithPast
@dataclass
class DiffusionQwen3Config(PretrainedConfig):
"""Configuration for Diffusion-adapted Qwen3 model."""
model_type = "diffusion_qwen3"
def __init__(
self,
# Base Qwen3 config
vocab_size: int = 151936,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_hidden_layers: int = 28,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
head_dim: int = 128,
max_position_embeddings: int = 40960,
rms_norm_eps: float = 1e-6,
rope_theta: float = 1000000.0,
hidden_act: str = "silu",
attention_dropout: float = 0.0,
attention_bias: bool = False,
tie_word_embeddings: bool = True,
# Diffusion-specific config
mask_token_id: int = 151669,
pad_token_id: int = 151643,
bos_token_id: int = 151643,
eos_token_id: int = 151645,
# Diffusion training parameters
sampling_eps: float = 0.001, # CoDA default: creates 1/t in [1, 1000]
mask_block_sizes: List[int] = None,
block_masking_probability: float = 0.01,
prefix_probability: float = 0.01,
truncate_probability: float = 0.01,
**kwargs
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs
)
# Base model config
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.max_position_embeddings = max_position_embeddings
self.rms_norm_eps = rms_norm_eps
self.rope_theta = rope_theta
self.hidden_act = hidden_act
self.attention_dropout = attention_dropout
self.attention_bias = attention_bias
# Diffusion config
self.mask_token_id = mask_token_id
self.sampling_eps = sampling_eps
self.mask_block_sizes = mask_block_sizes or [2, 4, 8]
self.block_masking_probability = block_masking_probability
self.prefix_probability = prefix_probability
self.truncate_probability = truncate_probability
class DiffusionQwen3Model(PreTrainedModel):
"""
Qwen3 model adapted for discrete diffusion language modeling.
Key modifications from standard Qwen3:
1. Bidirectional attention (is_causal=False)
2. Masked diffusion training objective
3. Loss weighted by 1/t (inverse noise level)
4. Support for progressive masking (S1/S2/S3)
CRITICAL: Loss normalization follows CoDA exactly (line 524 of modeling.py):
loss = (dsigma[:, None] * loss).sum() / (batch_size * seq_len)
"""
config_class = DiffusionQwen3Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen2DecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, config: DiffusionQwen3Config):
super().__init__(config)
self.config = config
# Initialize the base Qwen2 model (Qwen3 uses Qwen2 architecture in transformers)
# We'll load this from pretrained in the from_pretrained method
self.model = None
self.lm_head = None
self.embed_tokens = None
# Diffusion parameters
self.mask_token_id = config.mask_token_id
self.sampling_eps = config.sampling_eps
# Loss function
self.loss_fn = nn.CrossEntropyLoss(reduction='none')
def _init_from_qwen(self, qwen_model: Qwen2ForCausalLM):
"""Initialize from a pretrained Qwen model."""
# Extract the base model and lm_head
self.model = qwen_model.model
self.lm_head = qwen_model.lm_head
self.embed_tokens = self.model.embed_tokens
# Disable causal masking in all attention layers
self._disable_causal_masking()
def _disable_causal_masking(self):
"""Disable causal attention masks for bidirectional attention."""
for layer in self.model.layers:
if hasattr(layer.self_attn, 'is_causal'):
layer.self_attn.is_causal = False
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_embeds(self, input_ids: torch.LongTensor) -> torch.Tensor:
"""Get token embeddings."""
return self.embed_tokens(input_ids)
def transition(
self,
x_0: torch.LongTensor,
sigma: torch.Tensor,
maskable_mask: torch.BoolTensor,
mask_block_size: int = 1,
) -> torch.LongTensor:
"""
Apply noise transition: mask tokens with probability sigma.
Args:
x_0: Original token IDs [batch_size, seq_len]
sigma: Noise level per sample [batch_size, 1] or [batch_size]
maskable_mask: Boolean mask of which positions can be masked [batch_size, seq_len]
mask_block_size: Size of contiguous blocks to mask (1 for individual tokens)
Returns:
x_t: Noisy token IDs with some tokens replaced by mask_token_id
"""
if sigma.dim() == 1:
sigma = sigma.unsqueeze(-1)
if mask_block_size == 1:
# Standard per-token masking
move_indices = (torch.rand_like(x_0, dtype=torch.float) < sigma) & maskable_mask
x_t = torch.where(move_indices, self.mask_token_id, x_0)
else:
# Block masking
x_t = self._block_masking(x_0, sigma, maskable_mask, mask_block_size)
return x_t
def _block_masking(
self,
x_0: torch.LongTensor,
sigma: torch.Tensor,
maskable_mask: torch.BoolTensor,
mask_block_size: int,
) -> torch.LongTensor:
"""Apply block masking for contiguous spans."""
batch_size, seq_len = x_0.shape
if seq_len < mask_block_size:
return x_0
# Calculate number of possible block positions
num_windows = seq_len - mask_block_size + 1
# Create all possible block positions
window_starts = torch.arange(num_windows, device=x_0.device)
block_offsets = torch.arange(mask_block_size, device=x_0.device)
all_positions = window_starts.unsqueeze(1) + block_offsets.unsqueeze(0)
# Check which blocks are fully maskable
maskable_blocks = maskable_mask.unsqueeze(1).expand(-1, num_windows, -1)
maskable_blocks = maskable_blocks.gather(2, all_positions.unsqueeze(0).expand(batch_size, -1, -1))
fully_maskable = maskable_blocks.all(dim=2)
# Scale sigma for block masking (CoDA line 569)
effective_sigma = 1 - (1 - sigma) ** (1 / mask_block_size)
# Determine which blocks to mask
should_mask = (torch.rand(batch_size, num_windows, device=x_0.device) < effective_sigma) & fully_maskable
# Create final mask
position_indices = torch.arange(seq_len, device=x_0.device).unsqueeze(0).unsqueeze(0)
all_positions_expanded = all_positions.unsqueeze(0)
should_mask_expanded = should_mask.unsqueeze(2)
position_matches = (position_indices == all_positions_expanded.unsqueeze(3)).any(dim=2)
should_mask_positions = should_mask_expanded & position_matches
final_mask = should_mask_positions.any(dim=1)
return torch.where(final_mask, self.mask_token_id, x_0)
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
src_mask: Optional[torch.BoolTensor] = None,
training_mode: str = "pretrain",
masking_schedule: Optional[Dict[str, Any]] = None,
epoch: Optional[int] = None,
return_logits_only: bool = False,
**kwargs,
) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], CausalLMOutputWithPast]:
"""
Forward pass with diffusion training.
Args:
input_ids: Input token IDs [batch_size, seq_len]
attention_mask: Attention mask [batch_size, seq_len]
labels: Target labels (same as input_ids for diffusion)
src_mask: Source mask for SFT (True = prompt, False = response)
training_mode: "pretrain", "midtrain", or "sft"
masking_schedule: Optional override for masking probabilities
epoch: Current epoch for progressive masking
return_logits_only: If True, skip diffusion training logic (used by trainer)
Returns:
logits: Model predictions [batch_size, seq_len, vocab_size]
loss: Diffusion loss (if training and not return_logits_only)
"""
if not self.training or return_logits_only:
# Inference mode OR trainer is handling diffusion logic
hidden_states = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
).last_hidden_state
logits = self.lm_head(hidden_states)
return CausalLMOutputWithPast(logits=logits, loss=None)
# Training mode
batch_size, seq_len = input_ids.shape
# Get masking configuration
if masking_schedule is not None:
prefix_prob = masking_schedule.get("prefix_probability", 0)
truncate_prob = masking_schedule.get("truncate_probability", 0)
block_prob = masking_schedule.get("block_masking_probability", 0)
mask_block_sizes = masking_schedule.get("mask_block_sizes", self.config.mask_block_sizes)
else:
prefix_prob = self.config.prefix_probability
truncate_prob = self.config.truncate_probability
block_prob = self.config.block_masking_probability
mask_block_sizes = self.config.mask_block_sizes
# Create maskable_mask based on training mode
if src_mask is not None:
# SFT mode: only mask response tokens
maskable_mask = ~src_mask
else:
# Pre-training/mid-training: all tokens maskable
maskable_mask = torch.ones_like(input_ids, dtype=torch.bool)
# Apply S1: Unmaskable prefix
if prefix_prob > 0:
maskable_mask = self._apply_prefix_masking(
input_ids, maskable_mask, prefix_prob
)
# Apply S2: Truncated suffix
if truncate_prob > 0:
input_ids, maskable_mask = self._apply_truncate_masking(
input_ids, maskable_mask, truncate_prob
)
# Sample timesteps and compute sigma
# CoDA line 475: sigma = (1 - sampling_eps) * rand + sampling_eps
sampling_eps = self.config.sampling_eps
t = (1 - sampling_eps) * torch.rand(batch_size, device=input_ids.device) + sampling_eps
sigma = t
# CoDA line 476: dsigma = 1 / sigma (for loss weighting)
dsigma = torch.reciprocal(t)
# Select block masking size
if block_prob > 0 and mask_block_sizes and torch.rand(1).item() < block_prob:
mask_block_size = mask_block_sizes[torch.randint(len(mask_block_sizes), (1,)).item()]
else:
mask_block_size = 1
# Apply noise transition
noisy_input_ids = self.transition(
input_ids, sigma, maskable_mask, mask_block_size
)
# Track which positions are masked (for loss computation)
loss_mask = (noisy_input_ids == self.mask_token_id)
# Forward pass through model
hidden_states = self.model(
input_ids=noisy_input_ids,
attention_mask=attention_mask,
).last_hidden_state
logits = self.lm_head(hidden_states)
logits = logits.float()
# =================================================================
# LOSS COMPUTATION - MATCHES CODA EXACTLY (modeling.py lines 509-524)
# =================================================================
# Shift for next-token prediction
# logits: [batch, seq_len-1, vocab_size]
# labels: [batch, seq_len-1]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., 1:].contiguous()
shift_loss_mask = loss_mask[..., 1:].contiguous()
# Cross-entropy loss per token
loss = self.loss_fn(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1)
).view(batch_size, -1)
# Zero out loss for non-masked positions
loss = loss.masked_fill(~shift_loss_mask, 0)
# =================================================================
# CRITICAL: CoDA normalization (line 524)
# Divide by (batch_size * seq_len), NOT by num_masked!
# This gives stable gradients regardless of mask ratio
# =================================================================
# loss = (dsigma[:, None] * loss).sum() / (batch_size * seq_len)
loss = (dsigma.unsqueeze(-1) * loss).sum() / (batch_size * seq_len)
return logits, loss
def _apply_prefix_masking(
self,
input_ids: torch.LongTensor,
maskable_mask: torch.BoolTensor,
prefix_prob: float,
) -> torch.BoolTensor:
"""Apply S1: Random unmaskable prefix."""
batch_size, seq_len = input_ids.shape
# Randomly decide which samples get prefix
apply_prefix = torch.rand(batch_size, device=input_ids.device) < prefix_prob
# Generate random prefix lengths
prefix_lengths = torch.randint(1, seq_len, (batch_size,), device=input_ids.device)
# Create position indices
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
# Create prefix mask
prefix_mask = positions < prefix_lengths.unsqueeze(1)
# Apply: set maskable_mask to False for prefix positions
maskable_mask = maskable_mask & ~(apply_prefix.unsqueeze(1) & prefix_mask)
return maskable_mask
def _apply_truncate_masking(
self,
input_ids: torch.LongTensor,
maskable_mask: torch.BoolTensor,
truncate_prob: float,
) -> Tuple[torch.LongTensor, torch.BoolTensor]:
"""Apply S2: Random truncated suffix."""
batch_size, seq_len = input_ids.shape
# Randomly decide which samples get truncated
apply_truncate = torch.rand(batch_size, device=input_ids.device) < truncate_prob
# Generate random truncation positions
truncate_positions = torch.randint(1, seq_len, (batch_size,), device=input_ids.device)
# Create position indices
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
# Create truncate mask
truncate_mask = positions >= truncate_positions.unsqueeze(1)
# Apply: replace with pad token and update maskable_mask
input_ids = torch.where(
apply_truncate.unsqueeze(1) & truncate_mask,
self.config.pad_token_id,
input_ids
)
maskable_mask = maskable_mask & (input_ids != self.config.pad_token_id)
return input_ids, maskable_mask
@classmethod
def from_pretrained_qwen(
cls,
pretrained_model_name_or_path: str = "Qwen/Qwen3-1.7B",
config: Optional[DiffusionQwen3Config] = None,
**kwargs
) -> "DiffusionQwen3Model":
"""
Load from a pretrained Qwen3 model and convert to diffusion.
Args:
pretrained_model_name_or_path: HuggingFace model name or path
config: Optional DiffusionQwen3Config override
**kwargs: Additional arguments for from_pretrained
Returns:
DiffusionQwen3Model ready for diffusion training
"""
# Load the base Qwen model
print(f"Loading base model from {pretrained_model_name_or_path}...")
qwen_model = Qwen2ForCausalLM.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=kwargs.pop("torch_dtype", torch.bfloat16),
attn_implementation=kwargs.pop("attn_implementation", "flash_attention_2"),
**kwargs
)
# Create diffusion config if not provided
if config is None:
qwen_config = qwen_model.config
config = DiffusionQwen3Config(
vocab_size=qwen_config.vocab_size,
hidden_size=qwen_config.hidden_size,
intermediate_size=qwen_config.intermediate_size,
num_hidden_layers=qwen_config.num_hidden_layers,
num_attention_heads=qwen_config.num_attention_heads,
num_key_value_heads=qwen_config.num_key_value_heads,
max_position_embeddings=qwen_config.max_position_embeddings,
rms_norm_eps=qwen_config.rms_norm_eps,
rope_theta=qwen_config.rope_theta,
)
# Create diffusion model and initialize from Qwen
model = cls(config)
model._init_from_qwen(qwen_model)
print(f"Converted to DiffusionQwen3Model with bidirectional attention")
print(f" - Mask token ID: {config.mask_token_id}")
print(f" - Vocab size: {config.vocab_size}")
print(f" - Hidden size: {config.hidden_size}")
print(f" - Num layers: {config.num_hidden_layers}")
return model
def prepare_tokenizer(tokenizer_name: str = "Qwen/Qwen3-1.7B") -> AutoTokenizer:
"""
Prepare tokenizer with mask token for diffusion training.
Args:
tokenizer_name: HuggingFace tokenizer name
Returns:
Tokenizer with mask token added
"""
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
# Check if mask token already exists
if tokenizer.mask_token is None:
# Add mask token (CoDA uses ID 151669)
tokenizer.add_tokens("<|mask|>", special_tokens=True)
tokenizer.add_special_tokens(
{"mask_token": "<|mask|>"},
replace_additional_special_tokens=False
)
print(f"Added mask token: {tokenizer.mask_token} (ID: {tokenizer.mask_token_id})")
else:
print(f"Mask token already exists: {tokenizer.mask_token} (ID: {tokenizer.mask_token_id})")
return tokenizer
|