File size: 4,310 Bytes
78c54ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Mini knowledge bench (likelihood MCQ) for GPT-2-PL. Run: python knowbench_mac.py [-v]
Scores each item by length-normalized log-prob of each answer; picks argmax. No generation."""
import os, sys, math, torch, torch.nn.functional as F
from model import GPTConfig, GPT
from tokenizers import Tokenizer

HERE = os.path.dirname(os.path.abspath(__file__))
ROOT = os.path.dirname(HERE)
VERBOSE = "-v" in sys.argv
dev = "mps" if torch.backends.mps.is_available() else "cpu"

# (prompt, correct, [distractors])  — correct is index 0; choices shuffled deterministically by pos
ITEMS = [
 ("Stolicą Polski jest", "Warszawa", ["Kraków", "Gdańsk", "Poznań"]),
 ("Najdłuższą rzeką w Polsce jest", "Wisła", ["Odra", "Warta", "Bug"]),
 ("Autorem „Pana Tadeusza” jest", "Adam Mickiewicz", ["Juliusz Słowacki", "Henryk Sienkiewicz", "Bolesław Prus"]),
 ("Pierwszym koronowanym królem Polski był", "Bolesław Chrobry", ["Kazimierz Wielki", "Władysław Łokietek", "Jan Sobieski"]),
 ("Polska leży w", "Europie", ["Azji", "Afryce", "Ameryce"]),
 ("Walutą Polski jest", "złoty", ["euro", "dolar", "frank"]),
 ("Tatry to", "góry", ["rzeka", "jezioro", "miasto"]),
 ("Mikołaj Kopernik był", "astronomem", ["malarzem", "kompozytorem", "pisarzem"]),
 ("Fryderyk Chopin komponował", "muzykę", ["obrazy", "powieści", "rzeźby"]),
 ("Druga wojna światowa wybuchła w roku", "1939", ["1914", "1945", "1918"]),
 ("Polska wstąpiła do Unii Europejskiej w roku", "2004", ["1999", "2010", "1989"]),
 ("Słońce jest", "gwiazdą", ["planetą", "księżycem", "kometą"]),
 ("Woda składa się z wodoru i", "tlenu", ["azotu", "węgla", "żelaza"]),
 ("Największym oceanem jest Ocean", "Spokojny", ["Atlantycki", "Indyjski", "Arktyczny"]),
 ("Pszczoły produkują", "miód", ["mleko", "jedwab", "wełnę"]),
 ("Stolicą Francji jest", "Paryż", ["Londyn", "Berlin", "Madryt"]),
 ("Wisła wpada do Morza", "Bałtyckiego", ["Czarnego", "Śródziemnego", "Czerwonego"]),
 ("Lech Wałęsa był przywódcą", "Solidarności", ["PZPR", "Sejmu", "wojska"]),
 ("Kraków leży nad", "Wisłą", ["Odrą", "Wartą", "Bugiem"]),
 ("Zakopane leży w", "Tatrach", ["Bieszczadach", "Sudetach", "Karkonoszach"]),
 ("Bursztyn powstaje z", "żywicy", ["kamienia", "metalu", "piasku"]),
 ("Człowiek do oddychania potrzebuje", "tlenu", ["azotu", "wodoru", "helu"]),
 ("Księżyc krąży wokół", "Ziemi", ["Słońca", "Marsa", "Wenus"]),
 ("Orzeł biały znajduje się w godle", "Polski", ["Niemiec", "Francji", "Czech"]),
]

ck = torch.load(os.path.join(ROOT, "model", "ckpt.pt"), map_location="cpu")
m = GPT(GPTConfig(**ck["model_args"]))
sd = ck["model"]
for k in list(sd):
    if k.startswith("_orig_mod."): sd[k[len("_orig_mod."):]] = sd.pop(k)
m.load_state_dict(sd); m.eval().to(dev)
block = ck["model_args"]["block_size"]
tok = Tokenizer.from_file(os.path.join(ROOT, "tokenizers", "polish_bpe_32k.json"))

@torch.no_grad()
def score(prompt, answer):
    """length-normalized log-prob of answer tokens given prompt."""
    pids = tok.encode(prompt).ids
    fids = tok.encode(prompt + " " + answer).ids
    # common prefix length (BPE boundary safety)
    plen = 0
    for a, b in zip(pids, fids):
        if a != b: break
        plen += 1
    ans = fids[plen:]
    if not ans: return -1e9
    x = torch.tensor(fids, dtype=torch.long, device=dev)[None]
    logits, _ = m(x, x)                      # targets triggers full-position logits
    logp = F.log_softmax(logits[0].float(), dim=-1)
    total = sum(logp[plen + i - 1, ans[i]].item() for i in range(len(ans)))
    return total / len(ans)

correct = 0
rows = []
for prompt, ans, dist in ITEMS:
    cands = [ans] + dist
    scores = [(c, score(prompt, c)) for c in cands]
    pred = max(scores, key=lambda s: s[1])[0]
    ok = pred == ans
    correct += ok
    rows.append((prompt, ans, pred, ok))

n = len(ITEMS)
base = sum(1/(1+len(d)) for _,_,d in ITEMS) / n
print(f"\nMini-bench wiedzy PL  (likelihood MCQ, {n} pytań, 4 opcje)")
print(f"checkpoint iter {ck.get('iter')}  ·  device {dev}")
print("=" * 50)
print(f"ACCURACY: {correct}/{n} = {100*correct/n:.1f}%   (losowy baseline {100*base:.1f}%)")
if VERBOSE:
    print("-" * 50)
    for p, a, pred, ok in rows:
        print(f"{'✓' if ok else '✗'}  {p}{pred}" + ("" if ok else f"  [poprawnie: {a}]"))