|
|
""" |
|
|
DiffusionQwen3 Model - Converts Qwen3-1.7B AR to Bidirectional Diffusion LLM |
|
|
|
|
|
This module provides: |
|
|
1. DiffusionQwen3Config - Configuration for diffusion-adapted Qwen3 |
|
|
2. DiffusionQwen3Model - The main model class with diffusion training/inference |
|
|
|
|
|
Based on CoDA (Coding LM via Diffusion Adaptation) by Salesforce AI Research |
|
|
https://arxiv.org/abs/2510.03270 |
|
|
|
|
|
CRITICAL: Loss normalization matches CoDA official implementation exactly: |
|
|
loss = (dsigma[:, None] * loss).sum() / (batch_size * seq_len) |
|
|
NOT dividing by num_masked (which causes gradient explosion) |
|
|
""" |
|
|
|
|
|
import math |
|
|
from dataclasses import dataclass |
|
|
from typing import Optional, Tuple, Union, List, Dict, Any |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
from transformers import Qwen2ForCausalLM, Qwen2Config, AutoTokenizer |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DiffusionQwen3Config(PretrainedConfig): |
|
|
"""Configuration for Diffusion-adapted Qwen3 model.""" |
|
|
|
|
|
model_type = "diffusion_qwen3" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
|
|
|
vocab_size: int = 151936, |
|
|
hidden_size: int = 2048, |
|
|
intermediate_size: int = 6144, |
|
|
num_hidden_layers: int = 28, |
|
|
num_attention_heads: int = 16, |
|
|
num_key_value_heads: int = 8, |
|
|
head_dim: int = 128, |
|
|
max_position_embeddings: int = 40960, |
|
|
rms_norm_eps: float = 1e-6, |
|
|
rope_theta: float = 1000000.0, |
|
|
hidden_act: str = "silu", |
|
|
attention_dropout: float = 0.0, |
|
|
attention_bias: bool = False, |
|
|
tie_word_embeddings: bool = True, |
|
|
|
|
|
|
|
|
mask_token_id: int = 151669, |
|
|
pad_token_id: int = 151643, |
|
|
bos_token_id: int = 151643, |
|
|
eos_token_id: int = 151645, |
|
|
|
|
|
|
|
|
sampling_eps: float = 0.001, |
|
|
mask_block_sizes: List[int] = None, |
|
|
block_masking_probability: float = 0.01, |
|
|
prefix_probability: float = 0.01, |
|
|
truncate_probability: float = 0.01, |
|
|
|
|
|
**kwargs |
|
|
): |
|
|
super().__init__( |
|
|
pad_token_id=pad_token_id, |
|
|
bos_token_id=bos_token_id, |
|
|
eos_token_id=eos_token_id, |
|
|
tie_word_embeddings=tie_word_embeddings, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
self.vocab_size = vocab_size |
|
|
self.hidden_size = hidden_size |
|
|
self.intermediate_size = intermediate_size |
|
|
self.num_hidden_layers = num_hidden_layers |
|
|
self.num_attention_heads = num_attention_heads |
|
|
self.num_key_value_heads = num_key_value_heads |
|
|
self.head_dim = head_dim |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.rms_norm_eps = rms_norm_eps |
|
|
self.rope_theta = rope_theta |
|
|
self.hidden_act = hidden_act |
|
|
self.attention_dropout = attention_dropout |
|
|
self.attention_bias = attention_bias |
|
|
|
|
|
|
|
|
self.mask_token_id = mask_token_id |
|
|
self.sampling_eps = sampling_eps |
|
|
self.mask_block_sizes = mask_block_sizes or [2, 4, 8] |
|
|
self.block_masking_probability = block_masking_probability |
|
|
self.prefix_probability = prefix_probability |
|
|
self.truncate_probability = truncate_probability |
|
|
|
|
|
|
|
|
class DiffusionQwen3Model(PreTrainedModel): |
|
|
""" |
|
|
Qwen3 model adapted for discrete diffusion language modeling. |
|
|
|
|
|
Key modifications from standard Qwen3: |
|
|
1. Bidirectional attention (is_causal=False) |
|
|
2. Masked diffusion training objective |
|
|
3. Loss weighted by 1/t (inverse noise level) |
|
|
4. Support for progressive masking (S1/S2/S3) |
|
|
|
|
|
CRITICAL: Loss normalization follows CoDA exactly (line 524 of modeling.py): |
|
|
loss = (dsigma[:, None] * loss).sum() / (batch_size * seq_len) |
|
|
""" |
|
|
|
|
|
config_class = DiffusionQwen3Config |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["Qwen2DecoderLayer"] |
|
|
_supports_flash_attn_2 = True |
|
|
_supports_sdpa = True |
|
|
|
|
|
def __init__(self, config: DiffusionQwen3Config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
|
|
|
self.model = None |
|
|
self.lm_head = None |
|
|
self.embed_tokens = None |
|
|
|
|
|
|
|
|
self.mask_token_id = config.mask_token_id |
|
|
self.sampling_eps = config.sampling_eps |
|
|
|
|
|
|
|
|
self.loss_fn = nn.CrossEntropyLoss(reduction='none') |
|
|
|
|
|
def _init_from_qwen(self, qwen_model: Qwen2ForCausalLM): |
|
|
"""Initialize from a pretrained Qwen model.""" |
|
|
|
|
|
self.model = qwen_model.model |
|
|
self.lm_head = qwen_model.lm_head |
|
|
self.embed_tokens = self.model.embed_tokens |
|
|
|
|
|
|
|
|
self._disable_causal_masking() |
|
|
|
|
|
def _disable_causal_masking(self): |
|
|
"""Disable causal attention masks for bidirectional attention.""" |
|
|
for layer in self.model.layers: |
|
|
if hasattr(layer.self_attn, 'is_causal'): |
|
|
layer.self_attn.is_causal = False |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.embed_tokens |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.embed_tokens = value |
|
|
self.model.embed_tokens = value |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.lm_head |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.lm_head = new_embeddings |
|
|
|
|
|
def get_embeds(self, input_ids: torch.LongTensor) -> torch.Tensor: |
|
|
"""Get token embeddings.""" |
|
|
return self.embed_tokens(input_ids) |
|
|
|
|
|
def transition( |
|
|
self, |
|
|
x_0: torch.LongTensor, |
|
|
sigma: torch.Tensor, |
|
|
maskable_mask: torch.BoolTensor, |
|
|
mask_block_size: int = 1, |
|
|
) -> torch.LongTensor: |
|
|
""" |
|
|
Apply noise transition: mask tokens with probability sigma. |
|
|
|
|
|
Args: |
|
|
x_0: Original token IDs [batch_size, seq_len] |
|
|
sigma: Noise level per sample [batch_size, 1] or [batch_size] |
|
|
maskable_mask: Boolean mask of which positions can be masked [batch_size, seq_len] |
|
|
mask_block_size: Size of contiguous blocks to mask (1 for individual tokens) |
|
|
|
|
|
Returns: |
|
|
x_t: Noisy token IDs with some tokens replaced by mask_token_id |
|
|
""" |
|
|
if sigma.dim() == 1: |
|
|
sigma = sigma.unsqueeze(-1) |
|
|
|
|
|
if mask_block_size == 1: |
|
|
|
|
|
move_indices = (torch.rand_like(x_0, dtype=torch.float) < sigma) & maskable_mask |
|
|
x_t = torch.where(move_indices, self.mask_token_id, x_0) |
|
|
else: |
|
|
|
|
|
x_t = self._block_masking(x_0, sigma, maskable_mask, mask_block_size) |
|
|
|
|
|
return x_t |
|
|
|
|
|
def _block_masking( |
|
|
self, |
|
|
x_0: torch.LongTensor, |
|
|
sigma: torch.Tensor, |
|
|
maskable_mask: torch.BoolTensor, |
|
|
mask_block_size: int, |
|
|
) -> torch.LongTensor: |
|
|
"""Apply block masking for contiguous spans.""" |
|
|
batch_size, seq_len = x_0.shape |
|
|
|
|
|
if seq_len < mask_block_size: |
|
|
return x_0 |
|
|
|
|
|
|
|
|
num_windows = seq_len - mask_block_size + 1 |
|
|
|
|
|
|
|
|
window_starts = torch.arange(num_windows, device=x_0.device) |
|
|
block_offsets = torch.arange(mask_block_size, device=x_0.device) |
|
|
all_positions = window_starts.unsqueeze(1) + block_offsets.unsqueeze(0) |
|
|
|
|
|
|
|
|
maskable_blocks = maskable_mask.unsqueeze(1).expand(-1, num_windows, -1) |
|
|
maskable_blocks = maskable_blocks.gather(2, all_positions.unsqueeze(0).expand(batch_size, -1, -1)) |
|
|
fully_maskable = maskable_blocks.all(dim=2) |
|
|
|
|
|
|
|
|
effective_sigma = 1 - (1 - sigma) ** (1 / mask_block_size) |
|
|
|
|
|
|
|
|
should_mask = (torch.rand(batch_size, num_windows, device=x_0.device) < effective_sigma) & fully_maskable |
|
|
|
|
|
|
|
|
position_indices = torch.arange(seq_len, device=x_0.device).unsqueeze(0).unsqueeze(0) |
|
|
all_positions_expanded = all_positions.unsqueeze(0) |
|
|
should_mask_expanded = should_mask.unsqueeze(2) |
|
|
|
|
|
position_matches = (position_indices == all_positions_expanded.unsqueeze(3)).any(dim=2) |
|
|
should_mask_positions = should_mask_expanded & position_matches |
|
|
final_mask = should_mask_positions.any(dim=1) |
|
|
|
|
|
return torch.where(final_mask, self.mask_token_id, x_0) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
src_mask: Optional[torch.BoolTensor] = None, |
|
|
training_mode: str = "pretrain", |
|
|
masking_schedule: Optional[Dict[str, Any]] = None, |
|
|
epoch: Optional[int] = None, |
|
|
return_logits_only: bool = False, |
|
|
**kwargs, |
|
|
) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], CausalLMOutputWithPast]: |
|
|
""" |
|
|
Forward pass with diffusion training. |
|
|
|
|
|
Args: |
|
|
input_ids: Input token IDs [batch_size, seq_len] |
|
|
attention_mask: Attention mask [batch_size, seq_len] |
|
|
labels: Target labels (same as input_ids for diffusion) |
|
|
src_mask: Source mask for SFT (True = prompt, False = response) |
|
|
training_mode: "pretrain", "midtrain", or "sft" |
|
|
masking_schedule: Optional override for masking probabilities |
|
|
epoch: Current epoch for progressive masking |
|
|
return_logits_only: If True, skip diffusion training logic (used by trainer) |
|
|
|
|
|
Returns: |
|
|
logits: Model predictions [batch_size, seq_len, vocab_size] |
|
|
loss: Diffusion loss (if training and not return_logits_only) |
|
|
""" |
|
|
if not self.training or return_logits_only: |
|
|
|
|
|
hidden_states = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
).last_hidden_state |
|
|
logits = self.lm_head(hidden_states) |
|
|
return CausalLMOutputWithPast(logits=logits, loss=None) |
|
|
|
|
|
|
|
|
batch_size, seq_len = input_ids.shape |
|
|
|
|
|
|
|
|
if masking_schedule is not None: |
|
|
prefix_prob = masking_schedule.get("prefix_probability", 0) |
|
|
truncate_prob = masking_schedule.get("truncate_probability", 0) |
|
|
block_prob = masking_schedule.get("block_masking_probability", 0) |
|
|
mask_block_sizes = masking_schedule.get("mask_block_sizes", self.config.mask_block_sizes) |
|
|
else: |
|
|
prefix_prob = self.config.prefix_probability |
|
|
truncate_prob = self.config.truncate_probability |
|
|
block_prob = self.config.block_masking_probability |
|
|
mask_block_sizes = self.config.mask_block_sizes |
|
|
|
|
|
|
|
|
if src_mask is not None: |
|
|
|
|
|
maskable_mask = ~src_mask |
|
|
else: |
|
|
|
|
|
maskable_mask = torch.ones_like(input_ids, dtype=torch.bool) |
|
|
|
|
|
|
|
|
if prefix_prob > 0: |
|
|
maskable_mask = self._apply_prefix_masking( |
|
|
input_ids, maskable_mask, prefix_prob |
|
|
) |
|
|
|
|
|
|
|
|
if truncate_prob > 0: |
|
|
input_ids, maskable_mask = self._apply_truncate_masking( |
|
|
input_ids, maskable_mask, truncate_prob |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
sampling_eps = self.config.sampling_eps |
|
|
t = (1 - sampling_eps) * torch.rand(batch_size, device=input_ids.device) + sampling_eps |
|
|
sigma = t |
|
|
|
|
|
dsigma = torch.reciprocal(t) |
|
|
|
|
|
|
|
|
if block_prob > 0 and mask_block_sizes and torch.rand(1).item() < block_prob: |
|
|
mask_block_size = mask_block_sizes[torch.randint(len(mask_block_sizes), (1,)).item()] |
|
|
else: |
|
|
mask_block_size = 1 |
|
|
|
|
|
|
|
|
noisy_input_ids = self.transition( |
|
|
input_ids, sigma, maskable_mask, mask_block_size |
|
|
) |
|
|
|
|
|
|
|
|
loss_mask = (noisy_input_ids == self.mask_token_id) |
|
|
|
|
|
|
|
|
hidden_states = self.model( |
|
|
input_ids=noisy_input_ids, |
|
|
attention_mask=attention_mask, |
|
|
).last_hidden_state |
|
|
|
|
|
logits = self.lm_head(hidden_states) |
|
|
logits = logits.float() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = input_ids[..., 1:].contiguous() |
|
|
shift_loss_mask = loss_mask[..., 1:].contiguous() |
|
|
|
|
|
|
|
|
loss = self.loss_fn( |
|
|
shift_logits.view(-1, self.config.vocab_size), |
|
|
shift_labels.view(-1) |
|
|
).view(batch_size, -1) |
|
|
|
|
|
|
|
|
loss = loss.masked_fill(~shift_loss_mask, 0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = (dsigma.unsqueeze(-1) * loss).sum() / (batch_size * seq_len) |
|
|
|
|
|
return logits, loss |
|
|
|
|
|
def _apply_prefix_masking( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
maskable_mask: torch.BoolTensor, |
|
|
prefix_prob: float, |
|
|
) -> torch.BoolTensor: |
|
|
"""Apply S1: Random unmaskable prefix.""" |
|
|
batch_size, seq_len = input_ids.shape |
|
|
|
|
|
|
|
|
apply_prefix = torch.rand(batch_size, device=input_ids.device) < prefix_prob |
|
|
|
|
|
|
|
|
prefix_lengths = torch.randint(1, seq_len, (batch_size,), device=input_ids.device) |
|
|
|
|
|
|
|
|
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) |
|
|
|
|
|
|
|
|
prefix_mask = positions < prefix_lengths.unsqueeze(1) |
|
|
|
|
|
|
|
|
maskable_mask = maskable_mask & ~(apply_prefix.unsqueeze(1) & prefix_mask) |
|
|
|
|
|
return maskable_mask |
|
|
|
|
|
def _apply_truncate_masking( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
maskable_mask: torch.BoolTensor, |
|
|
truncate_prob: float, |
|
|
) -> Tuple[torch.LongTensor, torch.BoolTensor]: |
|
|
"""Apply S2: Random truncated suffix.""" |
|
|
batch_size, seq_len = input_ids.shape |
|
|
|
|
|
|
|
|
apply_truncate = torch.rand(batch_size, device=input_ids.device) < truncate_prob |
|
|
|
|
|
|
|
|
truncate_positions = torch.randint(1, seq_len, (batch_size,), device=input_ids.device) |
|
|
|
|
|
|
|
|
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) |
|
|
|
|
|
|
|
|
truncate_mask = positions >= truncate_positions.unsqueeze(1) |
|
|
|
|
|
|
|
|
input_ids = torch.where( |
|
|
apply_truncate.unsqueeze(1) & truncate_mask, |
|
|
self.config.pad_token_id, |
|
|
input_ids |
|
|
) |
|
|
maskable_mask = maskable_mask & (input_ids != self.config.pad_token_id) |
|
|
|
|
|
return input_ids, maskable_mask |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained_qwen( |
|
|
cls, |
|
|
pretrained_model_name_or_path: str = "Qwen/Qwen3-1.7B", |
|
|
config: Optional[DiffusionQwen3Config] = None, |
|
|
**kwargs |
|
|
) -> "DiffusionQwen3Model": |
|
|
""" |
|
|
Load from a pretrained Qwen3 model and convert to diffusion. |
|
|
|
|
|
Args: |
|
|
pretrained_model_name_or_path: HuggingFace model name or path |
|
|
config: Optional DiffusionQwen3Config override |
|
|
**kwargs: Additional arguments for from_pretrained |
|
|
|
|
|
Returns: |
|
|
DiffusionQwen3Model ready for diffusion training |
|
|
""" |
|
|
|
|
|
print(f"Loading base model from {pretrained_model_name_or_path}...") |
|
|
|
|
|
qwen_model = Qwen2ForCausalLM.from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
torch_dtype=kwargs.pop("torch_dtype", torch.bfloat16), |
|
|
attn_implementation=kwargs.pop("attn_implementation", "flash_attention_2"), |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
if config is None: |
|
|
qwen_config = qwen_model.config |
|
|
config = DiffusionQwen3Config( |
|
|
vocab_size=qwen_config.vocab_size, |
|
|
hidden_size=qwen_config.hidden_size, |
|
|
intermediate_size=qwen_config.intermediate_size, |
|
|
num_hidden_layers=qwen_config.num_hidden_layers, |
|
|
num_attention_heads=qwen_config.num_attention_heads, |
|
|
num_key_value_heads=qwen_config.num_key_value_heads, |
|
|
max_position_embeddings=qwen_config.max_position_embeddings, |
|
|
rms_norm_eps=qwen_config.rms_norm_eps, |
|
|
rope_theta=qwen_config.rope_theta, |
|
|
) |
|
|
|
|
|
|
|
|
model = cls(config) |
|
|
model._init_from_qwen(qwen_model) |
|
|
|
|
|
print(f"Converted to DiffusionQwen3Model with bidirectional attention") |
|
|
print(f" - Mask token ID: {config.mask_token_id}") |
|
|
print(f" - Vocab size: {config.vocab_size}") |
|
|
print(f" - Hidden size: {config.hidden_size}") |
|
|
print(f" - Num layers: {config.num_hidden_layers}") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def prepare_tokenizer(tokenizer_name: str = "Qwen/Qwen3-1.7B") -> AutoTokenizer: |
|
|
""" |
|
|
Prepare tokenizer with mask token for diffusion training. |
|
|
|
|
|
Args: |
|
|
tokenizer_name: HuggingFace tokenizer name |
|
|
|
|
|
Returns: |
|
|
Tokenizer with mask token added |
|
|
""" |
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) |
|
|
|
|
|
|
|
|
if tokenizer.mask_token is None: |
|
|
|
|
|
tokenizer.add_tokens("<|mask|>", special_tokens=True) |
|
|
tokenizer.add_special_tokens( |
|
|
{"mask_token": "<|mask|>"}, |
|
|
replace_additional_special_tokens=False |
|
|
) |
|
|
print(f"Added mask token: {tokenizer.mask_token} (ID: {tokenizer.mask_token_id})") |
|
|
else: |
|
|
print(f"Mask token already exists: {tokenizer.mask_token} (ID: {tokenizer.mask_token_id})") |
|
|
|
|
|
return tokenizer |
|
|
|