yu-val-weiss
commited on
Commit
·
eb6c7b0
1
Parent(s):
17ddf40
Update blimp.py
Browse files
blimp.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
| 14 |
"""Blimp Metric."""
|
| 15 |
|
| 16 |
from collections import defaultdict
|
|
|
|
| 17 |
|
| 18 |
import datasets
|
| 19 |
import evaluate
|
|
@@ -123,7 +124,7 @@ Args:
|
|
| 123 |
predictions (list[str]): names of metrics to run. pass empty list or ["*"] to run all of them
|
| 124 |
batch_size (int): the batch size to run texts through the model. Defaults to 16.
|
| 125 |
device (str): device to run on, defaults to 'cuda' when available.
|
| 126 |
-
samples_per_set (int): the number of samples per phenomenon, defaults to
|
| 127 |
|
| 128 |
Returns:
|
| 129 |
blimp: dictionary containing the blimp scores for each of the 67 sub-datasets, as well as the overall accuracy.
|
|
@@ -156,7 +157,7 @@ class Blimp(evaluate.Metric):
|
|
| 156 |
predictions=None,
|
| 157 |
batch_size: int = 16,
|
| 158 |
device=None,
|
| 159 |
-
samples_per_set: int =
|
| 160 |
):
|
| 161 |
if device is not None:
|
| 162 |
assert device in ["gpu", "cpu", "cuda", "mps"], (
|
|
@@ -171,6 +172,9 @@ class Blimp(evaluate.Metric):
|
|
| 171 |
else ("mps" if torch.mps.is_available() else "cpu")
|
| 172 |
)
|
| 173 |
|
|
|
|
|
|
|
|
|
|
| 174 |
model = AutoModelForCausalLM.from_pretrained(model_id)
|
| 175 |
model = model.to(device)
|
| 176 |
model.eval()
|
|
|
|
| 14 |
"""Blimp Metric."""
|
| 15 |
|
| 16 |
from collections import defaultdict
|
| 17 |
+
from typing import Optional
|
| 18 |
|
| 19 |
import datasets
|
| 20 |
import evaluate
|
|
|
|
| 124 |
predictions (list[str]): names of metrics to run. pass empty list or ["*"] to run all of them
|
| 125 |
batch_size (int): the batch size to run texts through the model. Defaults to 16.
|
| 126 |
device (str): device to run on, defaults to 'cuda' when available.
|
| 127 |
+
samples_per_set (Optional[int]): the number of samples per phenomenon. Max is 1,000 (but will not error if higher value given.) If None, defaults to 1000.
|
| 128 |
|
| 129 |
Returns:
|
| 130 |
blimp: dictionary containing the blimp scores for each of the 67 sub-datasets, as well as the overall accuracy.
|
|
|
|
| 157 |
predictions=None,
|
| 158 |
batch_size: int = 16,
|
| 159 |
device=None,
|
| 160 |
+
samples_per_set: Optional[int] = None,
|
| 161 |
):
|
| 162 |
if device is not None:
|
| 163 |
assert device in ["gpu", "cpu", "cuda", "mps"], (
|
|
|
|
| 172 |
else ("mps" if torch.mps.is_available() else "cpu")
|
| 173 |
)
|
| 174 |
|
| 175 |
+
if samples_per_set is None or samples_per_set <= 0:
|
| 176 |
+
samples_per_set = 1000
|
| 177 |
+
|
| 178 |
model = AutoModelForCausalLM.from_pretrained(model_id)
|
| 179 |
model = model.to(device)
|
| 180 |
model.eval()
|