| """ |
| Phase 3: Masked Diffusion β the dLLM core |
| ------------------------------------------- |
| This is what makes it a *diffusion* LM, not just a regular transformer. |
| |
| Concept: |
| Forward process (add noise): text β gradually mask tokens β fully masked |
| Backward process (denoise): [MASK]...[MASK] β gradually unmask β text |
| |
| The model learns: given a partially masked sequence, predict what the masked tokens are. |
| |
| python 04_diffusion.py |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
|
|
| |
| class SelfAttention(nn.Module): |
| def __init__(self, hidden, n_heads): |
| super().__init__() |
| self.n_heads = n_heads |
| self.head_dim = hidden // n_heads |
| self.qkv = nn.Linear(hidden, 3 * hidden, bias=False) |
| self.out = nn.Linear(hidden, hidden, bias=False) |
|
|
| def forward(self, x): |
| B, T, C = x.shape |
| q, k, v = self.qkv(x).chunk(3, dim=-1) |
| def split(t): |
| return t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| q, k, v = split(q), split(k), split(v) |
| scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) |
| out = (F.softmax(scores, dim=-1) @ v).transpose(1, 2).contiguous().view(B, T, C) |
| return self.out(out) |
|
|
| class FeedForward(nn.Module): |
| def __init__(self, hidden): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(hidden, 4 * hidden), nn.GELU(), nn.Linear(4 * hidden, hidden)) |
| def forward(self, x): |
| return self.net(x) |
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, hidden, n_heads): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(hidden) |
| self.attn = SelfAttention(hidden, n_heads) |
| self.norm2 = nn.LayerNorm(hidden) |
| self.ff = FeedForward(hidden) |
| def forward(self, x): |
| return x + self.ff(self.norm2(x + self.attn(self.norm1(x)))) |
|
|
|
|
| |
| class TinyDLLM(nn.Module): |
| def __init__(self, vocab_size, hidden=256, n_layers=4, n_heads=4, max_seq=128): |
| super().__init__() |
| |
| self.mask_token_id = vocab_size |
| full_vocab = vocab_size + 1 |
|
|
| self.token_emb = nn.Embedding(full_vocab, hidden) |
| self.pos_emb = nn.Embedding(max_seq, hidden) |
| self.blocks = nn.Sequential(*[TransformerBlock(hidden, n_heads) |
| for _ in range(n_layers)]) |
| self.norm = nn.LayerNorm(hidden) |
| self.head = nn.Linear(hidden, vocab_size, bias=False) |
| self.head.weight = nn.Parameter(self.token_emb.weight[:vocab_size]) |
|
|
| for m in self.modules(): |
| if isinstance(m, (nn.Linear, nn.Embedding)): |
| nn.init.normal_(m.weight, std=0.02) |
|
|
| def forward(self, token_ids): |
| B, T = token_ids.shape |
| pos = torch.arange(T, device=token_ids.device) |
| x = self.token_emb(token_ids) + self.pos_emb(pos) |
| x = self.norm(self.blocks(x)) |
| return self.head(x) |
|
|
|
|
| |
| class MaskedDiffusion: |
| """ |
| Forward process: randomly mask tokens with probability t β [0, 1]. |
| t=0 β no masking (original text) |
| t=1 β fully masked |
| |
| We sample t uniformly each training step β the model must learn |
| to denoise at ALL noise levels simultaneously. |
| """ |
|
|
| def __init__(self, mask_token_id): |
| self.mask_id = mask_token_id |
|
|
| def add_noise(self, tokens, t): |
| """ |
| Mask each token independently with probability t. |
| tokens: [B, T] original token ids |
| t: [B] noise level per sample (0=clean, 1=all masked) |
| Returns: noisy_tokens [B, T], mask [B, T] (True = was masked) |
| """ |
| B, T = tokens.shape |
| |
| mask_prob = t.unsqueeze(1).expand(B, T) |
| mask = torch.bernoulli(mask_prob).bool() |
| noisy = tokens.clone() |
| noisy[mask] = self.mask_id |
| return noisy, mask |
|
|
| def loss(self, model, tokens): |
| """ |
| Training loss: |
| 1. Sample random noise level t ~ Uniform(0, 1) per sample |
| 2. Apply forward process (mask tokens) |
| 3. Model predicts original tokens at masked positions |
| 4. Cross-entropy only on masked positions (nothing to learn at visible ones) |
| """ |
| B, T = tokens.shape |
| device = tokens.device |
|
|
| t = torch.rand(B, device=device) |
| noisy_tokens, mask = self.add_noise(tokens, t) |
|
|
| logits = model(noisy_tokens) |
|
|
| |
| logits_masked = logits[mask] |
| targets = tokens[mask] |
|
|
| if logits_masked.numel() == 0: |
| return torch.tensor(0.0, device=device) |
|
|
| return F.cross_entropy(logits_masked, targets) |
|
|
| @torch.no_grad() |
| def sample(self, model, seq_len, n_steps=20, device='cpu'): |
| """ |
| Generate text from scratch: |
| Start fully masked β iteratively unmask tokens over n_steps. |
| Each step unmasks a fraction of the most confident predictions. |
| """ |
| model.eval() |
| B = 1 |
| |
| tokens = torch.full((B, seq_len), self.mask_id, dtype=torch.long, device=device) |
|
|
| for step in range(n_steps): |
| |
| |
| frac_unmasked = (step + 1) / n_steps |
| target_unmasked = int(frac_unmasked * seq_len) |
|
|
| logits = model(tokens) |
| probs = F.softmax(logits, dim=-1) |
|
|
| |
| predicted = torch.multinomial( |
| probs.view(B * seq_len, -1), num_samples=1).view(B, seq_len) |
|
|
| |
| confidence, _ = probs.max(dim=-1) |
|
|
| |
| still_masked = (tokens == self.mask_id) |
| confidence[~still_masked] = -1.0 |
|
|
| |
| currently_unmasked = (~still_masked).sum().item() |
| to_unmask = max(0, target_unmasked - currently_unmasked) |
|
|
| if to_unmask > 0 and still_masked.any(): |
| _, top_idx = confidence.view(-1).topk(min(to_unmask, still_masked.sum().item())) |
| flat_tokens = tokens.view(-1) |
| flat_pred = predicted.view(-1) |
| flat_tokens[top_idx] = flat_pred[top_idx] |
| tokens = flat_tokens.view(B, seq_len) |
|
|
| return tokens |
|
|
|
|
| |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| VOCAB = 65 |
| HIDDEN = 256 |
| LAYERS = 4 |
| HEADS = 4 |
| MAX_SEQ = 128 |
|
|
| model = TinyDLLM(VOCAB, HIDDEN, LAYERS, HEADS, MAX_SEQ).to(device) |
| diffusion = MaskedDiffusion(mask_token_id=VOCAB) |
|
|
| |
| fake_tokens = torch.randint(0, VOCAB, (4, MAX_SEQ)).to(device) |
| loss = diffusion.loss(model, fake_tokens) |
| print(f"Loss on random data (untrained): {loss.item():.4f}") |
| print(f" (expected ~ln({VOCAB}) = {math.log(VOCAB):.2f} for random model)") |
|
|
| |
| tokens = torch.randint(0, VOCAB, (1, 10)).to(device) |
| t = torch.tensor([0.5]).to(device) |
| noisy, mask = diffusion.add_noise(tokens, t) |
| print(f"\nOriginal: {tokens[0].tolist()}") |
| print(f"Noisy (t=0.5): {noisy[0].tolist()} ({VOCAB}=MASK)") |
| print(f"Masked positions: {mask[0].tolist()}") |
|
|
| |
| generated = diffusion.sample(model, seq_len=20, n_steps=10, device=device) |
| print(f"\nGenerated (untrained): {generated[0].tolist()}") |
|
|
| print(""" |
| What just happened: |
| 1. add_noise() β masked 50% of tokens randomly (forward process) |
| 2. model() β predicted all token positions (bidirectional attention!) |
| 3. sample() β started fully masked, unmasked most-confident tokens step by step |
| |
| β
Phase 3 complete β diffusion process built |
| Next: 05_train.py β actually train this on Shakespeare text |
| """) |
|
|