Spaces:
Sleeping
Sleeping
| """ | |
| MedChain Env — Inference Script | |
| ================================ | |
| Runs all tasks sequentially and reports scores. | |
| MANDATORY environment variables: | |
| API_BASE_URL The API endpoint for the LLM | |
| MODEL_NAME / MODEL The model identifier for inference | |
| HF_TOKEN / API_KEY Your Hugging Face / API key | |
| OPTIONAL environment variables: | |
| LOCAL_IMAGE_NAME Docker image tag; if set, Docker is used (highest priority). | |
| BASE_URL URL of a running MedChain server (e.g. an HF Space). | |
| TASK_NAMES Comma-separated list of tasks to run. | |
| Default: orientation_ward,single_ward_stable,multi_ward_seasonal | |
| LOG_LEVEL INFO (default) or DEBUG (writes a timestamped log to logs/) | |
| Environment connection priority (per task): | |
| 1. LOCAL_IMAGE_NAME → spin up a Docker container | |
| 2. BASE_URL → connect directly to that server URL | |
| 3. Default HF Space → https://nik-55-medchain-openenv-hackathon.hf.space | |
| 4. Default image → nik-55_medchain-openenv (last-resort Docker fallback) | |
| STDOUT FORMAT | |
| - The script emits exactly three line types to stdout, in this order: | |
| [START] task=<task_name> env=medchain model=<model_name> | |
| [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null> | |
| [END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...,rn> | |
| Rules: | |
| - One [START] line at episode begin. | |
| - One [STEP] line per step, immediately after env.step() returns. | |
| - One [END] line after env.close(), always emitted (even on exception). | |
| - reward and rewards are formatted to 2 decimal places; score to 3. | |
| - done and success are lowercase booleans: true or false. | |
| - error is the raw error string, or null if none. | |
| - All fields on a single line with no newlines within a line. | |
| """ | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import sys | |
| import time | |
| import urllib.request | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| from openai import BadRequestError, OpenAI, RateLimitError | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from medchain_env import CallToolAction, MedchainEnv | |
| LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() | |
| _log_fmt = logging.Formatter( | |
| "[%(levelname)s] %(asctime)s %(message)s", datefmt="%H:%M:%S" | |
| ) | |
| _stream_handler = logging.StreamHandler(sys.stdout) | |
| _stream_handler.setFormatter(_log_fmt) | |
| _handlers: list = [_stream_handler] | |
| if LOG_LEVEL == "DEBUG": | |
| os.makedirs("logs", exist_ok=True) | |
| _log_filename = datetime.now().strftime("logs/inference_%Y%m%d_%H%M%S.log") | |
| _file_handler = logging.FileHandler(_log_filename) | |
| _file_handler.setFormatter(_log_fmt) | |
| _handlers.append(_file_handler) | |
| print(f"[DEBUG] Logging to file: {_log_filename}", flush=True) | |
| logging.basicConfig(level=logging.WARNING, handlers=_handlers) | |
| log = logging.getLogger(__name__) | |
| log.setLevel(getattr(logging, LOG_LEVEL, logging.INFO)) | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") | |
| MODEL_NAME = os.getenv("MODEL_NAME") or os.getenv("MODEL", "openai/gpt-oss-120b:groq") | |
| SMALL_MODEL = "openai/gpt-oss-20b:groq" | |
| LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") | |
| BASE_URL = os.getenv("BASE_URL") | |
| DEFAULT_BASE_URL = "https://nik-55-medchain-openenv-hackathon.hf.space" | |
| DEFAULT_IMAGE_NAME = "nik-55_medchain-openenv" | |
| # All available tasks: | |
| # "orientation_ward", "single_ward_stable", "multi_ward_seasonal", "hospital_network_crisis" | |
| _task_names_env = os.getenv( | |
| "TASK_NAMES", | |
| "orientation_ward,single_ward_stable,multi_ward_seasonal", | |
| ) | |
| TASKS = [t.strip() for t in _task_names_env.split(",") if t.strip()] | |
| # Per-task step limits (actions_per_shift × max_days + generous error headroom) | |
| MAX_STEPS_PER_TASK = { | |
| "orientation_ward": 30, # 8 actions × 2 days + headroom | |
| "single_ward_stable": 45, # 10 actions × 3 days + headroom | |
| "multi_ward_seasonal": 75, # 14 actions × 6 days + headroom | |
| "hospital_network_crisis": 180, # 18 actions × 12 days + headroom | |
| } | |
| MAX_TOKENS = 6000 | |
| TEMPERATURE = 0.1 | |
| MAX_CONSECUTIVE_ERRORS = 5 | |
| SLEEP_BETWEEN_STEPS = 2 | |
| SHIFT_HISTORY_KEEP = 6 | |
| # 429 rate-limit handling | |
| _429_WINDOW = 60 # seconds to track 429 count | |
| _429_DOWNGRADE_THRESHOLD = 3 # downgrade model after this many 429s in window | |
| _429_BASE_BACKOFF = 5 # initial backoff seconds | |
| _429_MAX_BACKOFF = 30 # cap backoff to stay within 20-min budget | |
| BENCHMARK = "medchain" | |
| SYSTEM_PROMPT = """You are an experienced hospital supply chain manager operating a legacy ERP system. | |
| Your goal is to maintain adequate medical supplies across all locations while controlling costs. | |
| CRITICAL — ACTION BUDGET: You have a strictly limited number of actions per shift. | |
| Budget does NOT roll over. Unspent actions are lost at end_shift(). | |
| Recommended budget allocation (highest priority first): | |
| 1. read_inbox() — ALWAYS do this first to catch urgent alerts | |
| 2. query_erp(table='inventory') — check current stock levels across all locations | |
| 3. submit_po(...) — place orders for items below safety stock (PRIORITY) | |
| 4. end_shift() — call this when budget is exhausted OR tasks are done | |
| Query tools (query_erp expiry/pipeline, query_forecast, query_supplier) are LOW PRIORITY. | |
| Only use them if you have budget remaining AFTER placing critical orders. | |
| MANDATORY RULES: | |
| - If you receive "Action budget exhausted" → call end_shift() as your VERY NEXT action. | |
| Do NOT call any other tool. The budget cannot be restored until end_shift() is called. | |
| - Order early: factor in lead times. If lead time is 2 days, order today to avoid stockout in 2 days. | |
| - Expedited orders require file_justification(ticket_id=...) with a real clinical reason. | |
| - FEFO: oldest stock consumed first — check expiry and rotate perishables proactively. | |
| - Recalls: quarantine the recalled lot immediately, then order a replacement. | |
| - MCI events: pre-emptive ordering beats reactive ordering. Order extra blood/critical supplies NOW. | |
| Safety stock target: aim for at least (lead_time + 1) × daily_demand units on hand. | |
| When calling tools, use the EXACT parameter names shown in the tool descriptions. | |
| """ | |
| def log_start(task: str, model: str) -> None: | |
| print(f"[START] task={task} env={BENCHMARK} model={model}", flush=True) | |
| def log_step( | |
| step: int, action: str, reward: float, done: bool, error: Optional[str] | |
| ) -> None: | |
| error_val = error if error else "null" | |
| print( | |
| f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error_val}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| def _tools_to_openai_format(tools) -> List[dict]: | |
| """Convert MCP tools to OpenAI function-calling format.""" | |
| openai_tools = [] | |
| for tool in tools: | |
| properties = {} | |
| required = [] | |
| if tool.input_schema and "properties" in tool.input_schema: | |
| for name, schema in tool.input_schema["properties"].items(): | |
| properties[name] = { | |
| "type": schema.get("type", "string"), | |
| "description": schema.get("description", ""), | |
| } | |
| required = tool.input_schema.get("required", []) | |
| openai_tools.append( | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": tool.name, | |
| "description": tool.description or "", | |
| "parameters": { | |
| "type": "object", | |
| "properties": properties, | |
| "required": required, | |
| }, | |
| }, | |
| } | |
| ) | |
| log.debug("Tool registered: %s (required=%s)", tool.name, required) | |
| return openai_tools | |
| def _make_shift_summary(shift_day: int, end_shift_result: str) -> str: | |
| """Build a compact summary of a completed shift for the context window.""" | |
| lines = [] | |
| for line in (end_shift_result or "").splitlines(): | |
| stripped = line.strip() | |
| if stripped and any( | |
| kw in stripped | |
| for kw in [ | |
| "DEMAND:", | |
| "FULFILLED:", | |
| "DELIVERIES:", | |
| "EXPIRED:", | |
| "Spend:", | |
| "Waste", | |
| "Service Level", | |
| "END OF SHIFT", | |
| "Day ", | |
| "Score:", | |
| ] | |
| ): | |
| lines.append(stripped) | |
| if len(lines) >= 35: | |
| break | |
| summary_body = "\n".join(lines) if lines else (end_shift_result or "")[:900] | |
| return f"[SHIFT DAY {shift_day} SUMMARY]\n{summary_body}" | |
| async def _is_url_reachable(url: str, timeout: float = 30.0) -> bool: | |
| loop = asyncio.get_event_loop() | |
| def _check() -> bool: | |
| try: | |
| data = b"{}" | |
| req = urllib.request.Request( | |
| url.rstrip("/") + "/reset", | |
| data=data, | |
| headers={"Content-Type": "application/json"}, | |
| method="POST", | |
| ) | |
| with urllib.request.urlopen(req, timeout=timeout) as resp: | |
| return resp.status == 200 | |
| except Exception: | |
| return False | |
| try: | |
| return await loop.run_in_executor(None, _check) | |
| except Exception: | |
| return False | |
| async def create_env(task_name: str) -> MedchainEnv: | |
| if LOCAL_IMAGE_NAME: | |
| log.info("Using Docker image '%s' for task '%s'", LOCAL_IMAGE_NAME, task_name) | |
| return await MedchainEnv.from_docker_image(LOCAL_IMAGE_NAME) | |
| if BASE_URL: | |
| log.info("Using BASE_URL '%s' for task '%s'", BASE_URL, task_name) | |
| env = MedchainEnv(base_url=BASE_URL) | |
| await env.connect() | |
| return env | |
| log.info("Probing default URL: %s", DEFAULT_BASE_URL) | |
| if await _is_url_reachable(DEFAULT_BASE_URL): | |
| log.info("Default URL reachable; connecting: %s", DEFAULT_BASE_URL) | |
| env = MedchainEnv(base_url=DEFAULT_BASE_URL) | |
| await env.connect() | |
| return env | |
| log.warning( | |
| "Default URL '%s' not reachable. Falling back to Docker image: %s", | |
| DEFAULT_BASE_URL, | |
| DEFAULT_IMAGE_NAME, | |
| ) | |
| try: | |
| return await MedchainEnv.from_docker_image(DEFAULT_IMAGE_NAME) | |
| except Exception as docker_err: | |
| raise RuntimeError( | |
| f"All environment connection methods failed for task '{task_name}'.\n" | |
| f" 1. LOCAL_IMAGE_NAME: not set\n" | |
| f" 2. BASE_URL: not set\n" | |
| f" 3. Default URL ({DEFAULT_BASE_URL}): not reachable\n" | |
| f" 4. Default Docker image ({DEFAULT_IMAGE_NAME}): {docker_err}\n" | |
| "\nFix: set LOCAL_IMAGE_NAME (Docker image name) or BASE_URL (running server URL), " | |
| "or ensure Docker is running with the image available." | |
| ) from docker_err | |
| async def run_task_episode( | |
| env: MedchainEnv, | |
| client: OpenAI, | |
| tools: List[dict], | |
| task_name: str, | |
| ) -> Dict[str, Any]: | |
| """Run one episode of a task and return the result.""" | |
| tool_names = [t["function"]["name"] for t in tools] | |
| max_steps = MAX_STEPS_PER_TASK.get(task_name, 160) | |
| obs = await env.reset(task=task_name) | |
| obs = obs.observation | |
| dashboard = obs.metadata.get("dashboard", "") | |
| log_start(task=task_name, model=MODEL_NAME) | |
| log.debug( | |
| "[%s] Episode started. Tools: %s max_steps=%d", task_name, tool_names, max_steps | |
| ) | |
| chat_history: List[dict] = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| { | |
| "role": "user", | |
| "content": f"Your shift has started. Current dashboard:\n\n{dashboard}", | |
| }, | |
| ] | |
| step_count = 0 | |
| final_reward = 0.0 | |
| done = obs.done | |
| consecutive_errors = 0 | |
| rewards: List[float] = [] | |
| past_shift_summaries: List[str] = [] | |
| current_shift_messages: List[dict] = [] | |
| # 429 rate-limit tracking | |
| active_model = MODEL_NAME | |
| rate_limit_times: List[float] = [] | |
| backoff_count = 0 | |
| episode_start = time.monotonic() | |
| while not done and step_count < max_steps: | |
| step_count += 1 | |
| log.debug( | |
| "[%s] Step %d/%d — %d messages in context", | |
| task_name, | |
| step_count, | |
| MAX_STEPS_PER_TASK[task_name], | |
| len(chat_history), | |
| ) | |
| try: | |
| response = client.chat.completions.create( | |
| model=active_model, | |
| messages=chat_history, | |
| tools=tools, | |
| tool_choice="required", | |
| max_completion_tokens=MAX_TOKENS, | |
| temperature=TEMPERATURE, | |
| ) | |
| consecutive_errors = 0 | |
| backoff_count = 0 | |
| except RateLimitError as e: | |
| now = time.monotonic() | |
| rate_limit_times.append(now) | |
| # Purge timestamps outside the tracking window | |
| rate_limit_times[:] = [ | |
| t for t in rate_limit_times if now - t <= _429_WINDOW | |
| ] | |
| backoff_count += 1 | |
| backoff_secs = min( | |
| _429_BASE_BACKOFF * (2 ** (backoff_count - 1)), _429_MAX_BACKOFF | |
| ) | |
| log.warning( | |
| "[%s] Step %d — 429 RateLimitError (count=%d in last %ds, backoff=%.1fs): %s", | |
| task_name, | |
| step_count, | |
| len(rate_limit_times), | |
| _429_WINDOW, | |
| backoff_secs, | |
| e, | |
| ) | |
| if ( | |
| len(rate_limit_times) >= _429_DOWNGRADE_THRESHOLD | |
| and active_model != SMALL_MODEL | |
| ): | |
| active_model = SMALL_MODEL | |
| log.warning( | |
| "[%s] Step %d — Downgrading model to %s due to repeated 429 errors", | |
| task_name, | |
| step_count, | |
| SMALL_MODEL, | |
| ) | |
| await asyncio.sleep(backoff_secs) | |
| continue | |
| except BadRequestError as e: | |
| consecutive_errors += 1 | |
| log.warning( | |
| "[%s] Step %d — BadRequestError (%d/%d): %s", | |
| task_name, | |
| step_count, | |
| consecutive_errors, | |
| MAX_CONSECUTIVE_ERRORS, | |
| e, | |
| ) | |
| if consecutive_errors >= MAX_CONSECUTIVE_ERRORS: | |
| log.error( | |
| "[%s] Aborting after %d consecutive errors", | |
| task_name, | |
| MAX_CONSECUTIVE_ERRORS, | |
| ) | |
| break | |
| err_msg = ( | |
| f"Your previous tool call was rejected with an error:\n{e}\n\n" | |
| "Please retry with a valid tool call. If your budget is exhausted, call end_shift()." | |
| ) | |
| chat_history.append({"role": "user", "content": err_msg}) | |
| current_shift_messages.append({"role": "user", "content": err_msg}) | |
| continue | |
| message = response.choices[0].message | |
| log.debug( | |
| "[%s] Step %d — finish_reason=%s tool_calls=%d", | |
| task_name, | |
| step_count, | |
| response.choices[0].finish_reason, | |
| len(message.tool_calls) if message.tool_calls else 0, | |
| ) | |
| if not message.tool_calls: | |
| log.warning( | |
| "[%s] Step %d — no tool_calls in response; falling back to end_shift", | |
| task_name, | |
| step_count, | |
| ) | |
| tool_name = "end_shift" | |
| tool_args = {} | |
| tool_call_id = "fallback" | |
| else: | |
| tc = message.tool_calls[0] | |
| tool_name = tc.function.name | |
| tool_call_id = tc.id | |
| try: | |
| tool_args = json.loads(tc.function.arguments) | |
| except (json.JSONDecodeError, AttributeError): | |
| log.warning( | |
| "[%s] Step %d — failed to parse tool arguments: %r", | |
| task_name, | |
| step_count, | |
| tc.function.arguments, | |
| ) | |
| tool_args = {} | |
| if tool_name not in tool_names: | |
| log.warning( | |
| "[%s] Step %d — unknown tool %r; falling back to end_shift", | |
| task_name, | |
| step_count, | |
| tool_name, | |
| ) | |
| tool_name = "end_shift" | |
| tool_args = {} | |
| log.debug( | |
| "[%s] Step %d — calling %s(%s)", task_name, step_count, tool_name, tool_args | |
| ) | |
| assistant_msg = { | |
| "role": "assistant", | |
| "content": None, | |
| "tool_calls": [ | |
| { | |
| "id": tool_call_id, | |
| "type": "function", | |
| "function": { | |
| "name": tool_name, | |
| "arguments": json.dumps(tool_args), | |
| }, | |
| } | |
| ], | |
| } | |
| chat_history.append(assistant_msg) | |
| current_shift_messages.append(assistant_msg) | |
| action = CallToolAction(tool_name=tool_name, arguments=tool_args) | |
| step_result = await env.step(action) | |
| obs = step_result.observation | |
| done = obs.done | |
| result_text = obs.metadata.get("tool_result", str(obs.metadata)) | |
| step_reward = obs.reward or 0.0 | |
| step_error: Optional[str] = None | |
| if "EPISODE COMPLETE" in (result_text or ""): | |
| log.info("[%s] Step %d — episode complete detected", task_name, step_count) | |
| done = True | |
| if obs.reward is not None and obs.reward > 0: | |
| final_reward = obs.reward | |
| rewards.append(step_reward) | |
| action_str = f"{tool_name}({json.dumps(tool_args)})" | |
| log_step( | |
| step=step_count, | |
| action=action_str, | |
| reward=step_reward, | |
| done=done, | |
| error=step_error, | |
| ) | |
| tool_result_msg = { | |
| "role": "tool", | |
| "tool_call_id": tool_call_id, | |
| "content": result_text[:2700] if result_text else "OK", | |
| } | |
| chat_history.append(tool_result_msg) | |
| current_shift_messages.append(tool_result_msg) | |
| # Budget exhausted — inject directive and skip sleep | |
| if "Action budget exhausted" in (result_text or ""): | |
| log.info( | |
| "[%s] Step %d — budget exhausted; injecting end_shift directive", | |
| task_name, | |
| step_count, | |
| ) | |
| directive = ( | |
| "SYSTEM ALERT: Your action budget for this shift is fully exhausted. " | |
| "You MUST call end_shift() as your very next action. " | |
| "Every other tool call will fail until you do." | |
| ) | |
| chat_history.append({"role": "user", "content": directive}) | |
| current_shift_messages.append({"role": "user", "content": directive}) | |
| continue | |
| await asyncio.sleep(SLEEP_BETWEEN_STEPS) | |
| # Shift ended — summarise and prune context, then set up next shift | |
| if ( | |
| tool_name == "end_shift" | |
| and "END OF SHIFT" in (result_text or "") | |
| and not done | |
| ): | |
| shift_day = "?" | |
| for part in (result_text or "").split(): | |
| if part.isdigit(): | |
| shift_day = part | |
| break | |
| shift_summary = _make_shift_summary(shift_day, result_text or "") | |
| log.debug("[%s] Shift %s summary:\n%s", task_name, shift_day, shift_summary) | |
| past_shift_summaries.append(shift_summary) | |
| log.info( | |
| "[%s] Step %d — shift %s ended; pruning context (%d summaries)", | |
| task_name, | |
| step_count, | |
| shift_day, | |
| len(past_shift_summaries), | |
| ) | |
| summaries_msg = { | |
| "role": "user", | |
| "content": "COMPLETED SHIFT SUMMARIES:\n\n" | |
| + "\n\n".join(past_shift_summaries), | |
| } | |
| trimmed = ( | |
| current_shift_messages[-SHIFT_HISTORY_KEEP:] | |
| if len(current_shift_messages) > SHIFT_HISTORY_KEEP | |
| else list(current_shift_messages) | |
| ) | |
| # Remove budget-exhausted directives so they don't bleed into the next shift | |
| trimmed = [ | |
| m | |
| for m in trimmed | |
| if "Action budget exhausted" not in (m.get("content") or "") | |
| ] | |
| # Strip orphaned leading tool-response messages to avoid API errors | |
| while trimmed and trimmed[0].get("role") == "tool": | |
| log.debug( | |
| "[%s] Dropping orphaned leading tool msg (tool_call_id=%s)", | |
| task_name, | |
| trimmed[0].get("tool_call_id"), | |
| ) | |
| trimmed = trimmed[1:] | |
| chat_history = ( | |
| [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| summaries_msg, | |
| ] | |
| + trimmed | |
| + [ | |
| { | |
| "role": "user", | |
| "content": "Your next shift has begun. The dashboard is shown above in the last tool result. " | |
| "Continue managing the supply chain.", | |
| }, | |
| ] | |
| ) | |
| current_shift_messages = [] | |
| episode_duration = time.monotonic() - episode_start | |
| log.info( | |
| "[%s] Episode finished. steps=%d done=%s final_reward=%.4f", | |
| task_name, | |
| step_count, | |
| done, | |
| final_reward, | |
| ) | |
| log.debug("[%s] Episode duration: %.1fs", task_name, episode_duration) | |
| return { | |
| "task": task_name, | |
| "reward": final_reward, | |
| "steps": step_count, | |
| "done": done, | |
| "rewards": rewards, | |
| "duration": episode_duration, | |
| } | |
| async def async_main() -> None: | |
| if not API_KEY: | |
| raise SystemExit("HF_TOKEN or API_KEY must be set.") | |
| if not MODEL_NAME: | |
| raise SystemExit("MODEL_NAME or MODEL must be set.") | |
| log.info("Starting. API_BASE_URL=%s MODEL_NAME=%s", API_BASE_URL, MODEL_NAME) | |
| log.info("Tasks: %s", TASKS) | |
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) | |
| results = [] | |
| script_start = time.monotonic() | |
| for task_name in TASKS: | |
| log.info("Launching task: %s", task_name) | |
| task_start = time.monotonic() | |
| env = await create_env(task_name) | |
| final_reward = 0.0 | |
| success = False | |
| steps = 0 | |
| step_rewards: List[float] = [] | |
| try: | |
| mcp_tools = await env.list_tools() | |
| tools = _tools_to_openai_format(mcp_tools) | |
| log.info("[%s] %d tools discovered", task_name, len(tools)) | |
| result = await run_task_episode(env, client, tools, task_name) | |
| results.append(result) | |
| final_reward = result["reward"] | |
| steps = result["steps"] | |
| success = result["done"] | |
| step_rewards = result["rewards"] | |
| log.info( | |
| "[%s] Task complete: reward=%.4f steps=%d", | |
| task_name, | |
| final_reward, | |
| steps, | |
| ) | |
| except Exception as e: | |
| log.error("[%s] Task failed with exception: %s", task_name, e) | |
| finally: | |
| try: | |
| await env.close() | |
| except Exception as e: | |
| log.error("[%s] env.close() failed: %s", task_name, e) | |
| log_end( | |
| success=success, steps=steps, score=final_reward, rewards=step_rewards | |
| ) | |
| log.debug( | |
| "[%s] Total task wall time: %.1fs", | |
| task_name, | |
| time.monotonic() - task_start, | |
| ) | |
| total_duration = time.monotonic() - script_start | |
| if results: | |
| avg_reward = sum(r["reward"] for r in results) / len(results) | |
| log.info("All tasks complete. avg_reward=%.4f", avg_reward) | |
| log.debug("Overall script duration: %.1fs", total_duration) | |
| def main() -> None: | |
| asyncio.run(async_main()) | |
| if __name__ == "__main__": | |
| main() | |