"""Inference script for QED Math with strict stdout compliance. This script emits only [START], [STEP], and [END] lines to stdout. """ from __future__ import annotations import asyncio import json import os import re import sys import traceback from pathlib import Path from typing import Any, Optional, cast from openai import OpenAI from client import QEDMathEnv def _load_local_dotenv() -> None: """Load .env values if present without overriding exported env vars.""" candidates = [Path.cwd() / ".env", Path(__file__).resolve().parent / ".env"] seen: set[Path] = set() for env_path in candidates: resolved = env_path.resolve() if resolved in seen or not resolved.is_file(): continue seen.add(resolved) for raw_line in resolved.read_text(encoding="utf-8").splitlines(): line = raw_line.strip() if not line or line.startswith("#"): continue if line.startswith("export "): line = line[len("export ") :].strip() key, sep, value = line.partition("=") if not sep: continue key = key.strip() value = value.strip() if len(value) >= 2 and value[0] == value[-1] and value[0] in {'"', "'"}: value = value[1:-1] if key: os.environ.setdefault(key, value) _load_local_dotenv() API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b:novita") HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("API_KEY") _raw_qed_math_url = os.getenv("QED_MATH_URL") QED_MATH_URL = ( _raw_qed_math_url.strip() if _raw_qed_math_url is not None else "https://rycerzes-qed-math-openenv.hf.space" ) TASK_NAME = os.getenv("TASK_NAME", "solve-qed-math") BENCHMARK = os.getenv("BENCHMARK", "qed-math") TASK_COUNT = max(3, int(os.getenv("TASK_COUNT", "3"))) MIN_SUBMISSION_SCORE = float(os.getenv("MIN_SUBMISSION_SCORE", "0.01")) MAX_SUBMISSION_SCORE = float(os.getenv("MAX_SUBMISSION_SCORE", "0.99")) MAX_STEPS = int(os.getenv("MAX_STEPS", "8")) TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2")) MAX_TOKENS = int(os.getenv("MAX_TOKENS", "4096")) SYSTEM_PROMPT = ( "You are an expert mathematician. Always call get_problem first, then reason " "carefully and call submit_proof with a complete solution. Use " "get_grading_guidelines if helpful." ) def _single_line(value: Any) -> str: """Normalize text values so each log record stays on one line.""" return re.sub(r"\s+", " ", str(value)).strip() def _strict_open_interval_score(score: float) -> float: """Clamp score for submission output so it remains strictly in (0, 1).""" lo = max(0.01, min(MIN_SUBMISSION_SCORE, MAX_SUBMISSION_SCORE)) hi = min(0.99, max(MIN_SUBMISSION_SCORE, MAX_SUBMISSION_SCORE)) if hi <= lo: lo, hi = 0.01, 0.99 return min(max(float(score), lo), hi) def log_start(task: str, env: str, model: str) -> None: print( f"[START] task={_single_line(task)} env={_single_line(env)} model={_single_line(model)}", flush=True, ) def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: error_val = str(error) if error else "null" print( f"[STEP] step={step} action={_single_line(action)} reward={reward:.2f} " f"done={str(done).lower()} error={error_val}", flush=True, ) def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None: rewards_str = ",".join(f"{reward:.2f}" for reward in rewards) score = _strict_open_interval_score(score) end_line = ( f"[END] success={str(success).lower()} steps={steps} " f"score={score:.2f} rewards={rewards_str}" ) print( end_line, flush=True, ) def _coerce_float(value: object) -> float | None: """Best-effort float coercion that never raises.""" if value is None: return None if isinstance(value, bool): return float(value) if isinstance(value, (int, float)): return float(value) if isinstance(value, str): text = value.strip() if not text: return None try: return float(text) except ValueError: return None return None def _normalize_episode_score(result_dict: dict[str, Any], reward: float) -> float: """Return a score in [0, 1] from tool payload fields.""" score = _coerce_float(result_dict.get("reward")) if score is None: parsed_raw = _coerce_float(result_dict.get("score")) if parsed_raw is None: score = reward else: score = parsed_raw / 7.0 if parsed_raw > 1.0 else parsed_raw return min(max(score, 0.0), 1.0) def _extract_task_ids(payload: Any) -> list[str]: result = _as_mapping(payload) task_ids = result.get("task_ids") if not isinstance(task_ids, list): return [] normalized: list[str] = [] for value in task_ids: task_id = str(value).strip() if task_id: normalized.append(task_id) return normalized def _select_task_ids(task_ids: list[str], task_count: int) -> list[str]: if task_count <= 0: return [] if not task_ids: return [TASK_NAME for _ in range(task_count)] if len(task_ids) >= task_count: return task_ids[:task_count] selected: list[str] = [] for idx in range(task_count): selected.append(task_ids[idx % len(task_ids)]) return selected def _tools_to_openai_format(tools: list[Any]) -> list[dict[str, Any]]: openai_tools: list[dict[str, Any]] = [] for tool in tools: properties: dict[str, Any] = {} required: list[str] = [] input_schema = ( getattr(tool, "input_schema", None) or getattr(tool, "inputSchema", None) or {} ) if input_schema and "properties" in input_schema: for name, schema in input_schema["properties"].items(): properties[name] = { "type": schema.get("type", "string"), "description": schema.get("description", ""), } required = input_schema.get("required", []) openai_tools.append( { "type": "function", "function": { "name": tool.name, "description": tool.description or "", "parameters": { "type": "object", "properties": properties, "required": required, }, }, } ) return openai_tools def _extract_tool_call(response: Any) -> tuple[str, dict[str, Any], str]: message = response.choices[0].message if message.tool_calls: tool_call_obj = cast(Any, message.tool_calls[0]) function_payload = getattr(tool_call_obj, "function", None) tool_call_id = str(getattr(tool_call_obj, "id", "fallback")) if function_payload is not None: tool_name = str(getattr(function_payload, "name", "submit_proof")) raw_arguments = str(getattr(function_payload, "arguments", "{}")) try: tool_args = json.loads(raw_arguments) except json.JSONDecodeError: tool_args = {"proof": raw_arguments} else: tool_name = "submit_proof" raw_input = getattr(tool_call_obj, "input", "") tool_args = {"proof": str(raw_input)} else: tool_name = "submit_proof" tool_args = {"proof": message.content or ""} tool_call_id = "fallback" return tool_name, tool_args, tool_call_id def _as_mapping(value: Any) -> dict[str, Any]: if hasattr(value, "model_dump"): return value.model_dump() if isinstance(value, dict): return value return {"result": str(value)} async def run_episode( env: QEDMathEnv, client: OpenAI, tools: list[dict[str, Any]], problem_id: str | None = None, ) -> tuple[bool, int, float, list[float]]: tool_names = {tool["function"]["name"] for tool in tools} if problem_id is not None: await env.reset(problem_id=problem_id) else: await env.reset() chat_history: list[dict[str, Any]] = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": "Solve the current QED math problem."}, ] rewards: list[float] = [] steps_taken = 0 score = 0.0 success = False grader_called = False for step in range(1, MAX_STEPS + 1): response = client.chat.completions.create( model=MODEL_NAME, messages=cast(Any, chat_history), tools=cast(Any, tools), tool_choice="required", temperature=TEMPERATURE, max_tokens=MAX_TOKENS, stream=False, ) tool_name, tool_args, tool_call_id = _extract_tool_call(response) if tool_name not in tool_names: tool_name = "submit_proof" tool_args = {"proof": tool_args.get("proof", str(tool_args))} call_kwargs = dict(tool_args) step_result = await env.call_tool(tool_name, **call_kwargs) if tool_name == "submit_proof": grader_called = True result_dict = _as_mapping(step_result) reward = float(result_dict.get("reward") or 0.0) done = bool(result_dict.get("done", False)) error_raw = result_dict.get("last_action_error") error = str(error_raw) if error_raw is not None else None action_str = json.dumps({"tool": tool_name, "args": tool_args}, ensure_ascii=True) log_step(step=step, action=action_str, reward=reward, done=done, error=error) rewards.append(reward) steps_taken = step if done: score = _normalize_episode_score(result_dict, reward) success = bool(result_dict.get("is_correct", score > 0.0)) break chat_history.append( { "role": "assistant", "content": None, "tool_calls": [ { "id": tool_call_id, "type": "function", "function": { "name": tool_name, "arguments": json.dumps(tool_args), }, } ], } ) chat_history.append( { "role": "tool", "tool_call_id": tool_call_id, "content": json.dumps(result_dict), } ) # Ensure each task reaches the grader at least once for submission validation. if not grader_called: forced_proof = "I attempted a complete proof but could not finish within the step budget." step_result = await env.call_tool("submit_proof", proof=forced_proof) result_dict = _as_mapping(step_result) reward = float(result_dict.get("reward") or 0.0) done = bool(result_dict.get("done", True)) error_raw = result_dict.get("last_action_error") error = str(error_raw) if error_raw is not None else None forced_step = steps_taken + 1 action_str = json.dumps( { "tool": "submit_proof", "args": {"proof": forced_proof}, "forced": True, }, ensure_ascii=True, ) log_step(step=forced_step, action=action_str, reward=reward, done=done, error=error) rewards.append(reward) steps_taken = forced_step score = _normalize_episode_score(result_dict, reward) success = bool(result_dict.get("is_correct", score > 0.0)) if not rewards: score = 0.0 elif score == 0.0: score = min(max(float(rewards[-1]), 0.0), 1.0) return success, steps_taken, score, rewards async def async_main() -> None: if not HF_TOKEN: raise SystemExit("HF_TOKEN must be set.\nOptional fallback: API_KEY.") if not QED_MATH_URL: raise SystemExit("QED_MATH_URL must be set (for example: https://.hf.space/).") client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) caught_error: Exception | None = None try: async with QEDMathEnv(base_url=QED_MATH_URL) as raw_env: env = cast(QEDMathEnv, raw_env) mcp_tools = await env.list_tools() tools = _tools_to_openai_format(mcp_tools) try: task_payload = await env.call_tool("list_task_ids") except Exception: task_payload = {"task_ids": []} available_task_ids = _extract_task_ids(task_payload) selected_task_ids = _select_task_ids(available_task_ids, TASK_COUNT) for task_index, problem_id in enumerate(selected_task_ids, start=1): task_name = f"{TASK_NAME}:{problem_id}:run{task_index}" log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME) success, steps_taken, score, rewards = await run_episode( env=env, client=client, tools=tools, problem_id=problem_id if problem_id != TASK_NAME else None, ) log_end(success=success, steps=steps_taken, score=score, rewards=rewards) except Exception as exc: caught_error = exc print( f"[ERROR] type={type(exc).__name__} message={exc}", file=sys.stderr, flush=True, ) print( f"[ERROR] QED_MATH_URL={QED_MATH_URL}", file=sys.stderr, flush=True, ) traceback.print_exc(file=sys.stderr) if caught_error is not None: raise SystemExit(1) from caught_error def main() -> None: asyncio.run(async_main()) if __name__ == "__main__": main()