sourasishbasu commited on
Commit
9d852e5
·
1 Parent(s): 9542047

QED Math Environment

Browse files
Files changed (10) hide show
  1. .dockerignore +13 -0
  2. README.md +47 -0
  3. __init__.py +29 -0
  4. client.py +146 -0
  5. models.py +87 -0
  6. openenv.yaml +83 -0
  7. pyproject.toml +35 -0
  8. server/Dockerfile +68 -0
  9. server/__init__.py +12 -0
  10. uv.lock +17 -0
.dockerignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .venv
2
+ .git
3
+ .gitignore
4
+ .env
5
+ __pycache__/
6
+ *.pyc
7
+ *.pyo
8
+ *.pyd
9
+ *.pyw
10
+ *.pyz
11
+ *.pywz
12
+ *.pyzw
13
+ *.pyzwz
README.md CHANGED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # QED Math Environment
2
+
3
+ A mathematical proof generation and evaluation environment for OpenEnv.
4
+
5
+ ## Features
6
+
7
+ - **MCP Tools**: Agent interacts via MCP (Model Context Protocol)
8
+ - `get_problem`: Get current problem statement and metadata
9
+ - `submit_proof`: Submit proof for LLM-judge rubric grading (0-7 scale)
10
+ - `get_grading_guidelines`: Get grading rubric for current problem
11
+
12
+ - **LLM-Judge Rubric**: Proofs graded on 0-7 scale with normalized rewards
13
+ - **Answer-mode verification**: Uses `math_verify` for fast \\boxed{} checking
14
+ - **Reward shaping**: Discount factor, length penalty, optional score thresholding
15
+ - **Flexible datasets**: Local JSONL/JSON, Hugging Face Hub, or built-in bootstrap
16
+
17
+ ## Quick Start
18
+
19
+ ```bash
20
+ # Install
21
+ uv sync --all-extras
22
+
23
+ # Run server
24
+ uv run server
25
+
26
+ # Or via Docker
27
+ docker build -t qed-math-env:latest -f server/Dockerfile .
28
+ docker run -p 8000:8000 -e OPENAI_API_KEY=$OPENAI_API_KEY qed-math-env:latest
29
+ ```
30
+
31
+ ## Usage
32
+
33
+ ```python
34
+ from qed_math_env import QEDMathEnv
35
+
36
+
37
+ with QEDMathEnv(base_url="http://localhost:8000") as env:
38
+ env.reset()
39
+ problem = env.call_tool("get_problem")
40
+ result = env.call_tool("submit_proof", proof="Let a=2m..."
41
+ ```
42
+
43
+ ## Testing
44
+
45
+ ```bash
46
+ PYTHONPATH=src:envs uv run pytest tests/envs/test_qed_math_environment.py -v
47
+ ```
__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """QED Math Environment."""
8
+
9
+ from .client import QEDMathEnv
10
+ from .models import (
11
+ GetGradingGuidelines,
12
+ GetProblem,
13
+ ProblemObservation,
14
+ ProofSubmissionObservation,
15
+ QEDMathAction,
16
+ QEDMathObservation,
17
+ SubmitProof,
18
+ )
19
+
20
+ __all__ = [
21
+ "QEDMathAction",
22
+ "QEDMathObservation",
23
+ "QEDMathEnv",
24
+ "SubmitProof",
25
+ "GetProblem",
26
+ "GetGradingGuidelines",
27
+ "ProblemObservation",
28
+ "ProofSubmissionObservation",
29
+ ]
client.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """QED Math Environment Client.
8
+
9
+ Provides tool-calling style interactions with the QED Math environment
10
+ via MCP (Model Context Protocol).
11
+
12
+
13
+ Example:
14
+ >>> with QEDMathEnv(base_url="http://localhost:8000") as env:
15
+ ... env.reset()
16
+ ... tools = env.list_tools()
17
+ ... print([t.name for t in tools])
18
+ ... result = env.call_tool("get_problem")
19
+ ... result = env.call_tool("submit_proof", proof="By induction...")
20
+ """
21
+
22
+ from typing import Any, Mapping, Optional
23
+
24
+ from openenv.core.client_types import StepResult
25
+ from openenv.core.env_server.types import Observation, State
26
+ from openenv.core.mcp_client import MCPToolClient
27
+
28
+ from .models import ProblemObservation, ProofSubmissionObservation
29
+
30
+
31
+ class QEDMathEnv(MCPToolClient):
32
+ """
33
+ Client for the QED Math Environment.
34
+
35
+ Inherits MCP tool-calling interface from MCPToolClient:
36
+ - ``list_tools()``: Discover available MCP tools
37
+ - ``call_tool(name, **kwargs)``: Call a tool by name
38
+ - ``reset(**kwargs)``: Reset the environment
39
+
40
+ Example:
41
+ >>> with QEDMathEnv(base_url="http://localhost:8000") as env:
42
+ ... env.reset()
43
+ ... result = env.call_tool("get_problem")
44
+ ... result = env.call_tool("submit_proof", proof="By induction...")
45
+ """
46
+
47
+ @staticmethod
48
+ def _as_problem_observation(value: Any) -> ProblemObservation:
49
+ """Normalize tool/reset outputs into a ProblemObservation instance."""
50
+ if isinstance(value, ProblemObservation):
51
+ return value
52
+ if isinstance(value, Mapping):
53
+ return ProblemObservation(**dict(value))
54
+ if hasattr(value, "model_dump"):
55
+ return ProblemObservation(**value.model_dump())
56
+ raise TypeError(f"Unsupported problem observation payload type: {type(value).__name__}")
57
+
58
+
59
+ @staticmethod
60
+ def _as_proof_submission_observation(value: Any) -> ProofSubmissionObservation:
61
+ """Normalize tool outputs into a ProofSubmissionObservation instance."""
62
+ if isinstance(value, ProofSubmissionObservation):
63
+ return value
64
+ if isinstance(value, Mapping):
65
+ return ProofSubmissionObservation(**dict(value))
66
+ if hasattr(value, "model_dump"):
67
+ return ProofSubmissionObservation(**value.model_dump())
68
+ raise TypeError(f"Unsupported proof submission payload type: {type(value).__name__}")
69
+
70
+
71
+ async def reset(
72
+ self, problem_id: Optional[str] = None, **kwargs: Any
73
+ ) -> StepResult[Observation]:
74
+ """
75
+ Reset the environment, optionally selecting a specific problem.
76
+
77
+ Args:
78
+ problem_id: Optional problem identifier to load a specific problem.
79
+ If None, a problem is chosen randomly from the dataset.
80
+ **kwargs: Additional reset parameters (e.g., seed).
81
+
82
+ Returns:
83
+ StepResult with a normalized ProblemObservation in `observation`.
84
+ """
85
+ if problem_id is not None:
86
+ kwargs["problem_id"] = problem_id
87
+ result = await super().reset(**kwargs)
88
+ observation = result.observation if isinstance(result, StepResult) else result
89
+ normalized_observation = self._as_problem_observation(observation)
90
+ return StepResult(
91
+ observation=normalized_observation,
92
+ reward=result.reward,
93
+ done=result.done,
94
+ )
95
+
96
+
97
+ async def submit_proof(self, proof: str) -> ProofSubmissionObservation:
98
+ """
99
+ Submit a proof attempt for the current problem.
100
+
101
+ Args:
102
+ proof: The proof text to submit for grading.
103
+
104
+ Returns:
105
+ ProofSubmissionObservation with score (0-7), feedback, and reward.
106
+ """
107
+ result = await self.call_tool("submit_proof", proof=proof)
108
+ return self._as_proof_submission_observation(result)
109
+
110
+ async def get_current_problem(self) -> ProblemObservation:
111
+ """
112
+ Retrieve the current problem statement without resetting.
113
+
114
+ Returns:
115
+ ProblemObservation for the active problem.
116
+ """
117
+ result = await self.call_tool("get_problem")
118
+ return self._as_problem_observation(result)
119
+
120
+
121
+ async def get_problem(self) -> ProblemObservation:
122
+ """Compatibility alias for get_current_problem()."""
123
+ return await self.get_current_problem()
124
+
125
+ async def get_grading_feedback(self) -> dict[str, Any]:
126
+ """
127
+ Retrieve the grading guidelines/rubric for the current problem.
128
+
129
+ Returns:
130
+ Tool payload containing grading_guidelines and problem metadata.
131
+ """
132
+ result = await self.call_tool("get_grading_guidelines")
133
+ if isinstance(result, Mapping):
134
+ return dict(result)
135
+ if hasattr(result, "model_dump"):
136
+ return result.model_dump()
137
+ raise TypeError(f"Unsupported grading feedback payload type: {type(result).__name__}")
138
+
139
+ async def get_state(self) -> State:
140
+ """Return current environment state (episode_id, step_count)."""
141
+ return await super().state()
142
+
143
+ def get_state_sync(self) -> State:
144
+ """Synchronous helper for code paths that do not use async/await."""
145
+ with self.sync() as client:
146
+ return client.state()
models.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Data models for the QED Math Environment.
9
+
10
+ Defines action and observation types for mathematical proof submission
11
+ and grading.
12
+ """
13
+
14
+ from openenv.core.env_server.types import Action, Observation
15
+ from pydantic import Field
16
+
17
+
18
+ RewardValue = bool | int | float | None
19
+
20
+
21
+ class QEDMathAction(Action):
22
+ """Base action for the QED Math environment."""
23
+
24
+
25
+ class SubmitProof(QEDMathAction):
26
+ """Submit a proof attempt for the current problem."""
27
+
28
+ proof: str = Field(..., description="The proof text submitted by agent")
29
+ attempt_number: int = Field(default=1, description="Attempt counter")
30
+
31
+
32
+ class GetProblem(QEDMathAction):
33
+ """Request the current problem statement."""
34
+
35
+
36
+ class GetGradingGuidelines(QEDMathAction):
37
+ """Request the grading guidelines/rubric for current problem."""
38
+
39
+
40
+ class QEDMathObservation(Observation):
41
+ """Base observation for the QED Math environment."""
42
+
43
+
44
+ class ProblemObservation(QEDMathObservation):
45
+ """Observation containing the problem statement."""
46
+
47
+ problem: str = Field(default="", description="The mathematical problem")
48
+ reference_solution: str = Field(default="", description="Ground truth solution")
49
+ grading_guidelines: str = Field(
50
+ default="", description="Rubric for grading (0-7 scale)"
51
+ )
52
+ problem_id: str = Field(default="", description="Unique problem identifier")
53
+ dataset_source: str = Field(default="", description="Source dataset name")
54
+ problem_type: str = Field(
55
+ default="proof",
56
+ description="Problem type: proof, answer, or multi_step",
57
+ )
58
+ max_attempts: int = Field(
59
+ default=1,
60
+ description="Maximum number of allowed submission attempts",
61
+ )
62
+
63
+
64
+ class ProofSubmissionObservation(QEDMathObservation):
65
+ """Observation returned after submitting a proof."""
66
+
67
+ proof: str = Field(default="", description="The submitted proof")
68
+ score: int = Field(default=0, description="Grade from rubric (0-7)")
69
+ feedback: str = Field(default="", description="Grader feedback")
70
+ reward: RewardValue = Field(
71
+ default=0.0,
72
+ description="Normalized reward (score/7)",
73
+ )
74
+ done: bool = Field(default=True, description="Episode ends after proof submission")
75
+ problem_type: str = Field(
76
+ default="proof",
77
+ description="Problem type used to evaluate this submission",
78
+ )
79
+ attempt_number: int = Field(default=1, description="1-based submission attempt index")
80
+ attempts_remaining: int = Field(
81
+ default=0,
82
+ description="Remaining submission attempts in the current episode",
83
+ )
84
+ is_correct: bool = Field(
85
+ default=False,
86
+ description="Whether the submission is considered fully correct",
87
+ )
openenv.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: qed_math_env
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
7
+
8
+ description: >
9
+ Mathematical proof generation and evaluation environment.
10
+ Agents receive math problems, submit proofs, and receive LLM-based
11
+ rubric grading (0-7 scale) with normalized rewards.
12
+ Answer-mode uses a process-based verifier service for concurrent rollout
13
+ safety (timeouts, retries, backpressure, worker restart).
14
+ version: "0.1.0"
15
+
16
+ # Environment configuration defaults (overridable via QEDMathConfig)
17
+ environment:
18
+ grader_model: gemini-3-pro
19
+ prompt_name: v2
20
+ custom_reward_threshold: false
21
+ discount_factor: 1.0
22
+ buffer_tokens: 0
23
+ max_tokens: 0
24
+ verifier_workers: 4
25
+ verifier_queue_size: 128
26
+ verifier_request_timeout_seconds: 5.0
27
+ verifier_max_retries: 1
28
+ verifier_strict: true
29
+ verifier_numeric_precision: 5
30
+ verifier_float_rounding: 10
31
+
32
+ # Rubric definition
33
+ rubric:
34
+ type: llm_judge
35
+ name: MathProofRubric
36
+ scale: "0-7"
37
+ normalization: "score / 7.0"
38
+ thresholding: "optional (collapses 1-5 -> 1)"
39
+
40
+ # MCP tools exposed by this environment
41
+ tools:
42
+ - name: get_problem
43
+ description: Return current problem statement and metadata
44
+ - name: submit_proof
45
+ description: Submit a proof for LLM-based rubric grading
46
+ parameters:
47
+ - proof (str, required)
48
+ - name: get_grading_guidelines
49
+ description: Return the rubric/marking scheme for the current problem
50
+
51
+ # Verifier metrics emitted per grading call
52
+ metrics:
53
+ verifier:
54
+ - verifier/rollouts/success
55
+ - verifier/rollouts/failure
56
+ - verifier/failures/timeout
57
+ - verifier/failures/rate_limit
58
+ - verifier/failures/no_input
59
+ - verifier/failures/no_score_tag
60
+ - verifier/failures/all_attempts_failed
61
+ - verifier/failures/num_retries
62
+ - verifier/runtime/latency_per_request
63
+ - verifier/requests/count
64
+ - verifier/requests/latency_ms
65
+ - verifier/requests/timeout_count
66
+ - verifier/requests/error_count
67
+ - verifier/queue/depth
68
+ - verifier/cache/hit_rate
69
+ - verifier/workers/restart_count
70
+ - verifier/workers/worker_restarted
71
+ - verifier/workers/heartbeat_lag_ms
72
+ - verifier/runtime/input_tokens
73
+ - verifier/runtime/output_tokens
74
+ reward:
75
+ - reward/base
76
+ - reward/shaped
77
+ - reward/score_raw
78
+ - reward/overlong_penalty
79
+ episode:
80
+ - episode/attempt_number
81
+ - episode/is_correct
82
+ - episode/problem_type
83
+ - episode/dataset_source
pyproject.toml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-qed_math_env"
13
+ version = "0.1.0"
14
+ description = "Qed Math Env environment for OpenEnv"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
18
+ "datasets>=4.7.0",
19
+ "math-verify[antlr4_13_2]>=0.9.0",
20
+ "trackio>=0.19.0"
21
+ ]
22
+
23
+ [project.optional-dependencies]
24
+ dev = [
25
+ "pytest>=8.0.0",
26
+ "pytest-cov>=4.0.0",
27
+ ]
28
+
29
+ [project.scripts]
30
+ server = "qed_math_env.server.app:main"
31
+
32
+ [tool.setuptools]
33
+ include-package-data = true
34
+ packages = ["qed_math_env", "qed_math_env.server"]
35
+ package-dir = { "qed_math_env" = ".", "qed_math_env.server" = "server" }
server/Dockerfile ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multi-stage build using openenv-base
2
+ # This Dockerfile is flexible and works for both:
3
+ # - In-repo environments (with local OpenEnv sources)
4
+ # - Standalone environments (with openenv from PyPI/Git)
5
+ # The build script (openenv build) handles context detection and sets appropriate build args.
6
+ #
7
+ # Build from the env directory:
8
+ # docker build -t qed-math-env:latest -f server/Dockerfile .
9
+ #
10
+ # Or from the repo root:
11
+ # docker build -t qed-math-env:latest -f envs/qed_math_env/server/Dockerfile envs/qed_math_env
12
+ #
13
+ # Run:
14
+ # docker run -p 8000:8000 -e OPENAI_API_KEY=$OPENAI_API_KEY qed-math-env:latest
15
+
16
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
17
+ FROM ${BASE_IMAGE} AS builder
18
+
19
+ WORKDIR /app
20
+
21
+ ARG BUILD_MODE=in-repo
22
+
23
+ COPY . /app/env
24
+
25
+ WORKDIR /app/env
26
+
27
+ RUN if ! command -v uv >/dev/null 2>&1; then \
28
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
29
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
30
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
31
+ fi
32
+
33
+ RUN apt-get update && apt-get install -y --no-install-recommends \
34
+ git \
35
+ && rm -rf /var/lib/apt/lists/*
36
+
37
+ RUN --mount=type=cache,target=/root/.cache/uv \
38
+ if [ -f uv.lock ]; then \
39
+ uv sync --frozen --no-install-project --no-editable; \
40
+ else \
41
+ uv sync --no-install-project --no-editable; \
42
+ fi
43
+
44
+ RUN --mount=type=cache,target=/root/.cache/uv \
45
+ if [ -f uv.lock ]; then \
46
+ uv sync --frozen --no-editable; \
47
+ else \
48
+ uv sync --no-editable; \
49
+ fi
50
+
51
+ FROM ${BASE_IMAGE}
52
+
53
+ WORKDIR /app
54
+
55
+ COPY --from=builder /app/env/.venv /app/.venv
56
+
57
+ COPY --from=builder /app/env /app/env
58
+
59
+ ENV PATH="/app/.venv/bin:$PATH"
60
+
61
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
62
+
63
+ ENV PYTHONUNBUFFERED=1
64
+
65
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
66
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
67
+
68
+ CMD ["uvicorn", "qed_math_env.server.app:app", "--host", "0.0.0.0", "--port", "8000", "--ws-ping-interval", "120", "--ws-ping-timeout", "600"]
server/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """QED Math Environment server components."""
8
+
9
+ from .qed_math_environment import QEDMathEnvironment
10
+
11
+
12
+ __all__ = ["QEDMathEnvironment"]
uv.lock ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version = 1
2
+ version_hash = ""
3
+
4
+ [[package]]
5
+ name = "openenv-qed_math_env"
6
+ version = "0.1.0"
7
+ source = { git = "https://github.com/meta-pytorch/OpenEnv.git" }
8
+
9
+
10
+ [package.dependencies]
11
+ "datasets" = ">=4.7.0"
12
+ "math-verify" = { extras = ["antlr4_13_2"], version = ">=0.9.0" }
13
+ "openenv-core" = { extras = ["core"], git = "https://github.com/meta-pytorch/OpenEnv.git" }
14
+ "trackio" = ">=0.19.0"
15
+
16
+ [package.optional-dependencies]
17
+ dev = ["pytest-cov>=4.0.0", "pytest>=8.0.0"]