tchauffi's picture
Use DiffusionLM pipeline in usage example
a3f2b3f verified
|
Raw
History Blame Contribute Delete
3.03 kB
---
library_name: pytorch
tags:
- diffusion
- language-model
- discrete-diffusion
- absorbing-state
- text-generation
- tinystories
- from-scratch
datasets:
- roneneldan/TinyStories
pipeline_tag: text-generation
---
# diffusionlm-from-scratch β€” masked diffusion LM (DiT, 142M)
A masked (absorbing-state) **diffusion language model**, built and trained from
scratch on TinyStories. Instead of generating left-to-right one token at a time,
it starts from a sequence of pure `[MASK]` tokens and **denoises the whole
sequence in parallel** β€” committing the tokens it is most confident about first,
in whatever order the meaning falls into place.
- **Code, training & sampling:** https://github.com/tchauffi/diffusionlm-from-scratch
- **Course / write-up:** [`RESEARCH.md`](https://github.com/tchauffi/diffusionlm-from-scratch/blob/main/RESEARCH.md) β€” a from-scratch course on discrete/text diffusion (D3PM β†’ absorbing-state β†’ sampling).
- **Demo site:** animated real generations live in [`docs/`](https://github.com/tchauffi/diffusionlm-from-scratch/tree/main/docs).
## Model
| | |
|---|---|
| Architecture | DiT (transformer denoiser), bidirectional attention, adaLN-Zero |
| Parameters | ~142M |
| Hidden size / depth / heads | 768 / 12 / 12 |
| MLP ratio | 4.0 |
| Vocab | 8,192 (byte-level BPE, trained on TinyStories) |
| Max sequence length | 256 |
| Diffusion | absorbing-state (masked) discrete diffusion |
| Training data | [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) |
| Eval cross-entropy | **2.18** |
**Key finding:** uniform loss weighting (`w(t) = 1`), *not* the textbook ELBO
weight `1/Οƒ(t)`, is what turned word-salad into coherent stories.
## Files
- `final.pt` β€” checkpoint with two state dicts, `model` (EMA, preferred) and
`raw`, plus the `config` used to build the model.
- `tokenizer.json`, `tokenizer_config.json` β€” the byte-level BPE tokenizer
(`PreTrainedTokenizerFast`; special tokens `[PAD]` `[UNK]` `[MASK]`
`<|endoftext|>`).
## Usage
Install the model code from the
[GitHub repo](https://github.com/tchauffi/diffusionlm-from-scratch), then
generate stories in two lines β€” `DiffusionLM` bundles the model, tokenizer, and
absorbing-state scheduler:
```python
from diffusionlm_from_scratch import DiffusionLM
lm = DiffusionLM.from_pretrained("tchauffi/diffusionlm-from-scratch")
for story in lm.generate(n=4, seq_len=80, temperature=0.9):
print(story)
```
`generate` exposes the sampler knobs (`order`, `steps`, `corrector_frac`,
`confidence_threshold`, …). For lower-level access, load just the model:
```python
from diffusionlm_from_scratch.model import DiT
model = DiT.from_pretrained("tchauffi/diffusionlm-from-scratch") # downloads final.pt
# the raw checkpoint carries ck["config"], ck["model"] (EMA), and ck["raw"].
```
See [`scripts/capture_trajectories.py`](https://github.com/tchauffi/diffusionlm-from-scratch/blob/main/scripts/capture_trajectories.py)
in the repo for the full parallel-denoising sampling loop.