| # Diffusion LM — TinyStories | |
| A masked-diffusion language model trained from scratch on the | |
| [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) dataset. | |
| ## Demo | |
|  | |
| ## Architecture | |
| | Param | Value | | |
| |---|---| | |
| | Parameters | ~45M | | |
| | Hidden dim | 512 | | |
| | Layers | 10 | | |
| | Heads | 8 | | |
| | FFN dim | 2048 | | |
| | Diffusion steps T | 128 | | |
| | Sequence length | 256 | | |
| | Vocab size | 26,000 | | |
| ## How it works | |
| This is a **masked diffusion** language model. Instead of generating | |
| tokens left-to-right like a standard LM, it starts with a fully masked | |
| sequence and progressively unmasks tokens over T diffusion steps. | |
| At each step the model predicts all masked tokens simultaneously, then | |
| re-masks the least confident predictions and repeats — gradually | |
| refining the output until the sequence is fully unmasked. | |
| ## Training | |
| - Dataset: 1M TinyStories examples | |
| - Train steps: 60,000 | |
| - Effective batch size: 64 (batch 32 × grad accum 2) | |
| - Optimizer: AdamW | |
| - Learning rate: 2e-4 with cosine schedule and 1,000 warmup steps | |
| - Weight decay: 0.1 | |
| - Mixed precision: bf16 | |
| - Hardware: NVIDIA RTX 3090 (24GB) | |
| ## Evaluation | |
| Val loss (cross-entropy on masked tokens, 20 batches of held-out TinyStories): | |
| | Step | Val Loss | | |
| |------|----------| | |
| | 5,000 | 6.0313 | | |
| | 10,000 | 5.9045 | | |
| | 15,000 | 5.6092 | | |
| | 20,000 | 4.4481 | | |
| | 25,000 | 3.8447 | | |
| | 30,000 | 3.6634 | | |
| | 35,000 | 3.5419 | | |
| | 40,000 | 3.3554 | | |
| | 45,000 | 3.2779 | | |
| | 50,000 | 3.1767 | | |
| | 55,000 | 3.1012 | | |
| | 60,000 | 3.1067 | | |
| The loss drop between steps 15,000–25,000 reflects the model learning | |
| basic language structure. Convergence around 3.10 by step 55,000. | |
| ## Files | |
| | File | Description | | |
| |---|---| | |
| | `model.pt` | Model weights (PyTorch state dict) | | |
| | `config.json` | Architecture hyperparameters | | |
| | `tokenizer/` | Byte-level BPE tokenizer | | |
| | `val_loss_history.json` | Validation loss curve | | |
| | `inference.gif` | Visualisation of progressive unmasking | |