amanoaki commited on
Commit
b09aed5
·
verified ·
1 Parent(s): 5be7967

Upload 10 files

Browse files
README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # domain-rlhf
dataset_construction/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Code to evaluate model on Dataset `PRIME-RL/Eurus-2-RL-Data` and assign a difficulty to the result. Output in JSON.
2
+
3
+ Evaluation metric: Pass@k, 1/Pass@k
dataset_construction/evaluate_on_eurus_2.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from datasets import load_dataset
3
+ from tqdm import tqdm
4
+ from vllm import LLM, SamplingParams
5
+ from transformers import AutoTokenizer
6
+
7
+ from evaluation_utils.code_util import evaluate_code
8
+ from evaluation_utils.math_util import evaluate_math
9
+ import argparse
10
+ from tqdm import tqdm
11
+
12
+ def read_args():
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-7B-Instruct", help="The model to use for evaluation")
15
+ parser.add_argument("--batch_size", type=int, default=4, help="Batch size for evaluation")
16
+ parser.add_argument("--k", type=int, default=10, help="Number of completions to generate for each prompt")
17
+ parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Number of parallel tensors to use for generation")
18
+ parser.add_argument("--temperature", type=float, default=0.8, help="Temperature for sampling")
19
+ parser.add_argument("--top_p", type=float, default=0.95, help="Top-p for sampling")
20
+ parser.add_argument("--num_samples", type=int, default=-1, help="Number of samples to evaluate")
21
+ parser.add_argument("--output_file", type=str, default="evaluation_results.json", help="File to save evaluation results")
22
+ # NOTE: The dataset asks for CoT output. This is very token-consuming. If you limit it to a small number, the model will not be able to generate the answer.
23
+ parser.add_argument("--max_tokens", type=int, default=1024, help="Maximum number of tokens to generate.")
24
+ parser.add_argument("--log_per_step", type=int, default=1000, help="Log results per step")
25
+ return parser.parse_args()
26
+
27
+ def generate_batch_completions(llm, prompts, sampling_params):
28
+ outputs = llm.generate(prompts, sampling_params)
29
+ all_completions = []
30
+ for output in outputs:
31
+ completions = [output.outputs[i].text for i in range(len(output.outputs))]
32
+ all_completions.append(completions)
33
+ return all_completions
34
+
35
+ def evaluate(k_completions, ability, ground_truth):
36
+ all_results = []
37
+ for completion in k_completions:
38
+ if ability == "math":
39
+ is_correct, _, _ = evaluate_math(completion, ground_truth)
40
+ if is_correct:
41
+ all_results.append(1)
42
+ else:
43
+ all_results.append(0)
44
+ elif ability == "coding":
45
+ success, _ = evaluate_code(completion, ground_truth)
46
+ if success:
47
+ all_results.append(1)
48
+ else:
49
+ all_results.append(0)
50
+ return all_results
51
+
52
+ def main():
53
+ args = read_args()
54
+ ds = load_dataset("PRIME-RL/Eurus-2-RL-Data")
55
+
56
+ llm = LLM(model=args.model, tensor_parallel_size=args.tensor_parallel_size)
57
+ k = args.k
58
+ sampling_params = SamplingParams(temperature=args.temperature, top_p=args.top_p, n=k, max_tokens=4096)
59
+
60
+ batch_size = args.batch_size
61
+ results = {}
62
+ complete_results = {}
63
+
64
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
65
+
66
+ num = 0
67
+ for split in ["train", "validation"]:
68
+ results[split] = []
69
+ complete_results[split] = []
70
+ print(f"Running evaluation on {split} split")
71
+ num_samples = args.num_samples if args.num_samples > 0 else len(ds[split])
72
+ for i in tqdm(range(0, num_samples, batch_size)):
73
+ batch_samples = ds[split][i: i + batch_size]
74
+ prompts = []
75
+ for sample in batch_samples['prompt']:
76
+ message = tokenizer.apply_chat_template(
77
+ sample,
78
+ tokenize=False,
79
+ add_generation_prompt=False
80
+ )
81
+ prompts.append(message)
82
+
83
+ batch_outputs = generate_batch_completions(llm, prompts, sampling_params)
84
+
85
+ for data_source, prompt, ability, reward_model, extra_info, completions in zip(
86
+ batch_samples["data_source"],
87
+ batch_samples["prompt"],
88
+ batch_samples["ability"],
89
+ batch_samples["reward_model"],
90
+ batch_samples["extra_info"],
91
+ batch_outputs
92
+ ):
93
+ sample = {
94
+ "data_source": data_source,
95
+ "prompt": prompt,
96
+ "ability": ability,
97
+ "reward_model": reward_model,
98
+ "extra_info": extra_info,
99
+ }
100
+
101
+ ability = sample.get("ability", "")
102
+ ground_truth = sample.get("reward_model", {}).get("ground_truth", "")
103
+ correctness = evaluate(completions, ability, ground_truth)
104
+ sample_result = {
105
+ **sample,
106
+ "correctness": correctness,
107
+ }
108
+ results[split].append(sample_result)
109
+ complete_sample_result = sample_result.copy()
110
+ complete_sample_result["completions"] = completions
111
+ complete_results[split].append(complete_sample_result)
112
+
113
+ num += 1
114
+ if args.num_samples > 0 and num >= args.num_samples:
115
+ break
116
+ if args.num_samples > 0 and num >= args.num_samples:
117
+ break
118
+ if num % args.log_per_step == 0:
119
+ with open(f"step{num}_" + args.output_file, "w", encoding="utf-8") as f:
120
+ json.dump(results, f, ensure_ascii=False, indent=2)
121
+ with open(f"step{num}_complete" + args.output_file, "w", encoding="utf-8") as f:
122
+ json.dump(complete_results, f, ensure_ascii=False, indent=2)
123
+ if args.num_samples > 0 and num >= args.num_samples:
124
+ break
125
+
126
+ with open(args.output_file, "w", encoding="utf-8") as f:
127
+ json.dump(results, f, ensure_ascii=False, indent=2)
128
+
129
+ with open("complete_" + args.output_file, "w", encoding="utf-8") as f:
130
+ json.dump(complete_results, f, ensure_ascii=False, indent=2)
131
+
132
+ if __name__ == "__main__":
133
+ main()
dataset_construction/evaluation_utils/code_util/__init__.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import check_correctness as apps_check_correctness
2
+ import json
3
+ import re
4
+ import traceback
5
+
6
+ def evaluate_code(completion, test_cases):
7
+ # try to get code solution from completion. if the completion is pure code, this will not take effect.
8
+ solution = completion.split('```python')[-1].split('```')[0]
9
+ try:
10
+ try:
11
+ if not isinstance(test_cases, dict):
12
+ test_cases = json.loads(test_cases)
13
+ except Exception as e:
14
+ print(f"Error:{e}")
15
+
16
+ # 先检查正确性,如果正确,则再one by one 检查test case
17
+ try:
18
+ res, metadata = apps_check_correctness(
19
+ in_outs=test_cases,
20
+ generation=solution,
21
+ timeout=5,
22
+ debug=False
23
+ )
24
+ metadata = dict(enumerate(metadata))[0]
25
+ success = all(map(lambda x: x == True, res))
26
+ if success:
27
+ return success, metadata
28
+ except Exception as e:
29
+ pass
30
+
31
+ test_cases_list = []
32
+ inputs = test_cases["inputs"]
33
+ outputs = test_cases["outputs"]
34
+ for i in range(len(inputs)):
35
+ test_cases_list.append({
36
+ "inputs": [inputs[i]],
37
+ "outputs": [outputs[i]]
38
+ })
39
+
40
+ metadata_list = []
41
+ res_list = []
42
+ for test_case_id, test_case in enumerate(test_cases_list):
43
+ res, metadata = apps_check_correctness(
44
+ in_outs=test_case,
45
+ generation=solution,
46
+ timeout=5,
47
+ debug=False
48
+ )
49
+ try:
50
+ metadata = dict(enumerate(metadata))[0] # 运算失败时metadata有可能为空
51
+ except Exception as e:
52
+ metadata={}
53
+ metadata["test_case"] = {}
54
+ metadata["test_case"]["input"] = str(test_case["inputs"][0])
55
+ metadata["test_case"]["output"] = str(test_case["outputs"][0])
56
+ metadata["test_case"]["res"] = str(res)
57
+ metadata_list.append(metadata)
58
+ res_list.extend(res)
59
+
60
+ if test_case_id>=9:
61
+ break
62
+
63
+ success = all(map(lambda x: x == True, res_list))
64
+ except Exception as e:
65
+ traceback.print_exc(10)
66
+ success = False
67
+ metadata_list = None
68
+ return success, metadata_list
69
+
dataset_construction/evaluation_utils/code_util/testing_util.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import json
3
+ import sys
4
+ import faulthandler
5
+ import platform
6
+
7
+ # used for debugging to time steps
8
+ from datetime import datetime
9
+
10
+ # to run the solution files we're using a timing based approach
11
+ import signal
12
+
13
+ import numpy as np
14
+
15
+ # for capturing the stdout
16
+ from io import StringIO
17
+
18
+ # used for testing the code that reads from input
19
+ from unittest.mock import patch, mock_open
20
+
21
+ from pyext import RuntimeModule
22
+
23
+ from enum import Enum
24
+
25
+ import traceback
26
+
27
+ def truncatefn(s, length=300):
28
+ assert isinstance(s, str)
29
+ if len(s) <= length:
30
+ return s
31
+
32
+ return s[: length // 2] + "...(truncated) ..." + s[-length // 2 :]
33
+
34
+
35
+ class CODE_TYPE(Enum):
36
+ call_based = 0
37
+ standard_input = 1
38
+
39
+
40
+ # stuff for setting up signal timer
41
+ class TimeoutException(Exception):
42
+ pass
43
+
44
+
45
+ def timeout_handler(signum, frame):
46
+ print("alarm went off")
47
+ return
48
+ # raise TimeoutException # this is an unhandled exception. just return None is OK
49
+
50
+
51
+ signal.signal(signal.SIGALRM, timeout_handler)
52
+ # timeout = 6 # seconds
53
+
54
+
55
+ # used to capture stdout as a list
56
+ # from https://stackoverflow.com/a/16571630/6416660
57
+ # alternative use redirect_stdout() from contextlib
58
+ class Capturing(list):
59
+ def __enter__(self):
60
+ self._stdout = sys.stdout
61
+ sys.stdout = self._stringio = StringIO()
62
+ # Make closing the StringIO a no-op
63
+ self._stringio.close = lambda x: 1
64
+ return self
65
+
66
+ def __exit__(self, *args):
67
+ self.append(self._stringio.getvalue())
68
+ del self._stringio # free up some memory
69
+ sys.stdout = self._stdout
70
+
71
+
72
+ def only_int_check(val):
73
+ return isinstance(val, int)
74
+
75
+
76
+ def string_int_check(val):
77
+ return isinstance(val, str) and val.isdigit()
78
+
79
+
80
+ def combined_int_check(val):
81
+ return only_int_check(val) or string_int_check(val)
82
+
83
+ def clean_traceback(error_traceback):
84
+ file_start = error_traceback.find('File \"<string>\"')
85
+ # print(file_start)
86
+ error_traceback = "Traceback (most recent call last):\n " + error_traceback[file_start:]
87
+ return error_traceback
88
+
89
+ def run_test(in_outs, test=None, debug=False, timeout=15):
90
+ """
91
+ if test(generated_code) is not None it'll try to run the code.
92
+ otherwise it'll just return an input and output pair.
93
+ """
94
+ # Disable functionalities that can make destructive changes to the test.
95
+ reliability_guard()
96
+
97
+ if debug:
98
+ print(f"start = {datetime.now().time()}")
99
+
100
+ # try:
101
+ # in_outs = json.loads(sample["input_output"])
102
+ # except ValueError:
103
+ # in_outs = None
104
+ if in_outs:
105
+ if in_outs.get("fn_name") is None:
106
+ which_type = CODE_TYPE.standard_input # Standard input
107
+ method_name = None
108
+ else:
109
+ which_type = CODE_TYPE.call_based # Call-based
110
+ method_name = in_outs["fn_name"]
111
+
112
+ if debug:
113
+ print(f"loaded input_output = {datetime.now().time()}")
114
+
115
+ if test is None:
116
+ assert False, "should not happen: test code is none"
117
+ return in_outs, {"error": "no test code provided"}
118
+ elif test is not None:
119
+ results = []
120
+ sol = "from string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\nsys.setrecursionlimit(6*10**5)\n"
121
+ if debug:
122
+ print(f"loading test code = {datetime.now().time()}")
123
+
124
+ if which_type == CODE_TYPE.call_based:
125
+
126
+ sol += test
127
+ if debug:
128
+ print(f"sol = {sol}")
129
+ signal.alarm(timeout)
130
+ try:
131
+ tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
132
+ if "class Solution" not in test:
133
+ tmp = tmp_sol
134
+ else:
135
+ tmp = tmp_sol.Solution()
136
+ signal.alarm(0)
137
+ except Exception as e:
138
+ signal.alarm(0)
139
+ error_traceback = traceback.format_exc()
140
+ if debug:
141
+ print(f"type 0 compilation error = {e}")
142
+ results.append(-2)
143
+ return results, {
144
+ "error": repr(e),
145
+ # "error_code": -1,
146
+ # "error_message": "Compilation Error",
147
+ "traceback": clean_traceback(error_traceback),
148
+ }
149
+ signal.alarm(0)
150
+
151
+ elif which_type == CODE_TYPE.standard_input:
152
+ # sol
153
+ # if code has if __name__ == "__main__": then remove it
154
+ try:
155
+ astree = ast.parse(test)
156
+ last_block = astree.body[-1]
157
+ if isinstance(last_block, ast.If):
158
+ condition = last_block.test
159
+ if ast.unparse(condition).strip() == "__name__ == '__main__'":
160
+ test = (
161
+ ast.unparse(astree.body[:-1])
162
+ + "\n"
163
+ + ast.unparse(last_block.body)
164
+ )
165
+ except:
166
+ pass
167
+
168
+ tmp_test = test.split("\n")
169
+
170
+ new_test = []
171
+ for x in tmp_test:
172
+ if (not x.startswith("from ")) and (not x.startswith("import ")):
173
+ new_test.append("\t" + x + "\n")
174
+ else:
175
+ new_test.append(x + "\n")
176
+ tmp_test = new_test
177
+
178
+ new_test = ""
179
+ started = False
180
+ for i in tmp_test:
181
+ if i.startswith("\t") and not started:
182
+ new_test += "stdin = sys.stdin\nstdout = sys.stdout\n"
183
+ new_test += "def code():\n"
184
+ new_test += i
185
+ started = True
186
+ elif started and ((i.startswith("from ")) or (i.startswith("import "))):
187
+ new_test += "\t" + i
188
+ else:
189
+ new_test += i
190
+ tmp_test = new_test
191
+
192
+ sol += tmp_test
193
+ if debug:
194
+ print(f"sol = {sol}")
195
+ method_name = "code"
196
+ signal.alarm(timeout)
197
+ try:
198
+ tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
199
+ tmp = tmp_sol
200
+ signal.alarm(0)
201
+ except Exception as e:
202
+ signal.alarm(0)
203
+ error_traceback = traceback.format_exc()
204
+ if debug:
205
+ print(f"type 1 compilation error = {e}")
206
+ results.append(-2)
207
+ return results, {
208
+ "error": repr(e),
209
+ # "error_code": -1,
210
+ # "error_message": "Compilation Error",
211
+ "traceback": clean_traceback(error_traceback),
212
+ }
213
+ signal.alarm(0)
214
+ if debug:
215
+ print(f"get method = {datetime.now().time()}")
216
+
217
+ try:
218
+ method = getattr(tmp, method_name) # get_attr second arg must be str
219
+ except:
220
+ signal.alarm(0)
221
+ error_traceback = traceback.format_exc()
222
+ e = sys.exc_info()
223
+ print(f"unable to get function error = {e}")
224
+ results.append(-2)
225
+ return results, {
226
+ "error": repr(e),
227
+ # "error_code": -1,
228
+ # "error_message": "Unable to extract code",
229
+ "traceback": clean_traceback(error_traceback),
230
+ }
231
+
232
+ for index, inputs in enumerate(in_outs["inputs"]):
233
+ raw_inputs = inputs
234
+ raw_outputs = in_outs["outputs"][index]
235
+ if which_type == CODE_TYPE.call_based:
236
+ inputs = [json.loads(line) for line in inputs.split("\n")]
237
+ in_outs["outputs"][index] = json.loads(in_outs["outputs"][index])
238
+
239
+ truncate_line_size = 300 // (raw_inputs.count("\n") + 1)
240
+ raw_inputs = "\n".join(
241
+ [
242
+ truncatefn(line, truncate_line_size)
243
+ for line in raw_inputs.strip().split("\n")
244
+ ]
245
+ )
246
+ raw_outputs = truncatefn(raw_outputs, 200)
247
+ else:
248
+ raw_inputs = truncatefn(raw_inputs)
249
+ raw_outputs = truncatefn(raw_outputs, 200)
250
+ # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)
251
+ try:
252
+ if isinstance(inputs[0], dict):
253
+ inputs = [{int(k): v for k, v in inputs[0].items()}]
254
+ except:
255
+ True
256
+ try:
257
+ if isinstance(in_outs["outputs"][index], dict):
258
+ in_outs["outputs"][index] = [
259
+ {int(k): v for k, v in in_outs["outputs"][index].items()}
260
+ ]
261
+ except:
262
+ True
263
+ try:
264
+ if isinstance(in_outs["outputs"][index][0], dict):
265
+ in_outs["outputs"][index] = [
266
+ {int(k): v for k, v in in_outs["outputs"][index][0].items()}
267
+ ]
268
+ except:
269
+ True
270
+
271
+ if debug:
272
+ print(
273
+ f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. type = {which_type}"
274
+ )
275
+ if which_type == CODE_TYPE.call_based: # Call-based
276
+ signal.alarm(timeout)
277
+ faulthandler.enable()
278
+ try:
279
+ output = method(*inputs)
280
+ raw_true_output = output
281
+
282
+ raw_true_output_copy = json.dumps(output)
283
+ raw_true_output_copy = truncatefn(raw_true_output_copy, 200)
284
+
285
+ # ground truth sequences are not tuples
286
+ if isinstance(output, tuple):
287
+ output = list(output)
288
+
289
+ tmp_result = output == in_outs["outputs"][index]
290
+ if (
291
+ isinstance(in_outs["outputs"][index], list)
292
+ and in_outs["outputs"][index]
293
+ ):
294
+ tmp_result = tmp_result or (
295
+ output == in_outs["outputs"][index][0]
296
+ )
297
+
298
+ # ground truth sequences are not tuples
299
+ try:
300
+ if isinstance(output[0], tuple):
301
+ tmp_result = tmp_result or (
302
+ [list(x) for x in output]
303
+ == in_outs["outputs"][index][0]
304
+ )
305
+ except:
306
+ True
307
+ results.append(tmp_result)
308
+ if tmp_result != True:
309
+ return results, {
310
+ "output": raw_true_output_copy,
311
+ "expected": raw_outputs,
312
+ "inputs": raw_inputs,
313
+ # "error_code": -2,
314
+ "error_message": "Wrong Answer",
315
+ }
316
+ # reset the alarm
317
+ signal.alarm(0)
318
+ except Exception as e:
319
+ signal.alarm(0)
320
+ error_traceback = traceback.format_exc()
321
+ faulthandler.disable()
322
+ if debug:
323
+ print(
324
+ f"Standard input runtime error or time limit exceeded error = {e}"
325
+ )
326
+ results.append(-1)
327
+ if "timeoutexception" in repr(e).lower():
328
+ return results, {
329
+ "error": repr(e),
330
+ # "error_code": -3,
331
+ # "error_message": "Time Limit Exceeded",
332
+ # "inputs": raw_inputs,
333
+ # "expected": raw_outputs,
334
+ "traceback": clean_traceback(error_traceback),
335
+ }
336
+ else:
337
+ return results, {
338
+ "error": repr(e),
339
+ # "error_code": -4,
340
+ # "error_message": "Runtime Error",
341
+ # "inputs": raw_inputs,
342
+ # "expected": raw_outputs,
343
+ "traceback": clean_traceback(error_traceback),
344
+ }
345
+ faulthandler.disable()
346
+ signal.alarm(0)
347
+ if debug:
348
+ print(
349
+ f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
350
+ )
351
+ elif which_type == CODE_TYPE.standard_input: # Standard input
352
+ faulthandler.enable()
353
+ passed = False
354
+
355
+ if isinstance(inputs, list):
356
+ inputs = "\n".join(inputs)
357
+ if isinstance(in_outs["outputs"][index], list):
358
+ in_outs["outputs"][index] = "\n".join(in_outs["outputs"][index])
359
+
360
+ signal.alarm(timeout)
361
+ with Capturing() as output:
362
+ try:
363
+ call_method(method, inputs)
364
+ # reset the alarm
365
+ signal.alarm(0)
366
+ passed = True
367
+ except Exception as e:
368
+ # runtime error or took too long
369
+ signal.alarm(0)
370
+ error_traceback = traceback.format_exc()
371
+ print(
372
+ f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}"
373
+ )
374
+ results.append(-1)
375
+ if "timeoutexception" in repr(e).lower():
376
+ return results, {
377
+ "error": repr(e),
378
+ # "error_code": -3,
379
+ # "error_message": "Time Limit Exceeded",
380
+ # "inputs": raw_inputs,
381
+ # "expected": raw_outputs,
382
+ "traceback": clean_traceback(error_traceback),
383
+ }
384
+ else:
385
+ return results, {
386
+ "error": repr(e),
387
+ # "error_code": -4,
388
+ # "error_message": "Runtime Error",
389
+ # "inputs": raw_inputs,
390
+ # "expected": raw_outputs,
391
+ "traceback": clean_traceback(error_traceback),
392
+ }
393
+ signal.alarm(0)
394
+ raw_true_output = output[0]
395
+ raw_true_output_copy = truncatefn(raw_true_output, 200)
396
+ output = raw_true_output.splitlines()
397
+ if not passed:
398
+ if debug:
399
+ nl = "\n"
400
+ if not isinstance(inputs, list):
401
+ print(
402
+ f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
403
+ )
404
+ else:
405
+ print(
406
+ f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
407
+ )
408
+ continue
409
+
410
+ if passed and debug:
411
+ print(
412
+ f"==> output = {output}, test outputs = {in_outs['outputs'][index]}"
413
+ )
414
+
415
+ if custom_compare_(output, in_outs["outputs"][index]):
416
+ tmp_result = True
417
+ results.append(tmp_result)
418
+ continue
419
+
420
+ # ground truth sequences are expressed as lists not tuples
421
+ if isinstance(output, tuple):
422
+ output = list(output)
423
+
424
+ tmp_result = False
425
+ try:
426
+ tmp_result = output == [in_outs["outputs"][index]]
427
+ if isinstance(in_outs["outputs"][index], list):
428
+ tmp_result = tmp_result or (output == in_outs["outputs"][index])
429
+ if isinstance(output[0], str):
430
+ tmp_result = tmp_result or (
431
+ [e.strip() for e in output] == in_outs["outputs"][index]
432
+ )
433
+ except Exception as e:
434
+ if debug:
435
+ print(f"Failed check1 exception = {e}")
436
+ pass
437
+
438
+ if tmp_result == True:
439
+ results.append(tmp_result)
440
+ continue
441
+
442
+ # try one more time without \n
443
+ if isinstance(in_outs["outputs"][index], list):
444
+ for tmp_index, i in enumerate(in_outs["outputs"][index]):
445
+ in_outs["outputs"][index][tmp_index] = i.split("\n")
446
+ in_outs["outputs"][index][tmp_index] = [
447
+ x.strip() for x in in_outs["outputs"][index][tmp_index] if x
448
+ ]
449
+ else:
450
+ in_outs["outputs"][index] = in_outs["outputs"][index].split("\n")
451
+ in_outs["outputs"][index] = list(
452
+ filter(len, in_outs["outputs"][index])
453
+ )
454
+ in_outs["outputs"][index] = list(
455
+ map(lambda x: x.strip(), in_outs["outputs"][index])
456
+ )
457
+
458
+ try:
459
+ tmp_result = output == [in_outs["outputs"][index]]
460
+ if isinstance(in_outs["outputs"][index], list):
461
+ tmp_result = tmp_result or (output == in_outs["outputs"][index])
462
+ except Exception as e:
463
+ if debug:
464
+ print(f"Failed check2 exception = {e}")
465
+ pass
466
+
467
+ if tmp_result == True:
468
+ results.append(tmp_result)
469
+ continue
470
+
471
+ # try by converting the output into a split up list too
472
+ if isinstance(output, list):
473
+ output = list(filter(len, output))
474
+
475
+ if debug:
476
+ nl = "\n"
477
+ if not isinstance(inputs, list):
478
+ print(
479
+ f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}"
480
+ )
481
+ else:
482
+ print(
483
+ f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}"
484
+ )
485
+
486
+ if tmp_result == True:
487
+ results.append(tmp_result)
488
+ continue
489
+
490
+ if debug:
491
+ print(f"{tmp_result=} @a")
492
+
493
+ try:
494
+ tmp_result = output == [in_outs["outputs"][index]]
495
+ if isinstance(in_outs["outputs"][index], list):
496
+ tmp_result = tmp_result or (output == in_outs["outputs"][index])
497
+ except Exception as e:
498
+ if debug:
499
+ print(f"Failed check3 exception = {e}")
500
+ pass
501
+
502
+ if debug:
503
+ print(f"{tmp_result=} @b")
504
+
505
+ try:
506
+ all_ints = all(
507
+ combined_int_check(e1) and combined_int_check(e2)
508
+ for e1, e2 in zip(output, in_outs["outputs"][index])
509
+ )
510
+ if not all_ints:
511
+ if debug:
512
+ print(
513
+ [
514
+ combined_int_check(e1) and combined_int_check(e2)
515
+ for e1, e2 in zip(output, in_outs["outputs"][index])
516
+ ]
517
+ )
518
+ output_float = [float(e) for e in output]
519
+ gt_float = [float(e) for e in in_outs["outputs"][index]]
520
+ tmp_result = tmp_result or (
521
+ (len(output_float) == len(gt_float))
522
+ and np.allclose(output_float, gt_float)
523
+ )
524
+ except Exception as e:
525
+ pass
526
+
527
+ if debug:
528
+ print(f"{tmp_result=} @c")
529
+
530
+ try:
531
+ if isinstance(output[0], list):
532
+ all_ints = all(
533
+ combined_int_check(e1) and combined_int_check(e2)
534
+ for e1, e2 in zip(output[0], in_outs["outputs"][index])
535
+ )
536
+ if not all_ints:
537
+ output_float = [float(e) for e in output[0]]
538
+ gt_float = [float(e) for e in in_outs["outputs"][index][0]]
539
+ tmp_result = tmp_result or (
540
+ (len(output_float) == len(gt_float))
541
+ and np.allclose(output_float, gt_float)
542
+ )
543
+ except Exception as e:
544
+ pass
545
+
546
+ if tmp_result == True:
547
+ results.append(tmp_result)
548
+ continue
549
+
550
+ if debug:
551
+ print(f"{tmp_result=} @d")
552
+ # try by converting the stuff into split up list
553
+ if isinstance(in_outs["outputs"][index], list):
554
+ for tmp_index, i in enumerate(in_outs["outputs"][index]):
555
+ in_outs["outputs"][index][tmp_index] = set(i.split())
556
+ else:
557
+ in_outs["outputs"][index] = set(in_outs["outputs"][index].split())
558
+
559
+ if debug:
560
+ print(f"{tmp_result=} @e")
561
+
562
+ try:
563
+ tmp_result = output == in_outs["outputs"][index]
564
+ except Exception as e:
565
+ if debug:
566
+ print(f"Failed check4 exception = {e}")
567
+ continue
568
+
569
+ if tmp_result == True:
570
+ results.append(tmp_result)
571
+ continue
572
+
573
+ if debug:
574
+ print(f"{tmp_result=} @f")
575
+
576
+ # try by converting the output into a split up list too
577
+ if isinstance(output, list):
578
+ for tmp_index, i in enumerate(output):
579
+ output[tmp_index] = i.split()
580
+ output = list(filter(len, output))
581
+ for tmp_index, i in enumerate(output):
582
+ output[tmp_index] = set(i)
583
+ else:
584
+ output = output.split()
585
+ output = list(filter(len, output))
586
+ output = set(output)
587
+
588
+ if debug:
589
+ print(f"{tmp_result=} @g")
590
+
591
+ if tmp_result == True and debug:
592
+ print("PASSED")
593
+
594
+ results.append(tmp_result)
595
+ if tmp_result != True:
596
+ return results, {
597
+ "output": raw_true_output_copy,
598
+ "expected": raw_outputs,
599
+ "inputs": raw_inputs,
600
+ # "error_code": -2,
601
+ "error_message": "Wrong Answer",
602
+ }
603
+
604
+ if debug:
605
+ nl = "\n"
606
+ if not isinstance(inputs, list):
607
+ print(
608
+ f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
609
+ )
610
+ else:
611
+ print(
612
+ f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
613
+ )
614
+
615
+ print(f"results = {results}")
616
+
617
+ return results, {}
618
+
619
+
620
+ def custom_compare_(output, ground_truth):
621
+
622
+ if isinstance(output, list):
623
+ output_1 = "\n".join(output)
624
+ if stripped_string_compare(output_1, ground_truth):
625
+ return True
626
+
627
+ if isinstance(output, list):
628
+ output_2 = [o.lstrip().rstrip() for o in output]
629
+ output_2 = "\n".join(output_2)
630
+ if stripped_string_compare(output_2, ground_truth):
631
+ return True
632
+
633
+ return False
634
+
635
+
636
+ def stripped_string_compare(s1, s2):
637
+ s1 = s1.lstrip().rstrip()
638
+ s2 = s2.lstrip().rstrip()
639
+ return s1 == s2
640
+
641
+
642
+ def call_method(method, inputs):
643
+
644
+ if isinstance(inputs, list):
645
+ inputs = "\n".join(inputs)
646
+
647
+ inputs_line_iterator = iter(inputs.split("\n"))
648
+
649
+ # sys.setrecursionlimit(10000)
650
+
651
+ # @patch('builtins.input', side_effect=inputs.split("\n"))
652
+ @patch("builtins.open", mock_open(read_data=inputs))
653
+ @patch("sys.stdin", StringIO(inputs))
654
+ @patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator))
655
+ @patch("sys.stdin.readlines", lambda *args: inputs.split("\n"))
656
+ @patch("sys.stdin.read", lambda *args: inputs)
657
+ # @patch('sys.stdout.write', print)
658
+ def _inner_call_method(_method):
659
+ try:
660
+ return _method()
661
+ except SystemExit as e:
662
+ pass
663
+ finally:
664
+ pass
665
+
666
+ return _inner_call_method(method)
667
+
668
+
669
+ def reliability_guard(maximum_memory_bytes=None):
670
+ """
671
+ This disables various destructive functions and prevents the generated code
672
+ from interfering with the test (e.g. fork bomb, killing other processes,
673
+ removing filesystem files, etc.)
674
+ WARNING
675
+ This function is NOT a security sandbox. Untrusted code, including, model-
676
+ generated code, should not be blindly executed outside of one. See the
677
+ Codex paper for more information about OpenAI's code sandbox, and proceed
678
+ with caution.
679
+ """
680
+
681
+ if maximum_memory_bytes is not None:
682
+ import resource
683
+
684
+ resource.setrlimit(
685
+ resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)
686
+ )
687
+ resource.setrlimit(
688
+ resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)
689
+ )
690
+ if not platform.uname().system == "Darwin":
691
+ resource.setrlimit(
692
+ resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)
693
+ )
694
+
695
+ faulthandler.disable()
696
+
697
+ import builtins
698
+
699
+ builtins.exit = None
700
+ builtins.quit = None
701
+
702
+ import os
703
+
704
+ os.environ["OMP_NUM_THREADS"] = "1"
705
+
706
+ os.kill = None
707
+ os.system = None # 防止干扰repl评测
708
+ os.putenv = None
709
+ os.remove = None
710
+ os.removedirs = None
711
+ os.rmdir = None
712
+ os.fchdir = None
713
+ os.setuid = None
714
+ os.fork = None
715
+ os.forkpty = None
716
+ os.killpg = None
717
+ os.rename = None
718
+ os.renames = None
719
+ os.truncate = None
720
+ os.replace = None
721
+ os.unlink = None
722
+ os.fchmod = None
723
+ os.fchown = None
724
+ os.chmod = None
725
+ os.chown = None
726
+ os.chroot = None
727
+ os.fchdir = None
728
+ os.lchflags = None
729
+ os.lchmod = None
730
+ os.lchown = None
731
+ os.getcwd = None
732
+ os.chdir = None
733
+
734
+ import shutil
735
+
736
+ shutil.rmtree = None
737
+ shutil.move = None
738
+ shutil.chown = None
739
+
740
+ import subprocess
741
+
742
+ subprocess.Popen = None # type: ignore
743
+
744
+ __builtins__["help"] = None
745
+
746
+ import sys
747
+
748
+ sys.modules["ipdb"] = None
749
+ sys.modules["joblib"] = None
750
+ sys.modules["resource"] = None
751
+ sys.modules["psutil"] = None
752
+ sys.modules["tkinter"] = None
dataset_construction/evaluation_utils/code_util/utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Borrowed from: https://huggingface.co/spaces/codeparrot/apps_metric/blob/main/utils.py
2
+
3
+
4
+ import multiprocessing
5
+ from typing import Dict, Optional
6
+ from datasets import load_dataset
7
+ from .testing_util import run_test
8
+ import traceback
9
+ import os,sys
10
+
11
+ def _temp_run(sample, generation, debug, result,metadata_list,timeout):
12
+ # test is run silently. If it is killed, nothing will be printed
13
+ with open(os.devnull, 'w') as devnull:
14
+ sys.stdout = devnull
15
+ sys.stderr = devnull
16
+ try:
17
+ res, metadata= run_test(in_outs=sample, test=generation, debug=debug,timeout=timeout)
18
+ result.append(res)
19
+ metadata_list.append(metadata)
20
+ except Exception as e:
21
+ # print(e) # some tracebacks are extremely long.
22
+ traceback.print_exc(10)
23
+ result.append([-1 for i in range(len(sample['inputs']))])
24
+ metadata_list.append({})
25
+
26
+ def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True):
27
+ """Check correctness of code generation with a global timeout.
28
+ The global timeout is to catch some extreme/rare cases not handled by the timeouts
29
+ inside `run_test`"""
30
+
31
+ manager = multiprocessing.Manager()
32
+ result = manager.list()
33
+ metadata_list = manager.list()
34
+ p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result,metadata_list,timeout))
35
+ p.start()
36
+ p.join(timeout=timeout + 1)
37
+ if p.is_alive():
38
+ p.kill()
39
+ # p.terminate()
40
+ if not result:
41
+ # consider that all tests failed
42
+ result = [[-1 for i in range(len(in_outs["inputs"]))]]
43
+ if debug:
44
+ print(f"global timeout")
45
+ return result[0], metadata_list
dataset_construction/evaluation_utils/math_util/__init__.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Answer checker API that uses sympy to simplify expressions and check for equality.
3
+
4
+ Call grade_answer(given_answer: str, ground_truth: str).
5
+
6
+ FROM: https://github.com/openai/prm800k/blob/main/prm800k/grading/grader.py
7
+ """
8
+ import re
9
+ import sympy
10
+ from pylatexenc import latex2text
11
+ from sympy.parsing import sympy_parser
12
+
13
+ from . import math_normalize
14
+ from .grader import math_equal
15
+ # import math_normalize
16
+ # from grader import math_equal
17
+
18
+ # sympy might hang -- we don't care about trying to be lenient in these cases
19
+ BAD_SUBSTRINGS = ["^{", "^("]
20
+ BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"]
21
+ TUPLE_CHARS = "()[]"
22
+
23
+
24
+ def _sympy_parse(expr: str):
25
+ """Parses an expression with sympy."""
26
+ py_expr = expr.replace("^", "**")
27
+ return sympy_parser.parse_expr(
28
+ py_expr,
29
+ transformations=(
30
+ sympy_parser.standard_transformations
31
+ + (sympy_parser.implicit_multiplication_application,)
32
+ ),
33
+ )
34
+
35
+
36
+ def _parse_latex(expr: str) -> str:
37
+ """Attempts to parse latex to an expression sympy can read."""
38
+ expr = expr.replace("\\tfrac", "\\frac")
39
+ expr = expr.replace("\\dfrac", "\\frac")
40
+ expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers.
41
+ expr = latex2text.LatexNodes2Text().latex_to_text(expr)
42
+
43
+ # Replace the specific characters that this parser uses.
44
+ expr = expr.replace("√", "sqrt")
45
+ expr = expr.replace("π", "pi")
46
+ expr = expr.replace("∞", "inf")
47
+ expr = expr.replace("∪", "U")
48
+ expr = expr.replace("·", "*")
49
+ expr = expr.replace("×", "*")
50
+
51
+ return expr.strip()
52
+
53
+
54
+ def _is_float(num: str) -> bool:
55
+ try:
56
+ float(num)
57
+ return True
58
+ except ValueError:
59
+ return False
60
+
61
+
62
+ def _is_int(x: float) -> bool:
63
+ try:
64
+ return abs(x - int(round(x))) <= 1e-7
65
+ except:
66
+ return False
67
+
68
+
69
+ def _is_frac(expr: str) -> bool:
70
+ return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr))
71
+
72
+
73
+ def _str_is_int(x: str) -> bool:
74
+ try:
75
+ x = _strip_properly_formatted_commas(x)
76
+ x = float(x)
77
+ return abs(x - int(round(x))) <= 1e-7
78
+ except:
79
+ return False
80
+
81
+
82
+ def _str_to_int(x: str) -> bool:
83
+ x = x.replace(",", "")
84
+ x = float(x)
85
+ return int(x)
86
+
87
+
88
+ def _inject_implicit_mixed_number(step: str):
89
+ """
90
+ Automatically make a mixed number evalable
91
+ e.g. 7 3/4 => 7+3/4
92
+ """
93
+ p1 = re.compile("([0-9]) +([0-9])")
94
+ step = p1.sub("\\1+\\2", step) ## implicit mults
95
+ return step
96
+
97
+
98
+ def _strip_properly_formatted_commas(expr: str):
99
+ # We want to be careful because we don't want to strip tuple commas
100
+ p1 = re.compile("(\d)(,)(\d\d\d)($|\D)")
101
+ while True:
102
+ next_expr = p1.sub("\\1\\3\\4", expr)
103
+ if next_expr == expr:
104
+ break
105
+ expr = next_expr
106
+ return next_expr
107
+
108
+
109
+ def _normalize(expr: str) -> str:
110
+ """Normalize answer expressions."""
111
+ if expr is None:
112
+ return None
113
+
114
+ # Remove enclosing `\text{}`.
115
+ m = re.search("^\\\\text\{(?P<text>.+?)\}$", expr)
116
+ if m is not None:
117
+ expr = m.group("text")
118
+
119
+ expr = expr.replace("\\%", "%")
120
+ expr = expr.replace("\\$", "$")
121
+ expr = expr.replace("$", "")
122
+ expr = expr.replace("%", "")
123
+ expr = expr.replace(" or ", " , ")
124
+ expr = expr.replace(" and ", " , ")
125
+
126
+ expr = expr.replace("million", "*10^6")
127
+ expr = expr.replace("billion", "*10^9")
128
+ expr = expr.replace("trillion", "*10^12")
129
+
130
+ for unit in [
131
+ "degree",
132
+ "cm",
133
+ "centimeter",
134
+ "meter",
135
+ "mile",
136
+ "second",
137
+ "minute",
138
+ "hour",
139
+ "day",
140
+ "week",
141
+ "month",
142
+ "year",
143
+ "foot",
144
+ "feet",
145
+ "inch",
146
+ "yard",
147
+ "liter",
148
+ ]:
149
+ expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr)
150
+ expr = re.sub(f"\^ *\\\\circ", "", expr)
151
+ # expr = re.sub(f"\^*\\\\circ", "", expr)
152
+
153
+ if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}":
154
+ expr = expr[1:-1]
155
+
156
+ expr = re.sub(",\\\\! *", "", expr)
157
+ if _is_float(expr) and _is_int(float(expr)):
158
+ expr = str(int(round(float(expr))))
159
+ if "\\" in expr:
160
+ try:
161
+ expr = _parse_latex(expr)
162
+ except:
163
+ pass
164
+
165
+ # edge case with mixed numbers and negative signs
166
+ expr = re.sub("- *", "-", expr)
167
+
168
+ expr = _inject_implicit_mixed_number(expr)
169
+ # expr = expr.replace(" ", "")
170
+
171
+ # # if we somehow still have latex braces here, just drop them
172
+ # expr = expr.replace("{", "")
173
+ # expr = expr.replace("}", "")
174
+
175
+ # don't be case sensitive for text answers
176
+ expr = expr.lower()
177
+
178
+ if _str_is_int(expr):
179
+ expr = str(_str_to_int(expr))
180
+
181
+ return expr
182
+
183
+
184
+ def count_unknown_letters_in_expr(expr: str):
185
+ expr = expr.replace("sqrt", "")
186
+ expr = expr.replace("frac", "")
187
+ letters_in_expr = set([x for x in expr if x.isalpha()])
188
+ return len(letters_in_expr)
189
+
190
+
191
+ def should_allow_eval(expr: str):
192
+ # we don't want to try parsing unknown text or functions of more than two variables
193
+ if count_unknown_letters_in_expr(expr) > 2:
194
+ return False
195
+
196
+ for bad_string in BAD_SUBSTRINGS:
197
+ if bad_string in expr:
198
+ return False
199
+
200
+ for bad_regex in BAD_REGEXES:
201
+ if re.search(bad_regex, expr) is not None:
202
+ return False
203
+
204
+ return True
205
+
206
+
207
+ def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):
208
+ are_equal = False
209
+ try:
210
+ expr = f"({ground_truth_normalized})-({given_normalized})"
211
+ if should_allow_eval(expr):
212
+ sympy_diff = _sympy_parse(expr)
213
+ simplified = sympy.simplify(sympy_diff)
214
+ if simplified == 0:
215
+ are_equal = True
216
+ except:
217
+ pass
218
+ return are_equal
219
+
220
+
221
+ def split_tuple(expr: str):
222
+ """
223
+ Split the elements in a tuple/interval, while handling well-formatted commas in large numbers
224
+ """
225
+ expr = _strip_properly_formatted_commas(expr)
226
+ if len(expr) == 0:
227
+ return []
228
+ if (
229
+ len(expr) > 2
230
+ and expr[0] in TUPLE_CHARS
231
+ and expr[-1] in TUPLE_CHARS
232
+ and all([ch not in expr[1:-1] for ch in TUPLE_CHARS])
233
+ ):
234
+ elems = [elem.strip() for elem in expr[1:-1].split(",")]
235
+ else:
236
+ elems = [expr]
237
+ return elems
238
+
239
+
240
+ def grade_answer(given_answer: str, ground_truth: str) -> bool:
241
+ """
242
+ The answer will be considered correct if:
243
+ (a) it normalizes to the same string as the ground truth answer
244
+ OR
245
+ (b) sympy can simplify the difference between the expressions to 0
246
+ """
247
+ if given_answer is None:
248
+ return False
249
+
250
+ ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth)
251
+ given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer)
252
+
253
+ # be at least as lenient as mathd
254
+ if ground_truth_normalized_mathd == given_answer_normalized_mathd:
255
+ return True
256
+
257
+ ground_truth_normalized = _normalize(ground_truth)
258
+ given_normalized = _normalize(given_answer)
259
+
260
+ if ground_truth_normalized is None:
261
+ return False
262
+
263
+ if ground_truth_normalized == given_normalized:
264
+ return True
265
+
266
+ if len(given_normalized) == 0:
267
+ return False
268
+
269
+ ground_truth_elems = split_tuple(ground_truth_normalized)
270
+ given_elems = split_tuple(given_normalized)
271
+
272
+ if len(ground_truth_elems) > 1 and (
273
+ ground_truth_normalized[0] != given_normalized[0]
274
+ or ground_truth_normalized[-1] != given_normalized[-1]
275
+ ):
276
+ is_correct = False
277
+ elif len(ground_truth_elems) != len(given_elems):
278
+ is_correct = False
279
+ else:
280
+ for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems):
281
+ if _is_frac(ground_truth_elem) and _is_frac(given_elem):
282
+ # if fractions aren't reduced, then shouldn't be marked as correct
283
+ # so, we don't want to allow sympy.simplify in this case
284
+ is_correct = ground_truth_elem == given_elem
285
+ elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem):
286
+ # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify)
287
+ is_correct = False
288
+ else:
289
+ is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)
290
+ if not is_correct:
291
+ break
292
+
293
+ return is_correct
294
+
295
+ def remove_boxed(s):
296
+ left = "\\boxed{"
297
+ try:
298
+ assert s[:len(left)] == left
299
+ assert s[-1] == "}"
300
+ return s[len(left):-1]
301
+ except:
302
+ return None
303
+
304
+ def _last_boxed_only_string(string):
305
+ idx = string.rfind("\\boxed")
306
+ if idx < 0:
307
+ idx = string.rfind("\\fbox")
308
+ if idx < 0:
309
+ return None
310
+
311
+ i = idx
312
+ left_brace_idx = None
313
+ right_brace_idx = None
314
+ num_left_braces_open = 0
315
+ while i < len(string):
316
+ if string[i] == "{":
317
+ num_left_braces_open += 1
318
+ if left_brace_idx is None:
319
+ left_brace_idx = i
320
+ elif string[i] == "}":
321
+ num_left_braces_open -= 1
322
+ if num_left_braces_open == 0:
323
+ right_brace_idx = i
324
+ break
325
+
326
+ i += 1
327
+
328
+ if left_brace_idx is None or right_brace_idx is None:
329
+ return None
330
+
331
+ return string[left_brace_idx + 1: right_brace_idx].strip()
332
+
333
+ def match_answer(response):
334
+ is_matched = False
335
+ for ans_marker in ['answer:', "answer is", "answers are"]:
336
+ ans_idx = response.lower().rfind(ans_marker)
337
+ if ans_idx != -1:
338
+ is_matched = True
339
+ response = response[ans_idx + len(ans_marker):].strip()
340
+ if response.endswith("\n"):
341
+ response = response[:-2]
342
+
343
+ for ans_marker in ["is answer", "is the answer", "are answers", "are the answers"]:
344
+ ans_idx = response.lower().rfind(ans_marker)
345
+ if ans_idx != -1:
346
+ is_matched = True
347
+ response = response[:ans_idx].strip()
348
+ if response.endswith("\n"):
349
+ response = response[:-2]
350
+
351
+ # Find boxed
352
+ ans_boxed = _last_boxed_only_string(response)
353
+ if ans_boxed:
354
+ is_matched = True
355
+ response = ans_boxed
356
+
357
+ if ". " in response:
358
+ dot_idx = response.lower().rfind(". ")
359
+ if dot_idx != -1:
360
+ response = response[:dot_idx].strip()
361
+
362
+ for ans_marker in ['be ', "is ", "are ", "=", ": ", "get ", 'be\n', "is\n", "are\n", ":\n", "get\n"]:
363
+ ans_idx = response.lower().rfind(ans_marker)
364
+ if ans_idx != -1:
365
+ is_matched = True
366
+ response = response[ans_idx + len(ans_marker):].strip()
367
+ if response.endswith("\n"):
368
+ response = response[:-2]
369
+
370
+ is_matched = is_matched if any([c.isdigit() for c in response]) else False # answer must have a digit
371
+ # Grade
372
+ return is_matched, response
373
+
374
+
375
+ import math
376
+ def evaluate_math(model_output: str, ground_truth: str) -> bool:
377
+ model_output = str(model_output)
378
+ ground_truth = str(ground_truth)
379
+
380
+ is_matched, extracted_model_output = match_answer(model_output)
381
+ format_correctness = "Step 2:" in model_output and "\\box" in model_output
382
+ # print(f"{model_output=}")
383
+ # print("\n")
384
+ # print(f"{extracted_model_output=}")
385
+ # print("\n")
386
+ # print(f"{ground_truth=}")
387
+ # print("="*20)
388
+ # print("\n\n")
389
+
390
+ # grade simple algebra questions. if succeed, return; otherwise, proceed to more complex grading
391
+ if grade_answer(extracted_model_output, ground_truth):
392
+ return True, True, extracted_model_output
393
+ # return True
394
+
395
+ try:
396
+ if "\pi" in extracted_model_output or "\pi" in ground_truth:
397
+ equivs = []
398
+ for pi in [math.pi, 3.14]:
399
+ equivs.append(math_equal(extracted_model_output, ground_truth, timeout=True, pi=pi))
400
+ is_correct = any(equivs)
401
+ else:
402
+ is_correct = math_equal(extracted_model_output, ground_truth, timeout=True)
403
+ except:
404
+ is_correct = False
405
+
406
+ # print(f"{extracted_model_output=}\n", f"{model_output=}\n", f"{ground_truth=}\n")
407
+
408
+ return is_correct, format_correctness, extracted_model_output
409
+
410
+
dataset_construction/evaluation_utils/math_util/grader.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Copyright (c) Microsoft Corporation.
16
+ #
17
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
18
+ # of this software and associated documentation files (the "Software"), to deal
19
+ # in the Software without restriction, including without limitation the rights
20
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
21
+ # copies of the Software, and to permit persons to whom the Software is
22
+ # furnished to do so, subject to the following conditions:
23
+ #
24
+ # The above copyright notice and this permission notice shall be included in all
25
+ # copies or substantial portions of the Software.
26
+ #
27
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
28
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
29
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
30
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
31
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
32
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
33
+ # SOFTWARE
34
+
35
+ # Copyright (c) 2023 OpenAI
36
+ #
37
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
38
+ # of this software and associated documentation files (the "Software"), to deal
39
+ # in the Software without restriction, including without limitation the rights
40
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
41
+ # copies of the Software, and to permit persons to whom the Software is
42
+ # furnished to do so, subject to the following conditions:
43
+
44
+ # The above copyright notice and this permission notice shall be included in all
45
+ # copies or substantial portions of the Software.
46
+ #
47
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
48
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
49
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
50
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
51
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
52
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
53
+ # SOFTWARE.
54
+
55
+ # Copyright (c) 2021 Dan Hendrycks
56
+ #
57
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
58
+ # of this software and associated documentation files (the "Software"), to deal
59
+ # in the Software without restriction, including without limitation the rights
60
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
61
+ # copies of the Software, and to permit persons to whom the Software is
62
+ # furnished to do so, subject to the following conditions:
63
+ #
64
+ # The above copyright notice and this permission notice shall be included in all
65
+ # copies or substantial portions of the Software.
66
+ #
67
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
68
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
69
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
70
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
71
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
72
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
73
+ # SOFTWARE.
74
+
75
+
76
+ """
77
+ This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
78
+ - https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py
79
+ - https://github.com/microsoft/ProphetNet/tree/master/CRITIC
80
+ - https://github.com/openai/prm800k
81
+ """
82
+
83
+
84
+ import contextlib
85
+ import re
86
+ import signal
87
+ import math
88
+ from math import isclose
89
+ from typing import Union
90
+
91
+ from sympy import N, simplify
92
+ from sympy.parsing.latex import parse_latex
93
+ from sympy.parsing.sympy_parser import parse_expr
94
+
95
+
96
+ def is_digit(s):
97
+ try:
98
+ if "{,}" in str(s):
99
+ num = float(str(s).replace("{,}", ""))
100
+ return True, num
101
+
102
+ num = float(str(s).replace(",", ""))
103
+ return True, num
104
+ except ValueError:
105
+ return False, None
106
+
107
+
108
+ def normalize(answer, pi) -> str:
109
+ # checking if answer is $<number> and removing $ in that case to compare
110
+ if isinstance(answer, str) and bool(re.match(r'\$\d+(\.\d+)?', answer)):
111
+ return answer[1:]
112
+
113
+ # checking if answer is <number>% or <number>\\% and removing %
114
+ if isinstance(answer, str) and (
115
+ bool(re.match(r'^\d+(\.\d+)?%$', answer)) or bool(re.match(r'^\d+(\.\d+)?\\%$', answer))
116
+ ):
117
+ return answer.replace("\\%", "").replace("%", "")
118
+
119
+ # handle base
120
+ answer = handle_base(answer)
121
+
122
+ # handle pi
123
+ answer = handle_pi(answer, pi)
124
+
125
+ return answer
126
+
127
+ def handle_base(x) -> str:
128
+ if isinstance(x, str) and "_" in x:
129
+ # Due to base
130
+ x = x.split("_")[0]
131
+ x = float(x)
132
+ return int(x)
133
+ return x
134
+
135
+
136
+ def handle_pi(string, pi):
137
+
138
+ if isinstance(string, str) and "\pi" in string:
139
+ # Find the first occurrence of "\pi"
140
+ idx = string.find("\pi")
141
+
142
+ # Iterate over the string and find all occurrences of "\pi" with a valid previous character
143
+ while idx != -1:
144
+
145
+ if idx > 0 and string[idx-1].isdigit():
146
+ # Replace "\pi" with "*math.pi" if the previous character is a digit
147
+ string = string[:idx] + f"*{pi}" + string[idx+3:]
148
+ else:
149
+ # Replace "\pi" with "1*math.pi" if the previous character is not a digit
150
+ string = string[:idx] + f"1*{pi}" + string[idx+3:]
151
+
152
+ # Find the next occurrence of "\pi"
153
+ idx = string.find("\pi", idx + 1)
154
+
155
+ # Evaluate the expression using eval() function
156
+ try:
157
+ string = eval(string)
158
+ except:
159
+ pass
160
+
161
+ return string
162
+
163
+ def math_equal(
164
+ prediction: Union[bool, float, str],
165
+ reference: Union[float, str],
166
+ include_percentage: bool = True,
167
+ tolerance: float = 1e-4,
168
+ timeout: float = 10.0,
169
+ pi: float = math.pi
170
+ ) -> bool:
171
+ """
172
+ Exact match of math if and only if:
173
+ 1. numerical equal: both can convert to float and are equal
174
+ 2. symbolic equal: both can convert to sympy expression and are equal
175
+ """
176
+
177
+ prediction = normalize(prediction, pi)
178
+ reference = normalize(reference, pi)
179
+
180
+ if isinstance(prediction, str) and len(prediction) > 1000: # handling weird corner-cases
181
+ prediction = prediction[:1000]
182
+
183
+ # 0. string comparison
184
+ if isinstance(prediction, str) and isinstance(reference, str):
185
+ if prediction.strip().lower() == reference.strip().lower():
186
+ return True
187
+ if prediction.replace(" ", "") == reference.replace(" ", ""):
188
+ return True
189
+
190
+ try: # 1. numerical equal
191
+ if is_digit(prediction)[0] and is_digit(reference)[0]:
192
+ prediction = is_digit(prediction)[1]
193
+ reference = is_digit(reference)[1]
194
+ # number questions
195
+ if include_percentage:
196
+ gt_result = [reference / 100, reference, reference * 100]
197
+ else:
198
+ gt_result = [reference]
199
+ for item in gt_result:
200
+ try:
201
+ if isclose(item, prediction, rel_tol=tolerance):
202
+ return True
203
+ except Exception:
204
+ continue
205
+ return False
206
+ except Exception:
207
+ pass
208
+
209
+ if not prediction and prediction not in [0, False]:
210
+ return False
211
+
212
+ # 2. symbolic equal
213
+ reference = str(reference).strip()
214
+ prediction = str(prediction).strip()
215
+
216
+ ## deal with [], (), {}
217
+ prediction = format_intervals(prediction)
218
+
219
+ pred_str, ref_str = prediction, reference
220
+ if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or (
221
+ prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")
222
+ ):
223
+ pred_str = pred_str.strip("[]()")
224
+ ref_str = ref_str.strip("[]()")
225
+ for s in ["{", "}", "(", ")"]:
226
+ ref_str = ref_str.replace(s, "")
227
+ pred_str = pred_str.replace(s, "")
228
+ if pred_str == ref_str:
229
+ return True
230
+
231
+ ## [a, b] vs. [c, d], return a==c and b==d
232
+ if (
233
+ prediction
234
+ and reference
235
+ and prediction[0] in "(["
236
+ and prediction[-1] in ")]"
237
+ and prediction[0] == reference[0]
238
+ and prediction[-1] == reference[-1]
239
+ ):
240
+ pred_parts = prediction[1:-1].split(",")
241
+ ref_parts = reference[1:-1].split(",")
242
+ if len(pred_parts) == len(ref_parts):
243
+ if all(
244
+ [
245
+ math_equal(pred_pt, ref_pt, include_percentage, tolerance)
246
+ for pred_pt, ref_pt in zip(pred_parts, ref_parts)
247
+ ]
248
+ ):
249
+ return True
250
+
251
+ if "," in prediction and "," in reference:
252
+ pred_parts = [item.strip() for item in prediction.split(",")]
253
+ ref_parts = [item.strip() for item in reference.split(",")]
254
+
255
+ if len(pred_parts) == len(ref_parts):
256
+ if all(
257
+ [
258
+ math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance)
259
+ for i in range(len(pred_parts))
260
+ ]
261
+ ):
262
+ return True
263
+ else:
264
+ return False
265
+
266
+ # if we have point == tuple of values
267
+ if prediction.startswith("Point") and reference[0] == "(" and reference[-1] == ")":
268
+ pred_parts = prediction[prediction.find("(") + 1 : -1].split(",")
269
+ ref_parts = reference[1:-1].split(",")
270
+ if len(pred_parts) == len(ref_parts):
271
+ if all(
272
+ [
273
+ math_equal(pred_pt, ref_pt, include_percentage, tolerance)
274
+ for pred_pt, ref_pt in zip(pred_parts, ref_parts)
275
+ ]
276
+ ):
277
+ return True
278
+
279
+ # if reference is a matrix
280
+ if "\begin{pmatrix}" in reference and prediction.startswith("Matrix"):
281
+ try:
282
+ pred_matrix = parse_expr(prediction)
283
+ ref_matrix_items = reference.split()[1:-1:2]
284
+ if len(pred_matrix) == len(ref_matrix_items):
285
+ if all(
286
+ [
287
+ math_equal(pred, ref, include_percentage, tolerance)
288
+ for ref, pred in zip(ref_matrix_items, pred_matrix)
289
+ ]
290
+ ):
291
+ return True
292
+ except Exception:
293
+ pass
294
+ elif "\begin{pmatrix}" in reference and prediction.startswith("[") and prediction.endswith("]"):
295
+ if isinstance(eval(prediction), list):
296
+ try:
297
+ pred_matrix = eval(prediction)
298
+ # ref_matrix_items = reference.split()[1:-1:2]
299
+ ref_matrix_items = reference.lstrip("\\begin{pmatrix}").lstrip("\begin{pmatrix}").rstrip("\\end{pmatrix}").rstrip("\end{pmatrix}")
300
+ ref_matrix_items = ref_matrix_items.split("\\")
301
+ ref_matrix_items = [row.split("&") if "&" in row else row for row in ref_matrix_items]
302
+ if len(pred_matrix) == len(ref_matrix_items):
303
+ if all(
304
+ [
305
+ math_equal(pred, ref, include_percentage, tolerance)
306
+ for ref, pred in zip(ref_matrix_items, pred_matrix)
307
+ ]
308
+ ):
309
+ return True
310
+ except Exception:
311
+ pass
312
+
313
+ return symbolic_equal(prediction, reference, tolerance, timeout)
314
+
315
+
316
+ def symbolic_equal(a, b, tolerance, timeout=10.0):
317
+ def _parse(s):
318
+ for f in [parse_expr, parse_latex]:
319
+ try:
320
+ with time_limit(timeout):
321
+ return f(s)
322
+ except Exception:
323
+ pass
324
+ return s
325
+
326
+ a = _parse(a)
327
+ b = _parse(b)
328
+
329
+ try:
330
+ with time_limit(timeout):
331
+ if simplify(a - b) == 0:
332
+ return True
333
+ except Exception:
334
+ pass
335
+
336
+ try:
337
+ with time_limit(timeout):
338
+ if isclose(N(a), N(b), rel_tol=tolerance):
339
+ return True
340
+ except Exception:
341
+ pass
342
+ return False
343
+
344
+
345
+ class TimeoutException(Exception):
346
+ pass
347
+
348
+
349
+ @contextlib.contextmanager
350
+ def time_limit(seconds: float):
351
+ def signal_handler(signum, frame):
352
+ raise TimeoutException("Timed out!")
353
+
354
+ signal.setitimer(signal.ITIMER_REAL, seconds)
355
+ signal.signal(signal.SIGALRM, signal_handler)
356
+ try:
357
+ yield
358
+ finally:
359
+ signal.setitimer(signal.ITIMER_REAL, 0)
360
+
361
+
362
+ def format_intervals(prediction):
363
+ patterns = {
364
+ "Interval(": r"^Interval\((.*)\)$",
365
+ "Interval.Ropen(": r"^Interval\.Ropen\((.*)\)$",
366
+ "Interval.Lopen(": r"^Interval\.Lopen\((.*)\)$",
367
+ "Interval.open(": r"^Interval\.open\((.*)\)$",
368
+ }
369
+
370
+ for key, pattern in patterns.items():
371
+ match = re.match(pattern, prediction)
372
+ if match:
373
+ inner_content = match.group(1)
374
+
375
+ if key == "Interval(": # Intarval(a, b) == [a, b]
376
+ return f"[{inner_content}]"
377
+ elif key == "Interval.Ropen(": # Intarval.Ropen(a, b) == [a, b)
378
+ return f"[{inner_content})"
379
+ elif key == "Interval.Lopen(": # Intarval.Lopen(a, b) == (a, b]
380
+ return f"({inner_content}]"
381
+ elif key == "Interval.open(": # Intarval.open(a, b) == (a, b)
382
+ return f"({inner_content})"
383
+
384
+ return prediction
dataset_construction/evaluation_utils/math_util/math_normalize.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This logic is largely copied from the Hendrycks' MATH release (math_equivalence).
3
+
4
+ From: https://github.com/openai/prm800k/blob/main/prm800k/grading/math_normalize.py
5
+ """
6
+ import re
7
+ from typing import Optional
8
+
9
+
10
+ def normalize_answer(answer: Optional[str]) -> Optional[str]:
11
+ if answer is None:
12
+ return None
13
+ answer = answer.strip()
14
+ try:
15
+ # Remove enclosing `\text{}`.
16
+ m = re.search("^\\\\text\{(?P<text>.+?)\}$", answer)
17
+ if m is not None:
18
+ answer = m.group("text").strip()
19
+ return _strip_string(answer)
20
+ except:
21
+ return answer
22
+
23
+
24
+ def _fix_fracs(string):
25
+ substrs = string.split("\\frac")
26
+ new_str = substrs[0]
27
+ if len(substrs) > 1:
28
+ substrs = substrs[1:]
29
+ for substr in substrs:
30
+ new_str += "\\frac"
31
+ if substr[0] == "{":
32
+ new_str += substr
33
+ else:
34
+ try:
35
+ assert len(substr) >= 2
36
+ except:
37
+ return string
38
+ a = substr[0]
39
+ b = substr[1]
40
+ if b != "{":
41
+ if len(substr) > 2:
42
+ post_substr = substr[2:]
43
+ new_str += "{" + a + "}{" + b + "}" + post_substr
44
+ else:
45
+ new_str += "{" + a + "}{" + b + "}"
46
+ else:
47
+ if len(substr) > 2:
48
+ post_substr = substr[2:]
49
+ new_str += "{" + a + "}" + b + post_substr
50
+ else:
51
+ new_str += "{" + a + "}" + b
52
+ string = new_str
53
+ return string
54
+
55
+
56
+ def _fix_a_slash_b(string):
57
+ if len(string.split("/")) != 2:
58
+ return string
59
+ a = string.split("/")[0]
60
+ b = string.split("/")[1]
61
+ try:
62
+ a = int(a)
63
+ b = int(b)
64
+ assert string == "{}/{}".format(a, b)
65
+ new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
66
+ return new_string
67
+ except:
68
+ return string
69
+
70
+
71
+ def _remove_right_units(string):
72
+ # "\\text{ " only ever occurs (at least in the val set) when describing units
73
+ if "\\text{ " in string:
74
+ splits = string.split("\\text{ ")
75
+ assert len(splits) == 2
76
+ return splits[0]
77
+ else:
78
+ return string
79
+
80
+
81
+ def _fix_sqrt(string):
82
+ if "\\sqrt" not in string:
83
+ return string
84
+ splits = string.split("\\sqrt")
85
+ new_string = splits[0]
86
+ for split in splits[1:]:
87
+ if split[0] != "{":
88
+ a = split[0]
89
+ new_substr = "\\sqrt{" + a + "}" + split[1:]
90
+ else:
91
+ new_substr = "\\sqrt" + split
92
+ new_string += new_substr
93
+ return new_string
94
+
95
+
96
+ def _strip_string(string):
97
+ # linebreaks
98
+ string = string.replace("\n", "")
99
+ # print(string)
100
+
101
+ # remove inverse spaces
102
+ string = string.replace("\\!", "")
103
+ # print(string)
104
+
105
+ # replace \\ with \
106
+ string = string.replace("\\\\", "\\")
107
+ # print(string)
108
+
109
+ # replace tfrac and dfrac with frac
110
+ string = string.replace("tfrac", "frac")
111
+ string = string.replace("dfrac", "frac")
112
+ # print(string)
113
+
114
+ # remove \left and \right
115
+ string = string.replace("\\left", "")
116
+ string = string.replace("\\right", "")
117
+ # print(string)
118
+
119
+ # Remove circ (degrees)
120
+ string = string.replace("^{\\circ}", "")
121
+ string = string.replace("^\\circ", "")
122
+
123
+ # remove dollar signs
124
+ string = string.replace("\\$", "")
125
+
126
+ # remove units (on the right)
127
+ string = _remove_right_units(string)
128
+
129
+ # remove percentage
130
+ string = string.replace("\\%", "")
131
+ string = string.replace("\%", "")
132
+
133
+ # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
134
+ string = string.replace(" .", " 0.")
135
+ string = string.replace("{.", "{0.")
136
+ # if empty, return empty string
137
+ if len(string) == 0:
138
+ return string
139
+ if string[0] == ".":
140
+ string = "0" + string
141
+
142
+ # to consider: get rid of e.g. "k = " or "q = " at beginning
143
+ if len(string.split("=")) == 2:
144
+ if len(string.split("=")[0]) <= 2:
145
+ string = string.split("=")[1]
146
+
147
+ # fix sqrt3 --> sqrt{3}
148
+ string = _fix_sqrt(string)
149
+
150
+ # remove spaces
151
+ string = string.replace(" ", "")
152
+
153
+ # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
154
+ string = _fix_fracs(string)
155
+
156
+ # manually change 0.5 --> \frac{1}{2}
157
+ if string == "0.5":
158
+ string = "\\frac{1}{2}"
159
+
160
+ # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
161
+ string = _fix_a_slash_b(string)
162
+
163
+ return string
dataset_construction/run_evaluate_on_slurm.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=evaluate_on_eurus_2
3
+ #SBATCH --time=4-00:00:00
4
+ #SBATCH --gpus-per-node=2
5
+ #SBATCH --mem=128G
6
+ #SBATCH --account=lusu-k
7
+
8
+ conda activate domain_rlhf
9
+
10
+ python3 evaluate_on_eurus_2.py --tensor_parallel_size 2