Delete sample.py
Browse files
sample.py
DELETED
|
@@ -1,85 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Sample from a trained model (London GPT)
|
| 3 |
-
"""
|
| 4 |
-
import os
|
| 5 |
-
import pickle
|
| 6 |
-
import torch
|
| 7 |
-
from contextlib import nullcontext
|
| 8 |
-
from model import GPTConfig, GPT
|
| 9 |
-
from tokenizers import ByteLevelBPETokenizer
|
| 10 |
-
|
| 11 |
-
# βββ tokenizer setup ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 12 |
-
tok_folder = "tokenizer_london"
|
| 13 |
-
vocab_path = os.path.join(tok_folder, "vocab.json")
|
| 14 |
-
merges_path = os.path.join(tok_folder, "merges.txt")
|
| 15 |
-
if not (os.path.isfile(vocab_path) and os.path.isfile(merges_path)):
|
| 16 |
-
raise FileNotFoundError(f"Cannot find tokenizer files in {tok_folder}: {vocab_path}, {merges_path}")
|
| 17 |
-
|
| 18 |
-
tokenizer = ByteLevelBPETokenizer(vocab_path, merges_path)
|
| 19 |
-
encode = lambda s: tokenizer.encode(s).ids
|
| 20 |
-
decode = lambda ids: tokenizer.decode(ids)
|
| 21 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
-
|
| 23 |
-
# βββ experiment settings (you can override via CLI) ββββββββββββββββ
|
| 24 |
-
import sys
|
| 25 |
-
import argparse
|
| 26 |
-
parser = argparse.ArgumentParser()
|
| 27 |
-
parser.add_argument("--out_dir", default="out_london")
|
| 28 |
-
parser.add_argument("--device", default="cpu")
|
| 29 |
-
parser.add_argument("--start", default="\n")
|
| 30 |
-
parser.add_argument("--num_samples", type=int, default=10)
|
| 31 |
-
parser.add_argument("--max_new_tokens", type=int, default=500)
|
| 32 |
-
parser.add_argument("--temperature", type=float, default=0.8)
|
| 33 |
-
parser.add_argument("--top_k", type=int, default=200)
|
| 34 |
-
parser.add_argument("--seed", type=int, default=1337)
|
| 35 |
-
parser.add_argument("--compile", action="store_true")
|
| 36 |
-
args = parser.parse_args()
|
| 37 |
-
|
| 38 |
-
out_dir = args.out_dir
|
| 39 |
-
start = args.start
|
| 40 |
-
num_samples = args.num_samples
|
| 41 |
-
max_new_tokens = args.max_new_tokens
|
| 42 |
-
temperature = args.temperature
|
| 43 |
-
top_k = args.top_k
|
| 44 |
-
seed = args.seed
|
| 45 |
-
compile_flag = args.compile
|
| 46 |
-
device_str = args.device
|
| 47 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 48 |
-
|
| 49 |
-
# reproducibility & device
|
| 50 |
-
torch.manual_seed(seed)
|
| 51 |
-
device = torch.device(device_str)
|
| 52 |
-
ctx = nullcontext() if device.type == "cpu" else torch.amp.autocast(device_type=device.type, dtype=torch.float32)
|
| 53 |
-
|
| 54 |
-
# βββ load model checkpoint ββββββββββββββββββββββββββββββββββββββββββ
|
| 55 |
-
ckpt_path = os.path.join(out_dir, "ckpt.pt")
|
| 56 |
-
if not os.path.isfile(ckpt_path):
|
| 57 |
-
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
| 58 |
-
ckpt = torch.load(ckpt_path, map_location=device)
|
| 59 |
-
gptconf = GPTConfig(**ckpt["model_args"])
|
| 60 |
-
model = GPT(gptconf)
|
| 61 |
-
sd = ckpt["model"]
|
| 62 |
-
# strip prefix if present
|
| 63 |
-
for k in list(sd.keys()):
|
| 64 |
-
if k.startswith("_orig_mod."):
|
| 65 |
-
sd[k[len("_orig_mod."):]] = sd.pop(k)
|
| 66 |
-
model.load_state_dict(sd)
|
| 67 |
-
model.eval().to(device)
|
| 68 |
-
if compile_flag:
|
| 69 |
-
model = torch.compile(model)
|
| 70 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 71 |
-
|
| 72 |
-
# prepare prompt tensor
|
| 73 |
-
if start.startswith("FILE:"):
|
| 74 |
-
with open(start[5:], "r", encoding="utf-8") as f:
|
| 75 |
-
start = f.read()
|
| 76 |
-
ids = encode(start)
|
| 77 |
-
x = torch.tensor([ids], dtype=torch.long, device=device)
|
| 78 |
-
|
| 79 |
-
# βββ generation βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 80 |
-
with torch.no_grad(), ctx:
|
| 81 |
-
for i in range(num_samples):
|
| 82 |
-
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
|
| 83 |
-
print(decode(y[0].tolist()))
|
| 84 |
-
print("---------------")
|
| 85 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|