| """GPT-2-PL probe on Apple Silicon (MPS). Usage: python sample_mac.py "Prompt" [n_tok]""" | |
| import os, sys, time, torch, torch.nn.functional as F | |
| from model import GPTConfig, GPT | |
| from tokenizers import Tokenizer | |
| HERE = os.path.dirname(os.path.abspath(__file__)) | |
| ROOT = os.path.dirname(HERE) | |
| TEMP, TOP_K, TOP_P, REP_PEN, NGRAM = 0.7, 40, 0.92, 1.15, 3 | |
| EOT = 0 | |
| dev = "mps" if torch.backends.mps.is_available() else "cpu" | |
| prompt = sys.argv[1] if len(sys.argv) > 1 else "Polska jest" | |
| maxtok = int(sys.argv[2]) if len(sys.argv) > 2 else 90 | |
| ck = torch.load(os.path.join(ROOT, "model", "ckpt.pt"), map_location="cpu") | |
| m = GPT(GPTConfig(**ck["model_args"])) | |
| sd = ck["model"] | |
| for k in list(sd): | |
| if k.startswith("_orig_mod."): | |
| sd[k[len("_orig_mod."):]] = sd.pop(k) | |
| m.load_state_dict(sd); m.eval().to(dev) | |
| block = ck["model_args"]["block_size"] | |
| tok = Tokenizer.from_file(os.path.join(ROOT, "tokenizers", "polish_bpe_32k.json")) | |
| def banned(seq, n): | |
| if len(seq) < n - 1: return set() | |
| pre = tuple(seq[-(n-1):]); bad = set() | |
| for i in range(len(seq)-n+1): | |
| if tuple(seq[i:i+n-1]) == pre: bad.add(seq[i+n-1]) | |
| return bad | |
| def gen(prompt, maxtok): | |
| idx = torch.tensor(tok.encode(prompt).ids, dtype=torch.long, device=dev)[None] | |
| t0 = time.time(); n = 0 | |
| for _ in range(maxtok): | |
| logits, _ = m(idx[:, -block:]); logits = logits[:, -1, :].float() | |
| for t in set(idx[0].tolist()): | |
| logits[0, t] /= REP_PEN if logits[0, t] > 0 else 1/REP_PEN | |
| for t in banned(idx[0].tolist(), NGRAM): | |
| logits[0, t] = -float("inf") | |
| logits /= TEMP | |
| kth = torch.topk(logits, TOP_K)[0][..., -1, None]; logits[logits < kth] = -float("inf") | |
| sl, si = torch.sort(logits, descending=True) | |
| cum = torch.cumsum(F.softmax(sl, dim=-1), dim=-1) | |
| rm = cum > TOP_P; rm[..., 1:] = rm[..., :-1].clone(); rm[..., 0] = False | |
| logits[0, si[0][rm[0]]] = -float("inf") | |
| nxt = torch.multinomial(F.softmax(logits, dim=-1), 1); n += 1 | |
| if nxt.item() == EOT: break | |
| idx = torch.cat([idx, nxt], dim=1) | |
| dt = time.time() - t0 | |
| return tok.decode(idx[0].tolist()), n/dt | |
| txt, tps = gen(prompt, maxtok) | |
| print(f"[device={dev} {tps:.1f} tok/s]\n") | |
| print(txt) | |