RecursiveComplete / train_big.py
Gentraxyz's picture
Upload folder using huggingface_hub
3c38b94 verified
Raw
History Blame Contribute Delete
3.27 kB
import os, time, json, math, numpy as np, torch
from gpt2 import GPT2
torch.manual_seed(1337)
torch.set_num_threads(2)
# ---------- config: "big but safe" 18M-param GPT-2 (BPE) ----------
cfg = dict(vocab_size=8192, n_embd=448, n_head=7, n_layer=6,
block_size=256, dropout=0.1)
batch_size = 10
grad_accum = 2 # effective batch 40 for stability
lr = 6e-4
min_lr = 6e-5
warmup = 300
max_iters = 12000 # long run; checkpoints every eval so we can stop anytime
eval_iter = 500
eval_batches = 40
ckpt_path = 'big.pt'
meta = json.load(open('data/meta.json'))
eot_id = meta['eot']
data = np.memmap('data/train.bin', dtype=np.uint16, mode='r')
# hold out a random interior slice for val (TinyStories region, not the Alpaca tail)
val_lo = int(0.90 * len(data)); val_hi = int(0.92 * len(data))
train_data = data # sample train from anywhere; val region is tiny
val_data = data[val_lo:val_hi]
print(f"corpus tokens: {len(data):,} | vocab {cfg['vocab_size']} | model ready")
def get_batch(split):
blk = cfg['block_size']
if split == 'train':
ix = torch.randint(len(train_data) - blk, (batch_size,))
d = train_data
else:
ix = torch.randint(len(val_data) - blk, (batch_size,))
d = val_data
x = torch.stack([torch.from_numpy(d[i:i+blk].astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy(d[i+1:i+1+blk].astype(np.int64)) for i in ix])
return x, y
@torch.no_grad()
def est():
out = {}; model.eval()
for sp in ['train', 'val']:
L = torch.zeros(eval_batches)
for k in range(eval_batches):
x, y = get_batch(sp); _, l = model(x, y); L[k] = l.item()
out[sp] = L.mean().item()
model.train(); return out
def get_lr(it):
if it < warmup: return lr * it / warmup
if it > max_iters: return min_lr
r = (it - warmup) / (max_iters - warmup)
return min_lr + 0.5 * (1 + math.cos(math.pi * r)) * (lr - min_lr)
model = GPT2(cfg)
print(f"params: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")
opt = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1)
start_iter = 0; hist = []
if os.path.exists(ckpt_path):
ck = torch.load(ckpt_path, map_location='cpu')
model.load_state_dict(ck['model']); opt.load_state_dict(ck['opt'])
start_iter = ck['iter']; hist = ck.get('hist', [])
print(f"RESUMED from iter {start_iter}")
t0 = time.time(); best_val = 1e9
for it in range(start_iter, max_iters + 1):
for g in opt.param_groups: g['lr'] = get_lr(it)
if it % eval_iter == 0:
l = est(); el = time.time() - t0
print(f"iter {it:5d} | train {l['train']:.3f} | val {l['val']:.3f} | lr {get_lr(it):.1e} | {el/60:.1f}min", flush=True)
hist.append({"iter": it, **l, "t": el})
torch.save({'model': model.state_dict(), 'opt': opt.state_dict(),
'cfg': cfg, 'iter': it, 'hist': hist, 'eot': eot_id}, ckpt_path)
# grad accumulation
opt.zero_grad(set_to_none=True)
for _ in range(grad_accum):
x, y = get_batch('train'); _, loss = model(x, y)
(loss / grad_accum).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
print(f"DONE {(time.time()-t0)/60:.1f}min")