| |
| """ |
| 1-Bit Transformer LM on TinyStories |
| < 1M params | < 200 vocab | BitNet b1.58 ternary weights {-1, 0, +1} |
| |
| Architecture: RoPE, RMSNorm, SwiGLU, tied embeddings |
| Tokenizer: SentencePiece unigram (192 vocab) |
| """ |
|
|
| import os, json, math, time, random, argparse |
| from pathlib import Path |
| from dataclasses import dataclass, asdict |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| import sentencepiece as spm |
|
|
|
|
| |
| |
| |
| @dataclass |
| class Config: |
| |
| vocab_size: int = 192 |
| d_model: int = 128 |
| n_heads: int = 4 |
| n_layers: int = 5 |
| d_ff: int = 336 |
| max_seq_len: int = 512 |
|
|
| |
| batch_size: int = 96 |
| grad_accum: int = 4 |
| lr: float = 1.5e-3 |
| min_lr: float = 1e-5 |
| warmup_steps: int = 800 |
| max_steps: int = 100_000 |
| weight_decay: float = 0.1 |
| grad_clip: float = 1.0 |
|
|
| |
| eval_interval: int = 1000 |
| eval_steps: int = 50 |
| log_interval: int = 100 |
| gen_interval: int = 5000 |
| save_interval: int = 5000 |
|
|
| |
| seed: int = 42 |
| device: str = "cuda:0" |
| compile: bool = False |
| num_workers: int = 0 |
|
|
|
|
| |
| |
| |
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.w = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
|
|
| def forward(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.w |
|
|
|
|
| class BitLinear(nn.Module): |
| """Linear layer with ternary {-1, 0, +1} weight quantization (BitNet b1.58). |
| Full-precision latent weights are kept for optimizer updates. |
| Forward uses quantized weights via straight-through estimator.""" |
|
|
| def __init__(self, in_f, out_f): |
| super().__init__() |
| self.weight = nn.Parameter(torch.empty(out_f, in_f)) |
| nn.init.normal_(self.weight, std=0.02) |
|
|
| def forward(self, x): |
| alpha = self.weight.abs().mean().clamp(min=1e-5) |
| wq = torch.clamp(torch.round(self.weight / alpha), -1, 1) * alpha |
| w = self.weight + (wq - self.weight).detach() |
| return F.linear(x, w) |
|
|
|
|
| def _rope_freqs(dim, max_len, base=10000.0): |
| f = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) |
| t = torch.arange(max_len, dtype=torch.float32) |
| ang = torch.outer(t, f) |
| return torch.cos(ang), torch.sin(ang) |
|
|
|
|
| def _apply_rope(x, c, s): |
| d = x.shape[-1] // 2 |
| x1, x2 = x[..., :d], x[..., d:] |
| return torch.cat([x1 * c - x2 * s, x2 * c + x1 * s], dim=-1) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, d, h, ff): |
| super().__init__() |
| self.n1 = RMSNorm(d) |
| self.n2 = RMSNorm(d) |
| |
| self.q = BitLinear(d, d) |
| self.k = BitLinear(d, d) |
| self.v = BitLinear(d, d) |
| self.o = BitLinear(d, d) |
| |
| self.gate = BitLinear(d, ff) |
| self.up = BitLinear(d, ff) |
| self.down = BitLinear(ff, d) |
| self.nh = h |
| self.hd = d // h |
|
|
| def forward(self, x, cos, sin): |
| B, T, C = x.shape |
| h = self.n1(x) |
| q = self.q(h).view(B, T, self.nh, self.hd).transpose(1, 2) |
| k = self.k(h).view(B, T, self.nh, self.hd).transpose(1, 2) |
| v = self.v(h).view(B, T, self.nh, self.hd).transpose(1, 2) |
| q = _apply_rope(q, cos, sin) |
| k = _apply_rope(k, cos, sin) |
| a = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
| x = x + self.o(a.transpose(1, 2).contiguous().view(B, T, C)) |
| h = self.n2(x) |
| x = x + self.down(F.silu(self.gate(h)) * self.up(h)) |
| return x |
|
|
|
|
| class BitLM(nn.Module): |
| def __init__(self, cfg: Config): |
| super().__init__() |
| self.cfg = cfg |
| self.emb = nn.Embedding(cfg.vocab_size, cfg.d_model) |
| self.layers = nn.ModuleList( |
| [Block(cfg.d_model, cfg.n_heads, cfg.d_ff) for _ in range(cfg.n_layers)] |
| ) |
| self.norm = RMSNorm(cfg.d_model) |
| self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) |
| self.head.weight = self.emb.weight |
|
|
| hd = cfg.d_model // cfg.n_heads |
| c, s = _rope_freqs(hd, cfg.max_seq_len) |
| self.register_buffer("rc", c) |
| self.register_buffer("rs", s) |
| nn.init.normal_(self.emb.weight, std=0.02) |
|
|
| def forward(self, idx, targets=None): |
| B, T = idx.shape |
| x = self.emb(idx) |
| c = self.rc[:T].unsqueeze(0).unsqueeze(0) |
| s = self.rs[:T].unsqueeze(0).unsqueeze(0) |
| for layer in self.layers: |
| x = layer(x, c, s) |
| logits = self.head(self.norm(x)) |
| loss = None |
| if targets is not None: |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=0 |
| ) |
| return logits, loss |
|
|
| def param_count(self): |
| seen = set() |
| total = 0 |
| for p in self.parameters(): |
| pid = id(p) |
| if pid not in seen: |
| seen.add(pid) |
| total += p.numel() |
| return total |
|
|
| @torch.no_grad() |
| def generate(self, idx, max_new=200, temp=0.8, top_k=40, eos_id=2): |
| for _ in range(max_new): |
| ic = idx[:, -self.cfg.max_seq_len:] |
| logits, _ = self(ic) |
| logits = logits[:, -1] / temp |
| if top_k > 0: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = float("-inf") |
| probs = F.softmax(logits, dim=-1) |
| nxt = torch.multinomial(probs, 1) |
| idx = torch.cat([idx, nxt], dim=1) |
| if nxt.item() == eos_id: |
| break |
| return idx |
|
|
|
|
| |
| |
| |
| class ChunkedDataset(Dataset): |
| """Flat token tensor split into fixed-length chunks.""" |
| def __init__(self, tokens, seq_len): |
| self.tokens = tokens |
| self.seq_len = seq_len |
| self.n = (len(tokens) - 1) // seq_len |
|
|
| def __len__(self): |
| return self.n |
|
|
| def __getitem__(self, i): |
| s = i * self.seq_len |
| c = self.tokens[s : s + self.seq_len + 1] |
| return c[:-1], c[1:] |
|
|
|
|
| |
| |
| |
| def train_tokenizer(texts, exp_dir, vocab_size=192, n_train=100_000): |
| """Train SentencePiece unigram tokenizer with <200 vocab.""" |
| data_file = exp_dir / "sp_train.txt" |
| prefix = str(exp_dir / "tokenizer") |
|
|
| print(f"Writing {min(n_train, len(texts))} texts for tokenizer training...") |
| with open(data_file, "w", encoding="utf-8") as f: |
| for t in texts[:n_train]: |
| f.write(t.strip().replace("\n", " ") + "\n") |
|
|
| print("Training SentencePiece unigram tokenizer...") |
| spm.SentencePieceTrainer.train( |
| input=str(data_file), |
| model_prefix=prefix, |
| vocab_size=vocab_size, |
| model_type="unigram", |
| character_coverage=1.0, |
| pad_id=0, bos_id=1, eos_id=2, unk_id=3, |
| byte_fallback=False, |
| normalization_rule_name="identity", |
| max_sentence_length=8192, |
| num_threads=os.cpu_count() or 4, |
| train_extremely_large_corpus=False, |
| ) |
| data_file.unlink(missing_ok=True) |
|
|
| sp = spm.SentencePieceProcessor(model_file=prefix + ".model") |
| print(f"Tokenizer ready: {sp.get_piece_size()} tokens") |
| return sp |
|
|
|
|
| def encode_texts(sp, texts, desc="data"): |
| """Encode all texts into a single flat token tensor (BOS story EOS ...).""" |
| bos, eos = sp.bos_id(), sp.eos_id() |
| all_ids = [] |
| t0 = time.time() |
| for i, t in enumerate(texts): |
| all_ids.append(bos) |
| all_ids.extend(sp.encode(t)) |
| all_ids.append(eos) |
| if (i + 1) % 500_000 == 0: |
| print(f" {desc}: {i+1}/{len(texts)} ({len(all_ids)/1e6:.1f}M tok)") |
| elapsed = time.time() - t0 |
| print(f" {desc}: {len(all_ids)/1e6:.2f}M tokens, {elapsed:.1f}s") |
| return torch.tensor(all_ids, dtype=torch.long) |
|
|
|
|
| |
| |
| |
| def get_lr(step, cfg): |
| if step < cfg.warmup_steps: |
| return cfg.lr * step / cfg.warmup_steps |
| if step >= cfg.max_steps: |
| return cfg.min_lr |
| r = (step - cfg.warmup_steps) / (cfg.max_steps - cfg.warmup_steps) |
| return cfg.min_lr + 0.5 * (cfg.lr - cfg.min_lr) * (1 + math.cos(math.pi * r)) |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def evaluate(model, loader, device, steps=50): |
| model.eval() |
| total, n = 0.0, 0 |
| for x, y in loader: |
| if n >= steps: |
| break |
| x, y = x.to(device), y.to(device) |
| with torch.amp.autocast("cuda", dtype=torch.float16): |
| _, loss = model(x, y) |
| total += loss.item() |
| n += 1 |
| model.train() |
| return total / max(n, 1) |
|
|
|
|
| |
| |
| |
| def main(): |
| parser = argparse.ArgumentParser(description="Train 1-bit Transformer LM") |
| parser.add_argument("--exp-dir", default="/root/experiments/1m-model") |
| parser.add_argument("--max-steps", type=int, default=100_000) |
| parser.add_argument("--batch-size", type=int, default=96) |
| parser.add_argument("--lr", type=float, default=1.5e-3) |
| parser.add_argument("--device", default="cuda:0") |
| parser.add_argument("--compile", action="store_true") |
| parser.add_argument("--generate", action="store_true") |
| parser.add_argument("--prompt", default="Once upon a time") |
| args = parser.parse_args() |
|
|
| cfg = Config() |
| cfg.batch_size = args.batch_size |
| cfg.max_steps = args.max_steps |
| cfg.lr = args.lr |
| cfg.device = args.device |
| cfg.compile = args.compile |
|
|
| exp = Path(args.exp_dir) |
| exp.mkdir(parents=True, exist_ok=True) |
|
|
| torch.manual_seed(cfg.seed) |
| random.seed(cfg.seed) |
| torch.backends.cudnn.benchmark = True |
|
|
| |
| tok_model = exp / "tokenizer.model" |
| if tok_model.exists(): |
| print("Loading tokenizer...") |
| sp = spm.SentencePieceProcessor(model_file=str(tok_model)) |
| else: |
| from datasets import load_dataset |
| print("Loading TinyStories for tokenizer training...") |
| ds = load_dataset("roneneldan/TinyStories", split="train") |
| subset = [ds[i]["text"] for i in range(min(100_000, len(ds)))] |
| sp = train_tokenizer(subset, exp, vocab_size=cfg.vocab_size) |
| del subset, ds |
|
|
| cfg.vocab_size = sp.get_piece_size() |
| print(f"Vocab size: {cfg.vocab_size}") |
| assert cfg.vocab_size < 200, f"Tokenizer too large: {cfg.vocab_size}" |
|
|
| |
| if args.generate: |
| model = BitLM(cfg).to(cfg.device) |
| ckpt = torch.load(exp / "best.pt", map_location=cfg.device, weights_only=True) |
| state = ckpt["model"] |
| if any(k.startswith("_orig_mod.") for k in state): |
| state = {k.replace("_orig_mod.", ""): v for k, v in state.items()} |
| model.load_state_dict(state) |
| model.eval() |
| print(f"Loaded best model (step {ckpt['step']}, val_loss={ckpt['val_loss']:.4f})") |
|
|
| ids = [sp.bos_id()] + sp.encode(args.prompt) |
| idx = torch.tensor([ids], device=cfg.device) |
| out = model.generate(idx, max_new=500, temp=0.8, top_k=40, eos_id=sp.eos_id()) |
| text = sp.decode(out[0].tolist()) |
| print(f"\n--- Generated ---\n{text}\n") |
| return |
|
|
| |
| train_cache = exp / "train_tokens.pt" |
| val_cache = exp / "val_tokens.pt" |
|
|
| if train_cache.exists() and val_cache.exists(): |
| print("Loading cached tokens...") |
| train_tok = torch.load(train_cache, weights_only=True) |
| val_tok = torch.load(val_cache, weights_only=True) |
| else: |
| from datasets import load_dataset |
| print("Loading TinyStories...") |
| train_ds = load_dataset("roneneldan/TinyStories", split="train") |
| val_ds = load_dataset("roneneldan/TinyStories", split="validation") |
|
|
| train_texts = [ex["text"] for ex in train_ds] |
| val_texts = [ex["text"] for ex in val_ds] |
| print(f"Train: {len(train_texts):,} stories, Val: {len(val_texts):,} stories") |
|
|
| train_tok = encode_texts(sp, train_texts, "train") |
| val_tok = encode_texts(sp, val_texts, "val") |
|
|
| print("Saving cached tokens...") |
| torch.save(train_tok, train_cache) |
| torch.save(val_tok, val_cache) |
| del train_texts, val_texts |
|
|
| train_data = ChunkedDataset(train_tok, cfg.max_seq_len) |
| val_data = ChunkedDataset(val_tok, cfg.max_seq_len) |
| print(f"Train: {len(train_data):,} chunks, Val: {len(val_data):,} chunks") |
|
|
| train_loader = DataLoader( |
| train_data, batch_size=cfg.batch_size, shuffle=True, |
| num_workers=cfg.num_workers, pin_memory=True, drop_last=True, |
| ) |
| val_loader = DataLoader( |
| val_data, batch_size=cfg.batch_size, shuffle=False, |
| num_workers=cfg.num_workers, pin_memory=True, drop_last=True, |
| ) |
|
|
| |
| model = BitLM(cfg).to(cfg.device) |
| n_params = model.param_count() |
| print(f"\nModel: {n_params:,} parameters ({n_params/1e6:.3f}M)") |
| print(f" d_model={cfg.d_model}, n_heads={cfg.n_heads}, n_layers={cfg.n_layers}, " |
| f"d_ff={cfg.d_ff}, max_seq_len={cfg.max_seq_len}") |
| assert n_params < 1_000_000, f"Model too large: {n_params:,} params >= 1M" |
|
|
| if cfg.compile: |
| print("Compiling model with torch.compile...") |
| model = torch.compile(model) |
|
|
| |
| decay_params, nodecay_params = [], [] |
| for name, p in model.named_parameters(): |
| if p.requires_grad: |
| if "norm" in name or "emb" in name: |
| nodecay_params.append(p) |
| else: |
| decay_params.append(p) |
|
|
| opt = torch.optim.AdamW( |
| [ |
| {"params": decay_params, "weight_decay": cfg.weight_decay}, |
| {"params": nodecay_params, "weight_decay": 0.0}, |
| ], |
| lr=cfg.lr, betas=(0.9, 0.95), |
| ) |
| scaler = torch.amp.GradScaler("cuda") |
|
|
| |
| step = 0 |
| best_val = float("inf") |
| ckpt_path = exp / "latest.pt" |
| if ckpt_path.exists(): |
| print(f"Resuming from {ckpt_path}...") |
| ck = torch.load(ckpt_path, map_location=cfg.device) |
| |
| state = ck["model"] |
| if any(k.startswith("_orig_mod.") for k in state): |
| state = {k.replace("_orig_mod.", ""): v for k, v in state.items()} |
| model.load_state_dict(state) |
| opt.load_state_dict(ck["optimizer"]) |
| scaler.load_state_dict(ck["scaler"]) |
| step = ck["step"] |
| best_val = ck.get("best_val", float("inf")) |
| print(f"Resumed at step {step}, best_val={best_val:.4f}") |
|
|
| |
| print(f"\nTraining for {cfg.max_steps:,} steps " |
| f"(batch={cfg.batch_size}, accum={cfg.grad_accum}, " |
| f"eff_batch={cfg.batch_size * cfg.grad_accum})\n") |
|
|
| model.train() |
| train_iter = iter(train_loader) |
| running_loss = 0.0 |
| t0 = time.time() |
| tokens_since_log = 0 |
|
|
| while step < cfg.max_steps: |
| |
| try: |
| x, y = next(train_iter) |
| except StopIteration: |
| train_iter = iter(train_loader) |
| x, y = next(train_iter) |
|
|
| x, y = x.to(cfg.device, non_blocking=True), y.to(cfg.device, non_blocking=True) |
|
|
| |
| lr = get_lr(step, cfg) |
| for pg in opt.param_groups: |
| pg["lr"] = lr |
|
|
| |
| with torch.amp.autocast("cuda", dtype=torch.float16): |
| _, loss = model(x, y) |
| scaled_loss = loss / cfg.grad_accum |
|
|
| scaler.scale(scaled_loss).backward() |
| running_loss += loss.item() |
| tokens_since_log += x.numel() |
|
|
| |
| if (step + 1) % cfg.grad_accum == 0: |
| scaler.unscale_(opt) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) |
| scaler.step(opt) |
| scaler.update() |
| opt.zero_grad(set_to_none=True) |
|
|
| step += 1 |
|
|
| |
| if step % cfg.log_interval == 0: |
| avg = running_loss / cfg.log_interval |
| elapsed = time.time() - t0 |
| tps = tokens_since_log / elapsed |
| ppl = math.exp(min(avg, 20)) |
| print( |
| f"step {step:>6d}/{cfg.max_steps} | " |
| f"loss {avg:.4f} | ppl {ppl:>8.2f} | " |
| f"lr {lr:.2e} | {tps/1e3:.0f}K tok/s" |
| ) |
| running_loss = 0.0 |
| tokens_since_log = 0 |
| t0 = time.time() |
|
|
| |
| if step % cfg.eval_interval == 0: |
| vl = evaluate(model, val_loader, cfg.device, cfg.eval_steps) |
| vppl = math.exp(min(vl, 20)) |
| improved = vl < best_val |
| tag = " ** NEW BEST **" if improved else "" |
| print(f" >>> val_loss={vl:.4f} val_ppl={vppl:.2f}{tag}") |
| if improved: |
| best_val = vl |
| save_dict = {"model": model.state_dict(), "step": step, |
| "val_loss": vl, "config": asdict(cfg)} |
| torch.save(save_dict, exp / "best.pt") |
| model.train() |
|
|
| |
| if step % cfg.gen_interval == 0: |
| model.eval() |
| for prompt in ["Once upon a time", "The little dog", "She was very happy"]: |
| ids = [sp.bos_id()] + sp.encode(prompt) |
| idx = torch.tensor([ids], device=cfg.device) |
| out = model.generate(idx, max_new=150, temp=0.8, top_k=40, |
| eos_id=sp.eos_id()) |
| text = sp.decode(out[0].tolist()) |
| print(f" GEN [{prompt[:20]}] → {text[:250]}") |
| model.train() |
|
|
| |
| if step % cfg.save_interval == 0: |
| torch.save( |
| { |
| "model": model.state_dict(), |
| "optimizer": opt.state_dict(), |
| "scaler": scaler.state_dict(), |
| "step": step, |
| "best_val": best_val, |
| "config": asdict(cfg), |
| }, |
| ckpt_path, |
| ) |
|
|
| |
| torch.save( |
| {"model": model.state_dict(), "step": step, "config": asdict(cfg)}, |
| exp / "final.pt", |
| ) |
| print(f"\nTraining complete! Best val loss: {best_val:.4f} (ppl {math.exp(best_val):.2f})") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|