adityashisharma's picture
Upload 6 files
04e4b39 verified
raw
history blame
2.53 kB
import yaml, math, time, json
import torch
from pathlib import Path
from tokenizers import Tokenizer
from torch.utils.data import DataLoader
from torch.optim import AdamW
from model.tiny_gpt2 import TinyGPT2, GPTConfig
from train.data_utils import TextDataset
def get_device(name):
if name == "auto":
return "cuda" if torch.cuda.is_available() else "cpu"
return name
def cosine_lr(step, max_steps, base, min_lr, warmup):
if step < warmup:
return base * step / max(1, warmup)
progress = (step - warmup)/max(1, max_steps - warmup)
return min_lr + 0.5*(base-min_lr)*(1+math.cos(math.pi*progress))
if __name__ == "__main__":
cfg = yaml.safe_load(open("train/config.yaml"))
device = get_device(cfg["device"])
Path(cfg["save_dir"]).mkdir(parents=True, exist_ok=True)
tok = Tokenizer.from_file(cfg["tokenizer_path"])
ids = tok.encode(open(cfg["train_txt"], "r", encoding="utf-8").read()).ids
ds = TextDataset(ids, cfg["block_size"])
dl = DataLoader(ds, batch_size=cfg["batch_size"], shuffle=True, drop_last=True)
gcfg = GPTConfig(
vocab_size=cfg["vocab_size"],
n_layer=cfg["n_layer"],
n_head=cfg["n_head"],
n_embed=cfg["n_embed"],
block_size=cfg["block_size"],
)
model = TinyGPT2(gcfg).to(device)
opt = AdamW(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"])
step, t0 = 0, time.time()
model.train()
for epoch in range(999999):
for x, y in dl:
step += 1
x, y = x.to(device), y.to(device)
logits = model(x)
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["grad_clip"])
lr = cosine_lr(step, cfg["max_steps"], cfg["lr"], cfg["min_lr"], cfg["warmup_steps"])
for g in opt.param_groups: g["lr"] = lr
opt.step(); opt.zero_grad(set_to_none=True)
if step % 100 == 0:
dt = time.time() - t0; t0 = time.time()
print(f"step {step:6d} | loss {loss.item():.4f} | lr {lr:.2e} | {dt:.2f}s")
if step >= cfg["max_steps"]:
torch.save(model.state_dict(), f"{cfg['save_dir']}/model.pt")
with open(f"{cfg['save_dir']}/gpt_config.json", "w") as f:
json.dump(gcfg.__dict__, f, indent=2)
print("saved checkpoint. done.")
raise SystemExit