ereniko commited on
Commit
337273e
·
verified ·
1 Parent(s): 44217ec

Upload eval_blimp.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. eval_blimp.py +157 -0
eval_blimp.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Batched BLiMP scorer for İvme — fast, GPU-parallel.
3
+
4
+ Scores all 67 BLiMP subtasks by batching sentence pairs through the model
5
+ instead of looping one at a time. On a Blackwell this runs the whole suite
6
+ in well under a minute.
7
+
8
+ Method: for each (good, bad) pair, compute total log-prob of each sentence
9
+ and count a win when logprob(good) > logprob(bad). Sentences are padded into
10
+ batches and scored with a length mask so padding contributes nothing.
11
+
12
+ Usage:
13
+ python eval_blimp.py --checkpoint checkpoints/ivme_base_ema.pt
14
+ python eval_blimp.py --checkpoint checkpoints/ivme_base_ema.pt --batch_size 256
15
+ """
16
+
17
+ from __future__ import annotations
18
+ import argparse
19
+ import json
20
+ import sys
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from tokenizers import Tokenizer
24
+ from datasets import load_dataset
25
+
26
+ sys.path.insert(0, ".")
27
+ from model import IvmeConfig, IvmeConversate
28
+
29
+ TOKENIZER_PATH = "ivme_tokenizer.json"
30
+
31
+ BLIMP_TASKS = [
32
+ "adjunct_island", "anaphor_gender_agreement", "anaphor_number_agreement",
33
+ "animate_subject_passive", "animate_subject_trans", "causative",
34
+ "complex_NP_island", "coordinate_structure_constraint_complex_left_branch",
35
+ "coordinate_structure_constraint_object_extraction", "determiner_noun_agreement_1",
36
+ "determiner_noun_agreement_2", "determiner_noun_agreement_irregular_1",
37
+ "determiner_noun_agreement_irregular_2", "determiner_noun_agreement_with_adj_2",
38
+ "determiner_noun_agreement_with_adj_irregular_1",
39
+ "determiner_noun_agreement_with_adj_irregular_2",
40
+ "determiner_noun_agreement_with_adjective_1", "distractor_agreement_relational_noun",
41
+ "distractor_agreement_relative_clause", "drop_argument", "ellipsis_n_bar_1",
42
+ "ellipsis_n_bar_2", "existential_there_object_raising",
43
+ "existential_there_quantifiers_1", "existential_there_quantifiers_2",
44
+ "existential_there_subject_raising", "expletive_it_object_raising", "inchoative",
45
+ "intransitive", "irregular_past_participle_adjectives",
46
+ "irregular_past_participle_verbs", "irregular_plural_subject_verb_agreement_1",
47
+ "irregular_plural_subject_verb_agreement_2", "left_branch_island_echo_question",
48
+ "left_branch_island_simple_question", "matrix_question_npi_licensor_present",
49
+ "npi_present_1", "npi_present_2", "only_npi_licensor_present", "only_npi_scope",
50
+ "passive_1", "passive_2", "principle_A_c_command", "principle_A_case_1",
51
+ "principle_A_case_2", "principle_A_domain_1", "principle_A_domain_2",
52
+ "principle_A_domain_3", "principle_A_reconstruction",
53
+ "regular_plural_subject_verb_agreement_1", "regular_plural_subject_verb_agreement_2",
54
+ "sentential_negation_npi_licensor_present", "sentential_negation_npi_scope",
55
+ "sentential_subject_island", "superlative_quantifiers_1", "superlative_quantifiers_2",
56
+ "tough_vs_raising_1", "tough_vs_raising_2", "transitive", "wh_island",
57
+ "wh_questions_object_gap", "wh_questions_subject_gap",
58
+ "wh_questions_subject_gap_long_distance", "wh_vs_that_no_gap",
59
+ "wh_vs_that_no_gap_long_distance", "wh_vs_that_with_gap",
60
+ "wh_vs_that_with_gap_long_distance",
61
+ ]
62
+
63
+
64
+ @torch.no_grad()
65
+ def batch_logprobs(model, token_lists, device, pad_id, max_len):
66
+ """Total log-prob of each sequence in a padded batch. token_lists: list[list[int]]."""
67
+ B = len(token_lists)
68
+ L = min(max(len(t) for t in token_lists), max_len)
69
+ inp = torch.full((B, L), pad_id, dtype=torch.long, device=device)
70
+ lengths = []
71
+ for i, t in enumerate(token_lists):
72
+ t = t[:L]
73
+ inp[i, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
74
+ lengths.append(len(t))
75
+
76
+ with torch.autocast(device_type=device.type, dtype=torch.bfloat16,
77
+ enabled=device.type == "cuda"):
78
+ logits, _ = model(inp)
79
+
80
+ logp = F.log_softmax(logits.float(), dim=-1)
81
+ targets = inp[:, 1:]
82
+ pred = logp[:, :-1, :]
83
+ tok_lp = pred.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
84
+
85
+ mask = torch.zeros_like(tok_lp)
86
+ for i, n in enumerate(lengths):
87
+ mask[i, : max(0, n - 1)] = 1.0
88
+ return (tok_lp * mask).sum(dim=1)
89
+
90
+
91
+ def main():
92
+ ap = argparse.ArgumentParser()
93
+ ap.add_argument("--checkpoint", required=True)
94
+ ap.add_argument("--batch_size", type=int, default=256)
95
+ ap.add_argument("--output", default="blimp_results.json")
96
+ args = ap.parse_args()
97
+
98
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
99
+ tok = Tokenizer.from_file(TOKENIZER_PATH)
100
+ pad_id = tok.token_to_id("<|pad|>") or 0
101
+
102
+ ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
103
+ cfg = ckpt["cfg"]
104
+ cfg.attn_backend = "sdpa"
105
+ max_len = cfg.max_seq_len
106
+ model = IvmeConversate(cfg).to(device)
107
+ model.load_state_dict(ckpt["model"])
108
+ model.eval()
109
+ print(f"[blimp] model loaded: {model.num_params()/1e6:.1f}M on {device}")
110
+
111
+ print("[blimp] loading full BLiMP dataset (one download)...")
112
+ full_ds = load_dataset("WillHeld/blimp", split="train")
113
+ by_task = {t: {"good": [], "bad": []} for t in BLIMP_TASKS}
114
+ for row in full_ds:
115
+ uid = row["UID"]
116
+ if uid in by_task:
117
+ by_task[uid]["good"].append(row["sentence_good"])
118
+ by_task[uid]["bad"].append(row["sentence_bad"])
119
+ print(f"[blimp] {len(full_ds)} examples bucketed into {len(BLIMP_TASKS)} subtasks\n")
120
+
121
+ results = {}
122
+ total_correct = total_examples = 0
123
+
124
+ for i, task in enumerate(BLIMP_TASKS):
125
+ goods = by_task[task]["good"]
126
+ bads = by_task[task]["bad"]
127
+ good_tok = [tok.encode(s).ids for s in goods]
128
+ bad_tok = [tok.encode(s).ids for s in bads]
129
+
130
+ correct = 0
131
+ for start in range(0, len(good_tok), args.batch_size):
132
+ gb = good_tok[start : start + args.batch_size]
133
+ bb = bad_tok[start : start + args.batch_size]
134
+ g_lp = batch_logprobs(model, gb, device, pad_id, max_len)
135
+ b_lp = batch_logprobs(model, bb, device, pad_id, max_len)
136
+ correct += (g_lp > b_lp).sum().item()
137
+
138
+ acc = correct / len(goods)
139
+ results[task] = acc
140
+ total_correct += correct
141
+ total_examples += len(goods)
142
+ running = total_correct / total_examples
143
+ print(f"[{i+1:02d}/{len(BLIMP_TASKS)}] {task:<55} {acc*100:5.1f}% "
144
+ f"(avg: {running*100:.2f}%)")
145
+
146
+ final = total_correct / total_examples
147
+ print(f"\n{'='*60}")
148
+ print(f" BLiMP average: {final*100:.2f}% (random baseline: 50%)")
149
+ print(f"{'='*60}")
150
+
151
+ with open(args.output, "w") as f:
152
+ json.dump({"tasks": results, "average": final}, f, indent=2)
153
+ print(f"\n[blimp] saved -> {args.output}")
154
+
155
+
156
+ if __name__ == "__main__":
157
+ main()