File size: 10,058 Bytes
6bb843b
 
 
 
 
 
 
 
 
 
 
 
 
71222fd
6bb843b
fb49a5d
 
 
71222fd
 
6bb843b
 
 
 
71222fd
 
 
 
 
6bb843b
 
 
 
71222fd
 
 
6bb843b
 
 
 
71222fd
6bb843b
71222fd
 
 
 
6bb843b
71222fd
 
6bb843b
71222fd
 
 
 
6bb843b
71222fd
6bb843b
 
 
 
 
71222fd
6bb843b
 
 
 
 
 
 
 
 
 
 
 
 
5470154
 
6bb843b
 
71222fd
 
 
6bb843b
 
fb49a5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71222fd
 
 
 
fb49a5d
6bb843b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb49a5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bb843b
 
 
 
 
edd0a90
6bb843b
 
fb49a5d
6bb843b
 
 
 
 
 
 
 
 
 
 
 
 
fb49a5d
 
 
 
 
 
 
 
 
6bb843b
fb49a5d
 
 
6bb843b
 
 
 
 
 
 
 
 
 
 
fb49a5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d20720
fb49a5d
 
 
 
1d20720
fb49a5d
 
1d20720
fb49a5d
1d20720
 
fb49a5d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation metric for the TimeBench temporal reasoning benchmark."""

import re
from datetime import datetime

import datasets
import evaluate
from dateutil import parser
from dateutil.parser import ParserError

_CITATION = """\
@software{abbood2026timebench_eval,
  title={TimeBench Eval},
  author={Abbood, Auss},
  year={2026},
  url={https://huggingface.co/spaces/aauss/timebench_eval}
}
"""

_DESCRIPTION = """\
Evaluation metric for the TimeBench benchmark, which assesses temporal reasoning
abilities in large language models. Supports multiple task types including TempReason,
TimeQA, MenatQA, Date Arithmetic, and TimeDial.
"""


_KWARGS_DESCRIPTION = """
Calculates evaluation metrics for temporal reasoning tasks.
Args:
    predictions: list of prediction strings from the model. Each prediction
        should contain the marker "Thus, the correct answer is:" followed by the answer.
    references: list of reference answer strings.
    task: the task type, one of "TempReason", "TimeQA", "MenatQA", "Date Arithmetic", or "TimeDial".
Returns:
    exact_match: list of exact match scores (0 or 1) for each prediction.
    f1: list of F1 scores for each prediction (for applicable tasks).
Examples:
    >>> timebench_eval = evaluate.load("aauss/timebench_eval")
    >>> predictions = ["Let me think... Thus, the correct answer is: Aug, 1987."]
    >>> references = ["Aug, 1987"]
    >>> results = timebench_eval.compute(predictions=predictions, references=references, task="Date Arithmetic")
    >>> print(results)
    {'exact_match': [1]}
"""


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class TimebenchEval(evaluate.Metric):
    """Evaluation metric for TimeBench temporal reasoning tasks."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.squad_metric = evaluate.load("squad")

    def _info(self):
        return evaluate.MetricInfo(
            module_type="metric",
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            features=datasets.Features(
                {
                    "predictions": datasets.Value("string"),
                    "references": datasets.Value("string"),
                }
            ),
            homepage="https://huggingface.co/spaces/aauss/timebench_eval",
            codebase_urls=["https://huggingface.co/spaces/aauss/timebench_eval/tree/main"],
            reference_urls=["https://huggingface.co/datasets/ulab-ai/Time-Bench"],
        )

    def _compute(
        self, predictions: list[str], references: list[str], task: str
    ) -> dict[str, list[float]]:
        """
        Compute evaluation metrics for the given predictions and references.

        Args:
            predictions: List of prediction strings to evaluate.
            references: List of reference strings to compare against.
            task: Task type, one of: "TempReason", "TimeQA", "MenatQA", "Date Arithmetic", "TimeDial".

        Returns:
            Dictionary containing metric scores (exact_match and/or f1) as lists of floats.
        """
        if task in [
            "TempReason",
            "TimeQA",
            "MenatQA",
        ]:
            return self._call_squad(predictions, references)
        elif task == "Date Arithmetic":
            return self._compare_dates(predictions, references)
        elif task == "TimeDial":
            return self._compute_timedial(predictions, references)
        else:
            raise ValueError(
                f"Unknown task: {task}. Expected one of: TempReason, TimeQA, MenatQA, Date Arithmetic, TimeDial"
            )

    @staticmethod
    def _extract_answer(response: str) -> str | None:
        """Extract the answer from the response"""
        marker = "Thus, the correct answer is:"

        if marker not in response:
            return None
        answer = response.split(marker)[-1]
        # Take only the first line (stops at newlines if model continues)
        answer = answer.strip().split("\n")[0]
        answer = answer.rstrip(".!?").strip()
        if "unanswerable" in answer.lower():
            return "unanswerable"
        return answer or None

    def _extract_selected_options(self, text: str) -> set[str]:
        """
        Extract selected option letters (A, B, C, D) from various formats:
        - "B, C"
        - "B and C"
        - "B & C"
        - "B && C"
        - "B. No more than ten minutes && C. No more than five minutes"
        - "Options B and C"
        - "The answer is B, C"
        """
        if not text:
            return set()

        # Pattern matches option letters that appear:
        # 1. At word boundary followed by period, comma, space, &, or end: \b[A-D](?=[.\s,&]|$)
        # 2. This avoids matching letters inside words like "CAD" or "BAD"

        # Find all A, B, C, D that look like option selections
        # They should be at a word boundary and followed by typical delimiters
        pattern = r"\b([A-D])(?:\.|,|\s|&|$)"

        matches = re.findall(pattern, text)
        return set(matches)

    def _call_squad(
        self, predictions: list[str], references: list[str]
    ) -> dict[str, list[float]]:
        """
        Compute SQuAD metrics (Exact Matchand F1) for predictions and references.

        Args:
            predictions: List of prediction strings.
            references: List of reference answer strings.

        Returns:
            Dictionary with "exact_match" and "f1" keys, each containing a list of scores.
        """
        exact_matches = []
        f1_scores = []

        for i, (pred, ref) in enumerate(zip(predictions, references)):
            formatted_pred = [
                {"id": "0", "prediction_text": self._extract_answer(pred) or ""}
            ]
            formatted_ref = [
                {"id": "0", "answers": {"text": [ref], "answer_start": [0]}}
            ]

            results = self.squad_metric.compute(
                predictions=formatted_pred, references=formatted_ref
            )
            exact_matches.append(results["exact_match"] / 100)
            f1_scores.append(results["f1"] / 100)

        return {
            "exact_match": exact_matches,
            "f1": f1_scores,
        }

    def _compare_dates(
        self, predictions: list[str], references: list[str]
    ) -> dict[str, list[int]]:
        """
        Parses and compares dates in predictions and references for exact match.

        Args:
            predictions: List of prediction strings containing dates.
            references: List of reference date strings.

        Returns:
            Dictionary with "exact_match" key containing a list of 0/1 scores.
        """
        predictions = [
            self._parse_historical_date(self._extract_answer(pred))
            for pred in predictions
        ]
        references = [self._parse_historical_date(ref) for ref in references]
        return {
            "exact_match": [
                1 if pred == ref else 0 for pred, ref in zip(predictions, references)
            ],
        }

    def _compute_timedial(
        self, predictions: list[str], references: list[str]
    ) -> dict[str, list[float]]:
        """
        Compute TimeDial metrics (Exact Match and F1) using set-based comparison of selected options.

        Args:
            predictions: List of prediction strings.
            references: List of reference strings containing selected options.

        Returns:
            Dictionary with "exact_match" and "f1" keys, each containing a list of scores.
        """
        exact_matches = []
        f1_scores = []

        for pred, ref in zip(predictions, references):
            pred_answer = self._extract_answer(pred)  # Get text after marker
            pred_options = (
                self._extract_selected_options(pred_answer) if pred_answer else set()
            )
            ref_options = self._extract_selected_options(ref)

            # Exact match: sets must be identical
            em = 1 if pred_options == ref_options else 0
            exact_matches.append(em)

            # F1: set-based
            if not pred_options and not ref_options:
                f1 = 1.0  # Both empty = perfect match
            elif not pred_options or not ref_options:
                f1 = 0.0  # One empty, one not
            else:
                tp = len(pred_options & ref_options)
                precision = tp / len(pred_options)
                recall = tp / len(ref_options)
                f1 = (
                    2 * precision * recall / (precision + recall)
                    if (precision + recall) > 0
                    else 0.0
                )
            f1_scores.append(f1)

        return {"exact_match": exact_matches, "f1": f1_scores}

    @staticmethod
    def _parse_historical_date(date_str: str | None) -> datetime | None:
        """
        Parse a date string and return a datetime object with day set to 1.

        Args:
            date_str: String representation of a date, or None.

        Returns:
            datetime object with day set to 1, or None if parsing fails or input is None.
        """
        if date_str is None:
            return None
        try:
            return parser.parse(date_str).replace(day=1)
        except ParserError:
            return None