File size: 9,052 Bytes
8b41737 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 | import json
import hashlib
import os
from flask import Blueprint, request, jsonify
from datasets import load_dataset, Dataset
bp = Blueprint("arena_datasets", __name__, url_prefix="/api/arena/datasets")
# In-memory cache: id -> dataset info
_cache: dict[str, dict] = {}
def _make_id(repo: str, split: str) -> str:
key = f"{repo}:{split}"
return hashlib.md5(key.encode()).hexdigest()[:12]
def _load_hf_dataset(repo: str, split: str) -> Dataset:
if os.path.exists(repo):
return Dataset.from_parquet(repo)
return load_dataset(repo, split=split)
def _detect_arena_dataset(columns: list[str]) -> bool:
"""Check if this looks like an arena evaluation dataset."""
required = {"game_id", "env_id", "transcript"}
return required.issubset(set(columns))
def _analyze_action(text: str) -> dict:
"""Split <think> tags from action text and compute analytics."""
if not text:
return {"think_text": "", "action_text": "", "think_len": 0, "action_len": 0,
"backtracks": 0, "restarts": 0}
think_end = text.find("</think>")
if think_end > 0:
think_text = text[:think_end + 8]
action_text = text[think_end + 8:].strip()
else:
think_text = ""
action_text = text
t = text.lower()
backtracks = sum(t.count(w) for w in
["wait,", "wait ", "hmm", "let me try", "try again",
"another approach", "let me reconsider"])
restarts = sum(t.count(w) for w in
["start over", "fresh approach", "different approach", "from scratch"])
return {
"think_text": think_text,
"action_text": action_text,
"think_len": len(think_text),
"action_len": len(action_text),
"backtracks": backtracks,
"restarts": restarts,
}
def _dedup_observation(text: str, prev_text: str) -> str:
"""Remove content duplicated from the previous observation.
TextArena accumulates the full chat history in each observation,
so turn N's observation repeats everything from turns 0..N-1
plus echoed [Player] actions. We strip the repeated prefix and
the echoed player actions, keeping only new [GAME]/[Moderator]
content for this turn.
"""
import re
if not text:
return ""
if not prev_text:
return text
new_part = None
# The previous observation text should appear as a prefix of the
# current one. Strip it to get only what's new.
if text.startswith(prev_text):
new_part = text[len(prev_text):].strip()
else:
# Fallback: find the longest common prefix
min_len = min(len(text), len(prev_text))
common = 0
for i in range(min_len):
if text[i] == prev_text[i]:
common = i + 1
else:
break
if common > len(prev_text) * 0.8:
new_part = text[common:].strip()
if not new_part:
return text
# After stripping the observation prefix, the remaining text typically
# starts with echoed [Player] actions (already shown in action bubbles),
# followed by new [GAME] or [Moderator] content. Strip the echoed
# player actions to keep only the new game content.
game_marker = re.search(r'\[GAME\]|\[Moderator\]', new_part)
if game_marker:
game_content = new_part[game_marker.start():].strip()
return game_content if game_content else new_part
return new_part
def _get_env_ids(ds: Dataset) -> list[str]:
"""Get sorted unique env_ids from dataset."""
return sorted(set(ds["env_id"]))
def _group_episodes_by_env(ds: Dataset) -> dict[str, list[int]]:
"""Group row indices by env_id."""
groups: dict[str, list[int]] = {}
for i in range(len(ds)):
env_id = ds[i]["env_id"]
if env_id not in groups:
groups[env_id] = []
groups[env_id].append(i)
return groups
@bp.route("/load", methods=["POST"])
def load_dataset_endpoint():
data = request.get_json()
repo = data.get("repo", "").strip()
if not repo:
return jsonify({"error": "repo is required"}), 400
split = data.get("split", "train")
try:
ds = _load_hf_dataset(repo, split)
except Exception as e:
return jsonify({"error": f"Failed to load dataset: {e}"}), 400
columns = ds.column_names
if not _detect_arena_dataset(columns):
return jsonify({
"error": f"Not an arena dataset. Expected columns: game_id, env_id, transcript. Found: {columns}"
}), 400
env_ids = _get_env_ids(ds)
episode_groups = _group_episodes_by_env(ds)
ds_id = _make_id(repo, split)
short_name = repo.rsplit("/", 1)[-1] if "/" in repo else repo
# Extract model name from first row
model_name = ds[0].get("model", "unknown") if len(ds) > 0 else "unknown"
# Compute win/loss/error counts
wins = sum(1 for r in ds["reward"] if r is not None and r > 0)
losses = sum(1 for i in range(len(ds)) if ds[i]["reward"] is not None and ds[i]["reward"] <= 0)
errors = sum(1 for e in ds["error"] if e is not None)
_cache[ds_id] = {
"dataset": ds,
"repo": repo,
"split": split,
"n_rows": len(ds),
"env_ids": env_ids,
"episode_groups": episode_groups,
"model_name": model_name,
"stats": {"wins": wins, "losses": losses, "errors": errors},
}
return jsonify({
"id": ds_id,
"repo": repo,
"name": short_name,
"split": split,
"columns": columns,
"n_rows": len(ds),
"env_ids": env_ids,
"episodes_per_env": {env: len(idxs) for env, idxs in episode_groups.items()},
"model_name": model_name,
"stats": {"wins": wins, "losses": losses, "errors": errors},
})
@bp.route("/", methods=["GET"])
def list_datasets():
result = []
for ds_id, info in _cache.items():
result.append({
"id": ds_id,
"repo": info["repo"],
"name": info["repo"].rsplit("/", 1)[-1] if "/" in info["repo"] else info["repo"],
"split": info["split"],
"n_rows": info["n_rows"],
"env_ids": info["env_ids"],
"model_name": info["model_name"],
})
return jsonify(result)
@bp.route("/<ds_id>/episode/<env_id>/<int:idx>", methods=["GET"])
def get_episode(ds_id, env_id, idx):
"""Get a single episode by env_id and episode index within that env."""
if ds_id not in _cache:
return jsonify({"error": "Dataset not loaded"}), 404
info = _cache[ds_id]
ds = info["dataset"]
episode_groups = info["episode_groups"]
if env_id not in episode_groups:
return jsonify({"error": f"env_id '{env_id}' not found"}), 404
indices = episode_groups[env_id]
if idx < 0 or idx >= len(indices):
return jsonify({"error": f"Episode index {idx} out of range (0-{len(indices)-1})"}), 400
row_idx = indices[idx]
row = ds[row_idx]
# Parse transcript JSON
transcript_raw = row.get("transcript", "[]")
try:
transcript = json.loads(transcript_raw) if isinstance(transcript_raw, str) else transcript_raw
except json.JSONDecodeError:
transcript = []
# Analyze each turn: dedup observations, split think tags from actions
analyzed_turns = []
prev_obs_raw = ""
for turn in transcript:
action_analysis = _analyze_action(turn.get("action", ""))
obs_raw = turn.get("observation", "")
obs_deduped = _dedup_observation(obs_raw, prev_obs_raw)
prev_obs_raw = obs_raw
analyzed_turns.append({
"turn": turn.get("turn", 0),
"player_id": turn.get("player_id", 0),
"observation": obs_deduped,
"action": turn.get("action", ""),
"think_text": action_analysis["think_text"],
"action_text": action_analysis["action_text"],
"think_len": action_analysis["think_len"],
"backtracks": action_analysis["backtracks"],
})
# Determine outcome
reward = row.get("reward")
error = row.get("error")
if error:
outcome = "error"
elif reward is not None and reward > 0:
outcome = "win"
elif reward is not None:
outcome = "loss"
else:
outcome = "unknown"
return jsonify({
"game_id": row.get("game_id", ""),
"env_id": row.get("env_id", ""),
"model": row.get("model", ""),
"opponent_model": row.get("opponent_model"),
"player_id": row.get("player_id", 0),
"reward": reward,
"num_turns": row.get("num_turns", len(transcript)),
"error": error,
"outcome": outcome,
"transcript": analyzed_turns,
"system_prompt": row.get("system_prompt", None),
"episode_idx": idx,
"total_episodes": len(indices),
})
@bp.route("/<ds_id>", methods=["DELETE"])
def unload_dataset(ds_id):
if ds_id in _cache:
del _cache[ds_id]
return jsonify({"status": "ok"})
|