Flight-Transit-Agent / transit_agent.py
Quazim0t0's picture
Upload 34 files
41e0c9e verified
Raw
History Blame Contribute Delete
14.8 kB
"""BAYLINE — the Bay Area Transit agent (second mode of FLIGHTDECK).
Same agent pattern as the flight agent: the LLM picks a tool, the tool runs a
real 511.org call, the LLM reasons over the result. 511 has no trip-planner, so
"fastest route" is reasoned from real-time departures + a scheduled estimate +
live traffic. Reuses the flight agent's trace + JSON-extraction plumbing.
"""
from __future__ import annotations
import datetime as dt
import re
import time
import liquid
import transit
from agent import _extract_json, _new_trace, _save_trace
AVG_BART_KMH = 50.0 # incl. station dwell; rough, labeled as estimate
DRIVE_KMH = 60.0 # free-flow-ish baseline
ROUTE_FACTOR = 1.25 # straight-line -> path correction
TRANSIT_KEYWORDS = {
"bart", "caltrain", "muni", "train", "trains", "bus", "buses", "ferry",
"transit", "departure", "departures", "depart", "leave", "leaving",
"arrive", "arrives", "station", "stop", "route", "fastest", "quickest",
"commute", "traffic", "travel", "trip", "ride", "subway", "metro",
"schedule", "next", "delay", "delays", "way", "get", "go", "bridge",
"freeway", "highway", "incident", "crash", "accident",
}
BAY_PLACES = {
"embarcadero", "berkeley", "oakland", "fremont", "richmond", "dublin",
"pleasanton", "millbrae", "sfo", "airport", "daly", "colma", "concord",
"walnut", "antioch", "pittsburg", "hayward", "san", "francisco", "jose",
"mateo", "bruno", "leandro", "rafael", "mountain", "view", "palo", "alto",
"redwood", "sunnyvale", "santa", "clara", "diridon", "bay", "peninsula",
"downtown", "mission", "civic", "powell", "montgomery", "rockridge",
"macarthur", "lake", "merritt", "coliseum", "city",
}
SYSTEM_PROMPT = """You are BAYLINE, a Bay Area public-transit assistant.
You help people find the fastest way around the SF Bay Area using LIVE 511 data.
TOOLS (call exactly one):
1. plan_trip - fastest way between two places. args: {"origin","destination"}
2. next_departures - real-time departures at one place. args: {"place"}
3. traffic - current road incidents. args: {"area"} (area optional, "" = all)
Reply with ONE JSON object only. Shapes:
{"tool":"plan_trip","origin":"Berkeley","destination":"SFO"}
{"tool":"next_departures","place":"Embarcadero"}
{"tool":"traffic","area":"Bay Bridge"}
{"tool":"none","answer":"<refusal>"}
Rules:
- "from X to Y" / "X to Y" / "fastest way to Y from X" => plan_trip.
- "when/next/departures at Z" => next_departures.
- "traffic/incidents/crash" => traffic.
- If it is NOT about Bay Area travel, use tool "none".
Output JSON only.
Examples:
User: fastest way from Berkeley to SFO
{"tool":"plan_trip","origin":"Berkeley","destination":"SFO"}
User: when is the next train from Embarcadero
{"tool":"next_departures","place":"Embarcadero"}
User: any traffic on the Bay Bridge
{"tool":"traffic","area":"Bay Bridge"}
User: tell me a joke
{"tool":"none","answer":"I only help with Bay Area transit and traffic."}"""
# --------------------------------------------------------------------------- #
def _in_scope(query: str) -> bool:
words = set(re.findall(r"[a-z]+", query.lower()))
return bool(words & TRANSIT_KEYWORDS or words & BAY_PLACES)
def _clean(text: str) -> str:
fill = {"the", "a", "an", "to", "from", "at", "in", "of", "me", "please",
"next", "train", "trains", "bart", "departures", "departure", "when",
"is", "are", "whats", "what", "fastest", "quickest", "way", "get",
"how", "do", "i", "go", "leaving", "leave", "near", "around", "for"}
toks = [t for t in re.findall(r"[a-z0-9]+", text.lower()) if t not in fill]
return " ".join(toks).strip()
def _regex_plan(query):
q = query.lower().strip()
if re.search(r"\b(traffic|incident|incidents|crash|accident|road)\b", q):
m = re.search(r"\b(?:on|near|around|at|in)\s+(.*)", q)
return {"tool": "traffic", "area": (_clean(m.group(1)) if m else "")}
if " to " in q:
left, right = q.split(" to ", 1)
o, d = _clean(left), _clean(right)
if o and d:
return {"tool": "plan_trip", "origin": o, "destination": d}
m = re.search(r"\bfrom\s+(.*)", q)
if m and _clean(m.group(1)):
return {"tool": "next_departures", "place": _clean(m.group(1))}
p = _clean(q)
if p:
return {"tool": "next_departures", "place": p}
return None
# ---- tools ----------------------------------------------------------------- #
def _eta_minutes(km, kmh):
return round(km * ROUTE_FACTOR / kmh * 60)
def _tool_plan_trip(args):
o = transit.resolve_place(args.get("origin", ""))
d = transit.resolve_place(args.get("destination", ""))
if not o or not d:
miss = args.get("origin") if not o else args.get("destination")
return {"error": f"could not find a station for {miss!r}"}, []
km = transit.haversine_km(o["lat"], o["lon"], d["lat"], d["lon"])
deps, _ = transit.station_departures(o["operator"], o["name"], limit=6)
next_wait = deps[0]["minutes"] if deps else None
in_veh = _eta_minutes(km, AVG_BART_KMH)
transit_total = (next_wait or 0) + in_veh
evs, _ = transit.traffic_events()
drive_min = _eta_minutes(km, DRIVE_KMH)
drive_adj = drive_min + min(25, 3 * len(evs)) # crude congestion penalty
markers = [
{"lat": o["lat"], "lon": o["lon"], "label": f"FROM {o['name']}", "kind": "origin"},
{"lat": d["lat"], "lon": d["lon"], "label": f"TO {d['name']}", "kind": "dest"},
]
for e in evs[:8]:
if e.get("lat") and e.get("lon"):
markers.append({"lat": e["lat"], "lon": e["lon"],
"label": f"{e['type']}: {e['headline'][:60]}", "kind": "incident"})
# Deterministic recommendation (the tiny LLM can't be trusted to compare).
if transit_total <= drive_adj:
best, delta = "BART", drive_adj - transit_total
else:
best, delta = "Driving", transit_total - drive_adj
result = {
"origin": o["name"], "destination": d["name"],
"operator": transit.OPERATOR_NAMES.get(o["operator"], o["operator"]),
"distance_km": round(km, 1),
"transit": {"next_departure_min": next_wait, "in_vehicle_min_est": in_veh,
"total_min_est": transit_total,
"departures": deps[:4]},
"driving": {"est_min": drive_adj, "active_incidents": len(evs)},
"recommendation": {"mode": best, "saves_min": delta,
"transit_min": transit_total, "drive_min": drive_adj},
}
return result, markers
def _tool_next_departures(args):
s = transit.resolve_place(args.get("place", ""))
if not s:
return {"error": f"could not find a station for {args.get('place')!r}"}, []
deps, _ = transit.station_departures(s["operator"], s["name"], limit=8)
markers = [{"lat": s["lat"], "lon": s["lon"],
"label": f"{s['name']} ({transit.OPERATOR_NAMES.get(s['operator'], s['operator'])})",
"kind": "origin"}]
return {"station": s["name"],
"operator": transit.OPERATOR_NAMES.get(s["operator"], s["operator"]),
"departures": deps}, markers
def _tool_traffic(args):
area = args.get("area") or ""
evs, _ = transit.traffic_events(area_query=area or None)
markers = [{"lat": e["lat"], "lon": e["lon"],
"label": f"{e['type']}: {e['headline'][:60]}", "kind": "incident"}
for e in evs if e.get("lat") and e.get("lon")]
return {"area": area or "Bay Area", "count": len(evs), "events": evs}, markers
TOOL_IMPLS = {
"plan_trip": _tool_plan_trip,
"next_departures": _tool_next_departures,
"traffic": _tool_traffic,
}
def _summarize(tool, result):
if "error" in result:
return f"Lookup problem: {result['error']}"
if tool == "plan_trip":
t, dr, rec = result["transit"], result["driving"], result["recommendation"]
best_min = rec["transit_min"] if rec["mode"] == "BART" else rec["drive_min"]
verdict = (f"**Fastest: {rec['mode']}** (~{best_min} min) — ~{rec['saves_min']} "
"min faster than the alternative." if rec["saves_min"] > 1 else
f"**{rec['mode']} and driving are about the same** (~{best_min} min).")
lines = [
verdict,
f"Trip {result['origin']}{result['destination']} (~{result['distance_km']} km).",
f"• {result['operator']}: next train in {t['next_departure_min']} min, "
f"~{t['in_vehicle_min_est']} min ride, **~{t['total_min_est']} min total** (est).",
f"• Driving: **~{dr['est_min']} min** est ({dr['active_incidents']} active incidents region-wide).",
f"Next departures from {result['origin']} (all directions): " + ("; ".join(
f"{d['line']}{d['destination']} in {d['minutes']}m"
for d in t["departures"]) or "none"),
"_Transit times are straight-line estimates (511 has no trip planner; "
"transfers not modeled)._",
]
return "\n".join(lines)
if tool == "next_departures":
deps = result["departures"]
head = f"{result['station']} ({result['operator']}) next departures:"
body = "; ".join(f"{d['line']}->{d['destination']} in {d['minutes']}m"
for d in deps) or "no real-time departures right now"
return head + " " + body
if tool == "traffic":
evs = result["events"]
head = f"{result['count']} active incident(s) in {result['area']}:"
body = "; ".join(f"{e['type']} on {e['roads'] or '?'}" for e in evs[:6])
return head + " " + (body or "none")
return str(result)
def _validate(action, query):
rx = _regex_plan(query)
if not action or action.get("tool") in (None, "none", ""):
if rx and rx.get("tool") in TOOL_IMPLS:
return rx, "override: model refused an in-scope query"
return action, None
reason = None
# strong route signal -> plan_trip
if rx and rx.get("tool") == "plan_trip" and action.get("tool") != "plan_trip":
return rx, "override: query has explicit origin->destination"
# traffic area: trust the query, not the model (it hallucinates roads).
if action.get("tool") == "traffic" and rx and rx.get("tool") == "traffic":
if rx.get("area") and rx.get("area") != action.get("area"):
action["area"] = rx["area"]
reason = "override: traffic area taken from query"
# fill missing args from regex
if action.get("tool") == "plan_trip" and not (action.get("origin") and action.get("destination")):
if rx and rx.get("tool") == "plan_trip":
return rx, "repair: filled trip endpoints"
return {"tool": "none", "answer": "Tell me both a start and a destination."}, "repair: no endpoints"
if action.get("tool") == "next_departures" and not action.get("place"):
if rx and rx.get("place"):
action["place"] = rx["place"]
reason = "repair: filled place"
return action, reason
# --------------------------------------------------------------------------- #
def run(query: str, max_tokens=380):
trace = _new_trace(query)
trace["mode_kind"] = "transit"
use_llm = liquid.available()
trace["agent_mode"] = "transit-llm" if use_llm else "transit-regex"
if not _in_scope(query):
ans = ("I'm the Bay Area transit assistant — try 'fastest way from "
"Berkeley to SFO', 'next train from Embarcadero', or 'traffic "
"on the Bay Bridge'.")
trace["answer"] = ans
trace["agent_mode"] += "+scope-refused"
path = _save_trace(trace)
return {"answer": ans, "markers": [], "result": None,
"trace_path": path, "trace_id": trace["trace_id"],
"tool_calls": [], "mode": trace["agent_mode"]}
action = None
if use_llm:
try:
raw, latency = liquid.complete(
[{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": query}],
max_tokens=200, temperature=0.0)
except Exception as e: # noqa: BLE001
raw, latency, use_llm = f"(model error: {e})", 0, False
trace["agent_mode"] = "transit-regex"
action = _extract_json(raw)
trace["steps"].append({"step": 1, "phase": "plan", "model_raw": raw,
"parsed_action": action, "latency_ms": latency})
if action is None:
action = _regex_plan(query)
trace["steps"].append({"step": 1, "phase": "plan-fallback",
"parsed_action": action})
action, override = _validate(action, query)
if override:
trace["steps"].append({"step": 1, "phase": "validate",
"final_action": action, "override_reason": override})
if not action or action.get("tool") in (None, "none", ""):
ans = (action or {}).get("answer", "I can only help with Bay Area transit.")
trace["answer"] = ans
path = _save_trace(trace)
return {"answer": ans, "markers": [], "result": None,
"trace_path": path, "trace_id": trace["trace_id"],
"tool_calls": [], "mode": trace["agent_mode"]}
tool = action.get("tool")
impl = TOOL_IMPLS.get(tool)
t0 = time.time()
try:
result, markers = impl(action)
error = result.get("error") if isinstance(result, dict) else None
except transit.Transit511Error as e:
result, markers, error = {"error": str(e)}, [], str(e)
except Exception as e: # noqa: BLE001
result, markers, error = {"error": repr(e)}, [], repr(e)
latency = int((time.time() - t0) * 1000)
call = {"tool": tool, "args": {k: v for k, v in action.items() if k != "tool"},
"latency_ms": latency, "error": error,
"result_count": len(markers)}
trace["tool_calls"].append(call)
trace["steps"].append({"step": 2, "phase": "act", **call})
trace["flights_returned"] = len(markers)
# Answer is the deterministic, fact-checked summary — a 350M model flips
# numeric comparisons, and "fastest route" must be correct. The LLM still
# drives the agentic part (tool selection) above.
summary = _summarize(tool, result)
answer = f"Couldn't complete that: {error}" if error else summary
trace["answer"] = answer
path = _save_trace(trace)
return {"answer": answer, "markers": markers, "result": result,
"trace_path": path, "trace_id": trace["trace_id"],
"tool_calls": [c["tool"] for c in trace["tool_calls"]],
"mode": trace["agent_mode"]}