agg-trace-visualizer / backend /api /rlm_eval_datasets.py
Zayne Rea Sprague
bump
6b7050a
import json
import hashlib
from flask import Blueprint, request, jsonify
from datasets import load_dataset
bp = Blueprint("rlm_eval_datasets", __name__, url_prefix="/api/rlm-eval/datasets")
_cache: dict[str, dict] = {}
def _make_id(repo: str, config: str, split: str) -> str:
key = f"{repo}:{config}:{split}"
return hashlib.md5(key.encode()).hexdigest()[:12]
def _build_hierarchy(rows: list[dict]) -> dict:
"""Reconstruct hierarchy from flat rows: examples -> iterations."""
examples: dict[int, dict] = {}
for row in rows:
ei = row.get("example_idx", 0)
ri = row.get("rlm_iter", 0)
if ei not in examples:
examples[ei] = {
"example_idx": ei,
"question_text": row.get("question_text", ""),
"eval_correct": row.get("eval_correct"),
"iterations": {},
"total_input_tokens": 0,
"total_output_tokens": 0,
"total_execution_time": 0.0,
"final_answer": None,
"final_answer_preview": "",
}
ex = examples[ei]
# Parse code blocks, flattening result.stdout -> stdout
code_blocks = []
cbj = row.get("code_blocks_json", "")
if cbj and cbj != "[]":
try:
raw_blocks = json.loads(cbj) if isinstance(cbj, str) else cbj
for cb in raw_blocks:
block = {"code": cb.get("code", "")}
result = cb.get("result", {})
if isinstance(result, dict) and result.get("stdout"):
block["stdout"] = result["stdout"]
elif cb.get("stdout"):
block["stdout"] = cb["stdout"]
code_blocks.append(block)
except (json.JSONDecodeError, TypeError):
code_blocks = []
iteration = {
"rlm_iter": ri,
"prompt": row.get("prompt", ""),
"response": row.get("response", ""),
"model": row.get("model", ""),
"input_tokens": row.get("input_tokens", 0),
"output_tokens": row.get("output_tokens", 0),
"execution_time": row.get("execution_time", 0.0),
"has_code_blocks": row.get("has_code_blocks", False),
"code_blocks": code_blocks,
"final_answer": row.get("final_answer"),
"timestamp": row.get("timestamp", ""),
}
ex["iterations"][ri] = iteration
ex["total_input_tokens"] += iteration["input_tokens"] or 0
ex["total_output_tokens"] += iteration["output_tokens"] or 0
ex["total_execution_time"] += iteration["execution_time"] or 0.0
if iteration["final_answer"]:
ex["final_answer"] = iteration["final_answer"]
ex["final_answer_preview"] = (iteration["final_answer"] or "")[:200]
# Sort and convert dicts to lists
result = []
for ei_key in sorted(examples.keys()):
ex = examples[ei_key]
iters_list = []
for ri_key in sorted(ex["iterations"].keys()):
iters_list.append(ex["iterations"][ri_key])
ex["iterations"] = iters_list
result.append(ex)
return {"examples": result}
@bp.route("/load", methods=["POST"])
def load_dataset_endpoint():
data = request.get_json()
repo = data.get("repo", "").strip()
if not repo:
return jsonify({"error": "repo is required"}), 400
config = data.get("config", "rlm_call_traces")
split = data.get("split", "train")
try:
ds = load_dataset(repo, config, split=split)
except Exception as e:
return jsonify({"error": f"Failed to load dataset: {e}"}), 400
ds_id = _make_id(repo, config, split)
rows = [ds[i] for i in range(len(ds))]
hierarchy = _build_hierarchy(rows)
# Extract metadata from first row
first_row = rows[0] if rows else {}
metadata = {
"run_id": first_row.get("run_id", ""),
"method": first_row.get("method", ""),
"model": first_row.get("model", ""),
}
_cache[ds_id] = {
"repo": repo,
"config": config,
"split": split,
"hierarchy": hierarchy,
"metadata": metadata,
"n_rows": len(rows),
}
short_name = repo.rsplit("/", 1)[-1] if "/" in repo else repo
return jsonify({
"id": ds_id,
"repo": repo,
"name": short_name,
"config": config,
"split": split,
"metadata": metadata,
"n_examples": len(hierarchy["examples"]),
"n_rows": len(rows),
})
@bp.route("/", methods=["GET"])
def list_datasets():
result = []
for ds_id, info in _cache.items():
result.append({
"id": ds_id,
"repo": info["repo"],
"name": info["repo"].rsplit("/", 1)[-1] if "/" in info["repo"] else info["repo"],
"config": info["config"],
"split": info["split"],
"metadata": info["metadata"],
"n_rows": info["n_rows"],
"n_examples": len(info["hierarchy"]["examples"]),
})
return jsonify(result)
@bp.route("/<ds_id>/overview", methods=["GET"])
def get_overview(ds_id):
"""Level 1: Summary of all examples."""
if ds_id not in _cache:
return jsonify({"error": "Dataset not loaded"}), 404
info = _cache[ds_id]
hierarchy = info["hierarchy"]
summaries = []
for ex in hierarchy["examples"]:
summaries.append({
"example_idx": ex["example_idx"],
"question_text": (ex["question_text"] or "")[:300],
"eval_correct": ex["eval_correct"],
"n_iterations": len(ex["iterations"]),
"total_input_tokens": ex["total_input_tokens"],
"total_output_tokens": ex["total_output_tokens"],
"total_execution_time": ex["total_execution_time"],
"final_answer_preview": ex["final_answer_preview"],
})
return jsonify({
"metadata": info["metadata"],
"examples": summaries,
})
@bp.route("/<ds_id>/example/<int:example_idx>", methods=["GET"])
def get_example_detail(ds_id, example_idx):
"""Level 2: Iteration timeline for one example."""
if ds_id not in _cache:
return jsonify({"error": "Dataset not loaded"}), 404
info = _cache[ds_id]
hierarchy = info["hierarchy"]
ex_data = None
for ex in hierarchy["examples"]:
if ex["example_idx"] == example_idx:
ex_data = ex
break
if ex_data is None:
return jsonify({"error": f"Example {example_idx} not found"}), 404
iters = []
for it in ex_data["iterations"]:
iters.append({
"rlm_iter": it["rlm_iter"],
"model": it["model"],
"input_tokens": it["input_tokens"],
"output_tokens": it["output_tokens"],
"execution_time": it["execution_time"],
"has_code_blocks": it["has_code_blocks"],
"n_code_blocks": len(it["code_blocks"]),
"response_preview": (it["response"] or "")[:300],
"has_final_answer": it["final_answer"] is not None,
"timestamp": it["timestamp"],
})
return jsonify({
"example_idx": example_idx,
"question_text": ex_data["question_text"],
"eval_correct": ex_data["eval_correct"],
"total_input_tokens": ex_data["total_input_tokens"],
"total_output_tokens": ex_data["total_output_tokens"],
"total_execution_time": ex_data["total_execution_time"],
"final_answer": ex_data["final_answer"],
"iterations": iters,
})
@bp.route("/<ds_id>/example/<int:example_idx>/iter/<int:rlm_iter>", methods=["GET"])
def get_iter_detail(ds_id, example_idx, rlm_iter):
"""Full detail for a specific RLM iteration within an example."""
if ds_id not in _cache:
return jsonify({"error": "Dataset not loaded"}), 404
info = _cache[ds_id]
hierarchy = info["hierarchy"]
for ex in hierarchy["examples"]:
if ex["example_idx"] != example_idx:
continue
for it in ex["iterations"]:
if it["rlm_iter"] == rlm_iter:
return jsonify(it)
return jsonify({"error": "Iteration not found"}), 404
@bp.route("/<ds_id>", methods=["DELETE"])
def unload_dataset(ds_id):
if ds_id in _cache:
del _cache[ds_id]
return jsonify({"status": "ok"})