mdlm-tiny-stories / model.py
youraveragedev's picture
Upload model.py with huggingface_hub
26b231d verified
"""
Small MDLM (Masked Diffusion Language Model) for text generation.
Based on: "Simple and Effective Masked Diffusion Language Models" (Sahoo et al., NeurIPS 2024)
Architecture: DiT backbone with adaLN-zero conditioning, RoPE, bidirectional attention.
No flash_attn dependency β€” uses PyTorch native scaled_dot_product_attention.
"""
import math
import typing
import json
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import MaskedLMOutput
class MDLMConfig(PretrainedConfig):
"""Configuration for a small MDLM text diffusion model."""
model_type = "mdlm"
def __init__(
self,
vocab_size: int = 50258,
model_length: int = 256,
hidden_dim: int = 512,
cond_dim: int = 128,
n_blocks: int = 6,
n_heads: int = 8,
dropout: float = 0.1,
time_conditioning: bool = True,
mlp_ratio: int = 4,
mask_token_id: int = 50257,
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.model_length = model_length
self.hidden_dim = hidden_dim
self.cond_dim = cond_dim
self.n_blocks = n_blocks
self.n_heads = n_heads
self.dropout = dropout
self.time_conditioning = time_conditioning
self.mlp_ratio = mlp_ratio
self.mask_token_id = mask_token_id
# ─── Rotary Position Embeddings ───────────────────────────
class RotaryEmbedding(nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, seq_len, device):
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
return torch.cat([freqs, freqs], dim=-1) # (seq_len, dim)
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, freqs):
"""Apply RoPE to query and key tensors."""
cos = freqs.cos().unsqueeze(0).unsqueeze(2) # (1, seq, 1, dim)
sin = freqs.sin().unsqueeze(0).unsqueeze(2) # (1, seq, 1, dim)
q = q * cos + rotate_half(q) * sin
k = k * cos + rotate_half(k) * sin
return q, k
# ─── Timestep Embedding ──────────────────────────────────
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=t.device) / half
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
return self.mlp(t_freq)
# ─── LayerNorm ────────────────────────────────────────────
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.dim = dim
def forward(self, x):
with torch.amp.autocast("cuda", enabled=False):
x = F.layer_norm(x.float(), [self.dim])
return x * self.weight[None, None, :]
# ─── DiT Block with adaLN-zero ───────────────────────────
class DDiTBlock(nn.Module):
def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1):
super().__init__()
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.norm1 = LayerNorm(dim)
self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
self.attn_out = nn.Linear(dim, dim, bias=False)
self.norm2 = LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_ratio * dim),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_ratio * dim, dim),
)
self.dropout = nn.Dropout(dropout)
self.drop_p = dropout
# adaLN-zero: 6 modulation params (shift, scale, gate for attn & mlp)
self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
nn.init.zeros_(self.adaLN_modulation.weight)
nn.init.zeros_(self.adaLN_modulation.bias)
def forward(self, x, rotary_freqs, c):
B, S, D = x.shape
# adaLN modulation
mod = self.adaLN_modulation(c)[:, None, :] # (B, 1, 6*D)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=-1)
# ── Self-Attention ──
h = self.norm1(x)
h = h * (1 + scale_msa) + shift_msa
qkv = self.attn_qkv(h)
qkv = qkv.view(B, S, 3, self.n_heads, self.head_dim)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
# q, k, v: (B, S, n_heads, head_dim)
# Apply RoPE
q, k = apply_rotary_pos_emb(q, k, rotary_freqs)
# Transpose to (B, n_heads, S, head_dim) for SDPA
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Bidirectional attention (no causal mask)
attn_out = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.drop_p if self.training else 0.0,
is_causal=False,
)
attn_out = attn_out.transpose(1, 2).reshape(B, S, D)
attn_out = self.attn_out(attn_out)
x = x + gate_msa * self.dropout(attn_out)
# ── MLP ──
h = self.norm2(x)
h = h * (1 + scale_mlp) + shift_mlp
x = x + gate_mlp * self.dropout(self.mlp(h))
return x
# ─── Final Layer ──────────────────────────────────────────
class DDitFinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, cond_dim):
super().__init__()
self.norm_final = LayerNorm(hidden_size)
self.linear = nn.Linear(hidden_size, out_channels)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True)
nn.init.zeros_(self.adaLN_modulation.weight)
nn.init.zeros_(self.adaLN_modulation.bias)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c)[:, None, :].chunk(2, dim=-1)
x = self.norm_final(x)
x = x * (1 + scale) + shift
return self.linear(x)
# ─── Full Model ──────────────────────────────────────────
class MDLM(PreTrainedModel):
"""
Small Masked Diffusion Language Model.
Forward pass: given noisy input_ids and timesteps t ∈ [0,1],
predicts logits over vocab for each position.
"""
config_class = MDLMConfig
def __init__(self, config: MDLMConfig):
super().__init__(config)
self.config = config
self.vocab_embed = nn.Embedding(config.vocab_size, config.hidden_dim)
nn.init.kaiming_uniform_(self.vocab_embed.weight, a=math.sqrt(5))
self.sigma_map = TimestepEmbedder(config.cond_dim)
self.rotary_emb = RotaryEmbedding(config.hidden_dim // config.n_heads)
self.blocks = nn.ModuleList([
DDiTBlock(
config.hidden_dim,
config.n_heads,
config.cond_dim,
mlp_ratio=config.mlp_ratio,
dropout=config.dropout,
)
for _ in range(config.n_blocks)
])
self.output_layer = DDitFinalLayer(
config.hidden_dim, config.vocab_size, config.cond_dim
)
# Separate output projection (no weight tying with embeddings)
self.post_init()
def get_num_params(self):
return sum(p.numel() for p in self.parameters())
def forward(
self,
input_ids: torch.LongTensor,
timesteps: torch.FloatTensor,
output_hidden_states: bool = False,
return_dict: bool = True,
):
B, S = input_ids.shape
x = self.vocab_embed(input_ids)
if not self.config.time_conditioning:
timesteps = torch.zeros_like(timesteps)
c = F.silu(self.sigma_map(timesteps))
rotary_freqs = self.rotary_emb(S, device=x.device)
all_hidden = [x] if output_hidden_states else None
# Mixed precision: let the outer training loop handle autocast
for block in self.blocks:
x = block(x, rotary_freqs, c)
if output_hidden_states:
all_hidden.append(x)
logits = self.output_layer(x, c)
if return_dict:
return MaskedLMOutput(logits=logits, hidden_states=all_hidden, loss=None)
return logits
# ─── Sampling ─────────────────────────────────────────────
@torch.no_grad()
def sample(
model: MDLM,
seq_len: int,
batch_size: int = 1,
num_steps: int = 100,
temperature: float = 0.7,
device: str = "cuda",
):
"""
Ancestral sampling from MDLM.
Start from all [MASK] tokens.
At each step s→t (t < s): unmask tokens with probability (1 - t/s),
using model predictions.
"""
mask_id = model.config.mask_token_id
# Start with all masked
x = torch.full((batch_size, seq_len), mask_id, dtype=torch.long, device=device)
# Discretize time from 1β†’0
timesteps = torch.linspace(1.0, 0.0, num_steps + 1, device=device)
for i in range(num_steps):
t_now = timesteps[i]
t_next = timesteps[i + 1]
# Get model predictions
t_batch = torch.full((batch_size,), t_now.item(), device=device)
output = model(x, t_batch, return_dict=True)
logits = output.logits / temperature
# Sample from predicted distribution
probs = F.softmax(logits, dim=-1)
predicted = torch.multinomial(probs.view(-1, probs.shape[-1]), 1).view(batch_size, seq_len)
# Determine which masked positions to unmask
is_masked = (x == mask_id)
if t_next <= 0:
# Last step: unmask everything
x = torch.where(is_masked, predicted, x)
else:
# Unmask with probability (1 - t_next/t_now)
unmask_prob = 1.0 - (t_next / t_now)
unmask = torch.bernoulli(
torch.full_like(x, unmask_prob, dtype=torch.float)
).bool() & is_masked
x = torch.where(unmask, predicted, x)
return x