| import os, time, json, math, numpy as np, torch |
| from gpt2 import GPT2 |
|
|
| torch.manual_seed(1337) |
| torch.set_num_threads(2) |
|
|
| |
| cfg = dict(vocab_size=8192, n_embd=448, n_head=7, n_layer=6, |
| block_size=256, dropout=0.1) |
| batch_size = 10 |
| grad_accum = 2 |
| lr = 6e-4 |
| min_lr = 6e-5 |
| warmup = 300 |
| max_iters = 12000 |
| eval_iter = 500 |
| eval_batches = 40 |
| ckpt_path = 'big.pt' |
|
|
| meta = json.load(open('data/meta.json')) |
| eot_id = meta['eot'] |
| data = np.memmap('data/train.bin', dtype=np.uint16, mode='r') |
| |
| val_lo = int(0.90 * len(data)); val_hi = int(0.92 * len(data)) |
| train_data = data |
| val_data = data[val_lo:val_hi] |
| print(f"corpus tokens: {len(data):,} | vocab {cfg['vocab_size']} | model ready") |
|
|
| def get_batch(split): |
| blk = cfg['block_size'] |
| if split == 'train': |
| ix = torch.randint(len(train_data) - blk, (batch_size,)) |
| d = train_data |
| else: |
| ix = torch.randint(len(val_data) - blk, (batch_size,)) |
| d = val_data |
| x = torch.stack([torch.from_numpy(d[i:i+blk].astype(np.int64)) for i in ix]) |
| y = torch.stack([torch.from_numpy(d[i+1:i+1+blk].astype(np.int64)) for i in ix]) |
| return x, y |
|
|
| @torch.no_grad() |
| def est(): |
| out = {}; model.eval() |
| for sp in ['train', 'val']: |
| L = torch.zeros(eval_batches) |
| for k in range(eval_batches): |
| x, y = get_batch(sp); _, l = model(x, y); L[k] = l.item() |
| out[sp] = L.mean().item() |
| model.train(); return out |
|
|
| def get_lr(it): |
| if it < warmup: return lr * it / warmup |
| if it > max_iters: return min_lr |
| r = (it - warmup) / (max_iters - warmup) |
| return min_lr + 0.5 * (1 + math.cos(math.pi * r)) * (lr - min_lr) |
|
|
| model = GPT2(cfg) |
| print(f"params: {sum(p.numel() for p in model.parameters())/1e6:.1f}M") |
| opt = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1) |
|
|
| start_iter = 0; hist = [] |
| if os.path.exists(ckpt_path): |
| ck = torch.load(ckpt_path, map_location='cpu') |
| model.load_state_dict(ck['model']); opt.load_state_dict(ck['opt']) |
| start_iter = ck['iter']; hist = ck.get('hist', []) |
| print(f"RESUMED from iter {start_iter}") |
|
|
| t0 = time.time(); best_val = 1e9 |
| for it in range(start_iter, max_iters + 1): |
| for g in opt.param_groups: g['lr'] = get_lr(it) |
| if it % eval_iter == 0: |
| l = est(); el = time.time() - t0 |
| print(f"iter {it:5d} | train {l['train']:.3f} | val {l['val']:.3f} | lr {get_lr(it):.1e} | {el/60:.1f}min", flush=True) |
| hist.append({"iter": it, **l, "t": el}) |
| torch.save({'model': model.state_dict(), 'opt': opt.state_dict(), |
| 'cfg': cfg, 'iter': it, 'hist': hist, 'eot': eot_id}, ckpt_path) |
| |
| opt.zero_grad(set_to_none=True) |
| for _ in range(grad_accum): |
| x, y = get_batch('train'); _, loss = model(x, y) |
| (loss / grad_accum).backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
|
|
| print(f"DONE {(time.time()-t0)/60:.1f}min") |
|
|