Spaces:
Sleeping
Sleeping
Add return average flag.
Browse files- tcp_accuracy.py +10 -5
- tests/test_metric.py +8 -0
tcp_accuracy.py
CHANGED
|
@@ -93,9 +93,14 @@ class TCPAccuracy(evaluate.Metric):
|
|
| 93 |
return match.group(1).replace("GMT", "").strip()
|
| 94 |
return None
|
| 95 |
|
| 96 |
-
def _compute(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
"""Returns the scores"""
|
| 98 |
-
# TODO: Compute the different scores of the module
|
| 99 |
if isinstance(subset, str):
|
| 100 |
subset = [subset] * len(predictions)
|
| 101 |
predictions = [self.extract_boxed_answer(p) for p in predictions]
|
|
@@ -104,6 +109,6 @@ class TCPAccuracy(evaluate.Metric):
|
|
| 104 |
for r, s in zip(references, subset)
|
| 105 |
]
|
| 106 |
accuracy = [int(i == j) for i, j in zip(predictions, references)]
|
| 107 |
-
|
| 108 |
-
"accuracy": accuracy
|
| 109 |
-
}
|
|
|
|
| 93 |
return match.group(1).replace("GMT", "").strip()
|
| 94 |
return None
|
| 95 |
|
| 96 |
+
def _compute(
|
| 97 |
+
self,
|
| 98 |
+
predictions,
|
| 99 |
+
references,
|
| 100 |
+
subset: str | list[str],
|
| 101 |
+
return_average: bool = True,
|
| 102 |
+
):
|
| 103 |
"""Returns the scores"""
|
|
|
|
| 104 |
if isinstance(subset, str):
|
| 105 |
subset = [subset] * len(predictions)
|
| 106 |
predictions = [self.extract_boxed_answer(p) for p in predictions]
|
|
|
|
| 109 |
for r, s in zip(references, subset)
|
| 110 |
]
|
| 111 |
accuracy = [int(i == j) for i, j in zip(predictions, references)]
|
| 112 |
+
if return_average:
|
| 113 |
+
return {"accuracy": sum(accuracy) / len(accuracy)}
|
| 114 |
+
return {"accuracy": accuracy}
|
tests/test_metric.py
CHANGED
|
@@ -21,5 +21,13 @@ def test_metric():
|
|
| 21 |
predictions=[response_1, response_2, response_3],
|
| 22 |
references=references,
|
| 23 |
subset=subsets,
|
|
|
|
| 24 |
)
|
| 25 |
assert results["accuracy"] == [1, 0, 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
predictions=[response_1, response_2, response_3],
|
| 22 |
references=references,
|
| 23 |
subset=subsets,
|
| 24 |
+
return_average=False,
|
| 25 |
)
|
| 26 |
assert results["accuracy"] == [1, 0, 1]
|
| 27 |
+
metric = TCPAccuracy()
|
| 28 |
+
results = metric.compute(
|
| 29 |
+
predictions=[response_1, response_2, response_3],
|
| 30 |
+
references=references,
|
| 31 |
+
subset=subsets,
|
| 32 |
+
)
|
| 33 |
+
assert results["accuracy"] == 2/3
|