"""STRIDE Applications dashboard — reads HF-hosted catalogs + queue status.""" from __future__ import annotations import json from pathlib import Path import gradio as gr import pandas as pd from huggingface_hub import hf_hub_download, snapshot_download DATASET_REPO = "stride-influence/stride-applications-data" MODEL_REPO = "stride-influence/stride-applications-models" DATA_KIND_OPTIONS = ["All", "benchmark_split", "contamination_manifest", "eval_config", "eval_results", "proxy_corpus", "training_pool"] MODEL_MODE_OPTIONS = ["All", "contaminated", "clean"] def _fetch_catalog_entries(repo_id: str, repo_type: str, folder: str) -> list[dict]: """Download all per-entry JSON files from data_catalog/ or model_catalog/ on HF. Uses HF's native revision-based cache (no local_dir) so deleted files are automatically excluded — each HF commit gets its own snapshot directory. """ try: snapshot_dir = snapshot_download( repo_id=repo_id, repo_type=repo_type, allow_patterns=[f"{folder}/*.json"], ) entries = [] for f in Path(snapshot_dir).glob(f"{folder}/*.json"): try: entries.append(json.loads(f.read_text())) except Exception: pass return entries except Exception: return [] def _try_load(repo_id: str, filename: str, repo_type: str): try: path = hf_hub_download( repo_id=repo_id, filename=filename, repo_type=repo_type, force_download=True, ) with open(path) as f: return json.load(f) except Exception: return None def _hf_dataset_link(path: str) -> str: url = f"https://huggingface.co/datasets/{DATASET_REPO}/blob/main/{path}" short = path.split("/")[-1] if "/" in path else path return f'{short}' def _hf_model_link(name: str) -> str: url = f"https://huggingface.co/{MODEL_REPO}/tree/main/{name}" return f'{name}' def load_data_catalog(kind_filter: str = "All") -> pd.DataFrame: entries = _fetch_catalog_entries(DATASET_REPO, "dataset", "data_catalog") if not entries: return pd.DataFrame( columns=["file", "path", "kind", "version", "n_examples", "n_tokens", "seed", "status", "description"] ) df = pd.DataFrame([{ "file": _hf_dataset_link(e.get("path", "")), "path": e.get("path", ""), "kind": e.get("kind", ""), "version": e.get("version", ""), "n_examples": e.get("n_examples"), "n_tokens": e.get("n_tokens"), "seed": e.get("seed"), "status": e.get("status", ""), "description": e.get("description", ""), } for e in entries]) if kind_filter != "All": df = df[df["kind"] == kind_filter] cols = ["file", "kind", "n_examples", "n_tokens", "seed", "status", "description"] return df[[c for c in cols if c in df.columns]].sort_values("file").reset_index(drop=True) def load_model_catalog(show_deleted: bool = False, show_smoke: bool = False, mode_filter: str = "All") -> pd.DataFrame: entries = _fetch_catalog_entries(MODEL_REPO, "model", "model_catalog") if not entries: return pd.DataFrame( columns=["name", "status", "mode", "contamination_rate", "contamination_seed", "accuracy_leaked", "accuracy_nonleaked", "lr", "epochs", "base_model", "proxy_dataset"] ) rows = [] for e in entries: cfg = e.get("config", {}) mtr = e.get("metrics", {}) rows.append({ "model": _hf_model_link(e.get("name", "")), "name": e.get("name", ""), "status": e.get("status", "VALID"), "mode": cfg.get("mode", ""), "contamination_rate": cfg.get("contamination_rate"), "contamination_seed": cfg.get("contamination_seed"), "accuracy_leaked": cfg.get("accuracy_leaked") or mtr.get("accuracy_leaked") or mtr.get("final_leaked_acc"), "accuracy_nonleaked": cfg.get("accuracy_nonleaked") or mtr.get("accuracy_nonleaked") or mtr.get("final_nonleaked_acc"), "lr": cfg.get("lr"), "epochs": cfg.get("epochs"), "base_model": cfg.get("base_model", ""), "proxy_dataset": cfg.get("proxy_dataset", ""), }) df = pd.DataFrame(rows) if not show_deleted: is_deleted = (df["status"] == "DELETED") | df["name"].str.startswith("deleted/") df = df[~is_deleted] if not show_smoke: df = df[~df["name"].str.startswith("smoke/")] if mode_filter != "All": df = df[df["mode"] == mode_filter] cols = ["model", "status", "mode", "contamination_rate", "contamination_seed", "accuracy_leaked", "accuracy_nonleaked", "lr", "epochs", "base_model", "proxy_dataset"] return df[[c for c in cols if c in df.columns]].sort_values("model").reset_index(drop=True) def load_queue_status(): status = _try_load(MODEL_REPO, "queue_status.json", "model") if not status: return "_No queue status available yet._", pd.DataFrame() summary = ( f"**Updated:** {status.get('timestamp', '?')} \n" f"Total: **{status.get('total', 0)}** · " f"Pending: {status.get('pending', 0)} · " f"Running: **{status.get('running', 0)}** · " f"Done: {status.get('done', 0)} · " f"Failed: {status.get('failed', 0)} · " f"Stale: {status.get('stale', 0)}" ) jobs = status.get("jobs", []) df = pd.DataFrame(jobs) if jobs else pd.DataFrame() return summary, df def refresh_all(show_deleted: bool, show_smoke: bool, kind_filter: str, mode_filter: str): summary, queue_df = load_queue_status() return ( load_data_catalog(kind_filter), load_model_catalog(show_deleted, show_smoke, mode_filter), summary, queue_df, ) LATEX_FILE = "paper/latex_snippets.txt" def load_latex() -> str: """Fetch paper/latex_snippets.txt from the HF dataset repo (always fresh).""" try: path = hf_hub_download( repo_id=DATASET_REPO, filename=LATEX_FILE, repo_type="dataset", force_download=True, ) return Path(path).read_text() except Exception as e: return f"(no content yet — {e})" with gr.Blocks(title="STRIDE Applications") as demo: gr.Markdown( "# STRIDE Applications — status dashboard\n" "Live view of the data catalog, model catalog, and GPU queue for the " "STRIDE training-data attribution + benchmark-leakage experiments.\n\n" f"Data repo: [`{DATASET_REPO}`](https://huggingface.co/datasets/{DATASET_REPO}) · " f"Model repo: [`{MODEL_REPO}`](https://huggingface.co/{MODEL_REPO})" ) with gr.Row(): refresh_btn = gr.Button("Refresh", variant="primary") show_deleted = gr.Checkbox(label="Show deleted models", value=False) show_smoke = gr.Checkbox(label="Show smoke-test models", value=False) with gr.Tab("Data catalog"): kind_filter = gr.Dropdown(choices=DATA_KIND_OPTIONS, value="All", label="Kind") data_tbl = gr.DataFrame(interactive=False, datatype=["html", "str", "number", "number", "number", "str", "str"]) with gr.Tab("Model catalog"): mode_filter = gr.Dropdown(choices=MODEL_MODE_OPTIONS, value="All", label="Mode") model_tbl = gr.DataFrame(interactive=False, datatype=["html", "str", "str", "number", "number", "number", "number", "number", "number", "str", "str"]) with gr.Tab("GPU queue"): queue_md = gr.Markdown() queue_tbl = gr.DataFrame(interactive=False) with gr.Tab("Paper / LaTeX"): gr.Markdown( "LaTeX snippets for copy-pasting into the paper. " f"Backed by [`{DATASET_REPO}/{LATEX_FILE}`]" f"(https://huggingface.co/datasets/{DATASET_REPO}/blob/main/{LATEX_FILE}) — " "push a new version there to update this view." ) latex_refresh_btn = gr.Button("Refresh", variant="secondary") latex_box = gr.Textbox( label="", lines=40, max_lines=200, interactive=False, show_copy_button=True, ) inputs = [show_deleted, show_smoke, kind_filter, mode_filter] outputs = [data_tbl, model_tbl, queue_md, queue_tbl] demo.load(fn=refresh_all, inputs=inputs, outputs=outputs) demo.load(fn=load_latex, outputs=[latex_box]) refresh_btn.click(fn=refresh_all, inputs=inputs, outputs=outputs) latex_refresh_btn.click(fn=load_latex, outputs=[latex_box]) show_deleted.change(fn=refresh_all, inputs=inputs, outputs=outputs) show_smoke.change(fn=refresh_all, inputs=inputs, outputs=outputs) kind_filter.change(fn=refresh_all, inputs=inputs, outputs=outputs) mode_filter.change(fn=refresh_all, inputs=inputs, outputs=outputs) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)