| | import json |
| | import hashlib |
| | import os |
| | from flask import Blueprint, request, jsonify |
| | from datasets import load_dataset, Dataset |
| |
|
| | bp = Blueprint("arena_datasets", __name__, url_prefix="/api/arena/datasets") |
| |
|
| | |
| | _cache: dict[str, dict] = {} |
| |
|
| |
|
| | def _make_id(repo: str, split: str) -> str: |
| | key = f"{repo}:{split}" |
| | return hashlib.md5(key.encode()).hexdigest()[:12] |
| |
|
| |
|
| | def _load_hf_dataset(repo: str, split: str) -> Dataset: |
| | if os.path.exists(repo): |
| | return Dataset.from_parquet(repo) |
| | return load_dataset(repo, split=split) |
| |
|
| |
|
| | def _detect_arena_dataset(columns: list[str]) -> bool: |
| | """Check if this looks like an arena evaluation dataset.""" |
| | required = {"game_id", "env_id", "transcript"} |
| | return required.issubset(set(columns)) |
| |
|
| |
|
| | def _analyze_action(text: str) -> dict: |
| | """Split <think> tags from action text and compute analytics.""" |
| | if not text: |
| | return {"think_text": "", "action_text": "", "think_len": 0, "action_len": 0, |
| | "backtracks": 0, "restarts": 0} |
| |
|
| | think_end = text.find("</think>") |
| | if think_end > 0: |
| | think_text = text[:think_end + 8] |
| | action_text = text[think_end + 8:].strip() |
| | else: |
| | think_text = "" |
| | action_text = text |
| |
|
| | t = text.lower() |
| | backtracks = sum(t.count(w) for w in |
| | ["wait,", "wait ", "hmm", "let me try", "try again", |
| | "another approach", "let me reconsider"]) |
| | restarts = sum(t.count(w) for w in |
| | ["start over", "fresh approach", "different approach", "from scratch"]) |
| |
|
| | return { |
| | "think_text": think_text, |
| | "action_text": action_text, |
| | "think_len": len(think_text), |
| | "action_len": len(action_text), |
| | "backtracks": backtracks, |
| | "restarts": restarts, |
| | } |
| |
|
| |
|
| | def _dedup_observation(text: str, prev_text: str) -> str: |
| | """Remove content duplicated from the previous observation. |
| | |
| | TextArena accumulates the full chat history in each observation, |
| | so turn N's observation repeats everything from turns 0..N-1 |
| | plus echoed [Player] actions. We strip the repeated prefix and |
| | the echoed player actions, keeping only new [GAME]/[Moderator] |
| | content for this turn. |
| | """ |
| | import re |
| |
|
| | if not text: |
| | return "" |
| | if not prev_text: |
| | return text |
| |
|
| | new_part = None |
| |
|
| | |
| | |
| | if text.startswith(prev_text): |
| | new_part = text[len(prev_text):].strip() |
| | else: |
| | |
| | min_len = min(len(text), len(prev_text)) |
| | common = 0 |
| | for i in range(min_len): |
| | if text[i] == prev_text[i]: |
| | common = i + 1 |
| | else: |
| | break |
| |
|
| | if common > len(prev_text) * 0.8: |
| | new_part = text[common:].strip() |
| |
|
| | if not new_part: |
| | return text |
| |
|
| | |
| | |
| | |
| | |
| | game_marker = re.search(r'\[GAME\]|\[Moderator\]', new_part) |
| | if game_marker: |
| | game_content = new_part[game_marker.start():].strip() |
| | return game_content if game_content else new_part |
| |
|
| | return new_part |
| |
|
| |
|
| | def _get_env_ids(ds: Dataset) -> list[str]: |
| | """Get sorted unique env_ids from dataset.""" |
| | return sorted(set(ds["env_id"])) |
| |
|
| |
|
| | def _group_episodes_by_env(ds: Dataset) -> dict[str, list[int]]: |
| | """Group row indices by env_id.""" |
| | groups: dict[str, list[int]] = {} |
| | for i in range(len(ds)): |
| | env_id = ds[i]["env_id"] |
| | if env_id not in groups: |
| | groups[env_id] = [] |
| | groups[env_id].append(i) |
| | return groups |
| |
|
| |
|
| | @bp.route("/load", methods=["POST"]) |
| | def load_dataset_endpoint(): |
| | data = request.get_json() |
| | repo = data.get("repo", "").strip() |
| | if not repo: |
| | return jsonify({"error": "repo is required"}), 400 |
| |
|
| | split = data.get("split", "train") |
| |
|
| | try: |
| | ds = _load_hf_dataset(repo, split) |
| | except Exception as e: |
| | return jsonify({"error": f"Failed to load dataset: {e}"}), 400 |
| |
|
| | columns = ds.column_names |
| | if not _detect_arena_dataset(columns): |
| | return jsonify({ |
| | "error": f"Not an arena dataset. Expected columns: game_id, env_id, transcript. Found: {columns}" |
| | }), 400 |
| |
|
| | env_ids = _get_env_ids(ds) |
| | episode_groups = _group_episodes_by_env(ds) |
| | ds_id = _make_id(repo, split) |
| | short_name = repo.rsplit("/", 1)[-1] if "/" in repo else repo |
| |
|
| | |
| | model_name = ds[0].get("model", "unknown") if len(ds) > 0 else "unknown" |
| |
|
| | |
| | wins = sum(1 for r in ds["reward"] if r is not None and r > 0) |
| | losses = sum(1 for i in range(len(ds)) if ds[i]["reward"] is not None and ds[i]["reward"] <= 0) |
| | errors = sum(1 for e in ds["error"] if e is not None) |
| |
|
| | _cache[ds_id] = { |
| | "dataset": ds, |
| | "repo": repo, |
| | "split": split, |
| | "n_rows": len(ds), |
| | "env_ids": env_ids, |
| | "episode_groups": episode_groups, |
| | "model_name": model_name, |
| | "stats": {"wins": wins, "losses": losses, "errors": errors}, |
| | } |
| |
|
| | return jsonify({ |
| | "id": ds_id, |
| | "repo": repo, |
| | "name": short_name, |
| | "split": split, |
| | "columns": columns, |
| | "n_rows": len(ds), |
| | "env_ids": env_ids, |
| | "episodes_per_env": {env: len(idxs) for env, idxs in episode_groups.items()}, |
| | "model_name": model_name, |
| | "stats": {"wins": wins, "losses": losses, "errors": errors}, |
| | }) |
| |
|
| |
|
| | @bp.route("/", methods=["GET"]) |
| | def list_datasets(): |
| | result = [] |
| | for ds_id, info in _cache.items(): |
| | result.append({ |
| | "id": ds_id, |
| | "repo": info["repo"], |
| | "name": info["repo"].rsplit("/", 1)[-1] if "/" in info["repo"] else info["repo"], |
| | "split": info["split"], |
| | "n_rows": info["n_rows"], |
| | "env_ids": info["env_ids"], |
| | "model_name": info["model_name"], |
| | }) |
| | return jsonify(result) |
| |
|
| |
|
| | @bp.route("/<ds_id>/episode/<env_id>/<int:idx>", methods=["GET"]) |
| | def get_episode(ds_id, env_id, idx): |
| | """Get a single episode by env_id and episode index within that env.""" |
| | if ds_id not in _cache: |
| | return jsonify({"error": "Dataset not loaded"}), 404 |
| |
|
| | info = _cache[ds_id] |
| | ds = info["dataset"] |
| | episode_groups = info["episode_groups"] |
| |
|
| | if env_id not in episode_groups: |
| | return jsonify({"error": f"env_id '{env_id}' not found"}), 404 |
| |
|
| | indices = episode_groups[env_id] |
| | if idx < 0 or idx >= len(indices): |
| | return jsonify({"error": f"Episode index {idx} out of range (0-{len(indices)-1})"}), 400 |
| |
|
| | row_idx = indices[idx] |
| | row = ds[row_idx] |
| |
|
| | |
| | transcript_raw = row.get("transcript", "[]") |
| | try: |
| | transcript = json.loads(transcript_raw) if isinstance(transcript_raw, str) else transcript_raw |
| | except json.JSONDecodeError: |
| | transcript = [] |
| |
|
| | |
| | analyzed_turns = [] |
| | prev_obs_raw = "" |
| | for turn in transcript: |
| | action_analysis = _analyze_action(turn.get("action", "")) |
| | obs_raw = turn.get("observation", "") |
| | obs_deduped = _dedup_observation(obs_raw, prev_obs_raw) |
| | prev_obs_raw = obs_raw |
| | analyzed_turns.append({ |
| | "turn": turn.get("turn", 0), |
| | "player_id": turn.get("player_id", 0), |
| | "observation": obs_deduped, |
| | "action": turn.get("action", ""), |
| | "think_text": action_analysis["think_text"], |
| | "action_text": action_analysis["action_text"], |
| | "think_len": action_analysis["think_len"], |
| | "backtracks": action_analysis["backtracks"], |
| | }) |
| |
|
| | |
| | reward = row.get("reward") |
| | error = row.get("error") |
| | if error: |
| | outcome = "error" |
| | elif reward is not None and reward > 0: |
| | outcome = "win" |
| | elif reward is not None: |
| | outcome = "loss" |
| | else: |
| | outcome = "unknown" |
| |
|
| | return jsonify({ |
| | "game_id": row.get("game_id", ""), |
| | "env_id": row.get("env_id", ""), |
| | "model": row.get("model", ""), |
| | "opponent_model": row.get("opponent_model"), |
| | "player_id": row.get("player_id", 0), |
| | "reward": reward, |
| | "num_turns": row.get("num_turns", len(transcript)), |
| | "error": error, |
| | "outcome": outcome, |
| | "transcript": analyzed_turns, |
| | "system_prompt": row.get("system_prompt", None), |
| | "episode_idx": idx, |
| | "total_episodes": len(indices), |
| | }) |
| |
|
| |
|
| | @bp.route("/<ds_id>", methods=["DELETE"]) |
| | def unload_dataset(ds_id): |
| | if ds_id in _cache: |
| | del _cache[ds_id] |
| | return jsonify({"status": "ok"}) |
| |
|