nik-55's picture
Upload folder using huggingface_hub
be77d11 verified
"""
Test inference script for medchain_env simulating runs without LLM calls.
"""
import asyncio
import logging
import sys
from pathlib import Path
# Add project root to sys.path so we can import medchain_env
sys.path.insert(0, str(Path(__file__).parent))
from medchain_env import CallToolAction, MedchainEnv
_log_fmt = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S")
class DualWriter:
def __init__(self, filepath):
self.terminal = sys.stdout
self.log = open(filepath, "w")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
self.log.flush()
def flush(self):
self.terminal.flush()
self.log.flush()
sys.stdout = DualWriter("test_inference.log")
_stream_handler = logging.StreamHandler(sys.stdout)
_stream_handler.setFormatter(_log_fmt)
logging.basicConfig(level=logging.INFO, handlers=[_stream_handler])
log = logging.getLogger(__name__)
async def run_test(task_name: str, actions_to_take: list):
log.info("Starting task: %s", task_name)
# We use the same Docker environment as in inference_medchain_env.py
env = await MedchainEnv.from_docker_image("medchain_env-env:latest")
try:
log.info("[%s] Docker env started", task_name)
mcp_tools = await env.list_tools()
tool_names = [t.name for t in mcp_tools]
log.info("Available tools: %s", tool_names)
obs = await env.reset(task=task_name)
obs = obs.observation
dashboard = obs.metadata.get("dashboard", "")
log.info("[%s] env.reset() complete. done=%s metadata_keys=%s",
task_name, obs.done, list(obs.metadata.keys()))
print(f"\n{'=' * 60}")
print(f"TASK: {task_name}")
print(f"{'=' * 60}")
print(dashboard[:500])
step_count = 0
final_reward = 0.0
done = obs.done
for act_dict in actions_to_take:
if done:
log.info("[%s] Episode already done before taking action %s", task_name, act_dict["tool_name"])
break
step_count += 1
tool_name = act_dict["tool_name"]
tool_args = act_dict["arguments"]
print(f"\n{'─' * 60}")
print(f"[{task_name}] Step {step_count} — predefined action")
print(f"{'─' * 60}")
print(f"\n[{task_name}] Step {step_count} — SIMULATED AGENT RESPONSE:")
print(f" TOOL CALL: {tool_name}({tool_args})")
log.info("[%s] Step %d - calling tool: %s(%s)", task_name, step_count, tool_name, tool_args)
action = CallToolAction(tool_name=tool_name, arguments=tool_args)
step_result = await env.step(action)
obs = step_result.observation
done = obs.done
result_text = obs.metadata.get("tool_result", str(obs.metadata))
if "EPISODE COMPLETE" in (result_text or ""):
log.info("[%s] Step %d - 'EPISODE COMPLETE' detected in result text; marking done", task_name, step_count)
done = True
print(f"\n[{task_name}] Step {step_count} — SERVER RESPONSE (tool_result):")
print(f" {(result_text or 'EMPTY')[:500]}")
log.info("[%s] Step %d - env.step() returned. done=%s reward=%s result_preview=%r",
task_name, step_count, done, obs.reward, (result_text or "")[:120])
if obs.reward is not None and obs.reward > 0:
final_reward = obs.reward
print(f" Reward: {obs.reward:.4f} | Done: {done}")
# Sleep slightly to replicate inference_medchain_env behaviour and prevent overwhelming
await asyncio.sleep(0.1)
log.info("[%s] Episode finished. steps=%d done=%s final_reward=%.4f", task_name, step_count, done, final_reward)
print(f" Final reward: {final_reward:.4f} | Steps: {step_count} | Done: {done}")
finally:
await env.close()
async def main():
# Intro/easy task: orientation_ward (2 days, explore tools, place 1 order)
intro_actions = [
{"tool_name": "read_inbox", "arguments": {"filter": "all"}},
{"tool_name": "query_erp", "arguments": {"table": "inventory"}},
{"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "GLOVE-001", "destination_id": "ward_general", "quantity": 40}},
{"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "SYR-10", "destination_id": "ward_general", "quantity": 60}},
{"tool_name": "end_shift", "arguments": {}},
{"tool_name": "end_shift", "arguments": {}},
]
await run_test("orientation_ward", intro_actions)
# Medium task: single_ward_stable (3 days, 6 products, no events)
easy_actions = [
{"tool_name": "read_inbox", "arguments": {"filter": "unread"}},
{"tool_name": "query_erp", "arguments": {"table": "inventory"}},
{"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "IV-500", "destination_id": "ward_general", "quantity": 100}},
{"tool_name": "end_shift", "arguments": {}},
{"tool_name": "end_shift", "arguments": {}},
{"tool_name": "end_shift", "arguments": {}},
{"tool_name": "end_shift", "arguments": {}},
]
await run_test("single_ward_stable", easy_actions)
# Medium-hard task: multi_ward_seasonal (6 days, flu surge + supplier delay)
medium_actions = [
{"tool_name": "read_inbox", "arguments": {"filter": "unread"}},
{"tool_name": "query_erp", "arguments": {"table": "inventory"}},
{"tool_name": "transfer", "arguments": {"from_location_id": "central_pharmacy", "to_location_id": "ward_icu", "product_id": "IV-500", "quantity": 50}},
{"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "IV-500", "destination_id": "central_pharmacy", "quantity": 200}},
]
# Add enough end_shift actions to finish the 6-day episode
for _ in range(7):
medium_actions.append({"tool_name": "end_shift", "arguments": {}})
await run_test("multi_ward_seasonal", medium_actions)
if __name__ == "__main__":
asyncio.run(main())