File size: 6,632 Bytes
337273e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | """
Batched BLiMP scorer for İvme — fast, GPU-parallel.
Scores all 67 BLiMP subtasks by batching sentence pairs through the model
instead of looping one at a time. On a Blackwell this runs the whole suite
in well under a minute.
Method: for each (good, bad) pair, compute total log-prob of each sentence
and count a win when logprob(good) > logprob(bad). Sentences are padded into
batches and scored with a length mask so padding contributes nothing.
Usage:
python eval_blimp.py --checkpoint checkpoints/ivme_base_ema.pt
python eval_blimp.py --checkpoint checkpoints/ivme_base_ema.pt --batch_size 256
"""
from __future__ import annotations
import argparse
import json
import sys
import torch
import torch.nn.functional as F
from tokenizers import Tokenizer
from datasets import load_dataset
sys.path.insert(0, ".")
from model import IvmeConfig, IvmeConversate
TOKENIZER_PATH = "ivme_tokenizer.json"
BLIMP_TASKS = [
"adjunct_island", "anaphor_gender_agreement", "anaphor_number_agreement",
"animate_subject_passive", "animate_subject_trans", "causative",
"complex_NP_island", "coordinate_structure_constraint_complex_left_branch",
"coordinate_structure_constraint_object_extraction", "determiner_noun_agreement_1",
"determiner_noun_agreement_2", "determiner_noun_agreement_irregular_1",
"determiner_noun_agreement_irregular_2", "determiner_noun_agreement_with_adj_2",
"determiner_noun_agreement_with_adj_irregular_1",
"determiner_noun_agreement_with_adj_irregular_2",
"determiner_noun_agreement_with_adjective_1", "distractor_agreement_relational_noun",
"distractor_agreement_relative_clause", "drop_argument", "ellipsis_n_bar_1",
"ellipsis_n_bar_2", "existential_there_object_raising",
"existential_there_quantifiers_1", "existential_there_quantifiers_2",
"existential_there_subject_raising", "expletive_it_object_raising", "inchoative",
"intransitive", "irregular_past_participle_adjectives",
"irregular_past_participle_verbs", "irregular_plural_subject_verb_agreement_1",
"irregular_plural_subject_verb_agreement_2", "left_branch_island_echo_question",
"left_branch_island_simple_question", "matrix_question_npi_licensor_present",
"npi_present_1", "npi_present_2", "only_npi_licensor_present", "only_npi_scope",
"passive_1", "passive_2", "principle_A_c_command", "principle_A_case_1",
"principle_A_case_2", "principle_A_domain_1", "principle_A_domain_2",
"principle_A_domain_3", "principle_A_reconstruction",
"regular_plural_subject_verb_agreement_1", "regular_plural_subject_verb_agreement_2",
"sentential_negation_npi_licensor_present", "sentential_negation_npi_scope",
"sentential_subject_island", "superlative_quantifiers_1", "superlative_quantifiers_2",
"tough_vs_raising_1", "tough_vs_raising_2", "transitive", "wh_island",
"wh_questions_object_gap", "wh_questions_subject_gap",
"wh_questions_subject_gap_long_distance", "wh_vs_that_no_gap",
"wh_vs_that_no_gap_long_distance", "wh_vs_that_with_gap",
"wh_vs_that_with_gap_long_distance",
]
@torch.no_grad()
def batch_logprobs(model, token_lists, device, pad_id, max_len):
"""Total log-prob of each sequence in a padded batch. token_lists: list[list[int]]."""
B = len(token_lists)
L = min(max(len(t) for t in token_lists), max_len)
inp = torch.full((B, L), pad_id, dtype=torch.long, device=device)
lengths = []
for i, t in enumerate(token_lists):
t = t[:L]
inp[i, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
lengths.append(len(t))
with torch.autocast(device_type=device.type, dtype=torch.bfloat16,
enabled=device.type == "cuda"):
logits, _ = model(inp)
logp = F.log_softmax(logits.float(), dim=-1)
targets = inp[:, 1:]
pred = logp[:, :-1, :]
tok_lp = pred.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
mask = torch.zeros_like(tok_lp)
for i, n in enumerate(lengths):
mask[i, : max(0, n - 1)] = 1.0
return (tok_lp * mask).sum(dim=1)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--checkpoint", required=True)
ap.add_argument("--batch_size", type=int, default=256)
ap.add_argument("--output", default="blimp_results.json")
args = ap.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tok = Tokenizer.from_file(TOKENIZER_PATH)
pad_id = tok.token_to_id("<|pad|>") or 0
ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
cfg = ckpt["cfg"]
cfg.attn_backend = "sdpa"
max_len = cfg.max_seq_len
model = IvmeConversate(cfg).to(device)
model.load_state_dict(ckpt["model"])
model.eval()
print(f"[blimp] model loaded: {model.num_params()/1e6:.1f}M on {device}")
print("[blimp] loading full BLiMP dataset (one download)...")
full_ds = load_dataset("WillHeld/blimp", split="train")
by_task = {t: {"good": [], "bad": []} for t in BLIMP_TASKS}
for row in full_ds:
uid = row["UID"]
if uid in by_task:
by_task[uid]["good"].append(row["sentence_good"])
by_task[uid]["bad"].append(row["sentence_bad"])
print(f"[blimp] {len(full_ds)} examples bucketed into {len(BLIMP_TASKS)} subtasks\n")
results = {}
total_correct = total_examples = 0
for i, task in enumerate(BLIMP_TASKS):
goods = by_task[task]["good"]
bads = by_task[task]["bad"]
good_tok = [tok.encode(s).ids for s in goods]
bad_tok = [tok.encode(s).ids for s in bads]
correct = 0
for start in range(0, len(good_tok), args.batch_size):
gb = good_tok[start : start + args.batch_size]
bb = bad_tok[start : start + args.batch_size]
g_lp = batch_logprobs(model, gb, device, pad_id, max_len)
b_lp = batch_logprobs(model, bb, device, pad_id, max_len)
correct += (g_lp > b_lp).sum().item()
acc = correct / len(goods)
results[task] = acc
total_correct += correct
total_examples += len(goods)
running = total_correct / total_examples
print(f"[{i+1:02d}/{len(BLIMP_TASKS)}] {task:<55} {acc*100:5.1f}% "
f"(avg: {running*100:.2f}%)")
final = total_correct / total_examples
print(f"\n{'='*60}")
print(f" BLiMP average: {final*100:.2f}% (random baseline: 50%)")
print(f"{'='*60}")
with open(args.output, "w") as f:
json.dump({"tasks": results, "average": final}, f, indent=2)
print(f"\n[blimp] saved -> {args.output}")
if __name__ == "__main__":
main() |