import json import hashlib import os from flask import Blueprint, request, jsonify from datasets import load_dataset, Dataset bp = Blueprint("arena_datasets", __name__, url_prefix="/api/arena/datasets") # In-memory cache: id -> dataset info _cache: dict[str, dict] = {} def _make_id(repo: str, split: str) -> str: key = f"{repo}:{split}" return hashlib.md5(key.encode()).hexdigest()[:12] def _load_hf_dataset(repo: str, split: str) -> Dataset: if os.path.exists(repo): return Dataset.from_parquet(repo) return load_dataset(repo, split=split) def _detect_arena_dataset(columns: list[str]) -> bool: """Check if this looks like an arena evaluation dataset.""" required = {"game_id", "env_id", "transcript"} return required.issubset(set(columns)) def _analyze_action(text: str) -> dict: """Split tags from action text and compute analytics.""" if not text: return {"think_text": "", "action_text": "", "think_len": 0, "action_len": 0, "backtracks": 0, "restarts": 0} think_end = text.find("") if think_end > 0: think_text = text[:think_end + 8] action_text = text[think_end + 8:].strip() else: think_text = "" action_text = text t = text.lower() backtracks = sum(t.count(w) for w in ["wait,", "wait ", "hmm", "let me try", "try again", "another approach", "let me reconsider"]) restarts = sum(t.count(w) for w in ["start over", "fresh approach", "different approach", "from scratch"]) return { "think_text": think_text, "action_text": action_text, "think_len": len(think_text), "action_len": len(action_text), "backtracks": backtracks, "restarts": restarts, } def _dedup_observation(text: str, prev_text: str) -> str: """Remove content duplicated from the previous observation. TextArena accumulates the full chat history in each observation, so turn N's observation repeats everything from turns 0..N-1 plus echoed [Player] actions. We strip the repeated prefix and the echoed player actions, keeping only new [GAME]/[Moderator] content for this turn. """ import re if not text: return "" if not prev_text: return text new_part = None # The previous observation text should appear as a prefix of the # current one. Strip it to get only what's new. if text.startswith(prev_text): new_part = text[len(prev_text):].strip() else: # Fallback: find the longest common prefix min_len = min(len(text), len(prev_text)) common = 0 for i in range(min_len): if text[i] == prev_text[i]: common = i + 1 else: break if common > len(prev_text) * 0.8: new_part = text[common:].strip() if not new_part: return text # After stripping the observation prefix, the remaining text typically # starts with echoed [Player] actions (already shown in action bubbles), # followed by new [GAME] or [Moderator] content. Strip the echoed # player actions to keep only the new game content. game_marker = re.search(r'\[GAME\]|\[Moderator\]', new_part) if game_marker: game_content = new_part[game_marker.start():].strip() return game_content if game_content else new_part return new_part def _get_env_ids(ds: Dataset) -> list[str]: """Get sorted unique env_ids from dataset.""" return sorted(set(ds["env_id"])) def _group_episodes_by_env(ds: Dataset) -> dict[str, list[int]]: """Group row indices by env_id.""" groups: dict[str, list[int]] = {} for i in range(len(ds)): env_id = ds[i]["env_id"] if env_id not in groups: groups[env_id] = [] groups[env_id].append(i) return groups @bp.route("/load", methods=["POST"]) def load_dataset_endpoint(): data = request.get_json() repo = data.get("repo", "").strip() if not repo: return jsonify({"error": "repo is required"}), 400 split = data.get("split", "train") try: ds = _load_hf_dataset(repo, split) except Exception as e: return jsonify({"error": f"Failed to load dataset: {e}"}), 400 columns = ds.column_names if not _detect_arena_dataset(columns): return jsonify({ "error": f"Not an arena dataset. Expected columns: game_id, env_id, transcript. Found: {columns}" }), 400 env_ids = _get_env_ids(ds) episode_groups = _group_episodes_by_env(ds) ds_id = _make_id(repo, split) short_name = repo.rsplit("/", 1)[-1] if "/" in repo else repo # Extract model name from first row model_name = ds[0].get("model", "unknown") if len(ds) > 0 else "unknown" # Compute win/loss/error counts wins = sum(1 for r in ds["reward"] if r is not None and r > 0) losses = sum(1 for i in range(len(ds)) if ds[i]["reward"] is not None and ds[i]["reward"] <= 0) errors = sum(1 for e in ds["error"] if e is not None) _cache[ds_id] = { "dataset": ds, "repo": repo, "split": split, "n_rows": len(ds), "env_ids": env_ids, "episode_groups": episode_groups, "model_name": model_name, "stats": {"wins": wins, "losses": losses, "errors": errors}, } return jsonify({ "id": ds_id, "repo": repo, "name": short_name, "split": split, "columns": columns, "n_rows": len(ds), "env_ids": env_ids, "episodes_per_env": {env: len(idxs) for env, idxs in episode_groups.items()}, "model_name": model_name, "stats": {"wins": wins, "losses": losses, "errors": errors}, }) @bp.route("/", methods=["GET"]) def list_datasets(): result = [] for ds_id, info in _cache.items(): result.append({ "id": ds_id, "repo": info["repo"], "name": info["repo"].rsplit("/", 1)[-1] if "/" in info["repo"] else info["repo"], "split": info["split"], "n_rows": info["n_rows"], "env_ids": info["env_ids"], "model_name": info["model_name"], }) return jsonify(result) @bp.route("//episode//", methods=["GET"]) def get_episode(ds_id, env_id, idx): """Get a single episode by env_id and episode index within that env.""" if ds_id not in _cache: return jsonify({"error": "Dataset not loaded"}), 404 info = _cache[ds_id] ds = info["dataset"] episode_groups = info["episode_groups"] if env_id not in episode_groups: return jsonify({"error": f"env_id '{env_id}' not found"}), 404 indices = episode_groups[env_id] if idx < 0 or idx >= len(indices): return jsonify({"error": f"Episode index {idx} out of range (0-{len(indices)-1})"}), 400 row_idx = indices[idx] row = ds[row_idx] # Parse transcript JSON transcript_raw = row.get("transcript", "[]") try: transcript = json.loads(transcript_raw) if isinstance(transcript_raw, str) else transcript_raw except json.JSONDecodeError: transcript = [] # Analyze each turn: dedup observations, split think tags from actions analyzed_turns = [] prev_obs_raw = "" for turn in transcript: action_analysis = _analyze_action(turn.get("action", "")) obs_raw = turn.get("observation", "") obs_deduped = _dedup_observation(obs_raw, prev_obs_raw) prev_obs_raw = obs_raw analyzed_turns.append({ "turn": turn.get("turn", 0), "player_id": turn.get("player_id", 0), "observation": obs_deduped, "action": turn.get("action", ""), "think_text": action_analysis["think_text"], "action_text": action_analysis["action_text"], "think_len": action_analysis["think_len"], "backtracks": action_analysis["backtracks"], }) # Determine outcome reward = row.get("reward") error = row.get("error") if error: outcome = "error" elif reward is not None and reward > 0: outcome = "win" elif reward is not None: outcome = "loss" else: outcome = "unknown" return jsonify({ "game_id": row.get("game_id", ""), "env_id": row.get("env_id", ""), "model": row.get("model", ""), "opponent_model": row.get("opponent_model"), "player_id": row.get("player_id", 0), "reward": reward, "num_turns": row.get("num_turns", len(transcript)), "error": error, "outcome": outcome, "transcript": analyzed_turns, "system_prompt": row.get("system_prompt", None), "episode_idx": idx, "total_episodes": len(indices), }) @bp.route("/", methods=["DELETE"]) def unload_dataset(ds_id): if ds_id in _cache: del _cache[ds_id] return jsonify({"status": "ok"})