Lean4-helper / tests /test_state_machine_fuzz.py
p4r5kpftnp-cmd
Pin ruff to 0.15.14 and re-sort imports
01cdf65
Raw
History Blame Contribute Delete
13.6 kB
"""
Fuzz tests for the state machine in src/langgraph_agent.py.
Covers:
- should_continue(state): exhaustive combinations of status/attempt/max_retries.
- make_generate_node(chain): behavior with mocked chain output.
- make_verify_node(lean_env): behavior with mocked verifier.
"""
import os
import sys
import unittest
from unittest.mock import MagicMock, patch
# Add src to Python path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src')))
sys.path.insert(0, 'src')
from langgraph.graph import END
import langgraph_agent
from langgraph_agent import (
make_generate_node,
make_verify_node,
should_continue,
)
def _base_state(**overrides):
"""Build a minimal ProofState dict, allowing field overrides."""
state = {
"file_path": "/tmp/fake.lean",
"lean_code": "",
"goals": [],
"errors": [],
"attempt": 0,
"max_retries": 5,
"status": "pending",
"retrieved_lemmas": [],
"solved_at_attempt": 0,
}
state.update(overrides)
return state
# ---------------------------------------------------------------------------
# Fuzz: should_continue
# ---------------------------------------------------------------------------
class TestShouldContinueFuzz(unittest.TestCase):
"""Exhaustive corpus over status x attempt x max_retries."""
STATUSES = ["pending", "success", "failed"]
MAX_RETRIES_VALUES = [1, 5, 10]
def _expected(self, status, attempt, max_retries):
if status == "success":
return END
if attempt >= max_retries:
return END
return "retrieve"
def test_should_continue_fuzz_corpus(self):
cases = 0
for status in self.STATUSES:
for max_retries in self.MAX_RETRIES_VALUES:
attempts = [0, 1, max_retries - 1, max_retries, max_retries + 1]
# Drop negative attempts (max_retries=1 -> attempt=0 dup)
attempts = [a for a in attempts if a >= 0]
for attempt in attempts:
state = _base_state(
status=status,
attempt=attempt,
max_retries=max_retries,
)
expected = self._expected(status, attempt, max_retries)
got = should_continue(state)
self.assertEqual(
got,
expected,
msg=(
f"should_continue mismatch for "
f"status={status!r}, attempt={attempt}, "
f"max_retries={max_retries}: "
f"expected {expected!r}, got {got!r}"
),
)
cases += 1
# Sanity: corpus actually exercised many cases.
self.assertGreaterEqual(cases, 30)
def test_success_always_ends(self):
for attempt in [-1, 0, 1, 5, 100]:
for max_retries in [1, 5, 10]:
state = _base_state(
status="success", attempt=attempt, max_retries=max_retries
)
self.assertEqual(should_continue(state), END)
def test_attempt_at_max_retries_ends(self):
for max_retries in [1, 5, 10]:
state = _base_state(
status="pending", attempt=max_retries, max_retries=max_retries
)
self.assertEqual(should_continue(state), END)
def test_attempt_exceeds_max_retries_ends(self):
for max_retries in [1, 5, 10]:
state = _base_state(
status="pending", attempt=max_retries + 1, max_retries=max_retries
)
self.assertEqual(should_continue(state), END)
def test_pending_under_max_retries_retrieves(self):
for max_retries in [1, 5, 10]:
for attempt in range(max_retries):
state = _base_state(
status="pending", attempt=attempt, max_retries=max_retries
)
self.assertEqual(should_continue(state), "retrieve")
def test_failed_status_also_retries_if_under_max(self):
"""status='failed' is *not* terminal — only success and budget exhaustion are."""
state = _base_state(status="failed", attempt=0, max_retries=5)
self.assertEqual(should_continue(state), "retrieve")
# ---------------------------------------------------------------------------
# Fuzz: make_generate_node
# ---------------------------------------------------------------------------
class TestGenerateNodeFuzz(unittest.TestCase):
def setUp(self):
# Patch file I/O so generate_node never touches disk.
self._write_patcher = patch.object(langgraph_agent, "_write_file")
self._read_patcher = patch.object(langgraph_agent, "_read_file")
self.mock_write = self._write_patcher.start()
self.mock_read = self._read_patcher.start()
self.addCleanup(self._write_patcher.stop)
self.addCleanup(self._read_patcher.stop)
def _make_chain(self, output):
chain = MagicMock()
chain.generate.return_value = output
return chain
# ---- Identical output ----
def test_chain_returns_identical_code_marks_failed(self):
"""If chain regurgitates the input verbatim, status -> 'failed', attempt += 1, no write."""
original = "import Mathlib\n\ntheorem foo : True := by trivial"
chain = self._make_chain(original)
node = make_generate_node(chain)
state = _base_state(lean_code=original, attempt=2)
out = node(state)
self.assertEqual(out["status"], "failed")
self.assertEqual(out["attempt"], 3)
self.mock_write.assert_not_called()
# ---- Fewer theorem blocks ----
def test_chain_drops_theorem_blocks_is_rejected(self):
"""If the LLM produces fewer theorem-like declarations, reject (no write); still increment attempt."""
original = (
"import Mathlib\n\n"
"theorem foo : True := by sorry\n"
"theorem bar : True := by sorry\n"
)
# Output drops 'bar' entirely.
bad_output = (
"```lean\n"
"import Mathlib\n\n"
"theorem foo : True := by trivial\n"
"```"
)
chain = self._make_chain(bad_output)
node = make_generate_node(chain)
state = _base_state(lean_code=original, attempt=1)
out = node(state)
self.assertEqual(out["attempt"], 2)
self.mock_write.assert_not_called()
# lean_code must not have been replaced with the dropped version.
self.assertEqual(out["lean_code"], original)
# ---- More theorem blocks ----
def test_chain_adds_theorem_blocks_is_accepted(self):
"""More theorem blocks than original — still accepted (count check is one-sided)."""
original = (
"import Mathlib\n\n"
"theorem foo : True := by sorry\n"
)
new_output = (
"```lean\n"
"import Mathlib\n\n"
"theorem foo : True := by trivial\n"
"theorem helper : True := by trivial\n"
"```"
)
chain = self._make_chain(new_output)
node = make_generate_node(chain)
state = _base_state(lean_code=original, attempt=0)
out = node(state)
self.assertEqual(out["attempt"], 1)
self.mock_write.assert_called_once()
# lean_code should be updated.
self.assertIn("theorem helper", out["lean_code"])
# ---- Empty string ----
def test_chain_returns_empty_string_marks_failed(self):
"""Empty output from the chain must terminate with status='failed', not silently write garbage."""
original = "import Mathlib\n\ntheorem foo : True := by sorry"
chain = self._make_chain("")
node = make_generate_node(chain)
state = _base_state(lean_code=original, attempt=0)
out = node(state)
self.assertEqual(out["status"], "failed")
self.assertEqual(out["attempt"], 1)
self.mock_write.assert_not_called()
def test_chain_returns_whitespace_only_marks_failed(self):
"""Whitespace-only output also counts as nothing useful."""
original = "import Mathlib\n\ntheorem foo : True := by sorry"
chain = self._make_chain(" \n\t \n")
node = make_generate_node(chain)
state = _base_state(lean_code=original, attempt=0)
out = node(state)
self.assertEqual(out["status"], "failed")
self.assertEqual(out["attempt"], 1)
self.mock_write.assert_not_called()
# ---- Proper theorem block ----
def test_chain_returns_valid_proof_writes_file(self):
"""A well-formed replacement gets written, attempt increments, no 'failed'."""
original = "import Mathlib\n\ntheorem foo : True := by sorry"
good_output = (
"```lean\n"
"import Mathlib\n\n"
"theorem foo : True := by trivial\n"
"```"
)
chain = self._make_chain(good_output)
node = make_generate_node(chain)
state = _base_state(lean_code=original, attempt=2, file_path="/tmp/x.lean")
out = node(state)
self.assertEqual(out["attempt"], 3)
self.assertNotEqual(out.get("status"), "failed")
self.mock_write.assert_called_once()
# Path must be passed through.
called_path, called_code = self.mock_write.call_args[0]
self.assertEqual(called_path, "/tmp/x.lean")
self.assertIn("trivial", called_code)
# State's lean_code is updated to the new code.
self.assertEqual(out["lean_code"], called_code)
def test_chain_output_with_no_code_fences_still_works(self):
"""Raw lean code (no ```lean fence) should be accepted."""
original = "import Mathlib\n\ntheorem foo : True := by sorry"
raw = "theorem foo : True := by trivial"
chain = self._make_chain(raw)
node = make_generate_node(chain)
state = _base_state(lean_code=original, attempt=0)
out = node(state)
self.assertEqual(out["attempt"], 1)
self.mock_write.assert_called_once()
# ---------------------------------------------------------------------------
# Fuzz: make_verify_node
# ---------------------------------------------------------------------------
class TestVerifyNodeFuzz(unittest.TestCase):
def setUp(self):
self._read_patcher = patch.object(langgraph_agent, "_read_file")
self.mock_read = self._read_patcher.start()
self.addCleanup(self._read_patcher.stop)
# Default file contents from disk
self.mock_read.return_value = "import Mathlib\n\ntheorem foo : True := by trivial"
def _make_env(self, status, errors, goals):
env = MagicMock()
env.verify_proof.return_value = {
"status": status,
"errors": errors,
"goals": goals,
}
return env
def test_success_marks_state_success_and_solved_at_attempt(self):
env = self._make_env("success", [], [])
node = make_verify_node(env)
# attempt=2 means the *current* attempt counter is 2; the run we are
# about to perform is attempt #3 -> solved_at_attempt should be 3.
state = _base_state(attempt=2, solved_at_attempt=0)
out = node(state)
self.assertEqual(out["status"], "success")
self.assertEqual(out["solved_at_attempt"], 3)
self.assertEqual(out["errors"], [])
self.assertEqual(out["goals"], [])
def test_failure_marks_state_pending_and_keeps_errors_goals(self):
env = self._make_env("failure", ["err1"], ["goal1"])
node = make_verify_node(env)
state = _base_state(attempt=0, solved_at_attempt=0)
out = node(state)
self.assertEqual(out["status"], "pending")
self.assertEqual(out["errors"], ["err1"])
self.assertEqual(out["goals"], ["goal1"])
# solved_at_attempt stays unchanged on failure.
self.assertEqual(out["solved_at_attempt"], 0)
def test_failure_then_success_records_correct_solved_attempt(self):
"""A success on the 5th attempt should record solved_at_attempt=5."""
env_fail = self._make_env("failure", ["err"], ["goal"])
env_ok = self._make_env("success", [], [])
state = _base_state(attempt=0, solved_at_attempt=0)
# Simulate four failures then a success.
for i in range(4):
node = make_verify_node(env_fail)
state = node(state)
self.assertEqual(state["status"], "pending")
# The agent's generate-step would normally bump 'attempt'; simulate that.
state["attempt"] += 1
# Now: attempt=4, next verify should be attempt #5.
node_ok = make_verify_node(env_ok)
final = node_ok(state)
self.assertEqual(final["status"], "success")
self.assertEqual(final["solved_at_attempt"], 5)
def test_verify_reads_from_disk(self):
env = self._make_env("success", [], [])
node = make_verify_node(env)
state = _base_state(file_path="/tmp/specific.lean")
out = node(state)
self.mock_read.assert_called_once_with("/tmp/specific.lean")
# lean_code in state should be what was read.
self.assertEqual(out["lean_code"], self.mock_read.return_value)
if __name__ == "__main__":
unittest.main()