tiny-tiny-stories / train.py
hidude562's picture
Upload folder using huggingface_hub
7a2fc07 verified
#!/usr/bin/env python3
"""
1-Bit Transformer LM on TinyStories
< 1M params | < 200 vocab | BitNet b1.58 ternary weights {-1, 0, +1}
Architecture: RoPE, RMSNorm, SwiGLU, tied embeddings
Tokenizer: SentencePiece unigram (192 vocab)
"""
import os, json, math, time, random, argparse
from pathlib import Path
from dataclasses import dataclass, asdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import sentencepiece as spm
# ================================================================
# Config
# ================================================================
@dataclass
class Config:
# Model
vocab_size: int = 192 # < 200
d_model: int = 128
n_heads: int = 4 # head_dim = 32
n_layers: int = 5
d_ff: int = 336 # SwiGLU intermediate
max_seq_len: int = 512
# Training
batch_size: int = 96
grad_accum: int = 4 # effective batch = 384
lr: float = 1.5e-3
min_lr: float = 1e-5
warmup_steps: int = 800
max_steps: int = 100_000
weight_decay: float = 0.1
grad_clip: float = 1.0
# Logging / eval
eval_interval: int = 1000
eval_steps: int = 50
log_interval: int = 100
gen_interval: int = 5000
save_interval: int = 5000
# Misc
seed: int = 42
device: str = "cuda:0"
compile: bool = False
num_workers: int = 0
# ================================================================
# Model
# ================================================================
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.w = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.w
class BitLinear(nn.Module):
"""Linear layer with ternary {-1, 0, +1} weight quantization (BitNet b1.58).
Full-precision latent weights are kept for optimizer updates.
Forward uses quantized weights via straight-through estimator."""
def __init__(self, in_f, out_f):
super().__init__()
self.weight = nn.Parameter(torch.empty(out_f, in_f))
nn.init.normal_(self.weight, std=0.02)
def forward(self, x):
alpha = self.weight.abs().mean().clamp(min=1e-5)
wq = torch.clamp(torch.round(self.weight / alpha), -1, 1) * alpha
w = self.weight + (wq - self.weight).detach() # STE
return F.linear(x, w)
def _rope_freqs(dim, max_len, base=10000.0):
f = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
t = torch.arange(max_len, dtype=torch.float32)
ang = torch.outer(t, f)
return torch.cos(ang), torch.sin(ang)
def _apply_rope(x, c, s):
d = x.shape[-1] // 2
x1, x2 = x[..., :d], x[..., d:]
return torch.cat([x1 * c - x2 * s, x2 * c + x1 * s], dim=-1)
class Block(nn.Module):
def __init__(self, d, h, ff):
super().__init__()
self.n1 = RMSNorm(d)
self.n2 = RMSNorm(d)
# Attention
self.q = BitLinear(d, d)
self.k = BitLinear(d, d)
self.v = BitLinear(d, d)
self.o = BitLinear(d, d)
# SwiGLU FFN
self.gate = BitLinear(d, ff)
self.up = BitLinear(d, ff)
self.down = BitLinear(ff, d)
self.nh = h
self.hd = d // h
def forward(self, x, cos, sin):
B, T, C = x.shape
h = self.n1(x)
q = self.q(h).view(B, T, self.nh, self.hd).transpose(1, 2)
k = self.k(h).view(B, T, self.nh, self.hd).transpose(1, 2)
v = self.v(h).view(B, T, self.nh, self.hd).transpose(1, 2)
q = _apply_rope(q, cos, sin)
k = _apply_rope(k, cos, sin)
a = F.scaled_dot_product_attention(q, k, v, is_causal=True)
x = x + self.o(a.transpose(1, 2).contiguous().view(B, T, C))
h = self.n2(x)
x = x + self.down(F.silu(self.gate(h)) * self.up(h))
return x
class BitLM(nn.Module):
def __init__(self, cfg: Config):
super().__init__()
self.cfg = cfg
self.emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.layers = nn.ModuleList(
[Block(cfg.d_model, cfg.n_heads, cfg.d_ff) for _ in range(cfg.n_layers)]
)
self.norm = RMSNorm(cfg.d_model)
self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
self.head.weight = self.emb.weight # weight tying
hd = cfg.d_model // cfg.n_heads
c, s = _rope_freqs(hd, cfg.max_seq_len)
self.register_buffer("rc", c)
self.register_buffer("rs", s)
nn.init.normal_(self.emb.weight, std=0.02)
def forward(self, idx, targets=None):
B, T = idx.shape
x = self.emb(idx)
c = self.rc[:T].unsqueeze(0).unsqueeze(0) # (1,1,T,hd/2)
s = self.rs[:T].unsqueeze(0).unsqueeze(0)
for layer in self.layers:
x = layer(x, c, s)
logits = self.head(self.norm(x))
loss = None
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=0
)
return logits, loss
def param_count(self):
seen = set()
total = 0
for p in self.parameters():
pid = id(p)
if pid not in seen:
seen.add(pid)
total += p.numel()
return total
@torch.no_grad()
def generate(self, idx, max_new=200, temp=0.8, top_k=40, eos_id=2):
for _ in range(max_new):
ic = idx[:, -self.cfg.max_seq_len:]
logits, _ = self(ic)
logits = logits[:, -1] / temp
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float("-inf")
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, 1)
idx = torch.cat([idx, nxt], dim=1)
if nxt.item() == eos_id:
break
return idx
# ================================================================
# Dataset
# ================================================================
class ChunkedDataset(Dataset):
"""Flat token tensor split into fixed-length chunks."""
def __init__(self, tokens, seq_len):
self.tokens = tokens
self.seq_len = seq_len
self.n = (len(tokens) - 1) // seq_len
def __len__(self):
return self.n
def __getitem__(self, i):
s = i * self.seq_len
c = self.tokens[s : s + self.seq_len + 1]
return c[:-1], c[1:]
# ================================================================
# Tokenizer helpers
# ================================================================
def train_tokenizer(texts, exp_dir, vocab_size=192, n_train=100_000):
"""Train SentencePiece unigram tokenizer with <200 vocab."""
data_file = exp_dir / "sp_train.txt"
prefix = str(exp_dir / "tokenizer")
print(f"Writing {min(n_train, len(texts))} texts for tokenizer training...")
with open(data_file, "w", encoding="utf-8") as f:
for t in texts[:n_train]:
f.write(t.strip().replace("\n", " ") + "\n")
print("Training SentencePiece unigram tokenizer...")
spm.SentencePieceTrainer.train(
input=str(data_file),
model_prefix=prefix,
vocab_size=vocab_size,
model_type="unigram",
character_coverage=1.0,
pad_id=0, bos_id=1, eos_id=2, unk_id=3,
byte_fallback=False,
normalization_rule_name="identity",
max_sentence_length=8192,
num_threads=os.cpu_count() or 4,
train_extremely_large_corpus=False,
)
data_file.unlink(missing_ok=True)
sp = spm.SentencePieceProcessor(model_file=prefix + ".model")
print(f"Tokenizer ready: {sp.get_piece_size()} tokens")
return sp
def encode_texts(sp, texts, desc="data"):
"""Encode all texts into a single flat token tensor (BOS story EOS ...)."""
bos, eos = sp.bos_id(), sp.eos_id()
all_ids = []
t0 = time.time()
for i, t in enumerate(texts):
all_ids.append(bos)
all_ids.extend(sp.encode(t))
all_ids.append(eos)
if (i + 1) % 500_000 == 0:
print(f" {desc}: {i+1}/{len(texts)} ({len(all_ids)/1e6:.1f}M tok)")
elapsed = time.time() - t0
print(f" {desc}: {len(all_ids)/1e6:.2f}M tokens, {elapsed:.1f}s")
return torch.tensor(all_ids, dtype=torch.long)
# ================================================================
# LR schedule
# ================================================================
def get_lr(step, cfg):
if step < cfg.warmup_steps:
return cfg.lr * step / cfg.warmup_steps
if step >= cfg.max_steps:
return cfg.min_lr
r = (step - cfg.warmup_steps) / (cfg.max_steps - cfg.warmup_steps)
return cfg.min_lr + 0.5 * (cfg.lr - cfg.min_lr) * (1 + math.cos(math.pi * r))
# ================================================================
# Eval
# ================================================================
@torch.no_grad()
def evaluate(model, loader, device, steps=50):
model.eval()
total, n = 0.0, 0
for x, y in loader:
if n >= steps:
break
x, y = x.to(device), y.to(device)
with torch.amp.autocast("cuda", dtype=torch.float16):
_, loss = model(x, y)
total += loss.item()
n += 1
model.train()
return total / max(n, 1)
# ================================================================
# Main
# ================================================================
def main():
parser = argparse.ArgumentParser(description="Train 1-bit Transformer LM")
parser.add_argument("--exp-dir", default="/root/experiments/1m-model")
parser.add_argument("--max-steps", type=int, default=100_000)
parser.add_argument("--batch-size", type=int, default=96)
parser.add_argument("--lr", type=float, default=1.5e-3)
parser.add_argument("--device", default="cuda:0")
parser.add_argument("--compile", action="store_true")
parser.add_argument("--generate", action="store_true")
parser.add_argument("--prompt", default="Once upon a time")
args = parser.parse_args()
cfg = Config()
cfg.batch_size = args.batch_size
cfg.max_steps = args.max_steps
cfg.lr = args.lr
cfg.device = args.device
cfg.compile = args.compile
exp = Path(args.exp_dir)
exp.mkdir(parents=True, exist_ok=True)
torch.manual_seed(cfg.seed)
random.seed(cfg.seed)
torch.backends.cudnn.benchmark = True
# ---- Tokenizer ----
tok_model = exp / "tokenizer.model"
if tok_model.exists():
print("Loading tokenizer...")
sp = spm.SentencePieceProcessor(model_file=str(tok_model))
else:
from datasets import load_dataset
print("Loading TinyStories for tokenizer training...")
ds = load_dataset("roneneldan/TinyStories", split="train")
subset = [ds[i]["text"] for i in range(min(100_000, len(ds)))]
sp = train_tokenizer(subset, exp, vocab_size=cfg.vocab_size)
del subset, ds
cfg.vocab_size = sp.get_piece_size()
print(f"Vocab size: {cfg.vocab_size}")
assert cfg.vocab_size < 200, f"Tokenizer too large: {cfg.vocab_size}"
# ---- Generate mode ----
if args.generate:
model = BitLM(cfg).to(cfg.device)
ckpt = torch.load(exp / "best.pt", map_location=cfg.device, weights_only=True)
state = ckpt["model"]
if any(k.startswith("_orig_mod.") for k in state):
state = {k.replace("_orig_mod.", ""): v for k, v in state.items()}
model.load_state_dict(state)
model.eval()
print(f"Loaded best model (step {ckpt['step']}, val_loss={ckpt['val_loss']:.4f})")
ids = [sp.bos_id()] + sp.encode(args.prompt)
idx = torch.tensor([ids], device=cfg.device)
out = model.generate(idx, max_new=500, temp=0.8, top_k=40, eos_id=sp.eos_id())
text = sp.decode(out[0].tolist())
print(f"\n--- Generated ---\n{text}\n")
return
# ---- Data ----
train_cache = exp / "train_tokens.pt"
val_cache = exp / "val_tokens.pt"
if train_cache.exists() and val_cache.exists():
print("Loading cached tokens...")
train_tok = torch.load(train_cache, weights_only=True)
val_tok = torch.load(val_cache, weights_only=True)
else:
from datasets import load_dataset
print("Loading TinyStories...")
train_ds = load_dataset("roneneldan/TinyStories", split="train")
val_ds = load_dataset("roneneldan/TinyStories", split="validation")
train_texts = [ex["text"] for ex in train_ds]
val_texts = [ex["text"] for ex in val_ds]
print(f"Train: {len(train_texts):,} stories, Val: {len(val_texts):,} stories")
train_tok = encode_texts(sp, train_texts, "train")
val_tok = encode_texts(sp, val_texts, "val")
print("Saving cached tokens...")
torch.save(train_tok, train_cache)
torch.save(val_tok, val_cache)
del train_texts, val_texts
train_data = ChunkedDataset(train_tok, cfg.max_seq_len)
val_data = ChunkedDataset(val_tok, cfg.max_seq_len)
print(f"Train: {len(train_data):,} chunks, Val: {len(val_data):,} chunks")
train_loader = DataLoader(
train_data, batch_size=cfg.batch_size, shuffle=True,
num_workers=cfg.num_workers, pin_memory=True, drop_last=True,
)
val_loader = DataLoader(
val_data, batch_size=cfg.batch_size, shuffle=False,
num_workers=cfg.num_workers, pin_memory=True, drop_last=True,
)
# ---- Model ----
model = BitLM(cfg).to(cfg.device)
n_params = model.param_count()
print(f"\nModel: {n_params:,} parameters ({n_params/1e6:.3f}M)")
print(f" d_model={cfg.d_model}, n_heads={cfg.n_heads}, n_layers={cfg.n_layers}, "
f"d_ff={cfg.d_ff}, max_seq_len={cfg.max_seq_len}")
assert n_params < 1_000_000, f"Model too large: {n_params:,} params >= 1M"
if cfg.compile:
print("Compiling model with torch.compile...")
model = torch.compile(model)
# ---- Optimizer ----
decay_params, nodecay_params = [], []
for name, p in model.named_parameters():
if p.requires_grad:
if "norm" in name or "emb" in name:
nodecay_params.append(p)
else:
decay_params.append(p)
opt = torch.optim.AdamW(
[
{"params": decay_params, "weight_decay": cfg.weight_decay},
{"params": nodecay_params, "weight_decay": 0.0},
],
lr=cfg.lr, betas=(0.9, 0.95),
)
scaler = torch.amp.GradScaler("cuda")
# ---- Resume ----
step = 0
best_val = float("inf")
ckpt_path = exp / "latest.pt"
if ckpt_path.exists():
print(f"Resuming from {ckpt_path}...")
ck = torch.load(ckpt_path, map_location=cfg.device)
# Handle compiled model keys
state = ck["model"]
if any(k.startswith("_orig_mod.") for k in state):
state = {k.replace("_orig_mod.", ""): v for k, v in state.items()}
model.load_state_dict(state)
opt.load_state_dict(ck["optimizer"])
scaler.load_state_dict(ck["scaler"])
step = ck["step"]
best_val = ck.get("best_val", float("inf"))
print(f"Resumed at step {step}, best_val={best_val:.4f}")
# ---- Training loop ----
print(f"\nTraining for {cfg.max_steps:,} steps "
f"(batch={cfg.batch_size}, accum={cfg.grad_accum}, "
f"eff_batch={cfg.batch_size * cfg.grad_accum})\n")
model.train()
train_iter = iter(train_loader)
running_loss = 0.0
t0 = time.time()
tokens_since_log = 0
while step < cfg.max_steps:
# Get batch (auto-restart on epoch boundary)
try:
x, y = next(train_iter)
except StopIteration:
train_iter = iter(train_loader)
x, y = next(train_iter)
x, y = x.to(cfg.device, non_blocking=True), y.to(cfg.device, non_blocking=True)
# LR schedule
lr = get_lr(step, cfg)
for pg in opt.param_groups:
pg["lr"] = lr
# Forward + backward (mixed precision FP16)
with torch.amp.autocast("cuda", dtype=torch.float16):
_, loss = model(x, y)
scaled_loss = loss / cfg.grad_accum
scaler.scale(scaled_loss).backward()
running_loss += loss.item()
tokens_since_log += x.numel()
# Optimizer step every grad_accum mini-batches
if (step + 1) % cfg.grad_accum == 0:
scaler.unscale_(opt)
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
scaler.step(opt)
scaler.update()
opt.zero_grad(set_to_none=True)
step += 1
# ---- Logging ----
if step % cfg.log_interval == 0:
avg = running_loss / cfg.log_interval
elapsed = time.time() - t0
tps = tokens_since_log / elapsed
ppl = math.exp(min(avg, 20)) # cap for display
print(
f"step {step:>6d}/{cfg.max_steps} | "
f"loss {avg:.4f} | ppl {ppl:>8.2f} | "
f"lr {lr:.2e} | {tps/1e3:.0f}K tok/s"
)
running_loss = 0.0
tokens_since_log = 0
t0 = time.time()
# ---- Evaluation ----
if step % cfg.eval_interval == 0:
vl = evaluate(model, val_loader, cfg.device, cfg.eval_steps)
vppl = math.exp(min(vl, 20))
improved = vl < best_val
tag = " ** NEW BEST **" if improved else ""
print(f" >>> val_loss={vl:.4f} val_ppl={vppl:.2f}{tag}")
if improved:
best_val = vl
save_dict = {"model": model.state_dict(), "step": step,
"val_loss": vl, "config": asdict(cfg)}
torch.save(save_dict, exp / "best.pt")
model.train()
# ---- Generate samples ----
if step % cfg.gen_interval == 0:
model.eval()
for prompt in ["Once upon a time", "The little dog", "She was very happy"]:
ids = [sp.bos_id()] + sp.encode(prompt)
idx = torch.tensor([ids], device=cfg.device)
out = model.generate(idx, max_new=150, temp=0.8, top_k=40,
eos_id=sp.eos_id())
text = sp.decode(out[0].tolist())
print(f" GEN [{prompt[:20]}] → {text[:250]}")
model.train()
# ---- Checkpoint ----
if step % cfg.save_interval == 0:
torch.save(
{
"model": model.state_dict(),
"optimizer": opt.state_dict(),
"scaler": scaler.state_dict(),
"step": step,
"best_val": best_val,
"config": asdict(cfg),
},
ckpt_path,
)
# ---- Final save ----
torch.save(
{"model": model.state_dict(), "step": step, "config": asdict(cfg)},
exp / "final.pt",
)
print(f"\nTraining complete! Best val loss: {best_val:.4f} (ppl {math.exp(best_val):.2f})")
if __name__ == "__main__":
main()