import hashlib from flask import Blueprint, request, jsonify from datasets import load_dataset bp = Blueprint("adaevolve_datasets", __name__, url_prefix="/api/adaevolve/datasets") _cache: dict[str, dict] = {} def _make_id(repo: str, split: str) -> str: key = f"{repo}:{split}" return hashlib.md5(key.encode()).hexdigest()[:12] def _build_iteration_summary(row: dict, idx: int) -> dict: """Build a summary for one iteration row.""" return { "index": idx, "iteration": row.get("iteration", idx), "island_id": row.get("island_id", 0), "score": row.get("score", 0.0), "best_score": row.get("best_score", 0.0), "delta": row.get("delta", 0.0), "adaptation_type": row.get("adaptation_type", ""), "exploration_intensity": row.get("exploration_intensity", 0.0), "is_valid": row.get("is_valid", False), "task_id": row.get("task_id", ""), "meta_guidance_tactic": row.get("meta_guidance_tactic", ""), "tactic_approach_type": row.get("tactic_approach_type", ""), } def _build_summary_stats(iterations: list[dict]) -> dict: """Build aggregate stats across all iterations.""" adaptation_counts: dict[str, int] = {} island_best_scores: dict[int, float] = {} global_best = 0.0 for it in iterations: atype = it.get("adaptation_type", "unknown") adaptation_counts[atype] = adaptation_counts.get(atype, 0) + 1 iid = it.get("island_id", 0) score = it.get("best_score", 0.0) if iid not in island_best_scores or score > island_best_scores[iid]: island_best_scores[iid] = score if score > global_best: global_best = score return { "adaptation_counts": adaptation_counts, "island_best_scores": {str(k): v for k, v in island_best_scores.items()}, "global_best": global_best, "n_islands": len(island_best_scores), } @bp.route("/load", methods=["POST"]) def load_dataset_endpoint(): data = request.get_json() repo = data.get("repo", "").strip() if not repo: return jsonify({"error": "repo is required"}), 400 split = data.get("split", "train") try: ds = load_dataset(repo, split=split) except Exception as e: return jsonify({"error": f"Failed to load dataset: {e}"}), 400 ds_id = _make_id(repo, split) # Build iteration summaries iterations = [] for i in range(len(ds)): row = ds[i] summary = _build_iteration_summary(row, i) iterations.append(summary) summary_stats = _build_summary_stats(iterations) _cache[ds_id] = { "repo": repo, "split": split, "dataset": ds, "iterations": iterations, } short_name = repo.rsplit("/", 1)[-1] if "/" in repo else repo return jsonify({ "id": ds_id, "repo": repo, "name": short_name, "split": split, "iterations": iterations, "n_iterations": len(iterations), "summary_stats": summary_stats, }) @bp.route("/", methods=["GET"]) def list_datasets(): result = [] for ds_id, info in _cache.items(): iterations = info["iterations"] summary_stats = _build_summary_stats(iterations) result.append({ "id": ds_id, "repo": info["repo"], "name": info["repo"].rsplit("/", 1)[-1] if "/" in info["repo"] else info["repo"], "split": info["split"], "n_iterations": len(iterations), "iterations": iterations, "summary_stats": summary_stats, }) return jsonify(result) @bp.route("//iterations", methods=["GET"]) def get_iterations(ds_id): if ds_id not in _cache: return jsonify({"error": "Dataset not loaded"}), 404 return jsonify(_cache[ds_id]["iterations"]) @bp.route("//iteration/", methods=["GET"]) def get_iteration(ds_id, idx): """Get full detail for one iteration including prompt, reasoning, code.""" if ds_id not in _cache: return jsonify({"error": "Dataset not loaded"}), 404 info = _cache[ds_id] if idx < 0 or idx >= len(info["dataset"]): return jsonify({"error": f"Iteration index {idx} out of range"}), 404 row = info["dataset"][idx] return jsonify({ "index": idx, "iteration": row.get("iteration", idx), "island_id": row.get("island_id", 0), "score": row.get("score", 0.0), "best_score": row.get("best_score", 0.0), "delta": row.get("delta", 0.0), "adaptation_type": row.get("adaptation_type", ""), "exploration_intensity": row.get("exploration_intensity", 0.0), "is_valid": row.get("is_valid", False), "task_id": row.get("task_id", ""), "prompt_text": row.get("prompt_text", ""), "reasoning_trace": row.get("reasoning_trace", ""), "program_code": row.get("program_code", ""), "meta_guidance_tactic": row.get("meta_guidance_tactic", ""), "tactic_approach_type": row.get("tactic_approach_type", ""), }) @bp.route("/", methods=["DELETE"]) def unload_dataset(ds_id): if ds_id in _cache: del _cache[ds_id] return jsonify({"status": "ok"})