tchauffi's picture
Use DiffusionLM pipeline in usage example
a3f2b3f verified
|
Raw
History Blame Contribute Delete
3.03 kB
metadata
library_name: pytorch
tags:
  - diffusion
  - language-model
  - discrete-diffusion
  - absorbing-state
  - text-generation
  - tinystories
  - from-scratch
datasets:
  - roneneldan/TinyStories
pipeline_tag: text-generation

diffusionlm-from-scratch — masked diffusion LM (DiT, 142M)

A masked (absorbing-state) diffusion language model, built and trained from scratch on TinyStories. Instead of generating left-to-right one token at a time, it starts from a sequence of pure [MASK] tokens and denoises the whole sequence in parallel — committing the tokens it is most confident about first, in whatever order the meaning falls into place.

Model

Architecture DiT (transformer denoiser), bidirectional attention, adaLN-Zero
Parameters ~142M
Hidden size / depth / heads 768 / 12 / 12
MLP ratio 4.0
Vocab 8,192 (byte-level BPE, trained on TinyStories)
Max sequence length 256
Diffusion absorbing-state (masked) discrete diffusion
Training data TinyStories
Eval cross-entropy 2.18

Key finding: uniform loss weighting (w(t) = 1), not the textbook ELBO weight 1/σ(t), is what turned word-salad into coherent stories.

Files

  • final.pt — checkpoint with two state dicts, model (EMA, preferred) and raw, plus the config used to build the model.
  • tokenizer.json, tokenizer_config.json — the byte-level BPE tokenizer (PreTrainedTokenizerFast; special tokens [PAD] [UNK] [MASK] <|endoftext|>).

Usage

Install the model code from the GitHub repo, then generate stories in two lines — DiffusionLM bundles the model, tokenizer, and absorbing-state scheduler:

from diffusionlm_from_scratch import DiffusionLM

lm = DiffusionLM.from_pretrained("tchauffi/diffusionlm-from-scratch")
for story in lm.generate(n=4, seq_len=80, temperature=0.9):
    print(story)

generate exposes the sampler knobs (order, steps, corrector_frac, confidence_threshold, …). For lower-level access, load just the model:

from diffusionlm_from_scratch.model import DiT

model = DiT.from_pretrained("tchauffi/diffusionlm-from-scratch")  # downloads final.pt
# the raw checkpoint carries ck["config"], ck["model"] (EMA), and ck["raw"].

See scripts/capture_trajectories.py in the repo for the full parallel-denoising sampling loop.