| """Rule-based baseline agent — no LLM, deterministic tool forwarding.""" |
|
|
| import asyncio |
| import hashlib |
| import json |
| import time |
| import argparse |
| from pathlib import Path |
| from dataclasses import asdict |
| from datetime import datetime, timezone |
|
|
| import httpx |
| import yaml |
|
|
| from harness.runner import CaseResult |
|
|
|
|
| class RuleAgent: |
| """Scripted agent that calls mock server tools in fixed order.""" |
|
|
| def __init__(self, mock_server_url: str, system_name: str, parallel: int = 4): |
| self.mock_server_url = mock_server_url.rstrip("/") |
| self.system_name = system_name |
| self.parallel = parallel |
| self.semaphore = asyncio.Semaphore(parallel) |
|
|
| |
| data_dir = Path(__file__).parent.parent / "data" / "systems" / system_name |
| with open(data_dir / "framebook.yaml") as f: |
| self.framebook = yaml.safe_load(f) |
|
|
| def _parse_events(self, events: list[dict]) -> dict: |
| """Extract structured fields from case events.""" |
| origin = destination = origin_id = destination_id = None |
| adults = children = seniors = disabled = 0 |
| freetext = None |
| payment_method = None |
| has_pax = False |
|
|
| for e in events: |
| t = e.get("type", "") |
| if t == "station_selected": |
| if e.get("field") == "origin": |
| origin = e.get("value") |
| origin_id = e.get("station_id") |
| elif e.get("field") == "destination": |
| destination = e.get("value") |
| destination_id = e.get("station_id") |
| elif t == "passenger_count_changed": |
| adults = e.get("adults", 0) |
| children = e.get("children", 0) |
| seniors = e.get("seniors", 0) |
| disabled = e.get("disabled", 0) |
| has_pax = True |
| elif t == "freetext_input": |
| freetext = e.get("text", "") |
| elif t == "payment_method_selected": |
| payment_method = e.get("method") |
|
|
| return { |
| "origin": origin, |
| "destination": destination, |
| "origin_id": origin_id, |
| "destination_id": destination_id, |
| "adults": adults, |
| "children": children, |
| "seniors": seniors, |
| "disabled": disabled, |
| "has_pax": has_pax, |
| "freetext": freetext, |
| "payment_method": payment_method, |
| } |
|
|
| def _check_service_hours(self, case: dict) -> bool: |
| """Check if current time is within service hours. Returns True if service available.""" |
| temporal = case.get("system_context", {}).get("temporal_context") |
| if not temporal: |
| return True |
|
|
| service_available = temporal.get("service_available") |
| if service_available is not None: |
| return service_available |
|
|
| |
| return True |
|
|
| async def _call_tool(self, client: httpx.AsyncClient, tool_name: str, args: dict, case_id: str) -> tuple[dict | None, str | None]: |
| """Call a mock server tool. Returns (result, error).""" |
| payload = dict(args) |
| if tool_name == "disruption_feed": |
| payload["case_id"] = case_id |
| try: |
| resp = await client.post(f"{self.mock_server_url}/{tool_name}", json=payload, timeout=30.0) |
| if resp.status_code >= 400: |
| return None, resp.text[:200] |
| return resp.json(), None |
| except Exception as e: |
| return None, str(e) |
|
|
| def _build_fare_quote(self, fare_result: dict, pax: dict) -> dict: |
| """Build fare_quote from fare_calculator result and passenger counts.""" |
| line_items = [] |
| for item in fare_result.get("line_items", []): |
| |
| label = item.get("label", "") |
| parts = label.lower().split(" x") |
| rider_type = parts[0].strip() if parts else "adult" |
| count = int(parts[1]) if len(parts) > 1 else 1 |
| unit_fare = item["amount"] / count if count > 0 else item["amount"] |
| line_items.append({ |
| "rider_type": rider_type, |
| "count": count, |
| "unit_fare": round(unit_fare, 2), |
| "subtotal": item["amount"], |
| "currency": item.get("currency", fare_result.get("currency", "USD")), |
| }) |
|
|
| |
| total_pax = pax["adults"] + pax["children"] + pax["seniors"] + pax["disabled"] |
| ticketed = sum(i["count"] for i in line_items) |
| free_riders = max(0, total_pax - ticketed) |
|
|
| return { |
| "passenger_summary": { |
| "adults": pax["adults"], |
| "children": pax["children"], |
| "seniors": pax["seniors"], |
| "disabled": pax["disabled"], |
| "free_riders": free_riders, |
| }, |
| "line_items": line_items, |
| "discounts": fare_result.get("discounts", []), |
| "total": fare_result["total"], |
| "currency": fare_result.get("currency", "USD"), |
| } |
|
|
| def _build_route(self, route_result: dict) -> dict: |
| """Build route from route_planner result.""" |
| return { |
| "origin": route_result["stops"][0]["station_name"], |
| "destination": route_result["stops"][-1]["station_name"], |
| "stops": [s["station_name"] for s in route_result["stops"]], |
| "transfers": route_result["transfers"], |
| "estimated_minutes": route_result["estimated_minutes"], |
| "distance_miles": route_result["distance_miles"], |
| "line_sequence": route_result.get("line_sequence", []), |
| } |
|
|
| async def _run_single_case(self, client: httpx.AsyncClient, case: dict) -> CaseResult: |
| """Run a single case with rule-based logic.""" |
| case_id = case["id"] |
| start_time = time.monotonic() |
| tool_calls_made = [] |
|
|
| |
| active_disruptions = case.get("system_context", {}).get("active_disruptions", []) |
| await client.post( |
| f"{self.mock_server_url}/set_disruptions", |
| json={"case_id": case_id, "disruptions": active_disruptions}, |
| timeout=5.0, |
| ) |
|
|
| |
| turn_groups = case.get("multi_turn_events") |
| if turn_groups: |
| all_events = [] |
| for group in turn_groups: |
| all_events.extend(group) |
| else: |
| all_events = case["events"] |
|
|
| parsed = self._parse_events(all_events) |
|
|
| def record_tool(name, args, result, error=None): |
| tool_calls_made.append({"name": name, "arguments": args, "result": result, "error": error}) |
|
|
| |
|
|
| |
| if not parsed["origin"] or not parsed["destination"]: |
| if parsed["freetext"]: |
| |
| kb_args = {"query": parsed["freetext"], "category": "general"} |
| kb_result, kb_err = await self._call_tool(client, "knowledge_base", kb_args, case_id) |
| record_tool("knowledge_base", kb_args, kb_result, kb_err) |
|
|
| if kb_result and kb_result.get("found"): |
| content = kb_result["results"][0]["content"] if kb_result["results"] else "" |
| submit_args = { |
| "outcome": "policy_answer_only", |
| "kiosk_action": {"action": "display_info", "reason_code": "ok"}, |
| "assistant_message": content[:300], |
| } |
| else: |
| submit_args = { |
| "outcome": "request_declined", |
| "kiosk_action": {"action": "block_purchase", "reason_code": "unsupported_request"}, |
| "assistant_message": "This request is outside kiosk capabilities.", |
| } |
| else: |
| submit_args = { |
| "outcome": "request_declined", |
| "kiosk_action": {"action": "block_purchase", "reason_code": "invalid_request"}, |
| "assistant_message": "Please select origin and destination stations.", |
| } |
|
|
| sub_result, sub_err = await self._call_tool(client, "submit_assistant_state", submit_args, case_id) |
| record_tool("submit_assistant_state", submit_args, sub_result, sub_err) |
|
|
| e2e_ms = (time.monotonic() - start_time) * 1000 |
| return self._make_result(case_id, submit_args, tool_calls_made, e2e_ms) |
|
|
| |
| if not self._check_service_hours(case): |
| submit_args = { |
| "outcome": "service_unavailable", |
| "kiosk_action": {"action": "block_purchase", "reason_code": "no_service"}, |
| "assistant_message": "Service is not available at the requested time.", |
| } |
| sub_result, sub_err = await self._call_tool(client, "submit_assistant_state", submit_args, case_id) |
| record_tool("submit_assistant_state", submit_args, sub_result, sub_err) |
|
|
| e2e_ms = (time.monotonic() - start_time) * 1000 |
| return self._make_result(case_id, submit_args, tool_calls_made, e2e_ms) |
|
|
| |
| route_args = {"origin": parsed["origin"], "destination": parsed["destination"]} |
| route_result, route_err = await self._call_tool(client, "route_planner", route_args, case_id) |
| record_tool("route_planner", route_args, route_result, route_err) |
|
|
| if route_err: |
| submit_args = { |
| "outcome": "request_declined", |
| "kiosk_action": {"action": "block_purchase", "reason_code": "invalid_request"}, |
| "assistant_message": f"Could not plan route: {route_err[:100]}", |
| } |
| sub_result, sub_err = await self._call_tool(client, "submit_assistant_state", submit_args, case_id) |
| record_tool("submit_assistant_state", submit_args, sub_result, sub_err) |
|
|
| e2e_ms = (time.monotonic() - start_time) * 1000 |
| return self._make_result(case_id, submit_args, tool_calls_made, e2e_ms) |
|
|
| |
| pax = { |
| "adults": parsed["adults"] if parsed["has_pax"] else 1, |
| "children": parsed["children"], |
| "seniors": parsed["seniors"], |
| "disabled": parsed["disabled"], |
| } |
| fare_args = { |
| "route_id": route_result["route_id"], |
| "passengers": pax, |
| "ticket_type": "single", |
| } |
| if parsed["payment_method"]: |
| fare_args["payment_method"] = parsed["payment_method"] |
| fare_result, fare_err = await self._call_tool(client, "fare_calculator", fare_args, case_id) |
| record_tool("fare_calculator", fare_args, fare_result, fare_err) |
|
|
| if fare_err: |
| submit_args = { |
| "outcome": "request_declined", |
| "kiosk_action": {"action": "block_purchase", "reason_code": "invalid_request"}, |
| "assistant_message": f"Could not calculate fare: {fare_err[:100]}", |
| } |
| sub_result, sub_err = await self._call_tool(client, "submit_assistant_state", submit_args, case_id) |
| record_tool("submit_assistant_state", submit_args, sub_result, sub_err) |
|
|
| e2e_ms = (time.monotonic() - start_time) * 1000 |
| return self._make_result(case_id, submit_args, tool_calls_made, e2e_ms) |
|
|
| route_data = self._build_route(route_result) |
| fare_data = self._build_fare_quote(fare_result, pax) |
|
|
| |
| advisory_banners = [] |
| outcome = "route_and_fare_ready" |
| action = "prompt_purchase" |
| reason_code = "ok" |
|
|
| if active_disruptions: |
| dis_args = {"severity_filter": "all"} |
| dis_result, dis_err = await self._call_tool(client, "disruption_feed", dis_args, case_id) |
| record_tool("disruption_feed", dis_args, dis_result, dis_err) |
|
|
| if dis_result: |
| route_stops = {s["station_id"] for s in route_result["stops"]} |
| route_lines = set(route_result.get("line_sequence", [])) |
| for d in dis_result.get("disruptions", []): |
| affected_stations = set(d.get("segment") or []) |
| affected_line = d.get("line") |
| if affected_stations & route_stops or (affected_line and affected_line in route_lines): |
| advisory_banners.append({ |
| "severity": "critical" if d["severity"] == "critical" else "warning", |
| "title": d["type"].replace("_", " ").title(), |
| "body": d["message"], |
| }) |
| outcome = "advisory_only" |
| action = "display_info" |
|
|
| |
| if case.get("system_context", {}).get("accessibility_mode"): |
| for stop in route_result["stops"]: |
| si_args = {"station_id": stop["station_id"], "query_type": "accessibility"} |
| si_result, si_err = await self._call_tool(client, "station_info", si_args, case_id) |
| record_tool("station_info", si_args, si_result, si_err) |
|
|
| if si_result: |
| acc = si_result.get("accessibility", {}) |
| issues = [] |
| if not acc.get("step_free"): |
| issues.append("not step-free") |
| if not acc.get("elevators"): |
| issues.append("no elevators") |
| if issues: |
| advisory_banners.append({ |
| "severity": "warning", |
| "title": f"Accessibility: {stop['station_name']}", |
| "body": f"{stop['station_name']}: {', '.join(issues)}", |
| }) |
| outcome = "advisory_only" |
| action = "display_info" |
| reason_code = "accessibility_issue" |
|
|
| |
| msg_parts = [f"Route: {route_data['origin']} to {route_data['destination']}"] |
| msg_parts.append(f"{route_data['transfers']} transfer(s), ~{route_data['estimated_minutes']} min") |
| msg_parts.append(f"Fare: {fare_data['total']} {fare_data['currency']}") |
| if advisory_banners: |
| for b in advisory_banners: |
| msg_parts.append(f"{b['severity'].upper()}: {b['body']}") |
| assistant_message = ". ".join(msg_parts) |
|
|
| |
| submit_args = { |
| "outcome": outcome, |
| "route": route_data, |
| "kiosk_action": {"action": action, "reason_code": reason_code}, |
| "assistant_message": assistant_message, |
| } |
| if outcome == "route_and_fare_ready": |
| submit_args["fare_quote"] = fare_data |
| if advisory_banners: |
| submit_args["advisory_banners"] = advisory_banners |
|
|
| sub_result, sub_err = await self._call_tool(client, "submit_assistant_state", submit_args, case_id) |
| record_tool("submit_assistant_state", submit_args, sub_result, sub_err) |
|
|
| e2e_ms = (time.monotonic() - start_time) * 1000 |
| return self._make_result(case_id, submit_args, tool_calls_made, e2e_ms) |
|
|
| def _make_result(self, case_id: str, submit_args: dict, tool_calls_made: list, e2e_ms: float) -> CaseResult: |
| """Build CaseResult matching runner output format.""" |
| parsed = { |
| "outcome": submit_args.get("outcome", ""), |
| "kiosk_action": submit_args.get("kiosk_action", {}), |
| "reasoning": "", |
| "ui_updates": { |
| "route": submit_args.get("route"), |
| "fare_quote": submit_args.get("fare_quote"), |
| "advisory_banners": submit_args.get("advisory_banners", []), |
| "assistant_message": submit_args.get("assistant_message", ""), |
| }, |
| } |
| return CaseResult( |
| case_id=case_id, |
| response=parsed, |
| tool_calls_made=tool_calls_made, |
| raw_content=json.dumps(submit_args), |
| reasoning_content="", |
| messages=[], |
| ttft_ms=0.0, |
| e2e_ms=round(e2e_ms, 1), |
| input_tokens=0, |
| output_tokens=0, |
| api_rounds=0, |
| error=None, |
| ) |
|
|
| async def run(self, cases: list[dict]) -> list[CaseResult]: |
| """Run all cases through the rule-based agent.""" |
| async with httpx.AsyncClient() as client: |
| tasks = [self._run_with_semaphore(client, case) for case in cases] |
| return await asyncio.gather(*tasks) |
|
|
| async def _run_with_semaphore(self, client: httpx.AsyncClient, case: dict) -> CaseResult: |
| async with self.semaphore: |
| try: |
| return await self._run_single_case(client, case) |
| except Exception as e: |
| return CaseResult( |
| case_id=case["id"], |
| response=None, |
| tool_calls_made=[], |
| raw_content="", |
| reasoning_content="", |
| messages=[], |
| ttft_ms=0.0, |
| e2e_ms=0.0, |
| input_tokens=0, |
| output_tokens=0, |
| api_rounds=0, |
| error=str(e), |
| ) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Rule-based baseline agent") |
| parser.add_argument("--cases", required=True, help="Path to cases JSON") |
| parser.add_argument("--system", default="marta", help="Transit system name") |
| parser.add_argument("--mock-url", default="http://localhost:8100", help="Mock server URL") |
| parser.add_argument("--parallel", type=int, default=4, help="Parallel requests") |
| parser.add_argument("--output", default=None, help="Output path") |
| parser.add_argument("--limit", type=int, default=None, help="Limit number of cases") |
| args = parser.parse_args() |
|
|
| with open(args.cases) as f: |
| cases = json.load(f) |
| if args.limit: |
| cases = cases[:args.limit] |
|
|
| print(f"Running {len(cases)} cases with rule-based agent") |
| print(f"Mock server: {args.mock_url}, parallel: {args.parallel}") |
|
|
| agent = RuleAgent( |
| mock_server_url=args.mock_url, |
| system_name=args.system, |
| parallel=args.parallel, |
| ) |
|
|
| cases_checksum = hashlib.sha256(Path(args.cases).read_bytes()).hexdigest()[:12] |
| started_at = datetime.now(timezone.utc).isoformat() |
| results = asyncio.run(agent.run(cases)) |
| finished_at = datetime.now(timezone.utc).isoformat() |
|
|
| output = { |
| "metadata": { |
| "harness_version": "0.4.0", |
| "started_at": started_at, |
| "finished_at": finished_at, |
| "llm_base_url": "rule-based", |
| "llm_model": "rule-based", |
| "temperature": 0.0, |
| "max_tokens": 0, |
| "max_tool_rounds": 1, |
| "thinking": False, |
| "parallel": args.parallel, |
| "system": args.system, |
| "cases_file": args.cases, |
| "cases_checksum_sha256": cases_checksum, |
| }, |
| "model": "rule-based", |
| "system": args.system, |
| "thinking": False, |
| "cases_total": len(cases), |
| "cases_succeeded": sum(1 for r in results if r.error is None), |
| "cases_failed": sum(1 for r in results if r.error is not None), |
| "results": [asdict(r) for r in results], |
| } |
|
|
| output_path = args.output or f"results/{args.system}_rule_based.json" |
| Path(output_path).parent.mkdir(parents=True, exist_ok=True) |
| with open(output_path, "w") as f: |
| json.dump(output, f, indent=2) |
|
|
| print(f"\nResults written to {output_path}") |
| print(f" Succeeded: {output['cases_succeeded']}/{output['cases_total']}") |
| print(f" Failed: {output['cases_failed']}/{output['cases_total']}") |
| for r in results: |
| status = "OK" if r.error is None else f"ERR: {r.error[:60]}" |
| print(f" {r.case_id}: {status} ({len(r.tool_calls_made)} tool calls, {r.e2e_ms:.0f}ms)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|