adityashisharma's picture
Upload 6 files
04e4b39 verified
import yaml, math, time, json
import torch
from pathlib import Path
from tokenizers import Tokenizer
from torch.utils.data import DataLoader
from torch.optim import AdamW
from model.tiny_gpt2 import TinyGPT2, GPTConfig
from train.data_utils import TextDataset
def get_device(name):
if name == "auto":
return "cuda" if torch.cuda.is_available() else "cpu"
return name
def cosine_lr(step, max_steps, base, min_lr, warmup):
if step < warmup:
return base * step / max(1, warmup)
progress = (step - warmup)/max(1, max_steps - warmup)
return min_lr + 0.5*(base-min_lr)*(1+math.cos(math.pi*progress))
if __name__ == "__main__":
cfg = yaml.safe_load(open("train/config.yaml"))
device = get_device(cfg["device"])
Path(cfg["save_dir"]).mkdir(parents=True, exist_ok=True)
tok = Tokenizer.from_file(cfg["tokenizer_path"])
ids = tok.encode(open(cfg["train_txt"], "r", encoding="utf-8").read()).ids
ds = TextDataset(ids, cfg["block_size"])
dl = DataLoader(ds, batch_size=cfg["batch_size"], shuffle=True, drop_last=True)
gcfg = GPTConfig(
vocab_size=cfg["vocab_size"],
n_layer=cfg["n_layer"],
n_head=cfg["n_head"],
n_embed=cfg["n_embed"],
block_size=cfg["block_size"],
)
model = TinyGPT2(gcfg).to(device)
opt = AdamW(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"])
step, t0 = 0, time.time()
model.train()
for epoch in range(999999):
for x, y in dl:
step += 1
x, y = x.to(device), y.to(device)
logits = model(x)
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["grad_clip"])
lr = cosine_lr(step, cfg["max_steps"], cfg["lr"], cfg["min_lr"], cfg["warmup_steps"])
for g in opt.param_groups: g["lr"] = lr
opt.step(); opt.zero_grad(set_to_none=True)
if step % 100 == 0:
dt = time.time() - t0; t0 = time.time()
print(f"step {step:6d} | loss {loss.item():.4f} | lr {lr:.2e} | {dt:.2f}s")
if step >= cfg["max_steps"]:
torch.save(model.state_dict(), f"{cfg['save_dir']}/model.pt")
with open(f"{cfg['save_dir']}/gpt_config.json", "w") as f:
json.dump(gcfg.__dict__, f, indent=2)
print("saved checkpoint. done.")
raise SystemExit