Spaces:
Running
Running
p4r5kpftnp-cmd
Speed up proof generation: cache components, smaller max_tokens, two-stage retry
df04431 | """ | |
| Stress tests for app.solve_proof under concurrent invocation. | |
| These tests mock out LangGraphAgent before importing app, so no Lean/LLM/Groq | |
| infrastructure is required. They focus on: | |
| * 10 simultaneous solve_proof calls returning distinct, non-cross-contaminated | |
| final_code values (one tempfile per call, paths unique) | |
| * Tempfile cleanup (no leftover .lean files in /tmp afterwards) | |
| * Exception isolation (one failing call must not affect the other 9) | |
| * Detecting the stdout-redirection race condition that the original | |
| contextlib.redirect_stdout based implementation exhibited. | |
| """ | |
| import os | |
| import sys | |
| import tempfile | |
| import threading | |
| import time | |
| import unittest | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from unittest import mock | |
| # Make `app` and `src/` importable. | |
| sys.path.insert(0, ".") | |
| sys.path.insert(0, "src") | |
| # IMPORTANT: mock LangGraphAgent BEFORE importing app so app.py never tries to | |
| # instantiate the real agent (which would require lean, langchain, faiss, etc.). | |
| _fake_lg_module = mock.MagicMock() | |
| class _DummyAgent: | |
| """Stand-in for LangGraphAgent. Configured per-test via class attributes.""" | |
| # Per-test knobs (mutated by tests before each scenario). | |
| sleep_seconds: float = 0.0 | |
| # If a thread's lean_code contains this token, the agent raises. | |
| raise_token: str | None = None | |
| # Optional: tag to print so we can detect cross-thread stdout bleed. | |
| print_tag: bool = False | |
| def __init__(self, model_name="x", max_retries=5, **kwargs): | |
| self.model_name = model_name | |
| self.max_retries = max_retries | |
| def solve_file_detailed(self, file_path: str) -> dict: | |
| # Read what app.solve_proof just wrote so we can react per-call. | |
| with open(file_path, "r") as f: | |
| code = f.read() | |
| if self.print_tag: | |
| # Print enough to be racy if stdout is process-global. | |
| for _ in range(20): | |
| print(f"TAG::{code}") | |
| # Sleep so calls overlap in real wall time. | |
| if self.sleep_seconds: | |
| time.sleep(self.sleep_seconds) | |
| if self.raise_token and self.raise_token in code: | |
| raise RuntimeError(f"injected failure for {self.raise_token}") | |
| return { | |
| "success": True, | |
| "solved_at_attempt": 1, | |
| "total_attempts": 1, | |
| } | |
| _fake_lg_module.LangGraphAgent = _DummyAgent | |
| sys.modules["langgraph_agent"] = _fake_lg_module | |
| # Pre-set GROQ_API_KEY so solve_proof doesn't short-circuit. | |
| os.environ.setdefault("GROQ_API_KEY", "test-key") | |
| import app # noqa: E402 (must come after the mock above) | |
| def _scan_tmp_lean(prefix_dir: str = "/tmp") -> set[str]: | |
| """Return absolute paths of all .lean files currently in `prefix_dir`.""" | |
| if not os.path.isdir(prefix_dir): | |
| return set() | |
| return { | |
| os.path.join(prefix_dir, f) | |
| for f in os.listdir(prefix_dir) | |
| if f.endswith(".lean") | |
| } | |
| class ConcurrentSolveTests(unittest.TestCase): | |
| def setUp(self): | |
| # Reset agent state between scenarios. | |
| _DummyAgent.sleep_seconds = 0.0 | |
| _DummyAgent.raise_token = None | |
| _DummyAgent.print_tag = False | |
| # Per-test patch: `app.solve_proof` does `from langgraph_agent import | |
| # LangGraphAgent` at import time, so `app.LangGraphAgent` is bound to | |
| # whichever module-level mock loaded first. Other test files do the | |
| # same trick and can win the race, so we patch on `app` directly. | |
| self._lg_patcher = mock.patch.object(app, "LangGraphAgent", _DummyAgent) | |
| self._lg_patcher.start() | |
| self.addCleanup(self._lg_patcher.stop) | |
| # Don't construct a real LeanEnvironment / FAISS retriever in tests. | |
| self._components_patcher = mock.patch.object( | |
| app, "_get_components", return_value=(mock.MagicMock(), mock.MagicMock()) | |
| ) | |
| self._components_patcher.start() | |
| self.addCleanup(self._components_patcher.stop) | |
| # Snapshot pre-existing tmp .lean files so we ignore unrelated ones. | |
| self._tmp_before = _scan_tmp_lean() | |
| def tearDown(self): | |
| # Best-effort cleanup of anything the tests created that survived a bug. | |
| leaked = _scan_tmp_lean() - self._tmp_before | |
| for path in leaked: | |
| try: | |
| os.unlink(path) | |
| except OSError: | |
| pass | |
| # ------------------------------------------------------------------ | |
| # 1 + 3: 10 parallel calls, agent sleeps so calls overlap. | |
| # ------------------------------------------------------------------ | |
| def test_ten_parallel_calls_no_cross_contamination(self): | |
| _DummyAgent.sleep_seconds = 0.1 # force overlap | |
| inputs = [f"theorem t_{i} : True := trivial" for i in range(10)] | |
| def run(code): | |
| return code, app.solve_proof(code, "llama-3.3-70b-versatile", 1) | |
| with ThreadPoolExecutor(max_workers=10) as ex: | |
| futures = [ex.submit(run, code) for code in inputs] | |
| results = [f.result() for f in as_completed(futures)] | |
| self.assertEqual(len(results), 10) | |
| for code, triple in results: | |
| self.assertIsInstance(triple, tuple, f"{code!r} did not return tuple") | |
| self.assertEqual(len(triple), 3, f"{code!r} returned {triple!r}") | |
| status_html, final_code, logs = triple | |
| # final_code is what the agent wrote/left in the tempfile, which the | |
| # mock leaves untouched, so it must equal the input verbatim. | |
| self.assertEqual( | |
| final_code, | |
| code, | |
| f"Cross-contamination: expected {code!r} got {final_code!r}", | |
| ) | |
| self.assertIn("Verified", status_html, f"Unexpected status: {status_html!r}") | |
| # All tempfiles must be cleaned up. | |
| leaked = _scan_tmp_lean() - self._tmp_before | |
| self.assertEqual(leaked, set(), f"Tempfiles leaked: {leaked!r}") | |
| # ------------------------------------------------------------------ | |
| # 2: Tempfile collision check. Patch NamedTemporaryFile to log paths. | |
| # ------------------------------------------------------------------ | |
| def test_tempfile_paths_unique_under_concurrency(self): | |
| _DummyAgent.sleep_seconds = 0.05 | |
| seen_paths: list[str] = [] | |
| seen_lock = threading.Lock() | |
| real_ntf = tempfile.NamedTemporaryFile | |
| def logging_ntf(*args, **kwargs): | |
| f = real_ntf(*args, **kwargs) | |
| with seen_lock: | |
| seen_paths.append(f.name) | |
| return f | |
| with mock.patch.object(app.tempfile, "NamedTemporaryFile", side_effect=logging_ntf): | |
| with ThreadPoolExecutor(max_workers=10) as ex: | |
| futures = [ | |
| ex.submit( | |
| app.solve_proof, | |
| f"theorem unique_{i} : True := trivial", | |
| "llama-3.3-70b-versatile", | |
| 1, | |
| ) | |
| for i in range(10) | |
| ] | |
| results = [f.result() for f in as_completed(futures)] | |
| self.assertEqual(len(results), 10) | |
| self.assertEqual( | |
| len(seen_paths), | |
| len(set(seen_paths)), | |
| f"Tempfile path collision: {seen_paths!r}", | |
| ) | |
| # ------------------------------------------------------------------ | |
| # 4: one call raises, the other 9 still succeed. | |
| # ------------------------------------------------------------------ | |
| def test_one_raising_call_does_not_break_others(self): | |
| _DummyAgent.sleep_seconds = 0.1 | |
| _DummyAgent.raise_token = "BOOM_3" | |
| inputs = [] | |
| for i in range(10): | |
| if i == 3: | |
| inputs.append("theorem t_BOOM_3 : True := trivial") | |
| else: | |
| inputs.append(f"theorem t_{i} : True := trivial") | |
| def run(code): | |
| return code, app.solve_proof(code, "llama-3.3-70b-versatile", 1) | |
| with ThreadPoolExecutor(max_workers=10) as ex: | |
| futures = [ex.submit(run, code) for code in inputs] | |
| results = [f.result() for f in as_completed(futures)] | |
| oks = [] | |
| errs = [] | |
| for code, triple in results: | |
| self.assertIsInstance(triple, tuple) | |
| self.assertEqual(len(triple), 3) | |
| status_html, final_code, logs = triple | |
| if "BOOM_3" in code: | |
| errs.append((code, status_html, final_code)) | |
| else: | |
| oks.append((code, status_html, final_code)) | |
| self.assertEqual(len(oks), 9, "Expected 9 successful calls") | |
| self.assertEqual(len(errs), 1, "Expected exactly 1 erroring call") | |
| # Erroring call returns the "✗ Error:" status (per app.solve_proof). | |
| _, err_status, err_final = errs[0] | |
| self.assertIn("Error", err_status) | |
| self.assertEqual(err_final, "") | |
| # The 9 successful calls keep their own inputs intact. | |
| for code, status, final in oks: | |
| self.assertEqual(final, code) | |
| self.assertIn("Verified", status) | |
| # No tempfile leaks even on the error path. | |
| leaked = _scan_tmp_lean() - self._tmp_before | |
| self.assertEqual(leaked, set(), f"Tempfiles leaked: {leaked!r}") | |
| # ------------------------------------------------------------------ | |
| # Stdout race: with print() inside the agent, log buffers must NOT | |
| # see each other's output. This is the bug redirect_stdout causes. | |
| # ------------------------------------------------------------------ | |
| def test_stdout_logs_are_not_cross_contaminated(self): | |
| _DummyAgent.sleep_seconds = 0.1 | |
| _DummyAgent.print_tag = True # agent prints "TAG::<code>" 20 times | |
| inputs = [f"theorem stdoutcheck_{i} : True := trivial" for i in range(10)] | |
| def run(code): | |
| return code, app.solve_proof(code, "llama-3.3-70b-versatile", 1) | |
| with ThreadPoolExecutor(max_workers=10) as ex: | |
| futures = [ex.submit(run, c) for c in inputs] | |
| results = [f.result() for f in as_completed(futures)] | |
| # Each call's `logs` should contain its own TAG line and only its own. | |
| offenders = [] | |
| for code, triple in results: | |
| _, _, logs = triple | |
| own_tag = f"TAG::{code}" | |
| if own_tag not in logs: | |
| offenders.append(("own-tag-missing", code, logs[:200])) | |
| continue | |
| # No other call's code should leak into this log. | |
| for other in inputs: | |
| if other == code: | |
| continue | |
| if f"TAG::{other}" in logs: | |
| offenders.append(("other-tag-leaked", code, other)) | |
| break | |
| self.assertEqual( | |
| offenders, | |
| [], | |
| "Stdout was cross-contaminated between concurrent solve_proof calls " | |
| "(redirect_stdout is process-global, not thread-local). " | |
| f"Offenders: {offenders[:5]!r}", | |
| ) | |
| if __name__ == "__main__": | |
| unittest.main() | |