#!/usr/bin/env python3 import argparse import json import re from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import streamlit as st ROLE_STYLE = { "system": {"label": "SYSTEM", "color": "#4B5563", "bg": "#F3F4F6"}, "user": {"label": "USER", "color": "#1D4ED8", "bg": "#DBEAFE"}, "assistant": {"label": "ASSISTANT", "color": "#065F46", "bg": "#D1FAE5"}, "tool": {"label": "TOOL", "color": "#7C2D12", "bg": "#FFEDD5"}, } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Streamlit viewer for eval-agent trajectory files.") parser.add_argument("--dir", type=str, default="", help="Directory containing agent run files.") return parser.parse_args() def file_sort_key(path: Path) -> Tuple[int, int, str]: m = re.match(r"gen_(\d+)_(.*)$", path.name) if not m: return (10**9, 10**9, path.name) gen = int(m.group(1)) suffix = m.group(2) order = { "task_message.txt": 0, "result.json": 1, "trajectory_messages.json": 2, }.get(suffix, 99) return (gen, order, path.name) def try_load_json(path: Path) -> Optional[Any]: try: with open(path, "r", encoding="utf-8") as f: return json.load(f) except Exception: return None def extract_text_from_message(message: Dict[str, Any]) -> str: text_parts: List[str] = [] content = message.get("content") if isinstance(content, list): for item in content: if isinstance(item, dict) and item.get("type") == "text": text = item.get("text") if isinstance(text, str) and text: text_parts.append(text) return "\n".join(text_parts).strip() def trajectory_summary(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: rows: List[Dict[str, Any]] = [] for idx, msg in enumerate(messages): role = msg.get("role", "unknown") text = extract_text_from_message(msg) preview = text[:120] + ("..." if len(text) > 120 else "") tool_calls = msg.get("tool_calls") tool_call_count = len(tool_calls) if isinstance(tool_calls, list) else 0 rows.append( { "idx": idx, "role": role, "tool_calls": tool_call_count, "chars": len(text), "preview": preview, } ) return rows def render_trajectory(messages: List[Dict[str, Any]]): st.subheader("Trajectory Overview") rows = trajectory_summary(messages) st.dataframe(rows, width="stretch") st.subheader("Full Message Timeline") show_raw = st.checkbox("Show raw dict under each message", value=False) for idx, msg in enumerate(messages): role = str(msg.get("role", "unknown")) style = ROLE_STYLE.get(role, {"label": role.upper(), "color": "#111827", "bg": "#F9FAFB"}) text = extract_text_from_message(msg) tool_calls = msg.get("tool_calls") tool_call_count = len(tool_calls) if isinstance(tool_calls, list) else 0 title = f"{style['label']} #{idx}" if tool_call_count > 0: title += f" | tool_calls={tool_call_count}" st.markdown( ( f"
" f"" f"{title}
" ), unsafe_allow_html=True, ) show_msg = st.toggle(f"Show message #{idx}", value=True, key=f"show_msg_{idx}") if show_msg: if text: st.markdown( ( f"
" f"{text}
" ), unsafe_allow_html=True, ) else: st.caption("") if show_raw and show_msg: st.json(msg) def main(): args = parse_args() st.set_page_config(page_title="Eval Agent Trajectory Viewer", layout="wide") st.title("Eval Agent Trajectory Viewer") default_dir = args.dir or "" run_dir_input = st.sidebar.text_input("Run directory", value=default_dir) run_dir = Path(run_dir_input).expanduser() if run_dir_input else None if not run_dir_input: st.info("Pass `--dir` or set the directory in the sidebar.") return if not run_dir or not run_dir.exists() or not run_dir.is_dir(): st.error(f"Directory not found: {run_dir_input}") return files = sorted([p for p in run_dir.iterdir() if p.is_file()], key=file_sort_key) if not files: st.warning("No files found in this directory.") return file_names = [p.name for p in files] selected_name = st.sidebar.selectbox("Select file", options=file_names, index=0) selected_path = run_dir / selected_name st.caption(f"Selected: `{selected_path}`") st.caption(f"Size: {selected_path.stat().st_size:,} bytes") if selected_name.endswith("_trajectory_messages.json"): data = try_load_json(selected_path) if not isinstance(data, list): st.error("Trajectory file is not a JSON list.") return msg_list = [x for x in data if isinstance(x, dict)] st.success(f"Loaded {len(msg_list)} message dicts.") render_trajectory(msg_list) elif selected_name.endswith(".json"): data = try_load_json(selected_path) if data is None: st.error("Failed to parse JSON.") raw = selected_path.read_text(encoding="utf-8", errors="replace") st.code(raw, language="json") else: st.json(data) else: raw = selected_path.read_text(encoding="utf-8", errors="replace") st.code(raw, language="text") if __name__ == "__main__": main()