supply-chain-env / inference.py
ragavrida's picture
fix: clamp ALL rewards/scores to strict (0.01, 0.99) β€” every output path
29994af
"""
SupplyChainEnv β€” Inference Script
===================================
MANDATORY
- Before submitting, ensure the following variables are defined in your environment configuration:
API_BASE_URL The API endpoint for the LLM.
MODEL_NAME The model identifier to use for inference.
HF_TOKEN Your Hugging Face / API key.
IMAGE_NAME The name of the local image to use for the environment if using from_docker_image()
- The inference script must be named `inference.py` and placed in the root directory of the project
- Participants must use OpenAI Client for all LLM calls using above variables
STDOUT FORMAT
- The script must emit exactly three line types to stdout, in this order:
[START] task=<task_name> env=<benchmark> model=<model_name>
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...,rn>
"""
import asyncio
import json
import os
import re
import sys
import textwrap
import time
from typing import Dict, List, Optional
from openai import OpenAI
from models import SupplyChainAction, SupplyChainObservation
from client import SupplyChainEnv
# ─── Configuration (matches reference inference exactly) ─────────────────────
IMAGE_NAME = os.getenv("IMAGE_NAME")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
BENCHMARK = "supply-chain-env"
TASK_NAMES = ["single_shipment", "disruption_reroute", "crisis_management"]
MAX_STEPS = 50
TEMPERATURE = 0.0
MAX_TOKENS = 500
print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", file=sys.stderr, flush=True)
print(f"[DEBUG] API_KEY=...{(API_KEY or '')[-8:]}", file=sys.stderr, flush=True)
print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", file=sys.stderr, flush=True)
print(f"[DEBUG] IMAGE_NAME={IMAGE_NAME}", file=sys.stderr, flush=True)
# ─── Logging (exact spec format) ────────────────────────────────────────────
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
# ─── LLM Call (uses platform proxy) ─────────────────────────────────────────
def call_llm(client: OpenAI, system: str, user: str) -> str:
"""Call LLM through the platform's LiteLLM proxy. Retries on failure."""
for attempt in range(3):
try:
r = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
stream=False,
)
return (r.choices[0].message.content or "").strip()
except Exception as e:
print(f"[DEBUG] LLM attempt {attempt+1} failed: {e}", file=sys.stderr, flush=True)
if attempt < 2:
time.sleep(2 ** attempt)
return ""
def parse_json(text: str) -> Optional[Dict]:
if not text:
return None
text = text.strip()
if text.startswith("```"):
text = "\n".join(l for l in text.split("\n") if not l.strip().startswith("```"))
try:
return json.loads(text)
except json.JSONDecodeError:
pass
m = re.search(r'\{.*\}', text, re.DOTALL)
if m:
try:
return json.loads(m.group(0))
except json.JSONDecodeError:
pass
return None
# ─── System Prompt ───────────────────────────────────────────────────────────
SYSTEM_PROMPT = textwrap.dedent("""
You are a supply chain logistics manager using MCP tools.
Route shipments to destinations before deadlines, avoiding disrupted ports.
Strategy:
1. Call view_network to see port status
2. Call get_disruptions to see what's blocked
3. Call view_shipments to see pending shipments
4. For each: find_path then route_shipment
5. Call advance_day to progress time
6. Call end_simulation when done
Respond with ONE JSON tool call:
{"tool_name": "view_network", "arguments": {}}
{"tool_name": "route_shipment", "arguments": {"shipment_id": "ship_0", "route": ["port_singapore", "port_la"]}}
{"tool_name": "advance_day", "arguments": {}}
{"tool_name": "end_simulation", "arguments": {}}
""").strip()
# ─── Task Runner ─────────────────────────────────────────────────────────────
async def run_task(env: SupplyChainEnv, client: OpenAI, task_name: str) -> tuple:
"""Run one task episode. Returns (score, steps, rewards). Never crashes."""
rewards: List[float] = []
steps_taken = 0
try:
result = await env.reset(seed=42, task=task_name)
obs = result.observation
except Exception as e:
print(f"[ERROR] reset failed for {task_name}: {e}", file=sys.stderr, flush=True)
return 0.0, 0, []
context = ""
try:
context = json.dumps(getattr(obs, 'tool_result', None), default=str)[:2000]
except Exception:
context = ""
for step in range(1, MAX_STEPS + 1):
if getattr(result, 'done', False):
break
# Ask LLM for next tool call
prompt = f"Task: {task_name}. Step {step}. Last result:\n{context}\n\nWhat tool to call next?"
response = call_llm(client, SYSTEM_PROMPT, prompt)
parsed = parse_json(response)
if parsed and parsed.get("tool_name"):
action = SupplyChainAction(
action_type="ToolCallAction",
tool_name=parsed["tool_name"],
arguments=parsed.get("arguments", {}),
)
action_str = parsed["tool_name"]
else:
action = SupplyChainAction(action_type="ToolCallAction", tool_name="advance_day", arguments={})
action_str = "advance_day"
reward, done, error = 0.0, False, None
try:
result = await env.step(action)
obs = result.observation
reward = result.reward or 0.0
done = result.done
error = getattr(obs, 'error_message', None)
context = json.dumps(getattr(obs, 'tool_result', None), default=str)[:2000]
except Exception as e:
reward, done, error = 0.0, True, str(e)
context = ""
print(f"[ERROR] step {step} failed: {e}", file=sys.stderr, flush=True)
rewards.append(reward)
steps_taken = step
log_step(step=step, action=action_str, reward=reward, done=done, error=error)
if done:
break
score = max(0.01, min(0.99, rewards[-1] if rewards else 0.01))
return score, steps_taken, rewards
# ─── Main ────────────────────────────────────────────────────────────────────
async def main() -> None:
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
# Verify LLM proxy is reachable FIRST β€” this ensures at least one API call
print(f"[DEBUG] Testing LLM proxy at {API_BASE_URL}...", file=sys.stderr, flush=True)
try:
test = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": "ping"}],
max_tokens=5,
temperature=0.0,
)
print(f"[DEBUG] Proxy OK: {test.choices[0].message.content!r}", file=sys.stderr, flush=True)
except Exception as e:
print(f"[WARNING] Proxy test failed: {e}. Continuing anyway.", file=sys.stderr, flush=True)
# Connect to environment via Docker image (platform provides IMAGE_NAME)
print(f"[DEBUG] Connecting to env via docker image: {IMAGE_NAME}", file=sys.stderr, flush=True)
try:
env = await SupplyChainEnv.from_docker_image(IMAGE_NAME)
except Exception as e:
print(f"[ERROR] from_docker_image failed: {e}", file=sys.stderr, flush=True)
import traceback
traceback.print_exc(file=sys.stderr)
for task_name in TASK_NAMES:
log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
log_end(success=False, steps=0, score=0.01, rewards=[])
return
scores = {}
for task_name in TASK_NAMES:
score = 0.01
steps_taken = 0
rewards: List[float] = []
success = False
log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
try:
score, steps_taken, rewards = await run_task(env, client, task_name)
success = score >= 0.1
except Exception as e:
print(f"[ERROR] Task {task_name} failed: {e}", file=sys.stderr, flush=True)
import traceback
traceback.print_exc(file=sys.stderr)
finally:
score = min(max(score, 0.01), 0.99)
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
scores[task_name] = score
# Cleanup
try:
await env.close()
except Exception as e:
print(f"[DEBUG] env.close() error: {e}", file=sys.stderr, flush=True)
composite = sum(scores.values()) / len(scores) if scores else 0.0
print(f"\n[SUMMARY] composite={composite:.3f} " + " ".join(f"{k}={v:.3f}" for k, v in scores.items()), file=sys.stderr, flush=True)
if __name__ == "__main__":
asyncio.run(main())