yu-val-weiss
commited on
Commit
·
8f3cd77
1
Parent(s):
803da62
Update blimp.py
Browse files
blimp.py
CHANGED
|
@@ -15,13 +15,83 @@
|
|
| 15 |
|
| 16 |
import datasets
|
| 17 |
import evaluate
|
| 18 |
-
import numpy as np
|
| 19 |
import torch
|
| 20 |
from evaluate import logging
|
| 21 |
-
from torch.nn import CrossEntropyLoss
|
| 22 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 23 |
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
@article{warstadt2020blimp,
|
| 26 |
author = {Warstadt, Alex and Parrish, Alicia and Liu, Haokun and Mohananey, Anhad and Peng, Wei and Wang, Sheng-Fu and Bowman, Samuel R.},
|
| 27 |
title = {BLiMP: The Benchmark of Linguistic Minimal Pairs for English},
|
|
@@ -37,8 +107,7 @@ _CITATION = """\
|
|
| 37 |
}
|
| 38 |
"""
|
| 39 |
|
| 40 |
-
_DESCRIPTION = """
|
| 41 |
-
BLiMP is a challenge set for evaluating what language models (LMs) know about major grammatical phenomena in English.
|
| 42 |
BLiMP consists of 67 sub-datasets, each containing 1000 minimal pairs isolating specific contrasts in syntax, morphology, or semantics.
|
| 43 |
The data is automatically generated according to expert-crafted grammars. Aggregate human agreement with the labels is 96.4%.
|
| 44 |
We use BLiMP to evaluate an n-gram LM, LSTM LM, GPT-2, and Transformer-XL.
|
|
@@ -48,9 +117,12 @@ For more info see https://github.com/alexwarstadt/blimp.
|
|
| 48 |
|
| 49 |
_KWARGS_DESCRIPTION = """
|
| 50 |
Args:
|
| 51 |
-
model_id (str): model used for calculating Blimp
|
|
|
|
| 52 |
batch_size (int): the batch size to run texts through the model. Defaults to 16.
|
| 53 |
-
device (str): device to run on, defaults to 'cuda' when available
|
|
|
|
|
|
|
| 54 |
Returns:
|
| 55 |
blimp: dictionary containing the blimp scores for each of the 67 sub-datasets, as well as the overall accuracy.
|
| 56 |
An LM’s overall accuracy on BLiMP is simply the proportion of the 67,000 minimal pairs in which the model assigns a higher probability to the acceptable sentence.
|
|
@@ -60,7 +132,7 @@ Examples:
|
|
| 60 |
|
| 61 |
|
| 62 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
| 63 |
-
class
|
| 64 |
def _info(self):
|
| 65 |
return evaluate.MetricInfo(
|
| 66 |
module_type="metric",
|
|
@@ -80,12 +152,11 @@ class Perplexity(evaluate.Metric):
|
|
| 80 |
|
| 81 |
def _compute(
|
| 82 |
self,
|
| 83 |
-
predictions,
|
| 84 |
model_id,
|
|
|
|
| 85 |
batch_size: int = 16,
|
| 86 |
-
add_start_token: bool = True,
|
| 87 |
device=None,
|
| 88 |
-
|
| 89 |
):
|
| 90 |
if device is not None:
|
| 91 |
assert device in ["gpu", "cpu", "cuda", "mps"], (
|
|
@@ -102,6 +173,7 @@ class Perplexity(evaluate.Metric):
|
|
| 102 |
|
| 103 |
model = AutoModelForCausalLM.from_pretrained(model_id)
|
| 104 |
model = model.to(device)
|
|
|
|
| 105 |
|
| 106 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 107 |
|
|
@@ -119,78 +191,93 @@ class Perplexity(evaluate.Metric):
|
|
| 119 |
# assign one of the special tokens to also be the pad token
|
| 120 |
tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
|
| 121 |
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
encodings = tokenizer(
|
| 132 |
-
predictions,
|
| 133 |
-
add_special_tokens=False,
|
| 134 |
-
padding=True,
|
| 135 |
-
truncation=True if max_tokenized_len else False,
|
| 136 |
-
max_length=max_tokenized_len,
|
| 137 |
-
return_tensors="pt",
|
| 138 |
-
return_attention_mask=True,
|
| 139 |
-
).to(device)
|
| 140 |
|
| 141 |
-
|
| 142 |
-
attn_masks = encodings["attention_mask"]
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
)
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
)
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
-
|
| 195 |
|
| 196 |
-
|
|
|
|
| 15 |
|
| 16 |
import datasets
|
| 17 |
import evaluate
|
|
|
|
| 18 |
import torch
|
| 19 |
from evaluate import logging
|
|
|
|
| 20 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 21 |
|
| 22 |
+
datasets.logging.set_verbosity_error()
|
| 23 |
+
|
| 24 |
+
BLIMP_PHENOMENA = [
|
| 25 |
+
"adjunct_island",
|
| 26 |
+
"anaphor_gender_agreement",
|
| 27 |
+
"anaphor_number_agreement",
|
| 28 |
+
"animate_subject_passive",
|
| 29 |
+
"animate_subject_trans",
|
| 30 |
+
"causative",
|
| 31 |
+
"complex_NP_island",
|
| 32 |
+
"coordinate_structure_constraint_complex_left_branch",
|
| 33 |
+
"coordinate_structure_constraint_object_extraction",
|
| 34 |
+
"determiner_noun_agreement_1",
|
| 35 |
+
"determiner_noun_agreement_2",
|
| 36 |
+
"determiner_noun_agreement_irregular_1",
|
| 37 |
+
"determiner_noun_agreement_irregular_2",
|
| 38 |
+
"determiner_noun_agreement_with_adj_2",
|
| 39 |
+
"determiner_noun_agreement_with_adj_irregular_1",
|
| 40 |
+
"determiner_noun_agreement_with_adj_irregular_2",
|
| 41 |
+
"determiner_noun_agreement_with_adjective_1",
|
| 42 |
+
"distractor_agreement_relational_noun",
|
| 43 |
+
"distractor_agreement_relative_clause",
|
| 44 |
+
"drop_argument",
|
| 45 |
+
"ellipsis_n_bar_1",
|
| 46 |
+
"ellipsis_n_bar_2",
|
| 47 |
+
"existential_there_object_raising",
|
| 48 |
+
"existential_there_quantifiers_1",
|
| 49 |
+
"existential_there_quantifiers_2",
|
| 50 |
+
"existential_there_subject_raising",
|
| 51 |
+
"expletive_it_object_raising",
|
| 52 |
+
"inchoative",
|
| 53 |
+
"intransitive",
|
| 54 |
+
"irregular_past_participle_adjectives",
|
| 55 |
+
"irregular_past_participle_verbs",
|
| 56 |
+
"irregular_plural_subject_verb_agreement_1",
|
| 57 |
+
"irregular_plural_subject_verb_agreement_2",
|
| 58 |
+
"left_branch_island_echo_question",
|
| 59 |
+
"left_branch_island_simple_question",
|
| 60 |
+
"matrix_question_npi_licensor_present",
|
| 61 |
+
"npi_present_1",
|
| 62 |
+
"npi_present_2",
|
| 63 |
+
"only_npi_licensor_present",
|
| 64 |
+
"only_npi_scope",
|
| 65 |
+
"passive_1",
|
| 66 |
+
"passive_2",
|
| 67 |
+
"principle_A_c_command",
|
| 68 |
+
"principle_A_case_1",
|
| 69 |
+
"principle_A_case_2",
|
| 70 |
+
"principle_A_domain_1",
|
| 71 |
+
"principle_A_domain_2",
|
| 72 |
+
"principle_A_domain_3",
|
| 73 |
+
"principle_A_reconstruction",
|
| 74 |
+
"regular_plural_subject_verb_agreement_1",
|
| 75 |
+
"regular_plural_subject_verb_agreement_2",
|
| 76 |
+
"sentential_negation_npi_licensor_present",
|
| 77 |
+
"sentential_negation_npi_scope",
|
| 78 |
+
"sentential_subject_island",
|
| 79 |
+
"superlative_quantifiers_1",
|
| 80 |
+
"superlative_quantifiers_2",
|
| 81 |
+
"tough_vs_raising_1",
|
| 82 |
+
"tough_vs_raising_2",
|
| 83 |
+
"transitive",
|
| 84 |
+
"wh_island",
|
| 85 |
+
"wh_questions_object_gap",
|
| 86 |
+
"wh_questions_subject_gap",
|
| 87 |
+
"wh_questions_subject_gap_long_distance",
|
| 88 |
+
"wh_vs_that_no_gap",
|
| 89 |
+
"wh_vs_that_no_gap_long_distance",
|
| 90 |
+
"wh_vs_that_with_gap",
|
| 91 |
+
"wh_vs_that_with_gap_long_distance",
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
_CITATION = r"""
|
| 95 |
@article{warstadt2020blimp,
|
| 96 |
author = {Warstadt, Alex and Parrish, Alicia and Liu, Haokun and Mohananey, Anhad and Peng, Wei and Wang, Sheng-Fu and Bowman, Samuel R.},
|
| 97 |
title = {BLiMP: The Benchmark of Linguistic Minimal Pairs for English},
|
|
|
|
| 107 |
}
|
| 108 |
"""
|
| 109 |
|
| 110 |
+
_DESCRIPTION = """BLiMP is a challenge set for evaluating what language models (LMs) know about major grammatical phenomena in English.
|
|
|
|
| 111 |
BLiMP consists of 67 sub-datasets, each containing 1000 minimal pairs isolating specific contrasts in syntax, morphology, or semantics.
|
| 112 |
The data is automatically generated according to expert-crafted grammars. Aggregate human agreement with the labels is 96.4%.
|
| 113 |
We use BLiMP to evaluate an n-gram LM, LSTM LM, GPT-2, and Transformer-XL.
|
|
|
|
| 117 |
|
| 118 |
_KWARGS_DESCRIPTION = """
|
| 119 |
Args:
|
| 120 |
+
model_id (str): model used for calculating Blimp, NOTE: should be a causal LM model
|
| 121 |
+
predictions (list[str]): names of metrics to run. pass empty list or ["*"] to run all of them
|
| 122 |
batch_size (int): the batch size to run texts through the model. Defaults to 16.
|
| 123 |
+
device (str): device to run on, defaults to 'cuda' when available.
|
| 124 |
+
samples_per_set (int): the number of samples per phenomenon, defaults to 1_000.
|
| 125 |
+
|
| 126 |
Returns:
|
| 127 |
blimp: dictionary containing the blimp scores for each of the 67 sub-datasets, as well as the overall accuracy.
|
| 128 |
An LM’s overall accuracy on BLiMP is simply the proportion of the 67,000 minimal pairs in which the model assigns a higher probability to the acceptable sentence.
|
|
|
|
| 132 |
|
| 133 |
|
| 134 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
| 135 |
+
class Blimp(evaluate.Metric):
|
| 136 |
def _info(self):
|
| 137 |
return evaluate.MetricInfo(
|
| 138 |
module_type="metric",
|
|
|
|
| 152 |
|
| 153 |
def _compute(
|
| 154 |
self,
|
|
|
|
| 155 |
model_id,
|
| 156 |
+
predictions=None,
|
| 157 |
batch_size: int = 16,
|
|
|
|
| 158 |
device=None,
|
| 159 |
+
samples_per_set: int = 1_000,
|
| 160 |
):
|
| 161 |
if device is not None:
|
| 162 |
assert device in ["gpu", "cpu", "cuda", "mps"], (
|
|
|
|
| 173 |
|
| 174 |
model = AutoModelForCausalLM.from_pretrained(model_id)
|
| 175 |
model = model.to(device)
|
| 176 |
+
model.eval()
|
| 177 |
|
| 178 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 179 |
|
|
|
|
| 191 |
# assign one of the special tokens to also be the pad token
|
| 192 |
tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
|
| 193 |
|
| 194 |
+
print("PAD", tokenizer.pad_token_id)
|
| 195 |
+
|
| 196 |
+
run_all = len(predictions) == 0 or predictions[0] == "*"
|
| 197 |
+
blimp_sets = (
|
| 198 |
+
BLIMP_PHENOMENA
|
| 199 |
+
if run_all
|
| 200 |
+
else [p for p in BLIMP_PHENOMENA if p.lower() in predictions]
|
| 201 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
+
assert len(blimp_sets) > 0, "no valid phenomena selected"
|
|
|
|
| 204 |
|
| 205 |
+
results = {}
|
| 206 |
+
|
| 207 |
+
for phenomenon in logging.tqdm(blimp_sets, desc="Evaluating phenomena..."):
|
| 208 |
+
dataset = datasets.load_dataset("nyu-mll/blimp", phenomenon)["train"]
|
| 209 |
+
|
| 210 |
+
# Prepare batches of good and bad sentences
|
| 211 |
+
|
| 212 |
+
sents = [(x["sentence_good"], x["sentence_bad"]) for x in dataset]
|
| 213 |
+
good_sents, bad_sents = zip(*sents[: min(1000, samples_per_set)])
|
| 214 |
+
|
| 215 |
+
# Get probabilities in batches
|
| 216 |
+
good_probs = get_batch_probabilities(
|
| 217 |
+
model, tokenizer, good_sents, device, batch_size, phenomenon
|
| 218 |
)
|
| 219 |
+
bad_probs = get_batch_probabilities(
|
| 220 |
+
model,
|
| 221 |
+
tokenizer,
|
| 222 |
+
bad_sents,
|
| 223 |
+
device,
|
| 224 |
+
batch_size,
|
| 225 |
+
phenomenon,
|
| 226 |
+
sent_type="bad",
|
| 227 |
)
|
| 228 |
|
| 229 |
+
# Compare probabilities
|
| 230 |
+
correct = sum(g > b for g, b in zip(good_probs, bad_probs))
|
| 231 |
+
accuracy = correct / len(good_probs)
|
| 232 |
+
results[phenomenon] = accuracy
|
| 233 |
+
|
| 234 |
+
# Calculate overall accuracy
|
| 235 |
+
overall_accuracy = sum(results.values()) / len(results)
|
| 236 |
+
|
| 237 |
+
return {"phenomenon_accuracies": results, "overall_accuracy": overall_accuracy}
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def get_batch_probabilities(
|
| 241 |
+
model,
|
| 242 |
+
tokenizer,
|
| 243 |
+
sentences: list[str],
|
| 244 |
+
device: str,
|
| 245 |
+
batch_size: int,
|
| 246 |
+
phenomenon: str,
|
| 247 |
+
sent_type: str = "good",
|
| 248 |
+
):
|
| 249 |
+
"""Compute log probabilities for a batch of sentences"""
|
| 250 |
+
probs = []
|
| 251 |
+
|
| 252 |
+
for i in logging.tqdm(
|
| 253 |
+
range(0, len(sentences), batch_size),
|
| 254 |
+
desc=f"{phenomenon} - {sent_type} sentences...",
|
| 255 |
+
leave=False,
|
| 256 |
+
):
|
| 257 |
+
batch = sentences[i : i + batch_size]
|
| 258 |
+
inputs = tokenizer(
|
| 259 |
+
batch, padding=batch_size > 1, return_tensors="pt", truncation=True
|
| 260 |
+
).to(device)
|
| 261 |
+
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
outputs = model(**inputs)
|
| 264 |
+
|
| 265 |
+
labels = inputs.input_ids
|
| 266 |
+
|
| 267 |
+
# Compute log probabilities
|
| 268 |
+
log_probs = torch.nn.functional.log_softmax(outputs.logits, dim=-1)
|
| 269 |
+
|
| 270 |
+
# Get probability of each actual token
|
| 271 |
+
token_log_probs = torch.gather(log_probs, 2, labels.unsqueeze(-1)).squeeze(-1)
|
| 272 |
+
|
| 273 |
+
if batch_size > 1:
|
| 274 |
+
# Create attention mask for padding
|
| 275 |
+
mask = (labels != tokenizer.pad_token_id).float()
|
| 276 |
+
token_log_probs *= mask
|
| 277 |
+
|
| 278 |
+
# sum log probabilities
|
| 279 |
+
sequence_log_probs = (token_log_probs).sum(dim=1)
|
| 280 |
|
| 281 |
+
probs.extend(sequence_log_probs.cpu().tolist())
|
| 282 |
|
| 283 |
+
return probs
|