aauss commited on
Commit
fb49a5d
·
1 Parent(s): 6bb843b

Implement TimeDial evaluation.

Browse files
tests/conftest.py CHANGED
@@ -107,3 +107,48 @@ PREDICTION_4 = dedent("""\
107
  Since the only team he left in 2002 is Cardiff City, and the timeframe in question is 1998–2000, it is the most likely candidate.
108
 
109
  Thus, the correct answer is: Cardiff City.""")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  Since the only team he left in 2002 is Cardiff City, and the timeframe in question is 1998–2000, it is the most likely candidate.
108
 
109
  Thus, the correct answer is: Cardiff City.""")
110
+
111
+ PREDICTION_5 = dedent("""\
112
+ Let's analyze the dialogue step by step to determine what makes the most sense in the context of the <mask>.
113
+
114
+ Dialogue:
115
+
116
+ Person1: What did you say?
117
+ Person2: I said it's a lovely day. Why don't we go for a walk?
118
+ Person1: Well, I feel a little tired.
119
+ Person2: Come on! A little labor, much health.
120
+ Person1: Then can you wait a few minutes? I want to finish writing this letter.
121
+ Person2: Don't take too long. It would be a shame not to take advantage of such lovely weather.
122
+ Person1: I won't be long. <MASK>. Why don't you go ahead and I'll meet you in the park?
123
+ Person2: I believe I will. Look for me near the lake.
124
+
125
+ We are to choose appropriate options to substitute the <mask>.
126
+
127
+ Now, evaluate the options:
128
+
129
+ A. No more than ten months
130
+ B. No more than ten minutes
131
+ C. No more than five minutes
132
+ D. No more than two years
133
+
134
+ Contextual Clue:
135
+ Person1 says "I won't be long" — implying a short time.
136
+ They are still writing a letter, and the second person is suggesting they go for a walk now.
137
+ The weather is lovely, and the second person is urging them not to delay.
138
+ So, the time frame must be very short — plausible in the context of finishing a letter.
139
+
140
+ Option A: "No more than ten months" — that's a long time. Doesn’t align with "I won't be long."
141
+ Option B: "No more than ten minutes" — reasonable, short time, fits with "won't be long."
142
+ Option C: "No more than five minutes" — even shorter, very plausible and fits better with "won't be long."
143
+ Option D: "No more than two years" — extremely long — totally inconsistent with the context.
144
+
145
+ So, B and C are both reasonable and within the context.
146
+ Both are short durations and reasonable for finishing a letter.
147
+
148
+ Note: The sentence says: "I won't be long. <MASK>. Why don't you go ahead..." — so the <mask> is a time commitment, and the next sentence is an invitation for the other person to go ahead.
149
+
150
+ Therefore, the correct options are those that convey a short time frame — clearly B and C.
151
+
152
+ A and D are implausible — too long.
153
+
154
+ Thus, the correct answer is: B, C.""")
tests/test_answer_extraction.py CHANGED
@@ -1,6 +1,12 @@
1
  import pytest
2
  from timebench_eval.timebench_eval import TimebenchEval
3
- from conftest import PREDICTION_1, PREDICTION_2, PREDICTION_3, PREDICTION_4
 
 
 
 
 
 
4
 
5
 
6
  @pytest.mark.parametrize(
@@ -10,6 +16,7 @@ from conftest import PREDICTION_1, PREDICTION_2, PREDICTION_3, PREDICTION_4
10
  (PREDICTION_2, "August 1804"),
11
  (PREDICTION_3, "unanswerable"),
12
  (PREDICTION_4, "Cardiff City"),
 
13
  ],
14
  )
15
  def test_answer_extraction(prediction, extracted_answer):
 
1
  import pytest
2
  from timebench_eval.timebench_eval import TimebenchEval
3
+ from conftest import (
4
+ PREDICTION_1,
5
+ PREDICTION_2,
6
+ PREDICTION_3,
7
+ PREDICTION_4,
8
+ PREDICTION_5,
9
+ )
10
 
11
 
12
  @pytest.mark.parametrize(
 
16
  (PREDICTION_2, "August 1804"),
17
  (PREDICTION_3, "unanswerable"),
18
  (PREDICTION_4, "Cardiff City"),
19
+ (PREDICTION_5, "B, C"),
20
  ],
21
  )
22
  def test_answer_extraction(prediction, extracted_answer):
tests/test_metrics.py CHANGED
@@ -1,6 +1,12 @@
1
  from timebench_eval.timebench_eval import TimebenchEval
2
  import pytest
3
- from conftest import PREDICTION_1, PREDICTION_2, PREDICTION_3, PREDICTION_4
 
 
 
 
 
 
4
 
5
 
6
  @pytest.mark.parametrize(
@@ -41,6 +47,33 @@ from conftest import PREDICTION_1, PREDICTION_2, PREDICTION_3, PREDICTION_4
41
  "f1": [1],
42
  },
43
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  ],
45
  )
46
  def test_eval(prediction, reference, task, expected_metrics):
 
1
  from timebench_eval.timebench_eval import TimebenchEval
2
  import pytest
3
+ from conftest import (
4
+ PREDICTION_1,
5
+ PREDICTION_2,
6
+ PREDICTION_3,
7
+ PREDICTION_4,
8
+ PREDICTION_5,
9
+ )
10
 
11
 
12
  @pytest.mark.parametrize(
 
47
  "f1": [1],
48
  },
49
  ),
50
+ (
51
+ PREDICTION_5,
52
+ "B. No more than ten minutes && C. No more than five minutes",
53
+ "TimeDial",
54
+ {
55
+ "exact_match": [1],
56
+ "f1": [1],
57
+ },
58
+ ),
59
+ (
60
+ PREDICTION_5,
61
+ "B.",
62
+ "TimeDial",
63
+ {
64
+ "exact_match": [0],
65
+ "f1": [pytest.approx(2 / 3, rel=1e-6)],
66
+ },
67
+ ),
68
+ (
69
+ PREDICTION_5,
70
+ "A.",
71
+ "TimeDial",
72
+ {
73
+ "exact_match": [0],
74
+ "f1": [0],
75
+ },
76
+ ),
77
  ],
78
  )
79
  def test_eval(prediction, reference, task, expected_metrics):
timebench_eval/.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
timebench_eval/timebench_eval.py CHANGED
@@ -13,13 +13,14 @@
13
  # limitations under the License.
14
  """TODO: Add a description here."""
15
 
 
 
 
16
  from dateutil import parser
17
  from dateutil.parser import ParserError
18
 
19
-
20
  import evaluate
21
  import datasets
22
- import numpy as np
23
 
24
 
25
  # TODO: Add BibTeX citation
@@ -92,6 +93,31 @@ class TimebenchEval(evaluate.Metric):
92
  reference_urls=["http://path.to.reference.url/new_module"],
93
  )
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  @staticmethod
96
  def _extract_answer(response: str) -> str | None:
97
  """Extract the answer from the response"""
@@ -107,7 +133,44 @@ class TimebenchEval(evaluate.Metric):
107
  return "unanswerable"
108
  return answer or None
109
 
110
- def _call_squad(self, predictions, references):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  exact_matches = []
112
  f1_scores = []
113
 
@@ -116,7 +179,7 @@ class TimebenchEval(evaluate.Metric):
116
  {"id": "0", "prediction_text": self._extract_answer(pred)}
117
  ]
118
  formatted_ref = [
119
- {"id": "0", "answers": {"text": [self._extract_answer(ref)], "answer_start": [0]}}
120
  ]
121
 
122
  results = self.squad_metric.compute(
@@ -130,14 +193,19 @@ class TimebenchEval(evaluate.Metric):
130
  "f1": f1_scores,
131
  }
132
 
133
- @staticmethod
134
- def _parse_historical_date(date_str):
135
- try:
136
- return parser.parse(date_str).replace(day=1)
137
- except ParserError:
138
- return None
 
 
 
139
 
140
- def _compare_dates(self, predictions, references):
 
 
141
  predictions = [
142
  self._parse_historical_date(self._extract_answer(pred))
143
  for pred in predictions
@@ -149,13 +217,63 @@ class TimebenchEval(evaluate.Metric):
149
  ],
150
  }
151
 
152
- def _compute(self, predictions, references, task: str):
153
- """Returns the scores"""
154
- if task in [
155
- "TempReason",
156
- "TimeQA",
157
- "MenatQA",
158
- ]:
159
- return self._call_squad(predictions, references)
160
- elif task == "Date Arithmetic":
161
- return self._compare_dates(predictions, references)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # limitations under the License.
14
  """TODO: Add a description here."""
15
 
16
+ import re
17
+ from datetime import datetime
18
+
19
  from dateutil import parser
20
  from dateutil.parser import ParserError
21
 
 
22
  import evaluate
23
  import datasets
 
24
 
25
 
26
  # TODO: Add BibTeX citation
 
93
  reference_urls=["http://path.to.reference.url/new_module"],
94
  )
95
 
96
+ def _compute(
97
+ self, predictions: list[str], references: list[str], task: str
98
+ ) -> dict[str, list[float]]:
99
+ """
100
+ Compute evaluation metrics for the given predictions and references.
101
+
102
+ Args:
103
+ predictions: List of prediction strings to evaluate.
104
+ references: List of reference strings to compare against.
105
+ task: Task type, one of: "TempReason", "TimeQA", "MenatQA", "Date Arithmetic", "TimeDial".
106
+
107
+ Returns:
108
+ Dictionary containing metric scores (exact_match and/or f1) as lists of floats.
109
+ """
110
+ if task in [
111
+ "TempReason",
112
+ "TimeQA",
113
+ "MenatQA",
114
+ ]:
115
+ return self._call_squad(predictions, references)
116
+ elif task == "Date Arithmetic":
117
+ return self._compare_dates(predictions, references)
118
+ elif task == "TimeDial":
119
+ return self._compute_timedial(predictions, references)
120
+
121
  @staticmethod
122
  def _extract_answer(response: str) -> str | None:
123
  """Extract the answer from the response"""
 
133
  return "unanswerable"
134
  return answer or None
135
 
136
+ def _extract_selected_options(self, text: str) -> set[str]:
137
+ """
138
+ Extract selected option letters (A, B, C, D) from various formats:
139
+ - "B, C"
140
+ - "B and C"
141
+ - "B & C"
142
+ - "B && C"
143
+ - "B. No more than ten minutes && C. No more than five minutes"
144
+ - "Options B and C"
145
+ - "The answer is B, C"
146
+ """
147
+ if not text:
148
+ return set()
149
+
150
+ # Pattern matches option letters that appear:
151
+ # 1. At word boundary followed by period, comma, space, &, or end: \b[A-D](?=[.\s,&]|$)
152
+ # 2. This avoids matching letters inside words like "CAD" or "BAD"
153
+
154
+ # Find all A, B, C, D that look like option selections
155
+ # They should be at a word boundary and followed by typical delimiters
156
+ pattern = r"\b([A-D])(?:\.|,|\s|&|$)"
157
+
158
+ matches = re.findall(pattern, text)
159
+ return set(matches)
160
+
161
+ def _call_squad(
162
+ self, predictions: list[str], references: list[str]
163
+ ) -> dict[str, list[float]]:
164
+ """
165
+ Compute SQuAD metrics (Exact Matchand F1) for predictions and references.
166
+
167
+ Args:
168
+ predictions: List of prediction strings.
169
+ references: List of reference answer strings.
170
+
171
+ Returns:
172
+ Dictionary with "exact_match" and "f1" keys, each containing a list of scores.
173
+ """
174
  exact_matches = []
175
  f1_scores = []
176
 
 
179
  {"id": "0", "prediction_text": self._extract_answer(pred)}
180
  ]
181
  formatted_ref = [
182
+ {"id": "0", "answers": {"text": [ref], "answer_start": [0]}}
183
  ]
184
 
185
  results = self.squad_metric.compute(
 
193
  "f1": f1_scores,
194
  }
195
 
196
+ def _compare_dates(
197
+ self, predictions: list[str], references: list[str]
198
+ ) -> dict[str, list[int]]:
199
+ """
200
+ Parses and compares dates in predictions and references for exact match.
201
+
202
+ Args:
203
+ predictions: List of prediction strings containing dates.
204
+ references: List of reference date strings.
205
 
206
+ Returns:
207
+ Dictionary with "exact_match" key containing a list of 0/1 scores.
208
+ """
209
  predictions = [
210
  self._parse_historical_date(self._extract_answer(pred))
211
  for pred in predictions
 
217
  ],
218
  }
219
 
220
+ def _compute_timedial(
221
+ self, predictions: list[str], references: list[str]
222
+ ) -> dict[str, list[float]]:
223
+ """
224
+ Compute TimeDial metrics (Exact Match and F1) using set-based comparison of selected options.
225
+
226
+ Args:
227
+ predictions: List of prediction strings.
228
+ references: List of reference strings containing selected options.
229
+
230
+ Returns:
231
+ Dictionary with "exact_match" and "f1" keys, each containing a list of scores.
232
+ """
233
+ exact_matches = []
234
+ f1_scores = []
235
+
236
+ for pred, ref in zip(predictions, references):
237
+ pred_answer = self._extract_answer(pred) # Get text after marker
238
+ pred_options = (
239
+ self._extract_selected_options(pred_answer) if pred_answer else set()
240
+ )
241
+ ref_options = self._extract_selected_options(ref)
242
+
243
+ # Exact match: sets must be identical
244
+ em = 1 if pred_options == ref_options else 0
245
+ exact_matches.append(em)
246
+
247
+ # F1: set-based
248
+ if not pred_options and not ref_options:
249
+ f1 = 1.0 # Both empty = perfect match
250
+ elif not pred_options or not ref_options:
251
+ f1 = 0.0 # One empty, one not
252
+ else:
253
+ tp = len(pred_options & ref_options)
254
+ precision = tp / len(pred_options)
255
+ recall = tp / len(ref_options)
256
+ f1 = (
257
+ 2 * precision * recall / (precision + recall)
258
+ if (precision + recall) > 0
259
+ else 0.0
260
+ )
261
+ f1_scores.append(f1)
262
+
263
+ return {"exact_match": exact_matches, "f1": f1_scores}
264
+
265
+ @staticmethod
266
+ def _parse_historical_date(date_str: str) -> datetime | None:
267
+ """
268
+ Parse a date string and return a datetime object with day set to 1.
269
+
270
+ Args:
271
+ date_str: String representation of a date.
272
+
273
+ Returns:
274
+ datetime object with day set to 1, or None if parsing fails.
275
+ """
276
+ try:
277
+ return parser.parse(date_str).replace(day=1)
278
+ except ParserError:
279
+ return None
uv.lock ADDED
The diff for this file is too large to render. See raw diff