| """Mock tool server for MetroLLM-Bench. |
| |
| Exposes three transit tool endpoints that the benchmark runner forwards LLM |
| tool calls to: |
| POST /route_planner |
| POST /fare_calculator |
| POST /station_info |
| |
| Run via: |
| uvicorn harness.mock_server:app --port 8100 |
| or via the project entry-point: |
| mock-server --system marta --port 8100 |
| """ |
|
|
| import argparse |
| import dataclasses |
| import hashlib |
| import json |
| import sys |
| from pathlib import Path |
| from typing import Optional |
|
|
| import networkx as nx |
| import uvicorn |
| from fastapi import FastAPI, HTTPException |
| from fastapi.responses import HTMLResponse, JSONResponse, FileResponse |
| from pydantic import BaseModel, Field |
|
|
| from harness.graph import MetroGraph |
| from harness.fares import FareCalculator |
|
|
| |
| |
| |
|
|
| _llm_base_url: str = "https://api.anthropic.com/v1" |
| _llm_api_key: str = "" |
| _llm_model: str = "claude-haiku-4-5-20251001" |
| _port: int = 8100 |
|
|
| |
| |
| |
|
|
| app = FastAPI(title="MetroLLM-Bench Mock Tool Server") |
|
|
| @dataclasses.dataclass |
| class _SystemData: |
| """Per-system data loaded lazily and cached.""" |
| metro: MetroGraph |
| fares: FareCalculator |
| policies: list[dict] |
| line_alias: dict[str, str] |
| route_cache: dict[str, dict] = dataclasses.field(default_factory=dict) |
|
|
|
|
| _systems: dict[str, _SystemData] = {} |
| _case_system: dict[str, str] = {} |
| _system_name: str = "" |
| _disruptions_by_case: dict[str, list[dict]] = {} |
|
|
|
|
| def _build_line_alias(system_dir: Path) -> dict[str, str]: |
| """Build alias→canonical_id map from lines.json. |
| |
| For a line with id="1", name="Line 1", generates: |
| "1" → "1", "line 1" → "1" |
| For id="red", name="Red Line": |
| "red" → "red", "red line" → "red" |
| """ |
| alias: dict[str, str] = {} |
| lines_path = system_dir / "lines.json" |
| if not lines_path.exists(): |
| return alias |
| with open(lines_path) as f: |
| lines = json.load(f) |
| for line in lines: |
| lid = line["id"] |
| alias[lid.lower()] = lid |
| if line.get("name"): |
| alias[line["name"].lower()] = lid |
| return alias |
|
|
|
|
| _DATA_ROOT = Path(__file__).resolve().parent.parent / "data" / "systems" |
|
|
|
|
| def _load_system(name: str) -> _SystemData: |
| """Load system data from disk, caching for subsequent calls.""" |
| if name in _systems: |
| return _systems[name] |
| sys_dir = _DATA_ROOT / name |
| if not sys_dir.is_dir(): |
| raise ValueError(f"Unknown system: {name}") |
| policies_path = sys_dir / "policies.json" |
| if policies_path.exists(): |
| raw = json.loads(policies_path.read_text()) |
| policies: list[dict] = raw["policies"] if isinstance(raw, dict) and "policies" in raw else raw |
| else: |
| policies = [] |
| sd = _SystemData( |
| metro=MetroGraph(sys_dir), |
| fares=FareCalculator(sys_dir), |
| policies=policies, |
| line_alias=_build_line_alias(sys_dir), |
| ) |
| _systems[name] = sd |
| return sd |
|
|
|
|
| def _system_for_case(case_id: str | None) -> _SystemData: |
| """Resolve system data for a case, falling back to startup default.""" |
| name = _case_system.get(case_id or "", _system_name) |
| if not name: |
| raise RuntimeError("No system configured") |
| return _load_system(name) |
|
|
|
|
| |
| |
| |
|
|
| |
|
|
| class StationRestriction(BaseModel): |
| station: str |
| restriction: str |
|
|
|
|
| class LineClosure(BaseModel): |
| line: str |
| from_station: Optional[str] = None |
| to_station: Optional[str] = None |
|
|
|
|
| class RoutePlannerRequest(BaseModel): |
| origin: str |
| destination: str |
| departure_time: Optional[str] = None |
| accessibility: Optional[list[str]] = None |
| station_restrictions: Optional[list[StationRestriction]] = None |
| segment_closures: Optional[list[list[str]]] = None |
| line_closures: Optional[list[LineClosure]] = None |
| case_id: Optional[str] = None |
|
|
|
|
| class StopInfo(BaseModel): |
| station_id: str |
| station_name: str |
| line: Optional[str] |
| is_transfer: bool |
| transfer_to: Optional[str] |
|
|
|
|
| class RoutePlannerResponse(BaseModel): |
| route_id: str |
| stops: list[StopInfo] |
| transfers: int |
| estimated_minutes: float |
| distance_miles: float |
| line_sequence: list[str] |
|
|
|
|
| |
|
|
| class PassengerCounts(BaseModel): |
| adults: int = Field(default=0, ge=0) |
| children: int = Field(default=0, ge=0) |
| seniors: int = Field(default=0, ge=0) |
| disabled: int = Field(default=0, ge=0) |
|
|
|
|
| class FareCalculatorRequest(BaseModel): |
| route_id: str |
| passengers: PassengerCounts |
| ticket_type: str = "single" |
| payment_method: str = "breeze_card" |
| case_id: Optional[str] = None |
|
|
|
|
| class LineItem(BaseModel): |
| label: str |
| amount: float |
| currency: str |
|
|
|
|
| class Discount(BaseModel): |
| label: str |
| amount: float |
| currency: str |
|
|
|
|
| class FareCalculatorResponse(BaseModel): |
| fare_id: str |
| line_items: list[LineItem] |
| subtotal: float |
| discounts: list[Discount] |
| total: float |
| currency: str |
|
|
|
|
| |
|
|
| VALID_QUERY_TYPES = frozenset( |
| {"accessibility", "facilities", "exits", "connections", "real_time_status"} |
| ) |
|
|
|
|
| class StationInfoRequest(BaseModel): |
| station_id: Optional[str] = None |
| station_ids: Optional[list[str]] = None |
| query_type: str |
| case_id: Optional[str] = None |
|
|
|
|
| class StationInfoResponse(BaseModel): |
| station_id: str |
| data: dict |
|
|
|
|
| class StationInfoBatchResponse(BaseModel): |
| results: list[StationInfoResponse] |
|
|
|
|
| |
|
|
| class LineInfoRequest(BaseModel): |
| line: Optional[str] = None |
| lines: Optional[list[str]] = None |
| case_id: Optional[str] = None |
|
|
|
|
| class LineStationEntry(BaseModel): |
| station_id: str |
| station_name: str |
| position: int |
| is_terminus: bool |
| connections: list[str] |
|
|
|
|
| class LineInfoResponse(BaseModel): |
| line_id: str |
| line_name: str |
| color: str |
| station_count: int |
| is_loop: bool |
| terminals: list[str] |
| stations: list[LineStationEntry] |
|
|
|
|
| class LineInfoBatchResponse(BaseModel): |
| results: list[LineInfoResponse] |
|
|
|
|
| |
|
|
| class DisruptionFeedRequest(BaseModel): |
| case_id: Optional[str] = None |
| current_time: Optional[str] = None |
| line: Optional[str] = None |
| station: Optional[str] = None |
| severity_filter: str = "all" |
|
|
|
|
| class DisruptionEntry(BaseModel): |
| id: str |
| line: Optional[str] = None |
| segment: Optional[list[str]] = None |
| type: str |
| severity: str |
| message: str |
| alternative: Optional[str] = None |
| eta_resolution: Optional[str] = None |
| valid_from: Optional[str] = None |
| valid_until: Optional[str] = None |
|
|
|
|
| class DisruptionFeedResponse(BaseModel): |
| disruptions: list[DisruptionEntry] |
|
|
|
|
| |
|
|
| class KnowledgeBaseRequest(BaseModel): |
| policy_id: str = "" |
| query: str = "" |
| category: str = "general" |
| case_id: Optional[str] = None |
|
|
|
|
| class KnowledgeBaseResult(BaseModel): |
| title: str |
| content: str |
| policy_id: str |
|
|
|
|
| class KnowledgeBaseResponse(BaseModel): |
| results: list[KnowledgeBaseResult] |
| found: bool |
|
|
|
|
| |
|
|
| class RouteInfo(BaseModel): |
| origin: str |
| destination: str |
| stops: list |
| transfers: int |
| estimated_minutes: float |
| distance_miles: float |
| line_sequence: list[str] |
|
|
| class AdvisoryBanner(BaseModel): |
| severity: str |
| title: str |
| body: str |
|
|
| class FareQuoteInfo(BaseModel): |
| passenger_summary: Optional[dict] = None |
| line_items: list[dict] = Field(default_factory=list) |
| discounts: list[dict] = Field(default_factory=list) |
| total: float |
| currency: str |
|
|
| class KioskAction(BaseModel): |
| action: str |
| reason_code: str |
|
|
| VALID_OUTCOMES = frozenset({ |
| "route_and_fare_ready", "advisory_only", "service_unavailable", |
| "request_declined", "policy_answer_only", |
| }) |
| VALID_ACTIONS = frozenset({ |
| "display_info", "prompt_purchase", "block_purchase", "refer_to_staff", |
| }) |
| VALID_REASON_CODES = frozenset({ |
| "ok", "no_service", "invalid_request", "unsupported_request", |
| "accessibility_issue", "policy_exception", |
| }) |
|
|
| class SubmitAssistantStateRequest(BaseModel): |
| outcome: str |
| route: Optional[RouteInfo] = None |
| fare_quote: Optional[FareQuoteInfo] = None |
| kiosk_action: KioskAction |
| advisory_banners: list[AdvisoryBanner] = Field(default_factory=list) |
| assistant_message: str |
| reasoning: str = "" |
| case_id: Optional[str] = None |
|
|
|
|
| |
|
|
| class SimulateRequest(BaseModel): |
| system: str |
| origin: str |
| destination: str |
| adults: int = 1 |
| children: int = 0 |
| seniors: int = 0 |
| disabled: int = 0 |
| current_time: str = "" |
| day_of_week: str = "" |
| disruptions: list[dict] = Field(default_factory=list) |
| freetext: str = "" |
| accessibility_mode: bool = False |
| |
| |
| |
| policy_change: Optional[dict] = None |
|
|
|
|
| |
| |
| |
|
|
| def _route_id(origin: str, destination: str) -> str: |
| """Deterministic route_id derived from origin + destination.""" |
| raw = f"route:{origin.lower()}:{destination.lower()}" |
| return "route_" + hashlib.sha256(raw.encode()).hexdigest()[:12] |
|
|
|
|
| def _fare_id(route_id: str, passengers: PassengerCounts, ticket_type: str) -> str: |
| """Deterministic fare_id derived from route + passengers + ticket type.""" |
| raw = ( |
| f"fare:{route_id}:{passengers.adults}:{passengers.children}:" |
| f"{passengers.seniors}:{passengers.disabled}:{ticket_type}" |
| ) |
| return "fare_" + hashlib.sha256(raw.encode()).hexdigest()[:12] |
|
|
|
|
| def _station_subset(station_data: dict, query_type: str) -> dict: |
| """Return the relevant subset of station data for the requested query type.""" |
| if query_type == "accessibility": |
| return { |
| "name": station_data.get("name"), |
| "accessibility": station_data.get("accessibility", {}), |
| } |
| if query_type == "facilities": |
| |
| return { |
| k: v |
| for k, v in station_data.items() |
| if k not in {"connections"} |
| } |
| if query_type == "exits": |
| |
| return { |
| "name": station_data.get("name"), |
| "type": station_data.get("type"), |
| "zone": station_data.get("zone"), |
| } |
| if query_type == "connections": |
| return { |
| "name": station_data.get("name"), |
| "lines": station_data.get("lines", []), |
| "connections": station_data.get("connections", []), |
| } |
| if query_type == "real_time_status": |
| |
| return { |
| "name": station_data.get("name"), |
| "status": "operational", |
| "alerts": [], |
| } |
| |
| return station_data |
|
|
|
|
| |
| |
| |
|
|
| @app.get("/health") |
| def health() -> dict: |
| return {"status": "ok"} |
|
|
|
|
| @app.post("/route_planner", response_model=RoutePlannerResponse) |
| def route_planner(req: RoutePlannerRequest) -> RoutePlannerResponse: |
| sd = _system_for_case(req.case_id) |
| metro = sd.metro |
|
|
| try: |
| if req.station_restrictions or req.segment_closures or req.line_closures: |
| restrictions = [ |
| {"station": r.station, "restriction": r.restriction} |
| for r in (req.station_restrictions or []) |
| ] |
| segments = [tuple(s) for s in (req.segment_closures or [])] |
| if req.line_closures: |
| closures_as_dicts = [] |
| for lc in req.line_closures: |
| cd = {"line": sd.line_alias.get(lc.line.lower(), lc.line)} |
| if lc.from_station is not None: |
| cd["from_station"] = lc.from_station |
| if lc.to_station is not None: |
| cd["to_station"] = lc.to_station |
| closures_as_dicts.append(cd) |
| segments.extend(metro.expand_line_closures(closures_as_dicts)) |
| result = metro.shortest_path_with_restrictions( |
| req.origin, req.destination, |
| station_restrictions=restrictions, |
| segment_closures=segments, |
| ) |
| else: |
| result = metro.shortest_path(req.origin, req.destination) |
| except ValueError as exc: |
| raise HTTPException(status_code=404, detail=str(exc)) from exc |
| except nx.NetworkXNoPath as exc: |
| raise HTTPException(status_code=404, detail=str(exc)) from exc |
| except nx.NodeNotFound as exc: |
| raise HTTPException(status_code=404, detail=str(exc)) from exc |
|
|
| stops = [ |
| StopInfo( |
| station_id=s["station_id"], |
| station_name=s["station_name"], |
| line=s.get("line"), |
| is_transfer=s.get("is_transfer", False), |
| transfer_to=s.get("transfer_to"), |
| ) |
| for s in result.stations |
| ] |
|
|
| rid = _route_id(req.origin, req.destination) |
|
|
| |
| origin_id = result.stations[0]["station_id"] if result.stations else req.origin |
| dest_id = result.stations[-1]["station_id"] if result.stations else req.destination |
| sd.route_cache[rid] = { |
| "origin": origin_id, |
| "destination": dest_id, |
| "distance_miles": result.distance_miles, |
| } |
|
|
| return RoutePlannerResponse( |
| route_id=rid, |
| stops=stops, |
| transfers=result.transfers, |
| estimated_minutes=result.estimated_minutes, |
| distance_miles=result.distance_miles, |
| line_sequence=result.line_sequence, |
| ) |
|
|
|
|
| @app.post("/fare_calculator", response_model=FareCalculatorResponse) |
| def fare_calculator(req: FareCalculatorRequest) -> FareCalculatorResponse: |
| sd = _system_for_case(req.case_id) |
|
|
| passengers_dict = { |
| "adults": req.passengers.adults, |
| "children": req.passengers.children, |
| "seniors": req.passengers.seniors, |
| "disabled": req.passengers.disabled, |
| } |
|
|
| |
| cached = sd.route_cache.get(req.route_id, {}) |
|
|
| try: |
| result = sd.fares.calculate( |
| passengers=passengers_dict, |
| ticket_type=req.ticket_type, |
| payment_method=req.payment_method, |
| route_distance_miles=cached.get("distance_miles"), |
| origin_id=cached.get("origin"), |
| destination_id=cached.get("destination"), |
| ) |
| except ValueError as exc: |
| raise HTTPException(status_code=422, detail=str(exc)) from exc |
| except NotImplementedError as exc: |
| raise HTTPException(status_code=501, detail=str(exc)) from exc |
|
|
| return FareCalculatorResponse( |
| fare_id=_fare_id(req.route_id, req.passengers, req.ticket_type), |
| line_items=[LineItem(**item) for item in result.items], |
| subtotal=result.subtotal, |
| discounts=[Discount(**d) for d in result.discounts], |
| total=result.total, |
| currency=result.currency, |
| ) |
|
|
|
|
| @app.post("/station_info") |
| def station_info(req: StationInfoRequest) -> StationInfoResponse | StationInfoBatchResponse: |
| if req.query_type not in VALID_QUERY_TYPES: |
| raise HTTPException( |
| status_code=422, |
| detail=( |
| f"Invalid query_type '{req.query_type}'. " |
| f"Must be one of: {sorted(VALID_QUERY_TYPES)}" |
| ), |
| ) |
|
|
| |
| ids = req.station_ids or ([req.station_id] if req.station_id else []) |
| if not ids: |
| raise HTTPException(status_code=422, detail="Provide station_id or station_ids") |
|
|
| sd = _system_for_case(req.case_id) |
| results = [] |
| for sid in ids: |
| data = sd.metro.station_info(sid) |
| if data is None: |
| raise HTTPException( |
| status_code=404, |
| detail=f"Station '{sid}' not found", |
| ) |
| results.append(StationInfoResponse( |
| station_id=sid, |
| data=_station_subset(data, req.query_type), |
| )) |
|
|
| |
| if len(results) == 1 and not req.station_ids: |
| return results[0] |
|
|
| return StationInfoBatchResponse(results=results) |
|
|
|
|
| def _build_line_info(sd, requested: str) -> LineInfoResponse: |
| metro = sd.metro |
| line_id = sd.line_alias.get(requested.lower(), requested) |
| if line_id not in metro.lines: |
| raise HTTPException(status_code=404, detail=f"Unknown line: {requested}") |
| line = metro.lines[line_id] |
| ordered: list[str] = list(line.get("stations", [])) |
| terminals = metro.line_terminals(line_id) |
| is_loop = metro.is_loop_line(line_id) |
| entries: list[LineStationEntry] = [] |
| for pos, sid in enumerate(ordered): |
| station = metro.stations.get(sid, {}) |
| connections = sorted(metro.station_lines.get(sid, set()) - {line_id}) |
| entries.append(LineStationEntry( |
| station_id=sid, |
| station_name=station.get("name", sid), |
| position=pos, |
| is_terminus=sid in terminals, |
| connections=connections, |
| )) |
| return LineInfoResponse( |
| line_id=line_id, |
| line_name=line.get("name", line_id), |
| color=line.get("color", ""), |
| station_count=len(ordered), |
| is_loop=is_loop, |
| terminals=terminals, |
| stations=entries, |
| ) |
|
|
|
|
| @app.post("/line_info") |
| def line_info(req: LineInfoRequest) -> LineInfoResponse | LineInfoBatchResponse: |
| requested = req.lines or ([req.line] if req.line else []) |
| if not requested: |
| raise HTTPException(status_code=422, detail="Provide line or lines") |
| sd = _system_for_case(req.case_id) |
| results = [_build_line_info(sd, r) for r in requested] |
| if len(results) == 1 and not req.lines: |
| return results[0] |
| return LineInfoBatchResponse(results=results) |
|
|
|
|
| @app.post("/set_disruptions") |
| def set_disruptions(payload: dict) -> dict: |
| case_id = payload.get("case_id", "_default") |
| _disruptions_by_case[case_id] = payload.get("disruptions", []) |
| if payload.get("system"): |
| _case_system[case_id] = payload["system"] |
| return {"ok": True} |
|
|
|
|
| @app.post("/disruption_feed", response_model=DisruptionFeedResponse) |
| def disruption_feed(req: DisruptionFeedRequest) -> DisruptionFeedResponse: |
| filtered = _disruptions_by_case.get(req.case_id or "_default", []) |
|
|
| if req.line: |
| |
| sd = _system_for_case(req.case_id) |
| req_canonical = sd.line_alias.get(req.line.lower()) |
| |
| filtered = [ |
| d for d in filtered |
| if not d.get("line") or d["line"].lower() == (req_canonical or req.line).lower() |
| ] |
|
|
| if req.station: |
| filtered = [d for d in filtered |
| if (d.get("segment") and req.station in d["segment"]) |
| or req.station.lower() in d.get("message", "").lower()] |
|
|
| if req.severity_filter == "major": |
| filtered = [d for d in filtered if d.get("severity") in ("critical", "warning")] |
| elif req.severity_filter == "minor": |
| filtered = [d for d in filtered if d.get("severity") == "info"] |
|
|
| |
| |
| |
| |
| |
| if req.current_time: |
| now = req.current_time |
| filtered = [ |
| d for d in filtered |
| if not (d.get("valid_until") and d["valid_until"] < now) |
| ] |
|
|
| entries = [DisruptionEntry(**d) for d in filtered] |
| return DisruptionFeedResponse(disruptions=entries) |
|
|
|
|
| @app.post("/knowledge_base", response_model=KnowledgeBaseResponse) |
| def knowledge_base(req: KnowledgeBaseRequest) -> KnowledgeBaseResponse: |
| sd = _system_for_case(req.case_id) |
| policies = sd.policies |
|
|
| |
| if req.policy_id: |
| for p in policies: |
| if p.get("policy_id") == req.policy_id: |
| return KnowledgeBaseResponse( |
| results=[KnowledgeBaseResult( |
| title=p.get("title", ""), |
| content=p.get("content", ""), |
| policy_id=req.policy_id, |
| )], |
| found=True, |
| ) |
| return KnowledgeBaseResponse(results=[], found=False) |
|
|
| |
| if not req.query: |
| return KnowledgeBaseResponse(results=[], found=False) |
|
|
| query_words = [w.lower() for w in req.query.split() if len(w) > 2] |
| scored: list[tuple[int, dict]] = [] |
| for policy in policies: |
| text = (policy.get("title", "") + " " + policy.get("content", "")).lower() |
| syns = " ".join(policy.get("synonyms", [])) |
| text += " " + syns.lower() |
| hits = sum(1 for w in query_words if w in text) |
| if hits > 0: |
| scored.append((hits, policy)) |
|
|
| |
| scored.sort(key=lambda x: x[0], reverse=True) |
| top = [p for _, p in scored[:3]] |
|
|
| results = [ |
| KnowledgeBaseResult( |
| title=p.get("title", ""), |
| content=p.get("content", ""), |
| policy_id=p.get("policy_id", p.get("id", "")), |
| ) |
| for p in top |
| ] |
|
|
| return KnowledgeBaseResponse(results=results, found=len(results) > 0) |
|
|
|
|
| @app.post("/submit_assistant_state") |
| def submit_assistant_state(req: SubmitAssistantStateRequest) -> dict: |
| """Validate and accept the LLM's final assistant kiosk state. |
| |
| Returns {"accepted": True} on success. On validation failure, Pydantic |
| raises a 422 with field-level error details before this handler runs. |
| Additional structural checks return 422 for conditional field violations. |
| """ |
| |
| if req.outcome not in VALID_OUTCOMES: |
| raise HTTPException(status_code=422, detail=f"Invalid outcome: {req.outcome}") |
| if req.kiosk_action.action not in VALID_ACTIONS: |
| raise HTTPException(status_code=422, detail=f"Invalid action: {req.kiosk_action.action}") |
| if req.kiosk_action.reason_code not in VALID_REASON_CODES: |
| raise HTTPException(status_code=422, detail=f"Invalid reason_code: {req.kiosk_action.reason_code}") |
|
|
| |
| if req.outcome in ("route_and_fare_ready", "advisory_only") and req.route is None: |
| raise HTTPException(status_code=422, detail=f"route required when outcome={req.outcome}") |
| if req.outcome == "route_and_fare_ready" and req.fare_quote is None: |
| raise HTTPException(status_code=422, detail="fare_quote required when outcome=route_and_fare_ready") |
|
|
| |
| if req.route and req.route.stops: |
| sd = _system_for_case(req.case_id) |
| stop_ids = [s.get("station_id", s) if isinstance(s, dict) else s for s in req.route.stops] |
| invalid = [s for s in stop_ids if s not in sd.metro.stations] |
| if invalid: |
| raise HTTPException( |
| status_code=422, |
| detail=f"Unknown station IDs in route.stops: {invalid[:5]}. Use station_id values from route_planner (e.g. MARTA-AP).", |
| ) |
|
|
| return {"accepted": True, "response_id": hashlib.sha256( |
| req.model_dump_json().encode() |
| ).hexdigest()[:12]} |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| _verify_graphs: dict[str, MetroGraph] = {} |
|
|
| def _get_verify_graph(system: str) -> MetroGraph: |
| if system not in _verify_graphs: |
| system_dir = Path(__file__).resolve().parent.parent / "data" / "systems" / system |
| if not system_dir.is_dir(): |
| raise ValueError(f"Unknown system: {system}") |
| _verify_graphs[system] = MetroGraph(system_dir) |
| return _verify_graphs[system] |
|
|
| @app.get("/verify") |
| def verify_page() -> HTMLResponse: |
| verify_html = Path(__file__).resolve().parent.parent / "dashboard" / "verify.html" |
| if not verify_html.exists(): |
| raise HTTPException(status_code=404, detail="verify.html not found") |
| return HTMLResponse(verify_html.read_text()) |
|
|
|
|
| @app.get("/verify_data.json") |
| def verify_data() -> JSONResponse: |
| verify_json = Path(__file__).resolve().parent.parent / "dashboard" / "verify_data.json" |
| if not verify_json.exists(): |
| raise HTTPException(status_code=404, detail="verify_data.json not found — run: uv run python data/verify.py --export-map") |
| return JSONResponse(json.loads(verify_json.read_text())) |
|
|
|
|
| @app.get("/annotate") |
| def annotate_page() -> HTMLResponse: |
| annotate_html = Path(__file__).resolve().parent.parent / "dashboard" / "annotate.html" |
| if not annotate_html.exists(): |
| raise HTTPException(status_code=404, detail="annotate.html not found") |
| return HTMLResponse(annotate_html.read_text()) |
|
|
|
|
| @app.get("/simulator") |
| def simulator_page() -> HTMLResponse: |
| simulator_html = Path(__file__).resolve().parent.parent / "dashboard" / "simulator.html" |
| if not simulator_html.exists(): |
| raise HTTPException(status_code=404, detail="simulator.html not found") |
| return HTMLResponse(simulator_html.read_text()) |
|
|
|
|
| @app.get("/systems") |
| def list_systems() -> JSONResponse: |
| systems_dir = Path(__file__).resolve().parent.parent / "data" / "systems" |
| systems = sorted( |
| d.name for d in systems_dir.iterdir() |
| if d.is_dir() and (d / "framebook.yaml").exists() |
| ) |
| return JSONResponse(systems) |
|
|
|
|
| @app.get("/stations/{system}") |
| def stations_for_system(system: str) -> JSONResponse: |
| system_dir = Path(__file__).resolve().parent.parent / "data" / "systems" / system |
| if not system_dir.is_dir(): |
| raise HTTPException(status_code=404, detail=f"Unknown system: {system}") |
| stations_path = system_dir / "stations.json" |
| lines_path = system_dir / "lines.json" |
| stations = json.loads(stations_path.read_text()) if stations_path.exists() else [] |
| lines = json.loads(lines_path.read_text()) if lines_path.exists() else [] |
| return JSONResponse({"stations": stations, "lines": lines}) |
|
|
|
|
| @app.post("/simulate") |
| async def simulate(req: SimulateRequest) -> JSONResponse: |
| """Run a single interactive kiosk case through the LLM and return the result.""" |
| import httpx as _httpx |
| from harness.runner import BenchmarkRunner |
|
|
| |
| disruptions = [] |
| for i, d in enumerate(req.disruptions): |
| entry = { |
| "id": f"sim-disruption-{i}", |
| "type": d.get("type", "delay"), |
| "severity": d.get("severity", "warning"), |
| "message": d.get("message", ""), |
| "line": d.get("line") or None, |
| "segment": d.get("segment") or None, |
| "alternative": d.get("alternative") or None, |
| "valid_from": d.get("valid_from") or None, |
| "valid_until": d.get("valid_until") or None, |
| } |
| disruptions.append(entry) |
|
|
| |
| case_id = f"sim-{req.system}-{req.origin[:8]}-{req.destination[:8]}".replace(" ", "_").lower() |
|
|
| events: list[dict] = [ |
| {"type": "station_selected", "field": "origin", "value": req.origin}, |
| {"type": "station_selected", "field": "destination", "value": req.destination}, |
| {"type": "passenger_count_changed", |
| "adults": req.adults, "children": req.children, |
| "seniors": req.seniors, "disabled": req.disabled}, |
| ] |
| |
| |
| |
| for d in disruptions: |
| events.append({ |
| "type": "disruption_update", |
| "disruption": d, |
| }) |
| if req.freetext: |
| events.append({"type": "freetext_input", "text": req.freetext}) |
|
|
| system_context: dict = {} |
| if req.accessibility_mode: |
| system_context["accessibility_mode"] = True |
| if disruptions: |
| system_context["active_disruptions"] = disruptions |
| if req.current_time or req.day_of_week: |
| system_context["temporal_context"] = { |
| "current_time": req.current_time or "", |
| "day_of_week": req.day_of_week or "", |
| } |
| if req.policy_change: |
| system_context["policy_change"] = req.policy_change |
|
|
| case = { |
| "id": case_id, |
| "system": req.system, |
| "category": "simulator", |
| "events": events, |
| "system_context": system_context, |
| } |
|
|
| |
| _case_system[case_id] = req.system |
| _disruptions_by_case[case_id] = disruptions |
| try: |
| _load_system(req.system) |
| except ValueError: |
| raise HTTPException(status_code=404, detail=f"Unknown system: {req.system}") |
|
|
| |
| is_gpt5_family = ( |
| "azure.com" in _llm_base_url |
| or (_llm_model or "").startswith("gpt-5") |
| ) |
| simulator_temperature = 1.0 if is_gpt5_family else 0.0 |
|
|
| runner = BenchmarkRunner( |
| llm_base_url=_llm_base_url, |
| llm_api_key=_llm_api_key, |
| llm_model=_llm_model, |
| mock_server_url=f"http://localhost:{_port}", |
| system_name=req.system, |
| parallel=1, |
| max_tokens=4096, |
| thinking=False, |
| temperature=simulator_temperature, |
| max_tool_rounds=20, |
| ) |
|
|
| try: |
| async with _httpx.AsyncClient() as client: |
| result = await runner._run_single_case(client, case) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"LLM error: {e}") from e |
| finally: |
| _case_system.pop(case_id, None) |
| _disruptions_by_case.pop(case_id, None) |
|
|
| return JSONResponse({"case": case, **dataclasses.asdict(result)}) |
|
|
|
|
| @app.get("/calibration_cases_blind.json") |
| def calibration_cases_blind() -> JSONResponse: |
| cal_json = Path(__file__).resolve().parent.parent / "dashboard" / "calibration_cases_blind.json" |
| if not cal_json.exists(): |
| raise HTTPException(status_code=404, detail="calibration_cases_blind.json not found") |
| return JSONResponse(json.loads(cal_json.read_text())) |
|
|
|
|
| @app.get("/calibration_cases.json") |
| def calibration_cases_full() -> JSONResponse: |
| """Full calibration file with judge scores + reasoning. UI enforces blindness.""" |
| cal_json = Path(__file__).resolve().parent.parent / "results" / "calibration_cases.json" |
| if not cal_json.exists(): |
| raise HTTPException(status_code=404, detail="calibration_cases.json not found") |
| return JSONResponse(json.loads(cal_json.read_text())) |
|
|
|
|
| class VerifyRouteRequest(BaseModel): |
| origin: str |
| destination: str |
| system: str = "" |
|
|
|
|
| @app.post("/verify/route") |
| def verify_route(req: VerifyRouteRequest): |
| sys_name = req.system or _system_name |
| try: |
| metro = _get_verify_graph(sys_name) if sys_name else _load_system(_system_name).metro |
| except ValueError as exc: |
| return JSONResponse({"error": str(exc)}, status_code=404) |
| try: |
| result = metro.shortest_path(req.origin, req.destination) |
| except (ValueError, nx.NetworkXNoPath, nx.NodeNotFound) as exc: |
| return JSONResponse({"error": str(exc)}, status_code=404) |
|
|
| return { |
| "origin": req.origin, |
| "destination": req.destination, |
| "path": result.path, |
| "stops": [ |
| { |
| "station_id": s["station_id"], |
| "station_name": s["station_name"], |
| "line": s.get("line"), |
| "is_transfer": s.get("is_transfer", False), |
| } |
| for s in result.stations |
| ], |
| "transfers": result.transfers, |
| "distance_miles": result.distance_miles, |
| "estimated_minutes": result.estimated_minutes, |
| "line_sequence": result.line_sequence, |
| "system": _system_name, |
| } |
|
|
|
|
| def main() -> None: |
| |
| |
| default_llm_url = "https://api.anthropic.com/v1" |
| default_llm_key = "" |
| default_llm_model = "claude-haiku-4-5-20251001" |
| try: |
| from dotenv import load_dotenv |
| import os |
| load_dotenv() |
| azure_endpoint = os.environ.get("AZURE_ENDPOINT", "").rstrip("/") |
| azure_key = os.environ.get("AZURE_OPENAI_API_KEY", "") |
| azure_mini = os.environ.get("AZURE_MINI_LLM_DEPLOYMENT", "") |
| if azure_endpoint and azure_key and azure_mini: |
| default_llm_url = f"{azure_endpoint}/openai/deployments/{azure_mini}?api-version=2024-10-21" |
| default_llm_key = azure_key |
| default_llm_model = azure_mini |
| else: |
| default_llm_key = os.environ.get("ANTHROPIC_API_KEY", "") |
| except ImportError: |
| pass |
|
|
| parser = argparse.ArgumentParser(description="MetroLLM-Bench mock tool server") |
| parser.add_argument( |
| "--system", |
| default="marta", |
| help="Transit system name (must exist under data/systems/). Default: marta", |
| ) |
| parser.add_argument( |
| "--port", |
| type=int, |
| default=8100, |
| help="Port to listen on. Default: 8100", |
| ) |
| parser.add_argument( |
| "--llm-url", |
| default=default_llm_url, |
| help="LLM API base URL for /simulate endpoint. Default: Azure gpt-5.4-mini if AZURE_* env vars set, else https://api.anthropic.com/v1", |
| ) |
| parser.add_argument( |
| "--llm-key", |
| default=default_llm_key, |
| help="LLM API key. Default: AZURE_OPENAI_API_KEY or ANTHROPIC_API_KEY from .env", |
| ) |
| parser.add_argument( |
| "--llm-model", |
| default=default_llm_model, |
| help="LLM model name (Azure deployment name for Azure). Default: AZURE_MINI_LLM_DEPLOYMENT or claude-haiku-4-5-20251001", |
| ) |
| args = parser.parse_args() |
|
|
| system_dir = ( |
| Path(__file__).resolve().parent.parent |
| / "data" |
| / "systems" |
| / args.system |
| ) |
|
|
| if not system_dir.is_dir(): |
| print( |
| f"Error: system directory not found: {system_dir}", |
| file=sys.stderr, |
| ) |
| sys.exit(1) |
|
|
| global _system_name |
| global _llm_base_url, _llm_api_key, _llm_model, _port |
| _system_name = args.system |
| _llm_base_url = args.llm_url |
| _llm_api_key = args.llm_key |
| _llm_model = args.llm_model |
| _port = args.port |
|
|
| sd = _load_system(args.system) |
| print(f"Loaded system '{args.system}' ({len(sd.policies)} policies) from {system_dir}") |
| print(f"Simulator LLM: {args.llm_model} @ {args.llm_url}") |
| uvicorn.run(app, host="0.0.0.0", port=args.port) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|