""" Test inference script for medchain_env simulating runs without LLM calls. """ import asyncio import logging import sys from pathlib import Path # Add project root to sys.path so we can import medchain_env sys.path.insert(0, str(Path(__file__).parent)) from medchain_env import CallToolAction, MedchainEnv _log_fmt = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S") class DualWriter: def __init__(self, filepath): self.terminal = sys.stdout self.log = open(filepath, "w") def write(self, message): self.terminal.write(message) self.log.write(message) self.log.flush() def flush(self): self.terminal.flush() self.log.flush() sys.stdout = DualWriter("test_inference.log") _stream_handler = logging.StreamHandler(sys.stdout) _stream_handler.setFormatter(_log_fmt) logging.basicConfig(level=logging.INFO, handlers=[_stream_handler]) log = logging.getLogger(__name__) async def run_test(task_name: str, actions_to_take: list): log.info("Starting task: %s", task_name) # We use the same Docker environment as in inference_medchain_env.py env = await MedchainEnv.from_docker_image("medchain_env-env:latest") try: log.info("[%s] Docker env started", task_name) mcp_tools = await env.list_tools() tool_names = [t.name for t in mcp_tools] log.info("Available tools: %s", tool_names) obs = await env.reset(task=task_name) obs = obs.observation dashboard = obs.metadata.get("dashboard", "") log.info("[%s] env.reset() complete. done=%s metadata_keys=%s", task_name, obs.done, list(obs.metadata.keys())) print(f"\n{'=' * 60}") print(f"TASK: {task_name}") print(f"{'=' * 60}") print(dashboard[:500]) step_count = 0 final_reward = 0.0 done = obs.done for act_dict in actions_to_take: if done: log.info("[%s] Episode already done before taking action %s", task_name, act_dict["tool_name"]) break step_count += 1 tool_name = act_dict["tool_name"] tool_args = act_dict["arguments"] print(f"\n{'─' * 60}") print(f"[{task_name}] Step {step_count} — predefined action") print(f"{'─' * 60}") print(f"\n[{task_name}] Step {step_count} — SIMULATED AGENT RESPONSE:") print(f" TOOL CALL: {tool_name}({tool_args})") log.info("[%s] Step %d - calling tool: %s(%s)", task_name, step_count, tool_name, tool_args) action = CallToolAction(tool_name=tool_name, arguments=tool_args) step_result = await env.step(action) obs = step_result.observation done = obs.done result_text = obs.metadata.get("tool_result", str(obs.metadata)) if "EPISODE COMPLETE" in (result_text or ""): log.info("[%s] Step %d - 'EPISODE COMPLETE' detected in result text; marking done", task_name, step_count) done = True print(f"\n[{task_name}] Step {step_count} — SERVER RESPONSE (tool_result):") print(f" {(result_text or 'EMPTY')[:500]}") log.info("[%s] Step %d - env.step() returned. done=%s reward=%s result_preview=%r", task_name, step_count, done, obs.reward, (result_text or "")[:120]) if obs.reward is not None and obs.reward > 0: final_reward = obs.reward print(f" Reward: {obs.reward:.4f} | Done: {done}") # Sleep slightly to replicate inference_medchain_env behaviour and prevent overwhelming await asyncio.sleep(0.1) log.info("[%s] Episode finished. steps=%d done=%s final_reward=%.4f", task_name, step_count, done, final_reward) print(f" Final reward: {final_reward:.4f} | Steps: {step_count} | Done: {done}") finally: await env.close() async def main(): # Intro/easy task: orientation_ward (2 days, explore tools, place 1 order) intro_actions = [ {"tool_name": "read_inbox", "arguments": {"filter": "all"}}, {"tool_name": "query_erp", "arguments": {"table": "inventory"}}, {"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "GLOVE-001", "destination_id": "ward_general", "quantity": 40}}, {"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "SYR-10", "destination_id": "ward_general", "quantity": 60}}, {"tool_name": "end_shift", "arguments": {}}, {"tool_name": "end_shift", "arguments": {}}, ] await run_test("orientation_ward", intro_actions) # Medium task: single_ward_stable (3 days, 6 products, no events) easy_actions = [ {"tool_name": "read_inbox", "arguments": {"filter": "unread"}}, {"tool_name": "query_erp", "arguments": {"table": "inventory"}}, {"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "IV-500", "destination_id": "ward_general", "quantity": 100}}, {"tool_name": "end_shift", "arguments": {}}, {"tool_name": "end_shift", "arguments": {}}, {"tool_name": "end_shift", "arguments": {}}, {"tool_name": "end_shift", "arguments": {}}, ] await run_test("single_ward_stable", easy_actions) # Medium-hard task: multi_ward_seasonal (6 days, flu surge + supplier delay) medium_actions = [ {"tool_name": "read_inbox", "arguments": {"filter": "unread"}}, {"tool_name": "query_erp", "arguments": {"table": "inventory"}}, {"tool_name": "transfer", "arguments": {"from_location_id": "central_pharmacy", "to_location_id": "ward_icu", "product_id": "IV-500", "quantity": 50}}, {"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "IV-500", "destination_id": "central_pharmacy", "quantity": 200}}, ] # Add enough end_shift actions to finish the 6-day episode for _ in range(7): medium_actions.append({"tool_name": "end_shift", "arguments": {}}) await run_test("multi_ward_seasonal", medium_actions) if __name__ == "__main__": asyncio.run(main())