| | |
| | import re |
| |
|
| | |
| | import signal |
| | import sys |
| |
|
| | |
| | from io import StringIO |
| | from typing import List, Tuple |
| |
|
| | |
| | from unittest.mock import patch, mock_open |
| |
|
| | import numpy as np |
| | from pyext import RuntimeModule |
| | from wrapt_timeout_decorator import timeout as wrapt_timeout |
| | import threading |
| |
|
| | from ..datasets.schema import assert_test_format_codeforces |
| |
|
| | from flows import logging |
| |
|
| | log = logging.get_logger(__name__) |
| | lock = threading.Lock() |
| |
|
| |
|
| | def evaluate_solution_for_problem( |
| | candidate_solution, |
| | hidden_tests_io=None, |
| | public_tests_io=None, |
| | timeout=10, |
| | debug=False, |
| | add_extra_imports=False, |
| | allow_truncated_io=False, |
| | ): |
| | with lock: |
| | """See the readme for the output format of this function.""" |
| | if hidden_tests_io is None: |
| | hidden_tests_io = [] |
| | if public_tests_io is None: |
| | public_tests_io = [] |
| |
|
| | if candidate_solution is None: |
| | results_dict = { |
| | "compilation_status": False, |
| | "compilation_error_message": "No code was provided.", |
| | "timeout_error": False, |
| | "hidden_tests_results": [ |
| | { |
| | "status": False, |
| | "error_message": "No code was provided.", |
| | "generated_output": None, |
| | "input": test[0], |
| | "expected_output": test[1], |
| | } |
| | for test in hidden_tests_io |
| | ], |
| | "public_tests_results": [ |
| | { |
| | "status": False, |
| | "error_message": "No code was provided.", |
| | "generated_output": None, |
| | "input": test[0], |
| | "expected_output": test[1], |
| | } |
| | for test in public_tests_io |
| | ], |
| | } |
| | return results_dict |
| |
|
| | @wrapt_timeout(timeout, use_signals=False) |
| | def run_tests(): |
| | hidden_tests_results = check_correctness( |
| | candidate_solution, hidden_tests_io, timeout, debug, add_extra_imports, allow_truncated_io |
| | ) |
| | public_tests_results = check_correctness( |
| | candidate_solution, public_tests_io, timeout, debug, add_extra_imports, allow_truncated_io |
| | ) |
| |
|
| | return hidden_tests_results, public_tests_results |
| |
|
| | try: |
| | hidden_tests_results, public_tests_results = run_tests() |
| | timeout_error_occurred = False |
| | except BaseException as e: |
| | log.info(e) |
| | hidden_tests_results = {} |
| | public_tests_results = {} |
| |
|
| | hidden_tests_results["compilation_status"] = True |
| | public_tests_results["compilation_status"] = True |
| | timeout_error_occurred = True |
| | hidden_tests_results["error_message"] = "Timeout error." |
| |
|
| | hidden_tests_results["results"] = [ |
| | { |
| | "status": False, |
| | "error_message": hidden_tests_results["error_message"], |
| | "generated_output": None, |
| | "input": test[0], |
| | "expected_output": test[1], |
| | } |
| | for test in hidden_tests_io |
| | ] |
| | public_tests_results["results"] = [ |
| | { |
| | "status": False, |
| | "error_message": hidden_tests_results["error_message"], |
| | "generated_output": None, |
| | "input": test[0], |
| | "expected_output": test[1], |
| | } |
| | for test in public_tests_io |
| | ] |
| |
|
| | |
| | assert hidden_tests_results["compilation_status"] == public_tests_results["compilation_status"] |
| |
|
| | results_dict = { |
| | "compilation_status": hidden_tests_results["compilation_status"], |
| | "compilation_error_message": hidden_tests_results["error_message"], |
| | "timeout_error": timeout_error_occurred, |
| | "hidden_tests_results": hidden_tests_results["results"], |
| | "public_tests_results": public_tests_results["results"], |
| | } |
| |
|
| | return results_dict |
| |
|
| |
|
| | def check_correctness( |
| | candidate_solution: str, |
| | tests: List[Tuple[List[str], str]], |
| | timeout: int = 6000, |
| | debug=True, |
| | add_extra_imports=False, |
| | allow_truncated_io=True, |
| | ): |
| | """ |
| | wrapping the testing code in a global timeout, based on huggingface code |
| | """ |
| |
|
| | assert_test_format_codeforces(tests) |
| | inputs, outputs = [], [] |
| | if len(tests) > 0: |
| | inputs, outputs = zip(*tests) |
| |
|
| | compilation_error, results = run_test( |
| | candidate_solution, inputs, outputs, timeout, debug, add_extra_imports, allow_truncated_io |
| | ) |
| |
|
| | assert len(results) == len(inputs) |
| |
|
| | for result in results: |
| | assert isinstance(result["generated_output"], str) or result["generated_output"] is None |
| | assert isinstance(result["status"], bool) |
| | assert isinstance(result["error_message"], str) or result["error_message"] is None |
| | assert isinstance(result["input"], list) |
| | assert isinstance(result["expected_output"], str) |
| |
|
| | compilation_status = compilation_error == "" |
| | if compilation_status: |
| | compilation_error = None |
| |
|
| | return {"compilation_status": compilation_status, "error_message": compilation_error, "results": results} |
| |
|
| |
|
| | class TimeoutException(Exception): |
| | pass |
| |
|
| |
|
| | def timeout_handler(signum, frame): |
| | log.info("alarm went off") |
| | |
| | raise TimeoutException |
| |
|
| |
|
| | signal.signal(signal.SIGALRM, timeout_handler) |
| |
|
| |
|
| | |
| | |
| | |
| | class Capturing(list): |
| | def __enter__(self): |
| | self._stdout = sys.stdout |
| | sys.stdout = self._stringio = StringIO() |
| | |
| | self._stringio.close = lambda x: 1 |
| | return self |
| |
|
| | def __exit__(self, *args): |
| | self.extend(self._stringio.getvalue().splitlines()) |
| | del self._stringio |
| | sys.stdout = self._stdout |
| |
|
| |
|
| | def run_test(code, inputs, outputs, timeout: int = 6000, debug=True, add_extra_imports=False, allow_truncated_io=True): |
| | """ |
| | runs the code and tries to match inputs and outputs |
| | the scraped testcases may be incomplete |
| | if allow_truncated_io==True, then we ignore an EOF exception at the end of the generated output |
| | """ |
| | |
| |
|
| | results = [] |
| |
|
| | if isinstance(code, list): |
| | tmp_test = code |
| | elif isinstance(code, str): |
| | tmp_test = code.split("\n") |
| | else: |
| | raise AssertionError("code must be provided as list of lines or string with \\n linebreaks.") |
| |
|
| | |
| | import_lines = [] |
| | future_import_lines = [] |
| | code_lines = [] |
| | for x in tmp_test: |
| | if (not x.startswith("from ")) and (not x.startswith("import ")): |
| | code_lines.append("\t" + x + "\n") |
| | else: |
| | if "__future__" in x: |
| | future_import_lines.append(x + "\n") |
| | else: |
| | import_lines.append(x + "\n") |
| |
|
| | |
| | new_test = "stdin = sys.stdin\nstdout = sys.stdout\n" |
| | new_test += '__name__="__main__"\n' |
| | new_test += "def code():\n" |
| | for line in code_lines: |
| | new_test += line |
| |
|
| | sol = "\n".join(future_import_lines) |
| | sol += "import sys\n" |
| | if add_extra_imports: |
| | sol += "import time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n" |
| | sol += "\n".join(import_lines) + "\n" + new_test |
| |
|
| | if debug: |
| | log.info(f"sol = {sol}") |
| | method_name = "code" |
| | signal.alarm(timeout) |
| |
|
| | |
| | sol_module = None |
| | try: |
| | sol_module = RuntimeModule.from_string("tmp_sol", "", sol) |
| | signal.alarm(0) |
| | except Exception as e: |
| | signal.alarm(0) |
| | if debug: |
| | log.info(f"type 1 compilation error = {e}") |
| | for inp, out in zip(inputs, outputs): |
| | |
| | results.append( |
| | { |
| | "status": False, |
| | "input": inp, |
| | "expected_output": out, |
| | "generated_output": None, |
| | "error_message": repr(e), |
| | } |
| | ) |
| | return repr(e), results |
| |
|
| | assert sol_module is not None |
| | signal.alarm(0) |
| |
|
| | try: |
| | method = getattr(sol_module, method_name) |
| | except: |
| | signal.alarm(0) |
| | e = sys.exc_info() |
| | log.info(f"unable to get function error = {e}") |
| |
|
| | for inp, out in zip(inputs, outputs): |
| | |
| | results.append( |
| | { |
| | "status": False, |
| | "input": inp, |
| | "expected_output": out, |
| | "generated_output": None, |
| | "error_message": repr(e), |
| | } |
| | ) |
| | return repr(e), results |
| |
|
| | |
| | |
| | for index, (test_input, reference_output) in enumerate(zip(inputs, outputs)): |
| |
|
| | result_object = { |
| | "input": test_input, |
| | "expected_output": reference_output, |
| | } |
| |
|
| | |
| | input_truncated = False |
| | if "".join(test_input).strip().endswith("...") and allow_truncated_io: |
| | test_input = test_input[:-1] |
| | input_truncated = True |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | error_code = None |
| | with Capturing() as generated_output: |
| | try: |
| | call_method(method, test_input) |
| | |
| | signal.alarm(0) |
| | except Exception as e: |
| | |
| | signal.alarm(0) |
| | error_code = e |
| | if debug: |
| | log.info(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}") |
| | signal.alarm(0) |
| |
|
| | |
| | |
| | if ( |
| | (input_truncated or reference_output.strip().endswith("...")) |
| | and allow_truncated_io |
| | and (error_code is None or isinstance(error_code, EOFError) or isinstance(error_code, ValueError)) |
| | ): |
| |
|
| | generated_output = generated_output[:-1] |
| | reference_output = reference_output.rstrip("...") |
| | if len(generated_output) == 0: |
| | |
| | result_object.update( |
| | **{ |
| | "status": True, |
| | "generated_output": "\n".join(generated_output), |
| | "error_message": None, |
| | } |
| | ) |
| | results.append(result_object) |
| | else: |
| | result_object.update( |
| | **{ |
| | "status": string_compare(generated_output, reference_output, True), |
| | "generated_output": "\n".join(generated_output), |
| | "error_message": None, |
| | } |
| | ) |
| | results.append(result_object) |
| |
|
| | |
| | elif error_code is not None: |
| | result_object.update(**{"status": False, "generated_output": None, "error_message": repr(error_code)}) |
| | results.append(result_object) |
| | |
| | else: |
| | |
| | result_object.update( |
| | **{ |
| | "status": string_compare(generated_output, reference_output, False), |
| | "generated_output": "\n".join(generated_output), |
| | "error_message": None, |
| | } |
| | ) |
| | results.append(result_object) |
| |
|
| | return "", results |
| |
|
| |
|
| | def string_compare(candidate, correct, truncate_output=False, floating_point_accuracy=0.01): |
| | candidate = [o.strip().lower() for o in candidate] |
| | correct = correct.strip().lower() |
| |
|
| | |
| | candidate = "\n".join(candidate) |
| | candidate = re.sub("\s+", " ", candidate).strip() |
| | correct = re.sub("\s+", " ", correct).strip() |
| |
|
| | |
| | candidate = candidate.split(" ") |
| | correct = correct.split(" ") |
| |
|
| | |
| | if not truncate_output: |
| | if not len(candidate) == len(correct): |
| | return False |
| |
|
| | |
| | if truncate_output: |
| | correct = correct[:-1] |
| |
|
| | |
| | for left, right in zip(candidate, correct): |
| | if left == right: |
| | continue |
| |
|
| | try: |
| | int_left = int(left) |
| | int_right = int(right) |
| | if int_left == int_right: |
| | continue |
| | except ValueError: |
| | pass |
| |
|
| | try: |
| | float_left = float(left) |
| | float_right = float(right) |
| | if np.abs(float_left - float_right) < floating_point_accuracy: |
| | continue |
| | except ValueError: |
| | pass |
| |
|
| | return False |
| |
|
| | return True |
| |
|
| |
|
| | def call_method(method, inputs): |
| | if isinstance(inputs, list): |
| | inputs = "\n".join(inputs) |
| |
|
| | inputs_line_iterator = iter(inputs.split("\n")) |
| |
|
| | |
| |
|
| | |
| | @patch("builtins.open", mock_open(read_data=inputs)) |
| | @patch("sys.stdin", StringIO(inputs)) |
| | @patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator)) |
| | @patch("sys.stdin.readlines", lambda *args: inputs.split("\n")) |
| | @patch("sys.stdin.read", lambda *args: inputs) |
| | |
| | def _inner_call_method(_method): |
| | try: |
| | return _method() |
| | except SystemExit as e: |
| | pass |
| | finally: |
| | pass |
| |
|
| | return _inner_call_method(method) |
| |
|