Spaces:
Sleeping
Sleeping
| """ | |
| Test inference script for medchain_env simulating runs without LLM calls. | |
| """ | |
| import asyncio | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| # Add project root to sys.path so we can import medchain_env | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| from medchain_env import CallToolAction, MedchainEnv | |
| _log_fmt = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S") | |
| class DualWriter: | |
| def __init__(self, filepath): | |
| self.terminal = sys.stdout | |
| self.log = open(filepath, "w") | |
| def write(self, message): | |
| self.terminal.write(message) | |
| self.log.write(message) | |
| self.log.flush() | |
| def flush(self): | |
| self.terminal.flush() | |
| self.log.flush() | |
| sys.stdout = DualWriter("test_inference.log") | |
| _stream_handler = logging.StreamHandler(sys.stdout) | |
| _stream_handler.setFormatter(_log_fmt) | |
| logging.basicConfig(level=logging.INFO, handlers=[_stream_handler]) | |
| log = logging.getLogger(__name__) | |
| async def run_test(task_name: str, actions_to_take: list): | |
| log.info("Starting task: %s", task_name) | |
| # We use the same Docker environment as in inference_medchain_env.py | |
| env = await MedchainEnv.from_docker_image("medchain_env-env:latest") | |
| try: | |
| log.info("[%s] Docker env started", task_name) | |
| mcp_tools = await env.list_tools() | |
| tool_names = [t.name for t in mcp_tools] | |
| log.info("Available tools: %s", tool_names) | |
| obs = await env.reset(task=task_name) | |
| obs = obs.observation | |
| dashboard = obs.metadata.get("dashboard", "") | |
| log.info("[%s] env.reset() complete. done=%s metadata_keys=%s", | |
| task_name, obs.done, list(obs.metadata.keys())) | |
| print(f"\n{'=' * 60}") | |
| print(f"TASK: {task_name}") | |
| print(f"{'=' * 60}") | |
| print(dashboard[:500]) | |
| step_count = 0 | |
| final_reward = 0.0 | |
| done = obs.done | |
| for act_dict in actions_to_take: | |
| if done: | |
| log.info("[%s] Episode already done before taking action %s", task_name, act_dict["tool_name"]) | |
| break | |
| step_count += 1 | |
| tool_name = act_dict["tool_name"] | |
| tool_args = act_dict["arguments"] | |
| print(f"\n{'─' * 60}") | |
| print(f"[{task_name}] Step {step_count} — predefined action") | |
| print(f"{'─' * 60}") | |
| print(f"\n[{task_name}] Step {step_count} — SIMULATED AGENT RESPONSE:") | |
| print(f" TOOL CALL: {tool_name}({tool_args})") | |
| log.info("[%s] Step %d - calling tool: %s(%s)", task_name, step_count, tool_name, tool_args) | |
| action = CallToolAction(tool_name=tool_name, arguments=tool_args) | |
| step_result = await env.step(action) | |
| obs = step_result.observation | |
| done = obs.done | |
| result_text = obs.metadata.get("tool_result", str(obs.metadata)) | |
| if "EPISODE COMPLETE" in (result_text or ""): | |
| log.info("[%s] Step %d - 'EPISODE COMPLETE' detected in result text; marking done", task_name, step_count) | |
| done = True | |
| print(f"\n[{task_name}] Step {step_count} — SERVER RESPONSE (tool_result):") | |
| print(f" {(result_text or 'EMPTY')[:500]}") | |
| log.info("[%s] Step %d - env.step() returned. done=%s reward=%s result_preview=%r", | |
| task_name, step_count, done, obs.reward, (result_text or "")[:120]) | |
| if obs.reward is not None and obs.reward > 0: | |
| final_reward = obs.reward | |
| print(f" Reward: {obs.reward:.4f} | Done: {done}") | |
| # Sleep slightly to replicate inference_medchain_env behaviour and prevent overwhelming | |
| await asyncio.sleep(0.1) | |
| log.info("[%s] Episode finished. steps=%d done=%s final_reward=%.4f", task_name, step_count, done, final_reward) | |
| print(f" Final reward: {final_reward:.4f} | Steps: {step_count} | Done: {done}") | |
| finally: | |
| await env.close() | |
| async def main(): | |
| # Intro/easy task: orientation_ward (2 days, explore tools, place 1 order) | |
| intro_actions = [ | |
| {"tool_name": "read_inbox", "arguments": {"filter": "all"}}, | |
| {"tool_name": "query_erp", "arguments": {"table": "inventory"}}, | |
| {"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "GLOVE-001", "destination_id": "ward_general", "quantity": 40}}, | |
| {"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "SYR-10", "destination_id": "ward_general", "quantity": 60}}, | |
| {"tool_name": "end_shift", "arguments": {}}, | |
| {"tool_name": "end_shift", "arguments": {}}, | |
| ] | |
| await run_test("orientation_ward", intro_actions) | |
| # Medium task: single_ward_stable (3 days, 6 products, no events) | |
| easy_actions = [ | |
| {"tool_name": "read_inbox", "arguments": {"filter": "unread"}}, | |
| {"tool_name": "query_erp", "arguments": {"table": "inventory"}}, | |
| {"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "IV-500", "destination_id": "ward_general", "quantity": 100}}, | |
| {"tool_name": "end_shift", "arguments": {}}, | |
| {"tool_name": "end_shift", "arguments": {}}, | |
| {"tool_name": "end_shift", "arguments": {}}, | |
| {"tool_name": "end_shift", "arguments": {}}, | |
| ] | |
| await run_test("single_ward_stable", easy_actions) | |
| # Medium-hard task: multi_ward_seasonal (6 days, flu surge + supplier delay) | |
| medium_actions = [ | |
| {"tool_name": "read_inbox", "arguments": {"filter": "unread"}}, | |
| {"tool_name": "query_erp", "arguments": {"table": "inventory"}}, | |
| {"tool_name": "transfer", "arguments": {"from_location_id": "central_pharmacy", "to_location_id": "ward_icu", "product_id": "IV-500", "quantity": 50}}, | |
| {"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "IV-500", "destination_id": "central_pharmacy", "quantity": 200}}, | |
| ] | |
| # Add enough end_shift actions to finish the 6-day episode | |
| for _ in range(7): | |
| medium_actions.append({"tool_name": "end_shift", "arguments": {}}) | |
| await run_test("multi_ward_seasonal", medium_actions) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |