ChilleD's picture
Upload folder using huggingface_hub
d57737f verified
"""
Agent World Model Environment Client.
Provides a client for connecting to an AWM Environment server.
AWMEnv extends MCPToolClient with AWM-specific helpers.
Example:
>>> with AWMEnv(base_url="http://localhost:8000") as env:
... # List all available scenarios
... result = env.call_tool("__list_scenarios__")
...
... # Start a scenario
... env.reset(scenario="e_commerce_33", task_idx=0)
...
... # Discover tools
... tools = env.list_tools()
... print([t.name for t in tools])
...
... # Call tools
... result = env.call_tool("search_products", query="headphones")
...
... # End episode and get verification result
... result = env.call_tool("done")
"""
from typing import Any
from openenv.core.client_types import StepResult
from openenv.core.mcp_client import MCPToolClient
from .models import AWMListToolsObservation, AWMObservation
class AWMEnv(MCPToolClient):
"""
Client for the Agent World Model Environment.
Inherits all functionality from MCPToolClient.
AWM-specific reset parameters:
scenario (str): Required. Name of the AWM scenario to load.
task_idx (int): Optional. Index of the pre-defined task/verifier pair.
task (str): Optional. Custom task description (no verifier).
verifier_mode (str): "sql" or "code". Default: "sql".
llm_base_url (str): LLM endpoint for sql verifier using any OpenAI compatible API service.
llm_api_key (str): LLM API key.
llm_model (str): LLM model name.
Hidden tools (not in list_tools):
done: End the episode and trigger verification.
__list_scenarios__: List all 1,000 available scenarios.
Accessing AWM-specific fields:
After reset/step, access fields via observation attributes:
>>> result = await env.reset(scenario="marketplace_1", task_idx=0)
>>> print(result.observation.reward_type) # "reset_ok"
>>> print(result.observation.scenario) # "marketplace_1"
>>> print(result.observation.task) # "Update my volunteer profile..."
"""
def _parse_result(self, payload: dict[str, Any]) -> StepResult:
"""
This override ensures AWM fields (reward_type, scenario, task, etc.)
are available as observation attributes instead of being lost.
"""
obs_data = payload.get("observation", {})
# Check if this is a ListToolsObservation (has "tools" key)
if "tools" in obs_data:
observation = AWMListToolsObservation(**obs_data)
else:
observation = AWMObservation(**obs_data)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)