namegen-1 / sample.py
hardik-vala's picture
Add app.
17b091d
from contextlib import nullcontext
import os
import pickle
import torch
from model import ModelConfig, Namegen
# -----------------------------------------------------------------------------
out_dir = "out" # ignored if init_from is not 'resume'
num_samples = 10 # number of samples to draw
max_new_tokens = 100 # number of tokens generated in each sample
temperature = (
0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
)
top_k = (
200 # retain only the top_k most likely tokens, clamp others to have 0 probability
)
seed = 24
device = "cpu" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
# -----------------------------------------------------------------------------
torch.manual_seed(seed)
device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast
ctx = nullcontext()
def sample_names(num_samples):
ckpt_path = os.path.join(out_dir, "ckpt.pt")
checkpoint = torch.load(ckpt_path, map_location=device)
modelconf = ModelConfig(**checkpoint["model_args"])
model = Namegen(modelconf)
state_dict = checkpoint["model"]
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval()
model.to(device)
meta_path = os.path.join("data", "meta.pkl")
print(f"Loading meta from {meta_path}...")
with open(meta_path, "rb") as f:
meta = pickle.load(f)
merges = meta["merges"]
stoi, itos = meta["stoi"], meta["itos"]
decode1 = lambda l: "".join([itos[i] for i in l])
def unmerge(ids, pair, idx):
newids = []
for i in ids:
if i == idx:
newids.append(pair[0])
newids.append(pair[1])
else:
newids.append(i)
return newids
def decode(ids):
tokens = list(ids)
for pair, idx in reversed(merges.items()):
tokens = unmerge(tokens, pair, idx)
return decode1(tokens)
names = []
with torch.no_grad():
while True:
x = torch.full((1, 1), stoi["!"], dtype=torch.long, device=device)
y = model.generate(x, max_new_tokens)
raw = decode(y[0].tolist())
parts = raw.split("!")
for i in range(1, len(parts) - 1):
names.append(parts[i])
if len(names) >= num_samples:
break
if len(names) >= num_samples:
break
return names
def main():
names = sample_names(num_samples)
for n in names:
print("---------------")
print(n)
if __name__ == "__main__":
main()