Add kind/mode dropdowns to data and model catalog tabs
Browse files
app.py
CHANGED
|
@@ -2,26 +2,21 @@
|
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
import json
|
|
|
|
|
|
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
import pandas as pd
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
MODEL_REPO = "stride-influence/stride-applications-models"
|
| 13 |
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
import re
|
| 18 |
-
m = re.search(r'(\d+)pt(\d+)pct', path)
|
| 19 |
-
if m:
|
| 20 |
-
return f"{m.group(1)}.{m.group(2)}%"
|
| 21 |
-
m = re.search(r'(\d+)pct', path)
|
| 22 |
-
if m:
|
| 23 |
-
return f"{m.group(1)}%"
|
| 24 |
-
return None
|
| 25 |
|
| 26 |
|
| 27 |
def _try_load(repo_id: str, filename: str, repo_type: str):
|
|
@@ -36,23 +31,65 @@ def _try_load(repo_id: str, filename: str, repo_type: str):
|
|
| 36 |
return None
|
| 37 |
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
if not entries:
|
| 42 |
return pd.DataFrame(
|
| 43 |
columns=["path", "kind", "version", "n_examples", "n_tokens", "seed", "status", "description"]
|
| 44 |
)
|
| 45 |
df = pd.DataFrame(entries)
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
)
|
| 50 |
-
cols = ["path", "kind", "contamination_rate", "version", "n_examples", "seed", "status", "description"]
|
| 51 |
return df[[c for c in cols if c in df.columns]]
|
| 52 |
|
| 53 |
|
| 54 |
-
def load_model_catalog(show_deleted: bool = False, show_smoke: bool = False
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
if not entries:
|
| 57 |
return pd.DataFrame(
|
| 58 |
columns=["name", "status", "mode", "benchmark", "contamination_rate",
|
|
@@ -60,29 +97,17 @@ def load_model_catalog(show_deleted: bool = False, show_smoke: bool = False) ->
|
|
| 60 |
"proxy_dataset", "base_model"]
|
| 61 |
)
|
| 62 |
df = pd.DataFrame(entries)
|
| 63 |
-
# Hoist nested config/metrics fields to top-level columns
|
| 64 |
-
for nested_col, fields in [
|
| 65 |
-
("config", ["contamination_rate", "contamination_seed", "lr", "epochs", "base_model", "proxy_dataset"]),
|
| 66 |
-
("metrics", ["accuracy_overall", "accuracy_leaked", "accuracy_nonleaked"]),
|
| 67 |
-
]:
|
| 68 |
-
if nested_col in df.columns:
|
| 69 |
-
nested = df[nested_col].apply(lambda x: x if isinstance(x, dict) else {})
|
| 70 |
-
for field in fields:
|
| 71 |
-
if field not in df.columns:
|
| 72 |
-
df[field] = nested.apply(lambda x: x.get(field))
|
| 73 |
if not show_deleted:
|
| 74 |
-
# Hide both status=DELETED and physically archived models (deleted/ prefix)
|
| 75 |
is_deleted = (df.get("status", pd.Series(["VALID"] * len(df))) == "DELETED") | \
|
| 76 |
df["name"].str.startswith("deleted/")
|
| 77 |
df = df[~is_deleted]
|
| 78 |
if not show_smoke:
|
| 79 |
df = df[~df["name"].str.startswith("smoke/")]
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
cols = ["name", "status", "contamination_rate", "contamination_seed",
|
| 84 |
"accuracy_overall", "accuracy_leaked", "accuracy_nonleaked",
|
| 85 |
-
"
|
| 86 |
return df[[c for c in cols if c in df.columns]]
|
| 87 |
|
| 88 |
|
|
@@ -104,9 +129,14 @@ def load_queue_status():
|
|
| 104 |
return summary, df
|
| 105 |
|
| 106 |
|
| 107 |
-
def refresh_all(show_deleted: bool, show_smoke: bool):
|
| 108 |
summary, queue_df = load_queue_status()
|
| 109 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
|
| 112 |
with gr.Blocks(title="STRIDE Applications") as demo:
|
|
@@ -124,20 +154,25 @@ with gr.Blocks(title="STRIDE Applications") as demo:
|
|
| 124 |
show_smoke = gr.Checkbox(label="Show smoke-test models", value=False)
|
| 125 |
|
| 126 |
with gr.Tab("Data catalog"):
|
| 127 |
-
|
|
|
|
| 128 |
|
| 129 |
with gr.Tab("Model catalog"):
|
| 130 |
-
|
|
|
|
| 131 |
|
| 132 |
with gr.Tab("GPU queue"):
|
| 133 |
queue_md = gr.Markdown()
|
| 134 |
queue_tbl = gr.DataFrame(interactive=False, wrap=True)
|
| 135 |
|
|
|
|
| 136 |
outputs = [data_tbl, model_tbl, queue_md, queue_tbl]
|
| 137 |
-
demo.load(fn=refresh_all, inputs=
|
| 138 |
-
refresh_btn.click(fn=refresh_all, inputs=
|
| 139 |
-
show_deleted.change(fn=refresh_all, inputs=
|
| 140 |
-
show_smoke.change(fn=refresh_all, inputs=
|
|
|
|
|
|
|
| 141 |
|
| 142 |
|
| 143 |
if __name__ == "__main__":
|
|
|
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
import json
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
import pandas as pd
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
|
| 12 |
+
# Make catalog importable when the dashboard is launched from any directory.
|
| 13 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
| 14 |
|
| 15 |
+
from applications.infra.catalog import DataCatalog, ModelCatalog
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
+
DATASET_REPO = "stride-influence/stride-applications-data"
|
| 19 |
+
MODEL_REPO = "stride-influence/stride-applications-models"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
def _try_load(repo_id: str, filename: str, repo_type: str):
|
|
|
|
| 31 |
return None
|
| 32 |
|
| 33 |
|
| 34 |
+
DATA_KIND_OPTIONS = ["All", "benchmark_split", "contamination_manifest", "eval_config",
|
| 35 |
+
"eval_results", "proxy_corpus", "training_pool"]
|
| 36 |
+
MODEL_MODE_OPTIONS = ["All", "contaminated", "clean"]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def load_data_catalog(kind_filter: str = "All") -> pd.DataFrame:
|
| 40 |
+
try:
|
| 41 |
+
cat = DataCatalog(repo_id=DATASET_REPO).fetch(verbose=False)
|
| 42 |
+
entries = [
|
| 43 |
+
{
|
| 44 |
+
"path": e.path,
|
| 45 |
+
"kind": e.kind,
|
| 46 |
+
"version": e.version,
|
| 47 |
+
"n_examples": e.n_examples,
|
| 48 |
+
"n_tokens": e.n_tokens,
|
| 49 |
+
"seed": e.seed,
|
| 50 |
+
"status": e.status,
|
| 51 |
+
"description": e.description,
|
| 52 |
+
}
|
| 53 |
+
for e in cat.entries
|
| 54 |
+
]
|
| 55 |
+
except Exception:
|
| 56 |
+
entries = []
|
| 57 |
+
|
| 58 |
if not entries:
|
| 59 |
return pd.DataFrame(
|
| 60 |
columns=["path", "kind", "version", "n_examples", "n_tokens", "seed", "status", "description"]
|
| 61 |
)
|
| 62 |
df = pd.DataFrame(entries)
|
| 63 |
+
if kind_filter != "All":
|
| 64 |
+
df = df[df["kind"] == kind_filter]
|
| 65 |
+
cols = ["path", "kind", "version", "n_examples", "n_tokens", "seed", "status", "description"]
|
|
|
|
|
|
|
| 66 |
return df[[c for c in cols if c in df.columns]]
|
| 67 |
|
| 68 |
|
| 69 |
+
def load_model_catalog(show_deleted: bool = False, show_smoke: bool = False,
|
| 70 |
+
mode_filter: str = "All") -> pd.DataFrame:
|
| 71 |
+
try:
|
| 72 |
+
cat = ModelCatalog(repo_id=MODEL_REPO).fetch(verbose=False)
|
| 73 |
+
entries = [
|
| 74 |
+
{
|
| 75 |
+
"name": e.name,
|
| 76 |
+
"status": e.status,
|
| 77 |
+
"mode": e.mode,
|
| 78 |
+
"benchmark": e.benchmark,
|
| 79 |
+
"contamination_rate": e.contamination_rate,
|
| 80 |
+
"contamination_seed": e.contamination_seed,
|
| 81 |
+
"accuracy_overall": e.accuracy_overall,
|
| 82 |
+
"accuracy_leaked": e.accuracy_leaked,
|
| 83 |
+
"accuracy_nonleaked": e.accuracy_nonleaked,
|
| 84 |
+
"proxy_dataset": e.proxy_dataset,
|
| 85 |
+
"base_model": e.base_model,
|
| 86 |
+
"epochs": e._cfg("epochs"),
|
| 87 |
+
}
|
| 88 |
+
for e in cat.entries
|
| 89 |
+
]
|
| 90 |
+
except Exception:
|
| 91 |
+
entries = []
|
| 92 |
+
|
| 93 |
if not entries:
|
| 94 |
return pd.DataFrame(
|
| 95 |
columns=["name", "status", "mode", "benchmark", "contamination_rate",
|
|
|
|
| 97 |
"proxy_dataset", "base_model"]
|
| 98 |
)
|
| 99 |
df = pd.DataFrame(entries)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
if not show_deleted:
|
|
|
|
| 101 |
is_deleted = (df.get("status", pd.Series(["VALID"] * len(df))) == "DELETED") | \
|
| 102 |
df["name"].str.startswith("deleted/")
|
| 103 |
df = df[~is_deleted]
|
| 104 |
if not show_smoke:
|
| 105 |
df = df[~df["name"].str.startswith("smoke/")]
|
| 106 |
+
if mode_filter != "All":
|
| 107 |
+
df = df[df["mode"] == mode_filter]
|
| 108 |
+
cols = ["name", "status", "mode", "benchmark", "contamination_rate", "contamination_seed",
|
|
|
|
| 109 |
"accuracy_overall", "accuracy_leaked", "accuracy_nonleaked",
|
| 110 |
+
"proxy_dataset", "base_model", "epochs"]
|
| 111 |
return df[[c for c in cols if c in df.columns]]
|
| 112 |
|
| 113 |
|
|
|
|
| 129 |
return summary, df
|
| 130 |
|
| 131 |
|
| 132 |
+
def refresh_all(show_deleted: bool, show_smoke: bool, kind_filter: str, mode_filter: str):
|
| 133 |
summary, queue_df = load_queue_status()
|
| 134 |
+
return (
|
| 135 |
+
load_data_catalog(kind_filter),
|
| 136 |
+
load_model_catalog(show_deleted, show_smoke, mode_filter),
|
| 137 |
+
summary,
|
| 138 |
+
queue_df,
|
| 139 |
+
)
|
| 140 |
|
| 141 |
|
| 142 |
with gr.Blocks(title="STRIDE Applications") as demo:
|
|
|
|
| 154 |
show_smoke = gr.Checkbox(label="Show smoke-test models", value=False)
|
| 155 |
|
| 156 |
with gr.Tab("Data catalog"):
|
| 157 |
+
kind_filter = gr.Dropdown(choices=DATA_KIND_OPTIONS, value="All", label="Kind")
|
| 158 |
+
data_tbl = gr.DataFrame(interactive=False, wrap=True)
|
| 159 |
|
| 160 |
with gr.Tab("Model catalog"):
|
| 161 |
+
mode_filter = gr.Dropdown(choices=MODEL_MODE_OPTIONS, value="All", label="Mode")
|
| 162 |
+
model_tbl = gr.DataFrame(interactive=False, wrap=True)
|
| 163 |
|
| 164 |
with gr.Tab("GPU queue"):
|
| 165 |
queue_md = gr.Markdown()
|
| 166 |
queue_tbl = gr.DataFrame(interactive=False, wrap=True)
|
| 167 |
|
| 168 |
+
inputs = [show_deleted, show_smoke, kind_filter, mode_filter]
|
| 169 |
outputs = [data_tbl, model_tbl, queue_md, queue_tbl]
|
| 170 |
+
demo.load(fn=refresh_all, inputs=inputs, outputs=outputs)
|
| 171 |
+
refresh_btn.click(fn=refresh_all, inputs=inputs, outputs=outputs)
|
| 172 |
+
show_deleted.change(fn=refresh_all, inputs=inputs, outputs=outputs)
|
| 173 |
+
show_smoke.change(fn=refresh_all, inputs=inputs, outputs=outputs)
|
| 174 |
+
kind_filter.change(fn=refresh_all, inputs=inputs, outputs=outputs)
|
| 175 |
+
mode_filter.change(fn=refresh_all, inputs=inputs, outputs=outputs)
|
| 176 |
|
| 177 |
|
| 178 |
if __name__ == "__main__":
|