| """Mini SYNTAX bench (grammatical minimal pairs, BLiMP-style) for GPT-2-PL. |
| For each pair, model should give higher total log-prob to the grammatical sentence. |
| Run: python syntaxbench_mac.py [-v]""" |
| import os, sys, torch, torch.nn.functional as F |
| from model import GPTConfig, GPT |
| from tokenizers import Tokenizer |
|
|
| HERE = os.path.dirname(os.path.abspath(__file__)) |
| ROOT = os.path.dirname(HERE) |
| VERBOSE = "-v" in sys.argv |
| dev = "mps" if torch.backends.mps.is_available() else "cpu" |
| EOT = 0 |
|
|
| |
| PAIRS = [ |
| ("Czerwony samochód stoi przed domem.", "Czerwona samochód stoi przed domem."), |
| ("Mały pies głośno szczeka.", "Mała pies głośno szczeka."), |
| ("Idę do szkoły.", "Idę do szkoła."), |
| ("Mieszkam w Warszawie.", "Mieszkam w Warszawa."), |
| ("Dzieci bawią się w parku.", "Dzieci bawi się w parku."), |
| ("Kot śpi na kanapie.", "Kot śpią na kanapie."), |
| ("Ona poszła do domu.", "Ona poszedł do domu."), |
| ("On czytał książkę.", "On czytała książkę."), |
| ("Widzę dużego psa.", "Widzę dużego pies."), |
| ("Lubię mocną kawę.", "Lubię mocną kawa."), |
| ("Duże domy stoją w mieście.", "Duży domy stoją w mieście."), |
| ("Piszę długopisem.", "Piszę długopis."), |
| ("Nie mam czasu.", "Nie mam czas."), |
| ("Nie lubię gorzkiej herbaty.", "Nie lubię gorzkiej herbata."), |
| ("Trzy koty śpią na dworze.", "Trzy kot śpią na dworze."), |
| ("Rozmawiam z bratem.", "Rozmawiam z brat."), |
| ("To jest moja siostra.", "To jest mój siostra."), |
| ("To jest mój brat.", "To jest moja brat."), |
| ("Ja idę do domu.", "Ja idzie do domu."), |
| ("Ty masz rację.", "Ty ma rację."), |
| ("Wczoraj byłem w kinie.", "Wczoraj byłem w kino."), |
| ("Ten wysoki mężczyzna śpiewa.", "Ten wysoka mężczyzna śpiewa."), |
| ] |
|
|
| ck = torch.load(os.path.join(ROOT, "model", "ckpt.pt"), map_location="cpu") |
| m = GPT(GPTConfig(**ck["model_args"])) |
| sd = ck["model"] |
| for k in list(sd): |
| if k.startswith("_orig_mod."): sd[k[len("_orig_mod."):]] = sd.pop(k) |
| m.load_state_dict(sd); m.eval().to(dev) |
| tok = Tokenizer.from_file(os.path.join(ROOT, "tokenizers", "polish_bpe_32k.json")) |
|
|
| @torch.no_grad() |
| def logprob(sentence): |
| ids = [EOT] + tok.encode(sentence).ids |
| x = torch.tensor(ids, dtype=torch.long, device=dev)[None] |
| logits, _ = m(x, x) |
| lp = F.log_softmax(logits[0].float(), dim=-1) |
| return sum(lp[i, ids[i+1]].item() for i in range(len(ids)-1)) |
|
|
| correct = 0; rows = [] |
| for good, bad in PAIRS: |
| lg, lb = logprob(good), logprob(bad) |
| ok = lg > lb; correct += ok |
| rows.append((good, bad, ok)) |
|
|
| n = len(PAIRS) |
| print(f"\nMini-bench SKŁADNI PL (pary minimalne, {n} par)") |
| print(f"checkpoint iter {ck.get('iter')} · device {dev}") |
| print("=" * 50) |
| print(f"ACCURACY: {correct}/{n} = {100*correct/n:.1f}% (losowy baseline 50.0%)") |
| if VERBOSE: |
| print("-" * 50) |
| for good, bad, ok in rows: |
| print(f"{'✓' if ok else '✗'} {good} vs {bad}") |
|
|