import json import hashlib from flask import Blueprint, request, jsonify from datasets import load_dataset bp = Blueprint("rlm_eval_datasets", __name__, url_prefix="/api/rlm-eval/datasets") _cache: dict[str, dict] = {} def _make_id(repo: str, config: str, split: str) -> str: key = f"{repo}:{config}:{split}" return hashlib.md5(key.encode()).hexdigest()[:12] def _build_hierarchy(rows: list[dict]) -> dict: """Reconstruct hierarchy from flat rows: examples -> iterations.""" examples: dict[int, dict] = {} for row in rows: ei = row.get("example_idx", 0) ri = row.get("rlm_iter", 0) if ei not in examples: examples[ei] = { "example_idx": ei, "question_text": row.get("question_text", ""), "eval_correct": row.get("eval_correct"), "iterations": {}, "total_input_tokens": 0, "total_output_tokens": 0, "total_execution_time": 0.0, "final_answer": None, "final_answer_preview": "", } ex = examples[ei] # Parse code blocks, flattening result.stdout -> stdout code_blocks = [] cbj = row.get("code_blocks_json", "") if cbj and cbj != "[]": try: raw_blocks = json.loads(cbj) if isinstance(cbj, str) else cbj for cb in raw_blocks: block = {"code": cb.get("code", "")} result = cb.get("result", {}) if isinstance(result, dict) and result.get("stdout"): block["stdout"] = result["stdout"] elif cb.get("stdout"): block["stdout"] = cb["stdout"] code_blocks.append(block) except (json.JSONDecodeError, TypeError): code_blocks = [] iteration = { "rlm_iter": ri, "prompt": row.get("prompt", ""), "response": row.get("response", ""), "model": row.get("model", ""), "input_tokens": row.get("input_tokens", 0), "output_tokens": row.get("output_tokens", 0), "execution_time": row.get("execution_time", 0.0), "has_code_blocks": row.get("has_code_blocks", False), "code_blocks": code_blocks, "final_answer": row.get("final_answer"), "timestamp": row.get("timestamp", ""), } ex["iterations"][ri] = iteration ex["total_input_tokens"] += iteration["input_tokens"] or 0 ex["total_output_tokens"] += iteration["output_tokens"] or 0 ex["total_execution_time"] += iteration["execution_time"] or 0.0 if iteration["final_answer"]: ex["final_answer"] = iteration["final_answer"] ex["final_answer_preview"] = (iteration["final_answer"] or "")[:200] # Sort and convert dicts to lists result = [] for ei_key in sorted(examples.keys()): ex = examples[ei_key] iters_list = [] for ri_key in sorted(ex["iterations"].keys()): iters_list.append(ex["iterations"][ri_key]) ex["iterations"] = iters_list result.append(ex) return {"examples": result} @bp.route("/load", methods=["POST"]) def load_dataset_endpoint(): data = request.get_json() repo = data.get("repo", "").strip() if not repo: return jsonify({"error": "repo is required"}), 400 config = data.get("config", "rlm_call_traces") split = data.get("split", "train") try: ds = load_dataset(repo, config, split=split) except Exception as e: return jsonify({"error": f"Failed to load dataset: {e}"}), 400 ds_id = _make_id(repo, config, split) rows = [ds[i] for i in range(len(ds))] hierarchy = _build_hierarchy(rows) # Extract metadata from first row first_row = rows[0] if rows else {} metadata = { "run_id": first_row.get("run_id", ""), "method": first_row.get("method", ""), "model": first_row.get("model", ""), } _cache[ds_id] = { "repo": repo, "config": config, "split": split, "hierarchy": hierarchy, "metadata": metadata, "n_rows": len(rows), } short_name = repo.rsplit("/", 1)[-1] if "/" in repo else repo return jsonify({ "id": ds_id, "repo": repo, "name": short_name, "config": config, "split": split, "metadata": metadata, "n_examples": len(hierarchy["examples"]), "n_rows": len(rows), }) @bp.route("/", methods=["GET"]) def list_datasets(): result = [] for ds_id, info in _cache.items(): result.append({ "id": ds_id, "repo": info["repo"], "name": info["repo"].rsplit("/", 1)[-1] if "/" in info["repo"] else info["repo"], "config": info["config"], "split": info["split"], "metadata": info["metadata"], "n_rows": info["n_rows"], "n_examples": len(info["hierarchy"]["examples"]), }) return jsonify(result) @bp.route("//overview", methods=["GET"]) def get_overview(ds_id): """Level 1: Summary of all examples.""" if ds_id not in _cache: return jsonify({"error": "Dataset not loaded"}), 404 info = _cache[ds_id] hierarchy = info["hierarchy"] summaries = [] for ex in hierarchy["examples"]: summaries.append({ "example_idx": ex["example_idx"], "question_text": (ex["question_text"] or "")[:300], "eval_correct": ex["eval_correct"], "n_iterations": len(ex["iterations"]), "total_input_tokens": ex["total_input_tokens"], "total_output_tokens": ex["total_output_tokens"], "total_execution_time": ex["total_execution_time"], "final_answer_preview": ex["final_answer_preview"], }) return jsonify({ "metadata": info["metadata"], "examples": summaries, }) @bp.route("//example/", methods=["GET"]) def get_example_detail(ds_id, example_idx): """Level 2: Iteration timeline for one example.""" if ds_id not in _cache: return jsonify({"error": "Dataset not loaded"}), 404 info = _cache[ds_id] hierarchy = info["hierarchy"] ex_data = None for ex in hierarchy["examples"]: if ex["example_idx"] == example_idx: ex_data = ex break if ex_data is None: return jsonify({"error": f"Example {example_idx} not found"}), 404 iters = [] for it in ex_data["iterations"]: iters.append({ "rlm_iter": it["rlm_iter"], "model": it["model"], "input_tokens": it["input_tokens"], "output_tokens": it["output_tokens"], "execution_time": it["execution_time"], "has_code_blocks": it["has_code_blocks"], "n_code_blocks": len(it["code_blocks"]), "response_preview": (it["response"] or "")[:300], "has_final_answer": it["final_answer"] is not None, "timestamp": it["timestamp"], }) return jsonify({ "example_idx": example_idx, "question_text": ex_data["question_text"], "eval_correct": ex_data["eval_correct"], "total_input_tokens": ex_data["total_input_tokens"], "total_output_tokens": ex_data["total_output_tokens"], "total_execution_time": ex_data["total_execution_time"], "final_answer": ex_data["final_answer"], "iterations": iters, }) @bp.route("//example//iter/", methods=["GET"]) def get_iter_detail(ds_id, example_idx, rlm_iter): """Full detail for a specific RLM iteration within an example.""" if ds_id not in _cache: return jsonify({"error": "Dataset not loaded"}), 404 info = _cache[ds_id] hierarchy = info["hierarchy"] for ex in hierarchy["examples"]: if ex["example_idx"] != example_idx: continue for it in ex["iterations"]: if it["rlm_iter"] == rlm_iter: return jsonify(it) return jsonify({"error": "Iteration not found"}), 404 @bp.route("/", methods=["DELETE"]) def unload_dataset(ds_id): if ds_id in _cache: del _cache[ds_id] return jsonify({"status": "ok"})