| | import json |
| | import os |
| | import uuid |
| | import tempfile |
| | import threading |
| | from flask import Blueprint, request, jsonify |
| |
|
| | bp = Blueprint("presets", __name__, url_prefix="/api/presets") |
| |
|
| | PRESETS_REPO = "reasoning-degeneration-dev/AGG_VIS_PRESETS" |
| | VALID_TYPES = {"model", "arena", "rlm", "rlm-eval", "harbor", "adaevolve"} |
| | LOCAL_PRESETS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "presets") |
| |
|
| | |
| | _cache: dict[str, list[dict]] = {} |
| | _cache_loaded: set[str] = set() |
| | _lock = threading.Lock() |
| |
|
| |
|
| | def _ensure_local_dir(): |
| | os.makedirs(LOCAL_PRESETS_DIR, exist_ok=True) |
| |
|
| |
|
| | def _local_path(vis_type: str) -> str: |
| | _ensure_local_dir() |
| | return os.path.join(LOCAL_PRESETS_DIR, f"{vis_type}_presets.json") |
| |
|
| |
|
| | def _download_presets(vis_type: str) -> list[dict]: |
| | """Download presets from HuggingFace, falling back to local file.""" |
| | try: |
| | from huggingface_hub import hf_hub_download |
| | path = hf_hub_download( |
| | PRESETS_REPO, |
| | f"{vis_type}_presets.json", |
| | repo_type="dataset", |
| | ) |
| | with open(path) as f: |
| | presets = json.load(f) |
| | |
| | with open(_local_path(vis_type), "w") as f: |
| | json.dump(presets, f, indent=2) |
| | return presets |
| | except Exception: |
| | |
| | local = _local_path(vis_type) |
| | if os.path.exists(local): |
| | with open(local) as f: |
| | return json.load(f) |
| | return [] |
| |
|
| |
|
| | def _upload_presets(vis_type: str, presets: list[dict]): |
| | """Upload presets to HuggingFace (best-effort, non-blocking).""" |
| | |
| | with open(_local_path(vis_type), "w") as f: |
| | json.dump(presets, f, indent=2) |
| |
|
| | def _do_upload(): |
| | try: |
| | from huggingface_hub import HfApi |
| | api = HfApi() |
| | |
| | try: |
| | api.create_repo( |
| | PRESETS_REPO, |
| | repo_type="dataset", |
| | exist_ok=True, |
| | ) |
| | except Exception: |
| | pass |
| | with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) as f: |
| | json.dump(presets, f, indent=2) |
| | tmp = f.name |
| | api.upload_file( |
| | path_or_fileobj=tmp, |
| | path_in_repo=f"{vis_type}_presets.json", |
| | repo_id=PRESETS_REPO, |
| | repo_type="dataset", |
| | ) |
| | os.unlink(tmp) |
| | except Exception as e: |
| | print(f"[presets] HF upload failed for {vis_type}: {e}") |
| |
|
| | threading.Thread(target=_do_upload, daemon=True).start() |
| |
|
| |
|
| | def _get_presets(vis_type: str) -> list[dict]: |
| | """Get presets for a visualizer type, downloading if needed.""" |
| | with _lock: |
| | if vis_type not in _cache_loaded: |
| | _cache[vis_type] = _download_presets(vis_type) |
| | _cache_loaded.add(vis_type) |
| | return list(_cache.get(vis_type, [])) |
| |
|
| |
|
| | def _set_presets(vis_type: str, presets: list[dict]): |
| | """Update presets in cache and sync to HF.""" |
| | with _lock: |
| | _cache[vis_type] = presets |
| | _cache_loaded.add(vis_type) |
| | _upload_presets(vis_type, presets) |
| |
|
| |
|
| | @bp.route("/<vis_type>", methods=["GET"]) |
| | def list_presets(vis_type): |
| | if vis_type not in VALID_TYPES: |
| | return jsonify({"error": f"Invalid type. Must be one of: {VALID_TYPES}"}), 400 |
| | return jsonify(_get_presets(vis_type)) |
| |
|
| |
|
| | @bp.route("/<vis_type>", methods=["POST"]) |
| | def create_preset(vis_type): |
| | if vis_type not in VALID_TYPES: |
| | return jsonify({"error": f"Invalid type. Must be one of: {VALID_TYPES}"}), 400 |
| |
|
| | data = request.get_json() |
| | name = data.get("name", "").strip() |
| |
|
| | if not name: |
| | return jsonify({"error": "name is required"}), 400 |
| |
|
| | preset = { |
| | "id": uuid.uuid4().hex[:8], |
| | "name": name, |
| | } |
| | |
| | repo = data.get("repo", "").strip() |
| | if not repo: |
| | return jsonify({"error": "repo is required"}), 400 |
| | preset["repo"] = repo |
| | preset["split"] = data.get("split", "train") |
| |
|
| | if vis_type == "model": |
| | preset["column"] = data.get("column", "model_responses") |
| | elif vis_type in ("rlm", "rlm-eval"): |
| | preset["config"] = data.get("config", "rlm_call_traces") |
| |
|
| | presets = _get_presets(vis_type) |
| | presets.append(preset) |
| | _set_presets(vis_type, presets) |
| |
|
| | return jsonify(preset), 201 |
| |
|
| |
|
| | @bp.route("/<vis_type>/<preset_id>", methods=["PUT"]) |
| | def update_preset(vis_type, preset_id): |
| | if vis_type not in VALID_TYPES: |
| | return jsonify({"error": f"Invalid type. Must be one of: {VALID_TYPES}"}), 400 |
| |
|
| | data = request.get_json() |
| | presets = _get_presets(vis_type) |
| |
|
| | for p in presets: |
| | if p["id"] == preset_id: |
| | if "name" in data: |
| | p["name"] = data["name"].strip() |
| | if "column" in data: |
| | p["column"] = data["column"] |
| | if "split" in data: |
| | p["split"] = data["split"] |
| | if "config" in data: |
| | p["config"] = data["config"] |
| | _set_presets(vis_type, presets) |
| | return jsonify(p) |
| |
|
| | return jsonify({"error": "not found"}), 404 |
| |
|
| |
|
| | @bp.route("/<vis_type>/<preset_id>", methods=["DELETE"]) |
| | def delete_preset(vis_type, preset_id): |
| | if vis_type not in VALID_TYPES: |
| | return jsonify({"error": f"Invalid type. Must be one of: {VALID_TYPES}"}), 400 |
| |
|
| | presets = _get_presets(vis_type) |
| | presets = [p for p in presets if p["id"] != preset_id] |
| | _set_presets(vis_type, presets) |
| | return jsonify({"status": "ok"}) |
| |
|
| |
|
| | @bp.route("/sync", methods=["POST"]) |
| | def sync_presets(): |
| | """Force re-download presets from HF.""" |
| | with _lock: |
| | _cache.clear() |
| | _cache_loaded.clear() |
| | for vt in VALID_TYPES: |
| | _get_presets(vt) |
| | return jsonify({"status": "ok"}) |
| |
|