qed-math-openenv / server /qed_math_environment.py
sourasishbasu
fix ruff lint errors in qed_math_environment.py
79e2b1e
"""
QED Math Environment Implementation.
A math proof environment that presents problems to agents and evaluates
submitted proofs using LLM-based rubric grading (0-7 scale).
"""
import asyncio
import json
import logging
import os
import random
import re
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Optional
from uuid import uuid4
import math_verify
from datasets import load_dataset
from fastmcp import FastMCP
from openenv.core.env_server.mcp_environment import (
MCPEnvironment,
)
from openenv.core.env_server.mcp_types import (
CallToolAction,
CallToolObservation,
ListToolsAction,
ListToolsObservation,
Tool,
ToolError,
ToolErrorType,
)
from openenv.core.env_server.types import Action, Observation, State
from models import ProblemObservation, ProofSubmissionObservation
from .math_verify_service import MathVerifierService
from .mcp_server import register_mcp_tools
from .rubric import GradingResult, MathProofRubric, length_penalty, parse_schema
DEFAULT_EVALUATOR_PROMPT = (
"You are a strict math proof grader. Score the submission from 0 to 7 based on "
"mathematical correctness, completeness, and logical rigor."
)
logger = logging.getLogger(__name__)
DatasetSource = str | dict[str, Any] | list[str | dict[str, Any]] | None
def _dataset_source_from_env() -> DatasetSource:
raw_spec = (os.environ.get("QED_DATASET_SPEC_JSON") or "").strip()
if raw_spec:
try:
parsed = json.loads(raw_spec)
except json.JSONDecodeError:
logger.warning("Ignoring invalid QED_DATASET_SPEC_JSON value.")
else:
if isinstance(parsed, (str, dict, list)):
return parsed
logger.warning(
"Ignoring QED_DATASET_SPEC_JSON with unsupported type: %s",
type(parsed).__name__,
)
raw_path = (os.environ.get("QED_DATASET_PATH") or "").strip()
if raw_path:
return raw_path
return None
def _default_verifier_workers() -> int:
cpu_count = os.cpu_count() or 2
return max(2, min(8, cpu_count // 2 or 1))
def _default_verifier_queue_size() -> int:
return _default_verifier_workers() * 32
class UnparsableException(Exception):
pass
class NoAnswerException(Exception):
pass
class EmptyBoxedException(Exception):
pass
def _parse_math_verify_expression(value: str) -> Any:
# Work around Windows multiprocessing issues in math-verify timeout wrappers.
parsed = math_verify.parse(value, parsing_timeout=0)
if parsed:
return parsed
boxed_match = re.search(r"\\boxed\{(.+?)\}", value)
if boxed_match:
return math_verify.parse(boxed_match.group(1), parsing_timeout=0)
return parsed
def remove_reasoning(
completion: str,
reasoning_delimiters: list[str] | None = None,
) -> str:
if not reasoning_delimiters:
return completion
for delim in reasoning_delimiters:
if delim in completion:
completion = completion.split(delim)[-1]
return completion.strip()
return ""
@dataclass
class QEDMathConfig:
dataset_path: DatasetSource = None
grader_model: str = "gemini-3-pro"
prompt_name: str = "v2"
custom_reward_threshold: bool = False
max_attempts: int = 1
discount_factor: float = 1.0
buffer_tokens: int = 0
max_tokens: int = 0
reasoning_delimiters: list[str] | None = None
verifier_workers: int = field(default_factory=_default_verifier_workers)
verifier_queue_size: int = field(default_factory=_default_verifier_queue_size)
verifier_request_timeout_seconds: float = 5.0
verifier_max_retries: int = 1
verifier_strict: bool = True
verifier_numeric_precision: int = 5
verifier_float_rounding: int = 10
def load_evaluator_prompt(prompt_name: str) -> str:
prompt_path = (
Path(__file__).resolve().parent.parent
/ "prompts"
/ "evaluator_prompts"
/ f"{prompt_name}.md"
)
if prompt_path.exists():
return prompt_path.read_text(encoding="utf-8")
return DEFAULT_EVALUATOR_PROMPT
def _bootstrap_problems() -> list[dict]:
return [
{
"problem": "Prove that the sum of two even integers is even.",
"reference_solution": "Let a=2m and b=2n. Then a+b=2(m+n), so it is even.",
"grading_guidelines": "Award full credit for a correct parity argument.",
"problem_id": "bootstrap_000001",
"dataset_source": "bootstrap",
"problem_type": "proof",
"max_attempts": 1,
}
]
def _coerce_positive_int(value: Any, default: int) -> int:
try:
parsed = int(value)
except (TypeError, ValueError):
return default
return parsed if parsed > 0 else default
def _canonical_problem_type(raw_problem: dict[str, Any]) -> str:
explicit_type = _first_present_value(
raw_problem,
("problem_type", "type", "problem_kind", "mode"),
None,
)
if isinstance(explicit_type, str):
normalized = explicit_type.strip().lower()
if normalized in {"proof", "answer", "multi_step"}:
return normalized
if bool(raw_problem.get("multi_step")):
return "multi_step"
evaluation_mode = _first_present_value(raw_problem, ("evaluation_mode",), None)
if isinstance(evaluation_mode, str) and evaluation_mode.strip().lower() == "answer":
return "answer"
return "proof"
def _coerce_dataset_specs(dataset_path: DatasetSource) -> list[str | dict[str, Any]]:
if dataset_path is None:
return []
if isinstance(dataset_path, (str, dict)):
return [dataset_path]
if isinstance(dataset_path, list):
return dataset_path
raise TypeError(
"dataset_path must be None, a string path or hub id, a dataset spec dict, or a list of"
" specs."
)
def _read_local_problem_rows(path: Path) -> list[dict[str, Any]]:
if not path.exists():
raise FileNotFoundError(f"Dataset path does not exist: {path}")
rows: list[dict[str, Any]] = []
suffix = path.suffix.lower()
if suffix == ".jsonl":
with path.open("r", encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if not line:
continue
parsed = json.loads(line)
if isinstance(parsed, dict):
rows.append(parsed)
elif suffix == ".json":
parsed = json.loads(path.read_text(encoding="utf-8"))
if isinstance(parsed, list):
rows = [item for item in parsed if isinstance(item, dict)]
elif isinstance(parsed, dict) and isinstance(parsed.get("problems"), list):
rows = [item for item in parsed["problems"] if isinstance(item, dict)]
else:
raise ValueError(
"JSON dataset must be a list of problem objects or contain 'problems'."
)
else:
raise ValueError("Unsupported dataset format. Expected .jsonl or .json.")
return rows
def _read_hub_problem_rows(
spec: str | dict[str, Any],
) -> tuple[list[dict[str, Any]], str]:
if isinstance(spec, str):
hub_id = spec
config = None
split = "train"
else:
hub_id = str(spec.get("hub_id") or spec.get("dataset") or "").strip()
config = spec.get("config")
split = spec.get("split", "train")
if not hub_id:
raise ValueError("Hub dataset specs must include 'hub_id' or 'dataset'.")
load_args: tuple[Any, ...] = (hub_id,)
if config is not None:
load_args += (config,)
dataset = load_dataset(
*load_args,
split=split,
)
rows = [dict(row) for row in dataset]
logger.info(
"Loaded QED math hub dataset %s%s split=%s with %d rows",
hub_id,
f"/{config}" if config else "",
split,
len(rows),
)
return rows, hub_id
def _first_present_value(
raw_problem: dict[str, Any],
keys: tuple[str, ...],
default: Any = None,
) -> Any:
for key in keys:
if key in raw_problem and raw_problem[key] is not None:
return raw_problem[key]
return default
def _normalize_problem(raw_problem: dict[str, Any], index: int, dataset_source: str) -> dict:
problem = _first_present_value(raw_problem, ("problem", "task", "Problem"))
if not isinstance(problem, str) or not problem.strip():
raise ValueError("Dataset row is missing a non-empty problem statement.")
reference_solution = _first_present_value(
raw_problem,
("reference_solution", "solution", "answer", "Solution"),
"",
)
grading_guidelines = _first_present_value(
raw_problem,
(
"grading_guidelines",
"rubrics",
"schema",
"schema_0",
"Grading guidelines",
"details",
),
"",
)
problem_id = _first_present_value(
raw_problem,
("problem_id", "id"),
f"problem_{index:06d}",
)
resolved_dataset_source = _first_present_value(
raw_problem,
("dataset_source", "dataset", "data_source"),
dataset_source,
)
problem_type = _canonical_problem_type(raw_problem)
default_max_attempts = 1 if problem_type != "multi_step" else 3
max_attempts = _coerce_positive_int(
_first_present_value(raw_problem, ("max_attempts", "attempts", "num_attempts"), None),
default=default_max_attempts,
)
success_score_threshold = _coerce_positive_int(
_first_present_value(raw_problem, ("success_score_threshold",), None),
default=6,
)
evaluation_mode = _first_present_value(raw_problem, ("evaluation_mode",), None)
if isinstance(evaluation_mode, str):
evaluation_mode = evaluation_mode.strip().lower()
else:
evaluation_mode = "answer" if problem_type == "answer" else "proof"
if evaluation_mode not in {"proof", "answer"}:
evaluation_mode = "proof"
if (
evaluation_mode == "answer"
and isinstance(reference_solution, str)
and "\\boxed{" not in reference_solution
):
reference_solution = f"\\boxed{{{reference_solution}}}"
original_problem = _first_present_value(raw_problem, ("original_problem",), None)
return {
"problem": problem,
"original_problem": original_problem,
"reference_solution": str(reference_solution),
"grading_guidelines": grading_guidelines,
"problem_id": str(problem_id),
"dataset_source": str(resolved_dataset_source),
"problem_type": problem_type,
"max_attempts": max_attempts,
"success_score_threshold": success_score_threshold,
"evaluation_mode": evaluation_mode,
}
def _load_problems_from_spec(
spec: str | dict[str, Any],
start_index: int,
) -> list[dict[str, Any]]:
if isinstance(spec, dict):
if "path" in spec:
path = Path(spec["path"])
rows = _read_local_problem_rows(path)
dataset_source = str(spec.get("dataset_source") or spec.get("dataset") or path.stem)
elif "hub_id" in spec or "dataset" in spec:
rows, dataset_source = _read_hub_problem_rows(spec)
else:
raise ValueError(
"Dataset spec dict must include either 'path' for local files or"
" 'hub_id'/'dataset' for hub datasets."
)
else:
candidate_path = Path(spec)
if candidate_path.exists() or candidate_path.suffix.lower() in {
".json",
".jsonl",
}:
rows = _read_local_problem_rows(candidate_path)
dataset_source = candidate_path.stem
elif "/" in spec:
rows, dataset_source = _read_hub_problem_rows(spec)
else:
raise ValueError(
"Dataset source must be a local .json/.jsonl path or a Hugging Face dataset id"
" like 'owner/name'."
)
problems: list[dict[str, Any]] = []
for offset, row in enumerate(rows, start=1):
try:
problems.append(_normalize_problem(row, start_index + offset, dataset_source))
except ValueError:
logger.warning(
"Skipping invalid QED math dataset row from %s at offset %d",
dataset_source,
offset,
)
return problems
def load_problems(dataset_path: DatasetSource) -> list[dict]:
dataset_specs = _coerce_dataset_specs(dataset_path)
if not dataset_specs:
return _bootstrap_problems()
problems: list[dict[str, Any]] = []
for spec in dataset_specs:
problems.extend(_load_problems_from_spec(spec, start_index=len(problems)))
if not problems:
raise ValueError(
"No valid QED math problems were loaded from the configured dataset source."
)
return problems
class QEDMathEnvironment(MCPEnvironment):
def __init__(
self,
dataset_path: DatasetSource = None,
grader_model: str | None = None,
prompt_name: str | None = None,
custom_reward_threshold: bool | None = None,
max_attempts: int | None = None,
config: QEDMathConfig | None = None,
):
mcp = FastMCP("qed_math_env")
register_mcp_tools(mcp, self)
super().__init__(mcp)
base_config = config or QEDMathConfig()
resolved_dataset_path = (
dataset_path if dataset_path is not None else base_config.dataset_path
)
if resolved_dataset_path is None:
resolved_dataset_path = _dataset_source_from_env()
self._config = QEDMathConfig(
dataset_path=resolved_dataset_path,
grader_model=(grader_model if grader_model is not None else base_config.grader_model),
prompt_name=(prompt_name if prompt_name is not None else base_config.prompt_name),
custom_reward_threshold=(
custom_reward_threshold
if custom_reward_threshold is not None
else base_config.custom_reward_threshold
),
max_attempts=(max_attempts if max_attempts is not None else base_config.max_attempts),
discount_factor=base_config.discount_factor,
buffer_tokens=base_config.buffer_tokens,
max_tokens=base_config.max_tokens,
reasoning_delimiters=base_config.reasoning_delimiters,
verifier_workers=base_config.verifier_workers,
verifier_queue_size=base_config.verifier_queue_size,
verifier_request_timeout_seconds=base_config.verifier_request_timeout_seconds,
verifier_max_retries=base_config.verifier_max_retries,
verifier_strict=base_config.verifier_strict,
verifier_numeric_precision=base_config.verifier_numeric_precision,
verifier_float_rounding=base_config.verifier_float_rounding,
)
self._dataset_path = self._config.dataset_path
self._grader_model = self._config.grader_model
self._prompt_name = self._config.prompt_name
self._prompt_template = load_evaluator_prompt(self._config.prompt_name)
self._problems: list[dict] = load_problems(self._config.dataset_path)
self._current_problem: dict | None = None
self._gold_cache_signature = self._build_gold_cache_signature()
self._gold_cache_problem_count = len(self._problems)
self._gold_answer_cache: dict[tuple[str, tuple[bool, int, int]], str] = {}
self._gold_cache_hits = 0
self._gold_cache_misses = 0
self._build_gold_answer_cache()
judge_api_base_url = os.environ.get("JUDGE_API_BASE_URL") or os.environ.get(
"OPENAI_BASE_URL"
)
judge_api_key = os.environ.get("JUDGE_API_KEY") or os.environ.get("OPENAI_API_KEY")
judge_model = os.environ.get("JUDGE_MODEL") or self._config.grader_model
self._rubric = MathProofRubric(
grader_model=judge_model,
prompt_template=self._prompt_template,
custom_threshold=self._config.custom_reward_threshold,
api_base_url=judge_api_base_url,
api_key=judge_api_key,
)
self._verifier_service = MathVerifierService(
max_workers=self._config.verifier_workers,
queue_size=self._config.verifier_queue_size,
request_timeout_seconds=self._config.verifier_request_timeout_seconds,
max_retries=self._config.verifier_max_retries,
strict=self._config.verifier_strict,
numeric_precision=self._config.verifier_numeric_precision,
float_rounding=self._config.verifier_float_rounding,
)
self._discount_factor = self._config.discount_factor
self._buffer_tokens = self._config.buffer_tokens
self._max_tokens = self._config.max_tokens
self._reasoning_delimiters = self._config.reasoning_delimiters
self._state = State(episode_id=str(uuid4()), step_count=0)
self._reset_count = 0
self._attempt_count = 0
self._current_max_attempts = max(1, int(self._config.max_attempts))
self._pending_output_length_tokens: int = 0
def _build_gold_cache_signature(self) -> tuple[bool, int, int]:
return (
bool(self._config.verifier_strict),
int(self._config.verifier_numeric_precision),
int(self._config.verifier_float_rounding),
)
def _gold_cache_key(self, problem_id: str) -> tuple[str, tuple[bool, int, int]]:
return (problem_id, self._gold_cache_signature)
def _build_gold_answer_cache(self) -> None:
self._gold_answer_cache = {}
for problem in self._problems:
evaluation_mode = problem.get("evaluation_mode", "proof")
if evaluation_mode != "answer":
continue
problem_id = str(problem.get("problem_id", "")).strip()
if not problem_id:
continue
reference_solution = str(problem.get("reference_solution", ""))
try:
_parse_math_verify_expression(reference_solution)
except Exception:
logger.warning(
"Failed to pre-parse answer-mode gold for problem_id=%s;"
" deferring to runtime verifier.",
problem_id,
)
self._gold_answer_cache[self._gold_cache_key(problem_id)] = reference_solution
def _refresh_gold_cache_if_needed(self) -> None:
current_signature = self._build_gold_cache_signature()
current_problem_count = len(self._problems)
if (
current_signature != self._gold_cache_signature
or current_problem_count != self._gold_cache_problem_count
):
self._gold_cache_signature = current_signature
self._gold_cache_problem_count = current_problem_count
self._build_gold_answer_cache()
def _get_cached_gold_answer(self, problem_id: str, fallback: str) -> str:
self._refresh_gold_cache_if_needed()
key = self._gold_cache_key(problem_id)
if key in self._gold_answer_cache:
self._gold_cache_hits += 1
return self._gold_answer_cache[key]
self._gold_cache_misses += 1
return fallback
def _gold_cache_hit_rate(self) -> float:
total = self._gold_cache_hits + self._gold_cache_misses
if total == 0:
return 0.0
return float(self._gold_cache_hits) / float(total)
async def step_async(
self,
action: Any,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> Any:
self._state.step_count += 1
from openenv.core.env_server.mcp_types import (
CallToolAction,
CallToolObservation,
)
MCP_TIMEOUT = 600.0
if isinstance(action, ListToolsAction):
try:
tools_result = await self._async_list_tools()
tools = [
Tool(
name=t.name,
description=t.description or "",
input_schema=t.inputSchema if hasattr(t, "inputSchema") else {},
)
for t in tools_result
]
return ListToolsObservation(tools=tools)
except Exception as exc:
return ListToolsObservation(
tools=[],
metadata={"error": str(exc), "error_type": "list_tools_failed"},
)
if isinstance(action, CallToolAction):
if action.tool_name == "submit_proof":
raw = kwargs.get("output_length_tokens", 0)
try:
self._pending_output_length_tokens = max(0, int(raw))
except (TypeError, ValueError):
self._pending_output_length_tokens = 0
timeout = timeout_s if timeout_s is not None else MCP_TIMEOUT
try:
result = await asyncio.wait_for(
self._async_call_tool(action.tool_name, action.arguments),
timeout=timeout,
)
return CallToolObservation(tool_name=action.tool_name, result=result)
except asyncio.TimeoutError:
return CallToolObservation(
tool_name=action.tool_name,
result=None,
error=ToolError(
error_type=ToolErrorType.TIMEOUT,
message=f"Tool '{action.tool_name}' timed out after {timeout}s",
),
)
except Exception as exc:
return CallToolObservation(
tool_name=action.tool_name,
result=None,
error=ToolError(
error_type=ToolErrorType.EXECUTION_ERROR,
message=str(exc),
),
)
return self._step_impl(action, timeout_s=timeout_s, **kwargs)
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> Observation:
selected_problem_id = kwargs.pop("problem_id", None)
self._refresh_gold_cache_if_needed()
self._state = State(
episode_id=episode_id or str(uuid4()),
step_count=0,
)
self._attempt_count = 0
self._reset_count += 1
rubric_reset = getattr(self._rubric, "reset", None)
if callable(rubric_reset):
rubric_reset()
if self._problems:
if selected_problem_id is not None:
selected = next(
(
problem
for problem in self._problems
if problem.get("problem_id") == selected_problem_id
),
None,
)
self._current_problem = selected or self._problems[0]
elif seed is not None:
rng = random.Random(seed)
self._current_problem = rng.choice(self._problems)
else:
self._current_problem = random.choice(self._problems)
else:
self._current_problem = None
if self._current_problem is None:
return Observation(
done=False,
reward=0.0,
metadata={
"error": "No problems loaded. Provide a valid dataset_path.",
"status": "empty",
},
)
self._current_max_attempts = _coerce_positive_int(
self._current_problem.get("max_attempts"),
default=max(1, int(self._config.max_attempts)),
)
return ProblemObservation(
problem=self._current_problem.get("problem", ""),
reference_solution=(
self._current_problem.get("reference_solution", "")
if self._current_problem.get("evaluation_mode", "proof") == "answer"
else ""
),
grading_guidelines=parse_schema(
self._current_problem.get("grading_guidelines", "") or ""
),
problem_id=self._current_problem.get("problem_id", ""),
dataset_source=self._current_problem.get("dataset_source", ""),
problem_type=self._current_problem.get("problem_type", "proof"),
max_attempts=self._current_max_attempts,
done=False,
reward=0.0,
metadata={
"status": "ready",
"reset_count": self._reset_count,
"step_count": self._state.step_count,
"attempt_count": self._attempt_count,
},
)
def _step_impl(
self,
action: Action,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> Observation:
return Observation(
done=False,
reward=0.0,
metadata={
"error": f"Unknown action type: {type(action).__name__}. "
"Use MCP tools (CallToolAction) for interactions."
},
)
def step(
self,
action: Action,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> Observation:
self._state.step_count += 1
obs = super().step(action, timeout_s=timeout_s, **kwargs)
if (
isinstance(action, CallToolAction)
and action.tool_name == "submit_proof"
and isinstance(obs, CallToolObservation)
and obs.error is None
):
payload = self._extract_tool_payload(obs.result)
if isinstance(payload, dict):
proof_obs = ProofSubmissionObservation(
proof=str(payload.get("proof", "")),
score=int(payload.get("score", 0)),
feedback=str(payload.get("feedback", "")),
done=bool(payload.get("done", True)),
reward=float(payload.get("reward", 0.0)),
problem_type=str(payload.get("problem_type", "proof")),
attempt_number=int(payload.get("attempt_number", 1)),
attempts_remaining=int(payload.get("attempts_remaining", 0)),
is_correct=bool(payload.get("is_correct", False)),
metadata=dict(payload.get("metadata", {})),
)
metadata = dict(obs.metadata)
metadata["proof_submission"] = proof_obs.model_dump()
return CallToolObservation(
tool_name=obs.tool_name,
result=obs.result,
error=obs.error,
done=proof_obs.done,
reward=proof_obs.reward,
metadata=metadata,
)
return obs
@staticmethod
def _extract_tool_payload(tool_result: Any) -> Any:
if tool_result is None:
return None
data = getattr(tool_result, "data", None)
if data is not None:
return data
structured_content = getattr(tool_result, "structured_content", None)
if structured_content is not None:
return structured_content
if isinstance(tool_result, dict):
return tool_result
return None
@staticmethod
def _utc_now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
@staticmethod
def _chunk_feedback(feedback: str, chunk_size: int = 280) -> list[str]:
if not feedback:
return []
return [feedback[i : i + chunk_size] for i in range(0, len(feedback), chunk_size)]
def _build_grading_progress(
self,
proof: str,
feedback: str,
is_correct: bool,
done: bool,
) -> dict[str, Any]:
feedback_chunks = self._chunk_feedback(feedback)
return {
"status": "completed" if done else "in_progress",
"progress": 1.0,
"events": [
{
"stage": "submission_received",
"progress": 0.2,
"message": "Proof submission received.",
"timestamp": self._utc_now_iso(),
},
{
"stage": "grading_started",
"progress": 0.6,
"message": "Grading started.",
"timestamp": self._utc_now_iso(),
},
{
"stage": "grading_completed",
"progress": 0.9,
"message": "Grading completed.",
"timestamp": self._utc_now_iso(),
},
{
"stage": "result_ready",
"progress": 1.0,
"message": "Result ready for client consumption.",
"timestamp": self._utc_now_iso(),
},
],
"realtime": {
"websocket_supported": True,
"submission_type": "proof" if proof.strip() else "empty",
},
"streaming_feedback": {
"chunks": feedback_chunks,
"chunk_count": len(feedback_chunks),
"is_final": True,
},
"is_correct": is_correct,
}
def get_problem_payload(self) -> dict:
if self._current_problem is None:
return {
"error": "No active problem. Call reset() first.",
"done": False,
"reward": 0.0,
}
problem_type = self._current_problem.get("problem_type", "proof")
evaluation_mode = self._current_problem.get("evaluation_mode")
if isinstance(evaluation_mode, str):
evaluation_mode = evaluation_mode.strip().lower()
else:
evaluation_mode = "answer" if problem_type == "answer" else "proof"
if evaluation_mode not in {"proof", "answer"}:
evaluation_mode = "proof"
payload = {
"problem": self._current_problem.get("problem", ""),
"grading_guidelines": self._current_grading_guidelines_text(),
"problem_id": self._current_problem.get("problem_id", ""),
"dataset_source": self._current_problem.get("dataset_source", ""),
"problem_type": problem_type,
"max_attempts": self._current_max_attempts,
"attempt_count": self._attempt_count,
"done": False,
"reward": 0.0,
}
if evaluation_mode == "answer":
payload["reference_solution"] = self._current_problem.get("reference_solution", "")
return payload
@staticmethod
def _verify_math(
prediction: str,
gold: str,
strict: bool = True,
max_prediction_length: int = 1000,
) -> str:
try:
if not isinstance(prediction, str) or not isinstance(gold, str):
raise ValueError("Prediction and gold must be strings")
boxed_start = prediction.rfind("\\boxed{")
if boxed_start < 0:
raise NoAnswerException()
boxed_prediction = prediction[boxed_start:]
if "\\boxed{}" in boxed_prediction:
raise EmptyBoxedException()
if len(boxed_prediction) > max_prediction_length:
raise UnparsableException()
gold_parsed = _parse_math_verify_expression(gold)
boxed_prediction_parsed = _parse_math_verify_expression(boxed_prediction)
if not gold_parsed:
raise ValueError("Failed to parse gold answer.")
if not boxed_prediction_parsed:
raise ValueError("Failed to parse prediction.")
equivalent = math_verify.verify(
gold_parsed,
boxed_prediction_parsed,
strict=strict,
timeout_seconds=None,
)
return "correct" if equivalent else "wrong"
except Exception as exc:
if isinstance(exc, NoAnswerException):
return "no_answer"
return "unparsable"
async def _grade_answer_submission(
self,
submission: str,
expected_answer: str,
problem_id: str = "",
) -> GradingResult:
response = await self._verifier_service.verify_answer(
prediction=submission,
gold=expected_answer,
strict=self._config.verifier_strict,
timeout_seconds=max(1, int(self._config.verifier_request_timeout_seconds)),
max_prediction_length=1000,
numeric_precision=self._config.verifier_numeric_precision,
float_rounding=self._config.verifier_float_rounding,
)
answer_status = response.status
verifier_health = await self._verifier_service.health_probe()
verifier_service_metrics = await self._verifier_service.metrics_snapshot()
logger.info(
"qed_math_verifier_result request_id=%s problem_id=%s evaluation_mode=answer"
" status=%s elapsed_ms=%.3f retry_count=%d",
response.request_id,
problem_id,
answer_status,
float(response.elapsed_ms),
int(response.retry_count),
)
score = 7 if answer_status == "correct" else 0
feedback = f"answer_status={answer_status}"
return GradingResult(
score=score,
feedback=feedback,
reward=score / 7.0,
metrics={
"verifier/rollouts/success": int(answer_status == "correct"),
"verifier/rollouts/failure": int(answer_status != "correct"),
"verifier/failures/timeout": int(answer_status == "timeout"),
"verifier/failures/rate_limit": 0,
"verifier/failures/no_input": 0,
"verifier/failures/no_score_tag": 0,
"verifier/failures/all_attempts_failed": int(
answer_status in {"internal_error", "timeout"}
),
"verifier/failures/num_retries": int(response.retry_count),
"verifier/runtime/latency_per_request": float(response.elapsed_ms),
"verifier/workers/restart_count": int(verifier_health.get("restart_count", 0)),
"verifier/workers/worker_restarted": int(response.worker_restarted),
"verifier/queue/depth": int(verifier_health.get("inflight_requests", 0)),
"verifier/requests/count": int(
verifier_service_metrics.get("verifier/requests/count", 0)
),
"verifier/requests/latency_ms": float(
verifier_service_metrics.get("verifier/requests/latency_ms", 0.0)
),
"verifier/requests/timeout_count": int(
verifier_service_metrics.get("verifier/requests/timeout_count", 0)
),
"verifier/requests/error_count": int(
verifier_service_metrics.get("verifier/requests/error_count", 0)
),
"verifier/workers/heartbeat_lag_ms": float(
verifier_service_metrics.get("verifier/workers/heartbeat_lag_ms", 0.0)
),
"verifier/cache/hit_rate": self._gold_cache_hit_rate(),
"verifier/runtime/input_tokens": 0,
"verifier/runtime/output_tokens": 0,
},
)
async def shutdown_verifier_service(self) -> None:
await self._verifier_service.stop()
def _strip_reasoning(self, text: str) -> str:
return remove_reasoning(text, self._reasoning_delimiters)
async def _grade_submission(self, submission: str) -> GradingResult:
if self._current_problem is None:
return GradingResult(
score=0,
feedback="No active problem. Call reset() first.",
reward=0.0,
)
grading_input = self._strip_reasoning(submission)
if not grading_input.strip() and submission.strip():
grading_input = submission
problem = self._current_problem.get("original_problem") or self._current_problem.get(
"problem", ""
)
reference_solution = self._current_problem.get("reference_solution", "")
grading_guidelines = parse_schema(self._current_problem.get("grading_guidelines", "") or "")
evaluation_mode = self._current_problem.get("evaluation_mode", "proof")
if evaluation_mode == "answer":
problem_id = str(self._current_problem.get("problem_id", "")).strip()
cached_reference_solution = self._get_cached_gold_answer(
problem_id,
reference_solution,
)
return await self._grade_answer_submission(
grading_input,
cached_reference_solution,
problem_id=problem_id,
)
return await self._rubric.grade(
grading_input,
problem,
reference_solution,
grading_guidelines,
)
def _apply_reward_shaping(
self,
reward: float,
output_length_tokens: int,
) -> float:
if output_length_tokens <= 0:
return reward
reward = reward * (self._discount_factor**output_length_tokens)
if self._buffer_tokens > 0 and self._max_tokens > 0:
reward += length_penalty(self._max_tokens, output_length_tokens, self._buffer_tokens)
return reward
async def submit_proof_payload(self, proof: str) -> dict:
if self._current_problem is None:
return ProofSubmissionObservation(
proof=proof,
score=0,
feedback="Proof not graded because no problem is active.",
done=True,
reward=0.0,
problem_type="proof",
attempt_number=1,
attempts_remaining=0,
is_correct=False,
metadata={"error": "No active problem. Call reset() first."},
).model_dump()
self._attempt_count += 1
problem_type = str(self._current_problem.get("problem_type", "proof"))
is_multi_step = problem_type == "multi_step"
if not proof.strip():
result = GradingResult(
score=0,
feedback="Empty proof submission.",
reward=0.0,
metrics={
"verifier/rollouts/failure": 1,
"verifier/failures/no_input": 1,
},
)
else:
result = await self._grade_submission(proof)
output_length_tokens = self._pending_output_length_tokens
self._pending_output_length_tokens = 0
shaped_reward = self._apply_reward_shaping(result.reward, output_length_tokens)
success_threshold = _coerce_positive_int(
self._current_problem.get("success_score_threshold"),
default=6,
)
is_correct = result.score >= success_threshold
attempts_remaining = max(0, self._current_max_attempts - self._attempt_count)
done = (not is_multi_step) or is_correct or attempts_remaining == 0
feedback = result.feedback
if is_multi_step and not done:
feedback = (
f"{result.feedback} Continue: "
f"attempt {self._attempt_count}/{self._current_max_attempts}."
)
grading_progress = self._build_grading_progress(
proof=proof,
feedback=feedback,
is_correct=is_correct,
done=done,
)
return ProofSubmissionObservation(
proof=proof,
score=result.score,
feedback=feedback,
done=done,
reward=shaped_reward,
problem_type=problem_type,
attempt_number=self._attempt_count,
attempts_remaining=attempts_remaining,
is_correct=is_correct,
metadata={
"grading_progress": grading_progress,
"status": grading_progress["status"],
"base_reward": result.reward,
"shaped_reward": shaped_reward,
"output_length_tokens": output_length_tokens,
"verifier_metrics": self._build_verifier_metrics(
result, shaped_reward, output_length_tokens, is_correct
),
},
).model_dump()
def _build_verifier_metrics(
self,
result: GradingResult,
shaped_reward: float,
output_length_tokens: int,
is_correct: bool,
) -> dict[str, float | int | str]:
metrics: dict[str, float | int | str] = dict(result.metrics)
metrics["reward/base"] = result.reward
metrics["reward/shaped"] = shaped_reward
metrics["reward/score_raw"] = result.score
if output_length_tokens > 0 and self._buffer_tokens > 0 and self._max_tokens > 0:
from .rubric import length_penalty as _lp
metrics["reward/overlong_penalty"] = _lp(
self._max_tokens, output_length_tokens, self._buffer_tokens
)
else:
metrics["reward/overlong_penalty"] = 0.0
metrics["episode/attempt_number"] = self._attempt_count
metrics["episode/is_correct"] = int(is_correct)
if self._current_problem is not None:
metrics["episode/problem_type"] = str(
self._current_problem.get("problem_type", "proof")
)
metrics["episode/dataset_source"] = str(
self._current_problem.get("dataset_source", "unknown")
)
return metrics
def get_grading_guidelines_payload(self) -> dict:
if self._current_problem is None:
return {
"error": "No active problem. Call reset() first.",
"grading_guidelines": "",
"done": False,
"reward": 0.0,
}
return {
"grading_guidelines": self._current_grading_guidelines_text(),
"problem_id": self._current_problem.get("problem_id", ""),
"done": False,
"reward": 0.0,
}
def list_task_ids_payload(self) -> dict:
task_ids: list[str] = []
for idx, problem in enumerate(self._problems):
raw_id = str(problem.get("problem_id", "")).strip()
task_ids.append(raw_id or f"problem_{idx + 1:06d}")
return {
"task_ids": task_ids,
"task_count": len(task_ids),
"done": False,
"reward": 0.0,
}
def _current_grading_guidelines_text(self) -> str:
if self._current_problem is None:
return ""
return parse_schema(self._current_problem.get("grading_guidelines", "") or "")
@property
def state(self) -> State:
return self._state