Spaces:
Sleeping
Sleeping
| """ | |
| QED Math Environment Implementation. | |
| A math proof environment that presents problems to agents and evaluates | |
| submitted proofs using LLM-based rubric grading (0-7 scale). | |
| """ | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import random | |
| import re | |
| from dataclasses import dataclass, field | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Any, Optional | |
| from uuid import uuid4 | |
| import math_verify | |
| from datasets import load_dataset | |
| from fastmcp import FastMCP | |
| from openenv.core.env_server.mcp_environment import ( | |
| MCPEnvironment, | |
| ) | |
| from openenv.core.env_server.mcp_types import ( | |
| CallToolAction, | |
| CallToolObservation, | |
| ListToolsAction, | |
| ListToolsObservation, | |
| Tool, | |
| ToolError, | |
| ToolErrorType, | |
| ) | |
| from openenv.core.env_server.types import Action, Observation, State | |
| from models import ProblemObservation, ProofSubmissionObservation | |
| from .math_verify_service import MathVerifierService | |
| from .mcp_server import register_mcp_tools | |
| from .rubric import GradingResult, MathProofRubric, length_penalty, parse_schema | |
| DEFAULT_EVALUATOR_PROMPT = ( | |
| "You are a strict math proof grader. Score the submission from 0 to 7 based on " | |
| "mathematical correctness, completeness, and logical rigor." | |
| ) | |
| logger = logging.getLogger(__name__) | |
| DatasetSource = str | dict[str, Any] | list[str | dict[str, Any]] | None | |
| def _dataset_source_from_env() -> DatasetSource: | |
| raw_spec = (os.environ.get("QED_DATASET_SPEC_JSON") or "").strip() | |
| if raw_spec: | |
| try: | |
| parsed = json.loads(raw_spec) | |
| except json.JSONDecodeError: | |
| logger.warning("Ignoring invalid QED_DATASET_SPEC_JSON value.") | |
| else: | |
| if isinstance(parsed, (str, dict, list)): | |
| return parsed | |
| logger.warning( | |
| "Ignoring QED_DATASET_SPEC_JSON with unsupported type: %s", | |
| type(parsed).__name__, | |
| ) | |
| raw_path = (os.environ.get("QED_DATASET_PATH") or "").strip() | |
| if raw_path: | |
| return raw_path | |
| return None | |
| def _default_verifier_workers() -> int: | |
| cpu_count = os.cpu_count() or 2 | |
| return max(2, min(8, cpu_count // 2 or 1)) | |
| def _default_verifier_queue_size() -> int: | |
| return _default_verifier_workers() * 32 | |
| class UnparsableException(Exception): | |
| pass | |
| class NoAnswerException(Exception): | |
| pass | |
| class EmptyBoxedException(Exception): | |
| pass | |
| def _parse_math_verify_expression(value: str) -> Any: | |
| # Work around Windows multiprocessing issues in math-verify timeout wrappers. | |
| parsed = math_verify.parse(value, parsing_timeout=0) | |
| if parsed: | |
| return parsed | |
| boxed_match = re.search(r"\\boxed\{(.+?)\}", value) | |
| if boxed_match: | |
| return math_verify.parse(boxed_match.group(1), parsing_timeout=0) | |
| return parsed | |
| def remove_reasoning( | |
| completion: str, | |
| reasoning_delimiters: list[str] | None = None, | |
| ) -> str: | |
| if not reasoning_delimiters: | |
| return completion | |
| for delim in reasoning_delimiters: | |
| if delim in completion: | |
| completion = completion.split(delim)[-1] | |
| return completion.strip() | |
| return "" | |
| class QEDMathConfig: | |
| dataset_path: DatasetSource = None | |
| grader_model: str = "gemini-3-pro" | |
| prompt_name: str = "v2" | |
| custom_reward_threshold: bool = False | |
| max_attempts: int = 1 | |
| discount_factor: float = 1.0 | |
| buffer_tokens: int = 0 | |
| max_tokens: int = 0 | |
| reasoning_delimiters: list[str] | None = None | |
| verifier_workers: int = field(default_factory=_default_verifier_workers) | |
| verifier_queue_size: int = field(default_factory=_default_verifier_queue_size) | |
| verifier_request_timeout_seconds: float = 5.0 | |
| verifier_max_retries: int = 1 | |
| verifier_strict: bool = True | |
| verifier_numeric_precision: int = 5 | |
| verifier_float_rounding: int = 10 | |
| def load_evaluator_prompt(prompt_name: str) -> str: | |
| prompt_path = ( | |
| Path(__file__).resolve().parent.parent | |
| / "prompts" | |
| / "evaluator_prompts" | |
| / f"{prompt_name}.md" | |
| ) | |
| if prompt_path.exists(): | |
| return prompt_path.read_text(encoding="utf-8") | |
| return DEFAULT_EVALUATOR_PROMPT | |
| def _bootstrap_problems() -> list[dict]: | |
| return [ | |
| { | |
| "problem": "Prove that the sum of two even integers is even.", | |
| "reference_solution": "Let a=2m and b=2n. Then a+b=2(m+n), so it is even.", | |
| "grading_guidelines": "Award full credit for a correct parity argument.", | |
| "problem_id": "bootstrap_000001", | |
| "dataset_source": "bootstrap", | |
| "problem_type": "proof", | |
| "max_attempts": 1, | |
| } | |
| ] | |
| def _coerce_positive_int(value: Any, default: int) -> int: | |
| try: | |
| parsed = int(value) | |
| except (TypeError, ValueError): | |
| return default | |
| return parsed if parsed > 0 else default | |
| def _canonical_problem_type(raw_problem: dict[str, Any]) -> str: | |
| explicit_type = _first_present_value( | |
| raw_problem, | |
| ("problem_type", "type", "problem_kind", "mode"), | |
| None, | |
| ) | |
| if isinstance(explicit_type, str): | |
| normalized = explicit_type.strip().lower() | |
| if normalized in {"proof", "answer", "multi_step"}: | |
| return normalized | |
| if bool(raw_problem.get("multi_step")): | |
| return "multi_step" | |
| evaluation_mode = _first_present_value(raw_problem, ("evaluation_mode",), None) | |
| if isinstance(evaluation_mode, str) and evaluation_mode.strip().lower() == "answer": | |
| return "answer" | |
| return "proof" | |
| def _coerce_dataset_specs(dataset_path: DatasetSource) -> list[str | dict[str, Any]]: | |
| if dataset_path is None: | |
| return [] | |
| if isinstance(dataset_path, (str, dict)): | |
| return [dataset_path] | |
| if isinstance(dataset_path, list): | |
| return dataset_path | |
| raise TypeError( | |
| "dataset_path must be None, a string path or hub id, a dataset spec dict, or a list of" | |
| " specs." | |
| ) | |
| def _read_local_problem_rows(path: Path) -> list[dict[str, Any]]: | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Dataset path does not exist: {path}") | |
| rows: list[dict[str, Any]] = [] | |
| suffix = path.suffix.lower() | |
| if suffix == ".jsonl": | |
| with path.open("r", encoding="utf-8") as handle: | |
| for line in handle: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| parsed = json.loads(line) | |
| if isinstance(parsed, dict): | |
| rows.append(parsed) | |
| elif suffix == ".json": | |
| parsed = json.loads(path.read_text(encoding="utf-8")) | |
| if isinstance(parsed, list): | |
| rows = [item for item in parsed if isinstance(item, dict)] | |
| elif isinstance(parsed, dict) and isinstance(parsed.get("problems"), list): | |
| rows = [item for item in parsed["problems"] if isinstance(item, dict)] | |
| else: | |
| raise ValueError( | |
| "JSON dataset must be a list of problem objects or contain 'problems'." | |
| ) | |
| else: | |
| raise ValueError("Unsupported dataset format. Expected .jsonl or .json.") | |
| return rows | |
| def _read_hub_problem_rows( | |
| spec: str | dict[str, Any], | |
| ) -> tuple[list[dict[str, Any]], str]: | |
| if isinstance(spec, str): | |
| hub_id = spec | |
| config = None | |
| split = "train" | |
| else: | |
| hub_id = str(spec.get("hub_id") or spec.get("dataset") or "").strip() | |
| config = spec.get("config") | |
| split = spec.get("split", "train") | |
| if not hub_id: | |
| raise ValueError("Hub dataset specs must include 'hub_id' or 'dataset'.") | |
| load_args: tuple[Any, ...] = (hub_id,) | |
| if config is not None: | |
| load_args += (config,) | |
| dataset = load_dataset( | |
| *load_args, | |
| split=split, | |
| ) | |
| rows = [dict(row) for row in dataset] | |
| logger.info( | |
| "Loaded QED math hub dataset %s%s split=%s with %d rows", | |
| hub_id, | |
| f"/{config}" if config else "", | |
| split, | |
| len(rows), | |
| ) | |
| return rows, hub_id | |
| def _first_present_value( | |
| raw_problem: dict[str, Any], | |
| keys: tuple[str, ...], | |
| default: Any = None, | |
| ) -> Any: | |
| for key in keys: | |
| if key in raw_problem and raw_problem[key] is not None: | |
| return raw_problem[key] | |
| return default | |
| def _normalize_problem(raw_problem: dict[str, Any], index: int, dataset_source: str) -> dict: | |
| problem = _first_present_value(raw_problem, ("problem", "task", "Problem")) | |
| if not isinstance(problem, str) or not problem.strip(): | |
| raise ValueError("Dataset row is missing a non-empty problem statement.") | |
| reference_solution = _first_present_value( | |
| raw_problem, | |
| ("reference_solution", "solution", "answer", "Solution"), | |
| "", | |
| ) | |
| grading_guidelines = _first_present_value( | |
| raw_problem, | |
| ( | |
| "grading_guidelines", | |
| "rubrics", | |
| "schema", | |
| "schema_0", | |
| "Grading guidelines", | |
| "details", | |
| ), | |
| "", | |
| ) | |
| problem_id = _first_present_value( | |
| raw_problem, | |
| ("problem_id", "id"), | |
| f"problem_{index:06d}", | |
| ) | |
| resolved_dataset_source = _first_present_value( | |
| raw_problem, | |
| ("dataset_source", "dataset", "data_source"), | |
| dataset_source, | |
| ) | |
| problem_type = _canonical_problem_type(raw_problem) | |
| default_max_attempts = 1 if problem_type != "multi_step" else 3 | |
| max_attempts = _coerce_positive_int( | |
| _first_present_value(raw_problem, ("max_attempts", "attempts", "num_attempts"), None), | |
| default=default_max_attempts, | |
| ) | |
| success_score_threshold = _coerce_positive_int( | |
| _first_present_value(raw_problem, ("success_score_threshold",), None), | |
| default=6, | |
| ) | |
| evaluation_mode = _first_present_value(raw_problem, ("evaluation_mode",), None) | |
| if isinstance(evaluation_mode, str): | |
| evaluation_mode = evaluation_mode.strip().lower() | |
| else: | |
| evaluation_mode = "answer" if problem_type == "answer" else "proof" | |
| if evaluation_mode not in {"proof", "answer"}: | |
| evaluation_mode = "proof" | |
| if ( | |
| evaluation_mode == "answer" | |
| and isinstance(reference_solution, str) | |
| and "\\boxed{" not in reference_solution | |
| ): | |
| reference_solution = f"\\boxed{{{reference_solution}}}" | |
| original_problem = _first_present_value(raw_problem, ("original_problem",), None) | |
| return { | |
| "problem": problem, | |
| "original_problem": original_problem, | |
| "reference_solution": str(reference_solution), | |
| "grading_guidelines": grading_guidelines, | |
| "problem_id": str(problem_id), | |
| "dataset_source": str(resolved_dataset_source), | |
| "problem_type": problem_type, | |
| "max_attempts": max_attempts, | |
| "success_score_threshold": success_score_threshold, | |
| "evaluation_mode": evaluation_mode, | |
| } | |
| def _load_problems_from_spec( | |
| spec: str | dict[str, Any], | |
| start_index: int, | |
| ) -> list[dict[str, Any]]: | |
| if isinstance(spec, dict): | |
| if "path" in spec: | |
| path = Path(spec["path"]) | |
| rows = _read_local_problem_rows(path) | |
| dataset_source = str(spec.get("dataset_source") or spec.get("dataset") or path.stem) | |
| elif "hub_id" in spec or "dataset" in spec: | |
| rows, dataset_source = _read_hub_problem_rows(spec) | |
| else: | |
| raise ValueError( | |
| "Dataset spec dict must include either 'path' for local files or" | |
| " 'hub_id'/'dataset' for hub datasets." | |
| ) | |
| else: | |
| candidate_path = Path(spec) | |
| if candidate_path.exists() or candidate_path.suffix.lower() in { | |
| ".json", | |
| ".jsonl", | |
| }: | |
| rows = _read_local_problem_rows(candidate_path) | |
| dataset_source = candidate_path.stem | |
| elif "/" in spec: | |
| rows, dataset_source = _read_hub_problem_rows(spec) | |
| else: | |
| raise ValueError( | |
| "Dataset source must be a local .json/.jsonl path or a Hugging Face dataset id" | |
| " like 'owner/name'." | |
| ) | |
| problems: list[dict[str, Any]] = [] | |
| for offset, row in enumerate(rows, start=1): | |
| try: | |
| problems.append(_normalize_problem(row, start_index + offset, dataset_source)) | |
| except ValueError: | |
| logger.warning( | |
| "Skipping invalid QED math dataset row from %s at offset %d", | |
| dataset_source, | |
| offset, | |
| ) | |
| return problems | |
| def load_problems(dataset_path: DatasetSource) -> list[dict]: | |
| dataset_specs = _coerce_dataset_specs(dataset_path) | |
| if not dataset_specs: | |
| return _bootstrap_problems() | |
| problems: list[dict[str, Any]] = [] | |
| for spec in dataset_specs: | |
| problems.extend(_load_problems_from_spec(spec, start_index=len(problems))) | |
| if not problems: | |
| raise ValueError( | |
| "No valid QED math problems were loaded from the configured dataset source." | |
| ) | |
| return problems | |
| class QEDMathEnvironment(MCPEnvironment): | |
| def __init__( | |
| self, | |
| dataset_path: DatasetSource = None, | |
| grader_model: str | None = None, | |
| prompt_name: str | None = None, | |
| custom_reward_threshold: bool | None = None, | |
| max_attempts: int | None = None, | |
| config: QEDMathConfig | None = None, | |
| ): | |
| mcp = FastMCP("qed_math_env") | |
| register_mcp_tools(mcp, self) | |
| super().__init__(mcp) | |
| base_config = config or QEDMathConfig() | |
| resolved_dataset_path = ( | |
| dataset_path if dataset_path is not None else base_config.dataset_path | |
| ) | |
| if resolved_dataset_path is None: | |
| resolved_dataset_path = _dataset_source_from_env() | |
| self._config = QEDMathConfig( | |
| dataset_path=resolved_dataset_path, | |
| grader_model=(grader_model if grader_model is not None else base_config.grader_model), | |
| prompt_name=(prompt_name if prompt_name is not None else base_config.prompt_name), | |
| custom_reward_threshold=( | |
| custom_reward_threshold | |
| if custom_reward_threshold is not None | |
| else base_config.custom_reward_threshold | |
| ), | |
| max_attempts=(max_attempts if max_attempts is not None else base_config.max_attempts), | |
| discount_factor=base_config.discount_factor, | |
| buffer_tokens=base_config.buffer_tokens, | |
| max_tokens=base_config.max_tokens, | |
| reasoning_delimiters=base_config.reasoning_delimiters, | |
| verifier_workers=base_config.verifier_workers, | |
| verifier_queue_size=base_config.verifier_queue_size, | |
| verifier_request_timeout_seconds=base_config.verifier_request_timeout_seconds, | |
| verifier_max_retries=base_config.verifier_max_retries, | |
| verifier_strict=base_config.verifier_strict, | |
| verifier_numeric_precision=base_config.verifier_numeric_precision, | |
| verifier_float_rounding=base_config.verifier_float_rounding, | |
| ) | |
| self._dataset_path = self._config.dataset_path | |
| self._grader_model = self._config.grader_model | |
| self._prompt_name = self._config.prompt_name | |
| self._prompt_template = load_evaluator_prompt(self._config.prompt_name) | |
| self._problems: list[dict] = load_problems(self._config.dataset_path) | |
| self._current_problem: dict | None = None | |
| self._gold_cache_signature = self._build_gold_cache_signature() | |
| self._gold_cache_problem_count = len(self._problems) | |
| self._gold_answer_cache: dict[tuple[str, tuple[bool, int, int]], str] = {} | |
| self._gold_cache_hits = 0 | |
| self._gold_cache_misses = 0 | |
| self._build_gold_answer_cache() | |
| judge_api_base_url = os.environ.get("JUDGE_API_BASE_URL") or os.environ.get( | |
| "OPENAI_BASE_URL" | |
| ) | |
| judge_api_key = os.environ.get("JUDGE_API_KEY") or os.environ.get("OPENAI_API_KEY") | |
| judge_model = os.environ.get("JUDGE_MODEL") or self._config.grader_model | |
| self._rubric = MathProofRubric( | |
| grader_model=judge_model, | |
| prompt_template=self._prompt_template, | |
| custom_threshold=self._config.custom_reward_threshold, | |
| api_base_url=judge_api_base_url, | |
| api_key=judge_api_key, | |
| ) | |
| self._verifier_service = MathVerifierService( | |
| max_workers=self._config.verifier_workers, | |
| queue_size=self._config.verifier_queue_size, | |
| request_timeout_seconds=self._config.verifier_request_timeout_seconds, | |
| max_retries=self._config.verifier_max_retries, | |
| strict=self._config.verifier_strict, | |
| numeric_precision=self._config.verifier_numeric_precision, | |
| float_rounding=self._config.verifier_float_rounding, | |
| ) | |
| self._discount_factor = self._config.discount_factor | |
| self._buffer_tokens = self._config.buffer_tokens | |
| self._max_tokens = self._config.max_tokens | |
| self._reasoning_delimiters = self._config.reasoning_delimiters | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._reset_count = 0 | |
| self._attempt_count = 0 | |
| self._current_max_attempts = max(1, int(self._config.max_attempts)) | |
| self._pending_output_length_tokens: int = 0 | |
| def _build_gold_cache_signature(self) -> tuple[bool, int, int]: | |
| return ( | |
| bool(self._config.verifier_strict), | |
| int(self._config.verifier_numeric_precision), | |
| int(self._config.verifier_float_rounding), | |
| ) | |
| def _gold_cache_key(self, problem_id: str) -> tuple[str, tuple[bool, int, int]]: | |
| return (problem_id, self._gold_cache_signature) | |
| def _build_gold_answer_cache(self) -> None: | |
| self._gold_answer_cache = {} | |
| for problem in self._problems: | |
| evaluation_mode = problem.get("evaluation_mode", "proof") | |
| if evaluation_mode != "answer": | |
| continue | |
| problem_id = str(problem.get("problem_id", "")).strip() | |
| if not problem_id: | |
| continue | |
| reference_solution = str(problem.get("reference_solution", "")) | |
| try: | |
| _parse_math_verify_expression(reference_solution) | |
| except Exception: | |
| logger.warning( | |
| "Failed to pre-parse answer-mode gold for problem_id=%s;" | |
| " deferring to runtime verifier.", | |
| problem_id, | |
| ) | |
| self._gold_answer_cache[self._gold_cache_key(problem_id)] = reference_solution | |
| def _refresh_gold_cache_if_needed(self) -> None: | |
| current_signature = self._build_gold_cache_signature() | |
| current_problem_count = len(self._problems) | |
| if ( | |
| current_signature != self._gold_cache_signature | |
| or current_problem_count != self._gold_cache_problem_count | |
| ): | |
| self._gold_cache_signature = current_signature | |
| self._gold_cache_problem_count = current_problem_count | |
| self._build_gold_answer_cache() | |
| def _get_cached_gold_answer(self, problem_id: str, fallback: str) -> str: | |
| self._refresh_gold_cache_if_needed() | |
| key = self._gold_cache_key(problem_id) | |
| if key in self._gold_answer_cache: | |
| self._gold_cache_hits += 1 | |
| return self._gold_answer_cache[key] | |
| self._gold_cache_misses += 1 | |
| return fallback | |
| def _gold_cache_hit_rate(self) -> float: | |
| total = self._gold_cache_hits + self._gold_cache_misses | |
| if total == 0: | |
| return 0.0 | |
| return float(self._gold_cache_hits) / float(total) | |
| async def step_async( | |
| self, | |
| action: Any, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> Any: | |
| self._state.step_count += 1 | |
| from openenv.core.env_server.mcp_types import ( | |
| CallToolAction, | |
| CallToolObservation, | |
| ) | |
| MCP_TIMEOUT = 600.0 | |
| if isinstance(action, ListToolsAction): | |
| try: | |
| tools_result = await self._async_list_tools() | |
| tools = [ | |
| Tool( | |
| name=t.name, | |
| description=t.description or "", | |
| input_schema=t.inputSchema if hasattr(t, "inputSchema") else {}, | |
| ) | |
| for t in tools_result | |
| ] | |
| return ListToolsObservation(tools=tools) | |
| except Exception as exc: | |
| return ListToolsObservation( | |
| tools=[], | |
| metadata={"error": str(exc), "error_type": "list_tools_failed"}, | |
| ) | |
| if isinstance(action, CallToolAction): | |
| if action.tool_name == "submit_proof": | |
| raw = kwargs.get("output_length_tokens", 0) | |
| try: | |
| self._pending_output_length_tokens = max(0, int(raw)) | |
| except (TypeError, ValueError): | |
| self._pending_output_length_tokens = 0 | |
| timeout = timeout_s if timeout_s is not None else MCP_TIMEOUT | |
| try: | |
| result = await asyncio.wait_for( | |
| self._async_call_tool(action.tool_name, action.arguments), | |
| timeout=timeout, | |
| ) | |
| return CallToolObservation(tool_name=action.tool_name, result=result) | |
| except asyncio.TimeoutError: | |
| return CallToolObservation( | |
| tool_name=action.tool_name, | |
| result=None, | |
| error=ToolError( | |
| error_type=ToolErrorType.TIMEOUT, | |
| message=f"Tool '{action.tool_name}' timed out after {timeout}s", | |
| ), | |
| ) | |
| except Exception as exc: | |
| return CallToolObservation( | |
| tool_name=action.tool_name, | |
| result=None, | |
| error=ToolError( | |
| error_type=ToolErrorType.EXECUTION_ERROR, | |
| message=str(exc), | |
| ), | |
| ) | |
| return self._step_impl(action, timeout_s=timeout_s, **kwargs) | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> Observation: | |
| selected_problem_id = kwargs.pop("problem_id", None) | |
| self._refresh_gold_cache_if_needed() | |
| self._state = State( | |
| episode_id=episode_id or str(uuid4()), | |
| step_count=0, | |
| ) | |
| self._attempt_count = 0 | |
| self._reset_count += 1 | |
| rubric_reset = getattr(self._rubric, "reset", None) | |
| if callable(rubric_reset): | |
| rubric_reset() | |
| if self._problems: | |
| if selected_problem_id is not None: | |
| selected = next( | |
| ( | |
| problem | |
| for problem in self._problems | |
| if problem.get("problem_id") == selected_problem_id | |
| ), | |
| None, | |
| ) | |
| self._current_problem = selected or self._problems[0] | |
| elif seed is not None: | |
| rng = random.Random(seed) | |
| self._current_problem = rng.choice(self._problems) | |
| else: | |
| self._current_problem = random.choice(self._problems) | |
| else: | |
| self._current_problem = None | |
| if self._current_problem is None: | |
| return Observation( | |
| done=False, | |
| reward=0.0, | |
| metadata={ | |
| "error": "No problems loaded. Provide a valid dataset_path.", | |
| "status": "empty", | |
| }, | |
| ) | |
| self._current_max_attempts = _coerce_positive_int( | |
| self._current_problem.get("max_attempts"), | |
| default=max(1, int(self._config.max_attempts)), | |
| ) | |
| return ProblemObservation( | |
| problem=self._current_problem.get("problem", ""), | |
| reference_solution=( | |
| self._current_problem.get("reference_solution", "") | |
| if self._current_problem.get("evaluation_mode", "proof") == "answer" | |
| else "" | |
| ), | |
| grading_guidelines=parse_schema( | |
| self._current_problem.get("grading_guidelines", "") or "" | |
| ), | |
| problem_id=self._current_problem.get("problem_id", ""), | |
| dataset_source=self._current_problem.get("dataset_source", ""), | |
| problem_type=self._current_problem.get("problem_type", "proof"), | |
| max_attempts=self._current_max_attempts, | |
| done=False, | |
| reward=0.0, | |
| metadata={ | |
| "status": "ready", | |
| "reset_count": self._reset_count, | |
| "step_count": self._state.step_count, | |
| "attempt_count": self._attempt_count, | |
| }, | |
| ) | |
| def _step_impl( | |
| self, | |
| action: Action, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> Observation: | |
| return Observation( | |
| done=False, | |
| reward=0.0, | |
| metadata={ | |
| "error": f"Unknown action type: {type(action).__name__}. " | |
| "Use MCP tools (CallToolAction) for interactions." | |
| }, | |
| ) | |
| def step( | |
| self, | |
| action: Action, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> Observation: | |
| self._state.step_count += 1 | |
| obs = super().step(action, timeout_s=timeout_s, **kwargs) | |
| if ( | |
| isinstance(action, CallToolAction) | |
| and action.tool_name == "submit_proof" | |
| and isinstance(obs, CallToolObservation) | |
| and obs.error is None | |
| ): | |
| payload = self._extract_tool_payload(obs.result) | |
| if isinstance(payload, dict): | |
| proof_obs = ProofSubmissionObservation( | |
| proof=str(payload.get("proof", "")), | |
| score=int(payload.get("score", 0)), | |
| feedback=str(payload.get("feedback", "")), | |
| done=bool(payload.get("done", True)), | |
| reward=float(payload.get("reward", 0.0)), | |
| problem_type=str(payload.get("problem_type", "proof")), | |
| attempt_number=int(payload.get("attempt_number", 1)), | |
| attempts_remaining=int(payload.get("attempts_remaining", 0)), | |
| is_correct=bool(payload.get("is_correct", False)), | |
| metadata=dict(payload.get("metadata", {})), | |
| ) | |
| metadata = dict(obs.metadata) | |
| metadata["proof_submission"] = proof_obs.model_dump() | |
| return CallToolObservation( | |
| tool_name=obs.tool_name, | |
| result=obs.result, | |
| error=obs.error, | |
| done=proof_obs.done, | |
| reward=proof_obs.reward, | |
| metadata=metadata, | |
| ) | |
| return obs | |
| def _extract_tool_payload(tool_result: Any) -> Any: | |
| if tool_result is None: | |
| return None | |
| data = getattr(tool_result, "data", None) | |
| if data is not None: | |
| return data | |
| structured_content = getattr(tool_result, "structured_content", None) | |
| if structured_content is not None: | |
| return structured_content | |
| if isinstance(tool_result, dict): | |
| return tool_result | |
| return None | |
| def _utc_now_iso() -> str: | |
| return datetime.now(timezone.utc).isoformat() | |
| def _chunk_feedback(feedback: str, chunk_size: int = 280) -> list[str]: | |
| if not feedback: | |
| return [] | |
| return [feedback[i : i + chunk_size] for i in range(0, len(feedback), chunk_size)] | |
| def _build_grading_progress( | |
| self, | |
| proof: str, | |
| feedback: str, | |
| is_correct: bool, | |
| done: bool, | |
| ) -> dict[str, Any]: | |
| feedback_chunks = self._chunk_feedback(feedback) | |
| return { | |
| "status": "completed" if done else "in_progress", | |
| "progress": 1.0, | |
| "events": [ | |
| { | |
| "stage": "submission_received", | |
| "progress": 0.2, | |
| "message": "Proof submission received.", | |
| "timestamp": self._utc_now_iso(), | |
| }, | |
| { | |
| "stage": "grading_started", | |
| "progress": 0.6, | |
| "message": "Grading started.", | |
| "timestamp": self._utc_now_iso(), | |
| }, | |
| { | |
| "stage": "grading_completed", | |
| "progress": 0.9, | |
| "message": "Grading completed.", | |
| "timestamp": self._utc_now_iso(), | |
| }, | |
| { | |
| "stage": "result_ready", | |
| "progress": 1.0, | |
| "message": "Result ready for client consumption.", | |
| "timestamp": self._utc_now_iso(), | |
| }, | |
| ], | |
| "realtime": { | |
| "websocket_supported": True, | |
| "submission_type": "proof" if proof.strip() else "empty", | |
| }, | |
| "streaming_feedback": { | |
| "chunks": feedback_chunks, | |
| "chunk_count": len(feedback_chunks), | |
| "is_final": True, | |
| }, | |
| "is_correct": is_correct, | |
| } | |
| def get_problem_payload(self) -> dict: | |
| if self._current_problem is None: | |
| return { | |
| "error": "No active problem. Call reset() first.", | |
| "done": False, | |
| "reward": 0.0, | |
| } | |
| problem_type = self._current_problem.get("problem_type", "proof") | |
| evaluation_mode = self._current_problem.get("evaluation_mode") | |
| if isinstance(evaluation_mode, str): | |
| evaluation_mode = evaluation_mode.strip().lower() | |
| else: | |
| evaluation_mode = "answer" if problem_type == "answer" else "proof" | |
| if evaluation_mode not in {"proof", "answer"}: | |
| evaluation_mode = "proof" | |
| payload = { | |
| "problem": self._current_problem.get("problem", ""), | |
| "grading_guidelines": self._current_grading_guidelines_text(), | |
| "problem_id": self._current_problem.get("problem_id", ""), | |
| "dataset_source": self._current_problem.get("dataset_source", ""), | |
| "problem_type": problem_type, | |
| "max_attempts": self._current_max_attempts, | |
| "attempt_count": self._attempt_count, | |
| "done": False, | |
| "reward": 0.0, | |
| } | |
| if evaluation_mode == "answer": | |
| payload["reference_solution"] = self._current_problem.get("reference_solution", "") | |
| return payload | |
| def _verify_math( | |
| prediction: str, | |
| gold: str, | |
| strict: bool = True, | |
| max_prediction_length: int = 1000, | |
| ) -> str: | |
| try: | |
| if not isinstance(prediction, str) or not isinstance(gold, str): | |
| raise ValueError("Prediction and gold must be strings") | |
| boxed_start = prediction.rfind("\\boxed{") | |
| if boxed_start < 0: | |
| raise NoAnswerException() | |
| boxed_prediction = prediction[boxed_start:] | |
| if "\\boxed{}" in boxed_prediction: | |
| raise EmptyBoxedException() | |
| if len(boxed_prediction) > max_prediction_length: | |
| raise UnparsableException() | |
| gold_parsed = _parse_math_verify_expression(gold) | |
| boxed_prediction_parsed = _parse_math_verify_expression(boxed_prediction) | |
| if not gold_parsed: | |
| raise ValueError("Failed to parse gold answer.") | |
| if not boxed_prediction_parsed: | |
| raise ValueError("Failed to parse prediction.") | |
| equivalent = math_verify.verify( | |
| gold_parsed, | |
| boxed_prediction_parsed, | |
| strict=strict, | |
| timeout_seconds=None, | |
| ) | |
| return "correct" if equivalent else "wrong" | |
| except Exception as exc: | |
| if isinstance(exc, NoAnswerException): | |
| return "no_answer" | |
| return "unparsable" | |
| async def _grade_answer_submission( | |
| self, | |
| submission: str, | |
| expected_answer: str, | |
| problem_id: str = "", | |
| ) -> GradingResult: | |
| response = await self._verifier_service.verify_answer( | |
| prediction=submission, | |
| gold=expected_answer, | |
| strict=self._config.verifier_strict, | |
| timeout_seconds=max(1, int(self._config.verifier_request_timeout_seconds)), | |
| max_prediction_length=1000, | |
| numeric_precision=self._config.verifier_numeric_precision, | |
| float_rounding=self._config.verifier_float_rounding, | |
| ) | |
| answer_status = response.status | |
| verifier_health = await self._verifier_service.health_probe() | |
| verifier_service_metrics = await self._verifier_service.metrics_snapshot() | |
| logger.info( | |
| "qed_math_verifier_result request_id=%s problem_id=%s evaluation_mode=answer" | |
| " status=%s elapsed_ms=%.3f retry_count=%d", | |
| response.request_id, | |
| problem_id, | |
| answer_status, | |
| float(response.elapsed_ms), | |
| int(response.retry_count), | |
| ) | |
| score = 7 if answer_status == "correct" else 0 | |
| feedback = f"answer_status={answer_status}" | |
| return GradingResult( | |
| score=score, | |
| feedback=feedback, | |
| reward=score / 7.0, | |
| metrics={ | |
| "verifier/rollouts/success": int(answer_status == "correct"), | |
| "verifier/rollouts/failure": int(answer_status != "correct"), | |
| "verifier/failures/timeout": int(answer_status == "timeout"), | |
| "verifier/failures/rate_limit": 0, | |
| "verifier/failures/no_input": 0, | |
| "verifier/failures/no_score_tag": 0, | |
| "verifier/failures/all_attempts_failed": int( | |
| answer_status in {"internal_error", "timeout"} | |
| ), | |
| "verifier/failures/num_retries": int(response.retry_count), | |
| "verifier/runtime/latency_per_request": float(response.elapsed_ms), | |
| "verifier/workers/restart_count": int(verifier_health.get("restart_count", 0)), | |
| "verifier/workers/worker_restarted": int(response.worker_restarted), | |
| "verifier/queue/depth": int(verifier_health.get("inflight_requests", 0)), | |
| "verifier/requests/count": int( | |
| verifier_service_metrics.get("verifier/requests/count", 0) | |
| ), | |
| "verifier/requests/latency_ms": float( | |
| verifier_service_metrics.get("verifier/requests/latency_ms", 0.0) | |
| ), | |
| "verifier/requests/timeout_count": int( | |
| verifier_service_metrics.get("verifier/requests/timeout_count", 0) | |
| ), | |
| "verifier/requests/error_count": int( | |
| verifier_service_metrics.get("verifier/requests/error_count", 0) | |
| ), | |
| "verifier/workers/heartbeat_lag_ms": float( | |
| verifier_service_metrics.get("verifier/workers/heartbeat_lag_ms", 0.0) | |
| ), | |
| "verifier/cache/hit_rate": self._gold_cache_hit_rate(), | |
| "verifier/runtime/input_tokens": 0, | |
| "verifier/runtime/output_tokens": 0, | |
| }, | |
| ) | |
| async def shutdown_verifier_service(self) -> None: | |
| await self._verifier_service.stop() | |
| def _strip_reasoning(self, text: str) -> str: | |
| return remove_reasoning(text, self._reasoning_delimiters) | |
| async def _grade_submission(self, submission: str) -> GradingResult: | |
| if self._current_problem is None: | |
| return GradingResult( | |
| score=0, | |
| feedback="No active problem. Call reset() first.", | |
| reward=0.0, | |
| ) | |
| grading_input = self._strip_reasoning(submission) | |
| if not grading_input.strip() and submission.strip(): | |
| grading_input = submission | |
| problem = self._current_problem.get("original_problem") or self._current_problem.get( | |
| "problem", "" | |
| ) | |
| reference_solution = self._current_problem.get("reference_solution", "") | |
| grading_guidelines = parse_schema(self._current_problem.get("grading_guidelines", "") or "") | |
| evaluation_mode = self._current_problem.get("evaluation_mode", "proof") | |
| if evaluation_mode == "answer": | |
| problem_id = str(self._current_problem.get("problem_id", "")).strip() | |
| cached_reference_solution = self._get_cached_gold_answer( | |
| problem_id, | |
| reference_solution, | |
| ) | |
| return await self._grade_answer_submission( | |
| grading_input, | |
| cached_reference_solution, | |
| problem_id=problem_id, | |
| ) | |
| return await self._rubric.grade( | |
| grading_input, | |
| problem, | |
| reference_solution, | |
| grading_guidelines, | |
| ) | |
| def _apply_reward_shaping( | |
| self, | |
| reward: float, | |
| output_length_tokens: int, | |
| ) -> float: | |
| if output_length_tokens <= 0: | |
| return reward | |
| reward = reward * (self._discount_factor**output_length_tokens) | |
| if self._buffer_tokens > 0 and self._max_tokens > 0: | |
| reward += length_penalty(self._max_tokens, output_length_tokens, self._buffer_tokens) | |
| return reward | |
| async def submit_proof_payload(self, proof: str) -> dict: | |
| if self._current_problem is None: | |
| return ProofSubmissionObservation( | |
| proof=proof, | |
| score=0, | |
| feedback="Proof not graded because no problem is active.", | |
| done=True, | |
| reward=0.0, | |
| problem_type="proof", | |
| attempt_number=1, | |
| attempts_remaining=0, | |
| is_correct=False, | |
| metadata={"error": "No active problem. Call reset() first."}, | |
| ).model_dump() | |
| self._attempt_count += 1 | |
| problem_type = str(self._current_problem.get("problem_type", "proof")) | |
| is_multi_step = problem_type == "multi_step" | |
| if not proof.strip(): | |
| result = GradingResult( | |
| score=0, | |
| feedback="Empty proof submission.", | |
| reward=0.0, | |
| metrics={ | |
| "verifier/rollouts/failure": 1, | |
| "verifier/failures/no_input": 1, | |
| }, | |
| ) | |
| else: | |
| result = await self._grade_submission(proof) | |
| output_length_tokens = self._pending_output_length_tokens | |
| self._pending_output_length_tokens = 0 | |
| shaped_reward = self._apply_reward_shaping(result.reward, output_length_tokens) | |
| success_threshold = _coerce_positive_int( | |
| self._current_problem.get("success_score_threshold"), | |
| default=6, | |
| ) | |
| is_correct = result.score >= success_threshold | |
| attempts_remaining = max(0, self._current_max_attempts - self._attempt_count) | |
| done = (not is_multi_step) or is_correct or attempts_remaining == 0 | |
| feedback = result.feedback | |
| if is_multi_step and not done: | |
| feedback = ( | |
| f"{result.feedback} Continue: " | |
| f"attempt {self._attempt_count}/{self._current_max_attempts}." | |
| ) | |
| grading_progress = self._build_grading_progress( | |
| proof=proof, | |
| feedback=feedback, | |
| is_correct=is_correct, | |
| done=done, | |
| ) | |
| return ProofSubmissionObservation( | |
| proof=proof, | |
| score=result.score, | |
| feedback=feedback, | |
| done=done, | |
| reward=shaped_reward, | |
| problem_type=problem_type, | |
| attempt_number=self._attempt_count, | |
| attempts_remaining=attempts_remaining, | |
| is_correct=is_correct, | |
| metadata={ | |
| "grading_progress": grading_progress, | |
| "status": grading_progress["status"], | |
| "base_reward": result.reward, | |
| "shaped_reward": shaped_reward, | |
| "output_length_tokens": output_length_tokens, | |
| "verifier_metrics": self._build_verifier_metrics( | |
| result, shaped_reward, output_length_tokens, is_correct | |
| ), | |
| }, | |
| ).model_dump() | |
| def _build_verifier_metrics( | |
| self, | |
| result: GradingResult, | |
| shaped_reward: float, | |
| output_length_tokens: int, | |
| is_correct: bool, | |
| ) -> dict[str, float | int | str]: | |
| metrics: dict[str, float | int | str] = dict(result.metrics) | |
| metrics["reward/base"] = result.reward | |
| metrics["reward/shaped"] = shaped_reward | |
| metrics["reward/score_raw"] = result.score | |
| if output_length_tokens > 0 and self._buffer_tokens > 0 and self._max_tokens > 0: | |
| from .rubric import length_penalty as _lp | |
| metrics["reward/overlong_penalty"] = _lp( | |
| self._max_tokens, output_length_tokens, self._buffer_tokens | |
| ) | |
| else: | |
| metrics["reward/overlong_penalty"] = 0.0 | |
| metrics["episode/attempt_number"] = self._attempt_count | |
| metrics["episode/is_correct"] = int(is_correct) | |
| if self._current_problem is not None: | |
| metrics["episode/problem_type"] = str( | |
| self._current_problem.get("problem_type", "proof") | |
| ) | |
| metrics["episode/dataset_source"] = str( | |
| self._current_problem.get("dataset_source", "unknown") | |
| ) | |
| return metrics | |
| def get_grading_guidelines_payload(self) -> dict: | |
| if self._current_problem is None: | |
| return { | |
| "error": "No active problem. Call reset() first.", | |
| "grading_guidelines": "", | |
| "done": False, | |
| "reward": 0.0, | |
| } | |
| return { | |
| "grading_guidelines": self._current_grading_guidelines_text(), | |
| "problem_id": self._current_problem.get("problem_id", ""), | |
| "done": False, | |
| "reward": 0.0, | |
| } | |
| def list_task_ids_payload(self) -> dict: | |
| task_ids: list[str] = [] | |
| for idx, problem in enumerate(self._problems): | |
| raw_id = str(problem.get("problem_id", "")).strip() | |
| task_ids.append(raw_id or f"problem_{idx + 1:06d}") | |
| return { | |
| "task_ids": task_ids, | |
| "task_count": len(task_ids), | |
| "done": False, | |
| "reward": 0.0, | |
| } | |
| def _current_grading_guidelines_text(self) -> str: | |
| if self._current_problem is None: | |
| return "" | |
| return parse_schema(self._current_problem.get("grading_guidelines", "") or "") | |
| def state(self) -> State: | |
| return self._state | |