#!/usr/bin/env python3 """ 1-Bit Transformer LM on TinyStories < 1M params | < 200 vocab | BitNet b1.58 ternary weights {-1, 0, +1} Architecture: RoPE, RMSNorm, SwiGLU, tied embeddings Tokenizer: SentencePiece unigram (192 vocab) """ import os, json, math, time, random, argparse from pathlib import Path from dataclasses import dataclass, asdict import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader import sentencepiece as spm # ================================================================ # Config # ================================================================ @dataclass class Config: # Model vocab_size: int = 192 # < 200 d_model: int = 128 n_heads: int = 4 # head_dim = 32 n_layers: int = 5 d_ff: int = 336 # SwiGLU intermediate max_seq_len: int = 512 # Training batch_size: int = 96 grad_accum: int = 4 # effective batch = 384 lr: float = 1.5e-3 min_lr: float = 1e-5 warmup_steps: int = 800 max_steps: int = 100_000 weight_decay: float = 0.1 grad_clip: float = 1.0 # Logging / eval eval_interval: int = 1000 eval_steps: int = 50 log_interval: int = 100 gen_interval: int = 5000 save_interval: int = 5000 # Misc seed: int = 42 device: str = "cuda:0" compile: bool = False num_workers: int = 0 # ================================================================ # Model # ================================================================ class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.w = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.w class BitLinear(nn.Module): """Linear layer with ternary {-1, 0, +1} weight quantization (BitNet b1.58). Full-precision latent weights are kept for optimizer updates. Forward uses quantized weights via straight-through estimator.""" def __init__(self, in_f, out_f): super().__init__() self.weight = nn.Parameter(torch.empty(out_f, in_f)) nn.init.normal_(self.weight, std=0.02) def forward(self, x): alpha = self.weight.abs().mean().clamp(min=1e-5) wq = torch.clamp(torch.round(self.weight / alpha), -1, 1) * alpha w = self.weight + (wq - self.weight).detach() # STE return F.linear(x, w) def _rope_freqs(dim, max_len, base=10000.0): f = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) t = torch.arange(max_len, dtype=torch.float32) ang = torch.outer(t, f) return torch.cos(ang), torch.sin(ang) def _apply_rope(x, c, s): d = x.shape[-1] // 2 x1, x2 = x[..., :d], x[..., d:] return torch.cat([x1 * c - x2 * s, x2 * c + x1 * s], dim=-1) class Block(nn.Module): def __init__(self, d, h, ff): super().__init__() self.n1 = RMSNorm(d) self.n2 = RMSNorm(d) # Attention self.q = BitLinear(d, d) self.k = BitLinear(d, d) self.v = BitLinear(d, d) self.o = BitLinear(d, d) # SwiGLU FFN self.gate = BitLinear(d, ff) self.up = BitLinear(d, ff) self.down = BitLinear(ff, d) self.nh = h self.hd = d // h def forward(self, x, cos, sin): B, T, C = x.shape h = self.n1(x) q = self.q(h).view(B, T, self.nh, self.hd).transpose(1, 2) k = self.k(h).view(B, T, self.nh, self.hd).transpose(1, 2) v = self.v(h).view(B, T, self.nh, self.hd).transpose(1, 2) q = _apply_rope(q, cos, sin) k = _apply_rope(k, cos, sin) a = F.scaled_dot_product_attention(q, k, v, is_causal=True) x = x + self.o(a.transpose(1, 2).contiguous().view(B, T, C)) h = self.n2(x) x = x + self.down(F.silu(self.gate(h)) * self.up(h)) return x class BitLM(nn.Module): def __init__(self, cfg: Config): super().__init__() self.cfg = cfg self.emb = nn.Embedding(cfg.vocab_size, cfg.d_model) self.layers = nn.ModuleList( [Block(cfg.d_model, cfg.n_heads, cfg.d_ff) for _ in range(cfg.n_layers)] ) self.norm = RMSNorm(cfg.d_model) self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) self.head.weight = self.emb.weight # weight tying hd = cfg.d_model // cfg.n_heads c, s = _rope_freqs(hd, cfg.max_seq_len) self.register_buffer("rc", c) self.register_buffer("rs", s) nn.init.normal_(self.emb.weight, std=0.02) def forward(self, idx, targets=None): B, T = idx.shape x = self.emb(idx) c = self.rc[:T].unsqueeze(0).unsqueeze(0) # (1,1,T,hd/2) s = self.rs[:T].unsqueeze(0).unsqueeze(0) for layer in self.layers: x = layer(x, c, s) logits = self.head(self.norm(x)) loss = None if targets is not None: loss = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=0 ) return logits, loss def param_count(self): seen = set() total = 0 for p in self.parameters(): pid = id(p) if pid not in seen: seen.add(pid) total += p.numel() return total @torch.no_grad() def generate(self, idx, max_new=200, temp=0.8, top_k=40, eos_id=2): for _ in range(max_new): ic = idx[:, -self.cfg.max_seq_len:] logits, _ = self(ic) logits = logits[:, -1] / temp if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float("-inf") probs = F.softmax(logits, dim=-1) nxt = torch.multinomial(probs, 1) idx = torch.cat([idx, nxt], dim=1) if nxt.item() == eos_id: break return idx # ================================================================ # Dataset # ================================================================ class ChunkedDataset(Dataset): """Flat token tensor split into fixed-length chunks.""" def __init__(self, tokens, seq_len): self.tokens = tokens self.seq_len = seq_len self.n = (len(tokens) - 1) // seq_len def __len__(self): return self.n def __getitem__(self, i): s = i * self.seq_len c = self.tokens[s : s + self.seq_len + 1] return c[:-1], c[1:] # ================================================================ # Tokenizer helpers # ================================================================ def train_tokenizer(texts, exp_dir, vocab_size=192, n_train=100_000): """Train SentencePiece unigram tokenizer with <200 vocab.""" data_file = exp_dir / "sp_train.txt" prefix = str(exp_dir / "tokenizer") print(f"Writing {min(n_train, len(texts))} texts for tokenizer training...") with open(data_file, "w", encoding="utf-8") as f: for t in texts[:n_train]: f.write(t.strip().replace("\n", " ") + "\n") print("Training SentencePiece unigram tokenizer...") spm.SentencePieceTrainer.train( input=str(data_file), model_prefix=prefix, vocab_size=vocab_size, model_type="unigram", character_coverage=1.0, pad_id=0, bos_id=1, eos_id=2, unk_id=3, byte_fallback=False, normalization_rule_name="identity", max_sentence_length=8192, num_threads=os.cpu_count() or 4, train_extremely_large_corpus=False, ) data_file.unlink(missing_ok=True) sp = spm.SentencePieceProcessor(model_file=prefix + ".model") print(f"Tokenizer ready: {sp.get_piece_size()} tokens") return sp def encode_texts(sp, texts, desc="data"): """Encode all texts into a single flat token tensor (BOS story EOS ...).""" bos, eos = sp.bos_id(), sp.eos_id() all_ids = [] t0 = time.time() for i, t in enumerate(texts): all_ids.append(bos) all_ids.extend(sp.encode(t)) all_ids.append(eos) if (i + 1) % 500_000 == 0: print(f" {desc}: {i+1}/{len(texts)} ({len(all_ids)/1e6:.1f}M tok)") elapsed = time.time() - t0 print(f" {desc}: {len(all_ids)/1e6:.2f}M tokens, {elapsed:.1f}s") return torch.tensor(all_ids, dtype=torch.long) # ================================================================ # LR schedule # ================================================================ def get_lr(step, cfg): if step < cfg.warmup_steps: return cfg.lr * step / cfg.warmup_steps if step >= cfg.max_steps: return cfg.min_lr r = (step - cfg.warmup_steps) / (cfg.max_steps - cfg.warmup_steps) return cfg.min_lr + 0.5 * (cfg.lr - cfg.min_lr) * (1 + math.cos(math.pi * r)) # ================================================================ # Eval # ================================================================ @torch.no_grad() def evaluate(model, loader, device, steps=50): model.eval() total, n = 0.0, 0 for x, y in loader: if n >= steps: break x, y = x.to(device), y.to(device) with torch.amp.autocast("cuda", dtype=torch.float16): _, loss = model(x, y) total += loss.item() n += 1 model.train() return total / max(n, 1) # ================================================================ # Main # ================================================================ def main(): parser = argparse.ArgumentParser(description="Train 1-bit Transformer LM") parser.add_argument("--exp-dir", default="/root/experiments/1m-model") parser.add_argument("--max-steps", type=int, default=100_000) parser.add_argument("--batch-size", type=int, default=96) parser.add_argument("--lr", type=float, default=1.5e-3) parser.add_argument("--device", default="cuda:0") parser.add_argument("--compile", action="store_true") parser.add_argument("--generate", action="store_true") parser.add_argument("--prompt", default="Once upon a time") args = parser.parse_args() cfg = Config() cfg.batch_size = args.batch_size cfg.max_steps = args.max_steps cfg.lr = args.lr cfg.device = args.device cfg.compile = args.compile exp = Path(args.exp_dir) exp.mkdir(parents=True, exist_ok=True) torch.manual_seed(cfg.seed) random.seed(cfg.seed) torch.backends.cudnn.benchmark = True # ---- Tokenizer ---- tok_model = exp / "tokenizer.model" if tok_model.exists(): print("Loading tokenizer...") sp = spm.SentencePieceProcessor(model_file=str(tok_model)) else: from datasets import load_dataset print("Loading TinyStories for tokenizer training...") ds = load_dataset("roneneldan/TinyStories", split="train") subset = [ds[i]["text"] for i in range(min(100_000, len(ds)))] sp = train_tokenizer(subset, exp, vocab_size=cfg.vocab_size) del subset, ds cfg.vocab_size = sp.get_piece_size() print(f"Vocab size: {cfg.vocab_size}") assert cfg.vocab_size < 200, f"Tokenizer too large: {cfg.vocab_size}" # ---- Generate mode ---- if args.generate: model = BitLM(cfg).to(cfg.device) ckpt = torch.load(exp / "best.pt", map_location=cfg.device, weights_only=True) state = ckpt["model"] if any(k.startswith("_orig_mod.") for k in state): state = {k.replace("_orig_mod.", ""): v for k, v in state.items()} model.load_state_dict(state) model.eval() print(f"Loaded best model (step {ckpt['step']}, val_loss={ckpt['val_loss']:.4f})") ids = [sp.bos_id()] + sp.encode(args.prompt) idx = torch.tensor([ids], device=cfg.device) out = model.generate(idx, max_new=500, temp=0.8, top_k=40, eos_id=sp.eos_id()) text = sp.decode(out[0].tolist()) print(f"\n--- Generated ---\n{text}\n") return # ---- Data ---- train_cache = exp / "train_tokens.pt" val_cache = exp / "val_tokens.pt" if train_cache.exists() and val_cache.exists(): print("Loading cached tokens...") train_tok = torch.load(train_cache, weights_only=True) val_tok = torch.load(val_cache, weights_only=True) else: from datasets import load_dataset print("Loading TinyStories...") train_ds = load_dataset("roneneldan/TinyStories", split="train") val_ds = load_dataset("roneneldan/TinyStories", split="validation") train_texts = [ex["text"] for ex in train_ds] val_texts = [ex["text"] for ex in val_ds] print(f"Train: {len(train_texts):,} stories, Val: {len(val_texts):,} stories") train_tok = encode_texts(sp, train_texts, "train") val_tok = encode_texts(sp, val_texts, "val") print("Saving cached tokens...") torch.save(train_tok, train_cache) torch.save(val_tok, val_cache) del train_texts, val_texts train_data = ChunkedDataset(train_tok, cfg.max_seq_len) val_data = ChunkedDataset(val_tok, cfg.max_seq_len) print(f"Train: {len(train_data):,} chunks, Val: {len(val_data):,} chunks") train_loader = DataLoader( train_data, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True, drop_last=True, ) val_loader = DataLoader( val_data, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True, drop_last=True, ) # ---- Model ---- model = BitLM(cfg).to(cfg.device) n_params = model.param_count() print(f"\nModel: {n_params:,} parameters ({n_params/1e6:.3f}M)") print(f" d_model={cfg.d_model}, n_heads={cfg.n_heads}, n_layers={cfg.n_layers}, " f"d_ff={cfg.d_ff}, max_seq_len={cfg.max_seq_len}") assert n_params < 1_000_000, f"Model too large: {n_params:,} params >= 1M" if cfg.compile: print("Compiling model with torch.compile...") model = torch.compile(model) # ---- Optimizer ---- decay_params, nodecay_params = [], [] for name, p in model.named_parameters(): if p.requires_grad: if "norm" in name or "emb" in name: nodecay_params.append(p) else: decay_params.append(p) opt = torch.optim.AdamW( [ {"params": decay_params, "weight_decay": cfg.weight_decay}, {"params": nodecay_params, "weight_decay": 0.0}, ], lr=cfg.lr, betas=(0.9, 0.95), ) scaler = torch.amp.GradScaler("cuda") # ---- Resume ---- step = 0 best_val = float("inf") ckpt_path = exp / "latest.pt" if ckpt_path.exists(): print(f"Resuming from {ckpt_path}...") ck = torch.load(ckpt_path, map_location=cfg.device) # Handle compiled model keys state = ck["model"] if any(k.startswith("_orig_mod.") for k in state): state = {k.replace("_orig_mod.", ""): v for k, v in state.items()} model.load_state_dict(state) opt.load_state_dict(ck["optimizer"]) scaler.load_state_dict(ck["scaler"]) step = ck["step"] best_val = ck.get("best_val", float("inf")) print(f"Resumed at step {step}, best_val={best_val:.4f}") # ---- Training loop ---- print(f"\nTraining for {cfg.max_steps:,} steps " f"(batch={cfg.batch_size}, accum={cfg.grad_accum}, " f"eff_batch={cfg.batch_size * cfg.grad_accum})\n") model.train() train_iter = iter(train_loader) running_loss = 0.0 t0 = time.time() tokens_since_log = 0 while step < cfg.max_steps: # Get batch (auto-restart on epoch boundary) try: x, y = next(train_iter) except StopIteration: train_iter = iter(train_loader) x, y = next(train_iter) x, y = x.to(cfg.device, non_blocking=True), y.to(cfg.device, non_blocking=True) # LR schedule lr = get_lr(step, cfg) for pg in opt.param_groups: pg["lr"] = lr # Forward + backward (mixed precision FP16) with torch.amp.autocast("cuda", dtype=torch.float16): _, loss = model(x, y) scaled_loss = loss / cfg.grad_accum scaler.scale(scaled_loss).backward() running_loss += loss.item() tokens_since_log += x.numel() # Optimizer step every grad_accum mini-batches if (step + 1) % cfg.grad_accum == 0: scaler.unscale_(opt) torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) scaler.step(opt) scaler.update() opt.zero_grad(set_to_none=True) step += 1 # ---- Logging ---- if step % cfg.log_interval == 0: avg = running_loss / cfg.log_interval elapsed = time.time() - t0 tps = tokens_since_log / elapsed ppl = math.exp(min(avg, 20)) # cap for display print( f"step {step:>6d}/{cfg.max_steps} | " f"loss {avg:.4f} | ppl {ppl:>8.2f} | " f"lr {lr:.2e} | {tps/1e3:.0f}K tok/s" ) running_loss = 0.0 tokens_since_log = 0 t0 = time.time() # ---- Evaluation ---- if step % cfg.eval_interval == 0: vl = evaluate(model, val_loader, cfg.device, cfg.eval_steps) vppl = math.exp(min(vl, 20)) improved = vl < best_val tag = " ** NEW BEST **" if improved else "" print(f" >>> val_loss={vl:.4f} val_ppl={vppl:.2f}{tag}") if improved: best_val = vl save_dict = {"model": model.state_dict(), "step": step, "val_loss": vl, "config": asdict(cfg)} torch.save(save_dict, exp / "best.pt") model.train() # ---- Generate samples ---- if step % cfg.gen_interval == 0: model.eval() for prompt in ["Once upon a time", "The little dog", "She was very happy"]: ids = [sp.bos_id()] + sp.encode(prompt) idx = torch.tensor([ids], device=cfg.device) out = model.generate(idx, max_new=150, temp=0.8, top_k=40, eos_id=sp.eos_id()) text = sp.decode(out[0].tolist()) print(f" GEN [{prompt[:20]}] → {text[:250]}") model.train() # ---- Checkpoint ---- if step % cfg.save_interval == 0: torch.save( { "model": model.state_dict(), "optimizer": opt.state_dict(), "scaler": scaler.state_dict(), "step": step, "best_val": best_val, "config": asdict(cfg), }, ckpt_path, ) # ---- Final save ---- torch.save( {"model": model.state_dict(), "step": step, "config": asdict(cfg)}, exp / "final.pt", ) print(f"\nTraining complete! Best val loss: {best_val:.4f} (ppl {math.exp(best_val):.2f})") if __name__ == "__main__": main()