""" AWM Environment wraps 1,000 Agent World Model sub-environments into a single OpenEnv environment. Each sub-environment is launched as a subprocess on demand and accessed via MCP tool calls. """ import asyncio import json import logging import os import tempfile from typing import Any from uuid import uuid4 from openenv.core.env_server.interfaces import Environment from openenv.core.env_server.mcp_types import CallToolAction, ListToolsAction, Tool from openenv.core.env_server.types import Action, State from ..models import AWMListToolsObservation, AWMObservation from .data_loader import AWMDataLoader, normalize_scenario_name from .db_manager import cleanup_session_dir, create_database, save_snapshot from .scenario_manager import ScenarioProcess from .session_registry import registry as _registry from .verifier import run_llm_judge, run_verifier from .config import DEFAULT_REWARD_CONFIG logger = logging.getLogger(__name__) # Tools dispatched specially by step() (see _handle_done / _handle_verify / # _handle_list_scenarios) rather than proxied to the sub-env subprocess. # Not currently used to filter list_tools output — the sub-env subprocess # never surfaces these names. Kept commented for documentation. # HIDDEN_TOOLS = frozenset(["done", "verify", "__list_scenarios__"]) VALID_VERIFIER_MODES = {"sql", "code"} # Reward types that map to format_error FORMAT_ERROR_TYPES = {"tool_not_found", "invalid_args", "invalid_action"} _TOOL_NOT_FOUND_KEYWORDS = ["not found", "unknown tool", "no tool"] _INVALID_ARGS_KEYWORDS = [ "invalid", "argument", "parameter", "required property", "validation error", "missing", "schema", ] _TIMEOUT_KEYWORDS = ["timeout", "timed out"] def _classify_tool_error(error_msg: str) -> str: """Classify a tool call error into a reward_type string.""" lower = error_msg.lower() if any(kw in lower for kw in _TOOL_NOT_FOUND_KEYWORDS): return "tool_not_found" if any(kw in lower for kw in _INVALID_ARGS_KEYWORDS): return "invalid_args" if any(kw in lower for kw in _TIMEOUT_KEYWORDS): return "timeout" return "server_error" def _run_async_oneshot(coro: Any) -> Any: """Run an async coroutine from sync context (one-shot, for LLM judge etc.).""" loop = asyncio.new_event_loop() try: return loop.run_until_complete(coro) finally: loop.close() class AWMEnvironment(Environment): """ Lifecycle: 1. reset(scenario="...", task_idx=...) -> starts a sub-env subprocess 2. step(ListToolsAction()) -> lists tools from the sub-env 3. step(CallToolAction(...)) -> proxies tool call to the sub-env 4. step(CallToolAction(tool_name="verify", arguments={verifier_mode: "sql"|"code"})) -> runs verifier 5. step(CallToolAction(tool_name="done")) -> ends episode, destroys environment 6. close() -> kills subprocess, cleans up """ SUPPORTS_CONCURRENT_SESSIONS = True def __init__(self, data_loader: AWMDataLoader | None = None): super().__init__() self._data_loader = data_loader or AWMDataLoader() self._process = ScenarioProcess() self._state = State(episode_id=None, step_count=0) self._scenario: str | None = None self._task: str | None = None self._task_idx: int | None = None self._has_verifier: dict | None = None # {sql: bool, code: bool} self._reset_ok: bool = False self._episode_done: bool = False self._session_dir: str | None = None self._db_path: str | None = None self._initial_db_path: str | None = None # LLM config for sql verifier mode self._llm_base_url: str | None = None self._llm_api_key: str | None = None self._llm_model: str | None = None self._tools_cache: list[dict] | None = None self._trajectory: list[dict] = [] self._keep_session: bool = False # Reward config (customizable at reset) self._reward_config: dict = DEFAULT_REWARD_CONFIG.copy() # Session registry tracking self._registry_id: str | None = None def reset( self, seed: int | None = None, episode_id: str | None = None, scenario: str | None = None, task_idx: int | None = None, task: str | None = None, reward_config: dict | None = None, llm_base_url: str | None = None, llm_api_key: str | None = None, llm_model: str | None = None, **kwargs: Any, ) -> AWMObservation: self._reset_ok = False self._episode_done = False if not scenario: return AWMObservation( done=False, reward=None, reward_type="reset_error", error="Parameter 'scenario' is required", ) scenario_key = normalize_scenario_name(scenario) if not self._data_loader.scenario_exists(scenario_key): return AWMObservation( done=False, reward=None, reward_type="reset_error", error=f"Scenario '{scenario}' not found", ) self._cleanup_session() self._scenario = scenario_key self._task_idx = task_idx self._tools_cache = None self._trajectory = [] # Set custom reward config or use default self._reward_config = ( reward_config.copy() if reward_config else DEFAULT_REWARD_CONFIG.copy() ) self._llm_base_url = llm_base_url or os.environ.get("OPENENV_AWM_LLM_BASE_URL") self._llm_api_key = llm_api_key or os.environ.get("OPENENV_AWM_LLM_API_KEY") self._llm_model = llm_model or os.environ.get("OPENENV_AWM_LLM_MODEL") if task is not None: self._task = task elif task_idx is not None: tasks = self._data_loader.get_tasks(scenario_key) if 0 <= task_idx < len(tasks): self._task = tasks[task_idx] else: return AWMObservation( done=False, reward=None, reward_type="reset_error", error=f"task_idx {task_idx} out of range (0..{len(tasks) - 1})", ) else: self._task = None # Check verifier support for both modes self._has_verifier = None if task_idx is not None: sql_verifier = self._data_loader.get_verifier(scenario_key, task_idx, "sql") code_verifier = self._data_loader.get_verifier( scenario_key, task_idx, "code" ) sql_available = False code_available = False if sql_verifier: sql_code = sql_verifier.get("verification", {}).get("code", "") sql_available = bool( sql_code and isinstance(sql_code, str) and len(sql_code.strip()) > 0 ) if code_verifier: code_code = code_verifier.get("verification", {}).get("code", "") code_available = bool( code_code and isinstance(code_code, str) and len(code_code.strip()) > 0 ) if sql_available or code_available: self._has_verifier = {"sql": sql_available, "code": code_available} self._session_dir = tempfile.mkdtemp(prefix=f"openenv_awm_{scenario_key}_") self._db_path = f"{self._session_dir}/{scenario_key}.db" self._initial_db_path = f"{self._session_dir}/{scenario_key}_initial.db" logger.info( f"[reset] scenario={scenario_key} task_idx={task_idx} " f"session_dir={self._session_dir} " f"db={self._db_path} initial_db={self._initial_db_path}" ) try: db_schema = self._data_loader.get_db_schema(scenario_key) sample_data = self._data_loader.get_sample_data(scenario_key) create_database(self._db_path, db_schema, sample_data) save_snapshot(self._db_path, self._initial_db_path) except Exception as e: logger.error(f"Failed to create database for {scenario_key}: {e}") return AWMObservation( done=False, reward=None, reward_type="reset_error", error=f"Database creation failed: {e}", ) try: full_code = self._data_loader.get_env_code(scenario_key) self._process.start(full_code, self._db_path, self._session_dir) except Exception as e: logger.error(f"Failed to start sub-env for {scenario_key}: {e}") return AWMObservation( done=False, reward=None, reward_type="reset_error", error=f"Sub-environment start failed: {e}", ) self._state = State( episode_id=episode_id or str(uuid4()), step_count=0, ) # Register with session registry for idle tracking self._registry_id = self._state.episode_id _registry.register(self._registry_id, self, scenario=self._scenario) tools: list[dict] = [] tool_error: str | None = None try: tools = self._process.list_tools() self._tools_cache = tools except Exception as e: tool_error = str(e) logger.warning(f"Failed to list tools on startup: {e}") if len(tools) == 0: self._reset_ok = True return AWMObservation( done=False, reward=None, reward_type="reset_warning", scenario=scenario_key, task=self._task, task_idx=self._task_idx, has_verifier=self._has_verifier, num_tools=0, warning=f"Sub-env started but no tools discovered. {tool_error or ''}".strip(), ) self._reset_ok = True return AWMObservation( done=False, reward=None, reward_type="reset_ok", scenario=scenario_key, task=self._task, task_idx=self._task_idx, has_verifier=self._has_verifier, num_tools=len(tools), ) def step( self, action: Action, timeout_s: float | None = None, **kwargs: Any, ) -> AWMObservation | AWMListToolsObservation: if self._episode_done: return AWMObservation( done=True, reward=None, reward_type="episode_already_done", error="Episode has ended. Call reset() to start a new episode.", ) self._state.step_count += 1 # Update idle tracker if self._registry_id: _registry.touch(self._registry_id) if isinstance(action, ListToolsAction): return self._handle_list_tools() elif isinstance(action, CallToolAction): if action.tool_name == "done": return self._handle_done(action) elif action.tool_name == "verify": return self._handle_verify(action) elif action.tool_name == "__list_scenarios__": return self._handle_list_scenarios() else: return self._handle_call_tool(action, timeout_s) else: return AWMObservation( done=False, reward=self._get_reward("invalid_action"), reward_type="invalid_action", error=f"Unknown action type: {type(action).__name__}. " "Use ListToolsAction or CallToolAction.", ) def _handle_list_tools(self) -> AWMListToolsObservation: """Return tools from the sub-environment (cached).""" if not self._process.is_running: obs = AWMListToolsObservation( tools=[], error="Sub-environment is not running. Call reset() first.", ) self._trajectory.append( { "action": "list_tools", "success": False, "error": obs.error, } ) return obs if self._tools_cache is not None: tools = [ Tool( name=t["name"], description=t.get("description", ""), input_schema=t.get("inputSchema", {}), ) for t in self._tools_cache ] tool_names = [t["name"] for t in self._tools_cache] self._trajectory.append( { "action": "list_tools", "success": True, "num_tools": len(tools), "tool_names": tool_names, } ) return AWMListToolsObservation(tools=tools) try: raw_tools = self._process.list_tools() self._tools_cache = raw_tools tools = [ Tool( name=t["name"], description=t.get("description", ""), input_schema=t.get("inputSchema", {}), ) for t in raw_tools ] tool_names = [t["name"] for t in raw_tools] self._trajectory.append( { "action": "list_tools", "success": True, "num_tools": len(tools), "tool_names": tool_names, } ) return AWMListToolsObservation(tools=tools) except Exception as e: self._trajectory.append( { "action": "list_tools", "success": False, "error": str(e), } ) return AWMListToolsObservation( tools=[], error=f"Failed to list tools: {e}", ) def _handle_call_tool( self, action: CallToolAction, timeout_s: float | None = None ) -> AWMObservation: """Proxy a tool call to the sub-environment subprocess.""" if not self._process.is_running: return AWMObservation( done=False, reward=self._get_reward("server_error"), reward_type="server_error", tool_name=action.tool_name, error="Sub-environment is not running. Call reset() first.", ) timeout = timeout_s if timeout_s is not None else 30.0 try: result = self._process.call_tool( action.tool_name, action.arguments, timeout, ) except Exception as e: return AWMObservation( done=False, reward=self._get_reward("server_error"), reward_type="server_error", tool_name=action.tool_name, error=str(e), ) self._trajectory.append( { "action": "call_tool", "tool_name": action.tool_name, "arguments": action.arguments, "success": result["success"], "result": result.get("result"), "error": result.get("error"), } ) if result["success"]: return AWMObservation( done=False, reward=self._get_reward("tool_call_ok"), reward_type="tool_call_ok", tool_name=action.tool_name, tool_result=result["result"], ) error_msg = result.get("error", "Unknown error") error_type = _classify_tool_error(error_msg) return AWMObservation( done=False, reward=self._get_reward(error_type), reward_type=error_type, tool_name=action.tool_name, error=error_msg, ) def _get_reward(self, reward_type: str) -> float: """Get reward value for a reward type using the configured reward config.""" # Map format error types to format_error if reward_type in FORMAT_ERROR_TYPES: return self._reward_config.get("format_error", -1.0) # Return configured reward or 0.0 for unknown types return self._reward_config.get(reward_type, 0.0) def _handle_verify(self, action: CallToolAction) -> AWMObservation: """Handle the `verify` tool — run verifier with specified mode.""" if not self._reset_ok or self._scenario is None: return AWMObservation( done=False, reward=self._get_reward("server_error"), reward_type="server_error", error="Cannot verify: environment not initialized " "(reset failed or not called)", ) if self._task is None or self._task_idx is None: return AWMObservation( done=False, reward=self._get_reward("no_verifier"), reward_type="no_verifier", error="Cannot verify: no task specified at reset", ) # Get verifier_mode from arguments args = action.arguments or {} verifier_mode = args.get("verifier_mode", "code") final_answer = args.get("final_answer") if verifier_mode not in VALID_VERIFIER_MODES: return AWMObservation( done=False, reward=self._get_reward("invalid_args"), reward_type="invalid_args", error=f"Invalid verifier_mode '{verifier_mode}'. " f"Must be one of: {', '.join(sorted(VALID_VERIFIER_MODES))}", ) # Check if verifier is available for the requested mode if self._has_verifier is None or not self._has_verifier.get( verifier_mode, False ): return AWMObservation( done=False, reward=self._get_reward("no_verifier"), reward_type="no_verifier", scenario=self._scenario, task=self._task, task_idx=self._task_idx, error=f"No {verifier_mode} verifier available for this task", ) verifier_entry = self._data_loader.get_verifier( self._scenario, self._task_idx, verifier_mode ) if verifier_entry is None: return AWMObservation( done=False, reward=self._get_reward("no_verifier"), reward_type="no_verifier", scenario=self._scenario, task=self._task, task_idx=self._task_idx, ) reward_type, verify_result = run_verifier( verifier_entry=verifier_entry, verifier_mode=verifier_mode, initial_db_path=self._initial_db_path, final_db_path=self._db_path, final_answer=final_answer, ) # For SQL mode, run LLM judge if verifier_mode == "sql" and reward_type != "judge_error": raw_response_str = verifier_entry.get("verification", {}).get( "raw_response", "{}" ) try: raw_response = json.loads(raw_response_str) except (json.JSONDecodeError, TypeError): raw_response = {} try: reward_type, judge_result = _run_async_oneshot( run_llm_judge( task=self._task, verifier_result=verify_result, llm_base_url=self._llm_base_url, llm_api_key=self._llm_api_key, llm_model=self._llm_model, trajectory=self._trajectory, verifier_reasoning=raw_response.get("reasoning", ""), success_criteria=raw_response.get("success_criteria", ""), failure_criteria=raw_response.get("failure_criteria", ""), ) ) verify_result["llm_judge"] = judge_result except Exception as e: logger.error(f"LLM judge failed: {e}") reward_type = "judge_error" verify_result["llm_judge_error"] = str(e) self._trajectory.append( { "action": "verify", "arguments": args, "success": True, "reward_type": reward_type, "reward": self._get_reward(reward_type), "verify_result": verify_result, } ) return AWMObservation( done=False, reward=self._get_reward(reward_type), reward_type=reward_type, verify_result=verify_result, scenario=self._scenario, task=self._task, task_idx=self._task_idx, steps_taken=self._state.step_count, ) def _handle_done(self, action: CallToolAction) -> AWMObservation: """Handle the `done` tool — end episode and destroy environment (no verification). Accepts optional arguments: keep_session (bool): If True, keep the session tmp folder for debugging. Default False (folder is deleted on cleanup). """ if not self._reset_ok or self._scenario is None: self._episode_done = True return AWMObservation( done=True, reward=self._get_reward("server_error"), reward_type="server_error", error="Cannot call done: environment not initialized " "(reset failed or not called)", ) args = action.arguments or {} keep_session = bool(args.get("keep_session", False)) # Save trajectory to JSON before stopping. # Capture locals for race safety — cleanup thread may null # self._session_dir between the check and the write. session_dir = self._session_dir trajectory = self._trajectory trajectory_path = None if session_dir and trajectory and os.path.isdir(session_dir): trajectory_path = f"{session_dir}/trajectory.json" try: with open(trajectory_path, "w", encoding="utf-8") as f: json.dump( { "scenario": self._scenario, "task": self._task, "task_idx": self._task_idx, "steps": self._state.step_count, "trajectory": trajectory, }, f, indent=2, ensure_ascii=False, ) logger.info(f"[AWM done] trajectory saved: {trajectory_path}") except OSError: # Session dir may have been cleaned up concurrently trajectory_path = None except Exception as e: logger.warning(f"Failed to save trajectory: {e}") trajectory_path = None self._episode_done = True self._process.stop() if keep_session and self._session_dir: logger.info(f"[AWM done] keeping session dir: {self._session_dir}") self._keep_session = True else: self._keep_session = False return AWMObservation( done=True, reward=0.0, # done itself doesn't give reward reward_type="episode_done", scenario=self._scenario, task=self._task, task_idx=self._task_idx, steps_taken=self._state.step_count, trajectory_path=trajectory_path, session_dir=self._session_dir if keep_session else None, ) def _handle_list_scenarios(self) -> AWMObservation: """Handle the `__list_scenarios__` tool — return all scenario info.""" try: all_scenarios = self._data_loader.list_scenarios() return AWMObservation( done=False, reward=None, reward_type="tool_call_ok", scenarios=all_scenarios, total=len(all_scenarios), ) except Exception as e: return AWMObservation( done=False, reward=None, reward_type="server_error", error=f"Failed to list scenarios: {e}", ) @property def state(self) -> State: return self._state def close(self) -> None: self._cleanup_session() def _cleanup_session(self) -> None: """Stop subprocess and clean up session temp files. If ``_keep_session`` is True (set by ``done(keep_session=True)``), the session directory is preserved for manual inspection. """ if self._registry_id: _registry.unregister(self._registry_id) self._registry_id = None self._process.stop() if self._session_dir: if getattr(self, "_keep_session", False): logger.info(f"Keeping session dir: {self._session_dir}") else: cleanup_session_dir(self._session_dir) self._session_dir = None self._db_path = None self._initial_db_path = None self._tools_cache = None self._trajectory = [] self._reset_ok = False self._episode_done = False self._has_verifier = None self._keep_session = False