import json import hashlib from flask import Blueprint, request, jsonify from datasets import load_dataset bp = Blueprint("rlm_datasets", __name__, url_prefix="/api/rlm/datasets") _cache: dict[str, dict] = {} def _make_id(repo: str, config: str, split: str) -> str: key = f"{repo}:{config}:{split}" return hashlib.md5(key.encode()).hexdigest()[:12] def _build_hierarchy(rows: list[dict]) -> dict: """Reconstruct hierarchy from flat rlm_call_traces rows.""" gepa_iters: dict[int, dict] = {} for row in rows: gi = row.get("gepa_iter", 0) rci = row.get("rlm_call_idx", 0) ri = row.get("rlm_iter", 0) if gi not in gepa_iters: gepa_iters[gi] = { "gepa_iter": gi, "rlm_calls": {}, "total_input_tokens": 0, "total_output_tokens": 0, "total_execution_time": 0.0, "final_answer": None, } gi_data = gepa_iters[gi] if rci not in gi_data["rlm_calls"]: gi_data["rlm_calls"][rci] = { "rlm_call_idx": rci, "iterations": [], } # Parse code blocks code_blocks = [] cbj = row.get("code_blocks_json", "") if cbj and cbj != "[]": try: code_blocks = json.loads(cbj) if isinstance(cbj, str) else cbj except (json.JSONDecodeError, TypeError): code_blocks = [] iteration = { "rlm_iter": ri, "prompt": row.get("prompt", ""), "response": row.get("response", ""), "model": row.get("model", ""), "input_tokens": row.get("input_tokens", 0), "output_tokens": row.get("output_tokens", 0), "execution_time": row.get("execution_time", 0.0), "has_code_blocks": row.get("has_code_blocks", False), "code_blocks": code_blocks, "final_answer": row.get("final_answer"), "subcall_id": row.get("subcall_id"), "parent_id": row.get("parent_id"), "timestamp": row.get("timestamp", ""), } gi_data["rlm_calls"][rci]["iterations"].append(iteration) gi_data["total_input_tokens"] += iteration["input_tokens"] or 0 gi_data["total_output_tokens"] += iteration["output_tokens"] or 0 gi_data["total_execution_time"] += iteration["execution_time"] or 0.0 if iteration["final_answer"]: gi_data["final_answer"] = iteration["final_answer"] # Sort and convert dicts to lists result = [] for gi_key in sorted(gepa_iters.keys()): gi_data = gepa_iters[gi_key] rlm_calls = [] for rci_key in sorted(gi_data["rlm_calls"].keys()): call = gi_data["rlm_calls"][rci_key] call["iterations"].sort(key=lambda x: x["rlm_iter"]) rlm_calls.append(call) gi_data["rlm_calls"] = rlm_calls result.append(gi_data) return {"gepa_iterations": result} @bp.route("/load", methods=["POST"]) def load_dataset_endpoint(): data = request.get_json() repo = data.get("repo", "").strip() if not repo: return jsonify({"error": "repo is required"}), 400 config = data.get("config", "rlm_call_traces") split = data.get("split", "train") try: ds = load_dataset(repo, config, split=split) except Exception as e: return jsonify({"error": f"Failed to load dataset: {e}"}), 400 ds_id = _make_id(repo, config, split) rows = [ds[i] for i in range(len(ds))] hierarchy = _build_hierarchy(rows) # Extract metadata from first row first_row = rows[0] if rows else {} metadata = { "run_id": first_row.get("run_id", ""), "method": first_row.get("method", ""), "k": first_row.get("k", 0), "model": first_row.get("model", ""), } _cache[ds_id] = { "repo": repo, "config": config, "split": split, "hierarchy": hierarchy, "metadata": metadata, "n_rows": len(rows), } short_name = repo.rsplit("/", 1)[-1] if "/" in repo else repo return jsonify({ "id": ds_id, "repo": repo, "name": short_name, "config": config, "split": split, "metadata": metadata, "n_gepa_iters": len(hierarchy["gepa_iterations"]), "n_rows": len(rows), }) @bp.route("/", methods=["GET"]) def list_datasets(): result = [] for ds_id, info in _cache.items(): result.append({ "id": ds_id, "repo": info["repo"], "name": info["repo"].rsplit("/", 1)[-1] if "/" in info["repo"] else info["repo"], "config": info["config"], "split": info["split"], "metadata": info["metadata"], "n_rows": info["n_rows"], "n_gepa_iters": len(info["hierarchy"]["gepa_iterations"]), }) return jsonify(result) @bp.route("//overview", methods=["GET"]) def get_overview(ds_id): """Level 1: Summary of all GEPA iterations.""" if ds_id not in _cache: return jsonify({"error": "Dataset not loaded"}), 404 info = _cache[ds_id] hierarchy = info["hierarchy"] summaries = [] for gi in hierarchy["gepa_iterations"]: total_rlm_iters = sum(len(c["iterations"]) for c in gi["rlm_calls"]) summaries.append({ "gepa_iter": gi["gepa_iter"], "n_rlm_calls": len(gi["rlm_calls"]), "n_rlm_iters": total_rlm_iters, "total_input_tokens": gi["total_input_tokens"], "total_output_tokens": gi["total_output_tokens"], "total_execution_time": gi["total_execution_time"], "has_final_answer": gi["final_answer"] is not None, "final_answer_preview": (gi["final_answer"] or "")[:200], }) return jsonify({ "metadata": info["metadata"], "gepa_iterations": summaries, }) @bp.route("//gepa/", methods=["GET"]) def get_gepa_iteration(ds_id, gepa_iter): """Level 2: RLM timeline for a specific GEPA iteration.""" if ds_id not in _cache: return jsonify({"error": "Dataset not loaded"}), 404 info = _cache[ds_id] hierarchy = info["hierarchy"] gi_data = None for gi in hierarchy["gepa_iterations"]: if gi["gepa_iter"] == gepa_iter: gi_data = gi break if gi_data is None: return jsonify({"error": f"GEPA iteration {gepa_iter} not found"}), 404 # Return full RLM call data with iterations (truncate prompts for timeline view) rlm_calls = [] for call in gi_data["rlm_calls"]: iters = [] for it in call["iterations"]: iters.append({ "rlm_iter": it["rlm_iter"], "model": it["model"], "input_tokens": it["input_tokens"], "output_tokens": it["output_tokens"], "execution_time": it["execution_time"], "has_code_blocks": it["has_code_blocks"], "n_code_blocks": len(it["code_blocks"]), "response_preview": (it["response"] or "")[:300], "has_final_answer": it["final_answer"] is not None, "timestamp": it["timestamp"], }) rlm_calls.append({ "rlm_call_idx": call["rlm_call_idx"], "iterations": iters, }) return jsonify({ "gepa_iter": gepa_iter, "total_input_tokens": gi_data["total_input_tokens"], "total_output_tokens": gi_data["total_output_tokens"], "total_execution_time": gi_data["total_execution_time"], "final_answer": gi_data["final_answer"], "rlm_calls": rlm_calls, }) @bp.route("//gepa//rlm//", methods=["GET"]) def get_rlm_iteration(ds_id, gepa_iter, rlm_call_idx, rlm_iter): """Level 3: Full detail for a specific RLM iteration.""" if ds_id not in _cache: return jsonify({"error": "Dataset not loaded"}), 404 info = _cache[ds_id] hierarchy = info["hierarchy"] for gi in hierarchy["gepa_iterations"]: if gi["gepa_iter"] != gepa_iter: continue for call in gi["rlm_calls"]: if call["rlm_call_idx"] != rlm_call_idx: continue for it in call["iterations"]: if it["rlm_iter"] == rlm_iter: return jsonify(it) return jsonify({"error": "RLM iteration not found"}), 404 @bp.route("/", methods=["DELETE"]) def unload_dataset(ds_id): if ds_id in _cache: del _cache[ds_id] return jsonify({"status": "ok"})