adityashisharma commited on
Commit
04e4b39
·
verified ·
1 Parent(s): 707323a

Upload 6 files

Browse files
train/config.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ vocab_size: 16000
2
+ block_size: 256
3
+ n_layer: 6
4
+ n_head: 6
5
+ n_embed: 384
6
+ batch_size: 32
7
+ micro_batches: 4
8
+ lr: 3.0e-4
9
+ min_lr: 3.0e-5
10
+ warmup_steps: 200
11
+ max_steps: 1000
12
+ weight_decay: 0.01
13
+ grad_clip: 1.0
14
+ dtype: "float32"
15
+ device: "auto"
16
+ save_dir: "out/pretrain"
17
+ tokenizer_path: "out/tokenizer.json"
18
+ train_txt: "data/corpus_raw.txt"
19
+ sft_jsonl: "data/sft_train.jsonl"
train/data_utils.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+
4
+ class TextDataset(Dataset):
5
+ def __init__(self, ids, block_size):
6
+ self.ids = ids
7
+ self.block = block_size
8
+ def __len__(self):
9
+ return max(1, len(self.ids) - self.block)
10
+ def __getitem__(self, i):
11
+ x = self.ids[i:i+self.block]
12
+ y = self.ids[i+1:i+self.block+1]
13
+ return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
train/gen_sample.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, argparse, json
2
+ from tokenizers import Tokenizer
3
+ from model.tiny_gpt2 import TinyGPT2, GPTConfig
4
+
5
+ parser = argparse.ArgumentParser()
6
+ parser.add_argument("--prompt", type=str, required=True)
7
+ parser.add_argument("--ckpt", type=str, default="out/sft/model_sft.pt")
8
+ parser.add_argument("--cfg", type=str, default="out/pretrain/gpt_config.json")
9
+ parser.add_argument("--tok", type=str, default="out/tokenizer.json")
10
+ args = parser.parse_args()
11
+
12
+ tok = Tokenizer.from_file(args.tok)
13
+ cfg = GPTConfig(**json.load(open(args.cfg)))
14
+ m = TinyGPT2(cfg)
15
+ m.load_state_dict(torch.load(args.ckpt, map_location="cpu"))
16
+ m.eval()
17
+
18
+ ids = tok.encode("[BOS] " + args.prompt).ids
19
+ x = torch.tensor([ids], dtype=torch.long)
20
+ with torch.no_grad():
21
+ y = m.generate(x, max_new_tokens=80)
22
+ text = tok.decode(y[0].tolist())
23
+ print(text)
train/prepare_corpus.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ SRC = Path("data/corpus_raw.txt")
4
+ if __name__ == "__main__":
5
+ text = SRC.read_text(encoding="utf-8", errors="ignore")
6
+ text = text.replace("\r\n", "\n").strip()
7
+ SRC.write_text(text, encoding="utf-8")
8
+ print("cleaned corpus in-place.")
train/pretrain.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml, math, time, json
2
+ import torch
3
+ from pathlib import Path
4
+ from tokenizers import Tokenizer
5
+ from torch.utils.data import DataLoader
6
+ from torch.optim import AdamW
7
+ from model.tiny_gpt2 import TinyGPT2, GPTConfig
8
+ from train.data_utils import TextDataset
9
+
10
+ def get_device(name):
11
+ if name == "auto":
12
+ return "cuda" if torch.cuda.is_available() else "cpu"
13
+ return name
14
+
15
+ def cosine_lr(step, max_steps, base, min_lr, warmup):
16
+ if step < warmup:
17
+ return base * step / max(1, warmup)
18
+ progress = (step - warmup)/max(1, max_steps - warmup)
19
+ return min_lr + 0.5*(base-min_lr)*(1+math.cos(math.pi*progress))
20
+
21
+ if __name__ == "__main__":
22
+ cfg = yaml.safe_load(open("train/config.yaml"))
23
+ device = get_device(cfg["device"])
24
+ Path(cfg["save_dir"]).mkdir(parents=True, exist_ok=True)
25
+
26
+ tok = Tokenizer.from_file(cfg["tokenizer_path"])
27
+ ids = tok.encode(open(cfg["train_txt"], "r", encoding="utf-8").read()).ids
28
+ ds = TextDataset(ids, cfg["block_size"])
29
+ dl = DataLoader(ds, batch_size=cfg["batch_size"], shuffle=True, drop_last=True)
30
+
31
+ gcfg = GPTConfig(
32
+ vocab_size=cfg["vocab_size"],
33
+ n_layer=cfg["n_layer"],
34
+ n_head=cfg["n_head"],
35
+ n_embed=cfg["n_embed"],
36
+ block_size=cfg["block_size"],
37
+ )
38
+ model = TinyGPT2(gcfg).to(device)
39
+
40
+ opt = AdamW(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"])
41
+ step, t0 = 0, time.time()
42
+ model.train()
43
+ for epoch in range(999999):
44
+ for x, y in dl:
45
+ step += 1
46
+ x, y = x.to(device), y.to(device)
47
+ logits = model(x)
48
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
49
+ loss.backward()
50
+ torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["grad_clip"])
51
+ lr = cosine_lr(step, cfg["max_steps"], cfg["lr"], cfg["min_lr"], cfg["warmup_steps"])
52
+ for g in opt.param_groups: g["lr"] = lr
53
+ opt.step(); opt.zero_grad(set_to_none=True)
54
+
55
+ if step % 100 == 0:
56
+ dt = time.time() - t0; t0 = time.time()
57
+ print(f"step {step:6d} | loss {loss.item():.4f} | lr {lr:.2e} | {dt:.2f}s")
58
+
59
+ if step >= cfg["max_steps"]:
60
+ torch.save(model.state_dict(), f"{cfg['save_dir']}/model.pt")
61
+ with open(f"{cfg['save_dir']}/gpt_config.json", "w") as f:
62
+ json.dump(gcfg.__dict__, f, indent=2)
63
+ print("saved checkpoint. done.")
64
+ raise SystemExit
train/sft.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, yaml, time
2
+ import torch
3
+ from pathlib import Path
4
+ from tokenizers import Tokenizer
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from torch.optim import AdamW
7
+ from model.tiny_gpt2 import TinyGPT2, GPTConfig
8
+
9
+ class SFTDataset(Dataset):
10
+ def __init__(self, jsonl_path, tokenizer, block_size):
11
+ self.block = block_size
12
+ self.tok = tokenizer
13
+ self.samples = [json.loads(l) for l in open(jsonl_path, 'r', encoding='utf-8')]
14
+ self.ids = []
15
+ for s in self.samples:
16
+ text = f"Instruction:\n{s['instruction'].strip()}\nAnswer:\n{s['output'].strip()}\n"
17
+ self.ids.append(self.tok.encode(text).ids)
18
+ def __len__(self): return len(self.ids)
19
+ def __getitem__(self, i):
20
+ ids = self.ids[i][:self.block]
21
+ x = ids[:-1]; y = ids[1:]
22
+ return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
23
+
24
+ if __name__ == "__main__":
25
+ cfg = yaml.safe_load(open("train/config.yaml"))
26
+ Path("out/sft").mkdir(parents=True, exist_ok=True)
27
+ tok = Tokenizer.from_file(cfg["tokenizer_path"])
28
+
29
+ gcfg = GPTConfig(**json.load(open(Path(cfg["save_dir"]) / "gpt_config.json")))
30
+ model = TinyGPT2(gcfg)
31
+ model.load_state_dict(torch.load(Path(cfg["save_dir"])/"model.pt", map_location="cpu"))
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ model = model.to(device)
34
+
35
+ ds = SFTDataset(cfg["sft_jsonl"], tok, gcfg.block_size)
36
+ dl = DataLoader(ds, batch_size=8, shuffle=True, drop_last=True)
37
+ opt = AdamW(model.parameters(), lr=1e-4)
38
+
39
+ model.train()
40
+ t0 = time.time()
41
+ for step, (x,y) in enumerate(dl, start=1):
42
+ x,y = x.to(device), y.to(device)
43
+ logits = model(x)
44
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
45
+ loss.backward(); opt.step(); opt.zero_grad(set_to_none=True)
46
+ if step % 50 == 0:
47
+ dt = time.time()-t0; t0=time.time()
48
+ print(f"sft step {step:5d} | loss {loss.item():.4f} | {dt:.2f}s")
49
+ if step >= 800: break
50
+
51
+ torch.save(model.state_dict(), "out/sft/model_sft.pt")
52
+ print("SFT saved.")