| | import json |
| | import hashlib |
| | from flask import Blueprint, request, jsonify |
| | from datasets import load_dataset |
| |
|
| | bp = Blueprint("rlm_datasets", __name__, url_prefix="/api/rlm/datasets") |
| |
|
| | _cache: dict[str, dict] = {} |
| |
|
| |
|
| | def _make_id(repo: str, config: str, split: str) -> str: |
| | key = f"{repo}:{config}:{split}" |
| | return hashlib.md5(key.encode()).hexdigest()[:12] |
| |
|
| |
|
| | def _build_hierarchy(rows: list[dict]) -> dict: |
| | """Reconstruct hierarchy from flat rlm_call_traces rows.""" |
| | gepa_iters: dict[int, dict] = {} |
| |
|
| | for row in rows: |
| | gi = row.get("gepa_iter", 0) |
| | rci = row.get("rlm_call_idx", 0) |
| | ri = row.get("rlm_iter", 0) |
| |
|
| | if gi not in gepa_iters: |
| | gepa_iters[gi] = { |
| | "gepa_iter": gi, |
| | "rlm_calls": {}, |
| | "total_input_tokens": 0, |
| | "total_output_tokens": 0, |
| | "total_execution_time": 0.0, |
| | "final_answer": None, |
| | } |
| |
|
| | gi_data = gepa_iters[gi] |
| | if rci not in gi_data["rlm_calls"]: |
| | gi_data["rlm_calls"][rci] = { |
| | "rlm_call_idx": rci, |
| | "iterations": [], |
| | } |
| |
|
| | |
| | code_blocks = [] |
| | cbj = row.get("code_blocks_json", "") |
| | if cbj and cbj != "[]": |
| | try: |
| | code_blocks = json.loads(cbj) if isinstance(cbj, str) else cbj |
| | except (json.JSONDecodeError, TypeError): |
| | code_blocks = [] |
| |
|
| | iteration = { |
| | "rlm_iter": ri, |
| | "prompt": row.get("prompt", ""), |
| | "response": row.get("response", ""), |
| | "model": row.get("model", ""), |
| | "input_tokens": row.get("input_tokens", 0), |
| | "output_tokens": row.get("output_tokens", 0), |
| | "execution_time": row.get("execution_time", 0.0), |
| | "has_code_blocks": row.get("has_code_blocks", False), |
| | "code_blocks": code_blocks, |
| | "final_answer": row.get("final_answer"), |
| | "subcall_id": row.get("subcall_id"), |
| | "parent_id": row.get("parent_id"), |
| | "timestamp": row.get("timestamp", ""), |
| | } |
| |
|
| | gi_data["rlm_calls"][rci]["iterations"].append(iteration) |
| | gi_data["total_input_tokens"] += iteration["input_tokens"] or 0 |
| | gi_data["total_output_tokens"] += iteration["output_tokens"] or 0 |
| | gi_data["total_execution_time"] += iteration["execution_time"] or 0.0 |
| |
|
| | if iteration["final_answer"]: |
| | gi_data["final_answer"] = iteration["final_answer"] |
| |
|
| | |
| | result = [] |
| | for gi_key in sorted(gepa_iters.keys()): |
| | gi_data = gepa_iters[gi_key] |
| | rlm_calls = [] |
| | for rci_key in sorted(gi_data["rlm_calls"].keys()): |
| | call = gi_data["rlm_calls"][rci_key] |
| | call["iterations"].sort(key=lambda x: x["rlm_iter"]) |
| | rlm_calls.append(call) |
| | gi_data["rlm_calls"] = rlm_calls |
| | result.append(gi_data) |
| |
|
| | return {"gepa_iterations": result} |
| |
|
| |
|
| | @bp.route("/load", methods=["POST"]) |
| | def load_dataset_endpoint(): |
| | data = request.get_json() |
| | repo = data.get("repo", "").strip() |
| | if not repo: |
| | return jsonify({"error": "repo is required"}), 400 |
| |
|
| | config = data.get("config", "rlm_call_traces") |
| | split = data.get("split", "train") |
| |
|
| | try: |
| | ds = load_dataset(repo, config, split=split) |
| | except Exception as e: |
| | return jsonify({"error": f"Failed to load dataset: {e}"}), 400 |
| |
|
| | ds_id = _make_id(repo, config, split) |
| | rows = [ds[i] for i in range(len(ds))] |
| | hierarchy = _build_hierarchy(rows) |
| |
|
| | |
| | first_row = rows[0] if rows else {} |
| | metadata = { |
| | "run_id": first_row.get("run_id", ""), |
| | "method": first_row.get("method", ""), |
| | "k": first_row.get("k", 0), |
| | "model": first_row.get("model", ""), |
| | } |
| |
|
| | _cache[ds_id] = { |
| | "repo": repo, |
| | "config": config, |
| | "split": split, |
| | "hierarchy": hierarchy, |
| | "metadata": metadata, |
| | "n_rows": len(rows), |
| | } |
| |
|
| | short_name = repo.rsplit("/", 1)[-1] if "/" in repo else repo |
| |
|
| | return jsonify({ |
| | "id": ds_id, |
| | "repo": repo, |
| | "name": short_name, |
| | "config": config, |
| | "split": split, |
| | "metadata": metadata, |
| | "n_gepa_iters": len(hierarchy["gepa_iterations"]), |
| | "n_rows": len(rows), |
| | }) |
| |
|
| |
|
| | @bp.route("/", methods=["GET"]) |
| | def list_datasets(): |
| | result = [] |
| | for ds_id, info in _cache.items(): |
| | result.append({ |
| | "id": ds_id, |
| | "repo": info["repo"], |
| | "name": info["repo"].rsplit("/", 1)[-1] if "/" in info["repo"] else info["repo"], |
| | "config": info["config"], |
| | "split": info["split"], |
| | "metadata": info["metadata"], |
| | "n_rows": info["n_rows"], |
| | "n_gepa_iters": len(info["hierarchy"]["gepa_iterations"]), |
| | }) |
| | return jsonify(result) |
| |
|
| |
|
| | @bp.route("/<ds_id>/overview", methods=["GET"]) |
| | def get_overview(ds_id): |
| | """Level 1: Summary of all GEPA iterations.""" |
| | if ds_id not in _cache: |
| | return jsonify({"error": "Dataset not loaded"}), 404 |
| |
|
| | info = _cache[ds_id] |
| | hierarchy = info["hierarchy"] |
| |
|
| | summaries = [] |
| | for gi in hierarchy["gepa_iterations"]: |
| | total_rlm_iters = sum(len(c["iterations"]) for c in gi["rlm_calls"]) |
| | summaries.append({ |
| | "gepa_iter": gi["gepa_iter"], |
| | "n_rlm_calls": len(gi["rlm_calls"]), |
| | "n_rlm_iters": total_rlm_iters, |
| | "total_input_tokens": gi["total_input_tokens"], |
| | "total_output_tokens": gi["total_output_tokens"], |
| | "total_execution_time": gi["total_execution_time"], |
| | "has_final_answer": gi["final_answer"] is not None, |
| | "final_answer_preview": (gi["final_answer"] or "")[:200], |
| | }) |
| |
|
| | return jsonify({ |
| | "metadata": info["metadata"], |
| | "gepa_iterations": summaries, |
| | }) |
| |
|
| |
|
| | @bp.route("/<ds_id>/gepa/<int:gepa_iter>", methods=["GET"]) |
| | def get_gepa_iteration(ds_id, gepa_iter): |
| | """Level 2: RLM timeline for a specific GEPA iteration.""" |
| | if ds_id not in _cache: |
| | return jsonify({"error": "Dataset not loaded"}), 404 |
| |
|
| | info = _cache[ds_id] |
| | hierarchy = info["hierarchy"] |
| |
|
| | gi_data = None |
| | for gi in hierarchy["gepa_iterations"]: |
| | if gi["gepa_iter"] == gepa_iter: |
| | gi_data = gi |
| | break |
| |
|
| | if gi_data is None: |
| | return jsonify({"error": f"GEPA iteration {gepa_iter} not found"}), 404 |
| |
|
| | |
| | rlm_calls = [] |
| | for call in gi_data["rlm_calls"]: |
| | iters = [] |
| | for it in call["iterations"]: |
| | iters.append({ |
| | "rlm_iter": it["rlm_iter"], |
| | "model": it["model"], |
| | "input_tokens": it["input_tokens"], |
| | "output_tokens": it["output_tokens"], |
| | "execution_time": it["execution_time"], |
| | "has_code_blocks": it["has_code_blocks"], |
| | "n_code_blocks": len(it["code_blocks"]), |
| | "response_preview": (it["response"] or "")[:300], |
| | "has_final_answer": it["final_answer"] is not None, |
| | "timestamp": it["timestamp"], |
| | }) |
| | rlm_calls.append({ |
| | "rlm_call_idx": call["rlm_call_idx"], |
| | "iterations": iters, |
| | }) |
| |
|
| | return jsonify({ |
| | "gepa_iter": gepa_iter, |
| | "total_input_tokens": gi_data["total_input_tokens"], |
| | "total_output_tokens": gi_data["total_output_tokens"], |
| | "total_execution_time": gi_data["total_execution_time"], |
| | "final_answer": gi_data["final_answer"], |
| | "rlm_calls": rlm_calls, |
| | }) |
| |
|
| |
|
| | @bp.route("/<ds_id>/gepa/<int:gepa_iter>/rlm/<int:rlm_call_idx>/<int:rlm_iter>", methods=["GET"]) |
| | def get_rlm_iteration(ds_id, gepa_iter, rlm_call_idx, rlm_iter): |
| | """Level 3: Full detail for a specific RLM iteration.""" |
| | if ds_id not in _cache: |
| | return jsonify({"error": "Dataset not loaded"}), 404 |
| |
|
| | info = _cache[ds_id] |
| | hierarchy = info["hierarchy"] |
| |
|
| | for gi in hierarchy["gepa_iterations"]: |
| | if gi["gepa_iter"] != gepa_iter: |
| | continue |
| | for call in gi["rlm_calls"]: |
| | if call["rlm_call_idx"] != rlm_call_idx: |
| | continue |
| | for it in call["iterations"]: |
| | if it["rlm_iter"] == rlm_iter: |
| | return jsonify(it) |
| |
|
| | return jsonify({"error": "RLM iteration not found"}), 404 |
| |
|
| |
|
| | @bp.route("/<ds_id>", methods=["DELETE"]) |
| | def unload_dataset(ds_id): |
| | if ds_id in _cache: |
| | del _cache[ds_id] |
| | return jsonify({"status": "ok"}) |
| |
|