File size: 2,528 Bytes
04e4b39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import yaml, math, time, json
import torch
from pathlib import Path
from tokenizers import Tokenizer
from torch.utils.data import DataLoader
from torch.optim import AdamW
from model.tiny_gpt2 import TinyGPT2, GPTConfig
from train.data_utils import TextDataset
def get_device(name):
if name == "auto":
return "cuda" if torch.cuda.is_available() else "cpu"
return name
def cosine_lr(step, max_steps, base, min_lr, warmup):
if step < warmup:
return base * step / max(1, warmup)
progress = (step - warmup)/max(1, max_steps - warmup)
return min_lr + 0.5*(base-min_lr)*(1+math.cos(math.pi*progress))
if __name__ == "__main__":
cfg = yaml.safe_load(open("train/config.yaml"))
device = get_device(cfg["device"])
Path(cfg["save_dir"]).mkdir(parents=True, exist_ok=True)
tok = Tokenizer.from_file(cfg["tokenizer_path"])
ids = tok.encode(open(cfg["train_txt"], "r", encoding="utf-8").read()).ids
ds = TextDataset(ids, cfg["block_size"])
dl = DataLoader(ds, batch_size=cfg["batch_size"], shuffle=True, drop_last=True)
gcfg = GPTConfig(
vocab_size=cfg["vocab_size"],
n_layer=cfg["n_layer"],
n_head=cfg["n_head"],
n_embed=cfg["n_embed"],
block_size=cfg["block_size"],
)
model = TinyGPT2(gcfg).to(device)
opt = AdamW(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"])
step, t0 = 0, time.time()
model.train()
for epoch in range(999999):
for x, y in dl:
step += 1
x, y = x.to(device), y.to(device)
logits = model(x)
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["grad_clip"])
lr = cosine_lr(step, cfg["max_steps"], cfg["lr"], cfg["min_lr"], cfg["warmup_steps"])
for g in opt.param_groups: g["lr"] = lr
opt.step(); opt.zero_grad(set_to_none=True)
if step % 100 == 0:
dt = time.time() - t0; t0 = time.time()
print(f"step {step:6d} | loss {loss.item():.4f} | lr {lr:.2e} | {dt:.2f}s")
if step >= cfg["max_steps"]:
torch.save(model.state_dict(), f"{cfg['save_dir']}/model.pt")
with open(f"{cfg['save_dir']}/gpt_config.json", "w") as f:
json.dump(gcfg.__dict__, f, indent=2)
print("saved checkpoint. done.")
raise SystemExit
|