| import yaml, math, time, json | |
| import torch | |
| from pathlib import Path | |
| from tokenizers import Tokenizer | |
| from torch.utils.data import DataLoader | |
| from torch.optim import AdamW | |
| from model.tiny_gpt2 import TinyGPT2, GPTConfig | |
| from train.data_utils import TextDataset | |
| def get_device(name): | |
| if name == "auto": | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| return name | |
| def cosine_lr(step, max_steps, base, min_lr, warmup): | |
| if step < warmup: | |
| return base * step / max(1, warmup) | |
| progress = (step - warmup)/max(1, max_steps - warmup) | |
| return min_lr + 0.5*(base-min_lr)*(1+math.cos(math.pi*progress)) | |
| if __name__ == "__main__": | |
| cfg = yaml.safe_load(open("train/config.yaml")) | |
| device = get_device(cfg["device"]) | |
| Path(cfg["save_dir"]).mkdir(parents=True, exist_ok=True) | |
| tok = Tokenizer.from_file(cfg["tokenizer_path"]) | |
| ids = tok.encode(open(cfg["train_txt"], "r", encoding="utf-8").read()).ids | |
| ds = TextDataset(ids, cfg["block_size"]) | |
| dl = DataLoader(ds, batch_size=cfg["batch_size"], shuffle=True, drop_last=True) | |
| gcfg = GPTConfig( | |
| vocab_size=cfg["vocab_size"], | |
| n_layer=cfg["n_layer"], | |
| n_head=cfg["n_head"], | |
| n_embed=cfg["n_embed"], | |
| block_size=cfg["block_size"], | |
| ) | |
| model = TinyGPT2(gcfg).to(device) | |
| opt = AdamW(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"]) | |
| step, t0 = 0, time.time() | |
| model.train() | |
| for epoch in range(999999): | |
| for x, y in dl: | |
| step += 1 | |
| x, y = x.to(device), y.to(device) | |
| logits = model(x) | |
| loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1)) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["grad_clip"]) | |
| lr = cosine_lr(step, cfg["max_steps"], cfg["lr"], cfg["min_lr"], cfg["warmup_steps"]) | |
| for g in opt.param_groups: g["lr"] = lr | |
| opt.step(); opt.zero_grad(set_to_none=True) | |
| if step % 100 == 0: | |
| dt = time.time() - t0; t0 = time.time() | |
| print(f"step {step:6d} | loss {loss.item():.4f} | lr {lr:.2e} | {dt:.2f}s") | |
| if step >= cfg["max_steps"]: | |
| torch.save(model.state_dict(), f"{cfg['save_dir']}/model.pt") | |
| with open(f"{cfg['save_dir']}/gpt_config.json", "w") as f: | |
| json.dump(gcfg.__dict__, f, indent=2) | |
| print("saved checkpoint. done.") | |
| raise SystemExit | |