Spaces:
Sleeping
Sleeping
Add early input checks, improve type hints and format code.
Browse files- app.py +1 -1
- tcp_accuracy.py +29 -10
- tests/test_metric.py +1 -1
app.py
CHANGED
|
@@ -3,4 +3,4 @@ from evaluate.utils import launch_gradio_widget
|
|
| 3 |
|
| 4 |
|
| 5 |
module = evaluate.load("aauss/tcp_accuracy")
|
| 6 |
-
launch_gradio_widget(module)
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
module = evaluate.load("aauss/tcp_accuracy")
|
| 6 |
+
launch_gradio_widget(module)
|
tcp_accuracy.py
CHANGED
|
@@ -14,9 +14,23 @@
|
|
| 14 |
"""TCP Accuracy metric for evaluating temporal constraint-based planning tasks."""
|
| 15 |
|
| 16 |
import re
|
|
|
|
| 17 |
|
| 18 |
-
import evaluate
|
| 19 |
import datasets
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
_CITATION = """\
|
|
@@ -61,9 +75,7 @@ Examples:
|
|
| 61 |
class TCPAccuracy(evaluate.Metric):
|
| 62 |
"""Accuracy metric for the TCP (Temporal Constraint-Based Planning) benchmark."""
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def _info(self):
|
| 67 |
return evaluate.MetricInfo(
|
| 68 |
module_type="metric",
|
| 69 |
description=_DESCRIPTION,
|
|
@@ -76,12 +88,14 @@ class TCPAccuracy(evaluate.Metric):
|
|
| 76 |
}
|
| 77 |
),
|
| 78 |
homepage="https://huggingface.co/spaces/aauss/tcp_accuracy",
|
| 79 |
-
codebase_urls=[
|
|
|
|
|
|
|
| 80 |
reference_urls=["https://aclanthology.org/2025.emnlp-main.1142/"],
|
| 81 |
)
|
| 82 |
|
| 83 |
def extract_boxed_answer(self, prediction: str) -> str | None:
|
| 84 |
-
match =
|
| 85 |
if match:
|
| 86 |
return match.group(1).strip()
|
| 87 |
return None
|
|
@@ -90,21 +104,26 @@ class TCPAccuracy(evaluate.Metric):
|
|
| 90 |
self,
|
| 91 |
predictions: list[str],
|
| 92 |
references: list[str],
|
| 93 |
-
subset:
|
| 94 |
return_average: bool = True,
|
| 95 |
-
) ->
|
| 96 |
"""Returns the scores"""
|
| 97 |
if not predictions:
|
| 98 |
raise ValueError("predictions cannot be empty")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
if isinstance(subset, str):
|
| 100 |
subset = [subset] * len(predictions)
|
| 101 |
extracted_predictions = [self.extract_boxed_answer(p) for p in predictions]
|
| 102 |
extracted_predictions = [
|
| 103 |
-
p.replace("GMT", "").strip() if p and s ==
|
| 104 |
for p, s in zip(extracted_predictions, subset)
|
| 105 |
]
|
| 106 |
references = [
|
| 107 |
-
r.replace("GMT", "").strip() if s ==
|
| 108 |
for r, s in zip(references, subset)
|
| 109 |
]
|
| 110 |
accuracy = [int(i == j) for i, j in zip(extracted_predictions, references)]
|
|
|
|
| 14 |
"""TCP Accuracy metric for evaluating temporal constraint-based planning tasks."""
|
| 15 |
|
| 16 |
import re
|
| 17 |
+
from typing import Literal, TypedDict
|
| 18 |
|
|
|
|
| 19 |
import datasets
|
| 20 |
+
import evaluate
|
| 21 |
+
|
| 22 |
+
SUBSET_TCP_SHORT = "tcp_short"
|
| 23 |
+
SUBSET_TCP_LONG = "tcp_long"
|
| 24 |
+
VALID_SUBSETS = frozenset({SUBSET_TCP_SHORT, SUBSET_TCP_LONG})
|
| 25 |
+
|
| 26 |
+
SubsetType = Literal["tcp_short", "tcp_long"]
|
| 27 |
+
|
| 28 |
+
BOXED_ANSWER_PATTERN = r"\\boxed\{([^}]*)\}"
|
| 29 |
+
BOXED_ANSWER_REGEX = re.compile(BOXED_ANSWER_PATTERN, re.DOTALL)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class AccuracyResult(TypedDict):
|
| 33 |
+
accuracy: float | list[int]
|
| 34 |
|
| 35 |
|
| 36 |
_CITATION = """\
|
|
|
|
| 75 |
class TCPAccuracy(evaluate.Metric):
|
| 76 |
"""Accuracy metric for the TCP (Temporal Constraint-Based Planning) benchmark."""
|
| 77 |
|
| 78 |
+
def _info(self) -> evaluate.MetricInfo:
|
|
|
|
|
|
|
| 79 |
return evaluate.MetricInfo(
|
| 80 |
module_type="metric",
|
| 81 |
description=_DESCRIPTION,
|
|
|
|
| 88 |
}
|
| 89 |
),
|
| 90 |
homepage="https://huggingface.co/spaces/aauss/tcp_accuracy",
|
| 91 |
+
codebase_urls=[
|
| 92 |
+
"https://huggingface.co/spaces/aauss/tcp_accuracy/tree/main"
|
| 93 |
+
],
|
| 94 |
reference_urls=["https://aclanthology.org/2025.emnlp-main.1142/"],
|
| 95 |
)
|
| 96 |
|
| 97 |
def extract_boxed_answer(self, prediction: str) -> str | None:
|
| 98 |
+
match = BOXED_ANSWER_REGEX.search(prediction)
|
| 99 |
if match:
|
| 100 |
return match.group(1).strip()
|
| 101 |
return None
|
|
|
|
| 104 |
self,
|
| 105 |
predictions: list[str],
|
| 106 |
references: list[str],
|
| 107 |
+
subset: SubsetType | list[SubsetType],
|
| 108 |
return_average: bool = True,
|
| 109 |
+
) -> AccuracyResult:
|
| 110 |
"""Returns the scores"""
|
| 111 |
if not predictions:
|
| 112 |
raise ValueError("predictions cannot be empty")
|
| 113 |
+
if len(predictions) != len(references):
|
| 114 |
+
raise ValueError(
|
| 115 |
+
f"predictions and references must have same length, "
|
| 116 |
+
f"got {len(predictions)} and {len(references)}"
|
| 117 |
+
)
|
| 118 |
if isinstance(subset, str):
|
| 119 |
subset = [subset] * len(predictions)
|
| 120 |
extracted_predictions = [self.extract_boxed_answer(p) for p in predictions]
|
| 121 |
extracted_predictions = [
|
| 122 |
+
p.replace("GMT", "").strip() if p and s == SUBSET_TCP_SHORT else p
|
| 123 |
for p, s in zip(extracted_predictions, subset)
|
| 124 |
]
|
| 125 |
references = [
|
| 126 |
+
r.replace("GMT", "").strip() if s == SUBSET_TCP_SHORT else r
|
| 127 |
for r, s in zip(references, subset)
|
| 128 |
]
|
| 129 |
accuracy = [int(i == j) for i, j in zip(extracted_predictions, references)]
|
tests/test_metric.py
CHANGED
|
@@ -30,4 +30,4 @@ def test_metric():
|
|
| 30 |
references=references,
|
| 31 |
subset=subsets,
|
| 32 |
)
|
| 33 |
-
assert results["accuracy"] == 2/3
|
|
|
|
| 30 |
references=references,
|
| 31 |
subset=subsets,
|
| 32 |
)
|
| 33 |
+
assert results["accuracy"] == 2 / 3
|