| """Mini knowledge bench (likelihood MCQ) for GPT-2-PL. Run: python knowbench_mac.py [-v] |
| Scores each item by length-normalized log-prob of each answer; picks argmax. No generation.""" |
| import os, sys, math, torch, torch.nn.functional as F |
| from model import GPTConfig, GPT |
| from tokenizers import Tokenizer |
|
|
| HERE = os.path.dirname(os.path.abspath(__file__)) |
| ROOT = os.path.dirname(HERE) |
| VERBOSE = "-v" in sys.argv |
| dev = "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
| |
| ITEMS = [ |
| ("Stolicą Polski jest", "Warszawa", ["Kraków", "Gdańsk", "Poznań"]), |
| ("Najdłuższą rzeką w Polsce jest", "Wisła", ["Odra", "Warta", "Bug"]), |
| ("Autorem „Pana Tadeusza” jest", "Adam Mickiewicz", ["Juliusz Słowacki", "Henryk Sienkiewicz", "Bolesław Prus"]), |
| ("Pierwszym koronowanym królem Polski był", "Bolesław Chrobry", ["Kazimierz Wielki", "Władysław Łokietek", "Jan Sobieski"]), |
| ("Polska leży w", "Europie", ["Azji", "Afryce", "Ameryce"]), |
| ("Walutą Polski jest", "złoty", ["euro", "dolar", "frank"]), |
| ("Tatry to", "góry", ["rzeka", "jezioro", "miasto"]), |
| ("Mikołaj Kopernik był", "astronomem", ["malarzem", "kompozytorem", "pisarzem"]), |
| ("Fryderyk Chopin komponował", "muzykę", ["obrazy", "powieści", "rzeźby"]), |
| ("Druga wojna światowa wybuchła w roku", "1939", ["1914", "1945", "1918"]), |
| ("Polska wstąpiła do Unii Europejskiej w roku", "2004", ["1999", "2010", "1989"]), |
| ("Słońce jest", "gwiazdą", ["planetą", "księżycem", "kometą"]), |
| ("Woda składa się z wodoru i", "tlenu", ["azotu", "węgla", "żelaza"]), |
| ("Największym oceanem jest Ocean", "Spokojny", ["Atlantycki", "Indyjski", "Arktyczny"]), |
| ("Pszczoły produkują", "miód", ["mleko", "jedwab", "wełnę"]), |
| ("Stolicą Francji jest", "Paryż", ["Londyn", "Berlin", "Madryt"]), |
| ("Wisła wpada do Morza", "Bałtyckiego", ["Czarnego", "Śródziemnego", "Czerwonego"]), |
| ("Lech Wałęsa był przywódcą", "Solidarności", ["PZPR", "Sejmu", "wojska"]), |
| ("Kraków leży nad", "Wisłą", ["Odrą", "Wartą", "Bugiem"]), |
| ("Zakopane leży w", "Tatrach", ["Bieszczadach", "Sudetach", "Karkonoszach"]), |
| ("Bursztyn powstaje z", "żywicy", ["kamienia", "metalu", "piasku"]), |
| ("Człowiek do oddychania potrzebuje", "tlenu", ["azotu", "wodoru", "helu"]), |
| ("Księżyc krąży wokół", "Ziemi", ["Słońca", "Marsa", "Wenus"]), |
| ("Orzeł biały znajduje się w godle", "Polski", ["Niemiec", "Francji", "Czech"]), |
| ] |
|
|
| ck = torch.load(os.path.join(ROOT, "model", "ckpt.pt"), map_location="cpu") |
| m = GPT(GPTConfig(**ck["model_args"])) |
| sd = ck["model"] |
| for k in list(sd): |
| if k.startswith("_orig_mod."): sd[k[len("_orig_mod."):]] = sd.pop(k) |
| m.load_state_dict(sd); m.eval().to(dev) |
| block = ck["model_args"]["block_size"] |
| tok = Tokenizer.from_file(os.path.join(ROOT, "tokenizers", "polish_bpe_32k.json")) |
|
|
| @torch.no_grad() |
| def score(prompt, answer): |
| """length-normalized log-prob of answer tokens given prompt.""" |
| pids = tok.encode(prompt).ids |
| fids = tok.encode(prompt + " " + answer).ids |
| |
| plen = 0 |
| for a, b in zip(pids, fids): |
| if a != b: break |
| plen += 1 |
| ans = fids[plen:] |
| if not ans: return -1e9 |
| x = torch.tensor(fids, dtype=torch.long, device=dev)[None] |
| logits, _ = m(x, x) |
| logp = F.log_softmax(logits[0].float(), dim=-1) |
| total = sum(logp[plen + i - 1, ans[i]].item() for i in range(len(ans))) |
| return total / len(ans) |
|
|
| correct = 0 |
| rows = [] |
| for prompt, ans, dist in ITEMS: |
| cands = [ans] + dist |
| scores = [(c, score(prompt, c)) for c in cands] |
| pred = max(scores, key=lambda s: s[1])[0] |
| ok = pred == ans |
| correct += ok |
| rows.append((prompt, ans, pred, ok)) |
|
|
| n = len(ITEMS) |
| base = sum(1/(1+len(d)) for _,_,d in ITEMS) / n |
| print(f"\nMini-bench wiedzy PL (likelihood MCQ, {n} pytań, 4 opcje)") |
| print(f"checkpoint iter {ck.get('iter')} · device {dev}") |
| print("=" * 50) |
| print(f"ACCURACY: {correct}/{n} = {100*correct/n:.1f}% (losowy baseline {100*base:.1f}%)") |
| if VERBOSE: |
| print("-" * 50) |
| for p, a, pred, ok in rows: |
| print(f"{'✓' if ok else '✗'} {p} → {pred}" + ("" if ok else f" [poprawnie: {a}]")) |
|
|