File size: 2,832 Bytes
d57737f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
Agent World Model Environment Client.

Provides a client for connecting to an AWM Environment server.
AWMEnv extends MCPToolClient with AWM-specific helpers.

Example:
    >>> with AWMEnv(base_url="http://localhost:8000") as env:
    ...     # List all available scenarios
    ...     result = env.call_tool("__list_scenarios__")
    ...
    ...     # Start a scenario
    ...     env.reset(scenario="e_commerce_33", task_idx=0)
    ...
    ...     # Discover tools
    ...     tools = env.list_tools()
    ...     print([t.name for t in tools])
    ...
    ...     # Call tools
    ...     result = env.call_tool("search_products", query="headphones")
    ...
    ...     # End episode and get verification result
    ...     result = env.call_tool("done")
"""

from typing import Any

from openenv.core.client_types import StepResult
from openenv.core.mcp_client import MCPToolClient

from .models import AWMListToolsObservation, AWMObservation


class AWMEnv(MCPToolClient):
    """
    Client for the Agent World Model Environment.
    Inherits all functionality from MCPToolClient.
    AWM-specific reset parameters:
        scenario (str): Required. Name of the AWM scenario to load.
        task_idx (int): Optional. Index of the pre-defined task/verifier pair.
        task (str): Optional. Custom task description (no verifier).
        verifier_mode (str): "sql" or "code". Default: "sql".
        llm_base_url (str): LLM endpoint for sql verifier using any OpenAI compatible API service.
        llm_api_key (str): LLM API key.
        llm_model (str): LLM model name.

    Hidden tools (not in list_tools):
        done: End the episode and trigger verification.
        __list_scenarios__: List all 1,000 available scenarios.

    Accessing AWM-specific fields:
        After reset/step, access fields via observation attributes:
        >>> result = await env.reset(scenario="marketplace_1", task_idx=0)
        >>> print(result.observation.reward_type)  # "reset_ok"
        >>> print(result.observation.scenario)     # "marketplace_1"
        >>> print(result.observation.task)         # "Update my volunteer profile..."
    """

    def _parse_result(self, payload: dict[str, Any]) -> StepResult:
        """
        This override ensures AWM fields (reward_type, scenario, task, etc.)
        are available as observation attributes instead of being lost.
        """
        obs_data = payload.get("observation", {})

        # Check if this is a ListToolsObservation (has "tools" key)
        if "tools" in obs_data:
            observation = AWMListToolsObservation(**obs_data)
        else:
            observation = AWMObservation(**obs_data)

        return StepResult(
            observation=observation,
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )