Spaces:
Sleeping
Sleeping
| # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """TCP Accuracy metric for evaluating temporal constraint-based planning tasks.""" | |
| import re | |
| from typing import Literal, TypedDict | |
| import datasets | |
| import evaluate | |
| SUBSET_TCP_SHORT = "tcp_short" | |
| SUBSET_TCP_LONG = "tcp_long" | |
| VALID_SUBSETS = frozenset({SUBSET_TCP_SHORT, SUBSET_TCP_LONG}) | |
| SubsetType = Literal["tcp_short", "tcp_long"] | |
| BOXED_ANSWER_PATTERN = r"\\boxed\{([^}]*)\}" | |
| BOXED_ANSWER_REGEX = re.compile(BOXED_ANSWER_PATTERN, re.DOTALL) | |
| class AccuracyResult(TypedDict): | |
| accuracy: float | list[int] | |
| _CITATION = """\ | |
| @software{abbood2025tcp_accuracy, | |
| title={TCP Accuracy}, | |
| author={Abbood, Auss}, | |
| year={2025}, | |
| url={https://huggingface.co/spaces/aauss/tcp_accuracy} | |
| } | |
| """ | |
| _DESCRIPTION = """\ | |
| This metric evaluates model predictions on the TCP (Temporal Constraint-Based Planning) benchmark | |
| (Ding et al., 2025). It measures accuracy by extracting answers from LaTeX boxed notation | |
| (\\boxed{answer}) and comparing them against reference answers using exact string matching. | |
| """ | |
| _KWARGS_DESCRIPTION = """ | |
| Calculates accuracy for TCP benchmark predictions. | |
| Args: | |
| predictions: list of prediction strings. Each prediction should contain the | |
| final answer in LaTeX boxed notation: \\boxed{answer}. | |
| references: list of reference answer strings. | |
| subset: either a string or list of strings indicating the subset type | |
| ("tcp_short" or "tcp_long"). For "tcp_short", GMT is stripped before comparison. | |
| return_average: if True (default), returns average accuracy as a float. | |
| If False, returns a list of binary scores (0 or 1) for each sample. | |
| Returns: | |
| accuracy: float (if return_average=True) or list of int (if return_average=False) | |
| Examples: | |
| >>> metric = evaluate.load("aauss/tcp_accuracy") | |
| >>> predictions = ["...\\\\boxed{2012-11-05}", "...\\\\boxed{2020-05-28 16:00}"] | |
| >>> references = ["2012-11-05", "2020-05-28 16:00 GMT"] | |
| >>> results = metric.compute(predictions=predictions, references=references, subset=["tcp_long", "tcp_short"]) | |
| >>> print(results) | |
| {'accuracy': 1.0} | |
| """ | |
| class TCPAccuracy(evaluate.Metric): | |
| """Accuracy metric for the TCP (Temporal Constraint-Based Planning) benchmark.""" | |
| def _info(self) -> evaluate.MetricInfo: | |
| return evaluate.MetricInfo( | |
| module_type="metric", | |
| description=_DESCRIPTION, | |
| citation=_CITATION, | |
| inputs_description=_KWARGS_DESCRIPTION, | |
| features=datasets.Features( | |
| { | |
| "predictions": datasets.Value("string"), | |
| "references": datasets.Value("string"), | |
| } | |
| ), | |
| homepage="https://huggingface.co/spaces/aauss/tcp_accuracy", | |
| codebase_urls=[ | |
| "https://huggingface.co/spaces/aauss/tcp_accuracy/tree/main" | |
| ], | |
| reference_urls=["https://aclanthology.org/2025.emnlp-main.1142/"], | |
| ) | |
| def extract_boxed_answer(self, prediction: str) -> str | None: | |
| match = BOXED_ANSWER_REGEX.search(prediction) | |
| if match: | |
| return match.group(1).strip() | |
| return None | |
| def _compute( | |
| self, | |
| predictions: list[str], | |
| references: list[str], | |
| subset: SubsetType | list[SubsetType], | |
| return_average: bool = True, | |
| ) -> AccuracyResult: | |
| """Returns the scores""" | |
| if not predictions: | |
| raise ValueError("predictions cannot be empty") | |
| if len(predictions) != len(references): | |
| raise ValueError( | |
| f"predictions and references must have same length, " | |
| f"got {len(predictions)} and {len(references)}" | |
| ) | |
| if isinstance(subset, str): | |
| subset = [subset] * len(predictions) | |
| extracted_predictions = [self.extract_boxed_answer(p) for p in predictions] | |
| extracted_predictions = [ | |
| p.replace("GMT", "").strip() if p and s == SUBSET_TCP_SHORT else p | |
| for p, s in zip(extracted_predictions, subset) | |
| ] | |
| references = [ | |
| r.replace("GMT", "").strip() if s == SUBSET_TCP_SHORT else r | |
| for r, s in zip(references, subset) | |
| ] | |
| accuracy = [int(i == j) for i, j in zip(extracted_predictions, references)] | |
| if return_average: | |
| return {"accuracy": sum(accuracy) / len(accuracy)} | |
| return {"accuracy": accuracy} | |