| """Benchmark runner — sends cases to LLM, handles tool calls via mock server.""" |
|
|
| import asyncio |
| import hashlib |
| import json |
| import subprocess |
| import time |
| import argparse |
| from pathlib import Path |
| from dataclasses import dataclass, field, asdict |
| from datetime import datetime, timezone |
|
|
| import httpx |
| import yaml |
|
|
|
|
| @dataclass |
| class CaseResult: |
| case_id: str |
| response: dict | None |
| tool_calls_made: list[dict] |
| raw_content: str |
| reasoning_content: str |
| messages: list[dict] |
| ttft_ms: float |
| e2e_ms: float |
| input_tokens: int |
| output_tokens: int |
| api_rounds: int |
| error: str | None |
|
|
|
|
| TOOL_DEFINITIONS = [ |
| { |
| "type": "function", |
| "function": { |
| "name": "route_planner", |
| "description": "Find optimal route between two stations. Supports station restrictions for disruption-aware routing.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "origin": {"type": "string", "description": "Origin station name or ID"}, |
| "destination": {"type": "string", "description": "Destination station name or ID"}, |
| "departure_time": {"type": "string", "description": "ISO 8601 departure time (optional)"}, |
| "accessibility": { |
| "type": "array", |
| "items": {"type": "string"}, |
| "description": "Accessibility requirements (optional)" |
| }, |
| "station_restrictions": { |
| "type": "array", |
| "items": { |
| "type": "object", |
| "properties": { |
| "station": {"type": "string", "description": "Station name to restrict"}, |
| "restriction": { |
| "type": "string", |
| "enum": ["closed", "skip", "no_transfer"], |
| "description": "closed: no service. skip: trains pass without stopping. no_transfer: cannot change lines." |
| } |
| }, |
| "required": ["station", "restriction"] |
| }, |
| "description": "Stations with operational restrictions from disruption info" |
| }, |
| "segment_closures": { |
| "type": "array", |
| "items": { |
| "type": "array", |
| "items": {"type": "string"}, |
| "minItems": 2, |
| "maxItems": 2 |
| }, |
| "description": "Pairs of adjacent stations where track is closed" |
| }, |
| "line_closures": { |
| "type": "array", |
| "items": { |
| "type": "object", |
| "properties": { |
| "line": {"type": "string", "description": "Line id or name"}, |
| "from_station": {"type": "string", "description": "Inclusive start of the closed range (omit both endpoints for whole-line closure)"}, |
| "to_station": {"type": "string", "description": "Inclusive end of the closed range"} |
| }, |
| "required": ["line"] |
| }, |
| "description": "Line-level closures. Omit from_station/to_station to close the entire line. Prefer this over listing individual stations in station_restrictions." |
| } |
| }, |
| "required": ["origin", "destination"] |
| } |
| } |
| }, |
| { |
| "type": "function", |
| "function": { |
| "name": "fare_calculator", |
| "description": "Calculate fare for a journey", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "route_id": {"type": "string", "description": "Route ID from route_planner"}, |
| "passengers": { |
| "type": "object", |
| "properties": { |
| "adults": {"type": "integer"}, |
| "children": {"type": "integer"}, |
| "seniors": {"type": "integer"}, |
| "disabled": {"type": "integer"} |
| } |
| }, |
| "ticket_type": {"type": "string", "enum": ["single", "return", "day_pass", "weekly", "monthly"]}, |
| "payment_method": {"type": "string", "enum": ["smartcard", "contactless", "cash", "mobile", "gold_travel_card", "clipper_card", "easycard", "ventra", "disposable_ticket"]} |
| }, |
| "required": ["route_id", "passengers"] |
| } |
| } |
| }, |
| { |
| "type": "function", |
| "function": { |
| "name": "station_info", |
| "description": "Get station facility and accessibility information. Use station_ids to check multiple stations in one call (e.g. all stops on a route).", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "station_id": {"type": "string", "description": "Single station ID or name"}, |
| "station_ids": {"type": "array", "items": {"type": "string"}, "description": "Multiple station IDs to check at once"}, |
| "query_type": { |
| "type": "string", |
| "enum": ["accessibility", "facilities", "exits", "connections", "real_time_status"] |
| } |
| }, |
| "required": ["query_type"] |
| } |
| } |
| }, |
| { |
| "type": "function", |
| "function": { |
| "name": "line_info", |
| "description": "Get a line's station sequence, loop/terminal metadata, and per-station transfers (other lines at each station). Use before encoding line-level disruptions so station IDs come from the tool, not from memory. Use lines to look up multiple lines in one call (e.g. when several lines are disrupted).", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "line": {"type": "string", "description": "Single line id or natural-language name (e.g. \"10\" or \"Line 10\")"}, |
| "lines": {"type": "array", "items": {"type": "string"}, "description": "Multiple line ids or names to look up at once (preferred when several lines are impacted)"} |
| } |
| } |
| } |
| }, |
| { |
| "type": "function", |
| "function": { |
| "name": "disruption_feed", |
| "description": "Get current service disruptions and advisories. Call this when a disruption alert is reported to get detailed status information.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "line": {"type": "string", "description": "Filter by line name (optional)"}, |
| "station": {"type": "string", "description": "Filter by station name or ID (optional)"}, |
| "severity_filter": { |
| "type": "string", |
| "enum": ["all", "major", "minor"], |
| "description": "Filter by severity level (default: all)" |
| } |
| } |
| } |
| } |
| }, |
| { |
| "type": "function", |
| "function": { |
| "name": "knowledge_base", |
| "description": "Look up transit policies, FAQ, and service information. Use policy_id for exact lookup (preferred) or query for keyword search.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "policy_id": {"type": "string", "description": "Exact policy ID from the available policies list"}, |
| "query": {"type": "string", "description": "Keyword search query (when policy_id is not known)"}, |
| "category": {"type": "string", "description": "Optional category filter"} |
| }, |
| "required": [] |
| } |
| } |
| }, |
| { |
| "type": "function", |
| "function": { |
| "name": "submit_assistant_state", |
| "description": "Submit the final assistant kiosk state for rendering. You MUST call this tool as your last action.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "outcome": { |
| "type": "string", |
| "enum": ["route_and_fare_ready", "advisory_only", "service_unavailable", "request_declined", "policy_answer_only"], |
| "description": "The outcome state of this interaction" |
| }, |
| "route": { |
| "type": "object", |
| "description": "Route information. Required when outcome is route_and_fare_ready or advisory_only.", |
| "properties": { |
| "origin": {"type": "string"}, |
| "destination": {"type": "string"}, |
| "stops": {"type": "array", "items": { |
| "type": "object", |
| "properties": { |
| "station_id": {"type": "string"}, |
| "station_name": {"type": "string"}, |
| "line": {"type": "string"}, |
| "is_transfer": {"type": "boolean"} |
| }, |
| "required": ["station_id"] |
| }, "description": "Stop objects from route_planner result"}, |
| "transfers": {"type": "integer"}, |
| "estimated_minutes": {"type": "integer"}, |
| "distance_miles": {"type": "number"}, |
| "line_sequence": {"type": "array", "items": {"type": "string"}, "description": "Line names used in order"} |
| }, |
| "required": ["origin", "destination", "stops", "transfers", "estimated_minutes", "distance_miles", "line_sequence"] |
| }, |
| "fare_quote": { |
| "type": "object", |
| "description": "Fare breakdown. Required when outcome is route_and_fare_ready.", |
| "properties": { |
| "passenger_summary": { |
| "type": "object", |
| "properties": { |
| "adults": {"type": "integer", "default": 0}, |
| "children": {"type": "integer", "default": 0}, |
| "seniors": {"type": "integer", "default": 0}, |
| "disabled": {"type": "integer", "default": 0}, |
| "free_riders": {"type": "integer", "default": 0} |
| } |
| }, |
| "line_items": { |
| "type": "array", |
| "items": { |
| "type": "object", |
| "properties": { |
| "rider_type": {"type": "string"}, |
| "count": {"type": "integer"}, |
| "unit_fare": {"type": "number"}, |
| "subtotal": {"type": "number"}, |
| "currency": {"type": "string"} |
| }, |
| "required": ["rider_type", "count", "unit_fare", "subtotal", "currency"] |
| } |
| }, |
| "discounts": { |
| "type": "array", |
| "items": { |
| "type": "object", |
| "properties": { |
| "label": {"type": "string"}, |
| "amount": {"type": "number"}, |
| "currency": {"type": "string"} |
| } |
| } |
| }, |
| "total": {"type": "number", "description": "Total fare as a number (e.g. 2.50, NOT '$2.50')"}, |
| "currency": {"type": "string"} |
| }, |
| "required": ["total", "currency"] |
| }, |
| "kiosk_action": { |
| "type": "object", |
| "description": "What the kiosk should do with this state", |
| "properties": { |
| "action": { |
| "type": "string", |
| "enum": ["display_info", "prompt_purchase", "block_purchase", "refer_to_staff"] |
| }, |
| "reason_code": { |
| "type": "string", |
| "enum": ["ok", "no_service", "invalid_request", "unsupported_request", "accessibility_issue", "policy_exception"] |
| } |
| }, |
| "required": ["action", "reason_code"] |
| }, |
| "advisory_banners": { |
| "type": "array", |
| "items": { |
| "type": "object", |
| "properties": { |
| "severity": {"type": "string", "enum": ["info", "warning", "critical", "positive"]}, |
| "title": {"type": "string"}, |
| "body": {"type": "string"} |
| }, |
| "required": ["severity", "title", "body"] |
| } |
| }, |
| "assistant_message": { |
| "type": "string", |
| "description": "Human-readable message for the kiosk screen" |
| }, |
| "reasoning": { |
| "type": "string", |
| "description": "Internal analysis of the query" |
| } |
| }, |
| "required": ["outcome", "kiosk_action", "assistant_message"] |
| } |
| } |
| } |
| ] |
|
|
|
|
| class BenchmarkRunner: |
| def __init__( |
| self, |
| llm_base_url: str, |
| llm_api_key: str, |
| llm_model: str, |
| mock_server_url: str, |
| system_name: str, |
| parallel: int = 2, |
| max_tokens: int = 4096, |
| thinking: bool = True, |
| temperature: float = 0.0, |
| max_tool_rounds: int = 20, |
| extra_body: dict | None = None, |
| ): |
| self.llm_base_url = llm_base_url.rstrip("/") |
| self.llm_api_key = llm_api_key |
| self.llm_model = llm_model |
| self.mock_server_url = mock_server_url.rstrip("/") |
| self.system_name = system_name |
| self.parallel = parallel |
| self.max_tokens = max_tokens |
| self.thinking = thinking |
| self.temperature = temperature |
| self.max_tool_rounds = max_tool_rounds |
| self.extra_body = extra_body or {} |
| self.semaphore = asyncio.Semaphore(parallel) |
|
|
| def _build_system_prompt(self, case: dict | None = None) -> str: |
| """Build system prompt from framebook + high-level rules. |
| |
| If case is provided and has active disruptions, disruption handling |
| instructions are appended. Otherwise they are omitted to avoid |
| the model defensively calling disruption_feed on normal cases. |
| """ |
| system_dir = Path(__file__).resolve().parent.parent / "data" / "systems" / self.system_name |
| with open(system_dir / "framebook.yaml") as f: |
| framebook = yaml.safe_load(f)["framebook"] |
|
|
| with open(system_dir / "fares.json") as f: |
| fares = json.load(f) |
|
|
| with open(system_dir / "lines.json") as f: |
| lines = json.load(f) |
|
|
| currency_symbol = framebook["currency_symbol"] |
| currency_code = framebook["currency_code"] |
| terminology = framebook["terminology"] |
|
|
| |
| line_names = ", ".join(l["name"] for l in lines) |
|
|
| base_fare = fares["base_fare"] |
| fare_display = framebook["fare_display_format"] |
| fare_model = fares.get("model", "flat") |
|
|
| prompt = f"""You are a transit kiosk assistant for {framebook['org_name']} ({framebook['full_name']}). |
| |
| ## System Information |
| - Lines: {line_names} |
| """ |
|
|
| |
| fare_rules = { |
| "model": fare_model, |
| "base_fare": f"{currency_symbol}{base_fare}", |
| "currency": currency_code, |
| "format": fare_display, |
| "payment": [terminology["smartcard"], terminology["contactless"]], |
| } |
| if fares.get("discounts"): |
| fare_rules["discounts"] = fares["discounts"] |
| if fares.get("fare_brackets"): |
| fare_rules["fare_brackets"] = fares["fare_brackets"] |
| if fares.get("surcharges"): |
| fare_rules["surcharges"] = fares["surcharges"] |
| if fares.get("station_overrides"): |
| fare_rules["station_overrides"] = fares["station_overrides"] |
| if fares.get("payment_methods"): |
| fare_rules["payment_methods"] = fares["payment_methods"] |
| if "gold_fare" in fares: |
| fare_rules["gold_class"] = { |
| "fare": f"{currency_symbol}{fares['gold_fare']}", |
| "card": terminology.get("smartcard_premium", "Gold Card"), |
| } |
| prompt += f"- Fare rules: {json.dumps(fare_rules)}\n" |
| prompt += f"- Respond in English (the local language is {framebook['primary_language']})\n" |
|
|
| |
| cultural_notes = framebook.get("cultural_notes", []) |
| if cultural_notes: |
| prompt += "\n## Cultural Notes\n" |
| for note in cultural_notes: |
| prompt += f"- {note}\n" |
|
|
| |
| operating_hours = framebook.get("operating_hours", {}) |
| if operating_hours: |
| prompt += f"\n## Service Hours\n{json.dumps(operating_hours)}\n" |
|
|
| |
| temporal_ctx = ( |
| case and case.get("system_context", {}).get("temporal_context") |
| ) |
| if temporal_ctx: |
| prompt += "\n## Current Time & Service Hours\n" |
| prompt += f"- Current time: {temporal_ctx['current_time']}\n" |
| if temporal_ctx.get("day_of_week"): |
| prompt += f"- Day: {temporal_ctx['day_of_week']}\n" |
| if temporal_ctx.get("notes"): |
| prompt += f"- {temporal_ctx['notes']}\n" |
| prompt += "- Check whether the requested journey falls within service hours and warn the passenger if not\n" |
| prompt += "- Consider headway frequency at the requested time\n" |
|
|
| prompt += f""" |
| ## Your Role |
| You help passengers plan trips, calculate fares, and provide station information. |
| Use the available tools to look up routes, calculate fares, and get station details. |
| Always use tools rather than guessing — do not fabricate route or fare information. |
| |
| ## Workflow |
| 1. Use route_planner, fare_calculator, station_info to gather information |
| 2. When you have all the information needed, call submit_assistant_state with your final kiosk state |
| 3. You MUST always finish by calling submit_assistant_state — never respond with plain text |
| 4. Set the outcome field to indicate the result: route_and_fare_ready (normal trip), advisory_only (disrupted but route shown), service_unavailable (no service), request_declined (invalid request), or policy_answer_only (info-only) |
| 5. Set kiosk_action to indicate what the kiosk should do: prompt_purchase (ready to buy), display_info (information only), block_purchase (cannot proceed), or refer_to_staff (need human help) |
| |
| ## Reason Code Semantics |
| - Use `ok` when the kiosk can complete the request normally |
| - Use `no_service` when service is unavailable for the requested trip or time |
| - Use `invalid_request` when the request is contradictory or impossible as asked |
| - Use `unsupported_request` when the question is outside kiosk capabilities |
| - Use `accessibility_issue` when the route does not satisfy the passenger's stated accessibility requirement |
| - Use `policy_exception` when a special policy changes the normal fare or purchase flow and that exception should be surfaced |
| |
| ## Advisory Banners |
| advisory_banners is a primary passenger-facing information channel. Use it to surface important context alongside the route and fare. Severity levels: |
| - `critical`: service unavailable, block_purchase required, safety issue |
| - `warning`: disruption affecting the route, accessibility concern, approaching last train |
| - `info`: security/ID rules, payment requirements, operating-hour reminders, policy context, station-specific notes, late-night service info |
| - `positive`: a discount, exception, or pass applied in the passenger's favor |
| |
| Write banners that are specific to this trip — reference affected stations, specific times, or exact policy items from the system prompt. Avoid generic boilerplate. Multiple banners are fine when they address distinct concerns. |
| |
| ## Rules |
| - Use {terminology['smartcard']} (not "metro card" or other names) |
| - Fare totals must be numbers (2.50), not strings ("{currency_symbol}2.50") |
| - Line names in line_sequence must be lowercase (e.g. "red", not "Red") |
| - Pass route_planner stop objects directly into route.stops (each with station_id, station_name, line, is_transfer) |
| - If submit_assistant_state returns an error, fix the issues and call it again |
| - Include fare_quote with passenger_summary and line_items when outcome is route_and_fare_ready |
| """ |
|
|
| |
| has_disruptions = bool( |
| case |
| and case.get("system_context", {}).get("active_disruptions") |
| ) |
| if has_disruptions: |
| prompt += """ |
| ## Disruption Handling |
| - A DISRUPTION ALERT is included in the passenger query — use the disruption_feed tool to get current service status |
| - Check if the planned route passes through any affected segments or stations |
| - Include advisory_banners in your submit_assistant_state with the appropriate severity (critical, warning, or info) |
| - If the route is affected, warn the passenger and suggest alternatives if available |
| - If the disruption makes the route unusable, set outcome to service_unavailable and kiosk_action to block_purchase |
| - When a disruption describes an entire line or a named segment between two stations, call line_info to resolve the topology and encode the closure via route_planner's line_closures parameter (do not enumerate individual stations in station_restrictions) |
| - If multiple lines are disrupted, pass all of them to line_info's `lines` array in a single call rather than issuing one request per line |
| """ |
|
|
| |
| has_accessibility = bool( |
| case |
| and case.get("system_context", {}).get("accessibility_mode") |
| ) |
| if has_accessibility: |
| prompt += """ |
| ## Accessibility |
| - The passenger has indicated an accessibility requirement |
| - Use the station_info tool with query_type "accessibility" to check stations along the route |
| - Check EACH station on the route for elevator and step-free access |
| - If any station has an accessibility issue (e.g. elevator out of service), warn the passenger in your advisory_banners |
| - Include the affected station name and the specific issue in the advisory |
| """ |
|
|
| |
| policy_change = ( |
| case and case.get("system_context", {}).get("policy_change") |
| ) |
| if policy_change: |
| prompt += "\n## Policy Update\n" |
| prompt += "IMPORTANT: The following policy is in effect and supersedes standard fare rules.\n\n" |
| prompt += policy_change["text"] + "\n\n" |
| prompt += "Apply this policy when calculating fares. If fare_calculator returns a fare based on old rules, adjust the total in submit_assistant_state.\n" |
|
|
| |
| policies_path = system_dir / "policies.json" |
| if policies_path.exists(): |
| with open(policies_path) as f: |
| policies_data = json.load(f) |
| policy_list = policies_data.get("policies", policies_data) if isinstance(policies_data, dict) else policies_data |
| if policy_list: |
| prompt += "\n## Available Policies\n" |
| for p in policy_list: |
| prompt += f"- [{p['policy_id']}] {p['title']}\n" |
| prompt += "Use knowledge_base with policy_id for exact lookup.\n" |
|
|
| |
| has_knowledge_query = bool( |
| case |
| and case.get("system_context", {}).get("knowledge_query") |
| ) |
| if has_knowledge_query: |
| prompt += """ |
| ## Knowledge Base |
| - The passenger has a question about transit policies or service information |
| - Use the knowledge_base tool with the appropriate policy_id to look up relevant policies |
| - If the passenger asks about multiple topics, make separate knowledge_base calls for each |
| - If you are unsure which policy applies, use the query parameter to search |
| - Include the relevant policy information in your submit_assistant_state |
| - If no matching policies are found, provide a helpful general response |
| """ |
|
|
| return prompt |
|
|
| def _build_user_message(self, case: dict) -> str: |
| """Convert case events into a user message.""" |
| events = case["events"] |
| parts = [] |
| for event in events: |
| if event["type"] == "station_selected": |
| parts.append(f"{event['field'].title()}: {event['value']}") |
| elif event["type"] == "passenger_count_changed": |
| pax_parts = [] |
| for key in ["adults", "children", "seniors", "disabled"]: |
| if key in event and event[key] != 0: |
| pax_parts.append(f"{event[key]} {key}") |
| parts.append(f"Passengers: {', '.join(pax_parts)}") |
| elif event["type"] == "freetext_input": |
| parts.append(event["text"]) |
| elif event["type"] == "payment_method_selected": |
| parts.append(f"Payment method: {event['method'].replace('_', ' ').title()}") |
| elif event["type"] == "disruption_update": |
| disruption = event.get("disruption", {}) |
| msg = disruption.get("message", "Service disruption in effect") |
| parts.append(f"⚠ DISRUPTION ALERT: {msg}") |
| return "\n".join(parts) |
|
|
| async def _call_mock_tool(self, client: httpx.AsyncClient, tool_name: str, arguments: dict, case_id: str | None = None, case: dict | None = None) -> dict: |
| """Forward a tool call to the mock server.""" |
| url = f"{self.mock_server_url}/{tool_name}" |
| payload = dict(arguments) |
| |
| if case_id: |
| payload["case_id"] = case_id |
| |
| if tool_name == "disruption_feed" and case is not None: |
| current_time = ( |
| case.get("system_context", {}) |
| .get("temporal_context", {}) |
| .get("current_time") |
| or case.get("system_context", {}).get("current_time") |
| ) |
| if current_time: |
| payload["current_time"] = current_time |
| resp = await client.post(url, json=payload, timeout=30.0) |
| resp.raise_for_status() |
| return resp.json() |
|
|
| async def _run_single_case(self, client: httpx.AsyncClient, case: dict) -> CaseResult: |
| """Run a single test case against the LLM.""" |
| case_id = case["id"] |
| system_prompt = self._build_system_prompt(case) |
| user_message = self._build_user_message(case) |
|
|
| |
| turn_groups = case.get("multi_turn_events") |
| if turn_groups: |
| first_msg = self._build_user_message({"events": turn_groups[0]}) |
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": first_msg}, |
| ] |
| remaining_turns = list(turn_groups[1:]) |
| else: |
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_message}, |
| ] |
| remaining_turns = [] |
|
|
| |
| active_disruptions = case.get("system_context", {}).get("active_disruptions", []) |
| await client.post( |
| f"{self.mock_server_url}/set_disruptions", |
| json={"case_id": case_id, "system": self.system_name, "disruptions": active_disruptions}, |
| timeout=5.0, |
| ) |
|
|
| tool_calls_made = [] |
| total_input_tokens = 0 |
| total_output_tokens = 0 |
| api_rounds = 0 |
| first_token_ms = 0.0 |
|
|
| start_time = time.monotonic() |
|
|
| |
| |
| is_azure = "azure.com" in self.llm_base_url |
| if is_azure: |
| from urllib.parse import urlparse, urlunparse |
| parsed = urlparse(self.llm_base_url) |
| new_path = parsed.path.rstrip("/") + "/chat/completions" |
| chat_endpoint = urlunparse(parsed._replace(path=new_path)) |
| request_headers = {"api-key": self.llm_api_key} |
| else: |
| chat_endpoint = f"{self.llm_base_url}/chat/completions" |
| request_headers = {"Authorization": f"Bearer {self.llm_api_key}"} |
|
|
| try: |
| for round_num in range(self.max_tool_rounds): |
| |
| use_completion = ( |
| "192.168.1.5" in self.llm_base_url |
| or "api.openai.com" in self.llm_base_url |
| or is_azure |
| ) |
| token_limit_key = "max_completion_tokens" if use_completion else "max_tokens" |
| request_body = { |
| "model": self.llm_model, |
| "messages": messages, |
| "tools": TOOL_DEFINITIONS, |
| token_limit_key: self.max_tokens, |
| } |
| if self.temperature is not None: |
| request_body["temperature"] = self.temperature |
| |
| |
| if is_azure or (self.llm_model or "").startswith("gpt-5"): |
| request_body["reasoning_effort"] = "medium" |
| |
| if not self.thinking and self.llm_base_url == "http://192.168.1.5:8080/v1": |
| request_body["chat_template_kwargs"] = {"enable_thinking": False} |
|
|
| |
| if self.extra_body: |
| request_body.update(self.extra_body) |
|
|
| |
| for attempt in range(5): |
| resp = await client.post( |
| chat_endpoint, |
| headers=request_headers, |
| json=request_body, |
| timeout=240.0, |
| ) |
| if resp.status_code == 429 and attempt < 4: |
| wait = 2 ** attempt |
| await asyncio.sleep(wait) |
| continue |
| break |
| if resp.status_code >= 400: |
| error_detail = resp.text[:500] |
| raise httpx.HTTPStatusError( |
| f"{resp.status_code}: {error_detail}", |
| request=resp.request, |
| response=resp, |
| ) |
| result = resp.json() |
|
|
| if api_rounds == 0: |
| first_token_ms = resp.elapsed.total_seconds() * 1000 |
|
|
| choice = result["choices"][0] |
| message = choice["message"] |
| finish_reason = choice.get("finish_reason", "") |
|
|
| usage = result.get("usage", {}) |
| total_input_tokens += usage.get("prompt_tokens", 0) |
| total_output_tokens += usage.get("completion_tokens", 0) |
| api_rounds += 1 |
|
|
| |
| if message.get("tool_calls"): |
| messages.append(message) |
|
|
| submitted = None |
| for tc in message["tool_calls"]: |
| fn_name = tc["function"]["name"] |
| fn_args = json.loads(tc["function"]["arguments"]) |
|
|
| try: |
| tool_result = await self._call_mock_tool(client, fn_name, fn_args, case_id=case_id, case=case) |
| tool_calls_made.append({ |
| "name": fn_name, |
| "arguments": fn_args, |
| "result": tool_result, |
| "error": None, |
| }) |
| |
| if fn_name == "submit_assistant_state" and tool_result.get("accepted"): |
| submitted = fn_args |
| except httpx.HTTPStatusError as e: |
| |
| error_body = e.response.text |
| tool_result = {"error": error_body} |
| tool_calls_made.append({ |
| "name": fn_name, |
| "arguments": fn_args, |
| "result": None, |
| "error": error_body, |
| }) |
| except Exception as e: |
| tool_result = {"error": str(e)} |
| tool_calls_made.append({ |
| "name": fn_name, |
| "arguments": fn_args, |
| "result": None, |
| "error": str(e), |
| }) |
|
|
| messages.append({ |
| "role": "tool", |
| "tool_call_id": tc["id"], |
| "content": json.dumps(tool_result), |
| }) |
|
|
| |
| if submitted is not None: |
| if remaining_turns: |
| |
| next_events = remaining_turns.pop(0) |
| next_msg = self._build_user_message({"events": next_events}) |
| messages.append({"role": "user", "content": next_msg}) |
| submitted = None |
| continue |
|
|
| e2e_ms = (time.monotonic() - start_time) * 1000 |
| reasoning = message.get("reasoning_content", "") |
| |
| parsed = { |
| "outcome": submitted.get("outcome", ""), |
| "kiosk_action": submitted.get("kiosk_action", {}), |
| "reasoning": submitted.get("reasoning", ""), |
| "ui_updates": { |
| "route": submitted.get("route"), |
| "fare_quote": submitted.get("fare_quote"), |
| "advisory_banners": submitted.get("advisory_banners", []), |
| "assistant_message": submitted.get("assistant_message", ""), |
| }, |
| } |
| return CaseResult( |
| case_id=case_id, |
| response=parsed, |
| tool_calls_made=tool_calls_made, |
| raw_content=json.dumps(submitted), |
| reasoning_content=reasoning, |
| messages=messages, |
| ttft_ms=round(first_token_ms, 1), |
| e2e_ms=round(e2e_ms, 1), |
| input_tokens=total_input_tokens, |
| output_tokens=total_output_tokens, |
| api_rounds=api_rounds, |
| error=None, |
| ) |
|
|
| continue |
|
|
| |
| raw_content = message.get("content", "") or "" |
| reasoning = message.get("reasoning_content", "") |
|
|
| |
| |
| if remaining_turns and (raw_content.strip() or reasoning): |
| messages.append(message) |
| next_events = remaining_turns.pop(0) |
| next_msg = self._build_user_message({"events": next_events}) |
| messages.append({"role": "user", "content": next_msg}) |
| continue |
|
|
| |
| if not raw_content.strip() and not reasoning and round_num < self.max_tool_rounds - 1: |
| |
| continue |
|
|
| e2e_ms = (time.monotonic() - start_time) * 1000 |
| parsed = None |
| try: |
| parsed = json.loads(raw_content) |
| except (json.JSONDecodeError, TypeError): |
| pass |
|
|
| return CaseResult( |
| case_id=case_id, |
| response=parsed, |
| tool_calls_made=tool_calls_made, |
| raw_content=raw_content, |
| reasoning_content=reasoning, |
| messages=messages, |
| ttft_ms=round(first_token_ms, 1), |
| e2e_ms=round(e2e_ms, 1), |
| input_tokens=total_input_tokens, |
| output_tokens=total_output_tokens, |
| api_rounds=api_rounds, |
| error=None, |
| ) |
|
|
| |
| e2e_ms = (time.monotonic() - start_time) * 1000 |
| return CaseResult( |
| case_id=case_id, response=None, tool_calls_made=tool_calls_made, |
| raw_content="", reasoning_content="", messages=messages, |
| ttft_ms=round(first_token_ms, 1), e2e_ms=round(e2e_ms, 1), |
| input_tokens=total_input_tokens, output_tokens=total_output_tokens, |
| api_rounds=api_rounds, |
| error=f"Exhausted {self.max_tool_rounds} tool call rounds", |
| ) |
|
|
| except Exception as e: |
| e2e_ms = (time.monotonic() - start_time) * 1000 |
| return CaseResult( |
| case_id=case_id, response=None, tool_calls_made=tool_calls_made, |
| raw_content="", reasoning_content="", messages=messages, |
| ttft_ms=round(first_token_ms, 1), e2e_ms=round(e2e_ms, 1), |
| input_tokens=total_input_tokens, output_tokens=total_output_tokens, |
| api_rounds=api_rounds, |
| error=str(e), |
| ) |
|
|
| async def _run_with_semaphore(self, client: httpx.AsyncClient, case: dict) -> CaseResult: |
| async with self.semaphore: |
| return await self._run_single_case(client, case) |
|
|
| async def run(self, cases: list[dict]) -> list[CaseResult]: |
| """Run all cases with controlled parallelism.""" |
| async with httpx.AsyncClient() as client: |
| tasks = [self._run_with_semaphore(client, case) for case in cases] |
| results = await asyncio.gather(*tasks) |
| return list(results) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="MetroLLM-Bench Runner") |
| parser.add_argument("--cases", required=True, help="Path to cases JSON (e.g., cases/marta_cases.json)") |
| parser.add_argument("--output", default=None, help="Output path (default: results/{model}_{timestamp}.json)") |
| parser.add_argument("--llm-url", default="http://192.168.1.5:8080/v1", help="LLM API base URL") |
| parser.add_argument("--llm-key", default="sk-local-test", help="LLM API key") |
| parser.add_argument("--llm-model", default="qwen3.5", help="Model name") |
| parser.add_argument("--mock-url", default="http://localhost:8100", help="Mock server URL") |
| parser.add_argument("--system", default="marta", help="Transit system name") |
| parser.add_argument("--parallel", type=int, default=2, help="Parallel requests") |
| parser.add_argument("--max-tokens", type=int, default=4096, help="Max tokens per response") |
| parser.add_argument("--limit", type=int, default=None, help="Limit number of cases (for testing)") |
| parser.add_argument("--case-ids", default=None, help="Comma-separated case IDs to run (filters cases file)") |
| parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature (default: 0.0 for reproducibility)") |
| parser.add_argument("--max-tool-rounds", type=int, default=20, help="Max tool call rounds per case") |
| parser.add_argument("--thinking", dest="thinking", action="store_true", default=True, help="Enable thinking mode (default)") |
| parser.add_argument("--no-thinking", dest="thinking", action="store_false", help="Disable thinking mode") |
| parser.add_argument("--extra-body-json", default=None, help="JSON string shallow-merged into each chat/completions request body") |
| args = parser.parse_args() |
|
|
| with open(args.cases) as f: |
| cases = json.load(f) |
|
|
| if args.case_ids: |
| wanted = {cid.strip() for cid in args.case_ids.split(",") if cid.strip()} |
| cases = [c for c in cases if c["id"] in wanted] |
| missing = wanted - {c["id"] for c in cases} |
| if missing: |
| print(f"Warning: case IDs not found: {sorted(missing)}") |
|
|
| if args.limit: |
| cases = cases[:args.limit] |
|
|
| thinking_label = "thinking" if args.thinking else "non-thinking" |
| print(f"Running {len(cases)} cases against {args.llm_model} ({thinking_label}) at {args.llm_url}") |
| print(f"Mock server: {args.mock_url}, parallel: {args.parallel}") |
|
|
| extra_body = json.loads(args.extra_body_json) if args.extra_body_json else None |
|
|
| runner = BenchmarkRunner( |
| llm_base_url=args.llm_url, |
| llm_api_key=args.llm_key, |
| llm_model=args.llm_model, |
| mock_server_url=args.mock_url, |
| system_name=args.system, |
| parallel=args.parallel, |
| max_tokens=args.max_tokens, |
| thinking=args.thinking, |
| temperature=args.temperature, |
| max_tool_rounds=args.max_tool_rounds, |
| extra_body=extra_body, |
| ) |
|
|
| |
| cases_checksum = hashlib.sha256(Path(args.cases).read_bytes()).hexdigest()[:12] |
|
|
| |
| try: |
| git_hash = subprocess.check_output( |
| ["git", "describe", "--always", "--dirty"], |
| stderr=subprocess.DEVNULL, |
| ).decode().strip() |
| except Exception: |
| git_hash = None |
|
|
| started_at = datetime.now(timezone.utc).isoformat() |
| results = asyncio.run(runner.run(cases)) |
| finished_at = datetime.now(timezone.utc).isoformat() |
|
|
| |
| output = { |
| "metadata": { |
| "harness_version": "0.4.0", |
| "started_at": started_at, |
| "finished_at": finished_at, |
| "git_hash": git_hash, |
| "llm_base_url": args.llm_url, |
| "llm_model": args.llm_model, |
| "temperature": args.temperature, |
| "max_tokens": args.max_tokens, |
| "max_tool_rounds": args.max_tool_rounds, |
| "thinking": args.thinking, |
| "parallel": args.parallel, |
| "system": args.system, |
| "cases_file": args.cases, |
| "cases_checksum_sha256": cases_checksum, |
| }, |
| "model": args.llm_model, |
| "system": args.system, |
| "thinking": args.thinking, |
| "cases_total": len(cases), |
| "cases_succeeded": sum(1 for r in results if r.error is None), |
| "cases_failed": sum(1 for r in results if r.error is not None), |
| "results": [asdict(r) for r in results], |
| } |
|
|
| if args.output is None: |
| ts = time.strftime("%Y%m%d_%H%M%S") |
| output_path = Path("results") / f"{args.llm_model}_{ts}.json" |
| else: |
| output_path = Path(args.output) |
|
|
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(output_path, "w") as f: |
| json.dump(output, f, indent=2) |
|
|
| print(f"\nResults written to {output_path}") |
| print(f" Succeeded: {output['cases_succeeded']}/{output['cases_total']}") |
| print(f" Failed: {output['cases_failed']}/{output['cases_total']}") |
|
|
| |
| for r in results: |
| status = "OK" if r.error is None else f"ERR: {r.error[:60]}" |
| tools = len(r.tool_calls_made) |
| print(f" {r.case_id}: {status} ({tools} tool calls, {r.e2e_ms:.0f}ms)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|