File size: 6,632 Bytes
337273e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
Batched BLiMP scorer for İvme — fast, GPU-parallel.

Scores all 67 BLiMP subtasks by batching sentence pairs through the model
instead of looping one at a time. On a Blackwell this runs the whole suite
in well under a minute.

Method: for each (good, bad) pair, compute total log-prob of each sentence
and count a win when logprob(good) > logprob(bad). Sentences are padded into
batches and scored with a length mask so padding contributes nothing.

Usage:
    python eval_blimp.py --checkpoint checkpoints/ivme_base_ema.pt
    python eval_blimp.py --checkpoint checkpoints/ivme_base_ema.pt --batch_size 256
"""

from __future__ import annotations
import argparse
import json
import sys
import torch
import torch.nn.functional as F
from tokenizers import Tokenizer
from datasets import load_dataset

sys.path.insert(0, ".")
from model import IvmeConfig, IvmeConversate

TOKENIZER_PATH = "ivme_tokenizer.json"

BLIMP_TASKS = [
    "adjunct_island", "anaphor_gender_agreement", "anaphor_number_agreement",
    "animate_subject_passive", "animate_subject_trans", "causative",
    "complex_NP_island", "coordinate_structure_constraint_complex_left_branch",
    "coordinate_structure_constraint_object_extraction", "determiner_noun_agreement_1",
    "determiner_noun_agreement_2", "determiner_noun_agreement_irregular_1",
    "determiner_noun_agreement_irregular_2", "determiner_noun_agreement_with_adj_2",
    "determiner_noun_agreement_with_adj_irregular_1",
    "determiner_noun_agreement_with_adj_irregular_2",
    "determiner_noun_agreement_with_adjective_1", "distractor_agreement_relational_noun",
    "distractor_agreement_relative_clause", "drop_argument", "ellipsis_n_bar_1",
    "ellipsis_n_bar_2", "existential_there_object_raising",
    "existential_there_quantifiers_1", "existential_there_quantifiers_2",
    "existential_there_subject_raising", "expletive_it_object_raising", "inchoative",
    "intransitive", "irregular_past_participle_adjectives",
    "irregular_past_participle_verbs", "irregular_plural_subject_verb_agreement_1",
    "irregular_plural_subject_verb_agreement_2", "left_branch_island_echo_question",
    "left_branch_island_simple_question", "matrix_question_npi_licensor_present",
    "npi_present_1", "npi_present_2", "only_npi_licensor_present", "only_npi_scope",
    "passive_1", "passive_2", "principle_A_c_command", "principle_A_case_1",
    "principle_A_case_2", "principle_A_domain_1", "principle_A_domain_2",
    "principle_A_domain_3", "principle_A_reconstruction",
    "regular_plural_subject_verb_agreement_1", "regular_plural_subject_verb_agreement_2",
    "sentential_negation_npi_licensor_present", "sentential_negation_npi_scope",
    "sentential_subject_island", "superlative_quantifiers_1", "superlative_quantifiers_2",
    "tough_vs_raising_1", "tough_vs_raising_2", "transitive", "wh_island",
    "wh_questions_object_gap", "wh_questions_subject_gap",
    "wh_questions_subject_gap_long_distance", "wh_vs_that_no_gap",
    "wh_vs_that_no_gap_long_distance", "wh_vs_that_with_gap",
    "wh_vs_that_with_gap_long_distance",
]


@torch.no_grad()
def batch_logprobs(model, token_lists, device, pad_id, max_len):
    """Total log-prob of each sequence in a padded batch. token_lists: list[list[int]]."""
    B = len(token_lists)
    L = min(max(len(t) for t in token_lists), max_len)
    inp = torch.full((B, L), pad_id, dtype=torch.long, device=device)
    lengths = []
    for i, t in enumerate(token_lists):
        t = t[:L]
        inp[i, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
        lengths.append(len(t))

    with torch.autocast(device_type=device.type, dtype=torch.bfloat16,
                        enabled=device.type == "cuda"):
        logits, _ = model(inp)

    logp = F.log_softmax(logits.float(), dim=-1)
    targets = inp[:, 1:]
    pred = logp[:, :-1, :]
    tok_lp = pred.gather(-1, targets.unsqueeze(-1)).squeeze(-1)

    mask = torch.zeros_like(tok_lp)
    for i, n in enumerate(lengths):
        mask[i, : max(0, n - 1)] = 1.0
    return (tok_lp * mask).sum(dim=1)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--checkpoint", required=True)
    ap.add_argument("--batch_size", type=int, default=256)
    ap.add_argument("--output", default="blimp_results.json")
    args = ap.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tok = Tokenizer.from_file(TOKENIZER_PATH)
    pad_id = tok.token_to_id("<|pad|>") or 0

    ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
    cfg = ckpt["cfg"]
    cfg.attn_backend = "sdpa"
    max_len = cfg.max_seq_len
    model = IvmeConversate(cfg).to(device)
    model.load_state_dict(ckpt["model"])
    model.eval()
    print(f"[blimp] model loaded: {model.num_params()/1e6:.1f}M on {device}")

    print("[blimp] loading full BLiMP dataset (one download)...")
    full_ds = load_dataset("WillHeld/blimp", split="train")
    by_task = {t: {"good": [], "bad": []} for t in BLIMP_TASKS}
    for row in full_ds:
        uid = row["UID"]
        if uid in by_task:
            by_task[uid]["good"].append(row["sentence_good"])
            by_task[uid]["bad"].append(row["sentence_bad"])
    print(f"[blimp] {len(full_ds)} examples bucketed into {len(BLIMP_TASKS)} subtasks\n")

    results = {}
    total_correct = total_examples = 0

    for i, task in enumerate(BLIMP_TASKS):
        goods = by_task[task]["good"]
        bads = by_task[task]["bad"]
        good_tok = [tok.encode(s).ids for s in goods]
        bad_tok = [tok.encode(s).ids for s in bads]

        correct = 0
        for start in range(0, len(good_tok), args.batch_size):
            gb = good_tok[start : start + args.batch_size]
            bb = bad_tok[start : start + args.batch_size]
            g_lp = batch_logprobs(model, gb, device, pad_id, max_len)
            b_lp = batch_logprobs(model, bb, device, pad_id, max_len)
            correct += (g_lp > b_lp).sum().item()

        acc = correct / len(goods)
        results[task] = acc
        total_correct += correct
        total_examples += len(goods)
        running = total_correct / total_examples
        print(f"[{i+1:02d}/{len(BLIMP_TASKS)}] {task:<55} {acc*100:5.1f}%  "
              f"(avg: {running*100:.2f}%)")

    final = total_correct / total_examples
    print(f"\n{'='*60}")
    print(f"  BLiMP average: {final*100:.2f}%   (random baseline: 50%)")
    print(f"{'='*60}")

    with open(args.output, "w") as f:
        json.dump({"tasks": results, "average": final}, f, indent=2)
    print(f"\n[blimp] saved -> {args.output}")


if __name__ == "__main__":
    main()