agg-trace-visualizer / backend /api /arena_datasets.py
Zayne Rea Sprague
Initial deploy: aggregate trace visualizer
8b41737
import json
import hashlib
import os
from flask import Blueprint, request, jsonify
from datasets import load_dataset, Dataset
bp = Blueprint("arena_datasets", __name__, url_prefix="/api/arena/datasets")
# In-memory cache: id -> dataset info
_cache: dict[str, dict] = {}
def _make_id(repo: str, split: str) -> str:
key = f"{repo}:{split}"
return hashlib.md5(key.encode()).hexdigest()[:12]
def _load_hf_dataset(repo: str, split: str) -> Dataset:
if os.path.exists(repo):
return Dataset.from_parquet(repo)
return load_dataset(repo, split=split)
def _detect_arena_dataset(columns: list[str]) -> bool:
"""Check if this looks like an arena evaluation dataset."""
required = {"game_id", "env_id", "transcript"}
return required.issubset(set(columns))
def _analyze_action(text: str) -> dict:
"""Split <think> tags from action text and compute analytics."""
if not text:
return {"think_text": "", "action_text": "", "think_len": 0, "action_len": 0,
"backtracks": 0, "restarts": 0}
think_end = text.find("</think>")
if think_end > 0:
think_text = text[:think_end + 8]
action_text = text[think_end + 8:].strip()
else:
think_text = ""
action_text = text
t = text.lower()
backtracks = sum(t.count(w) for w in
["wait,", "wait ", "hmm", "let me try", "try again",
"another approach", "let me reconsider"])
restarts = sum(t.count(w) for w in
["start over", "fresh approach", "different approach", "from scratch"])
return {
"think_text": think_text,
"action_text": action_text,
"think_len": len(think_text),
"action_len": len(action_text),
"backtracks": backtracks,
"restarts": restarts,
}
def _dedup_observation(text: str, prev_text: str) -> str:
"""Remove content duplicated from the previous observation.
TextArena accumulates the full chat history in each observation,
so turn N's observation repeats everything from turns 0..N-1
plus echoed [Player] actions. We strip the repeated prefix and
the echoed player actions, keeping only new [GAME]/[Moderator]
content for this turn.
"""
import re
if not text:
return ""
if not prev_text:
return text
new_part = None
# The previous observation text should appear as a prefix of the
# current one. Strip it to get only what's new.
if text.startswith(prev_text):
new_part = text[len(prev_text):].strip()
else:
# Fallback: find the longest common prefix
min_len = min(len(text), len(prev_text))
common = 0
for i in range(min_len):
if text[i] == prev_text[i]:
common = i + 1
else:
break
if common > len(prev_text) * 0.8:
new_part = text[common:].strip()
if not new_part:
return text
# After stripping the observation prefix, the remaining text typically
# starts with echoed [Player] actions (already shown in action bubbles),
# followed by new [GAME] or [Moderator] content. Strip the echoed
# player actions to keep only the new game content.
game_marker = re.search(r'\[GAME\]|\[Moderator\]', new_part)
if game_marker:
game_content = new_part[game_marker.start():].strip()
return game_content if game_content else new_part
return new_part
def _get_env_ids(ds: Dataset) -> list[str]:
"""Get sorted unique env_ids from dataset."""
return sorted(set(ds["env_id"]))
def _group_episodes_by_env(ds: Dataset) -> dict[str, list[int]]:
"""Group row indices by env_id."""
groups: dict[str, list[int]] = {}
for i in range(len(ds)):
env_id = ds[i]["env_id"]
if env_id not in groups:
groups[env_id] = []
groups[env_id].append(i)
return groups
@bp.route("/load", methods=["POST"])
def load_dataset_endpoint():
data = request.get_json()
repo = data.get("repo", "").strip()
if not repo:
return jsonify({"error": "repo is required"}), 400
split = data.get("split", "train")
try:
ds = _load_hf_dataset(repo, split)
except Exception as e:
return jsonify({"error": f"Failed to load dataset: {e}"}), 400
columns = ds.column_names
if not _detect_arena_dataset(columns):
return jsonify({
"error": f"Not an arena dataset. Expected columns: game_id, env_id, transcript. Found: {columns}"
}), 400
env_ids = _get_env_ids(ds)
episode_groups = _group_episodes_by_env(ds)
ds_id = _make_id(repo, split)
short_name = repo.rsplit("/", 1)[-1] if "/" in repo else repo
# Extract model name from first row
model_name = ds[0].get("model", "unknown") if len(ds) > 0 else "unknown"
# Compute win/loss/error counts
wins = sum(1 for r in ds["reward"] if r is not None and r > 0)
losses = sum(1 for i in range(len(ds)) if ds[i]["reward"] is not None and ds[i]["reward"] <= 0)
errors = sum(1 for e in ds["error"] if e is not None)
_cache[ds_id] = {
"dataset": ds,
"repo": repo,
"split": split,
"n_rows": len(ds),
"env_ids": env_ids,
"episode_groups": episode_groups,
"model_name": model_name,
"stats": {"wins": wins, "losses": losses, "errors": errors},
}
return jsonify({
"id": ds_id,
"repo": repo,
"name": short_name,
"split": split,
"columns": columns,
"n_rows": len(ds),
"env_ids": env_ids,
"episodes_per_env": {env: len(idxs) for env, idxs in episode_groups.items()},
"model_name": model_name,
"stats": {"wins": wins, "losses": losses, "errors": errors},
})
@bp.route("/", methods=["GET"])
def list_datasets():
result = []
for ds_id, info in _cache.items():
result.append({
"id": ds_id,
"repo": info["repo"],
"name": info["repo"].rsplit("/", 1)[-1] if "/" in info["repo"] else info["repo"],
"split": info["split"],
"n_rows": info["n_rows"],
"env_ids": info["env_ids"],
"model_name": info["model_name"],
})
return jsonify(result)
@bp.route("/<ds_id>/episode/<env_id>/<int:idx>", methods=["GET"])
def get_episode(ds_id, env_id, idx):
"""Get a single episode by env_id and episode index within that env."""
if ds_id not in _cache:
return jsonify({"error": "Dataset not loaded"}), 404
info = _cache[ds_id]
ds = info["dataset"]
episode_groups = info["episode_groups"]
if env_id not in episode_groups:
return jsonify({"error": f"env_id '{env_id}' not found"}), 404
indices = episode_groups[env_id]
if idx < 0 or idx >= len(indices):
return jsonify({"error": f"Episode index {idx} out of range (0-{len(indices)-1})"}), 400
row_idx = indices[idx]
row = ds[row_idx]
# Parse transcript JSON
transcript_raw = row.get("transcript", "[]")
try:
transcript = json.loads(transcript_raw) if isinstance(transcript_raw, str) else transcript_raw
except json.JSONDecodeError:
transcript = []
# Analyze each turn: dedup observations, split think tags from actions
analyzed_turns = []
prev_obs_raw = ""
for turn in transcript:
action_analysis = _analyze_action(turn.get("action", ""))
obs_raw = turn.get("observation", "")
obs_deduped = _dedup_observation(obs_raw, prev_obs_raw)
prev_obs_raw = obs_raw
analyzed_turns.append({
"turn": turn.get("turn", 0),
"player_id": turn.get("player_id", 0),
"observation": obs_deduped,
"action": turn.get("action", ""),
"think_text": action_analysis["think_text"],
"action_text": action_analysis["action_text"],
"think_len": action_analysis["think_len"],
"backtracks": action_analysis["backtracks"],
})
# Determine outcome
reward = row.get("reward")
error = row.get("error")
if error:
outcome = "error"
elif reward is not None and reward > 0:
outcome = "win"
elif reward is not None:
outcome = "loss"
else:
outcome = "unknown"
return jsonify({
"game_id": row.get("game_id", ""),
"env_id": row.get("env_id", ""),
"model": row.get("model", ""),
"opponent_model": row.get("opponent_model"),
"player_id": row.get("player_id", 0),
"reward": reward,
"num_turns": row.get("num_turns", len(transcript)),
"error": error,
"outcome": outcome,
"transcript": analyzed_turns,
"system_prompt": row.get("system_prompt", None),
"episode_idx": idx,
"total_episodes": len(indices),
})
@bp.route("/<ds_id>", methods=["DELETE"])
def unload_dataset(ds_id):
if ds_id in _cache:
del _cache[ds_id]
return jsonify({"status": "ok"})