Spaces:
Build error
Build error
minor code cleanup
Browse files- syntaxgym.py +0 -4
- test.py +3 -5
syntaxgym.py
CHANGED
|
@@ -197,14 +197,10 @@ class SyntaxGym(evaluate.EvaluationModule):
|
|
| 197 |
surps_shifted = surprisals[:, :-1, :]
|
| 198 |
expected_ids = input_ids[:, 1:]
|
| 199 |
|
| 200 |
-
# TODO: check this logic
|
| 201 |
-
tt = expected_ids.unsqueeze(2)
|
| 202 |
# reindexed surprisals: B * (T - 1)
|
| 203 |
surprisals = torch.gather(surps_shifted, 2, expected_ids.unsqueeze(2)) \
|
| 204 |
.squeeze(2)
|
| 205 |
|
| 206 |
-
# surprisals is now B * (T - 1)
|
| 207 |
-
|
| 208 |
#### aggregate
|
| 209 |
condition_names = item["conditions"]["condition_name"]
|
| 210 |
region_totals = {condition_name: defaultdict(float)
|
|
|
|
| 197 |
surps_shifted = surprisals[:, :-1, :]
|
| 198 |
expected_ids = input_ids[:, 1:]
|
| 199 |
|
|
|
|
|
|
|
| 200 |
# reindexed surprisals: B * (T - 1)
|
| 201 |
surprisals = torch.gather(surps_shifted, 2, expected_ids.unsqueeze(2)) \
|
| 202 |
.squeeze(2)
|
| 203 |
|
|
|
|
|
|
|
| 204 |
#### aggregate
|
| 205 |
condition_names = item["conditions"]["condition_name"]
|
| 206 |
region_totals = {condition_name: defaultdict(float)
|
test.py
CHANGED
|
@@ -14,6 +14,7 @@ def syntaxgym_dataset():
|
|
| 14 |
|
| 15 |
@pytest.fixture(scope="session")
|
| 16 |
def syntaxgym_metric():
|
|
|
|
| 17 |
return evaluate.load("./syntaxgym.py")
|
| 18 |
|
| 19 |
|
|
@@ -488,17 +489,14 @@ GPT2_SUBORDINATION_SRC_REFERENCE = \
|
|
| 488 |
('sub_no-matrix', 5): 4.819862633503057}]
|
| 489 |
|
| 490 |
|
| 491 |
-
def test_gpt_subordination_region_totals():
|
| 492 |
"""
|
| 493 |
Check region-level surprisals against the original syntaxgym-core
|
| 494 |
implementation, using the same underlying `gpt2` model.
|
| 495 |
"""
|
| 496 |
-
reference = ... # TODO
|
| 497 |
|
| 498 |
-
# TODO work out references
|
| 499 |
dataset = datasets.load_dataset("cpllab/syntaxgym", "subordination_src-src")
|
| 500 |
-
|
| 501 |
-
result = metric.compute(suite=dataset["test"], model_id="gpt2")
|
| 502 |
|
| 503 |
from pprint import pprint
|
| 504 |
pprint(result["region_totals"][0])
|
|
|
|
| 14 |
|
| 15 |
@pytest.fixture(scope="session")
|
| 16 |
def syntaxgym_metric():
|
| 17 |
+
# TODO work out reference
|
| 18 |
return evaluate.load("./syntaxgym.py")
|
| 19 |
|
| 20 |
|
|
|
|
| 489 |
('sub_no-matrix', 5): 4.819862633503057}]
|
| 490 |
|
| 491 |
|
| 492 |
+
def test_gpt_subordination_region_totals(syntaxgym_metric):
|
| 493 |
"""
|
| 494 |
Check region-level surprisals against the original syntaxgym-core
|
| 495 |
implementation, using the same underlying `gpt2` model.
|
| 496 |
"""
|
|
|
|
| 497 |
|
|
|
|
| 498 |
dataset = datasets.load_dataset("cpllab/syntaxgym", "subordination_src-src")
|
| 499 |
+
result = syntaxgym_metric.compute(suite=dataset["test"], model_id="gpt2")
|
|
|
|
| 500 |
|
| 501 |
from pprint import pprint
|
| 502 |
pprint(result["region_totals"][0])
|