File size: 2,528 Bytes
04e4b39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import yaml, math, time, json
import torch
from pathlib import Path
from tokenizers import Tokenizer
from torch.utils.data import DataLoader
from torch.optim import AdamW
from model.tiny_gpt2 import TinyGPT2, GPTConfig
from train.data_utils import TextDataset

def get_device(name):
    if name == "auto":
        return "cuda" if torch.cuda.is_available() else "cpu"
    return name

def cosine_lr(step, max_steps, base, min_lr, warmup):
    if step < warmup:
        return base * step / max(1, warmup)
    progress = (step - warmup)/max(1, max_steps - warmup)
    return min_lr + 0.5*(base-min_lr)*(1+math.cos(math.pi*progress))

if __name__ == "__main__":
    cfg = yaml.safe_load(open("train/config.yaml"))
    device = get_device(cfg["device"])
    Path(cfg["save_dir"]).mkdir(parents=True, exist_ok=True)

    tok = Tokenizer.from_file(cfg["tokenizer_path"])
    ids = tok.encode(open(cfg["train_txt"], "r", encoding="utf-8").read()).ids
    ds = TextDataset(ids, cfg["block_size"])
    dl = DataLoader(ds, batch_size=cfg["batch_size"], shuffle=True, drop_last=True)

    gcfg = GPTConfig(
        vocab_size=cfg["vocab_size"],
        n_layer=cfg["n_layer"],
        n_head=cfg["n_head"],
        n_embed=cfg["n_embed"],
        block_size=cfg["block_size"],
    )
    model = TinyGPT2(gcfg).to(device)

    opt = AdamW(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"])
    step, t0 = 0, time.time()
    model.train()
    for epoch in range(999999):
        for x, y in dl:
            step += 1
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["grad_clip"])
            lr = cosine_lr(step, cfg["max_steps"], cfg["lr"], cfg["min_lr"], cfg["warmup_steps"])
            for g in opt.param_groups: g["lr"] = lr
            opt.step(); opt.zero_grad(set_to_none=True)

            if step % 100 == 0:
                dt = time.time() - t0; t0 = time.time()
                print(f"step {step:6d} | loss {loss.item():.4f} | lr {lr:.2e} | {dt:.2f}s")

            if step >= cfg["max_steps"]:
                torch.save(model.state_dict(), f"{cfg['save_dir']}/model.pt")
                with open(f"{cfg['save_dir']}/gpt_config.json", "w") as f:
                    json.dump(gcfg.__dict__, f, indent=2)
                print("saved checkpoint. done.")
                raise SystemExit