tiny-dllm / 04_diffusion.py
sutharsan-311's picture
Upload folder using huggingface_hub
c9d2062 verified
Raw
History Blame Contribute Delete
9.19 kB
"""
Phase 3: Masked Diffusion β€” the dLLM core
-------------------------------------------
This is what makes it a *diffusion* LM, not just a regular transformer.
Concept:
Forward process (add noise): text β†’ gradually mask tokens β†’ fully masked
Backward process (denoise): [MASK]...[MASK] β†’ gradually unmask β†’ text
The model learns: given a partially masked sequence, predict what the masked tokens are.
python 04_diffusion.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# ── Minimal model (copy from 03) ─────────────────────────────────────────────
class SelfAttention(nn.Module):
def __init__(self, hidden, n_heads):
super().__init__()
self.n_heads = n_heads
self.head_dim = hidden // n_heads
self.qkv = nn.Linear(hidden, 3 * hidden, bias=False)
self.out = nn.Linear(hidden, hidden, bias=False)
def forward(self, x):
B, T, C = x.shape
q, k, v = self.qkv(x).chunk(3, dim=-1)
def split(t):
return t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
q, k, v = split(q), split(k), split(v)
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
out = (F.softmax(scores, dim=-1) @ v).transpose(1, 2).contiguous().view(B, T, C)
return self.out(out)
class FeedForward(nn.Module):
def __init__(self, hidden):
super().__init__()
self.net = nn.Sequential(
nn.Linear(hidden, 4 * hidden), nn.GELU(), nn.Linear(4 * hidden, hidden))
def forward(self, x):
return self.net(x)
class TransformerBlock(nn.Module):
def __init__(self, hidden, n_heads):
super().__init__()
self.norm1 = nn.LayerNorm(hidden)
self.attn = SelfAttention(hidden, n_heads)
self.norm2 = nn.LayerNorm(hidden)
self.ff = FeedForward(hidden)
def forward(self, x):
return x + self.ff(self.norm2(x + self.attn(self.norm1(x))))
# ── The dLLM model ────────────────────────────────────────────────────────────
class TinyDLLM(nn.Module):
def __init__(self, vocab_size, hidden=256, n_layers=4, n_heads=4, max_seq=128):
super().__init__()
# vocab_size + 1 because we add a special [MASK] token at index vocab_size
self.mask_token_id = vocab_size
full_vocab = vocab_size + 1
self.token_emb = nn.Embedding(full_vocab, hidden)
self.pos_emb = nn.Embedding(max_seq, hidden)
self.blocks = nn.Sequential(*[TransformerBlock(hidden, n_heads)
for _ in range(n_layers)])
self.norm = nn.LayerNorm(hidden)
self.head = nn.Linear(hidden, vocab_size, bias=False) # predict real tokens only
self.head.weight = nn.Parameter(self.token_emb.weight[:vocab_size])
for m in self.modules():
if isinstance(m, (nn.Linear, nn.Embedding)):
nn.init.normal_(m.weight, std=0.02)
def forward(self, token_ids):
B, T = token_ids.shape
pos = torch.arange(T, device=token_ids.device)
x = self.token_emb(token_ids) + self.pos_emb(pos)
x = self.norm(self.blocks(x))
return self.head(x) # [B, T, vocab_size]
# ── Masked Diffusion ──────────────────────────────────────────────────────────
class MaskedDiffusion:
"""
Forward process: randomly mask tokens with probability t ∈ [0, 1].
t=0 β†’ no masking (original text)
t=1 β†’ fully masked
We sample t uniformly each training step β€” the model must learn
to denoise at ALL noise levels simultaneously.
"""
def __init__(self, mask_token_id):
self.mask_id = mask_token_id
def add_noise(self, tokens, t):
"""
Mask each token independently with probability t.
tokens: [B, T] original token ids
t: [B] noise level per sample (0=clean, 1=all masked)
Returns: noisy_tokens [B, T], mask [B, T] (True = was masked)
"""
B, T = tokens.shape
# t is per-sample; broadcast to [B, T]
mask_prob = t.unsqueeze(1).expand(B, T)
mask = torch.bernoulli(mask_prob).bool() # True = mask this token
noisy = tokens.clone()
noisy[mask] = self.mask_id
return noisy, mask
def loss(self, model, tokens):
"""
Training loss:
1. Sample random noise level t ~ Uniform(0, 1) per sample
2. Apply forward process (mask tokens)
3. Model predicts original tokens at masked positions
4. Cross-entropy only on masked positions (nothing to learn at visible ones)
"""
B, T = tokens.shape
device = tokens.device
t = torch.rand(B, device=device) # random noise level
noisy_tokens, mask = self.add_noise(tokens, t)
logits = model(noisy_tokens) # [B, T, vocab]
# only compute loss where we masked
logits_masked = logits[mask] # [n_masked, vocab]
targets = tokens[mask] # [n_masked]
if logits_masked.numel() == 0:
return torch.tensor(0.0, device=device)
return F.cross_entropy(logits_masked, targets)
@torch.no_grad()
def sample(self, model, seq_len, n_steps=20, device='cpu'):
"""
Generate text from scratch:
Start fully masked β†’ iteratively unmask tokens over n_steps.
Each step unmasks a fraction of the most confident predictions.
"""
model.eval()
B = 1
# start: everything masked
tokens = torch.full((B, seq_len), self.mask_id, dtype=torch.long, device=device)
for step in range(n_steps):
# how many tokens to unmask this step
# unmask gradually: step 0 reveals fewest, step n_steps-1 reveals rest
frac_unmasked = (step + 1) / n_steps
target_unmasked = int(frac_unmasked * seq_len)
logits = model(tokens) # [1, T, vocab]
probs = F.softmax(logits, dim=-1) # [1, T, vocab]
# sample token predictions everywhere
predicted = torch.multinomial(
probs.view(B * seq_len, -1), num_samples=1).view(B, seq_len)
# confidence = max probability at each position
confidence, _ = probs.max(dim=-1) # [1, T]
# only consider currently masked positions
still_masked = (tokens == self.mask_id)
confidence[~still_masked] = -1.0 # ignore already-unmasked
# unmask the most confident positions up to target_unmasked total
currently_unmasked = (~still_masked).sum().item()
to_unmask = max(0, target_unmasked - currently_unmasked)
if to_unmask > 0 and still_masked.any():
_, top_idx = confidence.view(-1).topk(min(to_unmask, still_masked.sum().item()))
flat_tokens = tokens.view(-1)
flat_pred = predicted.view(-1)
flat_tokens[top_idx] = flat_pred[top_idx]
tokens = flat_tokens.view(B, seq_len)
return tokens
# ── Demo ──────────────────────────────────────────────────────────────────────
device = 'cuda' if torch.cuda.is_available() else 'cpu'
VOCAB = 65
HIDDEN = 256
LAYERS = 4
HEADS = 4
MAX_SEQ = 128
model = TinyDLLM(VOCAB, HIDDEN, LAYERS, HEADS, MAX_SEQ).to(device)
diffusion = MaskedDiffusion(mask_token_id=VOCAB)
# test loss on fake data
fake_tokens = torch.randint(0, VOCAB, (4, MAX_SEQ)).to(device)
loss = diffusion.loss(model, fake_tokens)
print(f"Loss on random data (untrained): {loss.item():.4f}")
print(f" (expected ~ln({VOCAB}) = {math.log(VOCAB):.2f} for random model)")
# test forward process
tokens = torch.randint(0, VOCAB, (1, 10)).to(device)
t = torch.tensor([0.5]).to(device)
noisy, mask = diffusion.add_noise(tokens, t)
print(f"\nOriginal: {tokens[0].tolist()}")
print(f"Noisy (t=0.5): {noisy[0].tolist()} ({VOCAB}=MASK)")
print(f"Masked positions: {mask[0].tolist()}")
# test sampling (untrained = random output, but pipeline works)
generated = diffusion.sample(model, seq_len=20, n_steps=10, device=device)
print(f"\nGenerated (untrained): {generated[0].tolist()}")
print("""
What just happened:
1. add_noise() β€” masked 50% of tokens randomly (forward process)
2. model() β€” predicted all token positions (bidirectional attention!)
3. sample() β€” started fully masked, unmasked most-confident tokens step by step
βœ… Phase 3 complete β€” diffusion process built
Next: 05_train.py β€” actually train this on Shakespeare text
""")