""" Small MDLM (Masked Diffusion Language Model) for text generation. Based on: "Simple and Effective Masked Diffusion Language Models" (Sahoo et al., NeurIPS 2024) Architecture: DiT backbone with adaLN-zero conditioning, RoPE, bidirectional attention. No flash_attn dependency — uses PyTorch native scaled_dot_product_attention. """ import math import typing import json import os import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import MaskedLMOutput class MDLMConfig(PretrainedConfig): """Configuration for a small MDLM text diffusion model.""" model_type = "mdlm" def __init__( self, vocab_size: int = 50258, model_length: int = 256, hidden_dim: int = 512, cond_dim: int = 128, n_blocks: int = 6, n_heads: int = 8, dropout: float = 0.1, time_conditioning: bool = True, mlp_ratio: int = 4, mask_token_id: int = 50257, **kwargs ): super().__init__(**kwargs) self.vocab_size = vocab_size self.model_length = model_length self.hidden_dim = hidden_dim self.cond_dim = cond_dim self.n_blocks = n_blocks self.n_heads = n_heads self.dropout = dropout self.time_conditioning = time_conditioning self.mlp_ratio = mlp_ratio self.mask_token_id = mask_token_id # ─── Rotary Position Embeddings ─────────────────────────── class RotaryEmbedding(nn.Module): def __init__(self, dim, base=10000): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) def forward(self, seq_len, device): t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) return torch.cat([freqs, freqs], dim=-1) # (seq_len, dim) def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, freqs): """Apply RoPE to query and key tensors.""" cos = freqs.cos().unsqueeze(0).unsqueeze(2) # (1, seq, 1, dim) sin = freqs.sin().unsqueeze(0).unsqueeze(2) # (1, seq, 1, dim) q = q * cos + rotate_half(q) * sin k = k * cos + rotate_half(k) * sin return q, k # ─── Timestep Embedding ────────────────────────────────── class TimestepEmbedder(nn.Module): def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size), nn.SiLU(), nn.Linear(hidden_size, hidden_size), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=t.device) / half ) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) return self.mlp(t_freq) # ─── LayerNorm ──────────────────────────────────────────── class LayerNorm(nn.Module): def __init__(self, dim): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.dim = dim def forward(self, x): with torch.amp.autocast("cuda", enabled=False): x = F.layer_norm(x.float(), [self.dim]) return x * self.weight[None, None, :] # ─── DiT Block with adaLN-zero ─────────────────────────── class DDiTBlock(nn.Module): def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1): super().__init__() self.n_heads = n_heads self.head_dim = dim // n_heads self.norm1 = LayerNorm(dim) self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False) self.attn_out = nn.Linear(dim, dim, bias=False) self.norm2 = LayerNorm(dim) self.mlp = nn.Sequential( nn.Linear(dim, mlp_ratio * dim), nn.GELU(approximate="tanh"), nn.Linear(mlp_ratio * dim, dim), ) self.dropout = nn.Dropout(dropout) self.drop_p = dropout # adaLN-zero: 6 modulation params (shift, scale, gate for attn & mlp) self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True) nn.init.zeros_(self.adaLN_modulation.weight) nn.init.zeros_(self.adaLN_modulation.bias) def forward(self, x, rotary_freqs, c): B, S, D = x.shape # adaLN modulation mod = self.adaLN_modulation(c)[:, None, :] # (B, 1, 6*D) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=-1) # ── Self-Attention ── h = self.norm1(x) h = h * (1 + scale_msa) + shift_msa qkv = self.attn_qkv(h) qkv = qkv.view(B, S, 3, self.n_heads, self.head_dim) q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] # q, k, v: (B, S, n_heads, head_dim) # Apply RoPE q, k = apply_rotary_pos_emb(q, k, rotary_freqs) # Transpose to (B, n_heads, S, head_dim) for SDPA q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # Bidirectional attention (no causal mask) attn_out = F.scaled_dot_product_attention( q, k, v, dropout_p=self.drop_p if self.training else 0.0, is_causal=False, ) attn_out = attn_out.transpose(1, 2).reshape(B, S, D) attn_out = self.attn_out(attn_out) x = x + gate_msa * self.dropout(attn_out) # ── MLP ── h = self.norm2(x) h = h * (1 + scale_mlp) + shift_mlp x = x + gate_mlp * self.dropout(self.mlp(h)) return x # ─── Final Layer ────────────────────────────────────────── class DDitFinalLayer(nn.Module): def __init__(self, hidden_size, out_channels, cond_dim): super().__init__() self.norm_final = LayerNorm(hidden_size) self.linear = nn.Linear(hidden_size, out_channels) nn.init.zeros_(self.linear.weight) nn.init.zeros_(self.linear.bias) self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True) nn.init.zeros_(self.adaLN_modulation.weight) nn.init.zeros_(self.adaLN_modulation.bias) def forward(self, x, c): shift, scale = self.adaLN_modulation(c)[:, None, :].chunk(2, dim=-1) x = self.norm_final(x) x = x * (1 + scale) + shift return self.linear(x) # ─── Full Model ────────────────────────────────────────── class MDLM(PreTrainedModel): """ Small Masked Diffusion Language Model. Forward pass: given noisy input_ids and timesteps t ∈ [0,1], predicts logits over vocab for each position. """ config_class = MDLMConfig def __init__(self, config: MDLMConfig): super().__init__(config) self.config = config self.vocab_embed = nn.Embedding(config.vocab_size, config.hidden_dim) nn.init.kaiming_uniform_(self.vocab_embed.weight, a=math.sqrt(5)) self.sigma_map = TimestepEmbedder(config.cond_dim) self.rotary_emb = RotaryEmbedding(config.hidden_dim // config.n_heads) self.blocks = nn.ModuleList([ DDiTBlock( config.hidden_dim, config.n_heads, config.cond_dim, mlp_ratio=config.mlp_ratio, dropout=config.dropout, ) for _ in range(config.n_blocks) ]) self.output_layer = DDitFinalLayer( config.hidden_dim, config.vocab_size, config.cond_dim ) # Separate output projection (no weight tying with embeddings) self.post_init() def get_num_params(self): return sum(p.numel() for p in self.parameters()) def forward( self, input_ids: torch.LongTensor, timesteps: torch.FloatTensor, output_hidden_states: bool = False, return_dict: bool = True, ): B, S = input_ids.shape x = self.vocab_embed(input_ids) if not self.config.time_conditioning: timesteps = torch.zeros_like(timesteps) c = F.silu(self.sigma_map(timesteps)) rotary_freqs = self.rotary_emb(S, device=x.device) all_hidden = [x] if output_hidden_states else None # Mixed precision: let the outer training loop handle autocast for block in self.blocks: x = block(x, rotary_freqs, c) if output_hidden_states: all_hidden.append(x) logits = self.output_layer(x, c) if return_dict: return MaskedLMOutput(logits=logits, hidden_states=all_hidden, loss=None) return logits # ─── Sampling ───────────────────────────────────────────── @torch.no_grad() def sample( model: MDLM, seq_len: int, batch_size: int = 1, num_steps: int = 100, temperature: float = 0.7, device: str = "cuda", ): """ Ancestral sampling from MDLM. Start from all [MASK] tokens. At each step s→t (t < s): unmask tokens with probability (1 - t/s), using model predictions. """ mask_id = model.config.mask_token_id # Start with all masked x = torch.full((batch_size, seq_len), mask_id, dtype=torch.long, device=device) # Discretize time from 1→0 timesteps = torch.linspace(1.0, 0.0, num_steps + 1, device=device) for i in range(num_steps): t_now = timesteps[i] t_next = timesteps[i + 1] # Get model predictions t_batch = torch.full((batch_size,), t_now.item(), device=device) output = model(x, t_batch, return_dict=True) logits = output.logits / temperature # Sample from predicted distribution probs = F.softmax(logits, dim=-1) predicted = torch.multinomial(probs.view(-1, probs.shape[-1]), 1).view(batch_size, seq_len) # Determine which masked positions to unmask is_masked = (x == mask_id) if t_next <= 0: # Last step: unmask everything x = torch.where(is_masked, predicted, x) else: # Unmask with probability (1 - t_next/t_now) unmask_prob = 1.0 - (t_next / t_now) unmask = torch.bernoulli( torch.full_like(x, unmask_prob, dtype=torch.float) ).bool() & is_masked x = torch.where(unmask, predicted, x) return x