Upload 6 files
Browse files- train/config.yaml +19 -0
- train/data_utils.py +13 -0
- train/gen_sample.py +23 -0
- train/prepare_corpus.py +8 -0
- train/pretrain.py +64 -0
- train/sft.py +52 -0
train/config.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
vocab_size: 16000
|
| 2 |
+
block_size: 256
|
| 3 |
+
n_layer: 6
|
| 4 |
+
n_head: 6
|
| 5 |
+
n_embed: 384
|
| 6 |
+
batch_size: 32
|
| 7 |
+
micro_batches: 4
|
| 8 |
+
lr: 3.0e-4
|
| 9 |
+
min_lr: 3.0e-5
|
| 10 |
+
warmup_steps: 200
|
| 11 |
+
max_steps: 1000
|
| 12 |
+
weight_decay: 0.01
|
| 13 |
+
grad_clip: 1.0
|
| 14 |
+
dtype: "float32"
|
| 15 |
+
device: "auto"
|
| 16 |
+
save_dir: "out/pretrain"
|
| 17 |
+
tokenizer_path: "out/tokenizer.json"
|
| 18 |
+
train_txt: "data/corpus_raw.txt"
|
| 19 |
+
sft_jsonl: "data/sft_train.jsonl"
|
train/data_utils.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
|
| 4 |
+
class TextDataset(Dataset):
|
| 5 |
+
def __init__(self, ids, block_size):
|
| 6 |
+
self.ids = ids
|
| 7 |
+
self.block = block_size
|
| 8 |
+
def __len__(self):
|
| 9 |
+
return max(1, len(self.ids) - self.block)
|
| 10 |
+
def __getitem__(self, i):
|
| 11 |
+
x = self.ids[i:i+self.block]
|
| 12 |
+
y = self.ids[i+1:i+self.block+1]
|
| 13 |
+
return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
|
train/gen_sample.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, argparse, json
|
| 2 |
+
from tokenizers import Tokenizer
|
| 3 |
+
from model.tiny_gpt2 import TinyGPT2, GPTConfig
|
| 4 |
+
|
| 5 |
+
parser = argparse.ArgumentParser()
|
| 6 |
+
parser.add_argument("--prompt", type=str, required=True)
|
| 7 |
+
parser.add_argument("--ckpt", type=str, default="out/sft/model_sft.pt")
|
| 8 |
+
parser.add_argument("--cfg", type=str, default="out/pretrain/gpt_config.json")
|
| 9 |
+
parser.add_argument("--tok", type=str, default="out/tokenizer.json")
|
| 10 |
+
args = parser.parse_args()
|
| 11 |
+
|
| 12 |
+
tok = Tokenizer.from_file(args.tok)
|
| 13 |
+
cfg = GPTConfig(**json.load(open(args.cfg)))
|
| 14 |
+
m = TinyGPT2(cfg)
|
| 15 |
+
m.load_state_dict(torch.load(args.ckpt, map_location="cpu"))
|
| 16 |
+
m.eval()
|
| 17 |
+
|
| 18 |
+
ids = tok.encode("[BOS] " + args.prompt).ids
|
| 19 |
+
x = torch.tensor([ids], dtype=torch.long)
|
| 20 |
+
with torch.no_grad():
|
| 21 |
+
y = m.generate(x, max_new_tokens=80)
|
| 22 |
+
text = tok.decode(y[0].tolist())
|
| 23 |
+
print(text)
|
train/prepare_corpus.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
SRC = Path("data/corpus_raw.txt")
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
text = SRC.read_text(encoding="utf-8", errors="ignore")
|
| 6 |
+
text = text.replace("\r\n", "\n").strip()
|
| 7 |
+
SRC.write_text(text, encoding="utf-8")
|
| 8 |
+
print("cleaned corpus in-place.")
|
train/pretrain.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml, math, time, json
|
| 2 |
+
import torch
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from tokenizers import Tokenizer
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torch.optim import AdamW
|
| 7 |
+
from model.tiny_gpt2 import TinyGPT2, GPTConfig
|
| 8 |
+
from train.data_utils import TextDataset
|
| 9 |
+
|
| 10 |
+
def get_device(name):
|
| 11 |
+
if name == "auto":
|
| 12 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
| 13 |
+
return name
|
| 14 |
+
|
| 15 |
+
def cosine_lr(step, max_steps, base, min_lr, warmup):
|
| 16 |
+
if step < warmup:
|
| 17 |
+
return base * step / max(1, warmup)
|
| 18 |
+
progress = (step - warmup)/max(1, max_steps - warmup)
|
| 19 |
+
return min_lr + 0.5*(base-min_lr)*(1+math.cos(math.pi*progress))
|
| 20 |
+
|
| 21 |
+
if __name__ == "__main__":
|
| 22 |
+
cfg = yaml.safe_load(open("train/config.yaml"))
|
| 23 |
+
device = get_device(cfg["device"])
|
| 24 |
+
Path(cfg["save_dir"]).mkdir(parents=True, exist_ok=True)
|
| 25 |
+
|
| 26 |
+
tok = Tokenizer.from_file(cfg["tokenizer_path"])
|
| 27 |
+
ids = tok.encode(open(cfg["train_txt"], "r", encoding="utf-8").read()).ids
|
| 28 |
+
ds = TextDataset(ids, cfg["block_size"])
|
| 29 |
+
dl = DataLoader(ds, batch_size=cfg["batch_size"], shuffle=True, drop_last=True)
|
| 30 |
+
|
| 31 |
+
gcfg = GPTConfig(
|
| 32 |
+
vocab_size=cfg["vocab_size"],
|
| 33 |
+
n_layer=cfg["n_layer"],
|
| 34 |
+
n_head=cfg["n_head"],
|
| 35 |
+
n_embed=cfg["n_embed"],
|
| 36 |
+
block_size=cfg["block_size"],
|
| 37 |
+
)
|
| 38 |
+
model = TinyGPT2(gcfg).to(device)
|
| 39 |
+
|
| 40 |
+
opt = AdamW(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"])
|
| 41 |
+
step, t0 = 0, time.time()
|
| 42 |
+
model.train()
|
| 43 |
+
for epoch in range(999999):
|
| 44 |
+
for x, y in dl:
|
| 45 |
+
step += 1
|
| 46 |
+
x, y = x.to(device), y.to(device)
|
| 47 |
+
logits = model(x)
|
| 48 |
+
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
|
| 49 |
+
loss.backward()
|
| 50 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["grad_clip"])
|
| 51 |
+
lr = cosine_lr(step, cfg["max_steps"], cfg["lr"], cfg["min_lr"], cfg["warmup_steps"])
|
| 52 |
+
for g in opt.param_groups: g["lr"] = lr
|
| 53 |
+
opt.step(); opt.zero_grad(set_to_none=True)
|
| 54 |
+
|
| 55 |
+
if step % 100 == 0:
|
| 56 |
+
dt = time.time() - t0; t0 = time.time()
|
| 57 |
+
print(f"step {step:6d} | loss {loss.item():.4f} | lr {lr:.2e} | {dt:.2f}s")
|
| 58 |
+
|
| 59 |
+
if step >= cfg["max_steps"]:
|
| 60 |
+
torch.save(model.state_dict(), f"{cfg['save_dir']}/model.pt")
|
| 61 |
+
with open(f"{cfg['save_dir']}/gpt_config.json", "w") as f:
|
| 62 |
+
json.dump(gcfg.__dict__, f, indent=2)
|
| 63 |
+
print("saved checkpoint. done.")
|
| 64 |
+
raise SystemExit
|
train/sft.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json, yaml, time
|
| 2 |
+
import torch
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from tokenizers import Tokenizer
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader
|
| 6 |
+
from torch.optim import AdamW
|
| 7 |
+
from model.tiny_gpt2 import TinyGPT2, GPTConfig
|
| 8 |
+
|
| 9 |
+
class SFTDataset(Dataset):
|
| 10 |
+
def __init__(self, jsonl_path, tokenizer, block_size):
|
| 11 |
+
self.block = block_size
|
| 12 |
+
self.tok = tokenizer
|
| 13 |
+
self.samples = [json.loads(l) for l in open(jsonl_path, 'r', encoding='utf-8')]
|
| 14 |
+
self.ids = []
|
| 15 |
+
for s in self.samples:
|
| 16 |
+
text = f"Instruction:\n{s['instruction'].strip()}\nAnswer:\n{s['output'].strip()}\n"
|
| 17 |
+
self.ids.append(self.tok.encode(text).ids)
|
| 18 |
+
def __len__(self): return len(self.ids)
|
| 19 |
+
def __getitem__(self, i):
|
| 20 |
+
ids = self.ids[i][:self.block]
|
| 21 |
+
x = ids[:-1]; y = ids[1:]
|
| 22 |
+
return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
|
| 23 |
+
|
| 24 |
+
if __name__ == "__main__":
|
| 25 |
+
cfg = yaml.safe_load(open("train/config.yaml"))
|
| 26 |
+
Path("out/sft").mkdir(parents=True, exist_ok=True)
|
| 27 |
+
tok = Tokenizer.from_file(cfg["tokenizer_path"])
|
| 28 |
+
|
| 29 |
+
gcfg = GPTConfig(**json.load(open(Path(cfg["save_dir"]) / "gpt_config.json")))
|
| 30 |
+
model = TinyGPT2(gcfg)
|
| 31 |
+
model.load_state_dict(torch.load(Path(cfg["save_dir"])/"model.pt", map_location="cpu"))
|
| 32 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 33 |
+
model = model.to(device)
|
| 34 |
+
|
| 35 |
+
ds = SFTDataset(cfg["sft_jsonl"], tok, gcfg.block_size)
|
| 36 |
+
dl = DataLoader(ds, batch_size=8, shuffle=True, drop_last=True)
|
| 37 |
+
opt = AdamW(model.parameters(), lr=1e-4)
|
| 38 |
+
|
| 39 |
+
model.train()
|
| 40 |
+
t0 = time.time()
|
| 41 |
+
for step, (x,y) in enumerate(dl, start=1):
|
| 42 |
+
x,y = x.to(device), y.to(device)
|
| 43 |
+
logits = model(x)
|
| 44 |
+
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
|
| 45 |
+
loss.backward(); opt.step(); opt.zero_grad(set_to_none=True)
|
| 46 |
+
if step % 50 == 0:
|
| 47 |
+
dt = time.time()-t0; t0=time.time()
|
| 48 |
+
print(f"sft step {step:5d} | loss {loss.item():.4f} | {dt:.2f}s")
|
| 49 |
+
if step >= 800: break
|
| 50 |
+
|
| 51 |
+
torch.save(model.state_dict(), "out/sft/model_sft.pt")
|
| 52 |
+
print("SFT saved.")
|