dispatch-triage / inference.py
muskanp's picture
Upload folder using huggingface_hub
95cb3fd verified
"""
inference.py — 911 Dispatch Triage Environment
===============================================
Runs one episode per difficulty level (easy → medium → hard).
The LLM agent reads natural-language incident descriptions and available
unit types, then decides which unit to dispatch to which incident each step.
Required environment variables:
API_BASE_URL LLM API endpoint (default: HuggingFace Inference Router)
MODEL_NAME Model identifier (default: Qwen/Qwen2.5-72B-Instruct)
HF_TOKEN HuggingFace API token (REQUIRED — no default)
ENV_BASE_URL Running server URL (e.g. https://<user>-dispatch-triage-env.hf.space)
Optional:
LOCAL_IMAGE_NAME Docker image name (fallback when ENV_BASE_URL is not set)
Stdout format (strictly required by hackathon grader):
[START] task=<str> env=<str> model=<str>
[STEP] step=<int> action=<str> reward=<float> done=<bool> error=<str|null>
[END] success=<bool> steps=<int> rewards=<float,...>
"""
import json
import os
import re
import sys
import textwrap
import traceback
from typing import List, Optional, Tuple
from dotenv import load_dotenv
from openai import OpenAI
load_dotenv(override=True)
# DispatchTriageEnv is the CLIENT — connects to the running server via WebSocket.
# inference.py lives at the repo root, so models/client are absolute imports.
from client import DispatchTriageEnv
from models import DispatchTriageAction
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
API_BASE_URL: str = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME: str = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
HF_TOKEN: Optional[str] = os.getenv("HF_TOKEN")
ENV_BASE_URL: Optional[str] = os.getenv("ENV_BASE_URL")
LOCAL_IMAGE_NAME: Optional[str] = os.getenv("LOCAL_IMAGE_NAME")
BENCHMARK: str = "dispatch_triage_env"
DIFFICULTIES: List[str] = ["easy", "medium", "hard"]
MAX_STEPS: int = 8 # safety ceiling (3 valid dispatches per episode max)
TEMPERATURE: float = 0.4
MAX_TOKENS: int = 512
SUCCESS_THRESHOLD: float = 0.6 # normalised score in [0.0, 1.0]
# --- Startup validation ---
if not HF_TOKEN:
print(
"[ERROR] HF_TOKEN environment variable is not set. "
"Export your Hugging Face token before running inference.py.",
file=sys.stderr,
flush=True,
)
sys.exit(1)
if not ENV_BASE_URL and not LOCAL_IMAGE_NAME:
print(
"[ERROR] Neither ENV_BASE_URL nor LOCAL_IMAGE_NAME is set.\n"
" • Set ENV_BASE_URL to your deployed HF Space URL, e.g.\n"
" export ENV_BASE_URL=https://<user>-dispatch-triage-env.hf.space\n"
" • Or set LOCAL_IMAGE_NAME to a local Docker image, e.g.\n"
" export LOCAL_IMAGE_NAME=dispatch-triage-env:latest",
file=sys.stderr,
flush=True,
)
sys.exit(1)
# ---------------------------------------------------------------------------
# Structured stdout logging — format is graded automatically
# ---------------------------------------------------------------------------
def log_start(task: str, env: str, model: str) -> None:
"""[START] line — emitted once per episode, before any steps."""
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(
step: int,
action: str,
reward: float,
done: bool,
error: Optional[str],
) -> None:
"""[STEP] line — emitted after every env.step() call."""
error_val = error if error else "null"
print(
f"[STEP] step={step} action={action} reward={reward:.4f} "
f"done={str(done).lower()} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, rewards: List[float]) -> None:
"""[END] line — emitted once per episode after the loop finishes."""
rewards_str = ",".join(f"{r:.4f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} rewards={rewards_str}",
flush=True,
)
# ---------------------------------------------------------------------------
# System prompt
# ---------------------------------------------------------------------------
SYSTEM_PROMPT: str = textwrap.dedent("""
You are an AI-powered 911 dispatch operator.
Each turn you receive:
• ACTIVE INCIDENTS — each has an id, a plain-English caller description, and a location.
• AVAILABLE UNITS — each has an id and a type: ambulance | fire_truck | police.
Your task is to choose ONE incident and ONE unit to dispatch this turn.
Rules:
1. Infer urgency purely from the description — there are no explicit severity numbers.
2. Match unit type to the nature of the emergency:
ambulance → medical emergencies (cardiac arrest, unresponsive person, injuries)
fire_truck → fires, gas leaks, structural hazards
police → vehicle collisions, traffic control, crowd management
3. Always prioritise the most life-threatening, time-critical incident first.
4. If a description hints that one incident must be resolved before another can be
safely attended (e.g. a gas leak blocking entry to an adjacent building where a
cardiac arrest is occurring), resolve the blocking incident FIRST, even if the
other sounds more urgent on the surface.
5. Only dispatch to incidents and units that appear in the lists below.
Respond with ONLY a single-line JSON object — no markdown, no preamble:
{"incident_id": <int>, "unit_id": <int>, "reasoning": "<one concise sentence>"}
""").strip()
# ---------------------------------------------------------------------------
# Prompt construction helpers
# ---------------------------------------------------------------------------
def _format_incidents(incidents) -> str:
lines = []
for inc in incidents:
if not inc.resolved:
dep_note = f" [depends_on: {inc.depends_on}]" if inc.depends_on else ""
lines.append(
f" [id={inc.id}] {inc.location}{dep_note}\n"
f" {inc.description}"
)
return "\n".join(lines) if lines else " (none remaining)"
def _format_units(units) -> str:
lines = []
for u in units:
if u.available:
lines.append(f" [id={u.id}] {u.type}")
return "\n".join(lines) if lines else " (none remaining)"
def build_user_prompt(obs, step: int, last_message: str) -> str:
return textwrap.dedent(f"""
Step {step} | Dispatches so far: {obs.dispatch_count}
Current score: {obs.score_so_far:.4f}
Last server message: {last_message}
ACTIVE INCIDENTS (unresolved):
{_format_incidents(obs.incidents)}
AVAILABLE UNITS:
{_format_units(obs.units)}
Choose your next dispatch. Reply with exactly:
{{"incident_id": <int>, "unit_id": <int>, "reasoning": "<one sentence>"}}
""").strip()
# ---------------------------------------------------------------------------
# LLM call + JSON parsing
# ---------------------------------------------------------------------------
def call_llm(
client: OpenAI,
user_prompt: str,
) -> Tuple[Optional[int], Optional[int], str]:
"""
Call the LLM and parse its JSON response.
Returns:
(incident_id, unit_id, raw_response)
On any failure returns (None, None, error_description).
"""
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
)
raw: str = (completion.choices[0].message.content or "").strip()
except Exception as exc:
err = f"LLM call failed: {exc}"
print(f"[DEBUG] {err}", flush=True)
return None, None, err
# Strip accidental markdown fences
clean = re.sub(r"```(?:json)?|```", "", raw).strip()
# Handle model responses that wrap JSON in extra text by finding the first {...}
match = re.search(r"\{.*?\}", clean, re.DOTALL)
if match:
clean = match.group(0)
try:
parsed = json.loads(clean)
incident_id = int(parsed["incident_id"])
unit_id = int(parsed["unit_id"])
return incident_id, unit_id, raw
except Exception as exc:
err = f"JSON parse failed: {exc}"
print(f"[DEBUG] {err} | raw={raw!r}", flush=True)
return None, None, raw
# ---------------------------------------------------------------------------
# Fallback action selector
# ---------------------------------------------------------------------------
def _fallback_action(obs) -> Optional[Tuple[int, int]]:
"""
When the LLM response is unparseable, pick the first available
(incident, unit) pair as a safe fallback.
"""
available_incidents = [i for i in obs.incidents if not i.resolved]
available_units = [u for u in obs.units if u.available]
if available_incidents and available_units:
return available_incidents[0].id, available_units[0].id
return None
# ---------------------------------------------------------------------------
# Single episode runner
# ---------------------------------------------------------------------------
def run_episode(
env, # SyncEnvClient wrapping DispatchTriageEnv
llm_client: OpenAI,
difficulty: str,
) -> None:
"""
Run one full episode for the given difficulty level.
Emits [START], one [STEP] per dispatch, and [END] to stdout.
All lines follow the hackathon grader format exactly.
"""
task_name = f"dispatch-{difficulty}"
log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
rewards: List[float] = []
steps_taken: int = 0
success: bool = False
try:
# env.reset() returns StepResult; observation holds the initial state.
result = env.reset(difficulty=difficulty)
obs = result.observation
last_message: str = obs.message
for step in range(1, MAX_STEPS + 1):
# Check terminal condition from the previous step result
if result.done:
break
# Build no-information-leak prompt (no severity numbers, no type labels)
user_prompt = build_user_prompt(obs, step, last_message)
incident_id, unit_id, raw_response = call_llm(llm_client, user_prompt)
# Fallback: use first available pair when LLM fails to parse
error_msg: Optional[str] = None
if incident_id is None or unit_id is None:
error_msg = "JSON parse failed — using fallback action"
fallback = _fallback_action(obs)
if fallback is None:
print("[DEBUG] No valid fallback action — ending episode.", flush=True)
break
incident_id, unit_id = fallback
action = DispatchTriageAction(incident_id=incident_id, unit_id=unit_id)
action_str = f"dispatch(incident_id={incident_id},unit_id={unit_id})"
# Execute action on the environment server
result = env.step(action)
obs = result.observation
reward = result.reward if result.reward is not None else 0.0
done = result.done
last_message = obs.message
rewards.append(reward)
steps_taken = step
log_step(step=step, action=action_str, reward=reward, done=done, error=error_msg)
if done:
break
# Episode ended — determine success from the final normalised score
final_score = rewards[-1] if rewards else 0.0
success = final_score >= SUCCESS_THRESHOLD
except Exception as exc:
print(f"[DEBUG] Episode error ({difficulty}): {exc}", flush=True)
traceback.print_exc(file=sys.stderr)
finally:
log_end(success=success, steps=steps_taken, rewards=rewards)
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def main() -> None:
llm_client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
# Build the async client, then wrap it in the synchronous adapter.
# Using .sync() is required — EnvClient methods are async coroutines.
if ENV_BASE_URL:
async_env = DispatchTriageEnv(base_url=ENV_BASE_URL)
else:
async_env = DispatchTriageEnv.from_docker_image(LOCAL_IMAGE_NAME)
env = async_env.sync()
try:
with env:
for difficulty in DIFFICULTIES:
run_episode(env, llm_client, difficulty)
except Exception as exc:
print(f"[DEBUG] Fatal error: {exc}", flush=True)
traceback.print_exc(file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()