lean-migrate / tests /test_env_grader.py
Hrushi's picture
Upload folder using huggingface_hub
bf9c466 verified
from __future__ import annotations
import subprocess
from lean_migrate.env import grader
from lean_migrate.env.grader import (
build_lean_sample_checks,
clamp_open_unit,
run_candidate_tests,
)
from lean_migrate.env.tasks import get_task, list_tasks
def test_list_tasks() -> None:
tasks = list_tasks()
assert {task["task_id"] for task in tasks} == {
"rbac_auth",
"pricing_engine",
"payment_saga",
"path_canonicalizer",
"expression_eval",
"lru_cache",
"shortest_path",
"interval_scheduler",
}
assert all(task["num_functions"] > 0 for task in tasks)
def test_sample_checks_for_rbac_find_role() -> None:
task = get_task("rbac_auth")
function_spec = task.get_function("findRole")
assert function_spec is not None
checks = build_lean_sample_checks(task, function_spec)
assert len(checks) == 10
assert all("example :" in check for check in checks)
def test_proof_only_function_skips_sample_checks() -> None:
task = get_task("payment_saga")
function_spec = task.get_function("no_double_charge_proof")
assert function_spec is not None
checks = build_lean_sample_checks(task, function_spec)
assert checks == []
def test_run_tests_for_python_candidate() -> None:
task = get_task("pricing_engine")
function_spec = task.get_function("subtotal")
assert function_spec is not None
result = run_candidate_tests(
task,
function_spec,
"""
def subtotal(order):
return sum(item["unitPrice"] * item["quantity"] for item in order["items"])
""".strip(),
)
assert result.passed
assert result.tests_passed == 10
assert result.tests_total == 10
assert result.case_results is not None
assert len(result.case_results) == 10
assert "10/10" in result.feedback
def test_run_tests_reports_failures() -> None:
task = get_task("pricing_engine")
function_spec = task.get_function("subtotal")
assert function_spec is not None
result = run_candidate_tests(
task,
function_spec,
"""
def subtotal(order):
raise RuntimeError("boom")
""".strip(),
)
assert not result.passed
assert result.tests_passed == 0
assert result.tests_total == 10
assert "run_tests failed" in result.feedback or "0/10" in result.feedback
def test_run_tests_reports_rust_build_failures(monkeypatch) -> None:
task = get_task("lru_cache")
function_spec = task.get_function("lruEvict")
assert function_spec is not None
class FakeRustRunner:
def run(self, function_spec, candidate_code, cases):
return (
None,
subprocess.CompletedProcess(
args=[],
returncode=1,
stdout="",
stderr="error[E0308]: mismatched types\n",
),
False,
)
monkeypatch.setitem(grader._RUNNERS, "rust", FakeRustRunner())
result = run_candidate_tests(
task,
function_spec,
"pub fn lru_evict(cache: Vec<(u64, u64)>, cap: usize) -> Vec<(u64, u64)> { cache }",
)
assert not result.passed
assert result.tests_passed == 0
assert result.tests_total == len(task.sample_inputs[function_spec.name])
assert "Rust build failed (compiler error)" in result.feedback
assert "Interpreter exit code" not in result.feedback
def test_run_tests_for_expression_eval_rust_candidate() -> None:
task = get_task("expression_eval")
function_spec = task.get_function("evalExpr")
assert function_spec is not None
result = run_candidate_tests(
task,
function_spec,
"""
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Op {
Add,
Sub,
Mul,
Div,
}
#[derive(Debug, Clone)]
pub enum Expr {
Lit(i64),
BinOp(Op, Box<Expr>, Box<Expr>),
}
pub fn eval_bin_op(op: Op, a: i64, b: i64) -> Option<i64> {
match op {
Op::Add => Some(a + b),
Op::Sub => Some(a - b),
Op::Mul => Some(a * b),
Op::Div => {
if b == 0 {
None
} else {
Some(a / b)
}
}
}
}
pub fn eval_expr(expr: &Expr) -> Option<i64> {
match expr {
Expr::Lit(n) => Some(*n),
Expr::BinOp(op, l, r) => {
let left = eval_expr(l);
let right = eval_expr(r);
match (left, right) {
(Some(a), Some(b)) => eval_bin_op(*op, a, b),
_ => None,
}
}
}
}
""".strip(),
)
assert result.passed
assert result.tests_passed == 10
assert result.tests_total == 10
assert "10/10" in result.feedback
def test_clamp_open_unit_keeps_scores_in_open_interval() -> None:
assert clamp_open_unit(0.0) == 0.01
assert clamp_open_unit(1.0) == 0.99
assert clamp_open_unit(0.42) == 0.42