Spaces:
Sleeping
Sleeping
File size: 6,287 Bytes
4afc4db be77d11 4afc4db be77d11 4afc4db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | """
Test inference script for medchain_env simulating runs without LLM calls.
"""
import asyncio
import logging
import sys
from pathlib import Path
# Add project root to sys.path so we can import medchain_env
sys.path.insert(0, str(Path(__file__).parent))
from medchain_env import CallToolAction, MedchainEnv
_log_fmt = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S")
class DualWriter:
def __init__(self, filepath):
self.terminal = sys.stdout
self.log = open(filepath, "w")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
self.log.flush()
def flush(self):
self.terminal.flush()
self.log.flush()
sys.stdout = DualWriter("test_inference.log")
_stream_handler = logging.StreamHandler(sys.stdout)
_stream_handler.setFormatter(_log_fmt)
logging.basicConfig(level=logging.INFO, handlers=[_stream_handler])
log = logging.getLogger(__name__)
async def run_test(task_name: str, actions_to_take: list):
log.info("Starting task: %s", task_name)
# We use the same Docker environment as in inference_medchain_env.py
env = await MedchainEnv.from_docker_image("medchain_env-env:latest")
try:
log.info("[%s] Docker env started", task_name)
mcp_tools = await env.list_tools()
tool_names = [t.name for t in mcp_tools]
log.info("Available tools: %s", tool_names)
obs = await env.reset(task=task_name)
obs = obs.observation
dashboard = obs.metadata.get("dashboard", "")
log.info("[%s] env.reset() complete. done=%s metadata_keys=%s",
task_name, obs.done, list(obs.metadata.keys()))
print(f"\n{'=' * 60}")
print(f"TASK: {task_name}")
print(f"{'=' * 60}")
print(dashboard[:500])
step_count = 0
final_reward = 0.0
done = obs.done
for act_dict in actions_to_take:
if done:
log.info("[%s] Episode already done before taking action %s", task_name, act_dict["tool_name"])
break
step_count += 1
tool_name = act_dict["tool_name"]
tool_args = act_dict["arguments"]
print(f"\n{'─' * 60}")
print(f"[{task_name}] Step {step_count} — predefined action")
print(f"{'─' * 60}")
print(f"\n[{task_name}] Step {step_count} — SIMULATED AGENT RESPONSE:")
print(f" TOOL CALL: {tool_name}({tool_args})")
log.info("[%s] Step %d - calling tool: %s(%s)", task_name, step_count, tool_name, tool_args)
action = CallToolAction(tool_name=tool_name, arguments=tool_args)
step_result = await env.step(action)
obs = step_result.observation
done = obs.done
result_text = obs.metadata.get("tool_result", str(obs.metadata))
if "EPISODE COMPLETE" in (result_text or ""):
log.info("[%s] Step %d - 'EPISODE COMPLETE' detected in result text; marking done", task_name, step_count)
done = True
print(f"\n[{task_name}] Step {step_count} — SERVER RESPONSE (tool_result):")
print(f" {(result_text or 'EMPTY')[:500]}")
log.info("[%s] Step %d - env.step() returned. done=%s reward=%s result_preview=%r",
task_name, step_count, done, obs.reward, (result_text or "")[:120])
if obs.reward is not None and obs.reward > 0:
final_reward = obs.reward
print(f" Reward: {obs.reward:.4f} | Done: {done}")
# Sleep slightly to replicate inference_medchain_env behaviour and prevent overwhelming
await asyncio.sleep(0.1)
log.info("[%s] Episode finished. steps=%d done=%s final_reward=%.4f", task_name, step_count, done, final_reward)
print(f" Final reward: {final_reward:.4f} | Steps: {step_count} | Done: {done}")
finally:
await env.close()
async def main():
# Intro/easy task: orientation_ward (2 days, explore tools, place 1 order)
intro_actions = [
{"tool_name": "read_inbox", "arguments": {"filter": "all"}},
{"tool_name": "query_erp", "arguments": {"table": "inventory"}},
{"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "GLOVE-001", "destination_id": "ward_general", "quantity": 40}},
{"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "SYR-10", "destination_id": "ward_general", "quantity": 60}},
{"tool_name": "end_shift", "arguments": {}},
{"tool_name": "end_shift", "arguments": {}},
]
await run_test("orientation_ward", intro_actions)
# Medium task: single_ward_stable (3 days, 6 products, no events)
easy_actions = [
{"tool_name": "read_inbox", "arguments": {"filter": "unread"}},
{"tool_name": "query_erp", "arguments": {"table": "inventory"}},
{"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "IV-500", "destination_id": "ward_general", "quantity": 100}},
{"tool_name": "end_shift", "arguments": {}},
{"tool_name": "end_shift", "arguments": {}},
{"tool_name": "end_shift", "arguments": {}},
{"tool_name": "end_shift", "arguments": {}},
]
await run_test("single_ward_stable", easy_actions)
# Medium-hard task: multi_ward_seasonal (6 days, flu surge + supplier delay)
medium_actions = [
{"tool_name": "read_inbox", "arguments": {"filter": "unread"}},
{"tool_name": "query_erp", "arguments": {"table": "inventory"}},
{"tool_name": "transfer", "arguments": {"from_location_id": "central_pharmacy", "to_location_id": "ward_icu", "product_id": "IV-500", "quantity": 50}},
{"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "IV-500", "destination_id": "central_pharmacy", "quantity": 200}},
]
# Add enough end_shift actions to finish the 6-day episode
for _ in range(7):
medium_actions.append({"tool_name": "end_shift", "arguments": {}})
await run_test("multi_ward_seasonal", medium_actions)
if __name__ == "__main__":
asyncio.run(main())
|