File size: 6,287 Bytes
4afc4db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be77d11
 
4afc4db
 
 
 
 
 
be77d11
4afc4db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
Test inference script for medchain_env simulating runs without LLM calls.
"""

import asyncio
import logging
import sys
from pathlib import Path

# Add project root to sys.path so we can import medchain_env
sys.path.insert(0, str(Path(__file__).parent))

from medchain_env import CallToolAction, MedchainEnv

_log_fmt = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S")

class DualWriter:
    def __init__(self, filepath):
        self.terminal = sys.stdout
        self.log = open(filepath, "w")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
        self.log.flush()

    def flush(self):
        self.terminal.flush()
        self.log.flush()

sys.stdout = DualWriter("test_inference.log")

_stream_handler = logging.StreamHandler(sys.stdout)
_stream_handler.setFormatter(_log_fmt)
logging.basicConfig(level=logging.INFO, handlers=[_stream_handler])
log = logging.getLogger(__name__)

async def run_test(task_name: str, actions_to_take: list):
    log.info("Starting task: %s", task_name)
    
    # We use the same Docker environment as in inference_medchain_env.py
    env = await MedchainEnv.from_docker_image("medchain_env-env:latest")

    try:
        log.info("[%s] Docker env started", task_name)
        mcp_tools = await env.list_tools()
        tool_names = [t.name for t in mcp_tools]
        log.info("Available tools: %s", tool_names)

        obs = await env.reset(task=task_name)
        obs = obs.observation
        dashboard = obs.metadata.get("dashboard", "")
        log.info("[%s] env.reset() complete. done=%s  metadata_keys=%s",
                 task_name, obs.done, list(obs.metadata.keys()))

        print(f"\n{'=' * 60}")
        print(f"TASK: {task_name}")
        print(f"{'=' * 60}")
        print(dashboard[:500])

        step_count = 0
        final_reward = 0.0
        done = obs.done

        for act_dict in actions_to_take:
            if done:
                log.info("[%s] Episode already done before taking action %s", task_name, act_dict["tool_name"])
                break

            step_count += 1
            tool_name = act_dict["tool_name"]
            tool_args = act_dict["arguments"]
            
            print(f"\n{'─' * 60}")
            print(f"[{task_name}] Step {step_count} — predefined action")
            print(f"{'─' * 60}")

            print(f"\n[{task_name}] Step {step_count} — SIMULATED AGENT RESPONSE:")
            print(f"  TOOL CALL: {tool_name}({tool_args})")
            
            log.info("[%s] Step %d - calling tool: %s(%s)", task_name, step_count, tool_name, tool_args)
            action = CallToolAction(tool_name=tool_name, arguments=tool_args)
            step_result = await env.step(action)
            obs = step_result.observation
            done = obs.done

            result_text = obs.metadata.get("tool_result", str(obs.metadata))
            
            if "EPISODE COMPLETE" in (result_text or ""):
                log.info("[%s] Step %d - 'EPISODE COMPLETE' detected in result text; marking done", task_name, step_count)
                done = True

            print(f"\n[{task_name}] Step {step_count} — SERVER RESPONSE (tool_result):")
            print(f"  {(result_text or 'EMPTY')[:500]}")
            log.info("[%s] Step %d - env.step() returned. done=%s reward=%s result_preview=%r",
                     task_name, step_count, done, obs.reward, (result_text or "")[:120])

            if obs.reward is not None and obs.reward > 0:
                final_reward = obs.reward
                print(f"    Reward: {obs.reward:.4f} | Done: {done}")

            # Sleep slightly to replicate inference_medchain_env behaviour and prevent overwhelming
            await asyncio.sleep(0.1)

        log.info("[%s] Episode finished. steps=%d done=%s final_reward=%.4f", task_name, step_count, done, final_reward)
        print(f"  Final reward: {final_reward:.4f} | Steps: {step_count} | Done: {done}")

    finally:
        await env.close()

async def main():
    # Intro/easy task: orientation_ward (2 days, explore tools, place 1 order)
    intro_actions = [
        {"tool_name": "read_inbox", "arguments": {"filter": "all"}},
        {"tool_name": "query_erp", "arguments": {"table": "inventory"}},
        {"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "GLOVE-001", "destination_id": "ward_general", "quantity": 40}},
        {"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "SYR-10", "destination_id": "ward_general", "quantity": 60}},
        {"tool_name": "end_shift", "arguments": {}},
        {"tool_name": "end_shift", "arguments": {}},
    ]
    await run_test("orientation_ward", intro_actions)

    # Medium task: single_ward_stable (3 days, 6 products, no events)
    easy_actions = [
        {"tool_name": "read_inbox", "arguments": {"filter": "unread"}},
        {"tool_name": "query_erp", "arguments": {"table": "inventory"}},
        {"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "IV-500", "destination_id": "ward_general", "quantity": 100}},
        {"tool_name": "end_shift", "arguments": {}},
        {"tool_name": "end_shift", "arguments": {}},
        {"tool_name": "end_shift", "arguments": {}},
        {"tool_name": "end_shift", "arguments": {}},
    ]
    await run_test("single_ward_stable", easy_actions)

    # Medium-hard task: multi_ward_seasonal (6 days, flu surge + supplier delay)
    medium_actions = [
        {"tool_name": "read_inbox", "arguments": {"filter": "unread"}},
        {"tool_name": "query_erp", "arguments": {"table": "inventory"}},
        {"tool_name": "transfer", "arguments": {"from_location_id": "central_pharmacy", "to_location_id": "ward_icu", "product_id": "IV-500", "quantity": 50}},
        {"tool_name": "submit_po", "arguments": {"supplier_id": "MEDLINE", "product_id": "IV-500", "destination_id": "central_pharmacy", "quantity": 200}},
    ]
    # Add enough end_shift actions to finish the 6-day episode
    for _ in range(7):
        medium_actions.append({"tool_name": "end_shift", "arguments": {}})

    await run_test("multi_ward_seasonal", medium_actions)

if __name__ == "__main__":
    asyncio.run(main())