import yaml, math, time, json import torch from pathlib import Path from tokenizers import Tokenizer from torch.utils.data import DataLoader from torch.optim import AdamW from model.tiny_gpt2 import TinyGPT2, GPTConfig from train.data_utils import TextDataset def get_device(name): if name == "auto": return "cuda" if torch.cuda.is_available() else "cpu" return name def cosine_lr(step, max_steps, base, min_lr, warmup): if step < warmup: return base * step / max(1, warmup) progress = (step - warmup)/max(1, max_steps - warmup) return min_lr + 0.5*(base-min_lr)*(1+math.cos(math.pi*progress)) if __name__ == "__main__": cfg = yaml.safe_load(open("train/config.yaml")) device = get_device(cfg["device"]) Path(cfg["save_dir"]).mkdir(parents=True, exist_ok=True) tok = Tokenizer.from_file(cfg["tokenizer_path"]) ids = tok.encode(open(cfg["train_txt"], "r", encoding="utf-8").read()).ids ds = TextDataset(ids, cfg["block_size"]) dl = DataLoader(ds, batch_size=cfg["batch_size"], shuffle=True, drop_last=True) gcfg = GPTConfig( vocab_size=cfg["vocab_size"], n_layer=cfg["n_layer"], n_head=cfg["n_head"], n_embed=cfg["n_embed"], block_size=cfg["block_size"], ) model = TinyGPT2(gcfg).to(device) opt = AdamW(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"]) step, t0 = 0, time.time() model.train() for epoch in range(999999): for x, y in dl: step += 1 x, y = x.to(device), y.to(device) logits = model(x) loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1)) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["grad_clip"]) lr = cosine_lr(step, cfg["max_steps"], cfg["lr"], cfg["min_lr"], cfg["warmup_steps"]) for g in opt.param_groups: g["lr"] = lr opt.step(); opt.zero_grad(set_to_none=True) if step % 100 == 0: dt = time.time() - t0; t0 = time.time() print(f"step {step:6d} | loss {loss.item():.4f} | lr {lr:.2e} | {dt:.2f}s") if step >= cfg["max_steps"]: torch.save(model.state_dict(), f"{cfg['save_dir']}/model.pt") with open(f"{cfg['save_dir']}/gpt_config.json", "w") as f: json.dump(gcfg.__dict__, f, indent=2) print("saved checkpoint. done.") raise SystemExit