""" Process-based math verification service for high-throughput answer grading. """ import asyncio import logging import multiprocessing as mp import re import time from concurrent.futures import ProcessPoolExecutor from dataclasses import dataclass from functools import partial from typing import Optional import math_verify logger = logging.getLogger(__name__) @dataclass class VerifyRequest: request_id: str prediction: str gold: str strict: bool = True timeout_seconds: int = 1 max_prediction_length: int = 1000 numeric_precision: int = 5 float_rounding: int = 10 @dataclass class VerifyResponse: request_id: str status: str elapsed_ms: float retry_count: int = 0 worker_id: Optional[int] = None worker_restarted: bool = False error_type: Optional[str] = None error_message: Optional[str] = None def _parse_math_verify_expression(value: str): # Work around Windows multiprocessing issues in math-verify timeout wrappers. parsed = math_verify.parse(value, parsing_timeout=0) if parsed: return parsed boxed_match = re.search(r"\\boxed\{(.+?)\}", value) if boxed_match: return math_verify.parse(boxed_match.group(1), parsing_timeout=0) return parsed class VerificationError(Exception): pass class UnparsableException(VerificationError): pass class NoAnswerException(VerificationError): pass class EmptyBoxedException(VerificationError): pass def _extract_boxed_answer(text: str) -> str: start_idx = text.rfind(r"\boxed{") if start_idx == -1: raise NoAnswerException() start_idx += len(r"\boxed{") depth = 1 end_idx = start_idx while end_idx < len(text) and depth > 0: if text[end_idx] == "{": depth += 1 elif text[end_idx] == "}": depth -= 1 end_idx += 1 if depth != 0: raise UnparsableException() answer = text[start_idx : end_idx - 1] if not answer.strip(): raise EmptyBoxedException() return answer def _verify_answer_worker(request: VerifyRequest) -> VerifyResponse: import time start_time = time.time() try: boxed_prediction = _extract_boxed_answer(request.prediction) if len(boxed_prediction) > request.max_prediction_length: status = "unparsable" else: gold_parsed = _parse_math_verify_expression(request.gold) boxed_prediction_parsed = _parse_math_verify_expression(boxed_prediction) if not gold_parsed or not boxed_prediction_parsed: status = "unparsable" else: try: equivalent = math_verify.verify( gold_parsed, boxed_prediction_parsed, float_rounding=request.float_rounding, numeric_precision=request.numeric_precision, strict=request.strict, # We enforce timeout at the service layer via asyncio.wait_for. timeout_seconds=0, ) status = "correct" if equivalent else "wrong" except Exception as exc: if "timeout" in str(exc).lower(): status = "timeout" else: status = "internal_error" return VerifyResponse( request_id=request.request_id, status=status, elapsed_ms=(time.time() - start_time) * 1000, error_type=type(exc).__name__, error_message=str(exc), ) except NoAnswerException: status = "no_answer" except (UnparsableException, EmptyBoxedException): status = "unparsable" except Exception as exc: status = "internal_error" return VerifyResponse( request_id=request.request_id, status=status, elapsed_ms=(time.time() - start_time) * 1000, error_type=type(exc).__name__, error_message=str(exc), ) elapsed_ms = (time.time() - start_time) * 1000 return VerifyResponse( request_id=request.request_id, status=status, elapsed_ms=elapsed_ms, ) class MathVerifierService: def __init__( self, max_workers: int = 2, queue_size: int = 100, request_timeout_seconds: float = 5.0, max_retries: int = 1, strict: bool = True, numeric_precision: int = 5, float_rounding: int = 10, ): self.max_workers = max(1, max_workers) self.queue_size = max(1, queue_size) self.request_timeout_seconds = max(0.001, float(request_timeout_seconds)) self.max_retries = max(0, max_retries) self.strict = strict self.numeric_precision = numeric_precision self.float_rounding = float_rounding self._executor: ProcessPoolExecutor | None = None self._request_counter = 0 self._admission_lock = asyncio.Lock() self._inflight_requests = 0 self._restart_lock = asyncio.Lock() self._restart_count = 0 self._metrics_lock = asyncio.Lock() self._requests_total = 0 self._timeouts_total = 0 self._errors_total = 0 self._latency_total_ms = 0.0 self._last_activity_monotonic = time.perf_counter() async def start(self) -> None: if self._executor is not None: logger.debug("MathVerifierService already started") return self._executor = ProcessPoolExecutor( max_workers=self.max_workers, mp_context=mp.get_context("spawn"), ) logger.info("MathVerifierService started with %d workers", self.max_workers) async def stop(self) -> None: if self._executor is None: logger.debug("MathVerifierService not running") return loop = asyncio.get_running_loop() shutdown_call = partial(self._executor.shutdown, wait=True, cancel_futures=True) await loop.run_in_executor(None, shutdown_call) self._executor = None logger.info("MathVerifierService stopped") async def _restart_pool(self) -> None: async with self._restart_lock: await self.stop() await self.start() self._restart_count += 1 @staticmethod def _requires_restart(response: VerifyResponse) -> bool: if response.status != "internal_error": return False fatal_error_types = { "BrokenProcessPool", "EOFError", "ConnectionResetError", "ConnectionAbortedError", } if response.error_type in fatal_error_types: return True msg = (response.error_message or "").lower() fatal_fragments = [ "broken process pool", "process terminated", "worker process", "connection reset", "connection aborted", "child process", ] return any(fragment in msg for fragment in fatal_fragments) async def health_probe(self) -> dict[str, str | bool | int | float]: executor_running = self._executor is not None status = "healthy" if executor_running else "stopped" heartbeat_lag_ms = max( 0.0, (time.perf_counter() - self._last_activity_monotonic) * 1000, ) return { "status": status, "executor_running": executor_running, "inflight_requests": self._inflight_requests, "queue_size": self.queue_size, "max_workers": self.max_workers, "restart_count": self._restart_count, "heartbeat_lag_ms": heartbeat_lag_ms, } async def metrics_snapshot(self) -> dict[str, float | int]: async with self._metrics_lock: requests_count = self._requests_total latency_avg_ms = self._latency_total_ms / requests_count if requests_count > 0 else 0.0 heartbeat_lag_ms = max( 0.0, (time.perf_counter() - self._last_activity_monotonic) * 1000, ) return { "verifier/requests/count": requests_count, "verifier/requests/latency_ms": latency_avg_ms, "verifier/requests/timeout_count": self._timeouts_total, "verifier/requests/error_count": self._errors_total, "verifier/workers/restart_count": self._restart_count, "verifier/queue/depth": self._inflight_requests, "verifier/workers/heartbeat_lag_ms": heartbeat_lag_ms, } async def _record_result_metrics(self, response: VerifyResponse) -> None: async with self._metrics_lock: self._requests_total += 1 self._latency_total_ms += max(0.0, float(response.elapsed_ms)) if response.status == "timeout": self._timeouts_total += 1 if response.status == "internal_error": self._errors_total += 1 self._last_activity_monotonic = time.perf_counter() async def _try_admit_request(self) -> bool: async with self._admission_lock: if self._inflight_requests >= self.queue_size: return False self._inflight_requests += 1 return True async def _release_request_slot(self) -> None: async with self._admission_lock: self._inflight_requests = max(0, self._inflight_requests - 1) @staticmethod def _is_retryable_response(response: VerifyResponse) -> bool: if response.status == "timeout": return True if response.status != "internal_error": return False retryable_error_types = { "ClientTimeout", "BrokenProcessPool", "CancelledError", "TimeoutError", "EOFError", "ConnectionResetError", "ConnectionAbortedError", } if response.error_type in retryable_error_types: return True error_message = (response.error_message or "").lower() retryable_fragments = [ "broken process pool", "executor", "cancelled", "worker", "connection reset", "connection aborted", ] return any(fragment in error_message for fragment in retryable_fragments) async def _run_request_once( self, request: VerifyRequest, ) -> VerifyResponse: loop = asyncio.get_running_loop() request_id = request.request_id try: future = loop.run_in_executor( self._executor, _verify_answer_worker, request, ) return await asyncio.wait_for( future, timeout=self.request_timeout_seconds, ) except asyncio.TimeoutError: logger.warning( "Verification request %s timed out after %.3fs", request_id, self.request_timeout_seconds, ) return VerifyResponse( request_id=request_id, status="timeout", elapsed_ms=self.request_timeout_seconds * 1000, error_type="ClientTimeout", error_message=f"Request timed out after {self.request_timeout_seconds}s", ) except Exception as exc: logger.exception("Verification request %s failed with exception", request_id) return VerifyResponse( request_id=request_id, status="internal_error", elapsed_ms=0.0, error_type=type(exc).__name__, error_message=str(exc), ) async def verify_answer( self, prediction: str, gold: str, strict: Optional[bool] = None, timeout_seconds: Optional[int] = None, max_prediction_length: int = 1000, numeric_precision: Optional[int] = None, float_rounding: Optional[int] = None, ) -> VerifyResponse: if self._executor is None: await self.start() admitted = await self._try_admit_request() if not admitted: return VerifyResponse( request_id=f"rejected-{self._request_counter + 1}", status="internal_error", elapsed_ms=0.0, error_type="QueueFull", error_message=( f"Verifier queue saturated: in_flight={self._inflight_requests}, " f"queue_size={self.queue_size}" ), ) self._request_counter += 1 request_id = f"req-{self._request_counter}" started_at = time.perf_counter() request = VerifyRequest( request_id=request_id, prediction=prediction, gold=gold, strict=strict if strict is not None else self.strict, timeout_seconds=max(1, int(timeout_seconds or 1)), max_prediction_length=max_prediction_length, numeric_precision=( numeric_precision if numeric_precision is not None else self.numeric_precision ), float_rounding=(float_rounding if float_rounding is not None else self.float_rounding), ) retry_count = 0 worker_restarted = False try: while True: response = await self._run_request_once(request) if retry_count < self.max_retries and self._is_retryable_response(response): if self._requires_restart(response): await self._restart_pool() worker_restarted = True retry_count += 1 continue response.retry_count = retry_count response.worker_restarted = worker_restarted response.elapsed_ms = (time.perf_counter() - started_at) * 1000 await self._record_result_metrics(response) return response finally: await self._release_request_slot()