haykgrigorian commited on
Commit
e136ab4
Β·
verified Β·
1 Parent(s): dc7b558

Delete sample.py

Browse files
Files changed (1) hide show
  1. sample.py +0 -85
sample.py DELETED
@@ -1,85 +0,0 @@
1
- """
2
- Sample from a trained model (London GPT)
3
- """
4
- import os
5
- import pickle
6
- import torch
7
- from contextlib import nullcontext
8
- from model import GPTConfig, GPT
9
- from tokenizers import ByteLevelBPETokenizer
10
-
11
- # ─── tokenizer setup ────────────────────────────────────────────────
12
- tok_folder = "tokenizer_london"
13
- vocab_path = os.path.join(tok_folder, "vocab.json")
14
- merges_path = os.path.join(tok_folder, "merges.txt")
15
- if not (os.path.isfile(vocab_path) and os.path.isfile(merges_path)):
16
- raise FileNotFoundError(f"Cannot find tokenizer files in {tok_folder}: {vocab_path}, {merges_path}")
17
-
18
- tokenizer = ByteLevelBPETokenizer(vocab_path, merges_path)
19
- encode = lambda s: tokenizer.encode(s).ids
20
- decode = lambda ids: tokenizer.decode(ids)
21
- # ────────────────────────────────────────────────────────────────────
22
-
23
- # ─── experiment settings (you can override via CLI) ────────────────
24
- import sys
25
- import argparse
26
- parser = argparse.ArgumentParser()
27
- parser.add_argument("--out_dir", default="out_london")
28
- parser.add_argument("--device", default="cpu")
29
- parser.add_argument("--start", default="\n")
30
- parser.add_argument("--num_samples", type=int, default=10)
31
- parser.add_argument("--max_new_tokens", type=int, default=500)
32
- parser.add_argument("--temperature", type=float, default=0.8)
33
- parser.add_argument("--top_k", type=int, default=200)
34
- parser.add_argument("--seed", type=int, default=1337)
35
- parser.add_argument("--compile", action="store_true")
36
- args = parser.parse_args()
37
-
38
- out_dir = args.out_dir
39
- start = args.start
40
- num_samples = args.num_samples
41
- max_new_tokens = args.max_new_tokens
42
- temperature = args.temperature
43
- top_k = args.top_k
44
- seed = args.seed
45
- compile_flag = args.compile
46
- device_str = args.device
47
- # ────────────────────────────────────────────────────────────────────
48
-
49
- # reproducibility & device
50
- torch.manual_seed(seed)
51
- device = torch.device(device_str)
52
- ctx = nullcontext() if device.type == "cpu" else torch.amp.autocast(device_type=device.type, dtype=torch.float32)
53
-
54
- # ─── load model checkpoint ──────────────────────────────────────────
55
- ckpt_path = os.path.join(out_dir, "ckpt.pt")
56
- if not os.path.isfile(ckpt_path):
57
- raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
58
- ckpt = torch.load(ckpt_path, map_location=device)
59
- gptconf = GPTConfig(**ckpt["model_args"])
60
- model = GPT(gptconf)
61
- sd = ckpt["model"]
62
- # strip prefix if present
63
- for k in list(sd.keys()):
64
- if k.startswith("_orig_mod."):
65
- sd[k[len("_orig_mod."):]] = sd.pop(k)
66
- model.load_state_dict(sd)
67
- model.eval().to(device)
68
- if compile_flag:
69
- model = torch.compile(model)
70
- # ────────────────────────────────────────────────────────────────────
71
-
72
- # prepare prompt tensor
73
- if start.startswith("FILE:"):
74
- with open(start[5:], "r", encoding="utf-8") as f:
75
- start = f.read()
76
- ids = encode(start)
77
- x = torch.tensor([ids], dtype=torch.long, device=device)
78
-
79
- # ─── generation ─────────────────────────────────────────────────────
80
- with torch.no_grad(), ctx:
81
- for i in range(num_samples):
82
- y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
83
- print(decode(y[0].tolist()))
84
- print("---------------")
85
- # ────────────────────────────────────────────────────────────────────