Spaces:
Sleeping
Sleeping
File size: 6,589 Bytes
4afc4db be77d11 4afc4db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | """MedChain Env Environment Client."""
import logging
import re
from typing import Any, Dict, List, Optional
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.mcp_types import (
CallToolAction,
ListToolsAction,
ListToolsObservation,
Tool,
)
from openenv.core.env_server.types import Observation, State
from .models import MedchainState
_log = logging.getLogger(__name__)
class MedchainEnv(EnvClient[CallToolAction, Observation, MedchainState]):
"""
Client for the MedChain Env hospital supply chain environment.
Inherits from EnvClient and communicates via the standard OpenEnv
WebSocket protocol (simulation mode).
Example:
>>> async with MedchainEnv(base_url="http://localhost:8000") as env:
... obs = await env.reset()
... print(obs.observation.metadata["dashboard"])
... tools = await env.list_tools()
... result = await env.step(CallToolAction(tool_name="read_inbox", arguments={}))
Example with Docker:
>>> env = await MedchainEnv.from_docker_image("medchain_env-env:latest")
>>> obs = await env.reset()
"""
def __init__(self, **kwargs: Any) -> None:
kwargs.setdefault("message_timeout_s", 1500.0)
super().__init__(**kwargs)
self._tools_cache: Optional[List[Tool]] = None
# ββ EnvClient abstract methods βββββββββββββββββββββββββββββββββββββββββ
def _step_payload(self, action: Any) -> Dict[str, Any]:
if isinstance(action, ListToolsAction):
return {"type": "list_tools"}
if isinstance(action, CallToolAction):
return {
"type": "call_tool",
"tool_name": action.tool_name,
"arguments": action.arguments,
}
raise ValueError(f"Unsupported action type: {type(action).__name__}")
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]:
obs_data = payload.get("observation", {})
reward = payload.get("reward")
done = payload.get("done", False) or obs_data.get("done", False)
# ββ List-tools response ββββββββββββββββββββββββββββββββββββββββββ
if "tools" in obs_data:
tools = [
Tool(
name=t.get("name", ""),
description=t.get("description", ""),
input_schema=t.get("input_schema", t.get("inputSchema", {})),
)
for t in obs_data.get("tools", [])
]
observation = ListToolsObservation(
tools=tools,
done=done,
reward=reward,
)
return StepResult(observation=observation, reward=reward, done=done)
# ββ Reset response (has "dashboard" field) βββββββββββββββββββββββ
if "dashboard" in obs_data:
observation = Observation(done=done, reward=reward, metadata=obs_data)
return StepResult(observation=observation, reward=reward, done=done)
# ββ Tool-call response (has "tool_name" and "tool_result") βββββββ
if "tool_name" in obs_data:
result_text = obs_data.get("tool_result", "")
# Safety net: if reward is still None (should not happen after the
# serialization fix), fall back to parsing the Final Score from text.
if reward is None and result_text:
m = re.search(r"Final Score:\s*([\d.]+)", result_text)
if m:
reward = float(m.group(1))
observation = Observation(
done=done,
reward=reward,
metadata={"tool_result": result_text},
)
return StepResult(observation=observation, reward=reward, done=done)
# ββ Generic fallback βββββββββββββββββββββββββββββββββββββββββββββ
observation = Observation(done=done, reward=reward, metadata=obs_data)
return StepResult(observation=observation, reward=reward, done=done)
def _parse_state(self, payload: Dict[str, Any]) -> MedchainState:
return MedchainState(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
task=payload.get("task", ""),
day=payload.get("day", 0),
max_days=payload.get("max_days", 0),
actions_remaining=payload.get("actions_remaining", 0),
budget_used=payload.get("budget_used", 0.0),
budget_limit=payload.get("budget_limit", 0.0),
unread_messages=payload.get("unread_messages", 0),
orders_in_transit=payload.get("orders_in_transit", 0),
)
# ββ Tool discovery βββββββββββββββββββββββββββββββββββββββββββββββββββββ
async def list_tools(self, use_cache: bool = True) -> List[Tool]:
"""
Discover the 9 ERP tools available in this environment.
Args:
use_cache: Return cached tools if available (default True).
Returns:
List of Tool objects with name, description, and input_schema.
"""
if use_cache and self._tools_cache is not None:
return self._tools_cache
result = await self.step(ListToolsAction())
if isinstance(result.observation, ListToolsObservation):
self._tools_cache = result.observation.tools
return self._tools_cache
self._tools_cache = []
return self._tools_cache
# ββ Resource cleanup βββββββββββββββββββββββββββββββββββββββββββββββββββ
async def close(self) -> None:
"""Close client, tolerating Docker stop timeouts gracefully."""
try:
await super().close()
except Exception as e:
# docker stop can time out (10 s) when the container is slow to exit.
# Log and swallow so the inference script doesn't crash.
_log.warning("MedchainEnv.close() suppressed error during shutdown: %s", e)
|