aauss commited on
Commit
98a831a
·
1 Parent(s): 78587a7

Add return average flag.

Browse files
Files changed (2) hide show
  1. tcp_accuracy.py +10 -5
  2. tests/test_metric.py +8 -0
tcp_accuracy.py CHANGED
@@ -93,9 +93,14 @@ class TCPAccuracy(evaluate.Metric):
93
  return match.group(1).replace("GMT", "").strip()
94
  return None
95
 
96
- def _compute(self, predictions, references, subset: str | list[str]):
 
 
 
 
 
 
97
  """Returns the scores"""
98
- # TODO: Compute the different scores of the module
99
  if isinstance(subset, str):
100
  subset = [subset] * len(predictions)
101
  predictions = [self.extract_boxed_answer(p) for p in predictions]
@@ -104,6 +109,6 @@ class TCPAccuracy(evaluate.Metric):
104
  for r, s in zip(references, subset)
105
  ]
106
  accuracy = [int(i == j) for i, j in zip(predictions, references)]
107
- return {
108
- "accuracy": accuracy,
109
- }
 
93
  return match.group(1).replace("GMT", "").strip()
94
  return None
95
 
96
+ def _compute(
97
+ self,
98
+ predictions,
99
+ references,
100
+ subset: str | list[str],
101
+ return_average: bool = True,
102
+ ):
103
  """Returns the scores"""
 
104
  if isinstance(subset, str):
105
  subset = [subset] * len(predictions)
106
  predictions = [self.extract_boxed_answer(p) for p in predictions]
 
109
  for r, s in zip(references, subset)
110
  ]
111
  accuracy = [int(i == j) for i, j in zip(predictions, references)]
112
+ if return_average:
113
+ return {"accuracy": sum(accuracy) / len(accuracy)}
114
+ return {"accuracy": accuracy}
tests/test_metric.py CHANGED
@@ -21,5 +21,13 @@ def test_metric():
21
  predictions=[response_1, response_2, response_3],
22
  references=references,
23
  subset=subsets,
 
24
  )
25
  assert results["accuracy"] == [1, 0, 1]
 
 
 
 
 
 
 
 
21
  predictions=[response_1, response_2, response_3],
22
  references=references,
23
  subset=subsets,
24
+ return_average=False,
25
  )
26
  assert results["accuracy"] == [1, 0, 1]
27
+ metric = TCPAccuracy()
28
+ results = metric.compute(
29
+ predictions=[response_1, response_2, response_3],
30
+ references=references,
31
+ subset=subsets,
32
+ )
33
+ assert results["accuracy"] == 2/3