agg-trace-visualizer / backend /api /adaevolve_datasets.py
Zayne Rea Sprague
feat: add AdaEvolve visualizer tab for reasoning stickiness analysis
e6cfd0f
import hashlib
from flask import Blueprint, request, jsonify
from datasets import load_dataset
bp = Blueprint("adaevolve_datasets", __name__, url_prefix="/api/adaevolve/datasets")
_cache: dict[str, dict] = {}
def _make_id(repo: str, split: str) -> str:
key = f"{repo}:{split}"
return hashlib.md5(key.encode()).hexdigest()[:12]
def _build_iteration_summary(row: dict, idx: int) -> dict:
"""Build a summary for one iteration row."""
return {
"index": idx,
"iteration": row.get("iteration", idx),
"island_id": row.get("island_id", 0),
"score": row.get("score", 0.0),
"best_score": row.get("best_score", 0.0),
"delta": row.get("delta", 0.0),
"adaptation_type": row.get("adaptation_type", ""),
"exploration_intensity": row.get("exploration_intensity", 0.0),
"is_valid": row.get("is_valid", False),
"task_id": row.get("task_id", ""),
"meta_guidance_tactic": row.get("meta_guidance_tactic", ""),
"tactic_approach_type": row.get("tactic_approach_type", ""),
}
def _build_summary_stats(iterations: list[dict]) -> dict:
"""Build aggregate stats across all iterations."""
adaptation_counts: dict[str, int] = {}
island_best_scores: dict[int, float] = {}
global_best = 0.0
for it in iterations:
atype = it.get("adaptation_type", "unknown")
adaptation_counts[atype] = adaptation_counts.get(atype, 0) + 1
iid = it.get("island_id", 0)
score = it.get("best_score", 0.0)
if iid not in island_best_scores or score > island_best_scores[iid]:
island_best_scores[iid] = score
if score > global_best:
global_best = score
return {
"adaptation_counts": adaptation_counts,
"island_best_scores": {str(k): v for k, v in island_best_scores.items()},
"global_best": global_best,
"n_islands": len(island_best_scores),
}
@bp.route("/load", methods=["POST"])
def load_dataset_endpoint():
data = request.get_json()
repo = data.get("repo", "").strip()
if not repo:
return jsonify({"error": "repo is required"}), 400
split = data.get("split", "train")
try:
ds = load_dataset(repo, split=split)
except Exception as e:
return jsonify({"error": f"Failed to load dataset: {e}"}), 400
ds_id = _make_id(repo, split)
# Build iteration summaries
iterations = []
for i in range(len(ds)):
row = ds[i]
summary = _build_iteration_summary(row, i)
iterations.append(summary)
summary_stats = _build_summary_stats(iterations)
_cache[ds_id] = {
"repo": repo,
"split": split,
"dataset": ds,
"iterations": iterations,
}
short_name = repo.rsplit("/", 1)[-1] if "/" in repo else repo
return jsonify({
"id": ds_id,
"repo": repo,
"name": short_name,
"split": split,
"iterations": iterations,
"n_iterations": len(iterations),
"summary_stats": summary_stats,
})
@bp.route("/", methods=["GET"])
def list_datasets():
result = []
for ds_id, info in _cache.items():
iterations = info["iterations"]
summary_stats = _build_summary_stats(iterations)
result.append({
"id": ds_id,
"repo": info["repo"],
"name": info["repo"].rsplit("/", 1)[-1] if "/" in info["repo"] else info["repo"],
"split": info["split"],
"n_iterations": len(iterations),
"iterations": iterations,
"summary_stats": summary_stats,
})
return jsonify(result)
@bp.route("/<ds_id>/iterations", methods=["GET"])
def get_iterations(ds_id):
if ds_id not in _cache:
return jsonify({"error": "Dataset not loaded"}), 404
return jsonify(_cache[ds_id]["iterations"])
@bp.route("/<ds_id>/iteration/<int:idx>", methods=["GET"])
def get_iteration(ds_id, idx):
"""Get full detail for one iteration including prompt, reasoning, code."""
if ds_id not in _cache:
return jsonify({"error": "Dataset not loaded"}), 404
info = _cache[ds_id]
if idx < 0 or idx >= len(info["dataset"]):
return jsonify({"error": f"Iteration index {idx} out of range"}), 404
row = info["dataset"][idx]
return jsonify({
"index": idx,
"iteration": row.get("iteration", idx),
"island_id": row.get("island_id", 0),
"score": row.get("score", 0.0),
"best_score": row.get("best_score", 0.0),
"delta": row.get("delta", 0.0),
"adaptation_type": row.get("adaptation_type", ""),
"exploration_intensity": row.get("exploration_intensity", 0.0),
"is_valid": row.get("is_valid", False),
"task_id": row.get("task_id", ""),
"prompt_text": row.get("prompt_text", ""),
"reasoning_trace": row.get("reasoning_trace", ""),
"program_code": row.get("program_code", ""),
"meta_guidance_tactic": row.get("meta_guidance_tactic", ""),
"tactic_approach_type": row.get("tactic_approach_type", ""),
})
@bp.route("/<ds_id>", methods=["DELETE"])
def unload_dataset(ds_id):
if ds_id in _cache:
del _cache[ds_id]
return jsonify({"status": "ok"})