"""Grader: oracle implementations, language runners, and test harness. Oracle functions are the ground truth for every task. run_candidate_tests() dispatches to a language-specific runner (PythonRunner, JSRunner, RustRunner) that executes the agent's submission in a subprocess and compares its output against the oracle. Scoring helpers (score_progress, score_step_reward, clamp_open_unit) are used by EpisodeState to produce per-step rewards and the final episode score. """ from __future__ import annotations import glob import hashlib import json import os import shutil import subprocess import sys import tempfile import textwrap from dataclasses import dataclass from pathlib import Path from typing import Any, Protocol from .tasks import FunctionSpec, SampleCase, Task from .verification_semantics import ( lean_call as _lean_call_impl, lean_value as _lean_value_impl, oracle_result as _oracle_result_impl, ) RUN_TESTS_MARKER = "__LEANMIGRATE_RUN_TESTS__" MIN_DISPLAY_SCORE = 0.01 MAX_DISPLAY_SCORE = 0.99 def _json_equal(a: Any, b: Any) -> bool: """Compare two values treating tuples and lists as interchangeable. JSON deserialisation always produces lists, but Python oracles may return tuples. This normalises both sides so the comparison works regardless of which side uses tuples vs. lists. """ if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)): return len(a) == len(b) and all(_json_equal(x, y) for x, y in zip(a, b)) return a == b @dataclass(frozen=True) class TestCaseResult: passed: bool expected: Any actual: Any | None error: str | None = None @dataclass(frozen=True) class TestRunResult: passed: bool tests_passed: int tests_total: int feedback: str stdout: str = "" stderr: str = "" timed_out: bool = False case_results: list[TestCaseResult] | None = None def score_progress(verified_count: int, total_count: int) -> float: if total_count <= 0: return 0.0 return verified_count / total_count def score_step_reward( success: bool, total_count: int, failure_penalty: float = -0.05 ) -> float: if success and total_count > 0: return clamp_open_unit(1.0 / total_count) return failure_penalty def clamp_open_unit(value: float) -> float: if value <= MIN_DISPLAY_SCORE: return MIN_DISPLAY_SCORE if value >= MAX_DISPLAY_SCORE: return MAX_DISPLAY_SCORE return value def build_breakdown( functional: float, property_score: float | None = None, proof: float | None = None ) -> dict[str, float]: breakdown = {"functional": clamp_open_unit(functional)} if property_score is not None: breakdown["property"] = clamp_open_unit(property_score) if proof is not None: breakdown["proof"] = clamp_open_unit(proof) return breakdown def _oracle_rbac_find_role( roles: list[dict[str, Any]], name: str ) -> dict[str, Any] | None: return next((role for role in roles if role["name"] == name), None) def _oracle_rbac_has_direct_permission( role: dict[str, Any], resource: str, action: str ) -> bool: return any( permission["resource"] == resource and permission["action"] == action for permission in role.get("permissions", []) ) def _oracle_rbac_can_access( roles: list[dict[str, Any]], role_name: str, resource: str, action: str, depth: int = 5, ) -> bool: if depth == 0: return False role = _oracle_rbac_find_role(roles, role_name) if role is None: return False if _oracle_rbac_has_direct_permission(role, resource, action): return True return any( _oracle_rbac_can_access(roles, parent_name, resource, action, depth - 1) for parent_name in role.get("inherits", []) ) def _oracle_pricing_tax_rate_bps(region_id: str) -> int: return { "US-CA": 875, "US-TX": 625, "US-NY": 800, "UK": 2000, }.get(region_id, 0) def _oracle_pricing_subtotal(order: dict[str, Any]) -> int: return sum( int(item["unitPrice"]) * int(item["quantity"]) for item in order.get("items", []) ) def _oracle_pricing_coupon_discount(order: dict[str, Any]) -> int: subtotal = _oracle_pricing_subtotal(order) raw_discount = sum( (subtotal * int(coupon["discountPercent"])) // 100 for coupon in order.get("coupons", []) ) return min(raw_discount, subtotal // 2) def _oracle_pricing_loyalty_discount(order: dict[str, Any]) -> int: subtotal = _oracle_pricing_subtotal(order) return min(int(order.get("loyaltyPoints", 0)), subtotal // 10) def _oracle_pricing_total_discount(order: dict[str, Any]) -> int: return _oracle_pricing_coupon_discount(order) + _oracle_pricing_loyalty_discount( order ) def _oracle_pricing_tax(order: dict[str, Any]) -> int: subtotal = _oracle_pricing_subtotal(order) total_discount = _oracle_pricing_total_discount(order) after_discount = subtotal - total_discount return ( after_discount * _oracle_pricing_tax_rate_bps(order.get("regionId", "")) ) // 10000 def _oracle_pricing_final_price(order: dict[str, Any]) -> int: subtotal = _oracle_pricing_subtotal(order) total_discount = _oracle_pricing_total_discount(order) return subtotal - total_discount + _oracle_pricing_tax(order) def _oracle_saga_transition(state: str, event: str) -> str: if event == "Fail": return "Failed" transitions = { ("Idle", "Reserve"): "Reserved", ("Reserved", "Authorize"): "Authorized", ("Authorized", "Capture"): "Captured", ("Captured", "Settle"): "Settled", ("Reserved", "CompensateReserve"): "Compensated", ("Authorized", "CompensateAuthorize"): "Compensating", ("Compensating", "CompensateReserve"): "Compensated", ("Captured", "CompensateCapture"): "Compensating", } return transitions.get((state, event), state) def _oracle_saga_run(events: list[str]) -> str: state = "Idle" for event in events: state = _oracle_saga_transition(state, event) return state def _oracle_saga_is_charged(state: str) -> bool: return state in {"Captured", "Settled"} def oracle_result(task_id: str, function_name: str, args: tuple[Any, ...]) -> Any: return _oracle_result_impl(task_id, function_name, args) def _lean_string(value: str) -> str: return json.dumps(value) def _lean_bool(value: bool) -> str: return "true" if value else "false" def _lean_int(value: int) -> str: return str(int(value)) def _lean_list(items: list[str]) -> str: return "[" + ", ".join(items) + "]" def _lean_permission(permission: dict[str, Any]) -> str: return ( "{ resource := " + _lean_string(str(permission["resource"])) + ", action := " + _lean_string(str(permission["action"])) + " }" ) def _lean_role(role: dict[str, Any]) -> str: permissions = _lean_list( [_lean_permission(permission) for permission in role.get("permissions", [])] ) inherits = _lean_list([_lean_string(name) for name in role.get("inherits", [])]) return ( "{ name := " + _lean_string(str(role["name"])) + ", permissions := " + permissions + ", inherits := " + inherits + " }" ) def _lean_role_list(roles: list[dict[str, Any]]) -> str: return _lean_list([_lean_role(role) for role in roles]) def _lean_option_role(role: dict[str, Any] | None) -> str: if role is None: return "none" return "some (" + _lean_role(role) + " : AuthSpec.Role)" def _lean_item(item: dict[str, Any]) -> str: return ( "{ sku := " + _lean_string(str(item["sku"])) + ", quantity := " + _lean_int(int(item["quantity"])) + ", unitPrice := " + _lean_int(int(item["unitPrice"])) + " }" ) def _lean_coupon(coupon: dict[str, Any]) -> str: return ( "{ code := " + _lean_string(str(coupon["code"])) + ", discountPercent := " + _lean_int(int(coupon["discountPercent"])) + " }" ) def _lean_order(order: dict[str, Any]) -> str: items = _lean_list([_lean_item(item) for item in order.get("items", [])]) coupons = _lean_list([_lean_coupon(coupon) for coupon in order.get("coupons", [])]) return ( "{ items := " + items + ", coupons := " + coupons + ", regionId := " + _lean_string(str(order.get("regionId", ""))) + ", loyaltyPoints := " + _lean_int(int(order.get("loyaltyPoints", 0))) + " }" ) def _lean_saga_state(state: str) -> str: mapping = { "Idle": ".Idle", "Reserved": ".Reserved", "Authorized": ".Authorized", "Captured": ".Captured", "Settled": ".Settled", "Compensating": ".Compensating", "Compensated": ".Compensated", "Failed": ".Failed", } return mapping[state] def _lean_saga_event(event: str) -> str: mapping = { "Reserve": ".Reserve", "Authorize": ".Authorize", "Capture": ".Capture", "Settle": ".Settle", "CompensateReserve": ".CompensateReserve", "CompensateAuthorize": ".CompensateAuthorize", "CompensateCapture": ".CompensateCapture", "Fail": ".Fail", } return mapping[event] def _lean_value(task_id: str, function_name: str, value: Any) -> str: return _lean_value_impl(task_id, function_name, value) def _lean_call(task_id: str, function_name: str, args: tuple[Any, ...]) -> str: return _lean_call_impl(task_id, function_name, args) def _with_call_namespace(call_expr: str, call_namespace: str) -> str: if call_namespace == "_root_": return call_expr root_prefix = "_root_." if call_expr.startswith(root_prefix): return f"{call_namespace}.{call_expr[len(root_prefix):]}" return call_expr def build_lean_sample_checks( task: Task, function_spec: FunctionSpec, call_namespace: str = "_root_" ) -> list[str]: if function_spec.is_proof_required: return [] cases = task.sample_inputs.get(function_spec.name, []) checks: list[str] = [] for case in cases: expected_value = oracle_result(task.task_id, function_spec.name, case.args) call_expr = _with_call_namespace( _lean_call(task.task_id, function_spec.name, case.args), call_namespace, ) expected_expr = _lean_value(task.task_id, function_spec.name, expected_value) checks.append( textwrap.dedent( f""" -- Lean turns one runtime sample into a concrete theorem. -- native_decide works here because the mirror reduces the goal to a closed equality. example : {call_expr} = {expected_expr} := by native_decide """ ).strip() ) return checks def _parse_runner_output(stdout: str) -> list[dict[str, Any]] | None: marker_index = stdout.rfind(RUN_TESTS_MARKER) if marker_index < 0: return None payload = stdout[marker_index + len(RUN_TESTS_MARKER) :].strip() if not payload: return None try: return json.loads(payload) except json.JSONDecodeError: return None def _run_python_candidate( function_name: str, candidate_code: str, cases: list[SampleCase] ) -> tuple[list[dict[str, Any]] | None, subprocess.CompletedProcess[str] | None, bool]: with tempfile.TemporaryDirectory() as temp_dir: script_path = Path(temp_dir) / "candidate.py" harness = textwrap.dedent( f""" import json import os CASES = json.loads(os.environ["CASES_JSON"]) FUNCTION_NAME = os.environ["FUNCTION_NAME"] results = [] for case_args in CASES: try: value = globals()[FUNCTION_NAME](*case_args) results.append({{"ok": True, "value": value}}) except Exception as error: results.append({{"ok": False, "error": repr(error)}}) print("{RUN_TESTS_MARKER}" + json.dumps(results, default=repr)) """ ).strip() script_path.write_text(candidate_code.rstrip() + "\n\n" + harness + "\n") process = subprocess.run( [sys.executable, str(script_path)], capture_output=True, text=True, timeout=5, env={ **os.environ, "FUNCTION_NAME": function_name, "CASES_JSON": json.dumps([list(case.args) for case in cases]), }, ) return _parse_runner_output(process.stdout), process, False def _normalize_typescript_candidate(candidate_code: str) -> str: normalized_lines: list[str] = [] for line in candidate_code.splitlines(): if line.lstrip().startswith("export "): line = line.replace("export ", "", 1) normalized_lines.append(line) return "\n".join(normalized_lines) def _find_cargo() -> str | None: """Locate cargo, checking system PATH then rustup default install location.""" rt = shutil.which("cargo") if rt: return rt # rustup installs outside the system PATH — check CARGO_HOME and the default cargo_home = os.environ.get("CARGO_HOME", os.path.expanduser("~/.cargo")) candidate = os.path.join(cargo_home, "bin", "cargo") if os.path.isfile(candidate) and os.access(candidate, os.X_OK): return candidate return None def _find_tsx() -> str | None: """Locate tsx or ts-node, checking system PATH then NVM directories.""" rt = shutil.which("tsx") or shutil.which("ts-node") if rt: return rt # NVM installs binaries outside the system PATH — scan versioned bin dirs. nvm_bin_dirs = sorted( glob.glob(os.path.expanduser("~/.nvm/versions/node/*/bin")), reverse=True, # newest version first ) for bin_dir in nvm_bin_dirs: for name in ("tsx", "ts-node"): candidate = os.path.join(bin_dir, name) if os.path.isfile(candidate) and os.access(candidate, os.X_OK): return candidate return shutil.which("node") # plain node as last resort (no TS support) def _run_js_candidate( function_name: str, candidate_code: str, cases: list[SampleCase] ) -> tuple[list[dict[str, Any]] | None, subprocess.CompletedProcess[str] | None, bool]: runtime = _find_tsx() if runtime is None: return None, None, False runtime_name = Path(runtime).name supports_typescript = runtime_name in {"tsx", "ts-node"} with tempfile.TemporaryDirectory() as temp_dir: script_path = Path(temp_dir) / ( "candidate.ts" if supports_typescript else "candidate.cjs" ) harness = textwrap.dedent( f""" const cases = JSON.parse(process.env.CASES_JSON || "[]"); const functionName = process.env.FUNCTION_NAME; const results = []; for (const caseArgs of cases) {{ try {{ const value = eval(functionName)(...caseArgs); results.push({{ ok: true, value }}); }} catch (error) {{ results.push({{ ok: false, error: String(error) }}); }} }} console.log("{RUN_TESTS_MARKER}" + JSON.stringify(results)); """ ).strip() code = candidate_code.rstrip() if not supports_typescript: code = _normalize_typescript_candidate(code) script_path.write_text(code + "\n\n" + harness + "\n") command = [runtime] if runtime_name == "ts-node": command.extend(["--transpile-only", str(script_path)]) else: command.append(str(script_path)) process = subprocess.run( command, capture_output=True, text=True, timeout=5, env={ **os.environ, "FUNCTION_NAME": function_name, "CASES_JSON": json.dumps([list(case.args) for case in cases]), }, ) return _parse_runner_output(process.stdout), process, False _RUST_CACHE_DIR = Path("/tmp/lean_migrate_rust_cache") _RUST_CARGO_TOML = textwrap.dedent( """ [package] name = "candidate" version = "0.1.0" edition = "2021" [dependencies] serde_json = { version = "1", features = ["preserve_order"] } serde = { version = "1", features = ["derive"] } [[bin]] name = "candidate" path = "src/main.rs" """ ).strip() _RUST_MAIN_TEMPLATE = textwrap.dedent( """ use std::env; {candidate_code} fn main() {{ let cases_json = env::var("CASES_JSON").unwrap_or_default(); let cases: Vec> = serde_json::from_str(&cases_json).unwrap_or_default(); let mut results = Vec::new(); for args in &cases {{ let result = std::panic::catch_unwind(|| -> serde_json::Value {{ {dispatch} }}); match result {{ Ok(v) => results.push(serde_json::json!({{"ok": true, "value": v}})), Err(_) => results.push(serde_json::json!({{"ok": false, "error": "panic"}})), }} }} println!("{marker}" + &serde_json::to_string(&results).unwrap()); }} """ ).strip() def _build_rust_binary( code_hash: str, candidate_code: str, rust_dispatch: str, cargo: str, ) -> tuple[Path | None, str]: """Build the candidate binary. Returns (binary_path, stderr). binary_path is None on failure.""" build_dir = _RUST_CACHE_DIR / code_hash binary = build_dir / "target" / "release" / "candidate" if binary.exists(): return binary, "" build_dir.mkdir(parents=True, exist_ok=True) (build_dir / "Cargo.toml").write_text(_RUST_CARGO_TOML) src_dir = build_dir / "src" src_dir.mkdir(exist_ok=True) main_rs = ( "use std::panic;\n" + candidate_code.rstrip() + "\n\nfn main() {\n" + " let cases_json = std::env::var(\"CASES_JSON\").unwrap_or_default();\n" + " let cases: Vec> = serde_json::from_str(&cases_json).unwrap_or_default();\n" + " let mut results: Vec = Vec::new();\n" + " for args in &cases {\n" + " let result = panic::catch_unwind(|| -> serde_json::Value {\n" + textwrap.indent(rust_dispatch.strip(), " ") + "\n });\n" + " match result {\n" + f' Ok(v) => results.push(serde_json::json!({{"ok": true, "value": v}})),\n' + f' Err(_) => results.push(serde_json::json!({{"ok": false, "error": "panic"}})),\n' + " }\n" + " }\n" + f' println!("{RUN_TESTS_MARKER}{{}}", serde_json::to_string(&results).unwrap());\n' + "}\n" ) (src_dir / "main.rs").write_text(main_rs) proc = subprocess.run( [cargo, "build", "--release", "--quiet"], cwd=str(build_dir), capture_output=True, text=True, timeout=120, ) if proc.returncode != 0: return None, proc.stderr return binary, "" def _run_rust_binary( binary: Path, cases: list[SampleCase] ) -> tuple[list[dict[str, Any]] | None, subprocess.CompletedProcess[str] | None, bool]: process = subprocess.run( [str(binary)], capture_output=True, text=True, timeout=10, env={ **os.environ, "CASES_JSON": json.dumps([list(case.args) for case in cases]), }, ) return _parse_runner_output(process.stdout), process, False def _run_rust_candidate( function_spec: FunctionSpec, candidate_code: str, cases: list[SampleCase], ) -> tuple[list[dict[str, Any]] | None, subprocess.CompletedProcess[str] | None, bool]: cargo = _find_cargo() if cargo is None: return None, None, False if not function_spec.rust_dispatch: return None, None, False code_hash = hashlib.sha256( (candidate_code + function_spec.rust_dispatch).encode() ).hexdigest()[:16] try: binary, build_stderr = _build_rust_binary( code_hash, candidate_code, function_spec.rust_dispatch, cargo ) except subprocess.TimeoutExpired: return None, None, True if binary is None or not binary.exists(): # Surface the actual compiler error so the agent can act on it fake = subprocess.CompletedProcess( args=[], returncode=1, stdout="", stderr=build_stderr ) return None, fake, False try: return _run_rust_binary(binary, cases) except subprocess.TimeoutExpired: return None, None, True class LanguageRunner(Protocol): def run( self, function_spec: FunctionSpec, candidate_code: str, cases: list[SampleCase], ) -> tuple[list[dict[str, Any]] | None, subprocess.CompletedProcess[str] | None, bool]: ... class PythonRunner: def run( self, function_spec: FunctionSpec, candidate_code: str, cases: list[SampleCase], ) -> tuple[list[dict[str, Any]] | None, subprocess.CompletedProcess[str] | None, bool]: return _run_python_candidate(function_spec.name, candidate_code, cases) class JSRunner: def run( self, function_spec: FunctionSpec, candidate_code: str, cases: list[SampleCase], ) -> tuple[list[dict[str, Any]] | None, subprocess.CompletedProcess[str] | None, bool]: return _run_js_candidate(function_spec.name, candidate_code, cases) class RustRunner: def run( self, function_spec: FunctionSpec, candidate_code: str, cases: list[SampleCase], ) -> tuple[list[dict[str, Any]] | None, subprocess.CompletedProcess[str] | None, bool]: return _run_rust_candidate(function_spec, candidate_code, cases) _RUNNERS: dict[str, LanguageRunner] = { "python": PythonRunner(), "typescript": JSRunner(), "rust": RustRunner(), } def run_candidate_tests( task: Task, function_spec: FunctionSpec, candidate_code: str ) -> TestRunResult: cases = task.sample_inputs.get(function_spec.name, []) if function_spec.is_proof_required: return TestRunResult( passed=True, tests_passed=0, tests_total=0, feedback=( f"run_tests skipped for '{function_spec.name}': proof-only functions are verified on submit." ), case_results=[], ) if not cases: return TestRunResult( passed=True, tests_passed=0, tests_total=0, feedback=f"run_tests skipped for '{function_spec.name}': no sample cases are defined.", case_results=[], ) runner = _RUNNERS.get(task.target_language) if runner is None: return TestRunResult( passed=False, tests_passed=0, tests_total=len(cases), feedback=f"run_tests: no runner registered for language '{task.target_language}'.", case_results=[], ) try: runner_output, process, timed_out = runner.run(function_spec, candidate_code, cases) if timed_out: return TestRunResult( passed=False, tests_passed=0, tests_total=len(cases), feedback=f"run_tests timed out for '{function_spec.name}' (build or execution).", timed_out=True, case_results=[], ) except subprocess.TimeoutExpired as error: return TestRunResult( passed=False, tests_passed=0, tests_total=len(cases), feedback=f"run_tests timed out for '{function_spec.name}' after 5 seconds.", stderr=str(error), timed_out=True, case_results=[], ) if process is None: _runtime_hints: dict[str, str] = { "rust": "cargo not found — install the Rust toolchain (https://rustup.rs) or check PATH.", "typescript": "Install node + tsx or ts-node for TypeScript tasks.", "python": "Python subprocess runner failed to start.", } hint = _runtime_hints.get( task.target_language, f"No runner found for '{task.target_language}'.", ) return TestRunResult( passed=False, tests_passed=0, tests_total=len(cases), feedback=f"run_tests could not find a runtime for '{function_spec.name}'. {hint}", case_results=[], ) if process.returncode != 0 or runner_output is None: stderr = process.stderr.strip() stdout = process.stdout.strip() if task.target_language == "rust" and process.args == []: feedback = ( f"run_tests failed for '{function_spec.name}'.\n" f"Rust build failed (compiler error).\n" f"stderr: {stderr[:300]}" ) else: feedback = ( f"run_tests failed for '{function_spec.name}'.\n" f"Process exit code: {process.returncode}\n" f"stderr: {stderr[:300]}" ) return TestRunResult( passed=False, tests_passed=0, tests_total=len(cases), feedback=feedback, stdout=stdout, stderr=stderr, case_results=[], ) expected_values = [ oracle_result(task.task_id, function_spec.name, case.args) for case in cases ] comparisons: list[TestCaseResult] = [] for expected_value, actual in zip(expected_values, runner_output, strict=True): if not actual.get("ok", False): comparisons.append( TestCaseResult( passed=False, expected=expected_value, actual=None, error=str(actual.get("error", "Unknown error")), ) ) continue actual_value = actual.get("value") comparisons.append( TestCaseResult( passed=_json_equal(actual_value, expected_value), expected=expected_value, actual=actual_value, ) ) passed_count = sum(1 for result in comparisons if result.passed) cases_with_results = list(zip(cases, comparisons)) failed_pairs = [(case, r) for case, r in cases_with_results if not r.passed] feedback_lines = [ f"run_tests for '{function_spec.name}': {passed_count}/{len(cases)} cases passed.", ] if failed_pairs: feedback_lines.append("Failures:") for index, (case, result) in enumerate(failed_pairs[:3], start=1): args_repr = ", ".join(repr(a) for a in case.args) if result.error: feedback_lines.append( f" {index}. input=({args_repr}) error: {result.error}" ) else: feedback_lines.append( f" {index}. input=({args_repr}) expected={result.expected!r} actual={result.actual!r}" ) return TestRunResult( passed=passed_count == len(cases), tests_passed=passed_count, tests_total=len(cases), feedback="\n".join(feedback_lines), stdout=process.stdout, stderr=process.stderr, case_results=comparisons, )