aauss commited on
Commit
86953e8
·
1 Parent(s): 28a8c77

Add early input check, improve type hints and format code.

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. tram_accuracy.py +24 -8
app.py CHANGED
@@ -3,4 +3,4 @@ from evaluate.utils import launch_gradio_widget
3
 
4
 
5
  module = evaluate.load("aauss/tram_accuracy")
6
- launch_gradio_widget(module)
 
3
 
4
 
5
  module = evaluate.load("aauss/tram_accuracy")
6
+ launch_gradio_widget(module)
tram_accuracy.py CHANGED
@@ -14,8 +14,18 @@
14
  """Metric to calculate the accuracy for the TRAM benchmark by Wang et al. (2024)."""
15
 
16
  import re
17
- import evaluate
 
18
  import datasets
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  _CITATION = """\
@@ -44,14 +54,14 @@ Returns:
44
  """
45
 
46
 
47
- TRAM_ANSWER_REGEX = re.compile(r"[Tt]he final answer is \(([A-D])\)")
48
 
49
 
50
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
51
  class TRAMAccuracy(evaluate.Metric):
52
  """Calculates the accuracy for the (multiple choice) TRAM datasets by extracting the final answer from the prediction and comparing it to the reference answer."""
53
 
54
- def _info(self):
55
  return evaluate.MetricInfo(
56
  module_type="metric",
57
  description=_DESCRIPTION,
@@ -65,13 +75,20 @@ class TRAMAccuracy(evaluate.Metric):
65
  }
66
  ),
67
  homepage="https://huggingface.co/spaces/aauss/tram_accuracy",
68
- codebase_urls=["https://huggingface.co/spaces/aauss/tram_accuracy/tree/main"],
 
 
69
  reference_urls=["https://huggingface.co/datasets/Warrieryes/TRAM-Temporal"],
70
  )
71
 
72
- def _compute(self, predictions, references, return_average=True):
 
 
 
 
 
73
  """Returns the accuracy for the (multiple choice) TRAM datasets."""
74
- if len(predictions) == 0:
75
  raise ValueError("predictions cannot be empty")
76
  if len(predictions) != len(references):
77
  raise ValueError(
@@ -91,5 +108,4 @@ class TRAMAccuracy(evaluate.Metric):
91
  ]
92
  if return_average:
93
  return {"accuracy": sum(accuracy) / len(accuracy)}
94
- else:
95
- return {"accuracy": accuracy}
 
14
  """Metric to calculate the accuracy for the TRAM benchmark by Wang et al. (2024)."""
15
 
16
  import re
17
+ from typing import TypedDict
18
+
19
  import datasets
20
+ import evaluate
21
+
22
+ VALID_ANSWER_CHOICES = frozenset({"A", "B", "C", "D"})
23
+
24
+ TRAM_ANSWER_PATTERN = r"[Tt]he final answer is \(([A-D])\)"
25
+
26
+
27
+ class AccuracyResult(TypedDict):
28
+ accuracy: float | list[int]
29
 
30
 
31
  _CITATION = """\
 
54
  """
55
 
56
 
57
+ TRAM_ANSWER_REGEX = re.compile(TRAM_ANSWER_PATTERN)
58
 
59
 
60
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
61
  class TRAMAccuracy(evaluate.Metric):
62
  """Calculates the accuracy for the (multiple choice) TRAM datasets by extracting the final answer from the prediction and comparing it to the reference answer."""
63
 
64
+ def _info(self) -> evaluate.MetricInfo:
65
  return evaluate.MetricInfo(
66
  module_type="metric",
67
  description=_DESCRIPTION,
 
75
  }
76
  ),
77
  homepage="https://huggingface.co/spaces/aauss/tram_accuracy",
78
+ codebase_urls=[
79
+ "https://huggingface.co/spaces/aauss/tram_accuracy/tree/main"
80
+ ],
81
  reference_urls=["https://huggingface.co/datasets/Warrieryes/TRAM-Temporal"],
82
  )
83
 
84
+ def _compute(
85
+ self,
86
+ predictions: list[str],
87
+ references: list[str],
88
+ return_average: bool = True,
89
+ ) -> AccuracyResult:
90
  """Returns the accuracy for the (multiple choice) TRAM datasets."""
91
+ if not predictions:
92
  raise ValueError("predictions cannot be empty")
93
  if len(predictions) != len(references):
94
  raise ValueError(
 
108
  ]
109
  if return_average:
110
  return {"accuracy": sum(accuracy) / len(accuracy)}
111
+ return {"accuracy": accuracy}