Lean4-helper / tests /test_concurrent_solve.py
p4r5kpftnp-cmd
Speed up proof generation: cache components, smaller max_tokens, two-stage retry
df04431
Raw
History Blame Contribute Delete
11.1 kB
"""
Stress tests for app.solve_proof under concurrent invocation.
These tests mock out LangGraphAgent before importing app, so no Lean/LLM/Groq
infrastructure is required. They focus on:
* 10 simultaneous solve_proof calls returning distinct, non-cross-contaminated
final_code values (one tempfile per call, paths unique)
* Tempfile cleanup (no leftover .lean files in /tmp afterwards)
* Exception isolation (one failing call must not affect the other 9)
* Detecting the stdout-redirection race condition that the original
contextlib.redirect_stdout based implementation exhibited.
"""
import os
import sys
import tempfile
import threading
import time
import unittest
from concurrent.futures import ThreadPoolExecutor, as_completed
from unittest import mock
# Make `app` and `src/` importable.
sys.path.insert(0, ".")
sys.path.insert(0, "src")
# IMPORTANT: mock LangGraphAgent BEFORE importing app so app.py never tries to
# instantiate the real agent (which would require lean, langchain, faiss, etc.).
_fake_lg_module = mock.MagicMock()
class _DummyAgent:
"""Stand-in for LangGraphAgent. Configured per-test via class attributes."""
# Per-test knobs (mutated by tests before each scenario).
sleep_seconds: float = 0.0
# If a thread's lean_code contains this token, the agent raises.
raise_token: str | None = None
# Optional: tag to print so we can detect cross-thread stdout bleed.
print_tag: bool = False
def __init__(self, model_name="x", max_retries=5, **kwargs):
self.model_name = model_name
self.max_retries = max_retries
def solve_file_detailed(self, file_path: str) -> dict:
# Read what app.solve_proof just wrote so we can react per-call.
with open(file_path, "r") as f:
code = f.read()
if self.print_tag:
# Print enough to be racy if stdout is process-global.
for _ in range(20):
print(f"TAG::{code}")
# Sleep so calls overlap in real wall time.
if self.sleep_seconds:
time.sleep(self.sleep_seconds)
if self.raise_token and self.raise_token in code:
raise RuntimeError(f"injected failure for {self.raise_token}")
return {
"success": True,
"solved_at_attempt": 1,
"total_attempts": 1,
}
_fake_lg_module.LangGraphAgent = _DummyAgent
sys.modules["langgraph_agent"] = _fake_lg_module
# Pre-set GROQ_API_KEY so solve_proof doesn't short-circuit.
os.environ.setdefault("GROQ_API_KEY", "test-key")
import app # noqa: E402 (must come after the mock above)
def _scan_tmp_lean(prefix_dir: str = "/tmp") -> set[str]:
"""Return absolute paths of all .lean files currently in `prefix_dir`."""
if not os.path.isdir(prefix_dir):
return set()
return {
os.path.join(prefix_dir, f)
for f in os.listdir(prefix_dir)
if f.endswith(".lean")
}
class ConcurrentSolveTests(unittest.TestCase):
def setUp(self):
# Reset agent state between scenarios.
_DummyAgent.sleep_seconds = 0.0
_DummyAgent.raise_token = None
_DummyAgent.print_tag = False
# Per-test patch: `app.solve_proof` does `from langgraph_agent import
# LangGraphAgent` at import time, so `app.LangGraphAgent` is bound to
# whichever module-level mock loaded first. Other test files do the
# same trick and can win the race, so we patch on `app` directly.
self._lg_patcher = mock.patch.object(app, "LangGraphAgent", _DummyAgent)
self._lg_patcher.start()
self.addCleanup(self._lg_patcher.stop)
# Don't construct a real LeanEnvironment / FAISS retriever in tests.
self._components_patcher = mock.patch.object(
app, "_get_components", return_value=(mock.MagicMock(), mock.MagicMock())
)
self._components_patcher.start()
self.addCleanup(self._components_patcher.stop)
# Snapshot pre-existing tmp .lean files so we ignore unrelated ones.
self._tmp_before = _scan_tmp_lean()
def tearDown(self):
# Best-effort cleanup of anything the tests created that survived a bug.
leaked = _scan_tmp_lean() - self._tmp_before
for path in leaked:
try:
os.unlink(path)
except OSError:
pass
# ------------------------------------------------------------------
# 1 + 3: 10 parallel calls, agent sleeps so calls overlap.
# ------------------------------------------------------------------
def test_ten_parallel_calls_no_cross_contamination(self):
_DummyAgent.sleep_seconds = 0.1 # force overlap
inputs = [f"theorem t_{i} : True := trivial" for i in range(10)]
def run(code):
return code, app.solve_proof(code, "llama-3.3-70b-versatile", 1)
with ThreadPoolExecutor(max_workers=10) as ex:
futures = [ex.submit(run, code) for code in inputs]
results = [f.result() for f in as_completed(futures)]
self.assertEqual(len(results), 10)
for code, triple in results:
self.assertIsInstance(triple, tuple, f"{code!r} did not return tuple")
self.assertEqual(len(triple), 3, f"{code!r} returned {triple!r}")
status_html, final_code, logs = triple
# final_code is what the agent wrote/left in the tempfile, which the
# mock leaves untouched, so it must equal the input verbatim.
self.assertEqual(
final_code,
code,
f"Cross-contamination: expected {code!r} got {final_code!r}",
)
self.assertIn("Verified", status_html, f"Unexpected status: {status_html!r}")
# All tempfiles must be cleaned up.
leaked = _scan_tmp_lean() - self._tmp_before
self.assertEqual(leaked, set(), f"Tempfiles leaked: {leaked!r}")
# ------------------------------------------------------------------
# 2: Tempfile collision check. Patch NamedTemporaryFile to log paths.
# ------------------------------------------------------------------
def test_tempfile_paths_unique_under_concurrency(self):
_DummyAgent.sleep_seconds = 0.05
seen_paths: list[str] = []
seen_lock = threading.Lock()
real_ntf = tempfile.NamedTemporaryFile
def logging_ntf(*args, **kwargs):
f = real_ntf(*args, **kwargs)
with seen_lock:
seen_paths.append(f.name)
return f
with mock.patch.object(app.tempfile, "NamedTemporaryFile", side_effect=logging_ntf):
with ThreadPoolExecutor(max_workers=10) as ex:
futures = [
ex.submit(
app.solve_proof,
f"theorem unique_{i} : True := trivial",
"llama-3.3-70b-versatile",
1,
)
for i in range(10)
]
results = [f.result() for f in as_completed(futures)]
self.assertEqual(len(results), 10)
self.assertEqual(
len(seen_paths),
len(set(seen_paths)),
f"Tempfile path collision: {seen_paths!r}",
)
# ------------------------------------------------------------------
# 4: one call raises, the other 9 still succeed.
# ------------------------------------------------------------------
def test_one_raising_call_does_not_break_others(self):
_DummyAgent.sleep_seconds = 0.1
_DummyAgent.raise_token = "BOOM_3"
inputs = []
for i in range(10):
if i == 3:
inputs.append("theorem t_BOOM_3 : True := trivial")
else:
inputs.append(f"theorem t_{i} : True := trivial")
def run(code):
return code, app.solve_proof(code, "llama-3.3-70b-versatile", 1)
with ThreadPoolExecutor(max_workers=10) as ex:
futures = [ex.submit(run, code) for code in inputs]
results = [f.result() for f in as_completed(futures)]
oks = []
errs = []
for code, triple in results:
self.assertIsInstance(triple, tuple)
self.assertEqual(len(triple), 3)
status_html, final_code, logs = triple
if "BOOM_3" in code:
errs.append((code, status_html, final_code))
else:
oks.append((code, status_html, final_code))
self.assertEqual(len(oks), 9, "Expected 9 successful calls")
self.assertEqual(len(errs), 1, "Expected exactly 1 erroring call")
# Erroring call returns the "✗ Error:" status (per app.solve_proof).
_, err_status, err_final = errs[0]
self.assertIn("Error", err_status)
self.assertEqual(err_final, "")
# The 9 successful calls keep their own inputs intact.
for code, status, final in oks:
self.assertEqual(final, code)
self.assertIn("Verified", status)
# No tempfile leaks even on the error path.
leaked = _scan_tmp_lean() - self._tmp_before
self.assertEqual(leaked, set(), f"Tempfiles leaked: {leaked!r}")
# ------------------------------------------------------------------
# Stdout race: with print() inside the agent, log buffers must NOT
# see each other's output. This is the bug redirect_stdout causes.
# ------------------------------------------------------------------
def test_stdout_logs_are_not_cross_contaminated(self):
_DummyAgent.sleep_seconds = 0.1
_DummyAgent.print_tag = True # agent prints "TAG::<code>" 20 times
inputs = [f"theorem stdoutcheck_{i} : True := trivial" for i in range(10)]
def run(code):
return code, app.solve_proof(code, "llama-3.3-70b-versatile", 1)
with ThreadPoolExecutor(max_workers=10) as ex:
futures = [ex.submit(run, c) for c in inputs]
results = [f.result() for f in as_completed(futures)]
# Each call's `logs` should contain its own TAG line and only its own.
offenders = []
for code, triple in results:
_, _, logs = triple
own_tag = f"TAG::{code}"
if own_tag not in logs:
offenders.append(("own-tag-missing", code, logs[:200]))
continue
# No other call's code should leak into this log.
for other in inputs:
if other == code:
continue
if f"TAG::{other}" in logs:
offenders.append(("other-tag-leaked", code, other))
break
self.assertEqual(
offenders,
[],
"Stdout was cross-contaminated between concurrent solve_proof calls "
"(redirect_stdout is process-global, not thread-local). "
f"Offenders: {offenders[:5]!r}",
)
if __name__ == "__main__":
unittest.main()