| """ |
| Small MDLM (Masked Diffusion Language Model) for text generation. |
| |
| Based on: "Simple and Effective Masked Diffusion Language Models" (Sahoo et al., NeurIPS 2024) |
| Architecture: DiT backbone with adaLN-zero conditioning, RoPE, bidirectional attention. |
| No flash_attn dependency β uses PyTorch native scaled_dot_product_attention. |
| """ |
|
|
| import math |
| import typing |
| import json |
| import os |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers.modeling_outputs import MaskedLMOutput |
|
|
|
|
| class MDLMConfig(PretrainedConfig): |
| """Configuration for a small MDLM text diffusion model.""" |
| model_type = "mdlm" |
|
|
| def __init__( |
| self, |
| vocab_size: int = 50258, |
| model_length: int = 256, |
| hidden_dim: int = 512, |
| cond_dim: int = 128, |
| n_blocks: int = 6, |
| n_heads: int = 8, |
| dropout: float = 0.1, |
| time_conditioning: bool = True, |
| mlp_ratio: int = 4, |
| mask_token_id: int = 50257, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.model_length = model_length |
| self.hidden_dim = hidden_dim |
| self.cond_dim = cond_dim |
| self.n_blocks = n_blocks |
| self.n_heads = n_heads |
| self.dropout = dropout |
| self.time_conditioning = time_conditioning |
| self.mlp_ratio = mlp_ratio |
| self.mask_token_id = mask_token_id |
|
|
|
|
| |
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, dim, base=10000): |
| super().__init__() |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq) |
| |
| def forward(self, seq_len, device): |
| t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) |
| freqs = torch.outer(t, self.inv_freq) |
| return torch.cat([freqs, freqs], dim=-1) |
|
|
|
|
| def rotate_half(x): |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, freqs): |
| """Apply RoPE to query and key tensors.""" |
| cos = freqs.cos().unsqueeze(0).unsqueeze(2) |
| sin = freqs.sin().unsqueeze(0).unsqueeze(2) |
| q = q * cos + rotate_half(q) * sin |
| k = k * cos + rotate_half(k) * sin |
| return q, k |
|
|
|
|
| |
|
|
| class TimestepEmbedder(nn.Module): |
| def __init__(self, hidden_size, frequency_embedding_size=256): |
| super().__init__() |
| self.mlp = nn.Sequential( |
| nn.Linear(frequency_embedding_size, hidden_size), |
| nn.SiLU(), |
| nn.Linear(hidden_size, hidden_size), |
| ) |
| self.frequency_embedding_size = frequency_embedding_size |
|
|
| @staticmethod |
| def timestep_embedding(t, dim, max_period=10000): |
| half = dim // 2 |
| freqs = torch.exp( |
| -math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=t.device) / half |
| ) |
| args = t[:, None].float() * freqs[None] |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| if dim % 2: |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
| return embedding |
|
|
| def forward(self, t): |
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size) |
| return self.mlp(t_freq) |
|
|
|
|
| |
|
|
| class LayerNorm(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.dim = dim |
|
|
| def forward(self, x): |
| with torch.amp.autocast("cuda", enabled=False): |
| x = F.layer_norm(x.float(), [self.dim]) |
| return x * self.weight[None, None, :] |
|
|
|
|
| |
|
|
| class DDiTBlock(nn.Module): |
| def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1): |
| super().__init__() |
| self.n_heads = n_heads |
| self.head_dim = dim // n_heads |
|
|
| self.norm1 = LayerNorm(dim) |
| self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False) |
| self.attn_out = nn.Linear(dim, dim, bias=False) |
|
|
| self.norm2 = LayerNorm(dim) |
| self.mlp = nn.Sequential( |
| nn.Linear(dim, mlp_ratio * dim), |
| nn.GELU(approximate="tanh"), |
| nn.Linear(mlp_ratio * dim, dim), |
| ) |
| self.dropout = nn.Dropout(dropout) |
| self.drop_p = dropout |
|
|
| |
| self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True) |
| nn.init.zeros_(self.adaLN_modulation.weight) |
| nn.init.zeros_(self.adaLN_modulation.bias) |
|
|
| def forward(self, x, rotary_freqs, c): |
| B, S, D = x.shape |
|
|
| |
| mod = self.adaLN_modulation(c)[:, None, :] |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=-1) |
|
|
| |
| h = self.norm1(x) |
| h = h * (1 + scale_msa) + shift_msa |
|
|
| qkv = self.attn_qkv(h) |
| qkv = qkv.view(B, S, 3, self.n_heads, self.head_dim) |
| q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] |
| |
|
|
| |
| q, k = apply_rotary_pos_emb(q, k, rotary_freqs) |
|
|
| |
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
|
|
| |
| attn_out = F.scaled_dot_product_attention( |
| q, k, v, |
| dropout_p=self.drop_p if self.training else 0.0, |
| is_causal=False, |
| ) |
| attn_out = attn_out.transpose(1, 2).reshape(B, S, D) |
|
|
| attn_out = self.attn_out(attn_out) |
| x = x + gate_msa * self.dropout(attn_out) |
|
|
| |
| h = self.norm2(x) |
| h = h * (1 + scale_mlp) + shift_mlp |
| x = x + gate_mlp * self.dropout(self.mlp(h)) |
|
|
| return x |
|
|
|
|
| |
|
|
| class DDitFinalLayer(nn.Module): |
| def __init__(self, hidden_size, out_channels, cond_dim): |
| super().__init__() |
| self.norm_final = LayerNorm(hidden_size) |
| self.linear = nn.Linear(hidden_size, out_channels) |
| nn.init.zeros_(self.linear.weight) |
| nn.init.zeros_(self.linear.bias) |
|
|
| self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True) |
| nn.init.zeros_(self.adaLN_modulation.weight) |
| nn.init.zeros_(self.adaLN_modulation.bias) |
|
|
| def forward(self, x, c): |
| shift, scale = self.adaLN_modulation(c)[:, None, :].chunk(2, dim=-1) |
| x = self.norm_final(x) |
| x = x * (1 + scale) + shift |
| return self.linear(x) |
|
|
|
|
| |
|
|
| class MDLM(PreTrainedModel): |
| """ |
| Small Masked Diffusion Language Model. |
| |
| Forward pass: given noisy input_ids and timesteps t β [0,1], |
| predicts logits over vocab for each position. |
| """ |
| config_class = MDLMConfig |
|
|
| def __init__(self, config: MDLMConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| self.vocab_embed = nn.Embedding(config.vocab_size, config.hidden_dim) |
| nn.init.kaiming_uniform_(self.vocab_embed.weight, a=math.sqrt(5)) |
|
|
| self.sigma_map = TimestepEmbedder(config.cond_dim) |
| self.rotary_emb = RotaryEmbedding(config.hidden_dim // config.n_heads) |
|
|
| self.blocks = nn.ModuleList([ |
| DDiTBlock( |
| config.hidden_dim, |
| config.n_heads, |
| config.cond_dim, |
| mlp_ratio=config.mlp_ratio, |
| dropout=config.dropout, |
| ) |
| for _ in range(config.n_blocks) |
| ]) |
|
|
| self.output_layer = DDitFinalLayer( |
| config.hidden_dim, config.vocab_size, config.cond_dim |
| ) |
|
|
| |
| self.post_init() |
|
|
| def get_num_params(self): |
| return sum(p.numel() for p in self.parameters()) |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor, |
| timesteps: torch.FloatTensor, |
| output_hidden_states: bool = False, |
| return_dict: bool = True, |
| ): |
| B, S = input_ids.shape |
|
|
| x = self.vocab_embed(input_ids) |
| |
| if not self.config.time_conditioning: |
| timesteps = torch.zeros_like(timesteps) |
| |
| c = F.silu(self.sigma_map(timesteps)) |
|
|
| rotary_freqs = self.rotary_emb(S, device=x.device) |
|
|
| all_hidden = [x] if output_hidden_states else None |
|
|
| |
| for block in self.blocks: |
| x = block(x, rotary_freqs, c) |
| if output_hidden_states: |
| all_hidden.append(x) |
| logits = self.output_layer(x, c) |
|
|
| if return_dict: |
| return MaskedLMOutput(logits=logits, hidden_states=all_hidden, loss=None) |
| return logits |
|
|
|
|
| |
|
|
| @torch.no_grad() |
| def sample( |
| model: MDLM, |
| seq_len: int, |
| batch_size: int = 1, |
| num_steps: int = 100, |
| temperature: float = 0.7, |
| device: str = "cuda", |
| ): |
| """ |
| Ancestral sampling from MDLM. |
| |
| Start from all [MASK] tokens. |
| At each step sβt (t < s): unmask tokens with probability (1 - t/s), |
| using model predictions. |
| """ |
| mask_id = model.config.mask_token_id |
| |
| |
| x = torch.full((batch_size, seq_len), mask_id, dtype=torch.long, device=device) |
| |
| |
| timesteps = torch.linspace(1.0, 0.0, num_steps + 1, device=device) |
| |
| for i in range(num_steps): |
| t_now = timesteps[i] |
| t_next = timesteps[i + 1] |
| |
| |
| t_batch = torch.full((batch_size,), t_now.item(), device=device) |
| output = model(x, t_batch, return_dict=True) |
| logits = output.logits / temperature |
| |
| |
| probs = F.softmax(logits, dim=-1) |
| predicted = torch.multinomial(probs.view(-1, probs.shape[-1]), 1).view(batch_size, seq_len) |
| |
| |
| is_masked = (x == mask_id) |
| |
| if t_next <= 0: |
| |
| x = torch.where(is_masked, predicted, x) |
| else: |
| |
| unmask_prob = 1.0 - (t_next / t_now) |
| unmask = torch.bernoulli( |
| torch.full_like(x, unmask_prob, dtype=torch.float) |
| ).bool() & is_masked |
| x = torch.where(unmask, predicted, x) |
| |
| return x |
|
|