File size: 5,914 Bytes
8b41737 e6cfd0f 8b41737 b630916 8b41737 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | import json
import os
import uuid
import tempfile
import threading
from flask import Blueprint, request, jsonify
bp = Blueprint("presets", __name__, url_prefix="/api/presets")
PRESETS_REPO = "reasoning-degeneration-dev/AGG_VIS_PRESETS"
VALID_TYPES = {"model", "arena", "rlm", "rlm-eval", "harbor", "adaevolve"}
LOCAL_PRESETS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "presets")
# In-memory cache: vis_type -> list[dict]
_cache: dict[str, list[dict]] = {}
_cache_loaded: set[str] = set()
_lock = threading.Lock()
def _ensure_local_dir():
os.makedirs(LOCAL_PRESETS_DIR, exist_ok=True)
def _local_path(vis_type: str) -> str:
_ensure_local_dir()
return os.path.join(LOCAL_PRESETS_DIR, f"{vis_type}_presets.json")
def _download_presets(vis_type: str) -> list[dict]:
"""Download presets from HuggingFace, falling back to local file."""
try:
from huggingface_hub import hf_hub_download
path = hf_hub_download(
PRESETS_REPO,
f"{vis_type}_presets.json",
repo_type="dataset",
)
with open(path) as f:
presets = json.load(f)
# Cache locally for offline fallback
with open(_local_path(vis_type), "w") as f:
json.dump(presets, f, indent=2)
return presets
except Exception:
# Fall back to local cache
local = _local_path(vis_type)
if os.path.exists(local):
with open(local) as f:
return json.load(f)
return []
def _upload_presets(vis_type: str, presets: list[dict]):
"""Upload presets to HuggingFace (best-effort, non-blocking)."""
# Always save locally first
with open(_local_path(vis_type), "w") as f:
json.dump(presets, f, indent=2)
def _do_upload():
try:
from huggingface_hub import HfApi
api = HfApi()
# Ensure repo exists
try:
api.create_repo(
PRESETS_REPO,
repo_type="dataset",
exist_ok=True,
)
except Exception:
pass
with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) as f:
json.dump(presets, f, indent=2)
tmp = f.name
api.upload_file(
path_or_fileobj=tmp,
path_in_repo=f"{vis_type}_presets.json",
repo_id=PRESETS_REPO,
repo_type="dataset",
)
os.unlink(tmp)
except Exception as e:
print(f"[presets] HF upload failed for {vis_type}: {e}")
threading.Thread(target=_do_upload, daemon=True).start()
def _get_presets(vis_type: str) -> list[dict]:
"""Get presets for a visualizer type, downloading if needed."""
with _lock:
if vis_type not in _cache_loaded:
_cache[vis_type] = _download_presets(vis_type)
_cache_loaded.add(vis_type)
return list(_cache.get(vis_type, []))
def _set_presets(vis_type: str, presets: list[dict]):
"""Update presets in cache and sync to HF."""
with _lock:
_cache[vis_type] = presets
_cache_loaded.add(vis_type)
_upload_presets(vis_type, presets)
@bp.route("/<vis_type>", methods=["GET"])
def list_presets(vis_type):
if vis_type not in VALID_TYPES:
return jsonify({"error": f"Invalid type. Must be one of: {VALID_TYPES}"}), 400
return jsonify(_get_presets(vis_type))
@bp.route("/<vis_type>", methods=["POST"])
def create_preset(vis_type):
if vis_type not in VALID_TYPES:
return jsonify({"error": f"Invalid type. Must be one of: {VALID_TYPES}"}), 400
data = request.get_json()
name = data.get("name", "").strip()
if not name:
return jsonify({"error": "name is required"}), 400
preset = {
"id": uuid.uuid4().hex[:8],
"name": name,
}
# Include type-specific fields
repo = data.get("repo", "").strip()
if not repo:
return jsonify({"error": "repo is required"}), 400
preset["repo"] = repo
preset["split"] = data.get("split", "train")
if vis_type == "model":
preset["column"] = data.get("column", "model_responses")
elif vis_type in ("rlm", "rlm-eval"):
preset["config"] = data.get("config", "rlm_call_traces")
presets = _get_presets(vis_type)
presets.append(preset)
_set_presets(vis_type, presets)
return jsonify(preset), 201
@bp.route("/<vis_type>/<preset_id>", methods=["PUT"])
def update_preset(vis_type, preset_id):
if vis_type not in VALID_TYPES:
return jsonify({"error": f"Invalid type. Must be one of: {VALID_TYPES}"}), 400
data = request.get_json()
presets = _get_presets(vis_type)
for p in presets:
if p["id"] == preset_id:
if "name" in data:
p["name"] = data["name"].strip()
if "column" in data:
p["column"] = data["column"]
if "split" in data:
p["split"] = data["split"]
if "config" in data:
p["config"] = data["config"]
_set_presets(vis_type, presets)
return jsonify(p)
return jsonify({"error": "not found"}), 404
@bp.route("/<vis_type>/<preset_id>", methods=["DELETE"])
def delete_preset(vis_type, preset_id):
if vis_type not in VALID_TYPES:
return jsonify({"error": f"Invalid type. Must be one of: {VALID_TYPES}"}), 400
presets = _get_presets(vis_type)
presets = [p for p in presets if p["id"] != preset_id]
_set_presets(vis_type, presets)
return jsonify({"status": "ok"})
@bp.route("/sync", methods=["POST"])
def sync_presets():
"""Force re-download presets from HF."""
with _lock:
_cache.clear()
_cache_loaded.clear()
for vt in VALID_TYPES:
_get_presets(vt)
return jsonify({"status": "ok"})
|