| """ |
| SupplyChainEnv β Inference Script |
| =================================== |
| MANDATORY |
| - Before submitting, ensure the following variables are defined in your environment configuration: |
| API_BASE_URL The API endpoint for the LLM. |
| MODEL_NAME The model identifier to use for inference. |
| HF_TOKEN Your Hugging Face / API key. |
| IMAGE_NAME The name of the local image to use for the environment if using from_docker_image() |
| |
| - The inference script must be named `inference.py` and placed in the root directory of the project |
| - Participants must use OpenAI Client for all LLM calls using above variables |
| |
| STDOUT FORMAT |
| - The script must emit exactly three line types to stdout, in this order: |
| |
| [START] task=<task_name> env=<benchmark> model=<model_name> |
| [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null> |
| [END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...,rn> |
| """ |
|
|
| import asyncio |
| import json |
| import os |
| import re |
| import sys |
| import textwrap |
| import time |
| from typing import Dict, List, Optional |
|
|
| from openai import OpenAI |
|
|
| from models import SupplyChainAction, SupplyChainObservation |
| from client import SupplyChainEnv |
|
|
| |
|
|
| IMAGE_NAME = os.getenv("IMAGE_NAME") |
| API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") |
| API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1" |
| MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct" |
|
|
| BENCHMARK = "supply-chain-env" |
| TASK_NAMES = ["single_shipment", "disruption_reroute", "crisis_management"] |
| MAX_STEPS = 50 |
| TEMPERATURE = 0.0 |
| MAX_TOKENS = 500 |
|
|
| print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", file=sys.stderr, flush=True) |
| print(f"[DEBUG] API_KEY=...{(API_KEY or '')[-8:]}", file=sys.stderr, flush=True) |
| print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", file=sys.stderr, flush=True) |
| print(f"[DEBUG] IMAGE_NAME={IMAGE_NAME}", file=sys.stderr, flush=True) |
|
|
|
|
| |
|
|
| def log_start(task: str, env: str, model: str) -> None: |
| print(f"[START] task={task} env={env} model={model}", flush=True) |
|
|
|
|
| def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: |
| error_val = error if error else "null" |
| done_val = str(done).lower() |
| print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True) |
|
|
|
|
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) |
| print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True) |
|
|
|
|
| |
|
|
| def call_llm(client: OpenAI, system: str, user: str) -> str: |
| """Call LLM through the platform's LiteLLM proxy. Retries on failure.""" |
| for attempt in range(3): |
| try: |
| r = client.chat.completions.create( |
| model=MODEL_NAME, |
| messages=[ |
| {"role": "system", "content": system}, |
| {"role": "user", "content": user}, |
| ], |
| temperature=TEMPERATURE, |
| max_tokens=MAX_TOKENS, |
| stream=False, |
| ) |
| return (r.choices[0].message.content or "").strip() |
| except Exception as e: |
| print(f"[DEBUG] LLM attempt {attempt+1} failed: {e}", file=sys.stderr, flush=True) |
| if attempt < 2: |
| time.sleep(2 ** attempt) |
| return "" |
|
|
|
|
| def parse_json(text: str) -> Optional[Dict]: |
| if not text: |
| return None |
| text = text.strip() |
| if text.startswith("```"): |
| text = "\n".join(l for l in text.split("\n") if not l.strip().startswith("```")) |
| try: |
| return json.loads(text) |
| except json.JSONDecodeError: |
| pass |
| m = re.search(r'\{.*\}', text, re.DOTALL) |
| if m: |
| try: |
| return json.loads(m.group(0)) |
| except json.JSONDecodeError: |
| pass |
| return None |
|
|
|
|
| |
|
|
| SYSTEM_PROMPT = textwrap.dedent(""" |
| You are a supply chain logistics manager using MCP tools. |
| Route shipments to destinations before deadlines, avoiding disrupted ports. |
| |
| Strategy: |
| 1. Call view_network to see port status |
| 2. Call get_disruptions to see what's blocked |
| 3. Call view_shipments to see pending shipments |
| 4. For each: find_path then route_shipment |
| 5. Call advance_day to progress time |
| 6. Call end_simulation when done |
| |
| Respond with ONE JSON tool call: |
| {"tool_name": "view_network", "arguments": {}} |
| {"tool_name": "route_shipment", "arguments": {"shipment_id": "ship_0", "route": ["port_singapore", "port_la"]}} |
| {"tool_name": "advance_day", "arguments": {}} |
| {"tool_name": "end_simulation", "arguments": {}} |
| """).strip() |
|
|
|
|
| |
|
|
| async def run_task(env: SupplyChainEnv, client: OpenAI, task_name: str) -> tuple: |
| """Run one task episode. Returns (score, steps, rewards). Never crashes.""" |
| rewards: List[float] = [] |
| steps_taken = 0 |
|
|
| try: |
| result = await env.reset(seed=42, task=task_name) |
| obs = result.observation |
| except Exception as e: |
| print(f"[ERROR] reset failed for {task_name}: {e}", file=sys.stderr, flush=True) |
| return 0.0, 0, [] |
|
|
| context = "" |
| try: |
| context = json.dumps(getattr(obs, 'tool_result', None), default=str)[:2000] |
| except Exception: |
| context = "" |
|
|
| for step in range(1, MAX_STEPS + 1): |
| if getattr(result, 'done', False): |
| break |
|
|
| |
| prompt = f"Task: {task_name}. Step {step}. Last result:\n{context}\n\nWhat tool to call next?" |
| response = call_llm(client, SYSTEM_PROMPT, prompt) |
| parsed = parse_json(response) |
|
|
| if parsed and parsed.get("tool_name"): |
| action = SupplyChainAction( |
| action_type="ToolCallAction", |
| tool_name=parsed["tool_name"], |
| arguments=parsed.get("arguments", {}), |
| ) |
| action_str = parsed["tool_name"] |
| else: |
| action = SupplyChainAction(action_type="ToolCallAction", tool_name="advance_day", arguments={}) |
| action_str = "advance_day" |
|
|
| reward, done, error = 0.0, False, None |
| try: |
| result = await env.step(action) |
| obs = result.observation |
| reward = result.reward or 0.0 |
| done = result.done |
| error = getattr(obs, 'error_message', None) |
| context = json.dumps(getattr(obs, 'tool_result', None), default=str)[:2000] |
| except Exception as e: |
| reward, done, error = 0.0, True, str(e) |
| context = "" |
| print(f"[ERROR] step {step} failed: {e}", file=sys.stderr, flush=True) |
|
|
| rewards.append(reward) |
| steps_taken = step |
|
|
| log_step(step=step, action=action_str, reward=reward, done=done, error=error) |
|
|
| if done: |
| break |
|
|
| score = max(0.01, min(0.99, rewards[-1] if rewards else 0.01)) |
| return score, steps_taken, rewards |
|
|
|
|
| |
|
|
| async def main() -> None: |
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) |
|
|
| |
| print(f"[DEBUG] Testing LLM proxy at {API_BASE_URL}...", file=sys.stderr, flush=True) |
| try: |
| test = client.chat.completions.create( |
| model=MODEL_NAME, |
| messages=[{"role": "user", "content": "ping"}], |
| max_tokens=5, |
| temperature=0.0, |
| ) |
| print(f"[DEBUG] Proxy OK: {test.choices[0].message.content!r}", file=sys.stderr, flush=True) |
| except Exception as e: |
| print(f"[WARNING] Proxy test failed: {e}. Continuing anyway.", file=sys.stderr, flush=True) |
|
|
| |
| print(f"[DEBUG] Connecting to env via docker image: {IMAGE_NAME}", file=sys.stderr, flush=True) |
| try: |
| env = await SupplyChainEnv.from_docker_image(IMAGE_NAME) |
| except Exception as e: |
| print(f"[ERROR] from_docker_image failed: {e}", file=sys.stderr, flush=True) |
| import traceback |
| traceback.print_exc(file=sys.stderr) |
| for task_name in TASK_NAMES: |
| log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME) |
| log_end(success=False, steps=0, score=0.01, rewards=[]) |
| return |
|
|
| scores = {} |
|
|
| for task_name in TASK_NAMES: |
| score = 0.01 |
| steps_taken = 0 |
| rewards: List[float] = [] |
| success = False |
|
|
| log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME) |
|
|
| try: |
| score, steps_taken, rewards = await run_task(env, client, task_name) |
| success = score >= 0.1 |
| except Exception as e: |
| print(f"[ERROR] Task {task_name} failed: {e}", file=sys.stderr, flush=True) |
| import traceback |
| traceback.print_exc(file=sys.stderr) |
| finally: |
| score = min(max(score, 0.01), 0.99) |
| log_end(success=success, steps=steps_taken, score=score, rewards=rewards) |
|
|
| scores[task_name] = score |
|
|
| |
| try: |
| await env.close() |
| except Exception as e: |
| print(f"[DEBUG] env.close() error: {e}", file=sys.stderr, flush=True) |
|
|
| composite = sum(scores.values()) / len(scores) if scores else 0.0 |
| print(f"\n[SUMMARY] composite={composite:.3f} " + " ".join(f"{k}={v:.3f}" for k, v in scores.items()), file=sys.stderr, flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| asyncio.run(main()) |
|
|