"""GPT-2-PL probe on Apple Silicon (MPS). Usage: python sample_mac.py "Prompt" [n_tok]""" import os, sys, time, torch, torch.nn.functional as F from model import GPTConfig, GPT from tokenizers import Tokenizer HERE = os.path.dirname(os.path.abspath(__file__)) ROOT = os.path.dirname(HERE) TEMP, TOP_K, TOP_P, REP_PEN, NGRAM = 0.7, 40, 0.92, 1.15, 3 EOT = 0 dev = "mps" if torch.backends.mps.is_available() else "cpu" prompt = sys.argv[1] if len(sys.argv) > 1 else "Polska jest" maxtok = int(sys.argv[2]) if len(sys.argv) > 2 else 90 ck = torch.load(os.path.join(ROOT, "model", "ckpt.pt"), map_location="cpu") m = GPT(GPTConfig(**ck["model_args"])) sd = ck["model"] for k in list(sd): if k.startswith("_orig_mod."): sd[k[len("_orig_mod."):]] = sd.pop(k) m.load_state_dict(sd); m.eval().to(dev) block = ck["model_args"]["block_size"] tok = Tokenizer.from_file(os.path.join(ROOT, "tokenizers", "polish_bpe_32k.json")) def banned(seq, n): if len(seq) < n - 1: return set() pre = tuple(seq[-(n-1):]); bad = set() for i in range(len(seq)-n+1): if tuple(seq[i:i+n-1]) == pre: bad.add(seq[i+n-1]) return bad @torch.no_grad() def gen(prompt, maxtok): idx = torch.tensor(tok.encode(prompt).ids, dtype=torch.long, device=dev)[None] t0 = time.time(); n = 0 for _ in range(maxtok): logits, _ = m(idx[:, -block:]); logits = logits[:, -1, :].float() for t in set(idx[0].tolist()): logits[0, t] /= REP_PEN if logits[0, t] > 0 else 1/REP_PEN for t in banned(idx[0].tolist(), NGRAM): logits[0, t] = -float("inf") logits /= TEMP kth = torch.topk(logits, TOP_K)[0][..., -1, None]; logits[logits < kth] = -float("inf") sl, si = torch.sort(logits, descending=True) cum = torch.cumsum(F.softmax(sl, dim=-1), dim=-1) rm = cum > TOP_P; rm[..., 1:] = rm[..., :-1].clone(); rm[..., 0] = False logits[0, si[0][rm[0]]] = -float("inf") nxt = torch.multinomial(F.softmax(logits, dim=-1), 1); n += 1 if nxt.item() == EOT: break idx = torch.cat([idx, nxt], dim=1) dt = time.time() - t0 return tok.decode(idx[0].tolist()), n/dt txt, tps = gen(prompt, maxtok) print(f"[device={dev} {tps:.1f} tok/s]\n") print(txt)