slayer-gpt-tokenizer-model / scripts /knowbench_mac.py
kacperwikiel's picture
Upload Slayer GPT tokenizer model archive
78c54ec verified
Raw
History Blame Contribute Delete
4.31 kB
"""Mini knowledge bench (likelihood MCQ) for GPT-2-PL. Run: python knowbench_mac.py [-v]
Scores each item by length-normalized log-prob of each answer; picks argmax. No generation."""
import os, sys, math, torch, torch.nn.functional as F
from model import GPTConfig, GPT
from tokenizers import Tokenizer
HERE = os.path.dirname(os.path.abspath(__file__))
ROOT = os.path.dirname(HERE)
VERBOSE = "-v" in sys.argv
dev = "mps" if torch.backends.mps.is_available() else "cpu"
# (prompt, correct, [distractors]) — correct is index 0; choices shuffled deterministically by pos
ITEMS = [
("Stolicą Polski jest", "Warszawa", ["Kraków", "Gdańsk", "Poznań"]),
("Najdłuższą rzeką w Polsce jest", "Wisła", ["Odra", "Warta", "Bug"]),
("Autorem „Pana Tadeusza” jest", "Adam Mickiewicz", ["Juliusz Słowacki", "Henryk Sienkiewicz", "Bolesław Prus"]),
("Pierwszym koronowanym królem Polski był", "Bolesław Chrobry", ["Kazimierz Wielki", "Władysław Łokietek", "Jan Sobieski"]),
("Polska leży w", "Europie", ["Azji", "Afryce", "Ameryce"]),
("Walutą Polski jest", "złoty", ["euro", "dolar", "frank"]),
("Tatry to", "góry", ["rzeka", "jezioro", "miasto"]),
("Mikołaj Kopernik był", "astronomem", ["malarzem", "kompozytorem", "pisarzem"]),
("Fryderyk Chopin komponował", "muzykę", ["obrazy", "powieści", "rzeźby"]),
("Druga wojna światowa wybuchła w roku", "1939", ["1914", "1945", "1918"]),
("Polska wstąpiła do Unii Europejskiej w roku", "2004", ["1999", "2010", "1989"]),
("Słońce jest", "gwiazdą", ["planetą", "księżycem", "kometą"]),
("Woda składa się z wodoru i", "tlenu", ["azotu", "węgla", "żelaza"]),
("Największym oceanem jest Ocean", "Spokojny", ["Atlantycki", "Indyjski", "Arktyczny"]),
("Pszczoły produkują", "miód", ["mleko", "jedwab", "wełnę"]),
("Stolicą Francji jest", "Paryż", ["Londyn", "Berlin", "Madryt"]),
("Wisła wpada do Morza", "Bałtyckiego", ["Czarnego", "Śródziemnego", "Czerwonego"]),
("Lech Wałęsa był przywódcą", "Solidarności", ["PZPR", "Sejmu", "wojska"]),
("Kraków leży nad", "Wisłą", ["Odrą", "Wartą", "Bugiem"]),
("Zakopane leży w", "Tatrach", ["Bieszczadach", "Sudetach", "Karkonoszach"]),
("Bursztyn powstaje z", "żywicy", ["kamienia", "metalu", "piasku"]),
("Człowiek do oddychania potrzebuje", "tlenu", ["azotu", "wodoru", "helu"]),
("Księżyc krąży wokół", "Ziemi", ["Słońca", "Marsa", "Wenus"]),
("Orzeł biały znajduje się w godle", "Polski", ["Niemiec", "Francji", "Czech"]),
]
ck = torch.load(os.path.join(ROOT, "model", "ckpt.pt"), map_location="cpu")
m = GPT(GPTConfig(**ck["model_args"]))
sd = ck["model"]
for k in list(sd):
if k.startswith("_orig_mod."): sd[k[len("_orig_mod."):]] = sd.pop(k)
m.load_state_dict(sd); m.eval().to(dev)
block = ck["model_args"]["block_size"]
tok = Tokenizer.from_file(os.path.join(ROOT, "tokenizers", "polish_bpe_32k.json"))
@torch.no_grad()
def score(prompt, answer):
"""length-normalized log-prob of answer tokens given prompt."""
pids = tok.encode(prompt).ids
fids = tok.encode(prompt + " " + answer).ids
# common prefix length (BPE boundary safety)
plen = 0
for a, b in zip(pids, fids):
if a != b: break
plen += 1
ans = fids[plen:]
if not ans: return -1e9
x = torch.tensor(fids, dtype=torch.long, device=dev)[None]
logits, _ = m(x, x) # targets triggers full-position logits
logp = F.log_softmax(logits[0].float(), dim=-1)
total = sum(logp[plen + i - 1, ans[i]].item() for i in range(len(ans)))
return total / len(ans)
correct = 0
rows = []
for prompt, ans, dist in ITEMS:
cands = [ans] + dist
scores = [(c, score(prompt, c)) for c in cands]
pred = max(scores, key=lambda s: s[1])[0]
ok = pred == ans
correct += ok
rows.append((prompt, ans, pred, ok))
n = len(ITEMS)
base = sum(1/(1+len(d)) for _,_,d in ITEMS) / n
print(f"\nMini-bench wiedzy PL (likelihood MCQ, {n} pytań, 4 opcje)")
print(f"checkpoint iter {ck.get('iter')} · device {dev}")
print("=" * 50)
print(f"ACCURACY: {correct}/{n} = {100*correct/n:.1f}% (losowy baseline {100*base:.1f}%)")
if VERBOSE:
print("-" * 50)
for p, a, pred, ok in rows:
print(f"{'✓' if ok else '✗'} {p}{pred}" + ("" if ok else f" [poprawnie: {a}]"))