#!/usr/bin/env python3 import argparse import json import re from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import streamlit as st ROLE_STYLE = { "system": {"label": "SYSTEM", "color": "#4B5563", "bg": "#F3F4F6"}, "user": {"label": "USER", "color": "#1D4ED8", "bg": "#DBEAFE"}, "assistant": {"label": "ASSISTANT", "color": "#065F46", "bg": "#D1FAE5"}, "tool": {"label": "TOOL", "color": "#7C2D12", "bg": "#FFEDD5"}, } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Streamlit viewer for eval-agent trajectory files.") parser.add_argument("--dir", type=str, default="", help="Directory containing agent run files.") return parser.parse_args() def file_sort_key(path: Path) -> Tuple[int, int, str]: m = re.match(r"gen_(\d+)_(.*)$", path.name) if not m: return (10**9, 10**9, path.name) gen = int(m.group(1)) suffix = m.group(2) order = { "task_message.txt": 0, "result.json": 1, "trajectory_messages.json": 2, }.get(suffix, 99) return (gen, order, path.name) def try_load_json(path: Path) -> Optional[Any]: try: with open(path, "r", encoding="utf-8") as f: return json.load(f) except Exception: return None def extract_text_from_message(message: Dict[str, Any]) -> str: text_parts: List[str] = [] content = message.get("content") if isinstance(content, list): for item in content: if isinstance(item, dict) and item.get("type") == "text": text = item.get("text") if isinstance(text, str) and text: text_parts.append(text) return "\n".join(text_parts).strip() def trajectory_summary(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: rows: List[Dict[str, Any]] = [] for idx, msg in enumerate(messages): role = msg.get("role", "unknown") text = extract_text_from_message(msg) preview = text[:120] + ("..." if len(text) > 120 else "") tool_calls = msg.get("tool_calls") tool_call_count = len(tool_calls) if isinstance(tool_calls, list) else 0 rows.append( { "idx": idx, "role": role, "tool_calls": tool_call_count, "chars": len(text), "preview": preview, } ) return rows def render_trajectory(messages: List[Dict[str, Any]]): st.subheader("Trajectory Overview") rows = trajectory_summary(messages) st.dataframe(rows, width="stretch") st.subheader("Full Message Timeline") show_raw = st.checkbox("Show raw dict under each message", value=False) for idx, msg in enumerate(messages): role = str(msg.get("role", "unknown")) style = ROLE_STYLE.get(role, {"label": role.upper(), "color": "#111827", "bg": "#F9FAFB"}) text = extract_text_from_message(msg) tool_calls = msg.get("tool_calls") tool_call_count = len(tool_calls) if isinstance(tool_calls, list) else 0 title = f"{style['label']} #{idx}" if tool_call_count > 0: title += f" | tool_calls={tool_call_count}" st.markdown( ( f"