metrollm-bench-mac / harness /mock_server.py
Remco Hendriks
Update Mac bench dist
2d05890 verified
"""Mock tool server for MetroLLM-Bench.
Exposes three transit tool endpoints that the benchmark runner forwards LLM
tool calls to:
POST /route_planner
POST /fare_calculator
POST /station_info
Run via:
uvicorn harness.mock_server:app --port 8100
or via the project entry-point:
mock-server --system marta --port 8100
"""
import argparse
import dataclasses
import hashlib
import json
import sys
from pathlib import Path
from typing import Optional
import networkx as nx
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
from pydantic import BaseModel, Field
from harness.graph import MetroGraph
from harness.fares import FareCalculator
# ---------------------------------------------------------------------------
# LLM config (set by main(), used by /simulate)
# ---------------------------------------------------------------------------
_llm_base_url: str = "https://api.anthropic.com/v1"
_llm_api_key: str = ""
_llm_model: str = "claude-haiku-4-5-20251001"
_port: int = 8100
# ---------------------------------------------------------------------------
# Application state — populated at startup
# ---------------------------------------------------------------------------
app = FastAPI(title="MetroLLM-Bench Mock Tool Server")
@dataclasses.dataclass
class _SystemData:
"""Per-system data loaded lazily and cached."""
metro: MetroGraph
fares: FareCalculator
policies: list[dict]
line_alias: dict[str, str]
route_cache: dict[str, dict] = dataclasses.field(default_factory=dict)
_systems: dict[str, _SystemData] = {} # system_name → data (lazy cache)
_case_system: dict[str, str] = {} # case_id → system_name
_system_name: str = "" # default system (set at startup)
_disruptions_by_case: dict[str, list[dict]] = {}
def _build_line_alias(system_dir: Path) -> dict[str, str]:
"""Build alias→canonical_id map from lines.json.
For a line with id="1", name="Line 1", generates:
"1" → "1", "line 1" → "1"
For id="red", name="Red Line":
"red" → "red", "red line" → "red"
"""
alias: dict[str, str] = {}
lines_path = system_dir / "lines.json"
if not lines_path.exists():
return alias
with open(lines_path) as f:
lines = json.load(f)
for line in lines:
lid = line["id"]
alias[lid.lower()] = lid
if line.get("name"):
alias[line["name"].lower()] = lid
return alias
_DATA_ROOT = Path(__file__).resolve().parent.parent / "data" / "systems"
def _load_system(name: str) -> _SystemData:
"""Load system data from disk, caching for subsequent calls."""
if name in _systems:
return _systems[name]
sys_dir = _DATA_ROOT / name
if not sys_dir.is_dir():
raise ValueError(f"Unknown system: {name}")
policies_path = sys_dir / "policies.json"
if policies_path.exists():
raw = json.loads(policies_path.read_text())
policies: list[dict] = raw["policies"] if isinstance(raw, dict) and "policies" in raw else raw
else:
policies = []
sd = _SystemData(
metro=MetroGraph(sys_dir),
fares=FareCalculator(sys_dir),
policies=policies,
line_alias=_build_line_alias(sys_dir),
)
_systems[name] = sd
return sd
def _system_for_case(case_id: str | None) -> _SystemData:
"""Resolve system data for a case, falling back to startup default."""
name = _case_system.get(case_id or "", _system_name)
if not name:
raise RuntimeError("No system configured")
return _load_system(name)
# ---------------------------------------------------------------------------
# Pydantic models
# ---------------------------------------------------------------------------
# --- /route_planner ----------------------------------------------------------
class StationRestriction(BaseModel):
station: str
restriction: str # "closed", "skip", "no_transfer"
class LineClosure(BaseModel):
line: str
from_station: Optional[str] = None
to_station: Optional[str] = None
class RoutePlannerRequest(BaseModel):
origin: str
destination: str
departure_time: Optional[str] = None
accessibility: Optional[list[str]] = None
station_restrictions: Optional[list[StationRestriction]] = None
segment_closures: Optional[list[list[str]]] = None
line_closures: Optional[list[LineClosure]] = None
case_id: Optional[str] = None
class StopInfo(BaseModel):
station_id: str
station_name: str
line: Optional[str]
is_transfer: bool
transfer_to: Optional[str]
class RoutePlannerResponse(BaseModel):
route_id: str
stops: list[StopInfo]
transfers: int
estimated_minutes: float
distance_miles: float
line_sequence: list[str]
# --- /fare_calculator --------------------------------------------------------
class PassengerCounts(BaseModel):
adults: int = Field(default=0, ge=0)
children: int = Field(default=0, ge=0)
seniors: int = Field(default=0, ge=0)
disabled: int = Field(default=0, ge=0)
class FareCalculatorRequest(BaseModel):
route_id: str
passengers: PassengerCounts
ticket_type: str = "single"
payment_method: str = "breeze_card"
case_id: Optional[str] = None
class LineItem(BaseModel):
label: str
amount: float
currency: str
class Discount(BaseModel):
label: str
amount: float
currency: str
class FareCalculatorResponse(BaseModel):
fare_id: str
line_items: list[LineItem]
subtotal: float
discounts: list[Discount]
total: float
currency: str
# --- /station_info -----------------------------------------------------------
VALID_QUERY_TYPES = frozenset(
{"accessibility", "facilities", "exits", "connections", "real_time_status"}
)
class StationInfoRequest(BaseModel):
station_id: Optional[str] = None
station_ids: Optional[list[str]] = None
query_type: str
case_id: Optional[str] = None
class StationInfoResponse(BaseModel):
station_id: str
data: dict
class StationInfoBatchResponse(BaseModel):
results: list[StationInfoResponse]
# --- /line_info --------------------------------------------------------------
class LineInfoRequest(BaseModel):
line: Optional[str] = None
lines: Optional[list[str]] = None
case_id: Optional[str] = None
class LineStationEntry(BaseModel):
station_id: str
station_name: str
position: int
is_terminus: bool
connections: list[str] # other line ids at this station (empty if single-line)
class LineInfoResponse(BaseModel):
line_id: str
line_name: str
color: str
station_count: int
is_loop: bool
terminals: list[str]
stations: list[LineStationEntry]
class LineInfoBatchResponse(BaseModel):
results: list[LineInfoResponse]
# --- /disruption_feed -------------------------------------------------------
class DisruptionFeedRequest(BaseModel):
case_id: Optional[str] = None # internal: set by runner, not exposed to LLM
current_time: Optional[str] = None # ISO 8601 naive timestamp for temporal filtering
line: Optional[str] = None
station: Optional[str] = None
severity_filter: str = "all" # all, major, minor
class DisruptionEntry(BaseModel):
id: str
line: Optional[str] = None
segment: Optional[list[str]] = None
type: str
severity: str
message: str
alternative: Optional[str] = None
eta_resolution: Optional[str] = None
valid_from: Optional[str] = None
valid_until: Optional[str] = None
class DisruptionFeedResponse(BaseModel):
disruptions: list[DisruptionEntry]
# --- /knowledge_base --------------------------------------------------------
class KnowledgeBaseRequest(BaseModel):
policy_id: str = ""
query: str = ""
category: str = "general"
case_id: Optional[str] = None
class KnowledgeBaseResult(BaseModel):
title: str
content: str
policy_id: str
class KnowledgeBaseResponse(BaseModel):
results: list[KnowledgeBaseResult]
found: bool
# --- /submit_assistant_state ------------------------------------------------
class RouteInfo(BaseModel):
origin: str
destination: str
stops: list
transfers: int
estimated_minutes: float
distance_miles: float
line_sequence: list[str]
class AdvisoryBanner(BaseModel):
severity: str
title: str
body: str
class FareQuoteInfo(BaseModel):
passenger_summary: Optional[dict] = None
line_items: list[dict] = Field(default_factory=list)
discounts: list[dict] = Field(default_factory=list)
total: float
currency: str
class KioskAction(BaseModel):
action: str
reason_code: str
VALID_OUTCOMES = frozenset({
"route_and_fare_ready", "advisory_only", "service_unavailable",
"request_declined", "policy_answer_only",
})
VALID_ACTIONS = frozenset({
"display_info", "prompt_purchase", "block_purchase", "refer_to_staff",
})
VALID_REASON_CODES = frozenset({
"ok", "no_service", "invalid_request", "unsupported_request",
"accessibility_issue", "policy_exception",
})
class SubmitAssistantStateRequest(BaseModel):
outcome: str
route: Optional[RouteInfo] = None
fare_quote: Optional[FareQuoteInfo] = None
kiosk_action: KioskAction
advisory_banners: list[AdvisoryBanner] = Field(default_factory=list)
assistant_message: str
reasoning: str = ""
case_id: Optional[str] = None
# --- /simulate ---------------------------------------------------------------
class SimulateRequest(BaseModel):
system: str
origin: str
destination: str
adults: int = 1
children: int = 0
seniors: int = 0
disabled: int = 0
current_time: str = "" # ISO naive local, e.g. "2026-04-06T14:00:00"
day_of_week: str = ""
disruptions: list[dict] = Field(default_factory=list)
freetext: str = ""
accessibility_mode: bool = False
# Cat F routing-impact policies (permanent operating patterns like
# BART Yellow night shuttle, MARTA Green short-turn, CTA State/Lake
# closure) are announced as prompt-level policy text, not disruptions.
policy_change: Optional[dict] = None
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _route_id(origin: str, destination: str) -> str:
"""Deterministic route_id derived from origin + destination."""
raw = f"route:{origin.lower()}:{destination.lower()}"
return "route_" + hashlib.sha256(raw.encode()).hexdigest()[:12]
def _fare_id(route_id: str, passengers: PassengerCounts, ticket_type: str) -> str:
"""Deterministic fare_id derived from route + passengers + ticket type."""
raw = (
f"fare:{route_id}:{passengers.adults}:{passengers.children}:"
f"{passengers.seniors}:{passengers.disabled}:{ticket_type}"
)
return "fare_" + hashlib.sha256(raw.encode()).hexdigest()[:12]
def _station_subset(station_data: dict, query_type: str) -> dict:
"""Return the relevant subset of station data for the requested query type."""
if query_type == "accessibility":
return {
"name": station_data.get("name"),
"accessibility": station_data.get("accessibility", {}),
}
if query_type == "facilities":
# Return all scalar/non-graph metadata; most systems store extra keys here
return {
k: v
for k, v in station_data.items()
if k not in {"connections"}
}
if query_type == "exits":
# Exits may not be a dedicated field; surface what is available
return {
"name": station_data.get("name"),
"type": station_data.get("type"),
"zone": station_data.get("zone"),
}
if query_type == "connections":
return {
"name": station_data.get("name"),
"lines": station_data.get("lines", []),
"connections": station_data.get("connections", []),
}
if query_type == "real_time_status":
# The mock server has no live data; return a static "operational" status
return {
"name": station_data.get("name"),
"status": "operational",
"alerts": [],
}
# Should not reach here after validation, but return everything as fallback
return station_data
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/health")
def health() -> dict:
return {"status": "ok"}
@app.post("/route_planner", response_model=RoutePlannerResponse)
def route_planner(req: RoutePlannerRequest) -> RoutePlannerResponse:
sd = _system_for_case(req.case_id)
metro = sd.metro
try:
if req.station_restrictions or req.segment_closures or req.line_closures:
restrictions = [
{"station": r.station, "restriction": r.restriction}
for r in (req.station_restrictions or [])
]
segments = [tuple(s) for s in (req.segment_closures or [])]
if req.line_closures:
closures_as_dicts = []
for lc in req.line_closures:
cd = {"line": sd.line_alias.get(lc.line.lower(), lc.line)}
if lc.from_station is not None:
cd["from_station"] = lc.from_station
if lc.to_station is not None:
cd["to_station"] = lc.to_station
closures_as_dicts.append(cd)
segments.extend(metro.expand_line_closures(closures_as_dicts))
result = metro.shortest_path_with_restrictions(
req.origin, req.destination,
station_restrictions=restrictions,
segment_closures=segments,
)
else:
result = metro.shortest_path(req.origin, req.destination)
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except nx.NetworkXNoPath as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except nx.NodeNotFound as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
stops = [
StopInfo(
station_id=s["station_id"],
station_name=s["station_name"],
line=s.get("line"),
is_transfer=s.get("is_transfer", False),
transfer_to=s.get("transfer_to"),
)
for s in result.stations
]
rid = _route_id(req.origin, req.destination)
# Cache route details for fare_calculator surcharge lookups
origin_id = result.stations[0]["station_id"] if result.stations else req.origin
dest_id = result.stations[-1]["station_id"] if result.stations else req.destination
sd.route_cache[rid] = {
"origin": origin_id,
"destination": dest_id,
"distance_miles": result.distance_miles,
}
return RoutePlannerResponse(
route_id=rid,
stops=stops,
transfers=result.transfers,
estimated_minutes=result.estimated_minutes,
distance_miles=result.distance_miles,
line_sequence=result.line_sequence,
)
@app.post("/fare_calculator", response_model=FareCalculatorResponse)
def fare_calculator(req: FareCalculatorRequest) -> FareCalculatorResponse:
sd = _system_for_case(req.case_id)
passengers_dict = {
"adults": req.passengers.adults,
"children": req.passengers.children,
"seniors": req.passengers.seniors,
"disabled": req.passengers.disabled,
}
# Look up cached route details for distance-based fare models
cached = sd.route_cache.get(req.route_id, {})
try:
result = sd.fares.calculate(
passengers=passengers_dict,
ticket_type=req.ticket_type,
payment_method=req.payment_method,
route_distance_miles=cached.get("distance_miles"),
origin_id=cached.get("origin"),
destination_id=cached.get("destination"),
)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
except NotImplementedError as exc:
raise HTTPException(status_code=501, detail=str(exc)) from exc
return FareCalculatorResponse(
fare_id=_fare_id(req.route_id, req.passengers, req.ticket_type),
line_items=[LineItem(**item) for item in result.items],
subtotal=result.subtotal,
discounts=[Discount(**d) for d in result.discounts],
total=result.total,
currency=result.currency,
)
@app.post("/station_info")
def station_info(req: StationInfoRequest) -> StationInfoResponse | StationInfoBatchResponse:
if req.query_type not in VALID_QUERY_TYPES:
raise HTTPException(
status_code=422,
detail=(
f"Invalid query_type '{req.query_type}'. "
f"Must be one of: {sorted(VALID_QUERY_TYPES)}"
),
)
# Batch mode: multiple stations in one call
ids = req.station_ids or ([req.station_id] if req.station_id else [])
if not ids:
raise HTTPException(status_code=422, detail="Provide station_id or station_ids")
sd = _system_for_case(req.case_id)
results = []
for sid in ids:
data = sd.metro.station_info(sid)
if data is None:
raise HTTPException(
status_code=404,
detail=f"Station '{sid}' not found",
)
results.append(StationInfoResponse(
station_id=sid,
data=_station_subset(data, req.query_type),
))
# Single station: return flat response (backwards compatible)
if len(results) == 1 and not req.station_ids:
return results[0]
return StationInfoBatchResponse(results=results)
def _build_line_info(sd, requested: str) -> LineInfoResponse:
metro = sd.metro
line_id = sd.line_alias.get(requested.lower(), requested)
if line_id not in metro.lines:
raise HTTPException(status_code=404, detail=f"Unknown line: {requested}")
line = metro.lines[line_id]
ordered: list[str] = list(line.get("stations", []))
terminals = metro.line_terminals(line_id)
is_loop = metro.is_loop_line(line_id)
entries: list[LineStationEntry] = []
for pos, sid in enumerate(ordered):
station = metro.stations.get(sid, {})
connections = sorted(metro.station_lines.get(sid, set()) - {line_id})
entries.append(LineStationEntry(
station_id=sid,
station_name=station.get("name", sid),
position=pos,
is_terminus=sid in terminals,
connections=connections,
))
return LineInfoResponse(
line_id=line_id,
line_name=line.get("name", line_id),
color=line.get("color", ""),
station_count=len(ordered),
is_loop=is_loop,
terminals=terminals,
stations=entries,
)
@app.post("/line_info")
def line_info(req: LineInfoRequest) -> LineInfoResponse | LineInfoBatchResponse:
requested = req.lines or ([req.line] if req.line else [])
if not requested:
raise HTTPException(status_code=422, detail="Provide line or lines")
sd = _system_for_case(req.case_id)
results = [_build_line_info(sd, r) for r in requested]
if len(results) == 1 and not req.lines:
return results[0]
return LineInfoBatchResponse(results=results)
@app.post("/set_disruptions")
def set_disruptions(payload: dict) -> dict:
case_id = payload.get("case_id", "_default")
_disruptions_by_case[case_id] = payload.get("disruptions", [])
if payload.get("system"):
_case_system[case_id] = payload["system"]
return {"ok": True}
@app.post("/disruption_feed", response_model=DisruptionFeedResponse)
def disruption_feed(req: DisruptionFeedRequest) -> DisruptionFeedResponse:
filtered = _disruptions_by_case.get(req.case_id or "_default", [])
if req.line:
# Normalize requested line to canonical ID via alias map
sd = _system_for_case(req.case_id)
req_canonical = sd.line_alias.get(req.line.lower())
# Keep disruptions that match the canonical line OR have no line (station closures affect all lines)
filtered = [
d for d in filtered
if not d.get("line") or d["line"].lower() == (req_canonical or req.line).lower()
]
if req.station:
filtered = [d for d in filtered
if (d.get("segment") and req.station in d["segment"])
or req.station.lower() in d.get("message", "").lower()]
if req.severity_filter == "major":
filtered = [d for d in filtered if d.get("severity") in ("critical", "warning")]
elif req.severity_filter == "minor":
filtered = [d for d in filtered if d.get("severity") == "info"]
# Temporal filtering: remove expired disruptions when current_time is provided.
# - Expired: valid_until is set and valid_until < current_time → filter out
# - Future: valid_from is set and valid_from > current_time → keep (announced, not yet active)
# - No temporal bounds: always active (backwards compatible)
# Uses lexicographic ISO 8601 string comparison (works for naive timestamps).
if req.current_time:
now = req.current_time
filtered = [
d for d in filtered
if not (d.get("valid_until") and d["valid_until"] < now)
]
entries = [DisruptionEntry(**d) for d in filtered]
return DisruptionFeedResponse(disruptions=entries)
@app.post("/knowledge_base", response_model=KnowledgeBaseResponse)
def knowledge_base(req: KnowledgeBaseRequest) -> KnowledgeBaseResponse:
sd = _system_for_case(req.case_id)
policies = sd.policies
# Exact lookup by policy_id (preferred path)
if req.policy_id:
for p in policies:
if p.get("policy_id") == req.policy_id:
return KnowledgeBaseResponse(
results=[KnowledgeBaseResult(
title=p.get("title", ""),
content=p.get("content", ""),
policy_id=req.policy_id,
)],
found=True,
)
return KnowledgeBaseResponse(results=[], found=False)
# Fallback: keyword search across all policies (no category gate)
if not req.query:
return KnowledgeBaseResponse(results=[], found=False)
query_words = [w.lower() for w in req.query.split() if len(w) > 2]
scored: list[tuple[int, dict]] = []
for policy in policies:
text = (policy.get("title", "") + " " + policy.get("content", "")).lower()
syns = " ".join(policy.get("synonyms", []))
text += " " + syns.lower()
hits = sum(1 for w in query_words if w in text)
if hits > 0:
scored.append((hits, policy))
# Sort by hit count descending, take top 3
scored.sort(key=lambda x: x[0], reverse=True)
top = [p for _, p in scored[:3]]
results = [
KnowledgeBaseResult(
title=p.get("title", ""),
content=p.get("content", ""),
policy_id=p.get("policy_id", p.get("id", "")),
)
for p in top
]
return KnowledgeBaseResponse(results=results, found=len(results) > 0)
@app.post("/submit_assistant_state")
def submit_assistant_state(req: SubmitAssistantStateRequest) -> dict:
"""Validate and accept the LLM's final assistant kiosk state.
Returns {"accepted": True} on success. On validation failure, Pydantic
raises a 422 with field-level error details before this handler runs.
Additional structural checks return 422 for conditional field violations.
"""
# Validate enum values
if req.outcome not in VALID_OUTCOMES:
raise HTTPException(status_code=422, detail=f"Invalid outcome: {req.outcome}")
if req.kiosk_action.action not in VALID_ACTIONS:
raise HTTPException(status_code=422, detail=f"Invalid action: {req.kiosk_action.action}")
if req.kiosk_action.reason_code not in VALID_REASON_CODES:
raise HTTPException(status_code=422, detail=f"Invalid reason_code: {req.kiosk_action.reason_code}")
# Conditional field validation
if req.outcome in ("route_and_fare_ready", "advisory_only") and req.route is None:
raise HTTPException(status_code=422, detail=f"route required when outcome={req.outcome}")
if req.outcome == "route_and_fare_ready" and req.fare_quote is None:
raise HTTPException(status_code=422, detail="fare_quote required when outcome=route_and_fare_ready")
# Validate route.stops are known station IDs
if req.route and req.route.stops:
sd = _system_for_case(req.case_id)
stop_ids = [s.get("station_id", s) if isinstance(s, dict) else s for s in req.route.stops]
invalid = [s for s in stop_ids if s not in sd.metro.stations]
if invalid:
raise HTTPException(
status_code=422,
detail=f"Unknown station IDs in route.stops: {invalid[:5]}. Use station_id values from route_planner (e.g. MARTA-AP).",
)
return {"accepted": True, "response_id": hashlib.sha256(
req.model_dump_json().encode()
).hexdigest()[:12]}
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Verify / interactive map endpoints
# ---------------------------------------------------------------------------
_verify_graphs: dict[str, MetroGraph] = {} # system_name → MetroGraph for all systems
def _get_verify_graph(system: str) -> MetroGraph:
if system not in _verify_graphs:
system_dir = Path(__file__).resolve().parent.parent / "data" / "systems" / system
if not system_dir.is_dir():
raise ValueError(f"Unknown system: {system}")
_verify_graphs[system] = MetroGraph(system_dir)
return _verify_graphs[system]
@app.get("/verify")
def verify_page() -> HTMLResponse:
verify_html = Path(__file__).resolve().parent.parent / "dashboard" / "verify.html"
if not verify_html.exists():
raise HTTPException(status_code=404, detail="verify.html not found")
return HTMLResponse(verify_html.read_text())
@app.get("/verify_data.json")
def verify_data() -> JSONResponse:
verify_json = Path(__file__).resolve().parent.parent / "dashboard" / "verify_data.json"
if not verify_json.exists():
raise HTTPException(status_code=404, detail="verify_data.json not found — run: uv run python data/verify.py --export-map")
return JSONResponse(json.loads(verify_json.read_text()))
@app.get("/annotate")
def annotate_page() -> HTMLResponse:
annotate_html = Path(__file__).resolve().parent.parent / "dashboard" / "annotate.html"
if not annotate_html.exists():
raise HTTPException(status_code=404, detail="annotate.html not found")
return HTMLResponse(annotate_html.read_text())
@app.get("/simulator")
def simulator_page() -> HTMLResponse:
simulator_html = Path(__file__).resolve().parent.parent / "dashboard" / "simulator.html"
if not simulator_html.exists():
raise HTTPException(status_code=404, detail="simulator.html not found")
return HTMLResponse(simulator_html.read_text())
@app.get("/systems")
def list_systems() -> JSONResponse:
systems_dir = Path(__file__).resolve().parent.parent / "data" / "systems"
systems = sorted(
d.name for d in systems_dir.iterdir()
if d.is_dir() and (d / "framebook.yaml").exists()
)
return JSONResponse(systems)
@app.get("/stations/{system}")
def stations_for_system(system: str) -> JSONResponse:
system_dir = Path(__file__).resolve().parent.parent / "data" / "systems" / system
if not system_dir.is_dir():
raise HTTPException(status_code=404, detail=f"Unknown system: {system}")
stations_path = system_dir / "stations.json"
lines_path = system_dir / "lines.json"
stations = json.loads(stations_path.read_text()) if stations_path.exists() else []
lines = json.loads(lines_path.read_text()) if lines_path.exists() else []
return JSONResponse({"stations": stations, "lines": lines})
@app.post("/simulate")
async def simulate(req: SimulateRequest) -> JSONResponse:
"""Run a single interactive kiosk case through the LLM and return the result."""
import httpx as _httpx
from harness.runner import BenchmarkRunner
# Build disruption objects for each user-injected disruption
disruptions = []
for i, d in enumerate(req.disruptions):
entry = {
"id": f"sim-disruption-{i}",
"type": d.get("type", "delay"),
"severity": d.get("severity", "warning"),
"message": d.get("message", ""),
"line": d.get("line") or None,
"segment": d.get("segment") or None,
"alternative": d.get("alternative") or None,
"valid_from": d.get("valid_from") or None,
"valid_until": d.get("valid_until") or None,
}
disruptions.append(entry)
# Build case dict matching the structure used by _run_single_case
case_id = f"sim-{req.system}-{req.origin[:8]}-{req.destination[:8]}".replace(" ", "_").lower()
events: list[dict] = [
{"type": "station_selected", "field": "origin", "value": req.origin},
{"type": "station_selected", "field": "destination", "value": req.destination},
{"type": "passenger_count_changed",
"adults": req.adults, "children": req.children,
"seniors": req.seniors, "disabled": req.disabled},
]
# Emit one disruption_update event per disruption. The runner's prompt
# builder appends each as a separate "⚠ DISRUPTION ALERT" line so the
# model sees every disruption in the user query, not just the first.
for d in disruptions:
events.append({
"type": "disruption_update",
"disruption": d,
})
if req.freetext:
events.append({"type": "freetext_input", "text": req.freetext})
system_context: dict = {}
if req.accessibility_mode:
system_context["accessibility_mode"] = True
if disruptions:
system_context["active_disruptions"] = disruptions
if req.current_time or req.day_of_week:
system_context["temporal_context"] = {
"current_time": req.current_time or "",
"day_of_week": req.day_of_week or "",
}
if req.policy_change:
system_context["policy_change"] = req.policy_change
case = {
"id": case_id,
"system": req.system,
"category": "simulator",
"events": events,
"system_context": system_context,
}
# Register case system + disruptions for tool endpoint routing
_case_system[case_id] = req.system
_disruptions_by_case[case_id] = disruptions
try:
_load_system(req.system)
except ValueError:
raise HTTPException(status_code=404, detail=f"Unknown system: {req.system}")
# GPT-5 family (including Azure deployments) only accepts temperature=1
is_gpt5_family = (
"azure.com" in _llm_base_url
or (_llm_model or "").startswith("gpt-5")
)
simulator_temperature = 1.0 if is_gpt5_family else 0.0
runner = BenchmarkRunner(
llm_base_url=_llm_base_url,
llm_api_key=_llm_api_key,
llm_model=_llm_model,
mock_server_url=f"http://localhost:{_port}",
system_name=req.system,
parallel=1,
max_tokens=4096,
thinking=False, # disable thinking for Haiku / API models
temperature=simulator_temperature,
max_tool_rounds=20,
)
try:
async with _httpx.AsyncClient() as client:
result = await runner._run_single_case(client, case)
except Exception as e:
raise HTTPException(status_code=500, detail=f"LLM error: {e}") from e
finally:
_case_system.pop(case_id, None)
_disruptions_by_case.pop(case_id, None)
return JSONResponse({"case": case, **dataclasses.asdict(result)})
@app.get("/calibration_cases_blind.json")
def calibration_cases_blind() -> JSONResponse:
cal_json = Path(__file__).resolve().parent.parent / "dashboard" / "calibration_cases_blind.json"
if not cal_json.exists():
raise HTTPException(status_code=404, detail="calibration_cases_blind.json not found")
return JSONResponse(json.loads(cal_json.read_text()))
@app.get("/calibration_cases.json")
def calibration_cases_full() -> JSONResponse:
"""Full calibration file with judge scores + reasoning. UI enforces blindness."""
cal_json = Path(__file__).resolve().parent.parent / "results" / "calibration_cases.json"
if not cal_json.exists():
raise HTTPException(status_code=404, detail="calibration_cases.json not found")
return JSONResponse(json.loads(cal_json.read_text()))
class VerifyRouteRequest(BaseModel):
origin: str
destination: str
system: str = ""
@app.post("/verify/route")
def verify_route(req: VerifyRouteRequest):
sys_name = req.system or _system_name
try:
metro = _get_verify_graph(sys_name) if sys_name else _load_system(_system_name).metro
except ValueError as exc:
return JSONResponse({"error": str(exc)}, status_code=404)
try:
result = metro.shortest_path(req.origin, req.destination)
except (ValueError, nx.NetworkXNoPath, nx.NodeNotFound) as exc:
return JSONResponse({"error": str(exc)}, status_code=404)
return {
"origin": req.origin,
"destination": req.destination,
"path": result.path,
"stops": [
{
"station_id": s["station_id"],
"station_name": s["station_name"],
"line": s.get("line"),
"is_transfer": s.get("is_transfer", False),
}
for s in result.stations
],
"transfers": result.transfers,
"distance_miles": result.distance_miles,
"estimated_minutes": result.estimated_minutes,
"line_sequence": result.line_sequence,
"system": _system_name,
}
def main() -> None:
# Resolve default LLM config from .env before parsing args.
# Prefer Azure gpt-5.4-mini when AZURE_* vars are present; fall back to Anthropic Haiku.
default_llm_url = "https://api.anthropic.com/v1"
default_llm_key = ""
default_llm_model = "claude-haiku-4-5-20251001"
try:
from dotenv import load_dotenv
import os
load_dotenv()
azure_endpoint = os.environ.get("AZURE_ENDPOINT", "").rstrip("/")
azure_key = os.environ.get("AZURE_OPENAI_API_KEY", "")
azure_mini = os.environ.get("AZURE_MINI_LLM_DEPLOYMENT", "")
if azure_endpoint and azure_key and azure_mini:
default_llm_url = f"{azure_endpoint}/openai/deployments/{azure_mini}?api-version=2024-10-21"
default_llm_key = azure_key
default_llm_model = azure_mini
else:
default_llm_key = os.environ.get("ANTHROPIC_API_KEY", "")
except ImportError:
pass
parser = argparse.ArgumentParser(description="MetroLLM-Bench mock tool server")
parser.add_argument(
"--system",
default="marta",
help="Transit system name (must exist under data/systems/). Default: marta",
)
parser.add_argument(
"--port",
type=int,
default=8100,
help="Port to listen on. Default: 8100",
)
parser.add_argument(
"--llm-url",
default=default_llm_url,
help="LLM API base URL for /simulate endpoint. Default: Azure gpt-5.4-mini if AZURE_* env vars set, else https://api.anthropic.com/v1",
)
parser.add_argument(
"--llm-key",
default=default_llm_key,
help="LLM API key. Default: AZURE_OPENAI_API_KEY or ANTHROPIC_API_KEY from .env",
)
parser.add_argument(
"--llm-model",
default=default_llm_model,
help="LLM model name (Azure deployment name for Azure). Default: AZURE_MINI_LLM_DEPLOYMENT or claude-haiku-4-5-20251001",
)
args = parser.parse_args()
system_dir = (
Path(__file__).resolve().parent.parent
/ "data"
/ "systems"
/ args.system
)
if not system_dir.is_dir():
print(
f"Error: system directory not found: {system_dir}",
file=sys.stderr,
)
sys.exit(1)
global _system_name
global _llm_base_url, _llm_api_key, _llm_model, _port
_system_name = args.system
_llm_base_url = args.llm_url
_llm_api_key = args.llm_key
_llm_model = args.llm_model
_port = args.port
sd = _load_system(args.system)
print(f"Loaded system '{args.system}' ({len(sd.policies)} policies) from {system_dir}")
print(f"Simulator LLM: {args.llm_model} @ {args.llm_url}")
uvicorn.run(app, host="0.0.0.0", port=args.port)
if __name__ == "__main__":
main()