Upload 10 files
Browse files- README.md +1 -0
- dataset_construction/README.md +3 -0
- dataset_construction/evaluate_on_eurus_2.py +133 -0
- dataset_construction/evaluation_utils/code_util/__init__.py +69 -0
- dataset_construction/evaluation_utils/code_util/testing_util.py +752 -0
- dataset_construction/evaluation_utils/code_util/utils.py +45 -0
- dataset_construction/evaluation_utils/math_util/__init__.py +410 -0
- dataset_construction/evaluation_utils/math_util/grader.py +384 -0
- dataset_construction/evaluation_utils/math_util/math_normalize.py +163 -0
- dataset_construction/run_evaluate_on_slurm.sh +10 -0
README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# domain-rlhf
|
dataset_construction/README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Code to evaluate model on Dataset `PRIME-RL/Eurus-2-RL-Data` and assign a difficulty to the result. Output in JSON.
|
| 2 |
+
|
| 3 |
+
Evaluation metric: Pass@k, 1/Pass@k
|
dataset_construction/evaluate_on_eurus_2.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from vllm import LLM, SamplingParams
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
|
| 7 |
+
from evaluation_utils.code_util import evaluate_code
|
| 8 |
+
from evaluation_utils.math_util import evaluate_math
|
| 9 |
+
import argparse
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
def read_args():
|
| 13 |
+
parser = argparse.ArgumentParser()
|
| 14 |
+
parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-7B-Instruct", help="The model to use for evaluation")
|
| 15 |
+
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for evaluation")
|
| 16 |
+
parser.add_argument("--k", type=int, default=10, help="Number of completions to generate for each prompt")
|
| 17 |
+
parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Number of parallel tensors to use for generation")
|
| 18 |
+
parser.add_argument("--temperature", type=float, default=0.8, help="Temperature for sampling")
|
| 19 |
+
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p for sampling")
|
| 20 |
+
parser.add_argument("--num_samples", type=int, default=-1, help="Number of samples to evaluate")
|
| 21 |
+
parser.add_argument("--output_file", type=str, default="evaluation_results.json", help="File to save evaluation results")
|
| 22 |
+
# NOTE: The dataset asks for CoT output. This is very token-consuming. If you limit it to a small number, the model will not be able to generate the answer.
|
| 23 |
+
parser.add_argument("--max_tokens", type=int, default=1024, help="Maximum number of tokens to generate.")
|
| 24 |
+
parser.add_argument("--log_per_step", type=int, default=1000, help="Log results per step")
|
| 25 |
+
return parser.parse_args()
|
| 26 |
+
|
| 27 |
+
def generate_batch_completions(llm, prompts, sampling_params):
|
| 28 |
+
outputs = llm.generate(prompts, sampling_params)
|
| 29 |
+
all_completions = []
|
| 30 |
+
for output in outputs:
|
| 31 |
+
completions = [output.outputs[i].text for i in range(len(output.outputs))]
|
| 32 |
+
all_completions.append(completions)
|
| 33 |
+
return all_completions
|
| 34 |
+
|
| 35 |
+
def evaluate(k_completions, ability, ground_truth):
|
| 36 |
+
all_results = []
|
| 37 |
+
for completion in k_completions:
|
| 38 |
+
if ability == "math":
|
| 39 |
+
is_correct, _, _ = evaluate_math(completion, ground_truth)
|
| 40 |
+
if is_correct:
|
| 41 |
+
all_results.append(1)
|
| 42 |
+
else:
|
| 43 |
+
all_results.append(0)
|
| 44 |
+
elif ability == "coding":
|
| 45 |
+
success, _ = evaluate_code(completion, ground_truth)
|
| 46 |
+
if success:
|
| 47 |
+
all_results.append(1)
|
| 48 |
+
else:
|
| 49 |
+
all_results.append(0)
|
| 50 |
+
return all_results
|
| 51 |
+
|
| 52 |
+
def main():
|
| 53 |
+
args = read_args()
|
| 54 |
+
ds = load_dataset("PRIME-RL/Eurus-2-RL-Data")
|
| 55 |
+
|
| 56 |
+
llm = LLM(model=args.model, tensor_parallel_size=args.tensor_parallel_size)
|
| 57 |
+
k = args.k
|
| 58 |
+
sampling_params = SamplingParams(temperature=args.temperature, top_p=args.top_p, n=k, max_tokens=4096)
|
| 59 |
+
|
| 60 |
+
batch_size = args.batch_size
|
| 61 |
+
results = {}
|
| 62 |
+
complete_results = {}
|
| 63 |
+
|
| 64 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
| 65 |
+
|
| 66 |
+
num = 0
|
| 67 |
+
for split in ["train", "validation"]:
|
| 68 |
+
results[split] = []
|
| 69 |
+
complete_results[split] = []
|
| 70 |
+
print(f"Running evaluation on {split} split")
|
| 71 |
+
num_samples = args.num_samples if args.num_samples > 0 else len(ds[split])
|
| 72 |
+
for i in tqdm(range(0, num_samples, batch_size)):
|
| 73 |
+
batch_samples = ds[split][i: i + batch_size]
|
| 74 |
+
prompts = []
|
| 75 |
+
for sample in batch_samples['prompt']:
|
| 76 |
+
message = tokenizer.apply_chat_template(
|
| 77 |
+
sample,
|
| 78 |
+
tokenize=False,
|
| 79 |
+
add_generation_prompt=False
|
| 80 |
+
)
|
| 81 |
+
prompts.append(message)
|
| 82 |
+
|
| 83 |
+
batch_outputs = generate_batch_completions(llm, prompts, sampling_params)
|
| 84 |
+
|
| 85 |
+
for data_source, prompt, ability, reward_model, extra_info, completions in zip(
|
| 86 |
+
batch_samples["data_source"],
|
| 87 |
+
batch_samples["prompt"],
|
| 88 |
+
batch_samples["ability"],
|
| 89 |
+
batch_samples["reward_model"],
|
| 90 |
+
batch_samples["extra_info"],
|
| 91 |
+
batch_outputs
|
| 92 |
+
):
|
| 93 |
+
sample = {
|
| 94 |
+
"data_source": data_source,
|
| 95 |
+
"prompt": prompt,
|
| 96 |
+
"ability": ability,
|
| 97 |
+
"reward_model": reward_model,
|
| 98 |
+
"extra_info": extra_info,
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
ability = sample.get("ability", "")
|
| 102 |
+
ground_truth = sample.get("reward_model", {}).get("ground_truth", "")
|
| 103 |
+
correctness = evaluate(completions, ability, ground_truth)
|
| 104 |
+
sample_result = {
|
| 105 |
+
**sample,
|
| 106 |
+
"correctness": correctness,
|
| 107 |
+
}
|
| 108 |
+
results[split].append(sample_result)
|
| 109 |
+
complete_sample_result = sample_result.copy()
|
| 110 |
+
complete_sample_result["completions"] = completions
|
| 111 |
+
complete_results[split].append(complete_sample_result)
|
| 112 |
+
|
| 113 |
+
num += 1
|
| 114 |
+
if args.num_samples > 0 and num >= args.num_samples:
|
| 115 |
+
break
|
| 116 |
+
if args.num_samples > 0 and num >= args.num_samples:
|
| 117 |
+
break
|
| 118 |
+
if num % args.log_per_step == 0:
|
| 119 |
+
with open(f"step{num}_" + args.output_file, "w", encoding="utf-8") as f:
|
| 120 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
| 121 |
+
with open(f"step{num}_complete" + args.output_file, "w", encoding="utf-8") as f:
|
| 122 |
+
json.dump(complete_results, f, ensure_ascii=False, indent=2)
|
| 123 |
+
if args.num_samples > 0 and num >= args.num_samples:
|
| 124 |
+
break
|
| 125 |
+
|
| 126 |
+
with open(args.output_file, "w", encoding="utf-8") as f:
|
| 127 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
| 128 |
+
|
| 129 |
+
with open("complete_" + args.output_file, "w", encoding="utf-8") as f:
|
| 130 |
+
json.dump(complete_results, f, ensure_ascii=False, indent=2)
|
| 131 |
+
|
| 132 |
+
if __name__ == "__main__":
|
| 133 |
+
main()
|
dataset_construction/evaluation_utils/code_util/__init__.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .utils import check_correctness as apps_check_correctness
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
import traceback
|
| 5 |
+
|
| 6 |
+
def evaluate_code(completion, test_cases):
|
| 7 |
+
# try to get code solution from completion. if the completion is pure code, this will not take effect.
|
| 8 |
+
solution = completion.split('```python')[-1].split('```')[0]
|
| 9 |
+
try:
|
| 10 |
+
try:
|
| 11 |
+
if not isinstance(test_cases, dict):
|
| 12 |
+
test_cases = json.loads(test_cases)
|
| 13 |
+
except Exception as e:
|
| 14 |
+
print(f"Error:{e}")
|
| 15 |
+
|
| 16 |
+
# 先检查正确性,如果正确,则再one by one 检查test case
|
| 17 |
+
try:
|
| 18 |
+
res, metadata = apps_check_correctness(
|
| 19 |
+
in_outs=test_cases,
|
| 20 |
+
generation=solution,
|
| 21 |
+
timeout=5,
|
| 22 |
+
debug=False
|
| 23 |
+
)
|
| 24 |
+
metadata = dict(enumerate(metadata))[0]
|
| 25 |
+
success = all(map(lambda x: x == True, res))
|
| 26 |
+
if success:
|
| 27 |
+
return success, metadata
|
| 28 |
+
except Exception as e:
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
test_cases_list = []
|
| 32 |
+
inputs = test_cases["inputs"]
|
| 33 |
+
outputs = test_cases["outputs"]
|
| 34 |
+
for i in range(len(inputs)):
|
| 35 |
+
test_cases_list.append({
|
| 36 |
+
"inputs": [inputs[i]],
|
| 37 |
+
"outputs": [outputs[i]]
|
| 38 |
+
})
|
| 39 |
+
|
| 40 |
+
metadata_list = []
|
| 41 |
+
res_list = []
|
| 42 |
+
for test_case_id, test_case in enumerate(test_cases_list):
|
| 43 |
+
res, metadata = apps_check_correctness(
|
| 44 |
+
in_outs=test_case,
|
| 45 |
+
generation=solution,
|
| 46 |
+
timeout=5,
|
| 47 |
+
debug=False
|
| 48 |
+
)
|
| 49 |
+
try:
|
| 50 |
+
metadata = dict(enumerate(metadata))[0] # 运算失败时metadata有可能为空
|
| 51 |
+
except Exception as e:
|
| 52 |
+
metadata={}
|
| 53 |
+
metadata["test_case"] = {}
|
| 54 |
+
metadata["test_case"]["input"] = str(test_case["inputs"][0])
|
| 55 |
+
metadata["test_case"]["output"] = str(test_case["outputs"][0])
|
| 56 |
+
metadata["test_case"]["res"] = str(res)
|
| 57 |
+
metadata_list.append(metadata)
|
| 58 |
+
res_list.extend(res)
|
| 59 |
+
|
| 60 |
+
if test_case_id>=9:
|
| 61 |
+
break
|
| 62 |
+
|
| 63 |
+
success = all(map(lambda x: x == True, res_list))
|
| 64 |
+
except Exception as e:
|
| 65 |
+
traceback.print_exc(10)
|
| 66 |
+
success = False
|
| 67 |
+
metadata_list = None
|
| 68 |
+
return success, metadata_list
|
| 69 |
+
|
dataset_construction/evaluation_utils/code_util/testing_util.py
ADDED
|
@@ -0,0 +1,752 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ast
|
| 2 |
+
import json
|
| 3 |
+
import sys
|
| 4 |
+
import faulthandler
|
| 5 |
+
import platform
|
| 6 |
+
|
| 7 |
+
# used for debugging to time steps
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
|
| 10 |
+
# to run the solution files we're using a timing based approach
|
| 11 |
+
import signal
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
# for capturing the stdout
|
| 16 |
+
from io import StringIO
|
| 17 |
+
|
| 18 |
+
# used for testing the code that reads from input
|
| 19 |
+
from unittest.mock import patch, mock_open
|
| 20 |
+
|
| 21 |
+
from pyext import RuntimeModule
|
| 22 |
+
|
| 23 |
+
from enum import Enum
|
| 24 |
+
|
| 25 |
+
import traceback
|
| 26 |
+
|
| 27 |
+
def truncatefn(s, length=300):
|
| 28 |
+
assert isinstance(s, str)
|
| 29 |
+
if len(s) <= length:
|
| 30 |
+
return s
|
| 31 |
+
|
| 32 |
+
return s[: length // 2] + "...(truncated) ..." + s[-length // 2 :]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class CODE_TYPE(Enum):
|
| 36 |
+
call_based = 0
|
| 37 |
+
standard_input = 1
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# stuff for setting up signal timer
|
| 41 |
+
class TimeoutException(Exception):
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def timeout_handler(signum, frame):
|
| 46 |
+
print("alarm went off")
|
| 47 |
+
return
|
| 48 |
+
# raise TimeoutException # this is an unhandled exception. just return None is OK
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
signal.signal(signal.SIGALRM, timeout_handler)
|
| 52 |
+
# timeout = 6 # seconds
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# used to capture stdout as a list
|
| 56 |
+
# from https://stackoverflow.com/a/16571630/6416660
|
| 57 |
+
# alternative use redirect_stdout() from contextlib
|
| 58 |
+
class Capturing(list):
|
| 59 |
+
def __enter__(self):
|
| 60 |
+
self._stdout = sys.stdout
|
| 61 |
+
sys.stdout = self._stringio = StringIO()
|
| 62 |
+
# Make closing the StringIO a no-op
|
| 63 |
+
self._stringio.close = lambda x: 1
|
| 64 |
+
return self
|
| 65 |
+
|
| 66 |
+
def __exit__(self, *args):
|
| 67 |
+
self.append(self._stringio.getvalue())
|
| 68 |
+
del self._stringio # free up some memory
|
| 69 |
+
sys.stdout = self._stdout
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def only_int_check(val):
|
| 73 |
+
return isinstance(val, int)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def string_int_check(val):
|
| 77 |
+
return isinstance(val, str) and val.isdigit()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def combined_int_check(val):
|
| 81 |
+
return only_int_check(val) or string_int_check(val)
|
| 82 |
+
|
| 83 |
+
def clean_traceback(error_traceback):
|
| 84 |
+
file_start = error_traceback.find('File \"<string>\"')
|
| 85 |
+
# print(file_start)
|
| 86 |
+
error_traceback = "Traceback (most recent call last):\n " + error_traceback[file_start:]
|
| 87 |
+
return error_traceback
|
| 88 |
+
|
| 89 |
+
def run_test(in_outs, test=None, debug=False, timeout=15):
|
| 90 |
+
"""
|
| 91 |
+
if test(generated_code) is not None it'll try to run the code.
|
| 92 |
+
otherwise it'll just return an input and output pair.
|
| 93 |
+
"""
|
| 94 |
+
# Disable functionalities that can make destructive changes to the test.
|
| 95 |
+
reliability_guard()
|
| 96 |
+
|
| 97 |
+
if debug:
|
| 98 |
+
print(f"start = {datetime.now().time()}")
|
| 99 |
+
|
| 100 |
+
# try:
|
| 101 |
+
# in_outs = json.loads(sample["input_output"])
|
| 102 |
+
# except ValueError:
|
| 103 |
+
# in_outs = None
|
| 104 |
+
if in_outs:
|
| 105 |
+
if in_outs.get("fn_name") is None:
|
| 106 |
+
which_type = CODE_TYPE.standard_input # Standard input
|
| 107 |
+
method_name = None
|
| 108 |
+
else:
|
| 109 |
+
which_type = CODE_TYPE.call_based # Call-based
|
| 110 |
+
method_name = in_outs["fn_name"]
|
| 111 |
+
|
| 112 |
+
if debug:
|
| 113 |
+
print(f"loaded input_output = {datetime.now().time()}")
|
| 114 |
+
|
| 115 |
+
if test is None:
|
| 116 |
+
assert False, "should not happen: test code is none"
|
| 117 |
+
return in_outs, {"error": "no test code provided"}
|
| 118 |
+
elif test is not None:
|
| 119 |
+
results = []
|
| 120 |
+
sol = "from string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\nsys.setrecursionlimit(6*10**5)\n"
|
| 121 |
+
if debug:
|
| 122 |
+
print(f"loading test code = {datetime.now().time()}")
|
| 123 |
+
|
| 124 |
+
if which_type == CODE_TYPE.call_based:
|
| 125 |
+
|
| 126 |
+
sol += test
|
| 127 |
+
if debug:
|
| 128 |
+
print(f"sol = {sol}")
|
| 129 |
+
signal.alarm(timeout)
|
| 130 |
+
try:
|
| 131 |
+
tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
|
| 132 |
+
if "class Solution" not in test:
|
| 133 |
+
tmp = tmp_sol
|
| 134 |
+
else:
|
| 135 |
+
tmp = tmp_sol.Solution()
|
| 136 |
+
signal.alarm(0)
|
| 137 |
+
except Exception as e:
|
| 138 |
+
signal.alarm(0)
|
| 139 |
+
error_traceback = traceback.format_exc()
|
| 140 |
+
if debug:
|
| 141 |
+
print(f"type 0 compilation error = {e}")
|
| 142 |
+
results.append(-2)
|
| 143 |
+
return results, {
|
| 144 |
+
"error": repr(e),
|
| 145 |
+
# "error_code": -1,
|
| 146 |
+
# "error_message": "Compilation Error",
|
| 147 |
+
"traceback": clean_traceback(error_traceback),
|
| 148 |
+
}
|
| 149 |
+
signal.alarm(0)
|
| 150 |
+
|
| 151 |
+
elif which_type == CODE_TYPE.standard_input:
|
| 152 |
+
# sol
|
| 153 |
+
# if code has if __name__ == "__main__": then remove it
|
| 154 |
+
try:
|
| 155 |
+
astree = ast.parse(test)
|
| 156 |
+
last_block = astree.body[-1]
|
| 157 |
+
if isinstance(last_block, ast.If):
|
| 158 |
+
condition = last_block.test
|
| 159 |
+
if ast.unparse(condition).strip() == "__name__ == '__main__'":
|
| 160 |
+
test = (
|
| 161 |
+
ast.unparse(astree.body[:-1])
|
| 162 |
+
+ "\n"
|
| 163 |
+
+ ast.unparse(last_block.body)
|
| 164 |
+
)
|
| 165 |
+
except:
|
| 166 |
+
pass
|
| 167 |
+
|
| 168 |
+
tmp_test = test.split("\n")
|
| 169 |
+
|
| 170 |
+
new_test = []
|
| 171 |
+
for x in tmp_test:
|
| 172 |
+
if (not x.startswith("from ")) and (not x.startswith("import ")):
|
| 173 |
+
new_test.append("\t" + x + "\n")
|
| 174 |
+
else:
|
| 175 |
+
new_test.append(x + "\n")
|
| 176 |
+
tmp_test = new_test
|
| 177 |
+
|
| 178 |
+
new_test = ""
|
| 179 |
+
started = False
|
| 180 |
+
for i in tmp_test:
|
| 181 |
+
if i.startswith("\t") and not started:
|
| 182 |
+
new_test += "stdin = sys.stdin\nstdout = sys.stdout\n"
|
| 183 |
+
new_test += "def code():\n"
|
| 184 |
+
new_test += i
|
| 185 |
+
started = True
|
| 186 |
+
elif started and ((i.startswith("from ")) or (i.startswith("import "))):
|
| 187 |
+
new_test += "\t" + i
|
| 188 |
+
else:
|
| 189 |
+
new_test += i
|
| 190 |
+
tmp_test = new_test
|
| 191 |
+
|
| 192 |
+
sol += tmp_test
|
| 193 |
+
if debug:
|
| 194 |
+
print(f"sol = {sol}")
|
| 195 |
+
method_name = "code"
|
| 196 |
+
signal.alarm(timeout)
|
| 197 |
+
try:
|
| 198 |
+
tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
|
| 199 |
+
tmp = tmp_sol
|
| 200 |
+
signal.alarm(0)
|
| 201 |
+
except Exception as e:
|
| 202 |
+
signal.alarm(0)
|
| 203 |
+
error_traceback = traceback.format_exc()
|
| 204 |
+
if debug:
|
| 205 |
+
print(f"type 1 compilation error = {e}")
|
| 206 |
+
results.append(-2)
|
| 207 |
+
return results, {
|
| 208 |
+
"error": repr(e),
|
| 209 |
+
# "error_code": -1,
|
| 210 |
+
# "error_message": "Compilation Error",
|
| 211 |
+
"traceback": clean_traceback(error_traceback),
|
| 212 |
+
}
|
| 213 |
+
signal.alarm(0)
|
| 214 |
+
if debug:
|
| 215 |
+
print(f"get method = {datetime.now().time()}")
|
| 216 |
+
|
| 217 |
+
try:
|
| 218 |
+
method = getattr(tmp, method_name) # get_attr second arg must be str
|
| 219 |
+
except:
|
| 220 |
+
signal.alarm(0)
|
| 221 |
+
error_traceback = traceback.format_exc()
|
| 222 |
+
e = sys.exc_info()
|
| 223 |
+
print(f"unable to get function error = {e}")
|
| 224 |
+
results.append(-2)
|
| 225 |
+
return results, {
|
| 226 |
+
"error": repr(e),
|
| 227 |
+
# "error_code": -1,
|
| 228 |
+
# "error_message": "Unable to extract code",
|
| 229 |
+
"traceback": clean_traceback(error_traceback),
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
for index, inputs in enumerate(in_outs["inputs"]):
|
| 233 |
+
raw_inputs = inputs
|
| 234 |
+
raw_outputs = in_outs["outputs"][index]
|
| 235 |
+
if which_type == CODE_TYPE.call_based:
|
| 236 |
+
inputs = [json.loads(line) for line in inputs.split("\n")]
|
| 237 |
+
in_outs["outputs"][index] = json.loads(in_outs["outputs"][index])
|
| 238 |
+
|
| 239 |
+
truncate_line_size = 300 // (raw_inputs.count("\n") + 1)
|
| 240 |
+
raw_inputs = "\n".join(
|
| 241 |
+
[
|
| 242 |
+
truncatefn(line, truncate_line_size)
|
| 243 |
+
for line in raw_inputs.strip().split("\n")
|
| 244 |
+
]
|
| 245 |
+
)
|
| 246 |
+
raw_outputs = truncatefn(raw_outputs, 200)
|
| 247 |
+
else:
|
| 248 |
+
raw_inputs = truncatefn(raw_inputs)
|
| 249 |
+
raw_outputs = truncatefn(raw_outputs, 200)
|
| 250 |
+
# JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)
|
| 251 |
+
try:
|
| 252 |
+
if isinstance(inputs[0], dict):
|
| 253 |
+
inputs = [{int(k): v for k, v in inputs[0].items()}]
|
| 254 |
+
except:
|
| 255 |
+
True
|
| 256 |
+
try:
|
| 257 |
+
if isinstance(in_outs["outputs"][index], dict):
|
| 258 |
+
in_outs["outputs"][index] = [
|
| 259 |
+
{int(k): v for k, v in in_outs["outputs"][index].items()}
|
| 260 |
+
]
|
| 261 |
+
except:
|
| 262 |
+
True
|
| 263 |
+
try:
|
| 264 |
+
if isinstance(in_outs["outputs"][index][0], dict):
|
| 265 |
+
in_outs["outputs"][index] = [
|
| 266 |
+
{int(k): v for k, v in in_outs["outputs"][index][0].items()}
|
| 267 |
+
]
|
| 268 |
+
except:
|
| 269 |
+
True
|
| 270 |
+
|
| 271 |
+
if debug:
|
| 272 |
+
print(
|
| 273 |
+
f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. type = {which_type}"
|
| 274 |
+
)
|
| 275 |
+
if which_type == CODE_TYPE.call_based: # Call-based
|
| 276 |
+
signal.alarm(timeout)
|
| 277 |
+
faulthandler.enable()
|
| 278 |
+
try:
|
| 279 |
+
output = method(*inputs)
|
| 280 |
+
raw_true_output = output
|
| 281 |
+
|
| 282 |
+
raw_true_output_copy = json.dumps(output)
|
| 283 |
+
raw_true_output_copy = truncatefn(raw_true_output_copy, 200)
|
| 284 |
+
|
| 285 |
+
# ground truth sequences are not tuples
|
| 286 |
+
if isinstance(output, tuple):
|
| 287 |
+
output = list(output)
|
| 288 |
+
|
| 289 |
+
tmp_result = output == in_outs["outputs"][index]
|
| 290 |
+
if (
|
| 291 |
+
isinstance(in_outs["outputs"][index], list)
|
| 292 |
+
and in_outs["outputs"][index]
|
| 293 |
+
):
|
| 294 |
+
tmp_result = tmp_result or (
|
| 295 |
+
output == in_outs["outputs"][index][0]
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# ground truth sequences are not tuples
|
| 299 |
+
try:
|
| 300 |
+
if isinstance(output[0], tuple):
|
| 301 |
+
tmp_result = tmp_result or (
|
| 302 |
+
[list(x) for x in output]
|
| 303 |
+
== in_outs["outputs"][index][0]
|
| 304 |
+
)
|
| 305 |
+
except:
|
| 306 |
+
True
|
| 307 |
+
results.append(tmp_result)
|
| 308 |
+
if tmp_result != True:
|
| 309 |
+
return results, {
|
| 310 |
+
"output": raw_true_output_copy,
|
| 311 |
+
"expected": raw_outputs,
|
| 312 |
+
"inputs": raw_inputs,
|
| 313 |
+
# "error_code": -2,
|
| 314 |
+
"error_message": "Wrong Answer",
|
| 315 |
+
}
|
| 316 |
+
# reset the alarm
|
| 317 |
+
signal.alarm(0)
|
| 318 |
+
except Exception as e:
|
| 319 |
+
signal.alarm(0)
|
| 320 |
+
error_traceback = traceback.format_exc()
|
| 321 |
+
faulthandler.disable()
|
| 322 |
+
if debug:
|
| 323 |
+
print(
|
| 324 |
+
f"Standard input runtime error or time limit exceeded error = {e}"
|
| 325 |
+
)
|
| 326 |
+
results.append(-1)
|
| 327 |
+
if "timeoutexception" in repr(e).lower():
|
| 328 |
+
return results, {
|
| 329 |
+
"error": repr(e),
|
| 330 |
+
# "error_code": -3,
|
| 331 |
+
# "error_message": "Time Limit Exceeded",
|
| 332 |
+
# "inputs": raw_inputs,
|
| 333 |
+
# "expected": raw_outputs,
|
| 334 |
+
"traceback": clean_traceback(error_traceback),
|
| 335 |
+
}
|
| 336 |
+
else:
|
| 337 |
+
return results, {
|
| 338 |
+
"error": repr(e),
|
| 339 |
+
# "error_code": -4,
|
| 340 |
+
# "error_message": "Runtime Error",
|
| 341 |
+
# "inputs": raw_inputs,
|
| 342 |
+
# "expected": raw_outputs,
|
| 343 |
+
"traceback": clean_traceback(error_traceback),
|
| 344 |
+
}
|
| 345 |
+
faulthandler.disable()
|
| 346 |
+
signal.alarm(0)
|
| 347 |
+
if debug:
|
| 348 |
+
print(
|
| 349 |
+
f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
|
| 350 |
+
)
|
| 351 |
+
elif which_type == CODE_TYPE.standard_input: # Standard input
|
| 352 |
+
faulthandler.enable()
|
| 353 |
+
passed = False
|
| 354 |
+
|
| 355 |
+
if isinstance(inputs, list):
|
| 356 |
+
inputs = "\n".join(inputs)
|
| 357 |
+
if isinstance(in_outs["outputs"][index], list):
|
| 358 |
+
in_outs["outputs"][index] = "\n".join(in_outs["outputs"][index])
|
| 359 |
+
|
| 360 |
+
signal.alarm(timeout)
|
| 361 |
+
with Capturing() as output:
|
| 362 |
+
try:
|
| 363 |
+
call_method(method, inputs)
|
| 364 |
+
# reset the alarm
|
| 365 |
+
signal.alarm(0)
|
| 366 |
+
passed = True
|
| 367 |
+
except Exception as e:
|
| 368 |
+
# runtime error or took too long
|
| 369 |
+
signal.alarm(0)
|
| 370 |
+
error_traceback = traceback.format_exc()
|
| 371 |
+
print(
|
| 372 |
+
f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}"
|
| 373 |
+
)
|
| 374 |
+
results.append(-1)
|
| 375 |
+
if "timeoutexception" in repr(e).lower():
|
| 376 |
+
return results, {
|
| 377 |
+
"error": repr(e),
|
| 378 |
+
# "error_code": -3,
|
| 379 |
+
# "error_message": "Time Limit Exceeded",
|
| 380 |
+
# "inputs": raw_inputs,
|
| 381 |
+
# "expected": raw_outputs,
|
| 382 |
+
"traceback": clean_traceback(error_traceback),
|
| 383 |
+
}
|
| 384 |
+
else:
|
| 385 |
+
return results, {
|
| 386 |
+
"error": repr(e),
|
| 387 |
+
# "error_code": -4,
|
| 388 |
+
# "error_message": "Runtime Error",
|
| 389 |
+
# "inputs": raw_inputs,
|
| 390 |
+
# "expected": raw_outputs,
|
| 391 |
+
"traceback": clean_traceback(error_traceback),
|
| 392 |
+
}
|
| 393 |
+
signal.alarm(0)
|
| 394 |
+
raw_true_output = output[0]
|
| 395 |
+
raw_true_output_copy = truncatefn(raw_true_output, 200)
|
| 396 |
+
output = raw_true_output.splitlines()
|
| 397 |
+
if not passed:
|
| 398 |
+
if debug:
|
| 399 |
+
nl = "\n"
|
| 400 |
+
if not isinstance(inputs, list):
|
| 401 |
+
print(
|
| 402 |
+
f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
|
| 403 |
+
)
|
| 404 |
+
else:
|
| 405 |
+
print(
|
| 406 |
+
f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
|
| 407 |
+
)
|
| 408 |
+
continue
|
| 409 |
+
|
| 410 |
+
if passed and debug:
|
| 411 |
+
print(
|
| 412 |
+
f"==> output = {output}, test outputs = {in_outs['outputs'][index]}"
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
if custom_compare_(output, in_outs["outputs"][index]):
|
| 416 |
+
tmp_result = True
|
| 417 |
+
results.append(tmp_result)
|
| 418 |
+
continue
|
| 419 |
+
|
| 420 |
+
# ground truth sequences are expressed as lists not tuples
|
| 421 |
+
if isinstance(output, tuple):
|
| 422 |
+
output = list(output)
|
| 423 |
+
|
| 424 |
+
tmp_result = False
|
| 425 |
+
try:
|
| 426 |
+
tmp_result = output == [in_outs["outputs"][index]]
|
| 427 |
+
if isinstance(in_outs["outputs"][index], list):
|
| 428 |
+
tmp_result = tmp_result or (output == in_outs["outputs"][index])
|
| 429 |
+
if isinstance(output[0], str):
|
| 430 |
+
tmp_result = tmp_result or (
|
| 431 |
+
[e.strip() for e in output] == in_outs["outputs"][index]
|
| 432 |
+
)
|
| 433 |
+
except Exception as e:
|
| 434 |
+
if debug:
|
| 435 |
+
print(f"Failed check1 exception = {e}")
|
| 436 |
+
pass
|
| 437 |
+
|
| 438 |
+
if tmp_result == True:
|
| 439 |
+
results.append(tmp_result)
|
| 440 |
+
continue
|
| 441 |
+
|
| 442 |
+
# try one more time without \n
|
| 443 |
+
if isinstance(in_outs["outputs"][index], list):
|
| 444 |
+
for tmp_index, i in enumerate(in_outs["outputs"][index]):
|
| 445 |
+
in_outs["outputs"][index][tmp_index] = i.split("\n")
|
| 446 |
+
in_outs["outputs"][index][tmp_index] = [
|
| 447 |
+
x.strip() for x in in_outs["outputs"][index][tmp_index] if x
|
| 448 |
+
]
|
| 449 |
+
else:
|
| 450 |
+
in_outs["outputs"][index] = in_outs["outputs"][index].split("\n")
|
| 451 |
+
in_outs["outputs"][index] = list(
|
| 452 |
+
filter(len, in_outs["outputs"][index])
|
| 453 |
+
)
|
| 454 |
+
in_outs["outputs"][index] = list(
|
| 455 |
+
map(lambda x: x.strip(), in_outs["outputs"][index])
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
try:
|
| 459 |
+
tmp_result = output == [in_outs["outputs"][index]]
|
| 460 |
+
if isinstance(in_outs["outputs"][index], list):
|
| 461 |
+
tmp_result = tmp_result or (output == in_outs["outputs"][index])
|
| 462 |
+
except Exception as e:
|
| 463 |
+
if debug:
|
| 464 |
+
print(f"Failed check2 exception = {e}")
|
| 465 |
+
pass
|
| 466 |
+
|
| 467 |
+
if tmp_result == True:
|
| 468 |
+
results.append(tmp_result)
|
| 469 |
+
continue
|
| 470 |
+
|
| 471 |
+
# try by converting the output into a split up list too
|
| 472 |
+
if isinstance(output, list):
|
| 473 |
+
output = list(filter(len, output))
|
| 474 |
+
|
| 475 |
+
if debug:
|
| 476 |
+
nl = "\n"
|
| 477 |
+
if not isinstance(inputs, list):
|
| 478 |
+
print(
|
| 479 |
+
f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}"
|
| 480 |
+
)
|
| 481 |
+
else:
|
| 482 |
+
print(
|
| 483 |
+
f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}"
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
if tmp_result == True:
|
| 487 |
+
results.append(tmp_result)
|
| 488 |
+
continue
|
| 489 |
+
|
| 490 |
+
if debug:
|
| 491 |
+
print(f"{tmp_result=} @a")
|
| 492 |
+
|
| 493 |
+
try:
|
| 494 |
+
tmp_result = output == [in_outs["outputs"][index]]
|
| 495 |
+
if isinstance(in_outs["outputs"][index], list):
|
| 496 |
+
tmp_result = tmp_result or (output == in_outs["outputs"][index])
|
| 497 |
+
except Exception as e:
|
| 498 |
+
if debug:
|
| 499 |
+
print(f"Failed check3 exception = {e}")
|
| 500 |
+
pass
|
| 501 |
+
|
| 502 |
+
if debug:
|
| 503 |
+
print(f"{tmp_result=} @b")
|
| 504 |
+
|
| 505 |
+
try:
|
| 506 |
+
all_ints = all(
|
| 507 |
+
combined_int_check(e1) and combined_int_check(e2)
|
| 508 |
+
for e1, e2 in zip(output, in_outs["outputs"][index])
|
| 509 |
+
)
|
| 510 |
+
if not all_ints:
|
| 511 |
+
if debug:
|
| 512 |
+
print(
|
| 513 |
+
[
|
| 514 |
+
combined_int_check(e1) and combined_int_check(e2)
|
| 515 |
+
for e1, e2 in zip(output, in_outs["outputs"][index])
|
| 516 |
+
]
|
| 517 |
+
)
|
| 518 |
+
output_float = [float(e) for e in output]
|
| 519 |
+
gt_float = [float(e) for e in in_outs["outputs"][index]]
|
| 520 |
+
tmp_result = tmp_result or (
|
| 521 |
+
(len(output_float) == len(gt_float))
|
| 522 |
+
and np.allclose(output_float, gt_float)
|
| 523 |
+
)
|
| 524 |
+
except Exception as e:
|
| 525 |
+
pass
|
| 526 |
+
|
| 527 |
+
if debug:
|
| 528 |
+
print(f"{tmp_result=} @c")
|
| 529 |
+
|
| 530 |
+
try:
|
| 531 |
+
if isinstance(output[0], list):
|
| 532 |
+
all_ints = all(
|
| 533 |
+
combined_int_check(e1) and combined_int_check(e2)
|
| 534 |
+
for e1, e2 in zip(output[0], in_outs["outputs"][index])
|
| 535 |
+
)
|
| 536 |
+
if not all_ints:
|
| 537 |
+
output_float = [float(e) for e in output[0]]
|
| 538 |
+
gt_float = [float(e) for e in in_outs["outputs"][index][0]]
|
| 539 |
+
tmp_result = tmp_result or (
|
| 540 |
+
(len(output_float) == len(gt_float))
|
| 541 |
+
and np.allclose(output_float, gt_float)
|
| 542 |
+
)
|
| 543 |
+
except Exception as e:
|
| 544 |
+
pass
|
| 545 |
+
|
| 546 |
+
if tmp_result == True:
|
| 547 |
+
results.append(tmp_result)
|
| 548 |
+
continue
|
| 549 |
+
|
| 550 |
+
if debug:
|
| 551 |
+
print(f"{tmp_result=} @d")
|
| 552 |
+
# try by converting the stuff into split up list
|
| 553 |
+
if isinstance(in_outs["outputs"][index], list):
|
| 554 |
+
for tmp_index, i in enumerate(in_outs["outputs"][index]):
|
| 555 |
+
in_outs["outputs"][index][tmp_index] = set(i.split())
|
| 556 |
+
else:
|
| 557 |
+
in_outs["outputs"][index] = set(in_outs["outputs"][index].split())
|
| 558 |
+
|
| 559 |
+
if debug:
|
| 560 |
+
print(f"{tmp_result=} @e")
|
| 561 |
+
|
| 562 |
+
try:
|
| 563 |
+
tmp_result = output == in_outs["outputs"][index]
|
| 564 |
+
except Exception as e:
|
| 565 |
+
if debug:
|
| 566 |
+
print(f"Failed check4 exception = {e}")
|
| 567 |
+
continue
|
| 568 |
+
|
| 569 |
+
if tmp_result == True:
|
| 570 |
+
results.append(tmp_result)
|
| 571 |
+
continue
|
| 572 |
+
|
| 573 |
+
if debug:
|
| 574 |
+
print(f"{tmp_result=} @f")
|
| 575 |
+
|
| 576 |
+
# try by converting the output into a split up list too
|
| 577 |
+
if isinstance(output, list):
|
| 578 |
+
for tmp_index, i in enumerate(output):
|
| 579 |
+
output[tmp_index] = i.split()
|
| 580 |
+
output = list(filter(len, output))
|
| 581 |
+
for tmp_index, i in enumerate(output):
|
| 582 |
+
output[tmp_index] = set(i)
|
| 583 |
+
else:
|
| 584 |
+
output = output.split()
|
| 585 |
+
output = list(filter(len, output))
|
| 586 |
+
output = set(output)
|
| 587 |
+
|
| 588 |
+
if debug:
|
| 589 |
+
print(f"{tmp_result=} @g")
|
| 590 |
+
|
| 591 |
+
if tmp_result == True and debug:
|
| 592 |
+
print("PASSED")
|
| 593 |
+
|
| 594 |
+
results.append(tmp_result)
|
| 595 |
+
if tmp_result != True:
|
| 596 |
+
return results, {
|
| 597 |
+
"output": raw_true_output_copy,
|
| 598 |
+
"expected": raw_outputs,
|
| 599 |
+
"inputs": raw_inputs,
|
| 600 |
+
# "error_code": -2,
|
| 601 |
+
"error_message": "Wrong Answer",
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
if debug:
|
| 605 |
+
nl = "\n"
|
| 606 |
+
if not isinstance(inputs, list):
|
| 607 |
+
print(
|
| 608 |
+
f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
|
| 609 |
+
)
|
| 610 |
+
else:
|
| 611 |
+
print(
|
| 612 |
+
f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
print(f"results = {results}")
|
| 616 |
+
|
| 617 |
+
return results, {}
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
def custom_compare_(output, ground_truth):
|
| 621 |
+
|
| 622 |
+
if isinstance(output, list):
|
| 623 |
+
output_1 = "\n".join(output)
|
| 624 |
+
if stripped_string_compare(output_1, ground_truth):
|
| 625 |
+
return True
|
| 626 |
+
|
| 627 |
+
if isinstance(output, list):
|
| 628 |
+
output_2 = [o.lstrip().rstrip() for o in output]
|
| 629 |
+
output_2 = "\n".join(output_2)
|
| 630 |
+
if stripped_string_compare(output_2, ground_truth):
|
| 631 |
+
return True
|
| 632 |
+
|
| 633 |
+
return False
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
def stripped_string_compare(s1, s2):
|
| 637 |
+
s1 = s1.lstrip().rstrip()
|
| 638 |
+
s2 = s2.lstrip().rstrip()
|
| 639 |
+
return s1 == s2
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def call_method(method, inputs):
|
| 643 |
+
|
| 644 |
+
if isinstance(inputs, list):
|
| 645 |
+
inputs = "\n".join(inputs)
|
| 646 |
+
|
| 647 |
+
inputs_line_iterator = iter(inputs.split("\n"))
|
| 648 |
+
|
| 649 |
+
# sys.setrecursionlimit(10000)
|
| 650 |
+
|
| 651 |
+
# @patch('builtins.input', side_effect=inputs.split("\n"))
|
| 652 |
+
@patch("builtins.open", mock_open(read_data=inputs))
|
| 653 |
+
@patch("sys.stdin", StringIO(inputs))
|
| 654 |
+
@patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator))
|
| 655 |
+
@patch("sys.stdin.readlines", lambda *args: inputs.split("\n"))
|
| 656 |
+
@patch("sys.stdin.read", lambda *args: inputs)
|
| 657 |
+
# @patch('sys.stdout.write', print)
|
| 658 |
+
def _inner_call_method(_method):
|
| 659 |
+
try:
|
| 660 |
+
return _method()
|
| 661 |
+
except SystemExit as e:
|
| 662 |
+
pass
|
| 663 |
+
finally:
|
| 664 |
+
pass
|
| 665 |
+
|
| 666 |
+
return _inner_call_method(method)
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
def reliability_guard(maximum_memory_bytes=None):
|
| 670 |
+
"""
|
| 671 |
+
This disables various destructive functions and prevents the generated code
|
| 672 |
+
from interfering with the test (e.g. fork bomb, killing other processes,
|
| 673 |
+
removing filesystem files, etc.)
|
| 674 |
+
WARNING
|
| 675 |
+
This function is NOT a security sandbox. Untrusted code, including, model-
|
| 676 |
+
generated code, should not be blindly executed outside of one. See the
|
| 677 |
+
Codex paper for more information about OpenAI's code sandbox, and proceed
|
| 678 |
+
with caution.
|
| 679 |
+
"""
|
| 680 |
+
|
| 681 |
+
if maximum_memory_bytes is not None:
|
| 682 |
+
import resource
|
| 683 |
+
|
| 684 |
+
resource.setrlimit(
|
| 685 |
+
resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)
|
| 686 |
+
)
|
| 687 |
+
resource.setrlimit(
|
| 688 |
+
resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)
|
| 689 |
+
)
|
| 690 |
+
if not platform.uname().system == "Darwin":
|
| 691 |
+
resource.setrlimit(
|
| 692 |
+
resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
faulthandler.disable()
|
| 696 |
+
|
| 697 |
+
import builtins
|
| 698 |
+
|
| 699 |
+
builtins.exit = None
|
| 700 |
+
builtins.quit = None
|
| 701 |
+
|
| 702 |
+
import os
|
| 703 |
+
|
| 704 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
| 705 |
+
|
| 706 |
+
os.kill = None
|
| 707 |
+
os.system = None # 防止干扰repl评测
|
| 708 |
+
os.putenv = None
|
| 709 |
+
os.remove = None
|
| 710 |
+
os.removedirs = None
|
| 711 |
+
os.rmdir = None
|
| 712 |
+
os.fchdir = None
|
| 713 |
+
os.setuid = None
|
| 714 |
+
os.fork = None
|
| 715 |
+
os.forkpty = None
|
| 716 |
+
os.killpg = None
|
| 717 |
+
os.rename = None
|
| 718 |
+
os.renames = None
|
| 719 |
+
os.truncate = None
|
| 720 |
+
os.replace = None
|
| 721 |
+
os.unlink = None
|
| 722 |
+
os.fchmod = None
|
| 723 |
+
os.fchown = None
|
| 724 |
+
os.chmod = None
|
| 725 |
+
os.chown = None
|
| 726 |
+
os.chroot = None
|
| 727 |
+
os.fchdir = None
|
| 728 |
+
os.lchflags = None
|
| 729 |
+
os.lchmod = None
|
| 730 |
+
os.lchown = None
|
| 731 |
+
os.getcwd = None
|
| 732 |
+
os.chdir = None
|
| 733 |
+
|
| 734 |
+
import shutil
|
| 735 |
+
|
| 736 |
+
shutil.rmtree = None
|
| 737 |
+
shutil.move = None
|
| 738 |
+
shutil.chown = None
|
| 739 |
+
|
| 740 |
+
import subprocess
|
| 741 |
+
|
| 742 |
+
subprocess.Popen = None # type: ignore
|
| 743 |
+
|
| 744 |
+
__builtins__["help"] = None
|
| 745 |
+
|
| 746 |
+
import sys
|
| 747 |
+
|
| 748 |
+
sys.modules["ipdb"] = None
|
| 749 |
+
sys.modules["joblib"] = None
|
| 750 |
+
sys.modules["resource"] = None
|
| 751 |
+
sys.modules["psutil"] = None
|
| 752 |
+
sys.modules["tkinter"] = None
|
dataset_construction/evaluation_utils/code_util/utils.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Borrowed from: https://huggingface.co/spaces/codeparrot/apps_metric/blob/main/utils.py
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import multiprocessing
|
| 5 |
+
from typing import Dict, Optional
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from .testing_util import run_test
|
| 8 |
+
import traceback
|
| 9 |
+
import os,sys
|
| 10 |
+
|
| 11 |
+
def _temp_run(sample, generation, debug, result,metadata_list,timeout):
|
| 12 |
+
# test is run silently. If it is killed, nothing will be printed
|
| 13 |
+
with open(os.devnull, 'w') as devnull:
|
| 14 |
+
sys.stdout = devnull
|
| 15 |
+
sys.stderr = devnull
|
| 16 |
+
try:
|
| 17 |
+
res, metadata= run_test(in_outs=sample, test=generation, debug=debug,timeout=timeout)
|
| 18 |
+
result.append(res)
|
| 19 |
+
metadata_list.append(metadata)
|
| 20 |
+
except Exception as e:
|
| 21 |
+
# print(e) # some tracebacks are extremely long.
|
| 22 |
+
traceback.print_exc(10)
|
| 23 |
+
result.append([-1 for i in range(len(sample['inputs']))])
|
| 24 |
+
metadata_list.append({})
|
| 25 |
+
|
| 26 |
+
def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True):
|
| 27 |
+
"""Check correctness of code generation with a global timeout.
|
| 28 |
+
The global timeout is to catch some extreme/rare cases not handled by the timeouts
|
| 29 |
+
inside `run_test`"""
|
| 30 |
+
|
| 31 |
+
manager = multiprocessing.Manager()
|
| 32 |
+
result = manager.list()
|
| 33 |
+
metadata_list = manager.list()
|
| 34 |
+
p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result,metadata_list,timeout))
|
| 35 |
+
p.start()
|
| 36 |
+
p.join(timeout=timeout + 1)
|
| 37 |
+
if p.is_alive():
|
| 38 |
+
p.kill()
|
| 39 |
+
# p.terminate()
|
| 40 |
+
if not result:
|
| 41 |
+
# consider that all tests failed
|
| 42 |
+
result = [[-1 for i in range(len(in_outs["inputs"]))]]
|
| 43 |
+
if debug:
|
| 44 |
+
print(f"global timeout")
|
| 45 |
+
return result[0], metadata_list
|
dataset_construction/evaluation_utils/math_util/__init__.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Answer checker API that uses sympy to simplify expressions and check for equality.
|
| 3 |
+
|
| 4 |
+
Call grade_answer(given_answer: str, ground_truth: str).
|
| 5 |
+
|
| 6 |
+
FROM: https://github.com/openai/prm800k/blob/main/prm800k/grading/grader.py
|
| 7 |
+
"""
|
| 8 |
+
import re
|
| 9 |
+
import sympy
|
| 10 |
+
from pylatexenc import latex2text
|
| 11 |
+
from sympy.parsing import sympy_parser
|
| 12 |
+
|
| 13 |
+
from . import math_normalize
|
| 14 |
+
from .grader import math_equal
|
| 15 |
+
# import math_normalize
|
| 16 |
+
# from grader import math_equal
|
| 17 |
+
|
| 18 |
+
# sympy might hang -- we don't care about trying to be lenient in these cases
|
| 19 |
+
BAD_SUBSTRINGS = ["^{", "^("]
|
| 20 |
+
BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"]
|
| 21 |
+
TUPLE_CHARS = "()[]"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _sympy_parse(expr: str):
|
| 25 |
+
"""Parses an expression with sympy."""
|
| 26 |
+
py_expr = expr.replace("^", "**")
|
| 27 |
+
return sympy_parser.parse_expr(
|
| 28 |
+
py_expr,
|
| 29 |
+
transformations=(
|
| 30 |
+
sympy_parser.standard_transformations
|
| 31 |
+
+ (sympy_parser.implicit_multiplication_application,)
|
| 32 |
+
),
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _parse_latex(expr: str) -> str:
|
| 37 |
+
"""Attempts to parse latex to an expression sympy can read."""
|
| 38 |
+
expr = expr.replace("\\tfrac", "\\frac")
|
| 39 |
+
expr = expr.replace("\\dfrac", "\\frac")
|
| 40 |
+
expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers.
|
| 41 |
+
expr = latex2text.LatexNodes2Text().latex_to_text(expr)
|
| 42 |
+
|
| 43 |
+
# Replace the specific characters that this parser uses.
|
| 44 |
+
expr = expr.replace("√", "sqrt")
|
| 45 |
+
expr = expr.replace("π", "pi")
|
| 46 |
+
expr = expr.replace("∞", "inf")
|
| 47 |
+
expr = expr.replace("∪", "U")
|
| 48 |
+
expr = expr.replace("·", "*")
|
| 49 |
+
expr = expr.replace("×", "*")
|
| 50 |
+
|
| 51 |
+
return expr.strip()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _is_float(num: str) -> bool:
|
| 55 |
+
try:
|
| 56 |
+
float(num)
|
| 57 |
+
return True
|
| 58 |
+
except ValueError:
|
| 59 |
+
return False
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _is_int(x: float) -> bool:
|
| 63 |
+
try:
|
| 64 |
+
return abs(x - int(round(x))) <= 1e-7
|
| 65 |
+
except:
|
| 66 |
+
return False
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _is_frac(expr: str) -> bool:
|
| 70 |
+
return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _str_is_int(x: str) -> bool:
|
| 74 |
+
try:
|
| 75 |
+
x = _strip_properly_formatted_commas(x)
|
| 76 |
+
x = float(x)
|
| 77 |
+
return abs(x - int(round(x))) <= 1e-7
|
| 78 |
+
except:
|
| 79 |
+
return False
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _str_to_int(x: str) -> bool:
|
| 83 |
+
x = x.replace(",", "")
|
| 84 |
+
x = float(x)
|
| 85 |
+
return int(x)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _inject_implicit_mixed_number(step: str):
|
| 89 |
+
"""
|
| 90 |
+
Automatically make a mixed number evalable
|
| 91 |
+
e.g. 7 3/4 => 7+3/4
|
| 92 |
+
"""
|
| 93 |
+
p1 = re.compile("([0-9]) +([0-9])")
|
| 94 |
+
step = p1.sub("\\1+\\2", step) ## implicit mults
|
| 95 |
+
return step
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _strip_properly_formatted_commas(expr: str):
|
| 99 |
+
# We want to be careful because we don't want to strip tuple commas
|
| 100 |
+
p1 = re.compile("(\d)(,)(\d\d\d)($|\D)")
|
| 101 |
+
while True:
|
| 102 |
+
next_expr = p1.sub("\\1\\3\\4", expr)
|
| 103 |
+
if next_expr == expr:
|
| 104 |
+
break
|
| 105 |
+
expr = next_expr
|
| 106 |
+
return next_expr
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _normalize(expr: str) -> str:
|
| 110 |
+
"""Normalize answer expressions."""
|
| 111 |
+
if expr is None:
|
| 112 |
+
return None
|
| 113 |
+
|
| 114 |
+
# Remove enclosing `\text{}`.
|
| 115 |
+
m = re.search("^\\\\text\{(?P<text>.+?)\}$", expr)
|
| 116 |
+
if m is not None:
|
| 117 |
+
expr = m.group("text")
|
| 118 |
+
|
| 119 |
+
expr = expr.replace("\\%", "%")
|
| 120 |
+
expr = expr.replace("\\$", "$")
|
| 121 |
+
expr = expr.replace("$", "")
|
| 122 |
+
expr = expr.replace("%", "")
|
| 123 |
+
expr = expr.replace(" or ", " , ")
|
| 124 |
+
expr = expr.replace(" and ", " , ")
|
| 125 |
+
|
| 126 |
+
expr = expr.replace("million", "*10^6")
|
| 127 |
+
expr = expr.replace("billion", "*10^9")
|
| 128 |
+
expr = expr.replace("trillion", "*10^12")
|
| 129 |
+
|
| 130 |
+
for unit in [
|
| 131 |
+
"degree",
|
| 132 |
+
"cm",
|
| 133 |
+
"centimeter",
|
| 134 |
+
"meter",
|
| 135 |
+
"mile",
|
| 136 |
+
"second",
|
| 137 |
+
"minute",
|
| 138 |
+
"hour",
|
| 139 |
+
"day",
|
| 140 |
+
"week",
|
| 141 |
+
"month",
|
| 142 |
+
"year",
|
| 143 |
+
"foot",
|
| 144 |
+
"feet",
|
| 145 |
+
"inch",
|
| 146 |
+
"yard",
|
| 147 |
+
"liter",
|
| 148 |
+
]:
|
| 149 |
+
expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr)
|
| 150 |
+
expr = re.sub(f"\^ *\\\\circ", "", expr)
|
| 151 |
+
# expr = re.sub(f"\^*\\\\circ", "", expr)
|
| 152 |
+
|
| 153 |
+
if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}":
|
| 154 |
+
expr = expr[1:-1]
|
| 155 |
+
|
| 156 |
+
expr = re.sub(",\\\\! *", "", expr)
|
| 157 |
+
if _is_float(expr) and _is_int(float(expr)):
|
| 158 |
+
expr = str(int(round(float(expr))))
|
| 159 |
+
if "\\" in expr:
|
| 160 |
+
try:
|
| 161 |
+
expr = _parse_latex(expr)
|
| 162 |
+
except:
|
| 163 |
+
pass
|
| 164 |
+
|
| 165 |
+
# edge case with mixed numbers and negative signs
|
| 166 |
+
expr = re.sub("- *", "-", expr)
|
| 167 |
+
|
| 168 |
+
expr = _inject_implicit_mixed_number(expr)
|
| 169 |
+
# expr = expr.replace(" ", "")
|
| 170 |
+
|
| 171 |
+
# # if we somehow still have latex braces here, just drop them
|
| 172 |
+
# expr = expr.replace("{", "")
|
| 173 |
+
# expr = expr.replace("}", "")
|
| 174 |
+
|
| 175 |
+
# don't be case sensitive for text answers
|
| 176 |
+
expr = expr.lower()
|
| 177 |
+
|
| 178 |
+
if _str_is_int(expr):
|
| 179 |
+
expr = str(_str_to_int(expr))
|
| 180 |
+
|
| 181 |
+
return expr
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def count_unknown_letters_in_expr(expr: str):
|
| 185 |
+
expr = expr.replace("sqrt", "")
|
| 186 |
+
expr = expr.replace("frac", "")
|
| 187 |
+
letters_in_expr = set([x for x in expr if x.isalpha()])
|
| 188 |
+
return len(letters_in_expr)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def should_allow_eval(expr: str):
|
| 192 |
+
# we don't want to try parsing unknown text or functions of more than two variables
|
| 193 |
+
if count_unknown_letters_in_expr(expr) > 2:
|
| 194 |
+
return False
|
| 195 |
+
|
| 196 |
+
for bad_string in BAD_SUBSTRINGS:
|
| 197 |
+
if bad_string in expr:
|
| 198 |
+
return False
|
| 199 |
+
|
| 200 |
+
for bad_regex in BAD_REGEXES:
|
| 201 |
+
if re.search(bad_regex, expr) is not None:
|
| 202 |
+
return False
|
| 203 |
+
|
| 204 |
+
return True
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):
|
| 208 |
+
are_equal = False
|
| 209 |
+
try:
|
| 210 |
+
expr = f"({ground_truth_normalized})-({given_normalized})"
|
| 211 |
+
if should_allow_eval(expr):
|
| 212 |
+
sympy_diff = _sympy_parse(expr)
|
| 213 |
+
simplified = sympy.simplify(sympy_diff)
|
| 214 |
+
if simplified == 0:
|
| 215 |
+
are_equal = True
|
| 216 |
+
except:
|
| 217 |
+
pass
|
| 218 |
+
return are_equal
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def split_tuple(expr: str):
|
| 222 |
+
"""
|
| 223 |
+
Split the elements in a tuple/interval, while handling well-formatted commas in large numbers
|
| 224 |
+
"""
|
| 225 |
+
expr = _strip_properly_formatted_commas(expr)
|
| 226 |
+
if len(expr) == 0:
|
| 227 |
+
return []
|
| 228 |
+
if (
|
| 229 |
+
len(expr) > 2
|
| 230 |
+
and expr[0] in TUPLE_CHARS
|
| 231 |
+
and expr[-1] in TUPLE_CHARS
|
| 232 |
+
and all([ch not in expr[1:-1] for ch in TUPLE_CHARS])
|
| 233 |
+
):
|
| 234 |
+
elems = [elem.strip() for elem in expr[1:-1].split(",")]
|
| 235 |
+
else:
|
| 236 |
+
elems = [expr]
|
| 237 |
+
return elems
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def grade_answer(given_answer: str, ground_truth: str) -> bool:
|
| 241 |
+
"""
|
| 242 |
+
The answer will be considered correct if:
|
| 243 |
+
(a) it normalizes to the same string as the ground truth answer
|
| 244 |
+
OR
|
| 245 |
+
(b) sympy can simplify the difference between the expressions to 0
|
| 246 |
+
"""
|
| 247 |
+
if given_answer is None:
|
| 248 |
+
return False
|
| 249 |
+
|
| 250 |
+
ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth)
|
| 251 |
+
given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer)
|
| 252 |
+
|
| 253 |
+
# be at least as lenient as mathd
|
| 254 |
+
if ground_truth_normalized_mathd == given_answer_normalized_mathd:
|
| 255 |
+
return True
|
| 256 |
+
|
| 257 |
+
ground_truth_normalized = _normalize(ground_truth)
|
| 258 |
+
given_normalized = _normalize(given_answer)
|
| 259 |
+
|
| 260 |
+
if ground_truth_normalized is None:
|
| 261 |
+
return False
|
| 262 |
+
|
| 263 |
+
if ground_truth_normalized == given_normalized:
|
| 264 |
+
return True
|
| 265 |
+
|
| 266 |
+
if len(given_normalized) == 0:
|
| 267 |
+
return False
|
| 268 |
+
|
| 269 |
+
ground_truth_elems = split_tuple(ground_truth_normalized)
|
| 270 |
+
given_elems = split_tuple(given_normalized)
|
| 271 |
+
|
| 272 |
+
if len(ground_truth_elems) > 1 and (
|
| 273 |
+
ground_truth_normalized[0] != given_normalized[0]
|
| 274 |
+
or ground_truth_normalized[-1] != given_normalized[-1]
|
| 275 |
+
):
|
| 276 |
+
is_correct = False
|
| 277 |
+
elif len(ground_truth_elems) != len(given_elems):
|
| 278 |
+
is_correct = False
|
| 279 |
+
else:
|
| 280 |
+
for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems):
|
| 281 |
+
if _is_frac(ground_truth_elem) and _is_frac(given_elem):
|
| 282 |
+
# if fractions aren't reduced, then shouldn't be marked as correct
|
| 283 |
+
# so, we don't want to allow sympy.simplify in this case
|
| 284 |
+
is_correct = ground_truth_elem == given_elem
|
| 285 |
+
elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem):
|
| 286 |
+
# if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify)
|
| 287 |
+
is_correct = False
|
| 288 |
+
else:
|
| 289 |
+
is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)
|
| 290 |
+
if not is_correct:
|
| 291 |
+
break
|
| 292 |
+
|
| 293 |
+
return is_correct
|
| 294 |
+
|
| 295 |
+
def remove_boxed(s):
|
| 296 |
+
left = "\\boxed{"
|
| 297 |
+
try:
|
| 298 |
+
assert s[:len(left)] == left
|
| 299 |
+
assert s[-1] == "}"
|
| 300 |
+
return s[len(left):-1]
|
| 301 |
+
except:
|
| 302 |
+
return None
|
| 303 |
+
|
| 304 |
+
def _last_boxed_only_string(string):
|
| 305 |
+
idx = string.rfind("\\boxed")
|
| 306 |
+
if idx < 0:
|
| 307 |
+
idx = string.rfind("\\fbox")
|
| 308 |
+
if idx < 0:
|
| 309 |
+
return None
|
| 310 |
+
|
| 311 |
+
i = idx
|
| 312 |
+
left_brace_idx = None
|
| 313 |
+
right_brace_idx = None
|
| 314 |
+
num_left_braces_open = 0
|
| 315 |
+
while i < len(string):
|
| 316 |
+
if string[i] == "{":
|
| 317 |
+
num_left_braces_open += 1
|
| 318 |
+
if left_brace_idx is None:
|
| 319 |
+
left_brace_idx = i
|
| 320 |
+
elif string[i] == "}":
|
| 321 |
+
num_left_braces_open -= 1
|
| 322 |
+
if num_left_braces_open == 0:
|
| 323 |
+
right_brace_idx = i
|
| 324 |
+
break
|
| 325 |
+
|
| 326 |
+
i += 1
|
| 327 |
+
|
| 328 |
+
if left_brace_idx is None or right_brace_idx is None:
|
| 329 |
+
return None
|
| 330 |
+
|
| 331 |
+
return string[left_brace_idx + 1: right_brace_idx].strip()
|
| 332 |
+
|
| 333 |
+
def match_answer(response):
|
| 334 |
+
is_matched = False
|
| 335 |
+
for ans_marker in ['answer:', "answer is", "answers are"]:
|
| 336 |
+
ans_idx = response.lower().rfind(ans_marker)
|
| 337 |
+
if ans_idx != -1:
|
| 338 |
+
is_matched = True
|
| 339 |
+
response = response[ans_idx + len(ans_marker):].strip()
|
| 340 |
+
if response.endswith("\n"):
|
| 341 |
+
response = response[:-2]
|
| 342 |
+
|
| 343 |
+
for ans_marker in ["is answer", "is the answer", "are answers", "are the answers"]:
|
| 344 |
+
ans_idx = response.lower().rfind(ans_marker)
|
| 345 |
+
if ans_idx != -1:
|
| 346 |
+
is_matched = True
|
| 347 |
+
response = response[:ans_idx].strip()
|
| 348 |
+
if response.endswith("\n"):
|
| 349 |
+
response = response[:-2]
|
| 350 |
+
|
| 351 |
+
# Find boxed
|
| 352 |
+
ans_boxed = _last_boxed_only_string(response)
|
| 353 |
+
if ans_boxed:
|
| 354 |
+
is_matched = True
|
| 355 |
+
response = ans_boxed
|
| 356 |
+
|
| 357 |
+
if ". " in response:
|
| 358 |
+
dot_idx = response.lower().rfind(". ")
|
| 359 |
+
if dot_idx != -1:
|
| 360 |
+
response = response[:dot_idx].strip()
|
| 361 |
+
|
| 362 |
+
for ans_marker in ['be ', "is ", "are ", "=", ": ", "get ", 'be\n', "is\n", "are\n", ":\n", "get\n"]:
|
| 363 |
+
ans_idx = response.lower().rfind(ans_marker)
|
| 364 |
+
if ans_idx != -1:
|
| 365 |
+
is_matched = True
|
| 366 |
+
response = response[ans_idx + len(ans_marker):].strip()
|
| 367 |
+
if response.endswith("\n"):
|
| 368 |
+
response = response[:-2]
|
| 369 |
+
|
| 370 |
+
is_matched = is_matched if any([c.isdigit() for c in response]) else False # answer must have a digit
|
| 371 |
+
# Grade
|
| 372 |
+
return is_matched, response
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
import math
|
| 376 |
+
def evaluate_math(model_output: str, ground_truth: str) -> bool:
|
| 377 |
+
model_output = str(model_output)
|
| 378 |
+
ground_truth = str(ground_truth)
|
| 379 |
+
|
| 380 |
+
is_matched, extracted_model_output = match_answer(model_output)
|
| 381 |
+
format_correctness = "Step 2:" in model_output and "\\box" in model_output
|
| 382 |
+
# print(f"{model_output=}")
|
| 383 |
+
# print("\n")
|
| 384 |
+
# print(f"{extracted_model_output=}")
|
| 385 |
+
# print("\n")
|
| 386 |
+
# print(f"{ground_truth=}")
|
| 387 |
+
# print("="*20)
|
| 388 |
+
# print("\n\n")
|
| 389 |
+
|
| 390 |
+
# grade simple algebra questions. if succeed, return; otherwise, proceed to more complex grading
|
| 391 |
+
if grade_answer(extracted_model_output, ground_truth):
|
| 392 |
+
return True, True, extracted_model_output
|
| 393 |
+
# return True
|
| 394 |
+
|
| 395 |
+
try:
|
| 396 |
+
if "\pi" in extracted_model_output or "\pi" in ground_truth:
|
| 397 |
+
equivs = []
|
| 398 |
+
for pi in [math.pi, 3.14]:
|
| 399 |
+
equivs.append(math_equal(extracted_model_output, ground_truth, timeout=True, pi=pi))
|
| 400 |
+
is_correct = any(equivs)
|
| 401 |
+
else:
|
| 402 |
+
is_correct = math_equal(extracted_model_output, ground_truth, timeout=True)
|
| 403 |
+
except:
|
| 404 |
+
is_correct = False
|
| 405 |
+
|
| 406 |
+
# print(f"{extracted_model_output=}\n", f"{model_output=}\n", f"{ground_truth=}\n")
|
| 407 |
+
|
| 408 |
+
return is_correct, format_correctness, extracted_model_output
|
| 409 |
+
|
| 410 |
+
|
dataset_construction/evaluation_utils/math_util/grader.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# Copyright (c) Microsoft Corporation.
|
| 16 |
+
#
|
| 17 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 18 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 19 |
+
# in the Software without restriction, including without limitation the rights
|
| 20 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 21 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 22 |
+
# furnished to do so, subject to the following conditions:
|
| 23 |
+
#
|
| 24 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 25 |
+
# copies or substantial portions of the Software.
|
| 26 |
+
#
|
| 27 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 28 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 29 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 30 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 31 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 32 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 33 |
+
# SOFTWARE
|
| 34 |
+
|
| 35 |
+
# Copyright (c) 2023 OpenAI
|
| 36 |
+
#
|
| 37 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 38 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 39 |
+
# in the Software without restriction, including without limitation the rights
|
| 40 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 41 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 42 |
+
# furnished to do so, subject to the following conditions:
|
| 43 |
+
|
| 44 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 45 |
+
# copies or substantial portions of the Software.
|
| 46 |
+
#
|
| 47 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 48 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 49 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 50 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 51 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 52 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 53 |
+
# SOFTWARE.
|
| 54 |
+
|
| 55 |
+
# Copyright (c) 2021 Dan Hendrycks
|
| 56 |
+
#
|
| 57 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 58 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 59 |
+
# in the Software without restriction, including without limitation the rights
|
| 60 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 61 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 62 |
+
# furnished to do so, subject to the following conditions:
|
| 63 |
+
#
|
| 64 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 65 |
+
# copies or substantial portions of the Software.
|
| 66 |
+
#
|
| 67 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 68 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 69 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 70 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 71 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 72 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 73 |
+
# SOFTWARE.
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
"""
|
| 77 |
+
This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
|
| 78 |
+
- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py
|
| 79 |
+
- https://github.com/microsoft/ProphetNet/tree/master/CRITIC
|
| 80 |
+
- https://github.com/openai/prm800k
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
import contextlib
|
| 85 |
+
import re
|
| 86 |
+
import signal
|
| 87 |
+
import math
|
| 88 |
+
from math import isclose
|
| 89 |
+
from typing import Union
|
| 90 |
+
|
| 91 |
+
from sympy import N, simplify
|
| 92 |
+
from sympy.parsing.latex import parse_latex
|
| 93 |
+
from sympy.parsing.sympy_parser import parse_expr
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def is_digit(s):
|
| 97 |
+
try:
|
| 98 |
+
if "{,}" in str(s):
|
| 99 |
+
num = float(str(s).replace("{,}", ""))
|
| 100 |
+
return True, num
|
| 101 |
+
|
| 102 |
+
num = float(str(s).replace(",", ""))
|
| 103 |
+
return True, num
|
| 104 |
+
except ValueError:
|
| 105 |
+
return False, None
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def normalize(answer, pi) -> str:
|
| 109 |
+
# checking if answer is $<number> and removing $ in that case to compare
|
| 110 |
+
if isinstance(answer, str) and bool(re.match(r'\$\d+(\.\d+)?', answer)):
|
| 111 |
+
return answer[1:]
|
| 112 |
+
|
| 113 |
+
# checking if answer is <number>% or <number>\\% and removing %
|
| 114 |
+
if isinstance(answer, str) and (
|
| 115 |
+
bool(re.match(r'^\d+(\.\d+)?%$', answer)) or bool(re.match(r'^\d+(\.\d+)?\\%$', answer))
|
| 116 |
+
):
|
| 117 |
+
return answer.replace("\\%", "").replace("%", "")
|
| 118 |
+
|
| 119 |
+
# handle base
|
| 120 |
+
answer = handle_base(answer)
|
| 121 |
+
|
| 122 |
+
# handle pi
|
| 123 |
+
answer = handle_pi(answer, pi)
|
| 124 |
+
|
| 125 |
+
return answer
|
| 126 |
+
|
| 127 |
+
def handle_base(x) -> str:
|
| 128 |
+
if isinstance(x, str) and "_" in x:
|
| 129 |
+
# Due to base
|
| 130 |
+
x = x.split("_")[0]
|
| 131 |
+
x = float(x)
|
| 132 |
+
return int(x)
|
| 133 |
+
return x
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def handle_pi(string, pi):
|
| 137 |
+
|
| 138 |
+
if isinstance(string, str) and "\pi" in string:
|
| 139 |
+
# Find the first occurrence of "\pi"
|
| 140 |
+
idx = string.find("\pi")
|
| 141 |
+
|
| 142 |
+
# Iterate over the string and find all occurrences of "\pi" with a valid previous character
|
| 143 |
+
while idx != -1:
|
| 144 |
+
|
| 145 |
+
if idx > 0 and string[idx-1].isdigit():
|
| 146 |
+
# Replace "\pi" with "*math.pi" if the previous character is a digit
|
| 147 |
+
string = string[:idx] + f"*{pi}" + string[idx+3:]
|
| 148 |
+
else:
|
| 149 |
+
# Replace "\pi" with "1*math.pi" if the previous character is not a digit
|
| 150 |
+
string = string[:idx] + f"1*{pi}" + string[idx+3:]
|
| 151 |
+
|
| 152 |
+
# Find the next occurrence of "\pi"
|
| 153 |
+
idx = string.find("\pi", idx + 1)
|
| 154 |
+
|
| 155 |
+
# Evaluate the expression using eval() function
|
| 156 |
+
try:
|
| 157 |
+
string = eval(string)
|
| 158 |
+
except:
|
| 159 |
+
pass
|
| 160 |
+
|
| 161 |
+
return string
|
| 162 |
+
|
| 163 |
+
def math_equal(
|
| 164 |
+
prediction: Union[bool, float, str],
|
| 165 |
+
reference: Union[float, str],
|
| 166 |
+
include_percentage: bool = True,
|
| 167 |
+
tolerance: float = 1e-4,
|
| 168 |
+
timeout: float = 10.0,
|
| 169 |
+
pi: float = math.pi
|
| 170 |
+
) -> bool:
|
| 171 |
+
"""
|
| 172 |
+
Exact match of math if and only if:
|
| 173 |
+
1. numerical equal: both can convert to float and are equal
|
| 174 |
+
2. symbolic equal: both can convert to sympy expression and are equal
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
prediction = normalize(prediction, pi)
|
| 178 |
+
reference = normalize(reference, pi)
|
| 179 |
+
|
| 180 |
+
if isinstance(prediction, str) and len(prediction) > 1000: # handling weird corner-cases
|
| 181 |
+
prediction = prediction[:1000]
|
| 182 |
+
|
| 183 |
+
# 0. string comparison
|
| 184 |
+
if isinstance(prediction, str) and isinstance(reference, str):
|
| 185 |
+
if prediction.strip().lower() == reference.strip().lower():
|
| 186 |
+
return True
|
| 187 |
+
if prediction.replace(" ", "") == reference.replace(" ", ""):
|
| 188 |
+
return True
|
| 189 |
+
|
| 190 |
+
try: # 1. numerical equal
|
| 191 |
+
if is_digit(prediction)[0] and is_digit(reference)[0]:
|
| 192 |
+
prediction = is_digit(prediction)[1]
|
| 193 |
+
reference = is_digit(reference)[1]
|
| 194 |
+
# number questions
|
| 195 |
+
if include_percentage:
|
| 196 |
+
gt_result = [reference / 100, reference, reference * 100]
|
| 197 |
+
else:
|
| 198 |
+
gt_result = [reference]
|
| 199 |
+
for item in gt_result:
|
| 200 |
+
try:
|
| 201 |
+
if isclose(item, prediction, rel_tol=tolerance):
|
| 202 |
+
return True
|
| 203 |
+
except Exception:
|
| 204 |
+
continue
|
| 205 |
+
return False
|
| 206 |
+
except Exception:
|
| 207 |
+
pass
|
| 208 |
+
|
| 209 |
+
if not prediction and prediction not in [0, False]:
|
| 210 |
+
return False
|
| 211 |
+
|
| 212 |
+
# 2. symbolic equal
|
| 213 |
+
reference = str(reference).strip()
|
| 214 |
+
prediction = str(prediction).strip()
|
| 215 |
+
|
| 216 |
+
## deal with [], (), {}
|
| 217 |
+
prediction = format_intervals(prediction)
|
| 218 |
+
|
| 219 |
+
pred_str, ref_str = prediction, reference
|
| 220 |
+
if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or (
|
| 221 |
+
prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")
|
| 222 |
+
):
|
| 223 |
+
pred_str = pred_str.strip("[]()")
|
| 224 |
+
ref_str = ref_str.strip("[]()")
|
| 225 |
+
for s in ["{", "}", "(", ")"]:
|
| 226 |
+
ref_str = ref_str.replace(s, "")
|
| 227 |
+
pred_str = pred_str.replace(s, "")
|
| 228 |
+
if pred_str == ref_str:
|
| 229 |
+
return True
|
| 230 |
+
|
| 231 |
+
## [a, b] vs. [c, d], return a==c and b==d
|
| 232 |
+
if (
|
| 233 |
+
prediction
|
| 234 |
+
and reference
|
| 235 |
+
and prediction[0] in "(["
|
| 236 |
+
and prediction[-1] in ")]"
|
| 237 |
+
and prediction[0] == reference[0]
|
| 238 |
+
and prediction[-1] == reference[-1]
|
| 239 |
+
):
|
| 240 |
+
pred_parts = prediction[1:-1].split(",")
|
| 241 |
+
ref_parts = reference[1:-1].split(",")
|
| 242 |
+
if len(pred_parts) == len(ref_parts):
|
| 243 |
+
if all(
|
| 244 |
+
[
|
| 245 |
+
math_equal(pred_pt, ref_pt, include_percentage, tolerance)
|
| 246 |
+
for pred_pt, ref_pt in zip(pred_parts, ref_parts)
|
| 247 |
+
]
|
| 248 |
+
):
|
| 249 |
+
return True
|
| 250 |
+
|
| 251 |
+
if "," in prediction and "," in reference:
|
| 252 |
+
pred_parts = [item.strip() for item in prediction.split(",")]
|
| 253 |
+
ref_parts = [item.strip() for item in reference.split(",")]
|
| 254 |
+
|
| 255 |
+
if len(pred_parts) == len(ref_parts):
|
| 256 |
+
if all(
|
| 257 |
+
[
|
| 258 |
+
math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance)
|
| 259 |
+
for i in range(len(pred_parts))
|
| 260 |
+
]
|
| 261 |
+
):
|
| 262 |
+
return True
|
| 263 |
+
else:
|
| 264 |
+
return False
|
| 265 |
+
|
| 266 |
+
# if we have point == tuple of values
|
| 267 |
+
if prediction.startswith("Point") and reference[0] == "(" and reference[-1] == ")":
|
| 268 |
+
pred_parts = prediction[prediction.find("(") + 1 : -1].split(",")
|
| 269 |
+
ref_parts = reference[1:-1].split(",")
|
| 270 |
+
if len(pred_parts) == len(ref_parts):
|
| 271 |
+
if all(
|
| 272 |
+
[
|
| 273 |
+
math_equal(pred_pt, ref_pt, include_percentage, tolerance)
|
| 274 |
+
for pred_pt, ref_pt in zip(pred_parts, ref_parts)
|
| 275 |
+
]
|
| 276 |
+
):
|
| 277 |
+
return True
|
| 278 |
+
|
| 279 |
+
# if reference is a matrix
|
| 280 |
+
if "\begin{pmatrix}" in reference and prediction.startswith("Matrix"):
|
| 281 |
+
try:
|
| 282 |
+
pred_matrix = parse_expr(prediction)
|
| 283 |
+
ref_matrix_items = reference.split()[1:-1:2]
|
| 284 |
+
if len(pred_matrix) == len(ref_matrix_items):
|
| 285 |
+
if all(
|
| 286 |
+
[
|
| 287 |
+
math_equal(pred, ref, include_percentage, tolerance)
|
| 288 |
+
for ref, pred in zip(ref_matrix_items, pred_matrix)
|
| 289 |
+
]
|
| 290 |
+
):
|
| 291 |
+
return True
|
| 292 |
+
except Exception:
|
| 293 |
+
pass
|
| 294 |
+
elif "\begin{pmatrix}" in reference and prediction.startswith("[") and prediction.endswith("]"):
|
| 295 |
+
if isinstance(eval(prediction), list):
|
| 296 |
+
try:
|
| 297 |
+
pred_matrix = eval(prediction)
|
| 298 |
+
# ref_matrix_items = reference.split()[1:-1:2]
|
| 299 |
+
ref_matrix_items = reference.lstrip("\\begin{pmatrix}").lstrip("\begin{pmatrix}").rstrip("\\end{pmatrix}").rstrip("\end{pmatrix}")
|
| 300 |
+
ref_matrix_items = ref_matrix_items.split("\\")
|
| 301 |
+
ref_matrix_items = [row.split("&") if "&" in row else row for row in ref_matrix_items]
|
| 302 |
+
if len(pred_matrix) == len(ref_matrix_items):
|
| 303 |
+
if all(
|
| 304 |
+
[
|
| 305 |
+
math_equal(pred, ref, include_percentage, tolerance)
|
| 306 |
+
for ref, pred in zip(ref_matrix_items, pred_matrix)
|
| 307 |
+
]
|
| 308 |
+
):
|
| 309 |
+
return True
|
| 310 |
+
except Exception:
|
| 311 |
+
pass
|
| 312 |
+
|
| 313 |
+
return symbolic_equal(prediction, reference, tolerance, timeout)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def symbolic_equal(a, b, tolerance, timeout=10.0):
|
| 317 |
+
def _parse(s):
|
| 318 |
+
for f in [parse_expr, parse_latex]:
|
| 319 |
+
try:
|
| 320 |
+
with time_limit(timeout):
|
| 321 |
+
return f(s)
|
| 322 |
+
except Exception:
|
| 323 |
+
pass
|
| 324 |
+
return s
|
| 325 |
+
|
| 326 |
+
a = _parse(a)
|
| 327 |
+
b = _parse(b)
|
| 328 |
+
|
| 329 |
+
try:
|
| 330 |
+
with time_limit(timeout):
|
| 331 |
+
if simplify(a - b) == 0:
|
| 332 |
+
return True
|
| 333 |
+
except Exception:
|
| 334 |
+
pass
|
| 335 |
+
|
| 336 |
+
try:
|
| 337 |
+
with time_limit(timeout):
|
| 338 |
+
if isclose(N(a), N(b), rel_tol=tolerance):
|
| 339 |
+
return True
|
| 340 |
+
except Exception:
|
| 341 |
+
pass
|
| 342 |
+
return False
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
class TimeoutException(Exception):
|
| 346 |
+
pass
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
@contextlib.contextmanager
|
| 350 |
+
def time_limit(seconds: float):
|
| 351 |
+
def signal_handler(signum, frame):
|
| 352 |
+
raise TimeoutException("Timed out!")
|
| 353 |
+
|
| 354 |
+
signal.setitimer(signal.ITIMER_REAL, seconds)
|
| 355 |
+
signal.signal(signal.SIGALRM, signal_handler)
|
| 356 |
+
try:
|
| 357 |
+
yield
|
| 358 |
+
finally:
|
| 359 |
+
signal.setitimer(signal.ITIMER_REAL, 0)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def format_intervals(prediction):
|
| 363 |
+
patterns = {
|
| 364 |
+
"Interval(": r"^Interval\((.*)\)$",
|
| 365 |
+
"Interval.Ropen(": r"^Interval\.Ropen\((.*)\)$",
|
| 366 |
+
"Interval.Lopen(": r"^Interval\.Lopen\((.*)\)$",
|
| 367 |
+
"Interval.open(": r"^Interval\.open\((.*)\)$",
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
for key, pattern in patterns.items():
|
| 371 |
+
match = re.match(pattern, prediction)
|
| 372 |
+
if match:
|
| 373 |
+
inner_content = match.group(1)
|
| 374 |
+
|
| 375 |
+
if key == "Interval(": # Intarval(a, b) == [a, b]
|
| 376 |
+
return f"[{inner_content}]"
|
| 377 |
+
elif key == "Interval.Ropen(": # Intarval.Ropen(a, b) == [a, b)
|
| 378 |
+
return f"[{inner_content})"
|
| 379 |
+
elif key == "Interval.Lopen(": # Intarval.Lopen(a, b) == (a, b]
|
| 380 |
+
return f"({inner_content}]"
|
| 381 |
+
elif key == "Interval.open(": # Intarval.open(a, b) == (a, b)
|
| 382 |
+
return f"({inner_content})"
|
| 383 |
+
|
| 384 |
+
return prediction
|
dataset_construction/evaluation_utils/math_util/math_normalize.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This logic is largely copied from the Hendrycks' MATH release (math_equivalence).
|
| 3 |
+
|
| 4 |
+
From: https://github.com/openai/prm800k/blob/main/prm800k/grading/math_normalize.py
|
| 5 |
+
"""
|
| 6 |
+
import re
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def normalize_answer(answer: Optional[str]) -> Optional[str]:
|
| 11 |
+
if answer is None:
|
| 12 |
+
return None
|
| 13 |
+
answer = answer.strip()
|
| 14 |
+
try:
|
| 15 |
+
# Remove enclosing `\text{}`.
|
| 16 |
+
m = re.search("^\\\\text\{(?P<text>.+?)\}$", answer)
|
| 17 |
+
if m is not None:
|
| 18 |
+
answer = m.group("text").strip()
|
| 19 |
+
return _strip_string(answer)
|
| 20 |
+
except:
|
| 21 |
+
return answer
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _fix_fracs(string):
|
| 25 |
+
substrs = string.split("\\frac")
|
| 26 |
+
new_str = substrs[0]
|
| 27 |
+
if len(substrs) > 1:
|
| 28 |
+
substrs = substrs[1:]
|
| 29 |
+
for substr in substrs:
|
| 30 |
+
new_str += "\\frac"
|
| 31 |
+
if substr[0] == "{":
|
| 32 |
+
new_str += substr
|
| 33 |
+
else:
|
| 34 |
+
try:
|
| 35 |
+
assert len(substr) >= 2
|
| 36 |
+
except:
|
| 37 |
+
return string
|
| 38 |
+
a = substr[0]
|
| 39 |
+
b = substr[1]
|
| 40 |
+
if b != "{":
|
| 41 |
+
if len(substr) > 2:
|
| 42 |
+
post_substr = substr[2:]
|
| 43 |
+
new_str += "{" + a + "}{" + b + "}" + post_substr
|
| 44 |
+
else:
|
| 45 |
+
new_str += "{" + a + "}{" + b + "}"
|
| 46 |
+
else:
|
| 47 |
+
if len(substr) > 2:
|
| 48 |
+
post_substr = substr[2:]
|
| 49 |
+
new_str += "{" + a + "}" + b + post_substr
|
| 50 |
+
else:
|
| 51 |
+
new_str += "{" + a + "}" + b
|
| 52 |
+
string = new_str
|
| 53 |
+
return string
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _fix_a_slash_b(string):
|
| 57 |
+
if len(string.split("/")) != 2:
|
| 58 |
+
return string
|
| 59 |
+
a = string.split("/")[0]
|
| 60 |
+
b = string.split("/")[1]
|
| 61 |
+
try:
|
| 62 |
+
a = int(a)
|
| 63 |
+
b = int(b)
|
| 64 |
+
assert string == "{}/{}".format(a, b)
|
| 65 |
+
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
| 66 |
+
return new_string
|
| 67 |
+
except:
|
| 68 |
+
return string
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _remove_right_units(string):
|
| 72 |
+
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
| 73 |
+
if "\\text{ " in string:
|
| 74 |
+
splits = string.split("\\text{ ")
|
| 75 |
+
assert len(splits) == 2
|
| 76 |
+
return splits[0]
|
| 77 |
+
else:
|
| 78 |
+
return string
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _fix_sqrt(string):
|
| 82 |
+
if "\\sqrt" not in string:
|
| 83 |
+
return string
|
| 84 |
+
splits = string.split("\\sqrt")
|
| 85 |
+
new_string = splits[0]
|
| 86 |
+
for split in splits[1:]:
|
| 87 |
+
if split[0] != "{":
|
| 88 |
+
a = split[0]
|
| 89 |
+
new_substr = "\\sqrt{" + a + "}" + split[1:]
|
| 90 |
+
else:
|
| 91 |
+
new_substr = "\\sqrt" + split
|
| 92 |
+
new_string += new_substr
|
| 93 |
+
return new_string
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _strip_string(string):
|
| 97 |
+
# linebreaks
|
| 98 |
+
string = string.replace("\n", "")
|
| 99 |
+
# print(string)
|
| 100 |
+
|
| 101 |
+
# remove inverse spaces
|
| 102 |
+
string = string.replace("\\!", "")
|
| 103 |
+
# print(string)
|
| 104 |
+
|
| 105 |
+
# replace \\ with \
|
| 106 |
+
string = string.replace("\\\\", "\\")
|
| 107 |
+
# print(string)
|
| 108 |
+
|
| 109 |
+
# replace tfrac and dfrac with frac
|
| 110 |
+
string = string.replace("tfrac", "frac")
|
| 111 |
+
string = string.replace("dfrac", "frac")
|
| 112 |
+
# print(string)
|
| 113 |
+
|
| 114 |
+
# remove \left and \right
|
| 115 |
+
string = string.replace("\\left", "")
|
| 116 |
+
string = string.replace("\\right", "")
|
| 117 |
+
# print(string)
|
| 118 |
+
|
| 119 |
+
# Remove circ (degrees)
|
| 120 |
+
string = string.replace("^{\\circ}", "")
|
| 121 |
+
string = string.replace("^\\circ", "")
|
| 122 |
+
|
| 123 |
+
# remove dollar signs
|
| 124 |
+
string = string.replace("\\$", "")
|
| 125 |
+
|
| 126 |
+
# remove units (on the right)
|
| 127 |
+
string = _remove_right_units(string)
|
| 128 |
+
|
| 129 |
+
# remove percentage
|
| 130 |
+
string = string.replace("\\%", "")
|
| 131 |
+
string = string.replace("\%", "")
|
| 132 |
+
|
| 133 |
+
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
| 134 |
+
string = string.replace(" .", " 0.")
|
| 135 |
+
string = string.replace("{.", "{0.")
|
| 136 |
+
# if empty, return empty string
|
| 137 |
+
if len(string) == 0:
|
| 138 |
+
return string
|
| 139 |
+
if string[0] == ".":
|
| 140 |
+
string = "0" + string
|
| 141 |
+
|
| 142 |
+
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
| 143 |
+
if len(string.split("=")) == 2:
|
| 144 |
+
if len(string.split("=")[0]) <= 2:
|
| 145 |
+
string = string.split("=")[1]
|
| 146 |
+
|
| 147 |
+
# fix sqrt3 --> sqrt{3}
|
| 148 |
+
string = _fix_sqrt(string)
|
| 149 |
+
|
| 150 |
+
# remove spaces
|
| 151 |
+
string = string.replace(" ", "")
|
| 152 |
+
|
| 153 |
+
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
| 154 |
+
string = _fix_fracs(string)
|
| 155 |
+
|
| 156 |
+
# manually change 0.5 --> \frac{1}{2}
|
| 157 |
+
if string == "0.5":
|
| 158 |
+
string = "\\frac{1}{2}"
|
| 159 |
+
|
| 160 |
+
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
| 161 |
+
string = _fix_a_slash_b(string)
|
| 162 |
+
|
| 163 |
+
return string
|
dataset_construction/run_evaluate_on_slurm.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=evaluate_on_eurus_2
|
| 3 |
+
#SBATCH --time=4-00:00:00
|
| 4 |
+
#SBATCH --gpus-per-node=2
|
| 5 |
+
#SBATCH --mem=128G
|
| 6 |
+
#SBATCH --account=lusu-k
|
| 7 |
+
|
| 8 |
+
conda activate domain_rlhf
|
| 9 |
+
|
| 10 |
+
python3 evaluate_on_eurus_2.py --tensor_parallel_size 2
|