File size: 2,133 Bytes
04e4b39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import json, yaml, time
import torch
from pathlib import Path
from tokenizers import Tokenizer
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from model.tiny_gpt2 import TinyGPT2, GPTConfig

class SFTDataset(Dataset):
    def __init__(self, jsonl_path, tokenizer, block_size):
        self.block = block_size
        self.tok = tokenizer
        self.samples = [json.loads(l) for l in open(jsonl_path, 'r', encoding='utf-8')]
        self.ids = []
        for s in self.samples:
            text = f"Instruction:\n{s['instruction'].strip()}\nAnswer:\n{s['output'].strip()}\n"
            self.ids.append(self.tok.encode(text).ids)
    def __len__(self): return len(self.ids)
    def __getitem__(self, i):
        ids = self.ids[i][:self.block]
        x = ids[:-1]; y = ids[1:]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

if __name__ == "__main__":
    cfg = yaml.safe_load(open("train/config.yaml"))
    Path("out/sft").mkdir(parents=True, exist_ok=True)
    tok = Tokenizer.from_file(cfg["tokenizer_path"])

    gcfg = GPTConfig(**json.load(open(Path(cfg["save_dir"]) / "gpt_config.json")))
    model = TinyGPT2(gcfg)
    model.load_state_dict(torch.load(Path(cfg["save_dir"])/"model.pt", map_location="cpu"))
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)

    ds = SFTDataset(cfg["sft_jsonl"], tok, gcfg.block_size)
    dl = DataLoader(ds, batch_size=8, shuffle=True, drop_last=True)
    opt = AdamW(model.parameters(), lr=1e-4)

    model.train()
    t0 = time.time()
    for step, (x,y) in enumerate(dl, start=1):
        x,y = x.to(device), y.to(device)
        logits = model(x)
        loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
        loss.backward(); opt.step(); opt.zero_grad(set_to_none=True)
        if step % 50 == 0:
            dt = time.time()-t0; t0=time.time()
            print(f"sft step {step:5d} | loss {loss.item():.4f} | {dt:.2f}s")
        if step >= 800: break

    torch.save(model.state_dict(), "out/sft/model_sft.pt")
    print("SFT saved.")