File size: 6,589 Bytes
4afc4db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be77d11
4afc4db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""MedChain Env Environment Client."""

import logging
import re
from typing import Any, Dict, List, Optional

from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.mcp_types import (
    CallToolAction,
    ListToolsAction,
    ListToolsObservation,
    Tool,
)
from openenv.core.env_server.types import Observation, State

from .models import MedchainState

_log = logging.getLogger(__name__)


class MedchainEnv(EnvClient[CallToolAction, Observation, MedchainState]):
    """
    Client for the MedChain Env hospital supply chain environment.

    Inherits from EnvClient and communicates via the standard OpenEnv
    WebSocket protocol (simulation mode).

    Example:
        >>> async with MedchainEnv(base_url="http://localhost:8000") as env:
        ...     obs = await env.reset()
        ...     print(obs.observation.metadata["dashboard"])
        ...     tools = await env.list_tools()
        ...     result = await env.step(CallToolAction(tool_name="read_inbox", arguments={}))

    Example with Docker:
        >>> env = await MedchainEnv.from_docker_image("medchain_env-env:latest")
        >>> obs = await env.reset()
    """

    def __init__(self, **kwargs: Any) -> None:
        kwargs.setdefault("message_timeout_s", 1500.0)
        super().__init__(**kwargs)
        self._tools_cache: Optional[List[Tool]] = None

    # ── EnvClient abstract methods ─────────────────────────────────────────

    def _step_payload(self, action: Any) -> Dict[str, Any]:
        if isinstance(action, ListToolsAction):
            return {"type": "list_tools"}
        if isinstance(action, CallToolAction):
            return {
                "type": "call_tool",
                "tool_name": action.tool_name,
                "arguments": action.arguments,
            }
        raise ValueError(f"Unsupported action type: {type(action).__name__}")

    def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]:
        obs_data = payload.get("observation", {})
        reward = payload.get("reward")
        done = payload.get("done", False) or obs_data.get("done", False)

        # ── List-tools response ──────────────────────────────────────────
        if "tools" in obs_data:
            tools = [
                Tool(
                    name=t.get("name", ""),
                    description=t.get("description", ""),
                    input_schema=t.get("input_schema", t.get("inputSchema", {})),
                )
                for t in obs_data.get("tools", [])
            ]
            observation = ListToolsObservation(
                tools=tools,
                done=done,
                reward=reward,
            )
            return StepResult(observation=observation, reward=reward, done=done)

        # ── Reset response (has "dashboard" field) ───────────────────────
        if "dashboard" in obs_data:
            observation = Observation(done=done, reward=reward, metadata=obs_data)
            return StepResult(observation=observation, reward=reward, done=done)

        # ── Tool-call response (has "tool_name" and "tool_result") ───────
        if "tool_name" in obs_data:
            result_text = obs_data.get("tool_result", "")

            # Safety net: if reward is still None (should not happen after the
            # serialization fix), fall back to parsing the Final Score from text.
            if reward is None and result_text:
                m = re.search(r"Final Score:\s*([\d.]+)", result_text)
                if m:
                    reward = float(m.group(1))

            observation = Observation(
                done=done,
                reward=reward,
                metadata={"tool_result": result_text},
            )
            return StepResult(observation=observation, reward=reward, done=done)

        # ── Generic fallback ─────────────────────────────────────────────
        observation = Observation(done=done, reward=reward, metadata=obs_data)
        return StepResult(observation=observation, reward=reward, done=done)

    def _parse_state(self, payload: Dict[str, Any]) -> MedchainState:
        return MedchainState(
            episode_id=payload.get("episode_id"),
            step_count=payload.get("step_count", 0),
            task=payload.get("task", ""),
            day=payload.get("day", 0),
            max_days=payload.get("max_days", 0),
            actions_remaining=payload.get("actions_remaining", 0),
            budget_used=payload.get("budget_used", 0.0),
            budget_limit=payload.get("budget_limit", 0.0),
            unread_messages=payload.get("unread_messages", 0),
            orders_in_transit=payload.get("orders_in_transit", 0),
        )

    # ── Tool discovery ─────────────────────────────────────────────────────

    async def list_tools(self, use_cache: bool = True) -> List[Tool]:
        """
        Discover the 9 ERP tools available in this environment.

        Args:
            use_cache: Return cached tools if available (default True).

        Returns:
            List of Tool objects with name, description, and input_schema.
        """
        if use_cache and self._tools_cache is not None:
            return self._tools_cache

        result = await self.step(ListToolsAction())
        if isinstance(result.observation, ListToolsObservation):
            self._tools_cache = result.observation.tools
            return self._tools_cache

        self._tools_cache = []
        return self._tools_cache

    # ── Resource cleanup ───────────────────────────────────────────────────

    async def close(self) -> None:
        """Close client, tolerating Docker stop timeouts gracefully."""
        try:
            await super().close()
        except Exception as e:
            # docker stop can time out (10 s) when the container is slow to exit.
            # Log and swallow so the inference script doesn't crash.
            _log.warning("MedchainEnv.close() suppressed error during shutdown: %s", e)