File size: 2,262 Bytes
78c54ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""GPT-2-PL probe on Apple Silicon (MPS). Usage: python sample_mac.py "Prompt" [n_tok]"""
import os, sys, time, torch, torch.nn.functional as F
from model import GPTConfig, GPT
from tokenizers import Tokenizer

HERE = os.path.dirname(os.path.abspath(__file__))
ROOT = os.path.dirname(HERE)
TEMP, TOP_K, TOP_P, REP_PEN, NGRAM = 0.7, 40, 0.92, 1.15, 3
EOT = 0
dev = "mps" if torch.backends.mps.is_available() else "cpu"

prompt = sys.argv[1] if len(sys.argv) > 1 else "Polska jest"
maxtok = int(sys.argv[2]) if len(sys.argv) > 2 else 90

ck = torch.load(os.path.join(ROOT, "model", "ckpt.pt"), map_location="cpu")
m = GPT(GPTConfig(**ck["model_args"]))
sd = ck["model"]
for k in list(sd):
    if k.startswith("_orig_mod."):
        sd[k[len("_orig_mod."):]] = sd.pop(k)
m.load_state_dict(sd); m.eval().to(dev)
block = ck["model_args"]["block_size"]
tok = Tokenizer.from_file(os.path.join(ROOT, "tokenizers", "polish_bpe_32k.json"))

def banned(seq, n):
    if len(seq) < n - 1: return set()
    pre = tuple(seq[-(n-1):]); bad = set()
    for i in range(len(seq)-n+1):
        if tuple(seq[i:i+n-1]) == pre: bad.add(seq[i+n-1])
    return bad

@torch.no_grad()
def gen(prompt, maxtok):
    idx = torch.tensor(tok.encode(prompt).ids, dtype=torch.long, device=dev)[None]
    t0 = time.time(); n = 0
    for _ in range(maxtok):
        logits, _ = m(idx[:, -block:]); logits = logits[:, -1, :].float()
        for t in set(idx[0].tolist()):
            logits[0, t] /= REP_PEN if logits[0, t] > 0 else 1/REP_PEN
        for t in banned(idx[0].tolist(), NGRAM):
            logits[0, t] = -float("inf")
        logits /= TEMP
        kth = torch.topk(logits, TOP_K)[0][..., -1, None]; logits[logits < kth] = -float("inf")
        sl, si = torch.sort(logits, descending=True)
        cum = torch.cumsum(F.softmax(sl, dim=-1), dim=-1)
        rm = cum > TOP_P; rm[..., 1:] = rm[..., :-1].clone(); rm[..., 0] = False
        logits[0, si[0][rm[0]]] = -float("inf")
        nxt = torch.multinomial(F.softmax(logits, dim=-1), 1); n += 1
        if nxt.item() == EOT: break
        idx = torch.cat([idx, nxt], dim=1)
    dt = time.time() - t0
    return tok.decode(idx[0].tolist()), n/dt

txt, tps = gen(prompt, maxtok)
print(f"[device={dev}  {tps:.1f} tok/s]\n")
print(txt)