Varshithdharmajv commited on
Commit
b2df04b
·
verified ·
1 Parent(s): 693764f

Upload math_verify/tasks.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. math_verify/tasks.py +324 -0
math_verify/tasks.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from textwrap import dedent
3
+ from typing import Callable, Optional
4
+
5
+ import numpy as np
6
+ from lighteval.metrics.dynamic_metrics import SampleLevelMetric
7
+ from lighteval.metrics.utils.metric_utils import MetricCategory, MetricUseCase
8
+ from lighteval.tasks.lighteval_task import LightevalTaskConfig
9
+ from lighteval.tasks.requests import Doc
10
+
11
+ from math_verify.few_shots import GSM8K_FEW_SHOTS, MATH_HARD_FEW_SHOTS
12
+ from math_verify.metric import math_metric
13
+ from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def as_lighteval_metric(
19
+ metric: Callable[
20
+ [list[str], list[str]], tuple[float, Optional[tuple[list[str], list[str]]]]
21
+ ],
22
+ ) -> SampleLevelMetric:
23
+ def sample_level_fn(
24
+ formatted_doc: Doc, golds: list[str], predictions: list[str]
25
+ ) -> float:
26
+ result, extracted_predictions = metric(golds, predictions)
27
+ if extracted_predictions is not None:
28
+ if not formatted_doc.specific:
29
+ formatted_doc.specific = {}
30
+ formatted_doc.specific["extracted_predictions"] = extracted_predictions
31
+ return result
32
+
33
+ return SampleLevelMetric(
34
+ metric_name="extractive_match",
35
+ sample_level_fn=sample_level_fn,
36
+ category=MetricCategory.GENERATIVE,
37
+ use_case=MetricUseCase.ACCURACY,
38
+ corpus_level_fn=np.mean,
39
+ higher_is_better=True,
40
+ )
41
+
42
+
43
+ def math_hard_prompt_function(x: dict, task_name: str) -> Doc:
44
+ if x.get("__few_shots"):
45
+ index = x["__index"]
46
+ few_shot_doc = (
47
+ MATH_HARD_FEW_SHOTS[index]
48
+ if index < len(MATH_HARD_FEW_SHOTS)
49
+ else MATH_HARD_FEW_SHOTS[-1]
50
+ )
51
+ answer = few_shot_doc["answer"]
52
+ question = few_shot_doc["question"]
53
+ else:
54
+ answer = str(x["solution"])
55
+ question = x["problem"]
56
+
57
+ query = dedent(
58
+ f"""\
59
+ Question: {question}
60
+ Step-by-Step Answer:\
61
+ """
62
+ ).strip()
63
+
64
+ choices = [answer]
65
+ return Doc(query=query, choices=choices, gold_index=0)
66
+
67
+
68
+ def math_prompt_function(x: dict, task_name: str) -> Doc:
69
+ if x.get("__few_shots"):
70
+ index = x["__index"]
71
+ few_shot_doc = (
72
+ MATH_HARD_FEW_SHOTS[index]
73
+ if index < len(MATH_HARD_FEW_SHOTS)
74
+ else MATH_HARD_FEW_SHOTS[-1]
75
+ )
76
+ answer = few_shot_doc["answer"]
77
+ question = few_shot_doc["question"]
78
+ else:
79
+ answer = str(x["answer"])
80
+ question = x["problem"]
81
+
82
+ query = dedent(
83
+ f"""\
84
+ Question: {question}
85
+ Step-by-Step Answer:\
86
+ """
87
+ ).strip()
88
+
89
+ choices = [answer]
90
+ return Doc(query=query, choices=choices, gold_index=0)
91
+
92
+
93
+ def math_aime24_prompt_function(x: dict, task_name: str) -> Doc:
94
+ if x.get("__few_shots"):
95
+ index = x["__index"]
96
+ few_shot_doc = (
97
+ MATH_HARD_FEW_SHOTS[index]
98
+ if index < len(MATH_HARD_FEW_SHOTS)
99
+ else MATH_HARD_FEW_SHOTS[-1]
100
+ )
101
+ answer = few_shot_doc["answer"]
102
+ question = few_shot_doc["question"]
103
+ else:
104
+ answer = str(x["reference_solution"])
105
+ question = x["problem"]
106
+
107
+ query = dedent(
108
+ f"""\
109
+ Question: {question}
110
+ Step-by-Step Answer:\
111
+ """
112
+ ).strip()
113
+
114
+ choices = [f" {answer}"]
115
+ return Doc(query=query, choices=choices, gold_index=0)
116
+
117
+
118
+ def math_amc23_prompt_function(x: dict, task_name: str) -> Doc:
119
+ if x.get("__few_shots"):
120
+ index = x["__index"]
121
+ few_shot_doc = (
122
+ MATH_HARD_FEW_SHOTS[index]
123
+ if index < len(MATH_HARD_FEW_SHOTS)
124
+ else MATH_HARD_FEW_SHOTS[-1]
125
+ )
126
+ answer = few_shot_doc["answer"]
127
+ question = few_shot_doc["question"]
128
+ else:
129
+ answer = str(x["answer"])
130
+ question = x["question"]
131
+
132
+ query = dedent(
133
+ f"""\
134
+ Question: {question}
135
+ Step-by-Step Answer:\
136
+ """
137
+ ).strip()
138
+ choices = [f" {answer}"]
139
+ return Doc(query=query, choices=choices, gold_index=0)
140
+
141
+
142
+ def gsm8k_prompt_function(x: dict, task_name: str) -> Doc:
143
+ if x.get("__few_shots"):
144
+ index = x["__index"]
145
+ few_shot_doc = (
146
+ GSM8K_FEW_SHOTS[index]
147
+ if index < len(GSM8K_FEW_SHOTS)
148
+ else GSM8K_FEW_SHOTS[-1]
149
+ )
150
+ answer = few_shot_doc["answer"]
151
+ question = few_shot_doc["question"]
152
+ else:
153
+ answer = f"{x['answer'].split('####')[-1].strip()}"
154
+ question = x["question"]
155
+
156
+ query = dedent(
157
+ f"""\
158
+ Question: {question}
159
+ Step-by-Step Answer:\
160
+ """
161
+ ).strip()
162
+
163
+ choices = [f" {answer}"]
164
+ return Doc(query=query, choices=choices, gold_index=0)
165
+
166
+
167
+ math_hard_lighteval = [
168
+ LightevalTaskConfig(
169
+ name=f"math_hard:{subset}",
170
+ suite=["lighteval", "math"],
171
+ prompt_function=math_hard_prompt_function,
172
+ hf_repo="lighteval/MATH-Hard",
173
+ hf_subset=subset,
174
+ evaluation_splits=["test"],
175
+ few_shots_split="train",
176
+ generation_size=1024,
177
+ metric=[
178
+ as_lighteval_metric(
179
+ math_metric(
180
+ gold_extraction_target=(
181
+ LatexExtractionConfig(boxed_match_priority=0),
182
+ ),
183
+ pred_extraction_target=(
184
+ LatexExtractionConfig(),
185
+ ExprExtractionConfig(),
186
+ ),
187
+ )
188
+ ),
189
+ ],
190
+ stop_sequence=["\nQuestion:", "\nProblem:", "\nquestion:", "\nproblem:"],
191
+ trust_dataset=True,
192
+ version=0,
193
+ )
194
+ for subset in [
195
+ "algebra",
196
+ "counting_and_probability",
197
+ "geometry",
198
+ "intermediate_algebra",
199
+ "number_theory",
200
+ "prealgebra",
201
+ "precalculus",
202
+ ]
203
+ ]
204
+
205
+ math_500_lighteval = [
206
+ LightevalTaskConfig(
207
+ name="math_500",
208
+ suite=["lighteval", "math"],
209
+ prompt_function=math_prompt_function,
210
+ hf_repo="HuggingFaceH4/MATH-500",
211
+ hf_subset="default",
212
+ evaluation_splits=["test"],
213
+ few_shots_split="test",
214
+ generation_size=1024,
215
+ metric=[
216
+ as_lighteval_metric(
217
+ math_metric(
218
+ gold_extraction_target=(
219
+ LatexExtractionConfig(boxed_match_priority=0),
220
+ ),
221
+ pred_extraction_target=(
222
+ LatexExtractionConfig(),
223
+ ExprExtractionConfig(),
224
+ ),
225
+ )
226
+ ),
227
+ ],
228
+ stop_sequence=["\nQuestion:", "\nProblem:", "\nquestion:", "\nproblem:"],
229
+ trust_dataset=True,
230
+ version=0,
231
+ )
232
+ ]
233
+
234
+
235
+ aime24_lighteval = [
236
+ LightevalTaskConfig(
237
+ name="aime24",
238
+ suite=["lighteval", "math"],
239
+ prompt_function=math_aime24_prompt_function,
240
+ hf_repo="zwhe99/aime24",
241
+ hf_subset="default",
242
+ evaluation_splits=["test"],
243
+ few_shots_split="test",
244
+ generation_size=1024,
245
+ metric=[
246
+ as_lighteval_metric(
247
+ math_metric(
248
+ gold_extraction_target=(LatexExtractionConfig(),),
249
+ pred_extraction_target=(
250
+ LatexExtractionConfig(),
251
+ ExprExtractionConfig(),
252
+ ),
253
+ )
254
+ ),
255
+ ],
256
+ stop_sequence=["\nQuestion:", "\nProblem:", "\nquestion:", "\nproblem:"],
257
+ trust_dataset=True,
258
+ version=0,
259
+ )
260
+ ]
261
+
262
+ amc23_lighteval = [
263
+ LightevalTaskConfig(
264
+ name="amc23",
265
+ suite=["lighteval", "math"],
266
+ prompt_function=math_amc23_prompt_function,
267
+ hf_repo="zwhe99/amc23",
268
+ hf_subset="default",
269
+ hf_filter=lambda x: len(x["question"].strip()) > 0,
270
+ evaluation_splits=["test"],
271
+ few_shots_split="test",
272
+ generation_size=1024,
273
+ metric=[
274
+ as_lighteval_metric(
275
+ math_metric(
276
+ gold_extraction_target=(ExprExtractionConfig(),),
277
+ pred_extraction_target=(
278
+ LatexExtractionConfig(),
279
+ ExprExtractionConfig(),
280
+ ),
281
+ )
282
+ ),
283
+ ],
284
+ stop_sequence=["\nQuestion:", "\nProblem:", "\nquestion:", "\nproblem:"],
285
+ trust_dataset=True,
286
+ version=0,
287
+ )
288
+ ]
289
+
290
+ gsm8k_lighteval = [
291
+ LightevalTaskConfig(
292
+ name="gsm8k",
293
+ suite=["lighteval", "math"],
294
+ prompt_function=gsm8k_prompt_function,
295
+ hf_repo="openai/gsm8k",
296
+ hf_subset="main",
297
+ hf_filter=lambda x: len(x["question"].strip()) > 0,
298
+ evaluation_splits=["test"],
299
+ few_shots_split="test",
300
+ generation_size=1024,
301
+ stop_sequence=["\nQuestion:", "\nProblem:", "\nquestion:", "\nproblem:"],
302
+ metric=[
303
+ as_lighteval_metric(
304
+ math_metric(
305
+ gold_extraction_target=(ExprExtractionConfig(),),
306
+ pred_extraction_target=(
307
+ LatexExtractionConfig(),
308
+ ExprExtractionConfig(),
309
+ ),
310
+ fallback_mode="first_match",
311
+ )
312
+ ),
313
+ ],
314
+ )
315
+ ]
316
+
317
+
318
+ TASKS_TABLE = [
319
+ *gsm8k_lighteval,
320
+ *math_hard_lighteval,
321
+ *math_500_lighteval,
322
+ *aime24_lighteval,
323
+ *amc23_lighteval,
324
+ ]