aauss commited on
Commit
52c7752
·
1 Parent(s): a031da6

Add early input checks, improve type hints and format code.

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. tcp_accuracy.py +29 -10
  3. tests/test_metric.py +1 -1
app.py CHANGED
@@ -3,4 +3,4 @@ from evaluate.utils import launch_gradio_widget
3
 
4
 
5
  module = evaluate.load("aauss/tcp_accuracy")
6
- launch_gradio_widget(module)
 
3
 
4
 
5
  module = evaluate.load("aauss/tcp_accuracy")
6
+ launch_gradio_widget(module)
tcp_accuracy.py CHANGED
@@ -14,9 +14,23 @@
14
  """TCP Accuracy metric for evaluating temporal constraint-based planning tasks."""
15
 
16
  import re
 
17
 
18
- import evaluate
19
  import datasets
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  _CITATION = """\
@@ -61,9 +75,7 @@ Examples:
61
  class TCPAccuracy(evaluate.Metric):
62
  """Accuracy metric for the TCP (Temporal Constraint-Based Planning) benchmark."""
63
 
64
- BOXED_ANSWER_PATTERN = r"\\boxed\{([^}]*)\}"
65
-
66
- def _info(self):
67
  return evaluate.MetricInfo(
68
  module_type="metric",
69
  description=_DESCRIPTION,
@@ -76,12 +88,14 @@ class TCPAccuracy(evaluate.Metric):
76
  }
77
  ),
78
  homepage="https://huggingface.co/spaces/aauss/tcp_accuracy",
79
- codebase_urls=["https://huggingface.co/spaces/aauss/tcp_accuracy/tree/main"],
 
 
80
  reference_urls=["https://aclanthology.org/2025.emnlp-main.1142/"],
81
  )
82
 
83
  def extract_boxed_answer(self, prediction: str) -> str | None:
84
- match = re.search(self.BOXED_ANSWER_PATTERN, prediction, re.DOTALL)
85
  if match:
86
  return match.group(1).strip()
87
  return None
@@ -90,21 +104,26 @@ class TCPAccuracy(evaluate.Metric):
90
  self,
91
  predictions: list[str],
92
  references: list[str],
93
- subset: str | list[str],
94
  return_average: bool = True,
95
- ) -> dict[str, float | list[int]]:
96
  """Returns the scores"""
97
  if not predictions:
98
  raise ValueError("predictions cannot be empty")
 
 
 
 
 
99
  if isinstance(subset, str):
100
  subset = [subset] * len(predictions)
101
  extracted_predictions = [self.extract_boxed_answer(p) for p in predictions]
102
  extracted_predictions = [
103
- p.replace("GMT", "").strip() if p and s == "tcp_short" else p
104
  for p, s in zip(extracted_predictions, subset)
105
  ]
106
  references = [
107
- r.replace("GMT", "").strip() if s == "tcp_short" else r
108
  for r, s in zip(references, subset)
109
  ]
110
  accuracy = [int(i == j) for i, j in zip(extracted_predictions, references)]
 
14
  """TCP Accuracy metric for evaluating temporal constraint-based planning tasks."""
15
 
16
  import re
17
+ from typing import Literal, TypedDict
18
 
 
19
  import datasets
20
+ import evaluate
21
+
22
+ SUBSET_TCP_SHORT = "tcp_short"
23
+ SUBSET_TCP_LONG = "tcp_long"
24
+ VALID_SUBSETS = frozenset({SUBSET_TCP_SHORT, SUBSET_TCP_LONG})
25
+
26
+ SubsetType = Literal["tcp_short", "tcp_long"]
27
+
28
+ BOXED_ANSWER_PATTERN = r"\\boxed\{([^}]*)\}"
29
+ BOXED_ANSWER_REGEX = re.compile(BOXED_ANSWER_PATTERN, re.DOTALL)
30
+
31
+
32
+ class AccuracyResult(TypedDict):
33
+ accuracy: float | list[int]
34
 
35
 
36
  _CITATION = """\
 
75
  class TCPAccuracy(evaluate.Metric):
76
  """Accuracy metric for the TCP (Temporal Constraint-Based Planning) benchmark."""
77
 
78
+ def _info(self) -> evaluate.MetricInfo:
 
 
79
  return evaluate.MetricInfo(
80
  module_type="metric",
81
  description=_DESCRIPTION,
 
88
  }
89
  ),
90
  homepage="https://huggingface.co/spaces/aauss/tcp_accuracy",
91
+ codebase_urls=[
92
+ "https://huggingface.co/spaces/aauss/tcp_accuracy/tree/main"
93
+ ],
94
  reference_urls=["https://aclanthology.org/2025.emnlp-main.1142/"],
95
  )
96
 
97
  def extract_boxed_answer(self, prediction: str) -> str | None:
98
+ match = BOXED_ANSWER_REGEX.search(prediction)
99
  if match:
100
  return match.group(1).strip()
101
  return None
 
104
  self,
105
  predictions: list[str],
106
  references: list[str],
107
+ subset: SubsetType | list[SubsetType],
108
  return_average: bool = True,
109
+ ) -> AccuracyResult:
110
  """Returns the scores"""
111
  if not predictions:
112
  raise ValueError("predictions cannot be empty")
113
+ if len(predictions) != len(references):
114
+ raise ValueError(
115
+ f"predictions and references must have same length, "
116
+ f"got {len(predictions)} and {len(references)}"
117
+ )
118
  if isinstance(subset, str):
119
  subset = [subset] * len(predictions)
120
  extracted_predictions = [self.extract_boxed_answer(p) for p in predictions]
121
  extracted_predictions = [
122
+ p.replace("GMT", "").strip() if p and s == SUBSET_TCP_SHORT else p
123
  for p, s in zip(extracted_predictions, subset)
124
  ]
125
  references = [
126
+ r.replace("GMT", "").strip() if s == SUBSET_TCP_SHORT else r
127
  for r, s in zip(references, subset)
128
  ]
129
  accuracy = [int(i == j) for i, j in zip(extracted_predictions, references)]
tests/test_metric.py CHANGED
@@ -30,4 +30,4 @@ def test_metric():
30
  references=references,
31
  subset=subsets,
32
  )
33
- assert results["accuracy"] == 2/3
 
30
  references=references,
31
  subset=subsets,
32
  )
33
+ assert results["accuracy"] == 2 / 3