| """STRIDE Applications dashboard — reads HF-hosted catalogs + queue status.""" |
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
|
|
| import gradio as gr |
| import pandas as pd |
| from huggingface_hub import hf_hub_download, snapshot_download |
|
|
|
|
| DATASET_REPO = "stride-influence/stride-applications-data" |
| MODEL_REPO = "stride-influence/stride-applications-models" |
|
|
| DATA_KIND_OPTIONS = ["All", "benchmark_split", "contamination_manifest", "eval_config", |
| "eval_results", "proxy_corpus", "training_pool"] |
| MODEL_MODE_OPTIONS = ["All", "contaminated", "clean"] |
|
|
|
|
| def _fetch_catalog_entries(repo_id: str, repo_type: str, folder: str) -> list[dict]: |
| """Download all per-entry JSON files from data_catalog/ or model_catalog/ on HF. |
| |
| Uses HF's native revision-based cache (no local_dir) so deleted files are |
| automatically excluded — each HF commit gets its own snapshot directory. |
| """ |
| try: |
| snapshot_dir = snapshot_download( |
| repo_id=repo_id, |
| repo_type=repo_type, |
| allow_patterns=[f"{folder}/*.json"], |
| ) |
| entries = [] |
| for f in Path(snapshot_dir).glob(f"{folder}/*.json"): |
| try: |
| entries.append(json.loads(f.read_text())) |
| except Exception: |
| pass |
| return entries |
| except Exception: |
| return [] |
|
|
|
|
| def _try_load(repo_id: str, filename: str, repo_type: str): |
| try: |
| path = hf_hub_download( |
| repo_id=repo_id, filename=filename, |
| repo_type=repo_type, force_download=True, |
| ) |
| with open(path) as f: |
| return json.load(f) |
| except Exception: |
| return None |
|
|
|
|
| def _hf_dataset_link(path: str) -> str: |
| url = f"https://huggingface.co/datasets/{DATASET_REPO}/blob/main/{path}" |
| short = path.split("/")[-1] if "/" in path else path |
| return f'<a href="{url}" target="_blank">{short}</a>' |
|
|
|
|
| def _hf_model_link(name: str) -> str: |
| url = f"https://huggingface.co/{MODEL_REPO}/tree/main/{name}" |
| return f'<a href="{url}" target="_blank">{name}</a>' |
|
|
|
|
| def load_data_catalog(kind_filter: str = "All") -> pd.DataFrame: |
| entries = _fetch_catalog_entries(DATASET_REPO, "dataset", "data_catalog") |
| if not entries: |
| return pd.DataFrame( |
| columns=["file", "path", "kind", "version", "n_examples", "n_tokens", "seed", "status", "description"] |
| ) |
| df = pd.DataFrame([{ |
| "file": _hf_dataset_link(e.get("path", "")), |
| "path": e.get("path", ""), |
| "kind": e.get("kind", ""), |
| "version": e.get("version", ""), |
| "n_examples": e.get("n_examples"), |
| "n_tokens": e.get("n_tokens"), |
| "seed": e.get("seed"), |
| "status": e.get("status", ""), |
| "description": e.get("description", ""), |
| } for e in entries]) |
| if kind_filter != "All": |
| df = df[df["kind"] == kind_filter] |
| cols = ["file", "kind", "n_examples", "n_tokens", "seed", "status", "description"] |
| return df[[c for c in cols if c in df.columns]].sort_values("file").reset_index(drop=True) |
|
|
|
|
| def load_model_catalog(show_deleted: bool = False, show_smoke: bool = False, |
| mode_filter: str = "All") -> pd.DataFrame: |
| entries = _fetch_catalog_entries(MODEL_REPO, "model", "model_catalog") |
| if not entries: |
| return pd.DataFrame( |
| columns=["name", "status", "mode", "contamination_rate", "contamination_seed", |
| "accuracy_leaked", "accuracy_nonleaked", "lr", "epochs", "base_model", "proxy_dataset"] |
| ) |
|
|
| rows = [] |
| for e in entries: |
| cfg = e.get("config", {}) |
| mtr = e.get("metrics", {}) |
| rows.append({ |
| "model": _hf_model_link(e.get("name", "")), |
| "name": e.get("name", ""), |
| "status": e.get("status", "VALID"), |
| "mode": cfg.get("mode", ""), |
| "contamination_rate": cfg.get("contamination_rate"), |
| "contamination_seed": cfg.get("contamination_seed"), |
| "accuracy_leaked": cfg.get("accuracy_leaked") or mtr.get("accuracy_leaked") or mtr.get("final_leaked_acc"), |
| "accuracy_nonleaked": cfg.get("accuracy_nonleaked") or mtr.get("accuracy_nonleaked") or mtr.get("final_nonleaked_acc"), |
| "lr": cfg.get("lr"), |
| "epochs": cfg.get("epochs"), |
| "base_model": cfg.get("base_model", ""), |
| "proxy_dataset": cfg.get("proxy_dataset", ""), |
| }) |
|
|
| df = pd.DataFrame(rows) |
| if not show_deleted: |
| is_deleted = (df["status"] == "DELETED") | df["name"].str.startswith("deleted/") |
| df = df[~is_deleted] |
| if not show_smoke: |
| df = df[~df["name"].str.startswith("smoke/")] |
| if mode_filter != "All": |
| df = df[df["mode"] == mode_filter] |
| cols = ["model", "status", "mode", "contamination_rate", "contamination_seed", |
| "accuracy_leaked", "accuracy_nonleaked", "lr", "epochs", "base_model", "proxy_dataset"] |
| return df[[c for c in cols if c in df.columns]].sort_values("model").reset_index(drop=True) |
|
|
|
|
| def load_queue_status(): |
| status = _try_load(MODEL_REPO, "queue_status.json", "model") |
| if not status: |
| return "_No queue status available yet._", pd.DataFrame() |
| summary = ( |
| f"**Updated:** {status.get('timestamp', '?')} \n" |
| f"Total: **{status.get('total', 0)}** · " |
| f"Pending: {status.get('pending', 0)} · " |
| f"Running: **{status.get('running', 0)}** · " |
| f"Done: {status.get('done', 0)} · " |
| f"Failed: {status.get('failed', 0)} · " |
| f"Stale: {status.get('stale', 0)}" |
| ) |
| jobs = status.get("jobs", []) |
| df = pd.DataFrame(jobs) if jobs else pd.DataFrame() |
| return summary, df |
|
|
|
|
| def refresh_all(show_deleted: bool, show_smoke: bool, kind_filter: str, mode_filter: str): |
| summary, queue_df = load_queue_status() |
| return ( |
| load_data_catalog(kind_filter), |
| load_model_catalog(show_deleted, show_smoke, mode_filter), |
| summary, |
| queue_df, |
| ) |
|
|
|
|
| LATEX_FILE = "paper/latex_snippets.txt" |
|
|
|
|
| def load_latex() -> str: |
| """Fetch paper/latex_snippets.txt from the HF dataset repo (always fresh).""" |
| try: |
| path = hf_hub_download( |
| repo_id=DATASET_REPO, |
| filename=LATEX_FILE, |
| repo_type="dataset", |
| force_download=True, |
| ) |
| return Path(path).read_text() |
| except Exception as e: |
| return f"(no content yet — {e})" |
|
|
|
|
| with gr.Blocks(title="STRIDE Applications") as demo: |
| gr.Markdown( |
| "# STRIDE Applications — status dashboard\n" |
| "Live view of the data catalog, model catalog, and GPU queue for the " |
| "STRIDE training-data attribution + benchmark-leakage experiments.\n\n" |
| f"Data repo: [`{DATASET_REPO}`](https://huggingface.co/datasets/{DATASET_REPO}) · " |
| f"Model repo: [`{MODEL_REPO}`](https://huggingface.co/{MODEL_REPO})" |
| ) |
|
|
| with gr.Row(): |
| refresh_btn = gr.Button("Refresh", variant="primary") |
| show_deleted = gr.Checkbox(label="Show deleted models", value=False) |
| show_smoke = gr.Checkbox(label="Show smoke-test models", value=False) |
|
|
| with gr.Tab("Data catalog"): |
| kind_filter = gr.Dropdown(choices=DATA_KIND_OPTIONS, value="All", label="Kind") |
| data_tbl = gr.DataFrame(interactive=False, datatype=["html", "str", "number", "number", "number", "str", "str"]) |
|
|
| with gr.Tab("Model catalog"): |
| mode_filter = gr.Dropdown(choices=MODEL_MODE_OPTIONS, value="All", label="Mode") |
| model_tbl = gr.DataFrame(interactive=False, datatype=["html", "str", "str", "number", "number", "number", "number", "number", "number", "str", "str"]) |
|
|
| with gr.Tab("GPU queue"): |
| queue_md = gr.Markdown() |
| queue_tbl = gr.DataFrame(interactive=False) |
|
|
| with gr.Tab("Paper / LaTeX"): |
| gr.Markdown( |
| "LaTeX snippets for copy-pasting into the paper. " |
| f"Backed by [`{DATASET_REPO}/{LATEX_FILE}`]" |
| f"(https://huggingface.co/datasets/{DATASET_REPO}/blob/main/{LATEX_FILE}) — " |
| "push a new version there to update this view." |
| ) |
| latex_refresh_btn = gr.Button("Refresh", variant="secondary") |
| latex_box = gr.Textbox( |
| label="", |
| lines=40, |
| max_lines=200, |
| interactive=False, |
| show_copy_button=True, |
| ) |
|
|
| inputs = [show_deleted, show_smoke, kind_filter, mode_filter] |
| outputs = [data_tbl, model_tbl, queue_md, queue_tbl] |
|
|
| demo.load(fn=refresh_all, inputs=inputs, outputs=outputs) |
| demo.load(fn=load_latex, outputs=[latex_box]) |
| refresh_btn.click(fn=refresh_all, inputs=inputs, outputs=outputs) |
| latex_refresh_btn.click(fn=load_latex, outputs=[latex_box]) |
| show_deleted.change(fn=refresh_all, inputs=inputs, outputs=outputs) |
| show_smoke.change(fn=refresh_all, inputs=inputs, outputs=outputs) |
| kind_filter.change(fn=refresh_all, inputs=inputs, outputs=outputs) |
| mode_filter.change(fn=refresh_all, inputs=inputs, outputs=outputs) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|