nik-55's picture
Upload folder using huggingface_hub
be77d11 verified
"""
MedChain Env — Inference Script
================================
Runs all tasks sequentially and reports scores.
MANDATORY environment variables:
API_BASE_URL The API endpoint for the LLM
MODEL_NAME / MODEL The model identifier for inference
HF_TOKEN / API_KEY Your Hugging Face / API key
OPTIONAL environment variables:
LOCAL_IMAGE_NAME Docker image tag; if set, Docker is used (highest priority).
BASE_URL URL of a running MedChain server (e.g. an HF Space).
TASK_NAMES Comma-separated list of tasks to run.
Default: orientation_ward,single_ward_stable,multi_ward_seasonal
LOG_LEVEL INFO (default) or DEBUG (writes a timestamped log to logs/)
Environment connection priority (per task):
1. LOCAL_IMAGE_NAME → spin up a Docker container
2. BASE_URL → connect directly to that server URL
3. Default HF Space → https://nik-55-medchain-openenv-hackathon.hf.space
4. Default image → nik-55_medchain-openenv (last-resort Docker fallback)
STDOUT FORMAT
- The script emits exactly three line types to stdout, in this order:
[START] task=<task_name> env=medchain model=<model_name>
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...,rn>
Rules:
- One [START] line at episode begin.
- One [STEP] line per step, immediately after env.step() returns.
- One [END] line after env.close(), always emitted (even on exception).
- reward and rewards are formatted to 2 decimal places; score to 3.
- done and success are lowercase booleans: true or false.
- error is the raw error string, or null if none.
- All fields on a single line with no newlines within a line.
"""
import asyncio
import json
import logging
import os
import sys
import time
import urllib.request
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from openai import BadRequestError, OpenAI, RateLimitError
sys.path.insert(0, str(Path(__file__).parent.parent))
from medchain_env import CallToolAction, MedchainEnv
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
_log_fmt = logging.Formatter(
"[%(levelname)s] %(asctime)s %(message)s", datefmt="%H:%M:%S"
)
_stream_handler = logging.StreamHandler(sys.stdout)
_stream_handler.setFormatter(_log_fmt)
_handlers: list = [_stream_handler]
if LOG_LEVEL == "DEBUG":
os.makedirs("logs", exist_ok=True)
_log_filename = datetime.now().strftime("logs/inference_%Y%m%d_%H%M%S.log")
_file_handler = logging.FileHandler(_log_filename)
_file_handler.setFormatter(_log_fmt)
_handlers.append(_file_handler)
print(f"[DEBUG] Logging to file: {_log_filename}", flush=True)
logging.basicConfig(level=logging.WARNING, handlers=_handlers)
log = logging.getLogger(__name__)
log.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
MODEL_NAME = os.getenv("MODEL_NAME") or os.getenv("MODEL", "openai/gpt-oss-120b:groq")
SMALL_MODEL = "openai/gpt-oss-20b:groq"
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
BASE_URL = os.getenv("BASE_URL")
DEFAULT_BASE_URL = "https://nik-55-medchain-openenv-hackathon.hf.space"
DEFAULT_IMAGE_NAME = "nik-55_medchain-openenv"
# All available tasks:
# "orientation_ward", "single_ward_stable", "multi_ward_seasonal", "hospital_network_crisis"
_task_names_env = os.getenv(
"TASK_NAMES",
"orientation_ward,single_ward_stable,multi_ward_seasonal",
)
TASKS = [t.strip() for t in _task_names_env.split(",") if t.strip()]
# Per-task step limits (actions_per_shift × max_days + generous error headroom)
MAX_STEPS_PER_TASK = {
"orientation_ward": 30, # 8 actions × 2 days + headroom
"single_ward_stable": 45, # 10 actions × 3 days + headroom
"multi_ward_seasonal": 75, # 14 actions × 6 days + headroom
"hospital_network_crisis": 180, # 18 actions × 12 days + headroom
}
MAX_TOKENS = 6000
TEMPERATURE = 0.1
MAX_CONSECUTIVE_ERRORS = 5
SLEEP_BETWEEN_STEPS = 2
SHIFT_HISTORY_KEEP = 6
# 429 rate-limit handling
_429_WINDOW = 60 # seconds to track 429 count
_429_DOWNGRADE_THRESHOLD = 3 # downgrade model after this many 429s in window
_429_BASE_BACKOFF = 5 # initial backoff seconds
_429_MAX_BACKOFF = 30 # cap backoff to stay within 20-min budget
BENCHMARK = "medchain"
SYSTEM_PROMPT = """You are an experienced hospital supply chain manager operating a legacy ERP system.
Your goal is to maintain adequate medical supplies across all locations while controlling costs.
CRITICAL — ACTION BUDGET: You have a strictly limited number of actions per shift.
Budget does NOT roll over. Unspent actions are lost at end_shift().
Recommended budget allocation (highest priority first):
1. read_inbox() — ALWAYS do this first to catch urgent alerts
2. query_erp(table='inventory') — check current stock levels across all locations
3. submit_po(...) — place orders for items below safety stock (PRIORITY)
4. end_shift() — call this when budget is exhausted OR tasks are done
Query tools (query_erp expiry/pipeline, query_forecast, query_supplier) are LOW PRIORITY.
Only use them if you have budget remaining AFTER placing critical orders.
MANDATORY RULES:
- If you receive "Action budget exhausted" → call end_shift() as your VERY NEXT action.
Do NOT call any other tool. The budget cannot be restored until end_shift() is called.
- Order early: factor in lead times. If lead time is 2 days, order today to avoid stockout in 2 days.
- Expedited orders require file_justification(ticket_id=...) with a real clinical reason.
- FEFO: oldest stock consumed first — check expiry and rotate perishables proactively.
- Recalls: quarantine the recalled lot immediately, then order a replacement.
- MCI events: pre-emptive ordering beats reactive ordering. Order extra blood/critical supplies NOW.
Safety stock target: aim for at least (lead_time + 1) × daily_demand units on hand.
When calling tools, use the EXACT parameter names shown in the tool descriptions.
"""
def log_start(task: str, model: str) -> None:
print(f"[START] task={task} env={BENCHMARK} model={model}", flush=True)
def log_step(
step: int, action: str, reward: float, done: bool, error: Optional[str]
) -> None:
error_val = error if error else "null"
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
flush=True,
)
def _tools_to_openai_format(tools) -> List[dict]:
"""Convert MCP tools to OpenAI function-calling format."""
openai_tools = []
for tool in tools:
properties = {}
required = []
if tool.input_schema and "properties" in tool.input_schema:
for name, schema in tool.input_schema["properties"].items():
properties[name] = {
"type": schema.get("type", "string"),
"description": schema.get("description", ""),
}
required = tool.input_schema.get("required", [])
openai_tools.append(
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description or "",
"parameters": {
"type": "object",
"properties": properties,
"required": required,
},
},
}
)
log.debug("Tool registered: %s (required=%s)", tool.name, required)
return openai_tools
def _make_shift_summary(shift_day: int, end_shift_result: str) -> str:
"""Build a compact summary of a completed shift for the context window."""
lines = []
for line in (end_shift_result or "").splitlines():
stripped = line.strip()
if stripped and any(
kw in stripped
for kw in [
"DEMAND:",
"FULFILLED:",
"DELIVERIES:",
"EXPIRED:",
"Spend:",
"Waste",
"Service Level",
"END OF SHIFT",
"Day ",
"Score:",
]
):
lines.append(stripped)
if len(lines) >= 35:
break
summary_body = "\n".join(lines) if lines else (end_shift_result or "")[:900]
return f"[SHIFT DAY {shift_day} SUMMARY]\n{summary_body}"
async def _is_url_reachable(url: str, timeout: float = 30.0) -> bool:
loop = asyncio.get_event_loop()
def _check() -> bool:
try:
data = b"{}"
req = urllib.request.Request(
url.rstrip("/") + "/reset",
data=data,
headers={"Content-Type": "application/json"},
method="POST",
)
with urllib.request.urlopen(req, timeout=timeout) as resp:
return resp.status == 200
except Exception:
return False
try:
return await loop.run_in_executor(None, _check)
except Exception:
return False
async def create_env(task_name: str) -> MedchainEnv:
if LOCAL_IMAGE_NAME:
log.info("Using Docker image '%s' for task '%s'", LOCAL_IMAGE_NAME, task_name)
return await MedchainEnv.from_docker_image(LOCAL_IMAGE_NAME)
if BASE_URL:
log.info("Using BASE_URL '%s' for task '%s'", BASE_URL, task_name)
env = MedchainEnv(base_url=BASE_URL)
await env.connect()
return env
log.info("Probing default URL: %s", DEFAULT_BASE_URL)
if await _is_url_reachable(DEFAULT_BASE_URL):
log.info("Default URL reachable; connecting: %s", DEFAULT_BASE_URL)
env = MedchainEnv(base_url=DEFAULT_BASE_URL)
await env.connect()
return env
log.warning(
"Default URL '%s' not reachable. Falling back to Docker image: %s",
DEFAULT_BASE_URL,
DEFAULT_IMAGE_NAME,
)
try:
return await MedchainEnv.from_docker_image(DEFAULT_IMAGE_NAME)
except Exception as docker_err:
raise RuntimeError(
f"All environment connection methods failed for task '{task_name}'.\n"
f" 1. LOCAL_IMAGE_NAME: not set\n"
f" 2. BASE_URL: not set\n"
f" 3. Default URL ({DEFAULT_BASE_URL}): not reachable\n"
f" 4. Default Docker image ({DEFAULT_IMAGE_NAME}): {docker_err}\n"
"\nFix: set LOCAL_IMAGE_NAME (Docker image name) or BASE_URL (running server URL), "
"or ensure Docker is running with the image available."
) from docker_err
async def run_task_episode(
env: MedchainEnv,
client: OpenAI,
tools: List[dict],
task_name: str,
) -> Dict[str, Any]:
"""Run one episode of a task and return the result."""
tool_names = [t["function"]["name"] for t in tools]
max_steps = MAX_STEPS_PER_TASK.get(task_name, 160)
obs = await env.reset(task=task_name)
obs = obs.observation
dashboard = obs.metadata.get("dashboard", "")
log_start(task=task_name, model=MODEL_NAME)
log.debug(
"[%s] Episode started. Tools: %s max_steps=%d", task_name, tool_names, max_steps
)
chat_history: List[dict] = [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": f"Your shift has started. Current dashboard:\n\n{dashboard}",
},
]
step_count = 0
final_reward = 0.0
done = obs.done
consecutive_errors = 0
rewards: List[float] = []
past_shift_summaries: List[str] = []
current_shift_messages: List[dict] = []
# 429 rate-limit tracking
active_model = MODEL_NAME
rate_limit_times: List[float] = []
backoff_count = 0
episode_start = time.monotonic()
while not done and step_count < max_steps:
step_count += 1
log.debug(
"[%s] Step %d/%d — %d messages in context",
task_name,
step_count,
MAX_STEPS_PER_TASK[task_name],
len(chat_history),
)
try:
response = client.chat.completions.create(
model=active_model,
messages=chat_history,
tools=tools,
tool_choice="required",
max_completion_tokens=MAX_TOKENS,
temperature=TEMPERATURE,
)
consecutive_errors = 0
backoff_count = 0
except RateLimitError as e:
now = time.monotonic()
rate_limit_times.append(now)
# Purge timestamps outside the tracking window
rate_limit_times[:] = [
t for t in rate_limit_times if now - t <= _429_WINDOW
]
backoff_count += 1
backoff_secs = min(
_429_BASE_BACKOFF * (2 ** (backoff_count - 1)), _429_MAX_BACKOFF
)
log.warning(
"[%s] Step %d — 429 RateLimitError (count=%d in last %ds, backoff=%.1fs): %s",
task_name,
step_count,
len(rate_limit_times),
_429_WINDOW,
backoff_secs,
e,
)
if (
len(rate_limit_times) >= _429_DOWNGRADE_THRESHOLD
and active_model != SMALL_MODEL
):
active_model = SMALL_MODEL
log.warning(
"[%s] Step %d — Downgrading model to %s due to repeated 429 errors",
task_name,
step_count,
SMALL_MODEL,
)
await asyncio.sleep(backoff_secs)
continue
except BadRequestError as e:
consecutive_errors += 1
log.warning(
"[%s] Step %d — BadRequestError (%d/%d): %s",
task_name,
step_count,
consecutive_errors,
MAX_CONSECUTIVE_ERRORS,
e,
)
if consecutive_errors >= MAX_CONSECUTIVE_ERRORS:
log.error(
"[%s] Aborting after %d consecutive errors",
task_name,
MAX_CONSECUTIVE_ERRORS,
)
break
err_msg = (
f"Your previous tool call was rejected with an error:\n{e}\n\n"
"Please retry with a valid tool call. If your budget is exhausted, call end_shift()."
)
chat_history.append({"role": "user", "content": err_msg})
current_shift_messages.append({"role": "user", "content": err_msg})
continue
message = response.choices[0].message
log.debug(
"[%s] Step %d — finish_reason=%s tool_calls=%d",
task_name,
step_count,
response.choices[0].finish_reason,
len(message.tool_calls) if message.tool_calls else 0,
)
if not message.tool_calls:
log.warning(
"[%s] Step %d — no tool_calls in response; falling back to end_shift",
task_name,
step_count,
)
tool_name = "end_shift"
tool_args = {}
tool_call_id = "fallback"
else:
tc = message.tool_calls[0]
tool_name = tc.function.name
tool_call_id = tc.id
try:
tool_args = json.loads(tc.function.arguments)
except (json.JSONDecodeError, AttributeError):
log.warning(
"[%s] Step %d — failed to parse tool arguments: %r",
task_name,
step_count,
tc.function.arguments,
)
tool_args = {}
if tool_name not in tool_names:
log.warning(
"[%s] Step %d — unknown tool %r; falling back to end_shift",
task_name,
step_count,
tool_name,
)
tool_name = "end_shift"
tool_args = {}
log.debug(
"[%s] Step %d — calling %s(%s)", task_name, step_count, tool_name, tool_args
)
assistant_msg = {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": tool_call_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": json.dumps(tool_args),
},
}
],
}
chat_history.append(assistant_msg)
current_shift_messages.append(assistant_msg)
action = CallToolAction(tool_name=tool_name, arguments=tool_args)
step_result = await env.step(action)
obs = step_result.observation
done = obs.done
result_text = obs.metadata.get("tool_result", str(obs.metadata))
step_reward = obs.reward or 0.0
step_error: Optional[str] = None
if "EPISODE COMPLETE" in (result_text or ""):
log.info("[%s] Step %d — episode complete detected", task_name, step_count)
done = True
if obs.reward is not None and obs.reward > 0:
final_reward = obs.reward
rewards.append(step_reward)
action_str = f"{tool_name}({json.dumps(tool_args)})"
log_step(
step=step_count,
action=action_str,
reward=step_reward,
done=done,
error=step_error,
)
tool_result_msg = {
"role": "tool",
"tool_call_id": tool_call_id,
"content": result_text[:2700] if result_text else "OK",
}
chat_history.append(tool_result_msg)
current_shift_messages.append(tool_result_msg)
# Budget exhausted — inject directive and skip sleep
if "Action budget exhausted" in (result_text or ""):
log.info(
"[%s] Step %d — budget exhausted; injecting end_shift directive",
task_name,
step_count,
)
directive = (
"SYSTEM ALERT: Your action budget for this shift is fully exhausted. "
"You MUST call end_shift() as your very next action. "
"Every other tool call will fail until you do."
)
chat_history.append({"role": "user", "content": directive})
current_shift_messages.append({"role": "user", "content": directive})
continue
await asyncio.sleep(SLEEP_BETWEEN_STEPS)
# Shift ended — summarise and prune context, then set up next shift
if (
tool_name == "end_shift"
and "END OF SHIFT" in (result_text or "")
and not done
):
shift_day = "?"
for part in (result_text or "").split():
if part.isdigit():
shift_day = part
break
shift_summary = _make_shift_summary(shift_day, result_text or "")
log.debug("[%s] Shift %s summary:\n%s", task_name, shift_day, shift_summary)
past_shift_summaries.append(shift_summary)
log.info(
"[%s] Step %d — shift %s ended; pruning context (%d summaries)",
task_name,
step_count,
shift_day,
len(past_shift_summaries),
)
summaries_msg = {
"role": "user",
"content": "COMPLETED SHIFT SUMMARIES:\n\n"
+ "\n\n".join(past_shift_summaries),
}
trimmed = (
current_shift_messages[-SHIFT_HISTORY_KEEP:]
if len(current_shift_messages) > SHIFT_HISTORY_KEEP
else list(current_shift_messages)
)
# Remove budget-exhausted directives so they don't bleed into the next shift
trimmed = [
m
for m in trimmed
if "Action budget exhausted" not in (m.get("content") or "")
]
# Strip orphaned leading tool-response messages to avoid API errors
while trimmed and trimmed[0].get("role") == "tool":
log.debug(
"[%s] Dropping orphaned leading tool msg (tool_call_id=%s)",
task_name,
trimmed[0].get("tool_call_id"),
)
trimmed = trimmed[1:]
chat_history = (
[
{"role": "system", "content": SYSTEM_PROMPT},
summaries_msg,
]
+ trimmed
+ [
{
"role": "user",
"content": "Your next shift has begun. The dashboard is shown above in the last tool result. "
"Continue managing the supply chain.",
},
]
)
current_shift_messages = []
episode_duration = time.monotonic() - episode_start
log.info(
"[%s] Episode finished. steps=%d done=%s final_reward=%.4f",
task_name,
step_count,
done,
final_reward,
)
log.debug("[%s] Episode duration: %.1fs", task_name, episode_duration)
return {
"task": task_name,
"reward": final_reward,
"steps": step_count,
"done": done,
"rewards": rewards,
"duration": episode_duration,
}
async def async_main() -> None:
if not API_KEY:
raise SystemExit("HF_TOKEN or API_KEY must be set.")
if not MODEL_NAME:
raise SystemExit("MODEL_NAME or MODEL must be set.")
log.info("Starting. API_BASE_URL=%s MODEL_NAME=%s", API_BASE_URL, MODEL_NAME)
log.info("Tasks: %s", TASKS)
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
results = []
script_start = time.monotonic()
for task_name in TASKS:
log.info("Launching task: %s", task_name)
task_start = time.monotonic()
env = await create_env(task_name)
final_reward = 0.0
success = False
steps = 0
step_rewards: List[float] = []
try:
mcp_tools = await env.list_tools()
tools = _tools_to_openai_format(mcp_tools)
log.info("[%s] %d tools discovered", task_name, len(tools))
result = await run_task_episode(env, client, tools, task_name)
results.append(result)
final_reward = result["reward"]
steps = result["steps"]
success = result["done"]
step_rewards = result["rewards"]
log.info(
"[%s] Task complete: reward=%.4f steps=%d",
task_name,
final_reward,
steps,
)
except Exception as e:
log.error("[%s] Task failed with exception: %s", task_name, e)
finally:
try:
await env.close()
except Exception as e:
log.error("[%s] env.close() failed: %s", task_name, e)
log_end(
success=success, steps=steps, score=final_reward, rewards=step_rewards
)
log.debug(
"[%s] Total task wall time: %.1fs",
task_name,
time.monotonic() - task_start,
)
total_duration = time.monotonic() - script_start
if results:
avg_reward = sum(r["reward"] for r in results) / len(results)
log.info("All tasks complete. avg_reward=%.4f", avg_reward)
log.debug("Overall script duration: %.1fs", total_duration)
def main() -> None:
asyncio.run(async_main())
if __name__ == "__main__":
main()