| | import json |
| | import hashlib |
| | from flask import Blueprint, request, jsonify |
| | from datasets import load_dataset |
| |
|
| | bp = Blueprint("rlm_eval_datasets", __name__, url_prefix="/api/rlm-eval/datasets") |
| |
|
| | _cache: dict[str, dict] = {} |
| |
|
| |
|
| | def _make_id(repo: str, config: str, split: str) -> str: |
| | key = f"{repo}:{config}:{split}" |
| | return hashlib.md5(key.encode()).hexdigest()[:12] |
| |
|
| |
|
| | def _build_hierarchy(rows: list[dict]) -> dict: |
| | """Reconstruct hierarchy from flat rows: examples -> iterations.""" |
| | examples: dict[int, dict] = {} |
| |
|
| | for row in rows: |
| | ei = row.get("example_idx", 0) |
| | ri = row.get("rlm_iter", 0) |
| |
|
| | if ei not in examples: |
| | examples[ei] = { |
| | "example_idx": ei, |
| | "question_text": row.get("question_text", ""), |
| | "eval_correct": row.get("eval_correct"), |
| | "iterations": {}, |
| | "total_input_tokens": 0, |
| | "total_output_tokens": 0, |
| | "total_execution_time": 0.0, |
| | "final_answer": None, |
| | "final_answer_preview": "", |
| | } |
| |
|
| | ex = examples[ei] |
| |
|
| | |
| | code_blocks = [] |
| | cbj = row.get("code_blocks_json", "") |
| | if cbj and cbj != "[]": |
| | try: |
| | raw_blocks = json.loads(cbj) if isinstance(cbj, str) else cbj |
| | for cb in raw_blocks: |
| | block = {"code": cb.get("code", "")} |
| | result = cb.get("result", {}) |
| | if isinstance(result, dict) and result.get("stdout"): |
| | block["stdout"] = result["stdout"] |
| | elif cb.get("stdout"): |
| | block["stdout"] = cb["stdout"] |
| | code_blocks.append(block) |
| | except (json.JSONDecodeError, TypeError): |
| | code_blocks = [] |
| |
|
| | iteration = { |
| | "rlm_iter": ri, |
| | "prompt": row.get("prompt", ""), |
| | "response": row.get("response", ""), |
| | "model": row.get("model", ""), |
| | "input_tokens": row.get("input_tokens", 0), |
| | "output_tokens": row.get("output_tokens", 0), |
| | "execution_time": row.get("execution_time", 0.0), |
| | "has_code_blocks": row.get("has_code_blocks", False), |
| | "code_blocks": code_blocks, |
| | "final_answer": row.get("final_answer"), |
| | "timestamp": row.get("timestamp", ""), |
| | } |
| |
|
| | ex["iterations"][ri] = iteration |
| | ex["total_input_tokens"] += iteration["input_tokens"] or 0 |
| | ex["total_output_tokens"] += iteration["output_tokens"] or 0 |
| | ex["total_execution_time"] += iteration["execution_time"] or 0.0 |
| |
|
| | if iteration["final_answer"]: |
| | ex["final_answer"] = iteration["final_answer"] |
| | ex["final_answer_preview"] = (iteration["final_answer"] or "")[:200] |
| |
|
| | |
| | result = [] |
| | for ei_key in sorted(examples.keys()): |
| | ex = examples[ei_key] |
| | iters_list = [] |
| | for ri_key in sorted(ex["iterations"].keys()): |
| | iters_list.append(ex["iterations"][ri_key]) |
| | ex["iterations"] = iters_list |
| | result.append(ex) |
| |
|
| | return {"examples": result} |
| |
|
| |
|
| | @bp.route("/load", methods=["POST"]) |
| | def load_dataset_endpoint(): |
| | data = request.get_json() |
| | repo = data.get("repo", "").strip() |
| | if not repo: |
| | return jsonify({"error": "repo is required"}), 400 |
| |
|
| | config = data.get("config", "rlm_call_traces") |
| | split = data.get("split", "train") |
| |
|
| | try: |
| | ds = load_dataset(repo, config, split=split) |
| | except Exception as e: |
| | return jsonify({"error": f"Failed to load dataset: {e}"}), 400 |
| |
|
| | ds_id = _make_id(repo, config, split) |
| | rows = [ds[i] for i in range(len(ds))] |
| | hierarchy = _build_hierarchy(rows) |
| |
|
| | |
| | first_row = rows[0] if rows else {} |
| | metadata = { |
| | "run_id": first_row.get("run_id", ""), |
| | "method": first_row.get("method", ""), |
| | "model": first_row.get("model", ""), |
| | } |
| |
|
| | _cache[ds_id] = { |
| | "repo": repo, |
| | "config": config, |
| | "split": split, |
| | "hierarchy": hierarchy, |
| | "metadata": metadata, |
| | "n_rows": len(rows), |
| | } |
| |
|
| | short_name = repo.rsplit("/", 1)[-1] if "/" in repo else repo |
| |
|
| | return jsonify({ |
| | "id": ds_id, |
| | "repo": repo, |
| | "name": short_name, |
| | "config": config, |
| | "split": split, |
| | "metadata": metadata, |
| | "n_examples": len(hierarchy["examples"]), |
| | "n_rows": len(rows), |
| | }) |
| |
|
| |
|
| | @bp.route("/", methods=["GET"]) |
| | def list_datasets(): |
| | result = [] |
| | for ds_id, info in _cache.items(): |
| | result.append({ |
| | "id": ds_id, |
| | "repo": info["repo"], |
| | "name": info["repo"].rsplit("/", 1)[-1] if "/" in info["repo"] else info["repo"], |
| | "config": info["config"], |
| | "split": info["split"], |
| | "metadata": info["metadata"], |
| | "n_rows": info["n_rows"], |
| | "n_examples": len(info["hierarchy"]["examples"]), |
| | }) |
| | return jsonify(result) |
| |
|
| |
|
| | @bp.route("/<ds_id>/overview", methods=["GET"]) |
| | def get_overview(ds_id): |
| | """Level 1: Summary of all examples.""" |
| | if ds_id not in _cache: |
| | return jsonify({"error": "Dataset not loaded"}), 404 |
| |
|
| | info = _cache[ds_id] |
| | hierarchy = info["hierarchy"] |
| |
|
| | summaries = [] |
| | for ex in hierarchy["examples"]: |
| | summaries.append({ |
| | "example_idx": ex["example_idx"], |
| | "question_text": (ex["question_text"] or "")[:300], |
| | "eval_correct": ex["eval_correct"], |
| | "n_iterations": len(ex["iterations"]), |
| | "total_input_tokens": ex["total_input_tokens"], |
| | "total_output_tokens": ex["total_output_tokens"], |
| | "total_execution_time": ex["total_execution_time"], |
| | "final_answer_preview": ex["final_answer_preview"], |
| | }) |
| |
|
| | return jsonify({ |
| | "metadata": info["metadata"], |
| | "examples": summaries, |
| | }) |
| |
|
| |
|
| | @bp.route("/<ds_id>/example/<int:example_idx>", methods=["GET"]) |
| | def get_example_detail(ds_id, example_idx): |
| | """Level 2: Iteration timeline for one example.""" |
| | if ds_id not in _cache: |
| | return jsonify({"error": "Dataset not loaded"}), 404 |
| |
|
| | info = _cache[ds_id] |
| | hierarchy = info["hierarchy"] |
| |
|
| | ex_data = None |
| | for ex in hierarchy["examples"]: |
| | if ex["example_idx"] == example_idx: |
| | ex_data = ex |
| | break |
| |
|
| | if ex_data is None: |
| | return jsonify({"error": f"Example {example_idx} not found"}), 404 |
| |
|
| | iters = [] |
| | for it in ex_data["iterations"]: |
| | iters.append({ |
| | "rlm_iter": it["rlm_iter"], |
| | "model": it["model"], |
| | "input_tokens": it["input_tokens"], |
| | "output_tokens": it["output_tokens"], |
| | "execution_time": it["execution_time"], |
| | "has_code_blocks": it["has_code_blocks"], |
| | "n_code_blocks": len(it["code_blocks"]), |
| | "response_preview": (it["response"] or "")[:300], |
| | "has_final_answer": it["final_answer"] is not None, |
| | "timestamp": it["timestamp"], |
| | }) |
| |
|
| | return jsonify({ |
| | "example_idx": example_idx, |
| | "question_text": ex_data["question_text"], |
| | "eval_correct": ex_data["eval_correct"], |
| | "total_input_tokens": ex_data["total_input_tokens"], |
| | "total_output_tokens": ex_data["total_output_tokens"], |
| | "total_execution_time": ex_data["total_execution_time"], |
| | "final_answer": ex_data["final_answer"], |
| | "iterations": iters, |
| | }) |
| |
|
| |
|
| | @bp.route("/<ds_id>/example/<int:example_idx>/iter/<int:rlm_iter>", methods=["GET"]) |
| | def get_iter_detail(ds_id, example_idx, rlm_iter): |
| | """Full detail for a specific RLM iteration within an example.""" |
| | if ds_id not in _cache: |
| | return jsonify({"error": "Dataset not loaded"}), 404 |
| |
|
| | info = _cache[ds_id] |
| | hierarchy = info["hierarchy"] |
| |
|
| | for ex in hierarchy["examples"]: |
| | if ex["example_idx"] != example_idx: |
| | continue |
| | for it in ex["iterations"]: |
| | if it["rlm_iter"] == rlm_iter: |
| | return jsonify(it) |
| |
|
| | return jsonify({"error": "Iteration not found"}), 404 |
| |
|
| |
|
| | @bp.route("/<ds_id>", methods=["DELETE"]) |
| | def unload_dataset(ds_id): |
| | if ds_id in _cache: |
| | del _cache[ds_id] |
| | return jsonify({"status": "ok"}) |
| |
|