Zayne Rea Sprague
Initial deploy: aggregate trace visualizer
8b41737
import json
import hashlib
from flask import Blueprint, request, jsonify
from datasets import load_dataset
bp = Blueprint("rlm_datasets", __name__, url_prefix="/api/rlm/datasets")
_cache: dict[str, dict] = {}
def _make_id(repo: str, config: str, split: str) -> str:
key = f"{repo}:{config}:{split}"
return hashlib.md5(key.encode()).hexdigest()[:12]
def _build_hierarchy(rows: list[dict]) -> dict:
"""Reconstruct hierarchy from flat rlm_call_traces rows."""
gepa_iters: dict[int, dict] = {}
for row in rows:
gi = row.get("gepa_iter", 0)
rci = row.get("rlm_call_idx", 0)
ri = row.get("rlm_iter", 0)
if gi not in gepa_iters:
gepa_iters[gi] = {
"gepa_iter": gi,
"rlm_calls": {},
"total_input_tokens": 0,
"total_output_tokens": 0,
"total_execution_time": 0.0,
"final_answer": None,
}
gi_data = gepa_iters[gi]
if rci not in gi_data["rlm_calls"]:
gi_data["rlm_calls"][rci] = {
"rlm_call_idx": rci,
"iterations": [],
}
# Parse code blocks
code_blocks = []
cbj = row.get("code_blocks_json", "")
if cbj and cbj != "[]":
try:
code_blocks = json.loads(cbj) if isinstance(cbj, str) else cbj
except (json.JSONDecodeError, TypeError):
code_blocks = []
iteration = {
"rlm_iter": ri,
"prompt": row.get("prompt", ""),
"response": row.get("response", ""),
"model": row.get("model", ""),
"input_tokens": row.get("input_tokens", 0),
"output_tokens": row.get("output_tokens", 0),
"execution_time": row.get("execution_time", 0.0),
"has_code_blocks": row.get("has_code_blocks", False),
"code_blocks": code_blocks,
"final_answer": row.get("final_answer"),
"subcall_id": row.get("subcall_id"),
"parent_id": row.get("parent_id"),
"timestamp": row.get("timestamp", ""),
}
gi_data["rlm_calls"][rci]["iterations"].append(iteration)
gi_data["total_input_tokens"] += iteration["input_tokens"] or 0
gi_data["total_output_tokens"] += iteration["output_tokens"] or 0
gi_data["total_execution_time"] += iteration["execution_time"] or 0.0
if iteration["final_answer"]:
gi_data["final_answer"] = iteration["final_answer"]
# Sort and convert dicts to lists
result = []
for gi_key in sorted(gepa_iters.keys()):
gi_data = gepa_iters[gi_key]
rlm_calls = []
for rci_key in sorted(gi_data["rlm_calls"].keys()):
call = gi_data["rlm_calls"][rci_key]
call["iterations"].sort(key=lambda x: x["rlm_iter"])
rlm_calls.append(call)
gi_data["rlm_calls"] = rlm_calls
result.append(gi_data)
return {"gepa_iterations": result}
@bp.route("/load", methods=["POST"])
def load_dataset_endpoint():
data = request.get_json()
repo = data.get("repo", "").strip()
if not repo:
return jsonify({"error": "repo is required"}), 400
config = data.get("config", "rlm_call_traces")
split = data.get("split", "train")
try:
ds = load_dataset(repo, config, split=split)
except Exception as e:
return jsonify({"error": f"Failed to load dataset: {e}"}), 400
ds_id = _make_id(repo, config, split)
rows = [ds[i] for i in range(len(ds))]
hierarchy = _build_hierarchy(rows)
# Extract metadata from first row
first_row = rows[0] if rows else {}
metadata = {
"run_id": first_row.get("run_id", ""),
"method": first_row.get("method", ""),
"k": first_row.get("k", 0),
"model": first_row.get("model", ""),
}
_cache[ds_id] = {
"repo": repo,
"config": config,
"split": split,
"hierarchy": hierarchy,
"metadata": metadata,
"n_rows": len(rows),
}
short_name = repo.rsplit("/", 1)[-1] if "/" in repo else repo
return jsonify({
"id": ds_id,
"repo": repo,
"name": short_name,
"config": config,
"split": split,
"metadata": metadata,
"n_gepa_iters": len(hierarchy["gepa_iterations"]),
"n_rows": len(rows),
})
@bp.route("/", methods=["GET"])
def list_datasets():
result = []
for ds_id, info in _cache.items():
result.append({
"id": ds_id,
"repo": info["repo"],
"name": info["repo"].rsplit("/", 1)[-1] if "/" in info["repo"] else info["repo"],
"config": info["config"],
"split": info["split"],
"metadata": info["metadata"],
"n_rows": info["n_rows"],
"n_gepa_iters": len(info["hierarchy"]["gepa_iterations"]),
})
return jsonify(result)
@bp.route("/<ds_id>/overview", methods=["GET"])
def get_overview(ds_id):
"""Level 1: Summary of all GEPA iterations."""
if ds_id not in _cache:
return jsonify({"error": "Dataset not loaded"}), 404
info = _cache[ds_id]
hierarchy = info["hierarchy"]
summaries = []
for gi in hierarchy["gepa_iterations"]:
total_rlm_iters = sum(len(c["iterations"]) for c in gi["rlm_calls"])
summaries.append({
"gepa_iter": gi["gepa_iter"],
"n_rlm_calls": len(gi["rlm_calls"]),
"n_rlm_iters": total_rlm_iters,
"total_input_tokens": gi["total_input_tokens"],
"total_output_tokens": gi["total_output_tokens"],
"total_execution_time": gi["total_execution_time"],
"has_final_answer": gi["final_answer"] is not None,
"final_answer_preview": (gi["final_answer"] or "")[:200],
})
return jsonify({
"metadata": info["metadata"],
"gepa_iterations": summaries,
})
@bp.route("/<ds_id>/gepa/<int:gepa_iter>", methods=["GET"])
def get_gepa_iteration(ds_id, gepa_iter):
"""Level 2: RLM timeline for a specific GEPA iteration."""
if ds_id not in _cache:
return jsonify({"error": "Dataset not loaded"}), 404
info = _cache[ds_id]
hierarchy = info["hierarchy"]
gi_data = None
for gi in hierarchy["gepa_iterations"]:
if gi["gepa_iter"] == gepa_iter:
gi_data = gi
break
if gi_data is None:
return jsonify({"error": f"GEPA iteration {gepa_iter} not found"}), 404
# Return full RLM call data with iterations (truncate prompts for timeline view)
rlm_calls = []
for call in gi_data["rlm_calls"]:
iters = []
for it in call["iterations"]:
iters.append({
"rlm_iter": it["rlm_iter"],
"model": it["model"],
"input_tokens": it["input_tokens"],
"output_tokens": it["output_tokens"],
"execution_time": it["execution_time"],
"has_code_blocks": it["has_code_blocks"],
"n_code_blocks": len(it["code_blocks"]),
"response_preview": (it["response"] or "")[:300],
"has_final_answer": it["final_answer"] is not None,
"timestamp": it["timestamp"],
})
rlm_calls.append({
"rlm_call_idx": call["rlm_call_idx"],
"iterations": iters,
})
return jsonify({
"gepa_iter": gepa_iter,
"total_input_tokens": gi_data["total_input_tokens"],
"total_output_tokens": gi_data["total_output_tokens"],
"total_execution_time": gi_data["total_execution_time"],
"final_answer": gi_data["final_answer"],
"rlm_calls": rlm_calls,
})
@bp.route("/<ds_id>/gepa/<int:gepa_iter>/rlm/<int:rlm_call_idx>/<int:rlm_iter>", methods=["GET"])
def get_rlm_iteration(ds_id, gepa_iter, rlm_call_idx, rlm_iter):
"""Level 3: Full detail for a specific RLM iteration."""
if ds_id not in _cache:
return jsonify({"error": "Dataset not loaded"}), 404
info = _cache[ds_id]
hierarchy = info["hierarchy"]
for gi in hierarchy["gepa_iterations"]:
if gi["gepa_iter"] != gepa_iter:
continue
for call in gi["rlm_calls"]:
if call["rlm_call_idx"] != rlm_call_idx:
continue
for it in call["iterations"]:
if it["rlm_iter"] == rlm_iter:
return jsonify(it)
return jsonify({"error": "RLM iteration not found"}), 404
@bp.route("/<ds_id>", methods=["DELETE"])
def unload_dataset(ds_id):
if ds_id in _cache:
del _cache[ds_id]
return jsonify({"status": "ok"})