| import json, yaml, time | |
| import torch | |
| from pathlib import Path | |
| from tokenizers import Tokenizer | |
| from torch.utils.data import Dataset, DataLoader | |
| from torch.optim import AdamW | |
| from model.tiny_gpt2 import TinyGPT2, GPTConfig | |
| class SFTDataset(Dataset): | |
| def __init__(self, jsonl_path, tokenizer, block_size): | |
| self.block = block_size | |
| self.tok = tokenizer | |
| self.samples = [json.loads(l) for l in open(jsonl_path, 'r', encoding='utf-8')] | |
| self.ids = [] | |
| for s in self.samples: | |
| text = f"Instruction:\n{s['instruction'].strip()}\nAnswer:\n{s['output'].strip()}\n" | |
| self.ids.append(self.tok.encode(text).ids) | |
| def __len__(self): return len(self.ids) | |
| def __getitem__(self, i): | |
| ids = self.ids[i][:self.block] | |
| x = ids[:-1]; y = ids[1:] | |
| return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long) | |
| if __name__ == "__main__": | |
| cfg = yaml.safe_load(open("train/config.yaml")) | |
| Path("out/sft").mkdir(parents=True, exist_ok=True) | |
| tok = Tokenizer.from_file(cfg["tokenizer_path"]) | |
| gcfg = GPTConfig(**json.load(open(Path(cfg["save_dir"]) / "gpt_config.json"))) | |
| model = TinyGPT2(gcfg) | |
| model.load_state_dict(torch.load(Path(cfg["save_dir"])/"model.pt", map_location="cpu")) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| ds = SFTDataset(cfg["sft_jsonl"], tok, gcfg.block_size) | |
| dl = DataLoader(ds, batch_size=8, shuffle=True, drop_last=True) | |
| opt = AdamW(model.parameters(), lr=1e-4) | |
| model.train() | |
| t0 = time.time() | |
| for step, (x,y) in enumerate(dl, start=1): | |
| x,y = x.to(device), y.to(device) | |
| logits = model(x) | |
| loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1)) | |
| loss.backward(); opt.step(); opt.zero_grad(set_to_none=True) | |
| if step % 50 == 0: | |
| dt = time.time()-t0; t0=time.time() | |
| print(f"sft step {step:5d} | loss {loss.item():.4f} | {dt:.2f}s") | |
| if step >= 800: break | |
| torch.save(model.state_dict(), "out/sft/model_sft.pt") | |
| print("SFT saved.") | |