metrollm / render.py
Remco Hendriks
landing replays for all 6 systems + render helpers
d23bf95
Raw
History Blame Contribute Delete
8.75 kB
"""Pure-function rendering helpers for the kiosk chat + map UI.
Lives separately from app.py so capture/replay tooling can import these
without triggering the (heavy) model load at app.py module level.
"""
from __future__ import annotations
import html
import json
import re
KNOWN_TOOLS = {
"route_planner", "fare_calculator", "station_info", "line_info",
"disruption_feed", "knowledge_base", "submit_assistant_state",
}
# --- narrative cleanup -----------------------------------------------------
_THINK_RE = re.compile(r"<think>.*?</think>\s*", re.DOTALL)
_TOOL_CALL_RE = re.compile(r"<tool_call>.*?</tool_call>\s*", re.DOTALL)
def strip_think_block(text: str) -> str:
"""Remove <think>...</think> content (and stray close-tags emitted by the
Qwen3.5 chat template) before showing the raw model output."""
text = _THINK_RE.sub("", text)
return text.replace("</think>", "").strip()
def strip_tool_calls(text: str) -> str:
"""Remove every <tool_call>...</tool_call> block so what remains is just
the model's narrative for chat display."""
return _TOOL_CALL_RE.sub("", text)
def to_narrative(text: str) -> str:
"""Narrative text shown in the chat bubble β€” neither think tags nor
tool-call XML belong there."""
return strip_think_block(strip_tool_calls(text))
# --- tool card -------------------------------------------------------------
def tool_status(name: str, result) -> tuple[str, str]:
"""(status_label, glyph) for the UI metadata block.
- hallucinated β†’ βœ— (tool name unknown)
- registered + error β†’ ⚠ (validation/domain error)
- else β†’ βœ“"""
if name not in KNOWN_TOOLS:
return "hallucinated", "βœ—"
if isinstance(result, dict) and result.get("error"):
return "rejected", "⚠"
return "done", "βœ“"
def summarise_result(name: str, result) -> str:
if not isinstance(result, dict):
return ""
if result.get("error"):
err = result["error"]
if isinstance(err, dict) and isinstance(err.get("detail"), list):
parts = []
for d in err["detail"][:4]:
loc = ".".join(str(x) for x in (d.get("loc") or [])[1:])
parts.append(f"{loc}: {d.get('msg', 'invalid')}")
return "; ".join(parts)
return str(err)[:240]
if name == "route_planner" and isinstance(result.get("stops"), list):
t = result.get("transfers", 0)
return f"{len(result['stops'])} stops Β· {t} transfer{'s' if t != 1 else ''} Β· {round(result.get('estimated_minutes', 0))} min"
if name == "fare_calculator":
return f"total {result.get('currency', '')} {result.get('total', '?')}"
if name == "submit_assistant_state":
return "accepted" if result.get("accepted") else "rejected"
if name == "disruption_feed" and isinstance(result.get("disruptions"), list):
return f"{len(result['disruptions'])} active"
if name == "station_info" and isinstance(result.get("stations"), list):
return f"{len(result['stations'])} station(s)"
return ""
def summarise_args(name: str, args: dict, bundle: dict) -> str:
"""One-line preview of tool args (station IDs resolved to names where
possible) β€” paired with `summarise_result` to drive the tool card.
Field names mirror the Pydantic schemas in harness/mock_server.py."""
sb = bundle["station_by_id"]
if name == "route_planner":
f = (sb.get(args.get("origin")) or {}).get("name") or args.get("origin") or "?"
t = (sb.get(args.get("destination")) or {}).get("name") or args.get("destination") or "?"
s = f"{f} β†’ {t}"
extras = []
rs = args.get("station_restrictions") or []
sc = args.get("segment_closures") or []
lc = args.get("line_closures") or []
if rs:
extras.append(f"{len(rs)} restriction{'s' if len(rs) != 1 else ''}")
if sc:
extras.append(f"{len(sc)} closure{'s' if len(sc) != 1 else ''}")
if lc:
extras.append(f"{len(lc)} line closure{'s' if len(lc) != 1 else ''}")
if extras:
s += " Β· " + ", ".join(extras)
return s
if name == "fare_calculator":
passengers = args.get("passengers") or {}
if isinstance(passengers, dict):
parts = [f"{v}Γ— {k}" for k, v in passengers.items() if isinstance(v, int) and v > 0]
base = ", ".join(parts) if parts else "β€”"
else:
base = "β€”"
rid = args.get("route_id")
return f"{base} Β· {rid}" if rid else base
if name == "station_info":
ids = args.get("station_ids")
if isinstance(ids, list):
if len(ids) == 1:
return (sb.get(ids[0]) or {}).get("name") or ids[0]
return f"{len(ids)} stations"
sid = args.get("station_id")
if sid:
return (sb.get(sid) or {}).get("name") or sid
return "β€”"
if name == "line_info":
lines = args.get("lines")
if isinstance(lines, list) and lines:
return ", ".join(str(l) for l in lines[:4]) + ("…" if len(lines) > 4 else "")
return str(args.get("line") or "β€”")
if name == "disruption_feed":
line = args.get("line")
station = args.get("station")
sev = args.get("severity_filter")
parts = []
if line:
parts.append(f"line {line}")
if station:
parts.append(f"@{(sb.get(station) or {}).get('name') or station}")
if sev and sev != "all":
parts.append(sev)
return " Β· ".join(parts) if parts else "all disruptions"
if name == "knowledge_base":
if args.get("policy_id"):
return f"policy: {args['policy_id']}"
q = (args.get("query") or "").strip()
return (q[:80] + "…") if len(q) > 80 else (q or "β€”")
if name == "submit_assistant_state":
out = args.get("outcome") or "β€”"
ka = args.get("kiosk_action") or {}
action = ka.get("action") if isinstance(ka, dict) else ka
return f"{out} / {action or 'β€”'}"
return "β€”"
def format_tool_card(name: str, args: dict, result, elapsed_ms: float,
status: str, glyph: str, bundle: dict) -> str:
"""Compact HTML card for one tool round-trip. Top: glyph + name + time.
Body: in/out one-liners. Bottom: <details> hides the raw JSON for power
users. Renders inside a chatbot bot bubble (chatbot has sanitize_html
disabled so the <div>/<details> survive)."""
args_preview = summarise_args(name, args, bundle)
result_preview = summarise_result(name, result) or "β€”"
args_json = json.dumps(args, indent=2, ensure_ascii=False)
result_json = json.dumps(result, indent=2, default=str, ensure_ascii=False)
if len(result_json) > 1200:
result_json = result_json[:1197] + "…"
e = html.escape
return (
f'<div class="metro-tool-card metro-tool-{e(status)}">'
f'<div class="metro-tool-head">'
f'<span class="metro-tool-glyph">{e(glyph)}</span>'
f'<span class="metro-tool-name">{e(name)}</span>'
f'<span class="metro-tool-time">{elapsed_ms:.0f}ms</span>'
f'</div>'
f'<div class="metro-tool-row">'
f'<span class="metro-tool-key">in</span>'
f'<span class="metro-tool-val">{e(args_preview)}</span>'
f'</div>'
f'<div class="metro-tool-row">'
f'<span class="metro-tool-key">out</span>'
f'<span class="metro-tool-val">{e(result_preview)}</span>'
f'</div>'
f'<details class="metro-tool-raw">'
f'<summary>raw</summary>'
f'<div class="metro-tool-raw-label">args</div>'
f'<pre>{e(args_json)}</pre>'
f'<div class="metro-tool-raw-label">result</div>'
f'<pre>{e(result_json)}</pre>'
f'</details>'
f'</div>'
)
# --- route geometry --------------------------------------------------------
def route_with_coords(stops, station_by_id: dict) -> list:
"""Annotate route_planner stops with lat/lon + carry through line and
transfer flags so the JS can draw per-line segments and mark transfer
stations (matching the /simulator dashboard's drawRoute behavior)."""
out = []
for s in stops or []:
if not isinstance(s, dict):
continue
sid = s.get("station_id")
geo = station_by_id.get(sid, {}) if sid else {}
out.append({
"station_id": sid,
"name": s.get("station_name") or geo.get("name"),
"lat": geo.get("lat"),
"lon": geo.get("lon"),
"line": s.get("line"),
"is_transfer": bool(s.get("is_transfer")),
})
return out