Spaces:
Runtime error
Runtime error
| from contextlib import nullcontext | |
| import os | |
| import pickle | |
| import torch | |
| from model import ModelConfig, Namegen | |
| # ----------------------------------------------------------------------------- | |
| out_dir = "out" # ignored if init_from is not 'resume' | |
| num_samples = 10 # number of samples to draw | |
| max_new_tokens = 100 # number of tokens generated in each sample | |
| temperature = ( | |
| 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions | |
| ) | |
| top_k = ( | |
| 200 # retain only the top_k most likely tokens, clamp others to have 0 probability | |
| ) | |
| seed = 24 | |
| device = "cpu" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. | |
| # ----------------------------------------------------------------------------- | |
| torch.manual_seed(seed) | |
| device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast | |
| ctx = nullcontext() | |
| def sample_names(num_samples): | |
| ckpt_path = os.path.join(out_dir, "ckpt.pt") | |
| checkpoint = torch.load(ckpt_path, map_location=device) | |
| modelconf = ModelConfig(**checkpoint["model_args"]) | |
| model = Namegen(modelconf) | |
| state_dict = checkpoint["model"] | |
| unwanted_prefix = "_orig_mod." | |
| for k, v in list(state_dict.items()): | |
| if k.startswith(unwanted_prefix): | |
| state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| model.to(device) | |
| meta_path = os.path.join("data", "meta.pkl") | |
| print(f"Loading meta from {meta_path}...") | |
| with open(meta_path, "rb") as f: | |
| meta = pickle.load(f) | |
| merges = meta["merges"] | |
| stoi, itos = meta["stoi"], meta["itos"] | |
| decode1 = lambda l: "".join([itos[i] for i in l]) | |
| def unmerge(ids, pair, idx): | |
| newids = [] | |
| for i in ids: | |
| if i == idx: | |
| newids.append(pair[0]) | |
| newids.append(pair[1]) | |
| else: | |
| newids.append(i) | |
| return newids | |
| def decode(ids): | |
| tokens = list(ids) | |
| for pair, idx in reversed(merges.items()): | |
| tokens = unmerge(tokens, pair, idx) | |
| return decode1(tokens) | |
| names = [] | |
| with torch.no_grad(): | |
| while True: | |
| x = torch.full((1, 1), stoi["!"], dtype=torch.long, device=device) | |
| y = model.generate(x, max_new_tokens) | |
| raw = decode(y[0].tolist()) | |
| parts = raw.split("!") | |
| for i in range(1, len(parts) - 1): | |
| names.append(parts[i]) | |
| if len(names) >= num_samples: | |
| break | |
| if len(names) >= num_samples: | |
| break | |
| return names | |
| def main(): | |
| names = sample_names(num_samples) | |
| for n in names: | |
| print("---------------") | |
| print(n) | |
| if __name__ == "__main__": | |
| main() | |