adityashisharma's picture
Upload 6 files
04e4b39 verified
import json, yaml, time
import torch
from pathlib import Path
from tokenizers import Tokenizer
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from model.tiny_gpt2 import TinyGPT2, GPTConfig
class SFTDataset(Dataset):
def __init__(self, jsonl_path, tokenizer, block_size):
self.block = block_size
self.tok = tokenizer
self.samples = [json.loads(l) for l in open(jsonl_path, 'r', encoding='utf-8')]
self.ids = []
for s in self.samples:
text = f"Instruction:\n{s['instruction'].strip()}\nAnswer:\n{s['output'].strip()}\n"
self.ids.append(self.tok.encode(text).ids)
def __len__(self): return len(self.ids)
def __getitem__(self, i):
ids = self.ids[i][:self.block]
x = ids[:-1]; y = ids[1:]
return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
if __name__ == "__main__":
cfg = yaml.safe_load(open("train/config.yaml"))
Path("out/sft").mkdir(parents=True, exist_ok=True)
tok = Tokenizer.from_file(cfg["tokenizer_path"])
gcfg = GPTConfig(**json.load(open(Path(cfg["save_dir"]) / "gpt_config.json")))
model = TinyGPT2(gcfg)
model.load_state_dict(torch.load(Path(cfg["save_dir"])/"model.pt", map_location="cpu"))
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
ds = SFTDataset(cfg["sft_jsonl"], tok, gcfg.block_size)
dl = DataLoader(ds, batch_size=8, shuffle=True, drop_last=True)
opt = AdamW(model.parameters(), lr=1e-4)
model.train()
t0 = time.time()
for step, (x,y) in enumerate(dl, start=1):
x,y = x.to(device), y.to(device)
logits = model(x)
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
loss.backward(); opt.step(); opt.zero_grad(set_to_none=True)
if step % 50 == 0:
dt = time.time()-t0; t0=time.time()
print(f"sft step {step:5d} | loss {loss.item():.4f} | {dt:.2f}s")
if step >= 800: break
torch.save(model.state_dict(), "out/sft/model_sft.pt")
print("SFT saved.")