File size: 2,262 Bytes
78c54ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | """GPT-2-PL probe on Apple Silicon (MPS). Usage: python sample_mac.py "Prompt" [n_tok]"""
import os, sys, time, torch, torch.nn.functional as F
from model import GPTConfig, GPT
from tokenizers import Tokenizer
HERE = os.path.dirname(os.path.abspath(__file__))
ROOT = os.path.dirname(HERE)
TEMP, TOP_K, TOP_P, REP_PEN, NGRAM = 0.7, 40, 0.92, 1.15, 3
EOT = 0
dev = "mps" if torch.backends.mps.is_available() else "cpu"
prompt = sys.argv[1] if len(sys.argv) > 1 else "Polska jest"
maxtok = int(sys.argv[2]) if len(sys.argv) > 2 else 90
ck = torch.load(os.path.join(ROOT, "model", "ckpt.pt"), map_location="cpu")
m = GPT(GPTConfig(**ck["model_args"]))
sd = ck["model"]
for k in list(sd):
if k.startswith("_orig_mod."):
sd[k[len("_orig_mod."):]] = sd.pop(k)
m.load_state_dict(sd); m.eval().to(dev)
block = ck["model_args"]["block_size"]
tok = Tokenizer.from_file(os.path.join(ROOT, "tokenizers", "polish_bpe_32k.json"))
def banned(seq, n):
if len(seq) < n - 1: return set()
pre = tuple(seq[-(n-1):]); bad = set()
for i in range(len(seq)-n+1):
if tuple(seq[i:i+n-1]) == pre: bad.add(seq[i+n-1])
return bad
@torch.no_grad()
def gen(prompt, maxtok):
idx = torch.tensor(tok.encode(prompt).ids, dtype=torch.long, device=dev)[None]
t0 = time.time(); n = 0
for _ in range(maxtok):
logits, _ = m(idx[:, -block:]); logits = logits[:, -1, :].float()
for t in set(idx[0].tolist()):
logits[0, t] /= REP_PEN if logits[0, t] > 0 else 1/REP_PEN
for t in banned(idx[0].tolist(), NGRAM):
logits[0, t] = -float("inf")
logits /= TEMP
kth = torch.topk(logits, TOP_K)[0][..., -1, None]; logits[logits < kth] = -float("inf")
sl, si = torch.sort(logits, descending=True)
cum = torch.cumsum(F.softmax(sl, dim=-1), dim=-1)
rm = cum > TOP_P; rm[..., 1:] = rm[..., :-1].clone(); rm[..., 0] = False
logits[0, si[0][rm[0]]] = -float("inf")
nxt = torch.multinomial(F.softmax(logits, dim=-1), 1); n += 1
if nxt.item() == EOT: break
idx = torch.cat([idx, nxt], dim=1)
dt = time.time() - t0
return tok.decode(idx[0].tolist()), n/dt
txt, tps = gen(prompt, maxtok)
print(f"[device={dev} {tps:.1f} tok/s]\n")
print(txt)
|