Spaces:
Sleeping
Sleeping
| """Tests for the QED Math Environment. | |
| Unit tests (no network/Docker required): | |
| - parse_schema, length_penalty, apply_score_threshold | |
| - MathProofRubric: normalize_reward, _parse_response, _build_prompt, | |
| grade() with mocked LLM, empty-proof fast path, retry/failure paths | |
| - QEDMathEnvironment: reset, get_problem_payload, _verify_math, | |
| _grade_answer_submission, _apply_reward_shaping, _strip_reasoning, | |
| submit_proof_payload (mocked grader), multi-step problems, | |
| original_problem field, verifier metrics, dataset loading | |
| - MCP tool registration (get_problem, submit_proof, get_grading_guidelines) | |
| - Models: ProblemObservation, ProofSubmissionObservation | |
| Integration tests (require running server, marked @pytest.mark.integration): | |
| - Full HTTP/WebSocket stack smoke test | |
| Run from repo root: | |
| uv run pytest tests/test_environment.py -v | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from unittest.mock import AsyncMock, MagicMock, patch | |
| import pytest | |
| # Path setup — ensure repository root is importable | |
| REPO_ROOT = Path(__file__).resolve().parent.parent | |
| sys.path.insert(0, str(REPO_ROOT)) | |
| # Autouse fixture: patch openai.AsyncOpenAI so tests that instantiate | |
| # QEDMathEnvironment (which eagerly creates a MathProofRubric / AsyncOpenAI | |
| # client) work without a real OPENAI_API_KEY in the environment. | |
| def _patch_openai_client(monkeypatch): | |
| """Replace openai.AsyncOpenAI with a MagicMock for the duration of each test.""" | |
| mock_client = MagicMock() | |
| # Ensure async methods on the mock return coroutines | |
| mock_client.responses.create = AsyncMock() | |
| mock_client.chat.completions.create = AsyncMock() | |
| monkeypatch.setattr("openai.AsyncOpenAI", lambda **kwargs: mock_client) | |
| yield mock_client | |
| # Guard: skip entire module if required packages are missing | |
| pytest.importorskip("math_verify", reason="math_verify is not installed") | |
| pytest.importorskip("fastmcp", reason="fastmcp is not installed") | |
| from openenv.core.env_server.mcp_environment import get_server_tools # noqa: E402 | |
| from models import ( # noqa: E402 | |
| ProblemObservation, | |
| ProofSubmissionObservation, | |
| ) | |
| from server.math_verify_service import ( # noqa: E402 | |
| MathVerifierService, | |
| VerifyRequest, | |
| VerifyResponse, | |
| _verify_answer_worker, | |
| ) | |
| from server.qed_math_environment import ( # noqa: E402 | |
| QEDMathEnvironment, | |
| _normalize_problem, | |
| load_problems, | |
| remove_reasoning, | |
| ) | |
| from server.rubric import ( # noqa: E402 | |
| MAX_SCORE, | |
| GradingResult, | |
| MathProofRubric, | |
| apply_score_threshold, | |
| length_penalty, | |
| parse_schema, | |
| ) | |
| # Helpers | |
| BOOTSTRAP_PROBLEM = { | |
| "problem": "Prove that the sum of two even integers is even.", | |
| "reference_solution": "Let a=2m and b=2n. Then a+b=2(m+n), so it is even.", | |
| "grading_guidelines": "Award full credit for a correct parity argument.", | |
| "problem_id": "bootstrap_000001", | |
| "dataset_source": "bootstrap", | |
| "problem_type": "proof", | |
| "max_attempts": 1, | |
| } | |
| ANSWER_PROBLEM = { | |
| "problem": "What is 2+2?", | |
| "reference_solution": "4", | |
| "grading_guidelines": "", | |
| "problem_id": "answer_001", | |
| "dataset_source": "test", | |
| "problem_type": "answer", | |
| "evaluation_mode": "answer", | |
| "max_attempts": 1, | |
| } | |
| MULTI_STEP_PROBLEM = { | |
| "problem": "Prove Fermat's Last Theorem.", | |
| "reference_solution": "Wiles 1995.", | |
| "grading_guidelines": "", | |
| "problem_id": "multi_001", | |
| "dataset_source": "test", | |
| "problem_type": "multi_step", | |
| "max_attempts": 3, | |
| } | |
| def _make_env(**kwargs) -> QEDMathEnvironment: | |
| """Create an environment with bootstrap problems (no dataset needed).""" | |
| return QEDMathEnvironment(**kwargs) | |
| def _make_env_with_problem(raw_problem: dict) -> QEDMathEnvironment: | |
| """Create an environment pre-loaded with a single synthetic problem.""" | |
| env = _make_env() | |
| normalized = _normalize_problem(raw_problem, 0, raw_problem.get("dataset_source", "test")) | |
| env._problems = [normalized] | |
| env._gold_cache_problem_count = len(env._problems) | |
| env._build_gold_answer_cache() | |
| return env | |
| # parse_schema | |
| class TestParseSchema: | |
| def test_string_passthrough(self): | |
| assert parse_schema("plain text") == "plain text" | |
| def test_empty_string(self): | |
| assert parse_schema("") == "" | |
| def test_list_of_criteria(self): | |
| schema = [ | |
| {"title": "Correctness", "points": 4, "desc": "Is it correct?"}, | |
| {"title": "Clarity", "points": 3, "description": "Is it clear?"}, | |
| ] | |
| result = parse_schema(schema) | |
| assert "# Correctness (4 points)" in result | |
| assert "# Clarity (3 points)" in result | |
| assert "Is it correct?" in result | |
| assert "Is it clear?" in result | |
| def test_invalid_type_raises(self): | |
| with pytest.raises(TypeError): | |
| parse_schema(42) | |
| def test_list_missing_keys_raises(self): | |
| with pytest.raises(ValueError): | |
| parse_schema([{"title": "X"}]) # missing points/desc | |
| def test_list_non_dict_entry_raises(self): | |
| with pytest.raises(ValueError): | |
| parse_schema(["not a dict"]) | |
| # length_penalty | |
| class TestLengthPenalty: | |
| def test_zero_buffer_always_zero(self): | |
| assert length_penalty(1000, 999, 0) == 0.0 | |
| assert length_penalty(1000, 1, 0) == 0.0 | |
| def test_inside_buffer_zone(self): | |
| # buffer=100, max=1000 → penalty zone is 900-1000 | |
| penalty = length_penalty(1000, 950, 100) | |
| assert penalty < 0.0 | |
| def test_outside_buffer_zone(self): | |
| # sequence_length <= max_length - buffer_tokens → no penalty | |
| assert length_penalty(1000, 800, 100) == 0.0 | |
| def test_at_boundary(self): | |
| # sequence_length == max_length - buffer_tokens → no penalty (boundary inclusive at 0) | |
| assert length_penalty(1000, 900, 100) == 0.0 | |
| def test_penalty_increases_toward_max(self): | |
| p1 = length_penalty(1000, 910, 100) | |
| p2 = length_penalty(1000, 960, 100) | |
| assert p2 < p1 # closer to max → more negative | |
| # apply_score_threshold | |
| class TestApplyScoreThreshold: | |
| def test_zero(self): | |
| assert apply_score_threshold(0) == 0.0 | |
| def test_partial_credit_collapses_to_one(self): | |
| for s in [1, 2, 3, 4, 5]: | |
| assert apply_score_threshold(float(s)) == 1.0 | |
| def test_high_scores_pass_through(self): | |
| assert apply_score_threshold(6.0) == 6.0 | |
| assert apply_score_threshold(7.0) == 7.0 | |
| # MathProofRubric unit tests | |
| class TestMathProofRubricNormalizeReward: | |
| def test_score_zero(self): | |
| rubric = MathProofRubric(grader_model="test", api_key="x") | |
| assert rubric.normalize_reward(0) == pytest.approx(0.0) | |
| def test_score_max(self): | |
| rubric = MathProofRubric(grader_model="test", api_key="x") | |
| assert rubric.normalize_reward(MAX_SCORE) == pytest.approx(1.0) | |
| def test_score_mid(self): | |
| rubric = MathProofRubric(grader_model="test", api_key="x") | |
| assert rubric.normalize_reward(3) == pytest.approx(3 / 7) | |
| def test_custom_threshold_collapses_partial(self): | |
| rubric = MathProofRubric(grader_model="test", api_key="x", custom_threshold=True) | |
| # scores 1-5 collapse to 1, normalized to 1/7 | |
| for score in [1, 2, 3, 4, 5]: | |
| assert rubric.normalize_reward(score) == pytest.approx(1 / 7) | |
| def test_custom_threshold_preserves_high(self): | |
| rubric = MathProofRubric(grader_model="test", api_key="x", custom_threshold=True) | |
| assert rubric.normalize_reward(6) == pytest.approx(6 / 7) | |
| assert rubric.normalize_reward(7) == pytest.approx(1.0) | |
| class TestMathProofRubricParseResponse: | |
| def _rubric(self): | |
| return MathProofRubric(grader_model="test", api_key="x") | |
| def test_parses_score_tag(self): | |
| rubric = self._rubric() | |
| score, feedback = rubric._parse_response("Good proof. <score>5</score>") | |
| assert score == 5 | |
| assert "<score>5</score>" in feedback | |
| def test_clamps_above_max(self): | |
| rubric = self._rubric() | |
| score, _ = rubric._parse_response("<score>99</score>") | |
| assert score == MAX_SCORE | |
| def test_clamps_below_zero(self): | |
| rubric = self._rubric() | |
| # Negative values are clamped by max(0, ...) — regex only matches digits | |
| # so a missing tag falls through to 0. | |
| score, _ = rubric._parse_response("<score>0</score>") | |
| assert score == 0 | |
| def test_missing_tag_returns_zero(self): | |
| rubric = self._rubric() | |
| score, feedback = rubric._parse_response("No tag here.") | |
| assert score == 0 | |
| assert feedback == "No tag here." | |
| class TestMathProofRubricBuildPrompt: | |
| def _rubric(self, template: str = "") -> MathProofRubric: | |
| return MathProofRubric(grader_model="test", api_key="x", prompt_template=template) | |
| def test_fallback_prompt_contains_proof(self): | |
| rubric = self._rubric() | |
| prompt = rubric._build_prompt("My proof", "Problem X", "Ref sol", "Rubric") | |
| assert "My proof" in prompt | |
| assert "Problem X" in prompt | |
| def test_template_substitution(self): | |
| template = "P:{problem} S:{solution} R:{marking_scheme} H:{human_solution}" | |
| rubric = self._rubric(template) | |
| prompt = rubric._build_prompt("my_proof", "my_problem", "my_ref", "my_rubric") | |
| assert "my_proof" in prompt | |
| assert "my_problem" in prompt | |
| assert "my_rubric" in prompt | |
| assert "my_ref" in prompt | |
| class TestMathProofRubricGrade: | |
| """Tests for MathProofRubric.grade() — LLM calls are mocked.""" | |
| def _rubric(self, **kwargs) -> MathProofRubric: | |
| return MathProofRubric(grader_model="test", api_key="x", **kwargs) | |
| async def test_successful_grade(self): | |
| rubric = self._rubric() | |
| response_text = "The proof is correct. <score>7</score>" | |
| with patch.object(rubric, "_call_llm", new=AsyncMock(return_value=response_text)): | |
| result = await rubric.grade("Good proof.", "Problem", "Ref", "Guidelines") | |
| assert result.score == 7 | |
| assert result.reward == pytest.approx(1.0) | |
| assert result.metrics["verifier/rollouts/success"] == 1 | |
| async def test_score_normalized(self): | |
| rubric = self._rubric() | |
| response_text = "<score>3</score>" | |
| with patch.object(rubric, "_call_llm", new=AsyncMock(return_value=response_text)): | |
| result = await rubric.grade("Partial proof.", "P", "R", "G") | |
| assert result.score == 3 | |
| assert result.reward == pytest.approx(3 / 7) | |
| async def test_empty_proof_fast_path(self): | |
| rubric = self._rubric() | |
| result = await rubric.grade(" ", "Problem", "Ref", "Guidelines") | |
| assert result.score == 0 | |
| assert result.reward == 0.0 | |
| assert result.metrics["verifier/rollouts/failure"] == 1 | |
| assert result.metrics["verifier/failures/no_input"] == 1 | |
| async def test_missing_score_tag_records_metric(self): | |
| rubric = self._rubric() | |
| response_text = "Looks good but I forgot the tag." | |
| with patch.object(rubric, "_call_llm", new=AsyncMock(return_value=response_text)): | |
| result = await rubric.grade("Some proof.", "P", "R", "G") | |
| assert result.metrics["verifier/failures/no_score_tag"] == 1 | |
| assert result.metrics["verifier/rollouts/failure"] == 1 | |
| async def test_all_attempts_exhausted(self): | |
| import openai as _openai | |
| rubric = self._rubric(max_retries=2, retry_backoff=[0, 0]) | |
| with patch.object( | |
| rubric, | |
| "_call_llm", | |
| new=AsyncMock( | |
| side_effect=_openai.RateLimitError("rate", response=MagicMock(), body={}) | |
| ), | |
| ): | |
| result = await rubric.grade("My proof.", "P", "R", "G") | |
| assert result.score == 0 | |
| assert result.reward == 0.0 | |
| assert result.metrics["verifier/failures/all_attempts_failed"] == 1 | |
| assert result.metrics["verifier/failures/rate_limit"] >= 1 | |
| async def test_custom_threshold_applied(self): | |
| rubric = self._rubric(custom_threshold=True) | |
| response_text = "<score>4</score>" | |
| with patch.object(rubric, "_call_llm", new=AsyncMock(return_value=response_text)): | |
| result = await rubric.grade("Partial proof.", "P", "R", "G") | |
| # Score 4 → thresholded to 1 → reward 1/7 | |
| assert result.reward == pytest.approx(1 / 7) | |
| async def test_verifier_metrics_present(self): | |
| rubric = self._rubric() | |
| with patch.object(rubric, "_call_llm", new=AsyncMock(return_value="<score>5</score>")): | |
| result = await rubric.grade("proof", "P", "R", "G") | |
| expected_keys = [ | |
| "verifier/rollouts/success", | |
| "verifier/rollouts/failure", | |
| "verifier/failures/timeout", | |
| "verifier/failures/rate_limit", | |
| "verifier/failures/no_score_tag", | |
| "verifier/failures/num_retries", | |
| "verifier/runtime/latency_per_request", | |
| "verifier/runtime/input_tokens", | |
| "verifier/runtime/output_tokens", | |
| ] | |
| for key in expected_keys: | |
| assert key in result.metrics, f"Missing metric: {key}" | |
| # remove_reasoning | |
| class TestRemoveReasoning: | |
| def test_no_delimiter_returns_original(self): | |
| assert remove_reasoning("hello world") == "hello world" | |
| def test_empty_delimiters_returns_original(self): | |
| assert remove_reasoning("hello", []) == "hello" | |
| def test_delimiter_present_strips_before(self): | |
| text = "<think>step1 step2</think>Final answer." | |
| result = remove_reasoning(text, ["</think>"]) | |
| assert result == "Final answer." | |
| def test_delimiter_absent_returns_empty_string(self): | |
| # When delimiters are configured but none present → empty (signal to grade full text) | |
| result = remove_reasoning("No think tag here.", ["</think>"]) | |
| assert result == "" | |
| def test_multiple_delimiters_first_match_wins(self): | |
| text = "<think>reasoning</think>answer" | |
| result = remove_reasoning(text, ["</think>", "</reason>"]) | |
| assert result == "answer" | |
| # QEDMathEnvironment: dataset loading | |
| class TestLoadProblems: | |
| def test_bootstrap_when_no_dataset(self): | |
| problems = load_problems(None) | |
| assert len(problems) >= 1 | |
| assert all("problem" in p for p in problems) | |
| def test_load_local_jsonl(self, tmp_path): | |
| jsonl_file = tmp_path / "problems.jsonl" | |
| rows = [ | |
| { | |
| "problem": "Prove 1+1=2.", | |
| "reference_solution": "By definition.", | |
| "problem_id": "p001", | |
| }, | |
| { | |
| "problem": "Prove 2+2=4.", | |
| "reference_solution": "By arithmetic.", | |
| "problem_id": "p002", | |
| }, | |
| ] | |
| with jsonl_file.open("w") as f: | |
| for row in rows: | |
| f.write(json.dumps(row) + "\n") | |
| problems = load_problems(str(jsonl_file)) | |
| assert len(problems) == 2 | |
| assert problems[0]["problem"] == "Prove 1+1=2." | |
| assert problems[1]["problem_id"] == "p002" | |
| def test_load_local_json_list(self, tmp_path): | |
| json_file = tmp_path / "problems.json" | |
| rows = [ | |
| {"problem": "P1", "reference_solution": "S1"}, | |
| {"problem": "P2", "reference_solution": "S2"}, | |
| ] | |
| json_file.write_text(json.dumps(rows)) | |
| problems = load_problems(str(json_file)) | |
| assert len(problems) == 2 | |
| def test_answer_mode_boxed_wrapping(self, tmp_path): | |
| jsonl_file = tmp_path / "answers.jsonl" | |
| row = { | |
| "problem": "2+2=?", | |
| "reference_solution": "4", | |
| "evaluation_mode": "answer", | |
| } | |
| jsonl_file.write_text(json.dumps(row) + "\n") | |
| problems = load_problems(str(jsonl_file)) | |
| # Answer-mode reference solution should be wrapped in \boxed{} | |
| assert "\\boxed{4}" in problems[0]["reference_solution"] | |
| def test_select_by_problem_id(self, tmp_path): | |
| jsonl_file = tmp_path / "probs.jsonl" | |
| rows = [ | |
| {"problem": "P1", "problem_id": "id_alpha"}, | |
| {"problem": "P2", "problem_id": "id_beta"}, | |
| ] | |
| jsonl_file.write_text("\n".join(json.dumps(r) for r in rows)) | |
| env = QEDMathEnvironment(dataset_path=str(jsonl_file)) | |
| obs = env.reset(problem_id="id_beta") | |
| assert isinstance(obs, ProblemObservation) | |
| assert obs.problem_id == "id_beta" | |
| assert obs.problem == "P2" | |
| # QEDMathEnvironment: reset | |
| class TestQEDMathEnvironmentReset: | |
| def test_reset_returns_problem_observation(self): | |
| env = _make_env() | |
| obs = env.reset() | |
| assert isinstance(obs, ProblemObservation) | |
| assert obs.problem != "" | |
| assert obs.done is False | |
| assert obs.reward == 0.0 | |
| def test_reset_fields_populated(self): | |
| env = _make_env() | |
| obs = env.reset() | |
| assert isinstance(obs, ProblemObservation) | |
| assert obs.problem_id != "" | |
| assert obs.dataset_source != "" | |
| assert obs.max_attempts >= 1 | |
| def test_reset_increments_counter(self): | |
| env = _make_env() | |
| env.reset() | |
| env.reset() | |
| assert env._reset_count == 2 | |
| def test_reset_clears_attempt_count(self): | |
| env = _make_env() | |
| env.reset() | |
| env._attempt_count = 99 | |
| env.reset() | |
| assert env._attempt_count == 0 | |
| # QEDMathEnvironment: get_problem_payload / get_grading_guidelines_payload | |
| class TestGetProblemPayload: | |
| def test_without_reset_returns_error(self): | |
| env = _make_env() | |
| payload = env.get_problem_payload() | |
| assert "error" in payload | |
| def test_after_reset_returns_problem(self): | |
| env = _make_env() | |
| env.reset() | |
| payload = env.get_problem_payload() | |
| assert "problem" in payload | |
| assert "error" not in payload | |
| assert "problem_id" in payload | |
| assert "reference_solution" not in payload | |
| def test_answer_mode_includes_reference_solution(self): | |
| env = _make_env_with_problem(ANSWER_PROBLEM) | |
| env.reset() | |
| payload = env.get_problem_payload() | |
| assert payload["reference_solution"] == r"\boxed{4}" | |
| def test_grading_guidelines_payload_without_reset(self): | |
| env = _make_env() | |
| payload = env.get_grading_guidelines_payload() | |
| assert "error" in payload | |
| def test_grading_guidelines_payload_after_reset(self): | |
| env = _make_env() | |
| env.reset() | |
| payload = env.get_grading_guidelines_payload() | |
| assert "grading_guidelines" in payload | |
| assert "error" not in payload | |
| # QEDMathEnvironment: _verify_math (answer-mode verification) | |
| class TestVerifyMath: | |
| def test_correct_answer(self): | |
| result = QEDMathEnvironment._verify_math(r"\boxed{4}", r"\boxed{4}") | |
| assert result == "correct" | |
| def test_wrong_answer(self): | |
| result = QEDMathEnvironment._verify_math(r"\boxed{5}", r"\boxed{4}") | |
| assert result == "wrong" | |
| def test_no_boxed_returns_no_answer(self): | |
| result = QEDMathEnvironment._verify_math("The answer is 4.", r"\boxed{4}") | |
| assert result == "no_answer" | |
| def test_empty_boxed_returns_unparsable(self): | |
| result = QEDMathEnvironment._verify_math(r"\boxed{}", r"\boxed{4}") | |
| assert result == "unparsable" | |
| # QEDMathEnvironment: _grade_answer_submission | |
| class TestGradeAnswerSubmission: | |
| async def test_correct_gives_max_reward(self): | |
| env = _make_env() | |
| try: | |
| result = await env._grade_answer_submission(r"\boxed{4}", r"\boxed{4}") | |
| assert result.score == 7 | |
| assert result.reward == pytest.approx(1.0) | |
| assert result.metrics["verifier/rollouts/success"] == 1 | |
| finally: | |
| await env._verifier_service.stop() | |
| async def test_wrong_gives_zero(self): | |
| env = _make_env() | |
| try: | |
| result = await env._grade_answer_submission(r"\boxed{5}", r"\boxed{4}") | |
| assert result.score == 0 | |
| assert result.reward == pytest.approx(0.0) | |
| assert result.metrics["verifier/rollouts/failure"] == 1 | |
| finally: | |
| await env._verifier_service.stop() | |
| async def test_missing_boxed_gives_zero(self): | |
| env = _make_env() | |
| try: | |
| result = await env._grade_answer_submission("The answer is 4.", r"\boxed{4}") | |
| assert result.score == 0 | |
| assert result.reward == pytest.approx(0.0) | |
| finally: | |
| await env._verifier_service.stop() | |
| async def test_answer_grading_uses_verifier_service(self): | |
| env = _make_env() | |
| try: | |
| result = await env._grade_answer_submission(r"\boxed{4}", r"\boxed{4}") | |
| assert result.score == 7 | |
| assert result.reward == pytest.approx(1.0) | |
| assert env._verifier_service._executor is not None | |
| finally: | |
| await env._verifier_service.stop() | |
| async def test_answer_grading_wrong_answer_via_service(self): | |
| env = _make_env() | |
| try: | |
| result = await env._grade_answer_submission(r"\boxed{5}", r"\boxed{4}") | |
| assert result.score == 0 | |
| assert result.reward == pytest.approx(0.0) | |
| finally: | |
| await env._verifier_service.stop() | |
| async def test_timeout_metrics_from_service(self): | |
| env = _make_env() | |
| try: | |
| timeout_response = VerifyResponse( | |
| request_id="req-timeout", | |
| status="timeout", | |
| elapsed_ms=12.5, | |
| retry_count=1, | |
| error_type="ClientTimeout", | |
| error_message="simulated timeout", | |
| ) | |
| with patch.object( | |
| env._verifier_service, | |
| "verify_answer", | |
| new=AsyncMock(return_value=timeout_response), | |
| ): | |
| result = await env._grade_answer_submission(r"\boxed{4}", r"\boxed{4}") | |
| assert result.score == 0 | |
| assert result.metrics["verifier/failures/timeout"] == 1 | |
| assert result.metrics["verifier/failures/num_retries"] == 1 | |
| assert result.metrics["verifier/runtime/latency_per_request"] == pytest.approx(12.5) | |
| assert "verifier/workers/restart_count" in result.metrics | |
| assert "verifier/workers/worker_restarted" in result.metrics | |
| assert "verifier/queue/depth" in result.metrics | |
| finally: | |
| await env._verifier_service.stop() | |
| async def test_shutdown_verifier_service_hook(self): | |
| env = _make_env() | |
| await env._grade_answer_submission(r"\boxed{4}", r"\boxed{4}") | |
| assert env._verifier_service._executor is not None | |
| await env.shutdown_verifier_service() | |
| assert env._verifier_service._executor is None | |
| # QEDMathEnvironment: reward shaping | |
| class TestRewardShaping: | |
| def test_no_tokens_no_shaping(self): | |
| env = _make_env() | |
| env._discount_factor = 0.99 | |
| assert env._apply_reward_shaping(1.0, 0) == pytest.approx(1.0) | |
| def test_discount_factor_applied(self): | |
| env = _make_env() | |
| env._discount_factor = 0.99 | |
| shaped = env._apply_reward_shaping(1.0, 10) | |
| assert shaped == pytest.approx(0.99**10) | |
| def test_discount_decreases_reward_with_tokens(self): | |
| env = _make_env() | |
| env._discount_factor = 0.999 | |
| r1 = env._apply_reward_shaping(1.0, 100) | |
| r2 = env._apply_reward_shaping(1.0, 500) | |
| assert r2 < r1 | |
| def test_length_penalty_applied_inside_buffer(self): | |
| env = _make_env() | |
| env._discount_factor = 1.0 | |
| env._max_tokens = 1000 | |
| env._buffer_tokens = 100 | |
| # Token count inside penalty zone (950 > 900) | |
| shaped = env._apply_reward_shaping(1.0, 950) | |
| assert shaped < 1.0 | |
| def test_length_penalty_zero_outside_buffer(self): | |
| env = _make_env() | |
| env._discount_factor = 1.0 | |
| env._max_tokens = 1000 | |
| env._buffer_tokens = 100 | |
| # Token count outside penalty zone (800 <= 900) | |
| shaped = env._apply_reward_shaping(1.0, 800) | |
| assert shaped == pytest.approx(1.0) | |
| def test_score_thresholding_via_rubric(self): | |
| env = _make_env() | |
| # Mimic custom_threshold=True: scores 1-5 → reward = 1/7 | |
| rubric = MathProofRubric(grader_model="test", api_key="x", custom_threshold=True) | |
| env._rubric = rubric | |
| for score in [1, 2, 3, 4, 5]: | |
| assert rubric.normalize_reward(score) == pytest.approx(1 / 7) | |
| # QEDMathEnvironment: reasoning stripping | |
| class TestReasoningStripping: | |
| def test_strip_think_tag(self): | |
| env = _make_env() | |
| env._reasoning_delimiters = ["</think>"] | |
| result = env._strip_reasoning("<think>step by step</think>Final answer.") | |
| assert result == "Final answer." | |
| def test_no_delimiter_configured_passthrough(self): | |
| env = _make_env() | |
| env._reasoning_delimiters = None | |
| result = env._strip_reasoning("Some text.") | |
| assert result == "Some text." | |
| def test_delimiter_absent_returns_empty(self): | |
| env = _make_env() | |
| env._reasoning_delimiters = ["</think>"] | |
| result = env._strip_reasoning("No think tag here.") | |
| assert result == "" | |
| async def test_grading_falls_back_to_full_text_when_strip_empty(self): | |
| """When stripping produces empty output, the full proof text is graded.""" | |
| env = _make_env_with_problem(BOOTSTRAP_PROBLEM) | |
| env.reset() | |
| env._reasoning_delimiters = ["</think>"] | |
| graded_input: list[str] = [] | |
| async def mock_grade(proof, *args, **kwargs): | |
| graded_input.append(proof) | |
| return GradingResult(score=7, feedback="ok", reward=1.0) | |
| with patch.object(env._rubric, "grade", new=mock_grade): | |
| await env._grade_submission("proof without think tag") | |
| # Stripping produced "" → falls back to original text | |
| assert graded_input[0] == "proof without think tag" | |
| # QEDMathEnvironment: submit_proof_payload (mocked grader) | |
| class TestSubmitProofPayload: | |
| async def test_no_active_problem(self): | |
| env = _make_env() | |
| payload = await env.submit_proof_payload("Some proof.") | |
| assert payload["score"] == 0 | |
| assert payload["done"] is True | |
| assert "error" in payload["metadata"] | |
| async def test_successful_submission(self): | |
| env = _make_env_with_problem(BOOTSTRAP_PROBLEM) | |
| env.reset() | |
| mock_result = GradingResult(score=7, feedback="Correct.", reward=1.0) | |
| with patch.object(env._rubric, "grade", new=AsyncMock(return_value=mock_result)): | |
| payload = await env.submit_proof_payload("Let a=2m, b=2n, a+b=2(m+n).") | |
| assert payload["score"] == 7 | |
| assert payload["reward"] == pytest.approx(1.0) | |
| assert payload["done"] is True | |
| assert payload["is_correct"] is True | |
| async def test_empty_proof_returns_zero(self): | |
| env = _make_env_with_problem(BOOTSTRAP_PROBLEM) | |
| env.reset() | |
| payload = await env.submit_proof_payload(" ") | |
| assert payload["score"] == 0 | |
| assert payload["reward"] == 0.0 | |
| async def test_attempt_number_increments(self): | |
| env = _make_env_with_problem(MULTI_STEP_PROBLEM) | |
| env.reset() | |
| mock_result = GradingResult(score=3, feedback="Partial.", reward=3 / 7) | |
| with patch.object(env._rubric, "grade", new=AsyncMock(return_value=mock_result)): | |
| p1 = await env.submit_proof_payload("attempt 1") | |
| p2 = await env.submit_proof_payload("attempt 2") | |
| assert p1["attempt_number"] == 1 | |
| assert p2["attempt_number"] == 2 | |
| async def test_verifier_metrics_in_metadata(self): | |
| env = _make_env_with_problem(BOOTSTRAP_PROBLEM) | |
| env.reset() | |
| metrics = { | |
| "verifier/rollouts/success": 1, | |
| "verifier/rollouts/failure": 0, | |
| "verifier/failures/timeout": 0, | |
| "verifier/failures/rate_limit": 0, | |
| "verifier/failures/no_input": 0, | |
| "verifier/failures/no_score_tag": 0, | |
| "verifier/failures/all_attempts_failed": 0, | |
| "verifier/failures/num_retries": 0, | |
| "verifier/runtime/latency_per_request": 0.01, | |
| "verifier/runtime/input_tokens": 50, | |
| "verifier/runtime/output_tokens": 20, | |
| } | |
| mock_result = GradingResult(score=6, feedback="Good.", reward=6 / 7, metrics=metrics) | |
| with patch.object(env._rubric, "grade", new=AsyncMock(return_value=mock_result)): | |
| payload = await env.submit_proof_payload("Good proof here.") | |
| vm = payload["metadata"]["verifier_metrics"] | |
| assert vm["verifier/rollouts/success"] == 1 | |
| assert "verifier/runtime/latency_per_request" in vm | |
| assert "verifier/runtime/input_tokens" in vm | |
| assert "verifier/runtime/output_tokens" in vm | |
| # Multi-step problems | |
| class TestMultiStepProblems: | |
| async def test_first_attempt_not_done(self): | |
| env = _make_env_with_problem(MULTI_STEP_PROBLEM) | |
| env.reset() | |
| # Score below success threshold → not done yet | |
| mock_result = GradingResult(score=3, feedback="Keep trying.", reward=3 / 7) | |
| with patch.object(env._rubric, "grade", new=AsyncMock(return_value=mock_result)): | |
| payload = await env.submit_proof_payload("First attempt.") | |
| assert payload["done"] is False | |
| assert payload["attempts_remaining"] > 0 | |
| async def test_exhausting_attempts_sets_done(self): | |
| env = _make_env_with_problem(MULTI_STEP_PROBLEM) | |
| env.reset() | |
| mock_result = GradingResult(score=1, feedback="Not enough.", reward=1 / 7) | |
| # 3 attempts configured | |
| with patch.object(env._rubric, "grade", new=AsyncMock(return_value=mock_result)): | |
| await env.submit_proof_payload("attempt 1") | |
| await env.submit_proof_payload("attempt 2") | |
| payload = await env.submit_proof_payload("attempt 3") | |
| assert payload["done"] is True | |
| assert payload["attempts_remaining"] == 0 | |
| async def test_correct_submission_ends_episode(self): | |
| env = _make_env_with_problem(MULTI_STEP_PROBLEM) | |
| env.reset() | |
| # Score >= 6 (success_score_threshold default) → done immediately | |
| mock_result = GradingResult(score=7, feedback="Correct!", reward=1.0) | |
| with patch.object(env._rubric, "grade", new=AsyncMock(return_value=mock_result)): | |
| payload = await env.submit_proof_payload("Full proof.") | |
| assert payload["done"] is True | |
| assert payload["is_correct"] is True | |
| # Original problem field (QED-Nano RC stream compatibility) | |
| class TestOriginalProblemField: | |
| async def test_original_problem_used_for_grading(self): | |
| """When original_problem is set it should be passed to the rubric.""" | |
| raw = dict(BOOTSTRAP_PROBLEM) | |
| raw["original_problem"] = "The ORIGINAL phrasing of the problem." | |
| env = _make_env_with_problem(raw) | |
| env.reset() | |
| graded_problems: list[str] = [] | |
| async def mock_grade(proof, problem, *args, **kwargs): | |
| graded_problems.append(problem) | |
| return GradingResult(score=5, feedback="ok", reward=5 / 7) | |
| with patch.object(env._rubric, "grade", new=mock_grade): | |
| await env.submit_proof_payload("Some proof.") | |
| assert graded_problems[0] == "The ORIGINAL phrasing of the problem." | |
| def test_normalize_problem_preserves_original_problem(self): | |
| raw = dict(BOOTSTRAP_PROBLEM) | |
| raw["original_problem"] = "Original phrasing." | |
| normalized = _normalize_problem(raw, 0, "test") | |
| assert normalized["original_problem"] == "Original phrasing." | |
| def test_normalize_problem_original_problem_none_when_absent(self): | |
| normalized = _normalize_problem(BOOTSTRAP_PROBLEM, 0, "test") | |
| assert normalized["original_problem"] is None | |
| # Models | |
| class TestModels: | |
| def test_problem_observation_defaults(self): | |
| obs = ProblemObservation() | |
| assert obs.problem == "" | |
| assert obs.problem_type == "proof" | |
| assert obs.max_attempts == 1 | |
| def test_proof_submission_observation_defaults(self): | |
| obs = ProofSubmissionObservation() | |
| assert obs.score == 0 | |
| assert obs.reward == 0.0 | |
| assert obs.done is True | |
| assert obs.is_correct is False | |
| def test_problem_observation_round_trips(self): | |
| obs = ProblemObservation( | |
| problem="P", | |
| reference_solution="R", | |
| problem_id="x", | |
| problem_type="answer", | |
| max_attempts=2, | |
| ) | |
| data = obs.model_dump() | |
| restored = ProblemObservation(**data) | |
| assert restored.problem == "P" | |
| assert restored.problem_type == "answer" | |
| assert restored.max_attempts == 2 | |
| def test_proof_submission_observation_round_trips(self): | |
| obs = ProofSubmissionObservation( | |
| proof="proof text", | |
| score=5, | |
| feedback="nice", | |
| reward=5 / 7, | |
| done=False, | |
| attempts_remaining=2, | |
| is_correct=False, | |
| ) | |
| data = obs.model_dump() | |
| restored = ProofSubmissionObservation(**data) | |
| assert restored.score == 5 | |
| assert restored.attempts_remaining == 2 | |
| # MCP tool registration | |
| class TestMCPToolRegistration: | |
| def test_tools_registered_on_init(self): | |
| env = _make_env() | |
| tools = get_server_tools(env.mcp_server) | |
| assert "get_problem" in tools | |
| assert "submit_proof" in tools | |
| assert "get_grading_guidelines" in tools | |
| def test_get_problem_tool_callable_after_reset(self): | |
| env = _make_env() | |
| env.reset() | |
| tools = get_server_tools(env.mcp_server) | |
| assert "get_problem" in tools | |
| def test_get_grading_guidelines_tool_callable_after_reset(self): | |
| env = _make_env() | |
| env.reset() | |
| payload = env.get_grading_guidelines_payload() | |
| assert "grading_guidelines" in payload | |
| # MathVerifierService | |
| class TestVerifyAnswerWorker: | |
| """Tests for the _verify_answer_worker function (process worker entry point).""" | |
| def test_correct_answer(self): | |
| request = VerifyRequest( | |
| request_id="req-1", | |
| prediction=r"\boxed{4}", | |
| gold="4", | |
| strict=True, | |
| timeout_seconds=1, | |
| ) | |
| response = _verify_answer_worker(request) | |
| assert response.status == "correct" | |
| assert response.request_id == "req-1" | |
| assert response.elapsed_ms >= 0 | |
| def test_wrong_answer(self): | |
| request = VerifyRequest( | |
| request_id="req-2", | |
| prediction=r"\boxed{5}", | |
| gold="4", | |
| strict=True, | |
| timeout_seconds=1, | |
| ) | |
| response = _verify_answer_worker(request) | |
| assert response.status == "wrong" | |
| def test_no_boxed_answer(self): | |
| request = VerifyRequest( | |
| request_id="req-3", | |
| prediction="The answer is 4", | |
| gold="4", | |
| strict=True, | |
| timeout_seconds=1, | |
| ) | |
| response = _verify_answer_worker(request) | |
| assert response.status == "no_answer" | |
| def test_empty_boxed(self): | |
| request = VerifyRequest( | |
| request_id="req-4", | |
| prediction=r"\boxed{}", | |
| gold="4", | |
| strict=True, | |
| timeout_seconds=1, | |
| ) | |
| response = _verify_answer_worker(request) | |
| assert response.status == "unparsable" | |
| def test_unparsable_prediction(self): | |
| request = VerifyRequest( | |
| request_id="req-5", | |
| prediction=r"\boxed{@#$%^&*()}", | |
| gold="4", | |
| strict=True, | |
| timeout_seconds=1, | |
| ) | |
| response = _verify_answer_worker(request) | |
| # Should be unparsable or internal_error depending on what math_verify does | |
| assert response.status in {"unparsable", "internal_error"} | |
| def test_unparsable_gold(self): | |
| request = VerifyRequest( | |
| request_id="req-6", | |
| prediction=r"\boxed{4}", | |
| gold="@#$%^&*()", | |
| strict=True, | |
| timeout_seconds=1, | |
| ) | |
| response = _verify_answer_worker(request) | |
| assert response.status == "unparsable" | |
| def test_prediction_too_long(self): | |
| request = VerifyRequest( | |
| request_id="req-7", | |
| prediction=r"\boxed{" + "x" * 2000 + "}", | |
| gold="4", | |
| strict=True, | |
| timeout_seconds=1, | |
| max_prediction_length=1000, | |
| ) | |
| response = _verify_answer_worker(request) | |
| assert response.status == "unparsable" | |
| def test_nested_braces(self): | |
| """Test that nested braces in \\boxed{...} are handled correctly.""" | |
| request = VerifyRequest( | |
| request_id="req-8", | |
| prediction=r"\boxed{\frac{1}{2}}", | |
| gold=r"\frac{1}{2}", | |
| strict=True, | |
| timeout_seconds=1, | |
| ) | |
| response = _verify_answer_worker(request) | |
| # Should parse and verify correctly | |
| assert response.status in {"correct", "wrong"} | |
| class TestMathVerifierService: | |
| """Tests for the MathVerifierService async interface.""" | |
| async def test_service_initialization(self): | |
| service = MathVerifierService( | |
| max_workers=2, | |
| queue_size=50, | |
| request_timeout_seconds=3.0, | |
| ) | |
| assert service.max_workers == 2 | |
| assert service.queue_size == 50 | |
| assert service.request_timeout_seconds == 3.0 | |
| async def test_start_and_stop(self): | |
| service = MathVerifierService(max_workers=1) | |
| # Start should be idempotent | |
| await service.start() | |
| assert service._executor is not None | |
| await service.start() # Should not fail if already started | |
| await service.stop() | |
| assert service._executor is None | |
| async def test_verify_answer_without_start(self): | |
| """verify_answer should auto-start the process pool when needed.""" | |
| service = MathVerifierService(max_workers=1) | |
| response = await service.verify_answer( | |
| prediction=r"\boxed{4}", | |
| gold="4", | |
| strict=True, | |
| ) | |
| assert response.status == "correct" | |
| assert service._executor is not None | |
| await service.stop() | |
| async def test_verify_answer_correct(self): | |
| service = MathVerifierService(max_workers=1) | |
| await service.start() | |
| try: | |
| response = await service.verify_answer( | |
| prediction=r"\boxed{4}", | |
| gold="4", | |
| strict=True, | |
| ) | |
| assert response.status == "correct" | |
| assert response.request_id.startswith("req-") | |
| assert response.elapsed_ms >= 0 | |
| finally: | |
| await service.stop() | |
| async def test_verify_answer_wrong(self): | |
| service = MathVerifierService(max_workers=1) | |
| await service.start() | |
| try: | |
| response = await service.verify_answer( | |
| prediction=r"\boxed{5}", | |
| gold="4", | |
| strict=True, | |
| ) | |
| assert response.status == "wrong" | |
| finally: | |
| await service.stop() | |
| async def test_verify_answer_no_answer(self): | |
| service = MathVerifierService(max_workers=1) | |
| await service.start() | |
| try: | |
| response = await service.verify_answer( | |
| prediction="The answer is 4", | |
| gold="4", | |
| strict=True, | |
| ) | |
| assert response.status == "no_answer" | |
| finally: | |
| await service.stop() | |
| async def test_verify_answer_timeout(self): | |
| """Test that client-side timeout is triggered.""" | |
| service = MathVerifierService( | |
| max_workers=1, | |
| request_timeout_seconds=0.001, # Very short timeout to force timeout | |
| ) | |
| await service.start() | |
| try: | |
| response = await service.verify_answer( | |
| prediction=r"\boxed{4}", | |
| gold="4", | |
| strict=True, | |
| ) | |
| # Should either timeout or succeed quickly; mostly just testing the path | |
| assert response.request_id.startswith("req-") | |
| finally: | |
| await service.stop() | |
| async def test_concurrent_requests(self): | |
| """Test multiple concurrent verification requests.""" | |
| import asyncio | |
| service = MathVerifierService(max_workers=2) | |
| await service.start() | |
| try: | |
| # Fire multiple requests concurrently | |
| tasks = [ | |
| service.verify_answer(r"\boxed{1}", "1"), | |
| service.verify_answer(r"\boxed{2}", "2"), | |
| service.verify_answer(r"\boxed{3}", "3"), | |
| service.verify_answer(r"\boxed{4}", "4"), | |
| ] | |
| responses = await asyncio.gather(*tasks) | |
| assert len(responses) == 4 | |
| assert all(r.status == "correct" for r in responses) | |
| assert len(set(r.request_id for r in responses)) == 4 # All unique IDs | |
| finally: | |
| await service.stop() | |
| async def test_default_strict_mode(self): | |
| """Test that default strict mode is applied.""" | |
| service = MathVerifierService(max_workers=1, strict=True) | |
| assert service.strict is True | |
| await service.start() | |
| try: | |
| response = await service.verify_answer( | |
| prediction=r"\boxed{4.0}", | |
| gold="4", | |
| strict=True, # Explicitly strict | |
| ) | |
| # Should handle the request | |
| assert response.request_id.startswith("req-") | |
| finally: | |
| await service.stop() | |
| async def test_queue_backpressure_rejects_when_saturated(self): | |
| import asyncio | |
| service = MathVerifierService(max_workers=1, queue_size=1) | |
| await service.start() | |
| gate_started = asyncio.Event() | |
| gate_release = asyncio.Event() | |
| async def slow_once(request): | |
| gate_started.set() | |
| await gate_release.wait() | |
| return _verify_answer_worker(request) | |
| try: | |
| with patch.object(service, "_run_request_once", side_effect=slow_once): | |
| task1 = asyncio.create_task(service.verify_answer(r"\boxed{1}", "1")) | |
| await gate_started.wait() | |
| response2 = await service.verify_answer(r"\boxed{2}", "2") | |
| assert response2.status == "internal_error" | |
| assert response2.error_type == "QueueFull" | |
| gate_release.set() | |
| response1 = await task1 | |
| assert response1.status in {"correct", "wrong"} | |
| finally: | |
| await service.stop() | |
| async def test_retry_policy_retries_transient_timeout(self): | |
| service = MathVerifierService(max_workers=1, max_retries=1) | |
| await service.start() | |
| call_count = 0 | |
| async def flaky_once(request): | |
| nonlocal call_count | |
| call_count += 1 | |
| if call_count == 1: | |
| return VerifyResponse( | |
| request_id=request.request_id, | |
| status="timeout", | |
| elapsed_ms=1.0, | |
| retry_count=0, | |
| worker_id=None, | |
| worker_restarted=False, | |
| error_type="ClientTimeout", | |
| error_message="simulated timeout", | |
| ) | |
| return _verify_answer_worker(request) | |
| try: | |
| with patch.object(service, "_run_request_once", side_effect=flaky_once): | |
| response = await service.verify_answer(r"\boxed{4}", "4") | |
| assert response.status == "correct" | |
| assert response.retry_count == 1 | |
| assert call_count == 2 | |
| finally: | |
| await service.stop() | |
| async def test_worker_crash_restarts_pool_and_redispatches(self): | |
| service = MathVerifierService(max_workers=1, max_retries=1) | |
| await service.start() | |
| call_count = 0 | |
| restart_calls = {"count": 0} | |
| original_restart = service._restart_pool | |
| async def crashing_once(request): | |
| nonlocal call_count | |
| call_count += 1 | |
| if call_count == 1: | |
| return VerifyResponse( | |
| request_id=request.request_id, | |
| status="internal_error", | |
| elapsed_ms=1.0, | |
| error_type="BrokenProcessPool", | |
| error_message="broken process pool simulated", | |
| ) | |
| return _verify_answer_worker(request) | |
| async def restart_wrapper(): | |
| restart_calls["count"] += 1 | |
| await original_restart() | |
| try: | |
| with patch.object(service, "_run_request_once", side_effect=crashing_once): | |
| with patch.object(service, "_restart_pool", side_effect=restart_wrapper): | |
| response = await service.verify_answer(r"\boxed{4}", "4") | |
| assert response.status == "correct" | |
| assert response.retry_count == 1 | |
| assert response.worker_restarted is True | |
| assert restart_calls["count"] == 1 | |
| health = await service.health_probe() | |
| restart_count = health["restart_count"] | |
| assert isinstance(restart_count, (int, float)) | |
| assert restart_count >= 1 | |
| finally: | |
| await service.stop() | |
| async def test_no_retry_for_unparsable_outputs(self): | |
| service = MathVerifierService(max_workers=1, max_retries=3) | |
| await service.start() | |
| call_count = 0 | |
| async def unparsable_once(request): | |
| nonlocal call_count | |
| call_count += 1 | |
| return VerifyResponse( | |
| request_id=request.request_id, | |
| status="unparsable", | |
| elapsed_ms=1.0, | |
| error_type=None, | |
| error_message=None, | |
| ) | |
| try: | |
| with patch.object(service, "_run_request_once", side_effect=unparsable_once): | |
| response = await service.verify_answer("not boxed", "4") | |
| assert response.status == "unparsable" | |
| assert response.retry_count == 0 | |
| assert call_count == 1 | |
| finally: | |
| await service.stop() | |
| async def test_health_probe_reports_runtime_state(self): | |
| service = MathVerifierService(max_workers=1, queue_size=3) | |
| before_start = await service.health_probe() | |
| assert before_start["status"] == "stopped" | |
| assert before_start["executor_running"] is False | |
| await service.start() | |
| try: | |
| after_start = await service.health_probe() | |
| assert after_start["status"] == "healthy" | |
| assert after_start["executor_running"] is True | |
| assert after_start["queue_size"] == 3 | |
| assert after_start["max_workers"] == 1 | |
| finally: | |
| await service.stop() | |
| async def test_metrics_snapshot_tracks_rollout_counters(self): | |
| service = MathVerifierService(max_workers=1, queue_size=4) | |
| await service.start() | |
| try: | |
| await service.verify_answer(r"\boxed{4}", "4") | |
| await service.verify_answer(r"\boxed{5}", "4") | |
| await service.verify_answer("no boxed answer", "4") | |
| metrics = await service.metrics_snapshot() | |
| assert metrics["verifier/requests/count"] >= 3 | |
| assert metrics["verifier/requests/latency_ms"] >= 0 | |
| assert metrics["verifier/requests/timeout_count"] >= 0 | |
| assert metrics["verifier/requests/error_count"] >= 0 | |
| assert metrics["verifier/workers/heartbeat_lag_ms"] >= 0 | |
| finally: | |
| await service.stop() | |
| class TestGoldCache: | |
| async def test_answer_mode_grading_uses_cached_gold(self): | |
| env = _make_env_with_problem(ANSWER_PROBLEM) | |
| env.reset() | |
| assert env._current_problem is not None | |
| problem_id = env._current_problem["problem_id"] | |
| env._gold_answer_cache[env._gold_cache_key(problem_id)] = r"\boxed{42}" | |
| captured_expected: dict[str, str] = {} | |
| async def fake_grade_answer_submission( | |
| submission: str, | |
| expected_answer: str, | |
| problem_id: str = "", | |
| ): | |
| captured_expected["value"] = expected_answer | |
| return GradingResult(score=7, feedback="ok", reward=1.0) | |
| with patch.object( | |
| env, | |
| "_grade_answer_submission", | |
| new=fake_grade_answer_submission, | |
| ): | |
| await env._grade_submission(r"\boxed{42}") | |
| assert captured_expected["value"] == r"\boxed{42}" | |
| await env._verifier_service.stop() | |
| def test_gold_cache_invalidation_on_config_change(self): | |
| env = _make_env_with_problem(ANSWER_PROBLEM) | |
| old_signature = env._gold_cache_signature | |
| env._config.verifier_numeric_precision += 1 | |
| env._refresh_gold_cache_if_needed() | |
| assert env._gold_cache_signature != old_signature | |
| env.reset() | |
| assert env._current_problem is not None | |
| problem_id = env._current_problem["problem_id"] | |
| assert env._gold_cache_key(problem_id) in env._gold_answer_cache | |
| async def test_cache_hit_rate_surfaces_in_verifier_metrics(self): | |
| env = _make_env_with_problem(ANSWER_PROBLEM) | |
| env.reset() | |
| await env._grade_submission(r"\boxed{4}") | |
| payload = await env.submit_proof_payload(r"\boxed{4}") | |
| verifier_metrics = payload["metadata"]["verifier_metrics"] | |
| assert "verifier/cache/hit_rate" in verifier_metrics | |
| assert verifier_metrics["verifier/cache/hit_rate"] >= 0.0 | |
| assert verifier_metrics["verifier/cache/hit_rate"] <= 1.0 | |
| await env._verifier_service.stop() | |
| # Integration tests (require running server) | |
| class TestQEDMathServerIntegration: | |
| """Full HTTP/WebSocket stack smoke test. | |
| Requires a running QED Math server at QED_MATH_URL (default: http://localhost:8000). | |
| Start it with: | |
| docker run -p 8000:8000 \\ | |
| -e JUDGE_API_BASE_URL=... \\ | |
| -e JUDGE_API_KEY=... \\ | |
| qed-math-env:latest | |
| """ | |
| BASE_URL = "http://localhost:8000" | |
| def env(self): | |
| from client import QEDMathEnv | |
| with QEDMathEnv(base_url=self.BASE_URL).sync() as e: | |
| yield e | |
| def test_reset_returns_problem(self, env): | |
| obs = env.reset() | |
| assert obs is not None | |
| def test_list_tools_returns_three_tools(self, env): | |
| tools = env.list_tools() | |
| names = {t.name for t in tools} | |
| assert {"get_problem", "submit_proof", "get_grading_guidelines"}.issubset(names) | |
| def test_get_problem_tool(self, env): | |
| env.reset() | |
| result = env.call_tool("get_problem") | |
| assert result is not None | |
| def test_get_grading_guidelines_tool(self, env): | |
| env.reset() | |
| result = env.call_tool("get_grading_guidelines") | |
| assert result is not None | |
| def test_submit_proof_tool_returns_observation(self, env): | |
| env.reset() | |
| result = env.call_tool("submit_proof", proof="Let a=2m and b=2n, so a+b=2(m+n).") | |
| assert result is not None | |