File size: 19,383 Bytes
7a2fc07 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 | #!/usr/bin/env python3
"""
1-Bit Transformer LM on TinyStories
< 1M params | < 200 vocab | BitNet b1.58 ternary weights {-1, 0, +1}
Architecture: RoPE, RMSNorm, SwiGLU, tied embeddings
Tokenizer: SentencePiece unigram (192 vocab)
"""
import os, json, math, time, random, argparse
from pathlib import Path
from dataclasses import dataclass, asdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import sentencepiece as spm
# ================================================================
# Config
# ================================================================
@dataclass
class Config:
# Model
vocab_size: int = 192 # < 200
d_model: int = 128
n_heads: int = 4 # head_dim = 32
n_layers: int = 5
d_ff: int = 336 # SwiGLU intermediate
max_seq_len: int = 512
# Training
batch_size: int = 96
grad_accum: int = 4 # effective batch = 384
lr: float = 1.5e-3
min_lr: float = 1e-5
warmup_steps: int = 800
max_steps: int = 100_000
weight_decay: float = 0.1
grad_clip: float = 1.0
# Logging / eval
eval_interval: int = 1000
eval_steps: int = 50
log_interval: int = 100
gen_interval: int = 5000
save_interval: int = 5000
# Misc
seed: int = 42
device: str = "cuda:0"
compile: bool = False
num_workers: int = 0
# ================================================================
# Model
# ================================================================
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.w = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.w
class BitLinear(nn.Module):
"""Linear layer with ternary {-1, 0, +1} weight quantization (BitNet b1.58).
Full-precision latent weights are kept for optimizer updates.
Forward uses quantized weights via straight-through estimator."""
def __init__(self, in_f, out_f):
super().__init__()
self.weight = nn.Parameter(torch.empty(out_f, in_f))
nn.init.normal_(self.weight, std=0.02)
def forward(self, x):
alpha = self.weight.abs().mean().clamp(min=1e-5)
wq = torch.clamp(torch.round(self.weight / alpha), -1, 1) * alpha
w = self.weight + (wq - self.weight).detach() # STE
return F.linear(x, w)
def _rope_freqs(dim, max_len, base=10000.0):
f = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
t = torch.arange(max_len, dtype=torch.float32)
ang = torch.outer(t, f)
return torch.cos(ang), torch.sin(ang)
def _apply_rope(x, c, s):
d = x.shape[-1] // 2
x1, x2 = x[..., :d], x[..., d:]
return torch.cat([x1 * c - x2 * s, x2 * c + x1 * s], dim=-1)
class Block(nn.Module):
def __init__(self, d, h, ff):
super().__init__()
self.n1 = RMSNorm(d)
self.n2 = RMSNorm(d)
# Attention
self.q = BitLinear(d, d)
self.k = BitLinear(d, d)
self.v = BitLinear(d, d)
self.o = BitLinear(d, d)
# SwiGLU FFN
self.gate = BitLinear(d, ff)
self.up = BitLinear(d, ff)
self.down = BitLinear(ff, d)
self.nh = h
self.hd = d // h
def forward(self, x, cos, sin):
B, T, C = x.shape
h = self.n1(x)
q = self.q(h).view(B, T, self.nh, self.hd).transpose(1, 2)
k = self.k(h).view(B, T, self.nh, self.hd).transpose(1, 2)
v = self.v(h).view(B, T, self.nh, self.hd).transpose(1, 2)
q = _apply_rope(q, cos, sin)
k = _apply_rope(k, cos, sin)
a = F.scaled_dot_product_attention(q, k, v, is_causal=True)
x = x + self.o(a.transpose(1, 2).contiguous().view(B, T, C))
h = self.n2(x)
x = x + self.down(F.silu(self.gate(h)) * self.up(h))
return x
class BitLM(nn.Module):
def __init__(self, cfg: Config):
super().__init__()
self.cfg = cfg
self.emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.layers = nn.ModuleList(
[Block(cfg.d_model, cfg.n_heads, cfg.d_ff) for _ in range(cfg.n_layers)]
)
self.norm = RMSNorm(cfg.d_model)
self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
self.head.weight = self.emb.weight # weight tying
hd = cfg.d_model // cfg.n_heads
c, s = _rope_freqs(hd, cfg.max_seq_len)
self.register_buffer("rc", c)
self.register_buffer("rs", s)
nn.init.normal_(self.emb.weight, std=0.02)
def forward(self, idx, targets=None):
B, T = idx.shape
x = self.emb(idx)
c = self.rc[:T].unsqueeze(0).unsqueeze(0) # (1,1,T,hd/2)
s = self.rs[:T].unsqueeze(0).unsqueeze(0)
for layer in self.layers:
x = layer(x, c, s)
logits = self.head(self.norm(x))
loss = None
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=0
)
return logits, loss
def param_count(self):
seen = set()
total = 0
for p in self.parameters():
pid = id(p)
if pid not in seen:
seen.add(pid)
total += p.numel()
return total
@torch.no_grad()
def generate(self, idx, max_new=200, temp=0.8, top_k=40, eos_id=2):
for _ in range(max_new):
ic = idx[:, -self.cfg.max_seq_len:]
logits, _ = self(ic)
logits = logits[:, -1] / temp
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float("-inf")
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, 1)
idx = torch.cat([idx, nxt], dim=1)
if nxt.item() == eos_id:
break
return idx
# ================================================================
# Dataset
# ================================================================
class ChunkedDataset(Dataset):
"""Flat token tensor split into fixed-length chunks."""
def __init__(self, tokens, seq_len):
self.tokens = tokens
self.seq_len = seq_len
self.n = (len(tokens) - 1) // seq_len
def __len__(self):
return self.n
def __getitem__(self, i):
s = i * self.seq_len
c = self.tokens[s : s + self.seq_len + 1]
return c[:-1], c[1:]
# ================================================================
# Tokenizer helpers
# ================================================================
def train_tokenizer(texts, exp_dir, vocab_size=192, n_train=100_000):
"""Train SentencePiece unigram tokenizer with <200 vocab."""
data_file = exp_dir / "sp_train.txt"
prefix = str(exp_dir / "tokenizer")
print(f"Writing {min(n_train, len(texts))} texts for tokenizer training...")
with open(data_file, "w", encoding="utf-8") as f:
for t in texts[:n_train]:
f.write(t.strip().replace("\n", " ") + "\n")
print("Training SentencePiece unigram tokenizer...")
spm.SentencePieceTrainer.train(
input=str(data_file),
model_prefix=prefix,
vocab_size=vocab_size,
model_type="unigram",
character_coverage=1.0,
pad_id=0, bos_id=1, eos_id=2, unk_id=3,
byte_fallback=False,
normalization_rule_name="identity",
max_sentence_length=8192,
num_threads=os.cpu_count() or 4,
train_extremely_large_corpus=False,
)
data_file.unlink(missing_ok=True)
sp = spm.SentencePieceProcessor(model_file=prefix + ".model")
print(f"Tokenizer ready: {sp.get_piece_size()} tokens")
return sp
def encode_texts(sp, texts, desc="data"):
"""Encode all texts into a single flat token tensor (BOS story EOS ...)."""
bos, eos = sp.bos_id(), sp.eos_id()
all_ids = []
t0 = time.time()
for i, t in enumerate(texts):
all_ids.append(bos)
all_ids.extend(sp.encode(t))
all_ids.append(eos)
if (i + 1) % 500_000 == 0:
print(f" {desc}: {i+1}/{len(texts)} ({len(all_ids)/1e6:.1f}M tok)")
elapsed = time.time() - t0
print(f" {desc}: {len(all_ids)/1e6:.2f}M tokens, {elapsed:.1f}s")
return torch.tensor(all_ids, dtype=torch.long)
# ================================================================
# LR schedule
# ================================================================
def get_lr(step, cfg):
if step < cfg.warmup_steps:
return cfg.lr * step / cfg.warmup_steps
if step >= cfg.max_steps:
return cfg.min_lr
r = (step - cfg.warmup_steps) / (cfg.max_steps - cfg.warmup_steps)
return cfg.min_lr + 0.5 * (cfg.lr - cfg.min_lr) * (1 + math.cos(math.pi * r))
# ================================================================
# Eval
# ================================================================
@torch.no_grad()
def evaluate(model, loader, device, steps=50):
model.eval()
total, n = 0.0, 0
for x, y in loader:
if n >= steps:
break
x, y = x.to(device), y.to(device)
with torch.amp.autocast("cuda", dtype=torch.float16):
_, loss = model(x, y)
total += loss.item()
n += 1
model.train()
return total / max(n, 1)
# ================================================================
# Main
# ================================================================
def main():
parser = argparse.ArgumentParser(description="Train 1-bit Transformer LM")
parser.add_argument("--exp-dir", default="/root/experiments/1m-model")
parser.add_argument("--max-steps", type=int, default=100_000)
parser.add_argument("--batch-size", type=int, default=96)
parser.add_argument("--lr", type=float, default=1.5e-3)
parser.add_argument("--device", default="cuda:0")
parser.add_argument("--compile", action="store_true")
parser.add_argument("--generate", action="store_true")
parser.add_argument("--prompt", default="Once upon a time")
args = parser.parse_args()
cfg = Config()
cfg.batch_size = args.batch_size
cfg.max_steps = args.max_steps
cfg.lr = args.lr
cfg.device = args.device
cfg.compile = args.compile
exp = Path(args.exp_dir)
exp.mkdir(parents=True, exist_ok=True)
torch.manual_seed(cfg.seed)
random.seed(cfg.seed)
torch.backends.cudnn.benchmark = True
# ---- Tokenizer ----
tok_model = exp / "tokenizer.model"
if tok_model.exists():
print("Loading tokenizer...")
sp = spm.SentencePieceProcessor(model_file=str(tok_model))
else:
from datasets import load_dataset
print("Loading TinyStories for tokenizer training...")
ds = load_dataset("roneneldan/TinyStories", split="train")
subset = [ds[i]["text"] for i in range(min(100_000, len(ds)))]
sp = train_tokenizer(subset, exp, vocab_size=cfg.vocab_size)
del subset, ds
cfg.vocab_size = sp.get_piece_size()
print(f"Vocab size: {cfg.vocab_size}")
assert cfg.vocab_size < 200, f"Tokenizer too large: {cfg.vocab_size}"
# ---- Generate mode ----
if args.generate:
model = BitLM(cfg).to(cfg.device)
ckpt = torch.load(exp / "best.pt", map_location=cfg.device, weights_only=True)
state = ckpt["model"]
if any(k.startswith("_orig_mod.") for k in state):
state = {k.replace("_orig_mod.", ""): v for k, v in state.items()}
model.load_state_dict(state)
model.eval()
print(f"Loaded best model (step {ckpt['step']}, val_loss={ckpt['val_loss']:.4f})")
ids = [sp.bos_id()] + sp.encode(args.prompt)
idx = torch.tensor([ids], device=cfg.device)
out = model.generate(idx, max_new=500, temp=0.8, top_k=40, eos_id=sp.eos_id())
text = sp.decode(out[0].tolist())
print(f"\n--- Generated ---\n{text}\n")
return
# ---- Data ----
train_cache = exp / "train_tokens.pt"
val_cache = exp / "val_tokens.pt"
if train_cache.exists() and val_cache.exists():
print("Loading cached tokens...")
train_tok = torch.load(train_cache, weights_only=True)
val_tok = torch.load(val_cache, weights_only=True)
else:
from datasets import load_dataset
print("Loading TinyStories...")
train_ds = load_dataset("roneneldan/TinyStories", split="train")
val_ds = load_dataset("roneneldan/TinyStories", split="validation")
train_texts = [ex["text"] for ex in train_ds]
val_texts = [ex["text"] for ex in val_ds]
print(f"Train: {len(train_texts):,} stories, Val: {len(val_texts):,} stories")
train_tok = encode_texts(sp, train_texts, "train")
val_tok = encode_texts(sp, val_texts, "val")
print("Saving cached tokens...")
torch.save(train_tok, train_cache)
torch.save(val_tok, val_cache)
del train_texts, val_texts
train_data = ChunkedDataset(train_tok, cfg.max_seq_len)
val_data = ChunkedDataset(val_tok, cfg.max_seq_len)
print(f"Train: {len(train_data):,} chunks, Val: {len(val_data):,} chunks")
train_loader = DataLoader(
train_data, batch_size=cfg.batch_size, shuffle=True,
num_workers=cfg.num_workers, pin_memory=True, drop_last=True,
)
val_loader = DataLoader(
val_data, batch_size=cfg.batch_size, shuffle=False,
num_workers=cfg.num_workers, pin_memory=True, drop_last=True,
)
# ---- Model ----
model = BitLM(cfg).to(cfg.device)
n_params = model.param_count()
print(f"\nModel: {n_params:,} parameters ({n_params/1e6:.3f}M)")
print(f" d_model={cfg.d_model}, n_heads={cfg.n_heads}, n_layers={cfg.n_layers}, "
f"d_ff={cfg.d_ff}, max_seq_len={cfg.max_seq_len}")
assert n_params < 1_000_000, f"Model too large: {n_params:,} params >= 1M"
if cfg.compile:
print("Compiling model with torch.compile...")
model = torch.compile(model)
# ---- Optimizer ----
decay_params, nodecay_params = [], []
for name, p in model.named_parameters():
if p.requires_grad:
if "norm" in name or "emb" in name:
nodecay_params.append(p)
else:
decay_params.append(p)
opt = torch.optim.AdamW(
[
{"params": decay_params, "weight_decay": cfg.weight_decay},
{"params": nodecay_params, "weight_decay": 0.0},
],
lr=cfg.lr, betas=(0.9, 0.95),
)
scaler = torch.amp.GradScaler("cuda")
# ---- Resume ----
step = 0
best_val = float("inf")
ckpt_path = exp / "latest.pt"
if ckpt_path.exists():
print(f"Resuming from {ckpt_path}...")
ck = torch.load(ckpt_path, map_location=cfg.device)
# Handle compiled model keys
state = ck["model"]
if any(k.startswith("_orig_mod.") for k in state):
state = {k.replace("_orig_mod.", ""): v for k, v in state.items()}
model.load_state_dict(state)
opt.load_state_dict(ck["optimizer"])
scaler.load_state_dict(ck["scaler"])
step = ck["step"]
best_val = ck.get("best_val", float("inf"))
print(f"Resumed at step {step}, best_val={best_val:.4f}")
# ---- Training loop ----
print(f"\nTraining for {cfg.max_steps:,} steps "
f"(batch={cfg.batch_size}, accum={cfg.grad_accum}, "
f"eff_batch={cfg.batch_size * cfg.grad_accum})\n")
model.train()
train_iter = iter(train_loader)
running_loss = 0.0
t0 = time.time()
tokens_since_log = 0
while step < cfg.max_steps:
# Get batch (auto-restart on epoch boundary)
try:
x, y = next(train_iter)
except StopIteration:
train_iter = iter(train_loader)
x, y = next(train_iter)
x, y = x.to(cfg.device, non_blocking=True), y.to(cfg.device, non_blocking=True)
# LR schedule
lr = get_lr(step, cfg)
for pg in opt.param_groups:
pg["lr"] = lr
# Forward + backward (mixed precision FP16)
with torch.amp.autocast("cuda", dtype=torch.float16):
_, loss = model(x, y)
scaled_loss = loss / cfg.grad_accum
scaler.scale(scaled_loss).backward()
running_loss += loss.item()
tokens_since_log += x.numel()
# Optimizer step every grad_accum mini-batches
if (step + 1) % cfg.grad_accum == 0:
scaler.unscale_(opt)
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
scaler.step(opt)
scaler.update()
opt.zero_grad(set_to_none=True)
step += 1
# ---- Logging ----
if step % cfg.log_interval == 0:
avg = running_loss / cfg.log_interval
elapsed = time.time() - t0
tps = tokens_since_log / elapsed
ppl = math.exp(min(avg, 20)) # cap for display
print(
f"step {step:>6d}/{cfg.max_steps} | "
f"loss {avg:.4f} | ppl {ppl:>8.2f} | "
f"lr {lr:.2e} | {tps/1e3:.0f}K tok/s"
)
running_loss = 0.0
tokens_since_log = 0
t0 = time.time()
# ---- Evaluation ----
if step % cfg.eval_interval == 0:
vl = evaluate(model, val_loader, cfg.device, cfg.eval_steps)
vppl = math.exp(min(vl, 20))
improved = vl < best_val
tag = " ** NEW BEST **" if improved else ""
print(f" >>> val_loss={vl:.4f} val_ppl={vppl:.2f}{tag}")
if improved:
best_val = vl
save_dict = {"model": model.state_dict(), "step": step,
"val_loss": vl, "config": asdict(cfg)}
torch.save(save_dict, exp / "best.pt")
model.train()
# ---- Generate samples ----
if step % cfg.gen_interval == 0:
model.eval()
for prompt in ["Once upon a time", "The little dog", "She was very happy"]:
ids = [sp.bos_id()] + sp.encode(prompt)
idx = torch.tensor([ids], device=cfg.device)
out = model.generate(idx, max_new=150, temp=0.8, top_k=40,
eos_id=sp.eos_id())
text = sp.decode(out[0].tolist())
print(f" GEN [{prompt[:20]}] → {text[:250]}")
model.train()
# ---- Checkpoint ----
if step % cfg.save_interval == 0:
torch.save(
{
"model": model.state_dict(),
"optimizer": opt.state_dict(),
"scaler": scaler.state_dict(),
"step": step,
"best_val": best_val,
"config": asdict(cfg),
},
ckpt_path,
)
# ---- Final save ----
torch.save(
{"model": model.state_dict(), "step": step, "config": asdict(cfg)},
exp / "final.pt",
)
print(f"\nTraining complete! Best val loss: {best_val:.4f} (ppl {math.exp(best_val):.2f})")
if __name__ == "__main__":
main()
|