Spaces:
Sleeping
Sleeping
| """Grader: oracle implementations, language runners, and test harness. | |
| Oracle functions are the ground truth for every task. run_candidate_tests() | |
| dispatches to a language-specific runner (PythonRunner, JSRunner, RustRunner) | |
| that executes the agent's submission in a subprocess and compares its output | |
| against the oracle. | |
| Scoring helpers (score_progress, score_step_reward, clamp_open_unit) are used | |
| by EpisodeState to produce per-step rewards and the final episode score. | |
| """ | |
| from __future__ import annotations | |
| import glob | |
| import hashlib | |
| import json | |
| import os | |
| import shutil | |
| import subprocess | |
| import sys | |
| import tempfile | |
| import textwrap | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Protocol | |
| from .tasks import FunctionSpec, SampleCase, Task | |
| from .verification_semantics import ( | |
| lean_call as _lean_call_impl, | |
| lean_value as _lean_value_impl, | |
| oracle_result as _oracle_result_impl, | |
| ) | |
| RUN_TESTS_MARKER = "__LEANMIGRATE_RUN_TESTS__" | |
| MIN_DISPLAY_SCORE = 0.01 | |
| MAX_DISPLAY_SCORE = 0.99 | |
| def _json_equal(a: Any, b: Any) -> bool: | |
| """Compare two values treating tuples and lists as interchangeable. | |
| JSON deserialisation always produces lists, but Python oracles may return | |
| tuples. This normalises both sides so the comparison works regardless of | |
| which side uses tuples vs. lists. | |
| """ | |
| if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)): | |
| return len(a) == len(b) and all(_json_equal(x, y) for x, y in zip(a, b)) | |
| return a == b | |
| class TestCaseResult: | |
| passed: bool | |
| expected: Any | |
| actual: Any | None | |
| error: str | None = None | |
| class TestRunResult: | |
| passed: bool | |
| tests_passed: int | |
| tests_total: int | |
| feedback: str | |
| stdout: str = "" | |
| stderr: str = "" | |
| timed_out: bool = False | |
| case_results: list[TestCaseResult] | None = None | |
| def score_progress(verified_count: int, total_count: int) -> float: | |
| if total_count <= 0: | |
| return 0.0 | |
| return verified_count / total_count | |
| def score_step_reward( | |
| success: bool, total_count: int, failure_penalty: float = -0.05 | |
| ) -> float: | |
| if success and total_count > 0: | |
| return clamp_open_unit(1.0 / total_count) | |
| return failure_penalty | |
| def clamp_open_unit(value: float) -> float: | |
| if value <= MIN_DISPLAY_SCORE: | |
| return MIN_DISPLAY_SCORE | |
| if value >= MAX_DISPLAY_SCORE: | |
| return MAX_DISPLAY_SCORE | |
| return value | |
| def build_breakdown( | |
| functional: float, property_score: float | None = None, proof: float | None = None | |
| ) -> dict[str, float]: | |
| breakdown = {"functional": clamp_open_unit(functional)} | |
| if property_score is not None: | |
| breakdown["property"] = clamp_open_unit(property_score) | |
| if proof is not None: | |
| breakdown["proof"] = clamp_open_unit(proof) | |
| return breakdown | |
| def _oracle_rbac_find_role( | |
| roles: list[dict[str, Any]], name: str | |
| ) -> dict[str, Any] | None: | |
| return next((role for role in roles if role["name"] == name), None) | |
| def _oracle_rbac_has_direct_permission( | |
| role: dict[str, Any], resource: str, action: str | |
| ) -> bool: | |
| return any( | |
| permission["resource"] == resource and permission["action"] == action | |
| for permission in role.get("permissions", []) | |
| ) | |
| def _oracle_rbac_can_access( | |
| roles: list[dict[str, Any]], | |
| role_name: str, | |
| resource: str, | |
| action: str, | |
| depth: int = 5, | |
| ) -> bool: | |
| if depth == 0: | |
| return False | |
| role = _oracle_rbac_find_role(roles, role_name) | |
| if role is None: | |
| return False | |
| if _oracle_rbac_has_direct_permission(role, resource, action): | |
| return True | |
| return any( | |
| _oracle_rbac_can_access(roles, parent_name, resource, action, depth - 1) | |
| for parent_name in role.get("inherits", []) | |
| ) | |
| def _oracle_pricing_tax_rate_bps(region_id: str) -> int: | |
| return { | |
| "US-CA": 875, | |
| "US-TX": 625, | |
| "US-NY": 800, | |
| "UK": 2000, | |
| }.get(region_id, 0) | |
| def _oracle_pricing_subtotal(order: dict[str, Any]) -> int: | |
| return sum( | |
| int(item["unitPrice"]) * int(item["quantity"]) | |
| for item in order.get("items", []) | |
| ) | |
| def _oracle_pricing_coupon_discount(order: dict[str, Any]) -> int: | |
| subtotal = _oracle_pricing_subtotal(order) | |
| raw_discount = sum( | |
| (subtotal * int(coupon["discountPercent"])) // 100 | |
| for coupon in order.get("coupons", []) | |
| ) | |
| return min(raw_discount, subtotal // 2) | |
| def _oracle_pricing_loyalty_discount(order: dict[str, Any]) -> int: | |
| subtotal = _oracle_pricing_subtotal(order) | |
| return min(int(order.get("loyaltyPoints", 0)), subtotal // 10) | |
| def _oracle_pricing_total_discount(order: dict[str, Any]) -> int: | |
| return _oracle_pricing_coupon_discount(order) + _oracle_pricing_loyalty_discount( | |
| order | |
| ) | |
| def _oracle_pricing_tax(order: dict[str, Any]) -> int: | |
| subtotal = _oracle_pricing_subtotal(order) | |
| total_discount = _oracle_pricing_total_discount(order) | |
| after_discount = subtotal - total_discount | |
| return ( | |
| after_discount * _oracle_pricing_tax_rate_bps(order.get("regionId", "")) | |
| ) // 10000 | |
| def _oracle_pricing_final_price(order: dict[str, Any]) -> int: | |
| subtotal = _oracle_pricing_subtotal(order) | |
| total_discount = _oracle_pricing_total_discount(order) | |
| return subtotal - total_discount + _oracle_pricing_tax(order) | |
| def _oracle_saga_transition(state: str, event: str) -> str: | |
| if event == "Fail": | |
| return "Failed" | |
| transitions = { | |
| ("Idle", "Reserve"): "Reserved", | |
| ("Reserved", "Authorize"): "Authorized", | |
| ("Authorized", "Capture"): "Captured", | |
| ("Captured", "Settle"): "Settled", | |
| ("Reserved", "CompensateReserve"): "Compensated", | |
| ("Authorized", "CompensateAuthorize"): "Compensating", | |
| ("Compensating", "CompensateReserve"): "Compensated", | |
| ("Captured", "CompensateCapture"): "Compensating", | |
| } | |
| return transitions.get((state, event), state) | |
| def _oracle_saga_run(events: list[str]) -> str: | |
| state = "Idle" | |
| for event in events: | |
| state = _oracle_saga_transition(state, event) | |
| return state | |
| def _oracle_saga_is_charged(state: str) -> bool: | |
| return state in {"Captured", "Settled"} | |
| def oracle_result(task_id: str, function_name: str, args: tuple[Any, ...]) -> Any: | |
| return _oracle_result_impl(task_id, function_name, args) | |
| def _lean_string(value: str) -> str: | |
| return json.dumps(value) | |
| def _lean_bool(value: bool) -> str: | |
| return "true" if value else "false" | |
| def _lean_int(value: int) -> str: | |
| return str(int(value)) | |
| def _lean_list(items: list[str]) -> str: | |
| return "[" + ", ".join(items) + "]" | |
| def _lean_permission(permission: dict[str, Any]) -> str: | |
| return ( | |
| "{ resource := " | |
| + _lean_string(str(permission["resource"])) | |
| + ", action := " | |
| + _lean_string(str(permission["action"])) | |
| + " }" | |
| ) | |
| def _lean_role(role: dict[str, Any]) -> str: | |
| permissions = _lean_list( | |
| [_lean_permission(permission) for permission in role.get("permissions", [])] | |
| ) | |
| inherits = _lean_list([_lean_string(name) for name in role.get("inherits", [])]) | |
| return ( | |
| "{ name := " | |
| + _lean_string(str(role["name"])) | |
| + ", permissions := " | |
| + permissions | |
| + ", inherits := " | |
| + inherits | |
| + " }" | |
| ) | |
| def _lean_role_list(roles: list[dict[str, Any]]) -> str: | |
| return _lean_list([_lean_role(role) for role in roles]) | |
| def _lean_option_role(role: dict[str, Any] | None) -> str: | |
| if role is None: | |
| return "none" | |
| return "some (" + _lean_role(role) + " : AuthSpec.Role)" | |
| def _lean_item(item: dict[str, Any]) -> str: | |
| return ( | |
| "{ sku := " | |
| + _lean_string(str(item["sku"])) | |
| + ", quantity := " | |
| + _lean_int(int(item["quantity"])) | |
| + ", unitPrice := " | |
| + _lean_int(int(item["unitPrice"])) | |
| + " }" | |
| ) | |
| def _lean_coupon(coupon: dict[str, Any]) -> str: | |
| return ( | |
| "{ code := " | |
| + _lean_string(str(coupon["code"])) | |
| + ", discountPercent := " | |
| + _lean_int(int(coupon["discountPercent"])) | |
| + " }" | |
| ) | |
| def _lean_order(order: dict[str, Any]) -> str: | |
| items = _lean_list([_lean_item(item) for item in order.get("items", [])]) | |
| coupons = _lean_list([_lean_coupon(coupon) for coupon in order.get("coupons", [])]) | |
| return ( | |
| "{ items := " | |
| + items | |
| + ", coupons := " | |
| + coupons | |
| + ", regionId := " | |
| + _lean_string(str(order.get("regionId", ""))) | |
| + ", loyaltyPoints := " | |
| + _lean_int(int(order.get("loyaltyPoints", 0))) | |
| + " }" | |
| ) | |
| def _lean_saga_state(state: str) -> str: | |
| mapping = { | |
| "Idle": ".Idle", | |
| "Reserved": ".Reserved", | |
| "Authorized": ".Authorized", | |
| "Captured": ".Captured", | |
| "Settled": ".Settled", | |
| "Compensating": ".Compensating", | |
| "Compensated": ".Compensated", | |
| "Failed": ".Failed", | |
| } | |
| return mapping[state] | |
| def _lean_saga_event(event: str) -> str: | |
| mapping = { | |
| "Reserve": ".Reserve", | |
| "Authorize": ".Authorize", | |
| "Capture": ".Capture", | |
| "Settle": ".Settle", | |
| "CompensateReserve": ".CompensateReserve", | |
| "CompensateAuthorize": ".CompensateAuthorize", | |
| "CompensateCapture": ".CompensateCapture", | |
| "Fail": ".Fail", | |
| } | |
| return mapping[event] | |
| def _lean_value(task_id: str, function_name: str, value: Any) -> str: | |
| return _lean_value_impl(task_id, function_name, value) | |
| def _lean_call(task_id: str, function_name: str, args: tuple[Any, ...]) -> str: | |
| return _lean_call_impl(task_id, function_name, args) | |
| def _with_call_namespace(call_expr: str, call_namespace: str) -> str: | |
| if call_namespace == "_root_": | |
| return call_expr | |
| root_prefix = "_root_." | |
| if call_expr.startswith(root_prefix): | |
| return f"{call_namespace}.{call_expr[len(root_prefix):]}" | |
| return call_expr | |
| def build_lean_sample_checks( | |
| task: Task, function_spec: FunctionSpec, call_namespace: str = "_root_" | |
| ) -> list[str]: | |
| if function_spec.is_proof_required: | |
| return [] | |
| cases = task.sample_inputs.get(function_spec.name, []) | |
| checks: list[str] = [] | |
| for case in cases: | |
| expected_value = oracle_result(task.task_id, function_spec.name, case.args) | |
| call_expr = _with_call_namespace( | |
| _lean_call(task.task_id, function_spec.name, case.args), | |
| call_namespace, | |
| ) | |
| expected_expr = _lean_value(task.task_id, function_spec.name, expected_value) | |
| checks.append( | |
| textwrap.dedent( | |
| f""" | |
| -- Lean turns one runtime sample into a concrete theorem. | |
| -- native_decide works here because the mirror reduces the goal to a closed equality. | |
| example : {call_expr} = {expected_expr} := by | |
| native_decide | |
| """ | |
| ).strip() | |
| ) | |
| return checks | |
| def _parse_runner_output(stdout: str) -> list[dict[str, Any]] | None: | |
| marker_index = stdout.rfind(RUN_TESTS_MARKER) | |
| if marker_index < 0: | |
| return None | |
| payload = stdout[marker_index + len(RUN_TESTS_MARKER) :].strip() | |
| if not payload: | |
| return None | |
| try: | |
| return json.loads(payload) | |
| except json.JSONDecodeError: | |
| return None | |
| def _run_python_candidate( | |
| function_name: str, candidate_code: str, cases: list[SampleCase] | |
| ) -> tuple[list[dict[str, Any]] | None, subprocess.CompletedProcess[str] | None, bool]: | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| script_path = Path(temp_dir) / "candidate.py" | |
| harness = textwrap.dedent( | |
| f""" | |
| import json | |
| import os | |
| CASES = json.loads(os.environ["CASES_JSON"]) | |
| FUNCTION_NAME = os.environ["FUNCTION_NAME"] | |
| results = [] | |
| for case_args in CASES: | |
| try: | |
| value = globals()[FUNCTION_NAME](*case_args) | |
| results.append({{"ok": True, "value": value}}) | |
| except Exception as error: | |
| results.append({{"ok": False, "error": repr(error)}}) | |
| print("{RUN_TESTS_MARKER}" + json.dumps(results, default=repr)) | |
| """ | |
| ).strip() | |
| script_path.write_text(candidate_code.rstrip() + "\n\n" + harness + "\n") | |
| process = subprocess.run( | |
| [sys.executable, str(script_path)], | |
| capture_output=True, | |
| text=True, | |
| timeout=5, | |
| env={ | |
| **os.environ, | |
| "FUNCTION_NAME": function_name, | |
| "CASES_JSON": json.dumps([list(case.args) for case in cases]), | |
| }, | |
| ) | |
| return _parse_runner_output(process.stdout), process, False | |
| def _normalize_typescript_candidate(candidate_code: str) -> str: | |
| normalized_lines: list[str] = [] | |
| for line in candidate_code.splitlines(): | |
| if line.lstrip().startswith("export "): | |
| line = line.replace("export ", "", 1) | |
| normalized_lines.append(line) | |
| return "\n".join(normalized_lines) | |
| def _find_cargo() -> str | None: | |
| """Locate cargo, checking system PATH then rustup default install location.""" | |
| rt = shutil.which("cargo") | |
| if rt: | |
| return rt | |
| # rustup installs outside the system PATH — check CARGO_HOME and the default | |
| cargo_home = os.environ.get("CARGO_HOME", os.path.expanduser("~/.cargo")) | |
| candidate = os.path.join(cargo_home, "bin", "cargo") | |
| if os.path.isfile(candidate) and os.access(candidate, os.X_OK): | |
| return candidate | |
| return None | |
| def _find_tsx() -> str | None: | |
| """Locate tsx or ts-node, checking system PATH then NVM directories.""" | |
| rt = shutil.which("tsx") or shutil.which("ts-node") | |
| if rt: | |
| return rt | |
| # NVM installs binaries outside the system PATH — scan versioned bin dirs. | |
| nvm_bin_dirs = sorted( | |
| glob.glob(os.path.expanduser("~/.nvm/versions/node/*/bin")), | |
| reverse=True, # newest version first | |
| ) | |
| for bin_dir in nvm_bin_dirs: | |
| for name in ("tsx", "ts-node"): | |
| candidate = os.path.join(bin_dir, name) | |
| if os.path.isfile(candidate) and os.access(candidate, os.X_OK): | |
| return candidate | |
| return shutil.which("node") # plain node as last resort (no TS support) | |
| def _run_js_candidate( | |
| function_name: str, candidate_code: str, cases: list[SampleCase] | |
| ) -> tuple[list[dict[str, Any]] | None, subprocess.CompletedProcess[str] | None, bool]: | |
| runtime = _find_tsx() | |
| if runtime is None: | |
| return None, None, False | |
| runtime_name = Path(runtime).name | |
| supports_typescript = runtime_name in {"tsx", "ts-node"} | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| script_path = Path(temp_dir) / ( | |
| "candidate.ts" if supports_typescript else "candidate.cjs" | |
| ) | |
| harness = textwrap.dedent( | |
| f""" | |
| const cases = JSON.parse(process.env.CASES_JSON || "[]"); | |
| const functionName = process.env.FUNCTION_NAME; | |
| const results = []; | |
| for (const caseArgs of cases) {{ | |
| try {{ | |
| const value = eval(functionName)(...caseArgs); | |
| results.push({{ ok: true, value }}); | |
| }} catch (error) {{ | |
| results.push({{ ok: false, error: String(error) }}); | |
| }} | |
| }} | |
| console.log("{RUN_TESTS_MARKER}" + JSON.stringify(results)); | |
| """ | |
| ).strip() | |
| code = candidate_code.rstrip() | |
| if not supports_typescript: | |
| code = _normalize_typescript_candidate(code) | |
| script_path.write_text(code + "\n\n" + harness + "\n") | |
| command = [runtime] | |
| if runtime_name == "ts-node": | |
| command.extend(["--transpile-only", str(script_path)]) | |
| else: | |
| command.append(str(script_path)) | |
| process = subprocess.run( | |
| command, | |
| capture_output=True, | |
| text=True, | |
| timeout=5, | |
| env={ | |
| **os.environ, | |
| "FUNCTION_NAME": function_name, | |
| "CASES_JSON": json.dumps([list(case.args) for case in cases]), | |
| }, | |
| ) | |
| return _parse_runner_output(process.stdout), process, False | |
| _RUST_CACHE_DIR = Path("/tmp/lean_migrate_rust_cache") | |
| _RUST_CARGO_TOML = textwrap.dedent( | |
| """ | |
| [package] | |
| name = "candidate" | |
| version = "0.1.0" | |
| edition = "2021" | |
| [dependencies] | |
| serde_json = { version = "1", features = ["preserve_order"] } | |
| serde = { version = "1", features = ["derive"] } | |
| [[bin]] | |
| name = "candidate" | |
| path = "src/main.rs" | |
| """ | |
| ).strip() | |
| _RUST_MAIN_TEMPLATE = textwrap.dedent( | |
| """ | |
| use std::env; | |
| {candidate_code} | |
| fn main() {{ | |
| let cases_json = env::var("CASES_JSON").unwrap_or_default(); | |
| let cases: Vec<Vec<serde_json::Value>> = serde_json::from_str(&cases_json).unwrap_or_default(); | |
| let mut results = Vec::new(); | |
| for args in &cases {{ | |
| let result = std::panic::catch_unwind(|| -> serde_json::Value {{ | |
| {dispatch} | |
| }}); | |
| match result {{ | |
| Ok(v) => results.push(serde_json::json!({{"ok": true, "value": v}})), | |
| Err(_) => results.push(serde_json::json!({{"ok": false, "error": "panic"}})), | |
| }} | |
| }} | |
| println!("{marker}" + &serde_json::to_string(&results).unwrap()); | |
| }} | |
| """ | |
| ).strip() | |
| def _build_rust_binary( | |
| code_hash: str, | |
| candidate_code: str, | |
| rust_dispatch: str, | |
| cargo: str, | |
| ) -> tuple[Path | None, str]: | |
| """Build the candidate binary. Returns (binary_path, stderr). binary_path is None on failure.""" | |
| build_dir = _RUST_CACHE_DIR / code_hash | |
| binary = build_dir / "target" / "release" / "candidate" | |
| if binary.exists(): | |
| return binary, "" | |
| build_dir.mkdir(parents=True, exist_ok=True) | |
| (build_dir / "Cargo.toml").write_text(_RUST_CARGO_TOML) | |
| src_dir = build_dir / "src" | |
| src_dir.mkdir(exist_ok=True) | |
| main_rs = ( | |
| "use std::panic;\n" | |
| + candidate_code.rstrip() | |
| + "\n\nfn main() {\n" | |
| + " let cases_json = std::env::var(\"CASES_JSON\").unwrap_or_default();\n" | |
| + " let cases: Vec<Vec<serde_json::Value>> = serde_json::from_str(&cases_json).unwrap_or_default();\n" | |
| + " let mut results: Vec<serde_json::Value> = Vec::new();\n" | |
| + " for args in &cases {\n" | |
| + " let result = panic::catch_unwind(|| -> serde_json::Value {\n" | |
| + textwrap.indent(rust_dispatch.strip(), " ") | |
| + "\n });\n" | |
| + " match result {\n" | |
| + f' Ok(v) => results.push(serde_json::json!({{"ok": true, "value": v}})),\n' | |
| + f' Err(_) => results.push(serde_json::json!({{"ok": false, "error": "panic"}})),\n' | |
| + " }\n" | |
| + " }\n" | |
| + f' println!("{RUN_TESTS_MARKER}{{}}", serde_json::to_string(&results).unwrap());\n' | |
| + "}\n" | |
| ) | |
| (src_dir / "main.rs").write_text(main_rs) | |
| proc = subprocess.run( | |
| [cargo, "build", "--release", "--quiet"], | |
| cwd=str(build_dir), | |
| capture_output=True, | |
| text=True, | |
| timeout=120, | |
| ) | |
| if proc.returncode != 0: | |
| return None, proc.stderr | |
| return binary, "" | |
| def _run_rust_binary( | |
| binary: Path, cases: list[SampleCase] | |
| ) -> tuple[list[dict[str, Any]] | None, subprocess.CompletedProcess[str] | None, bool]: | |
| process = subprocess.run( | |
| [str(binary)], | |
| capture_output=True, | |
| text=True, | |
| timeout=10, | |
| env={ | |
| **os.environ, | |
| "CASES_JSON": json.dumps([list(case.args) for case in cases]), | |
| }, | |
| ) | |
| return _parse_runner_output(process.stdout), process, False | |
| def _run_rust_candidate( | |
| function_spec: FunctionSpec, | |
| candidate_code: str, | |
| cases: list[SampleCase], | |
| ) -> tuple[list[dict[str, Any]] | None, subprocess.CompletedProcess[str] | None, bool]: | |
| cargo = _find_cargo() | |
| if cargo is None: | |
| return None, None, False | |
| if not function_spec.rust_dispatch: | |
| return None, None, False | |
| code_hash = hashlib.sha256( | |
| (candidate_code + function_spec.rust_dispatch).encode() | |
| ).hexdigest()[:16] | |
| try: | |
| binary, build_stderr = _build_rust_binary( | |
| code_hash, candidate_code, function_spec.rust_dispatch, cargo | |
| ) | |
| except subprocess.TimeoutExpired: | |
| return None, None, True | |
| if binary is None or not binary.exists(): | |
| # Surface the actual compiler error so the agent can act on it | |
| fake = subprocess.CompletedProcess( | |
| args=[], returncode=1, stdout="", stderr=build_stderr | |
| ) | |
| return None, fake, False | |
| try: | |
| return _run_rust_binary(binary, cases) | |
| except subprocess.TimeoutExpired: | |
| return None, None, True | |
| class LanguageRunner(Protocol): | |
| def run( | |
| self, | |
| function_spec: FunctionSpec, | |
| candidate_code: str, | |
| cases: list[SampleCase], | |
| ) -> tuple[list[dict[str, Any]] | None, subprocess.CompletedProcess[str] | None, bool]: ... | |
| class PythonRunner: | |
| def run( | |
| self, | |
| function_spec: FunctionSpec, | |
| candidate_code: str, | |
| cases: list[SampleCase], | |
| ) -> tuple[list[dict[str, Any]] | None, subprocess.CompletedProcess[str] | None, bool]: | |
| return _run_python_candidate(function_spec.name, candidate_code, cases) | |
| class JSRunner: | |
| def run( | |
| self, | |
| function_spec: FunctionSpec, | |
| candidate_code: str, | |
| cases: list[SampleCase], | |
| ) -> tuple[list[dict[str, Any]] | None, subprocess.CompletedProcess[str] | None, bool]: | |
| return _run_js_candidate(function_spec.name, candidate_code, cases) | |
| class RustRunner: | |
| def run( | |
| self, | |
| function_spec: FunctionSpec, | |
| candidate_code: str, | |
| cases: list[SampleCase], | |
| ) -> tuple[list[dict[str, Any]] | None, subprocess.CompletedProcess[str] | None, bool]: | |
| return _run_rust_candidate(function_spec, candidate_code, cases) | |
| _RUNNERS: dict[str, LanguageRunner] = { | |
| "python": PythonRunner(), | |
| "typescript": JSRunner(), | |
| "rust": RustRunner(), | |
| } | |
| def run_candidate_tests( | |
| task: Task, function_spec: FunctionSpec, candidate_code: str | |
| ) -> TestRunResult: | |
| cases = task.sample_inputs.get(function_spec.name, []) | |
| if function_spec.is_proof_required: | |
| return TestRunResult( | |
| passed=True, | |
| tests_passed=0, | |
| tests_total=0, | |
| feedback=( | |
| f"run_tests skipped for '{function_spec.name}': proof-only functions are verified on submit." | |
| ), | |
| case_results=[], | |
| ) | |
| if not cases: | |
| return TestRunResult( | |
| passed=True, | |
| tests_passed=0, | |
| tests_total=0, | |
| feedback=f"run_tests skipped for '{function_spec.name}': no sample cases are defined.", | |
| case_results=[], | |
| ) | |
| runner = _RUNNERS.get(task.target_language) | |
| if runner is None: | |
| return TestRunResult( | |
| passed=False, | |
| tests_passed=0, | |
| tests_total=len(cases), | |
| feedback=f"run_tests: no runner registered for language '{task.target_language}'.", | |
| case_results=[], | |
| ) | |
| try: | |
| runner_output, process, timed_out = runner.run(function_spec, candidate_code, cases) | |
| if timed_out: | |
| return TestRunResult( | |
| passed=False, | |
| tests_passed=0, | |
| tests_total=len(cases), | |
| feedback=f"run_tests timed out for '{function_spec.name}' (build or execution).", | |
| timed_out=True, | |
| case_results=[], | |
| ) | |
| except subprocess.TimeoutExpired as error: | |
| return TestRunResult( | |
| passed=False, | |
| tests_passed=0, | |
| tests_total=len(cases), | |
| feedback=f"run_tests timed out for '{function_spec.name}' after 5 seconds.", | |
| stderr=str(error), | |
| timed_out=True, | |
| case_results=[], | |
| ) | |
| if process is None: | |
| _runtime_hints: dict[str, str] = { | |
| "rust": "cargo not found — install the Rust toolchain (https://rustup.rs) or check PATH.", | |
| "typescript": "Install node + tsx or ts-node for TypeScript tasks.", | |
| "python": "Python subprocess runner failed to start.", | |
| } | |
| hint = _runtime_hints.get( | |
| task.target_language, | |
| f"No runner found for '{task.target_language}'.", | |
| ) | |
| return TestRunResult( | |
| passed=False, | |
| tests_passed=0, | |
| tests_total=len(cases), | |
| feedback=f"run_tests could not find a runtime for '{function_spec.name}'. {hint}", | |
| case_results=[], | |
| ) | |
| if process.returncode != 0 or runner_output is None: | |
| stderr = process.stderr.strip() | |
| stdout = process.stdout.strip() | |
| if task.target_language == "rust" and process.args == []: | |
| feedback = ( | |
| f"run_tests failed for '{function_spec.name}'.\n" | |
| f"Rust build failed (compiler error).\n" | |
| f"stderr: {stderr[:300]}" | |
| ) | |
| else: | |
| feedback = ( | |
| f"run_tests failed for '{function_spec.name}'.\n" | |
| f"Process exit code: {process.returncode}\n" | |
| f"stderr: {stderr[:300]}" | |
| ) | |
| return TestRunResult( | |
| passed=False, | |
| tests_passed=0, | |
| tests_total=len(cases), | |
| feedback=feedback, | |
| stdout=stdout, | |
| stderr=stderr, | |
| case_results=[], | |
| ) | |
| expected_values = [ | |
| oracle_result(task.task_id, function_spec.name, case.args) for case in cases | |
| ] | |
| comparisons: list[TestCaseResult] = [] | |
| for expected_value, actual in zip(expected_values, runner_output, strict=True): | |
| if not actual.get("ok", False): | |
| comparisons.append( | |
| TestCaseResult( | |
| passed=False, | |
| expected=expected_value, | |
| actual=None, | |
| error=str(actual.get("error", "Unknown error")), | |
| ) | |
| ) | |
| continue | |
| actual_value = actual.get("value") | |
| comparisons.append( | |
| TestCaseResult( | |
| passed=_json_equal(actual_value, expected_value), | |
| expected=expected_value, | |
| actual=actual_value, | |
| ) | |
| ) | |
| passed_count = sum(1 for result in comparisons if result.passed) | |
| cases_with_results = list(zip(cases, comparisons)) | |
| failed_pairs = [(case, r) for case, r in cases_with_results if not r.passed] | |
| feedback_lines = [ | |
| f"run_tests for '{function_spec.name}': {passed_count}/{len(cases)} cases passed.", | |
| ] | |
| if failed_pairs: | |
| feedback_lines.append("Failures:") | |
| for index, (case, result) in enumerate(failed_pairs[:3], start=1): | |
| args_repr = ", ".join(repr(a) for a in case.args) | |
| if result.error: | |
| feedback_lines.append( | |
| f" {index}. input=({args_repr}) error: {result.error}" | |
| ) | |
| else: | |
| feedback_lines.append( | |
| f" {index}. input=({args_repr}) expected={result.expected!r} actual={result.actual!r}" | |
| ) | |
| return TestRunResult( | |
| passed=passed_count == len(cases), | |
| tests_passed=passed_count, | |
| tests_total=len(cases), | |
| feedback="\n".join(feedback_lines), | |
| stdout=process.stdout, | |
| stderr=process.stderr, | |
| case_results=comparisons, | |
| ) | |