amirali1985 commited on
Commit
193a549
·
verified ·
1 Parent(s): baae236

Add kind/mode dropdowns to data and model catalog tabs

Browse files
Files changed (1) hide show
  1. app.py +80 -45
app.py CHANGED
@@ -2,26 +2,21 @@
2
  from __future__ import annotations
3
 
4
  import json
 
 
5
 
6
  import gradio as gr
7
  import pandas as pd
8
  from huggingface_hub import hf_hub_download
9
 
 
 
10
 
11
- DATASET_REPO = "stride-influence/stride-applications-data"
12
- MODEL_REPO = "stride-influence/stride-applications-models"
13
 
14
 
15
- def _parse_contamination_rate(path: str) -> str | None:
16
- """Extract contamination rate from a catalog path, e.g. '1pct' → '1%', '0pt5pct' → '0.5%'."""
17
- import re
18
- m = re.search(r'(\d+)pt(\d+)pct', path)
19
- if m:
20
- return f"{m.group(1)}.{m.group(2)}%"
21
- m = re.search(r'(\d+)pct', path)
22
- if m:
23
- return f"{m.group(1)}%"
24
- return None
25
 
26
 
27
  def _try_load(repo_id: str, filename: str, repo_type: str):
@@ -36,23 +31,65 @@ def _try_load(repo_id: str, filename: str, repo_type: str):
36
  return None
37
 
38
 
39
- def load_data_catalog() -> pd.DataFrame:
40
- entries = _try_load(DATASET_REPO, "data_catalog.json", "dataset") or []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  if not entries:
42
  return pd.DataFrame(
43
  columns=["path", "kind", "version", "n_examples", "n_tokens", "seed", "status", "description"]
44
  )
45
  df = pd.DataFrame(entries)
46
- df["contamination_rate"] = df["path"].apply(_parse_contamination_rate)
47
- df["path"] = df["path"].apply(
48
- lambda p: f'<a href="https://huggingface.co/datasets/{DATASET_REPO}/blob/main/{p}" target="_blank">{p}</a>'
49
- )
50
- cols = ["path", "kind", "contamination_rate", "version", "n_examples", "seed", "status", "description"]
51
  return df[[c for c in cols if c in df.columns]]
52
 
53
 
54
- def load_model_catalog(show_deleted: bool = False, show_smoke: bool = False) -> pd.DataFrame:
55
- entries = _try_load(MODEL_REPO, "model_catalog.json", "model") or []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  if not entries:
57
  return pd.DataFrame(
58
  columns=["name", "status", "mode", "benchmark", "contamination_rate",
@@ -60,29 +97,17 @@ def load_model_catalog(show_deleted: bool = False, show_smoke: bool = False) ->
60
  "proxy_dataset", "base_model"]
61
  )
62
  df = pd.DataFrame(entries)
63
- # Hoist nested config/metrics fields to top-level columns
64
- for nested_col, fields in [
65
- ("config", ["contamination_rate", "contamination_seed", "lr", "epochs", "base_model", "proxy_dataset"]),
66
- ("metrics", ["accuracy_overall", "accuracy_leaked", "accuracy_nonleaked"]),
67
- ]:
68
- if nested_col in df.columns:
69
- nested = df[nested_col].apply(lambda x: x if isinstance(x, dict) else {})
70
- for field in fields:
71
- if field not in df.columns:
72
- df[field] = nested.apply(lambda x: x.get(field))
73
  if not show_deleted:
74
- # Hide both status=DELETED and physically archived models (deleted/ prefix)
75
  is_deleted = (df.get("status", pd.Series(["VALID"] * len(df))) == "DELETED") | \
76
  df["name"].str.startswith("deleted/")
77
  df = df[~is_deleted]
78
  if not show_smoke:
79
  df = df[~df["name"].str.startswith("smoke/")]
80
- df["name"] = df["name"].apply(
81
- lambda n: f'<a href="https://huggingface.co/{MODEL_REPO}/tree/main/{n}" target="_blank">{n.split("/")[-1]}</a>'
82
- )
83
- cols = ["name", "status", "contamination_rate", "contamination_seed",
84
  "accuracy_overall", "accuracy_leaked", "accuracy_nonleaked",
85
- "lr", "epochs", "base_model", "proxy_dataset"]
86
  return df[[c for c in cols if c in df.columns]]
87
 
88
 
@@ -104,9 +129,14 @@ def load_queue_status():
104
  return summary, df
105
 
106
 
107
- def refresh_all(show_deleted: bool, show_smoke: bool):
108
  summary, queue_df = load_queue_status()
109
- return load_data_catalog(), load_model_catalog(show_deleted, show_smoke), summary, queue_df
 
 
 
 
 
110
 
111
 
112
  with gr.Blocks(title="STRIDE Applications") as demo:
@@ -124,20 +154,25 @@ with gr.Blocks(title="STRIDE Applications") as demo:
124
  show_smoke = gr.Checkbox(label="Show smoke-test models", value=False)
125
 
126
  with gr.Tab("Data catalog"):
127
- data_tbl = gr.DataFrame(interactive=False, wrap=True, datatype="html")
 
128
 
129
  with gr.Tab("Model catalog"):
130
- model_tbl = gr.DataFrame(interactive=False, wrap=True, datatype="html")
 
131
 
132
  with gr.Tab("GPU queue"):
133
  queue_md = gr.Markdown()
134
  queue_tbl = gr.DataFrame(interactive=False, wrap=True)
135
 
 
136
  outputs = [data_tbl, model_tbl, queue_md, queue_tbl]
137
- demo.load(fn=refresh_all, inputs=[show_deleted, show_smoke], outputs=outputs)
138
- refresh_btn.click(fn=refresh_all, inputs=[show_deleted, show_smoke], outputs=outputs)
139
- show_deleted.change(fn=refresh_all, inputs=[show_deleted, show_smoke], outputs=outputs)
140
- show_smoke.change(fn=refresh_all, inputs=[show_deleted, show_smoke], outputs=outputs)
 
 
141
 
142
 
143
  if __name__ == "__main__":
 
2
  from __future__ import annotations
3
 
4
  import json
5
+ import sys
6
+ from pathlib import Path
7
 
8
  import gradio as gr
9
  import pandas as pd
10
  from huggingface_hub import hf_hub_download
11
 
12
+ # Make catalog importable when the dashboard is launched from any directory.
13
+ sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
14
 
15
+ from applications.infra.catalog import DataCatalog, ModelCatalog
 
16
 
17
 
18
+ DATASET_REPO = "stride-influence/stride-applications-data"
19
+ MODEL_REPO = "stride-influence/stride-applications-models"
 
 
 
 
 
 
 
 
20
 
21
 
22
  def _try_load(repo_id: str, filename: str, repo_type: str):
 
31
  return None
32
 
33
 
34
+ DATA_KIND_OPTIONS = ["All", "benchmark_split", "contamination_manifest", "eval_config",
35
+ "eval_results", "proxy_corpus", "training_pool"]
36
+ MODEL_MODE_OPTIONS = ["All", "contaminated", "clean"]
37
+
38
+
39
+ def load_data_catalog(kind_filter: str = "All") -> pd.DataFrame:
40
+ try:
41
+ cat = DataCatalog(repo_id=DATASET_REPO).fetch(verbose=False)
42
+ entries = [
43
+ {
44
+ "path": e.path,
45
+ "kind": e.kind,
46
+ "version": e.version,
47
+ "n_examples": e.n_examples,
48
+ "n_tokens": e.n_tokens,
49
+ "seed": e.seed,
50
+ "status": e.status,
51
+ "description": e.description,
52
+ }
53
+ for e in cat.entries
54
+ ]
55
+ except Exception:
56
+ entries = []
57
+
58
  if not entries:
59
  return pd.DataFrame(
60
  columns=["path", "kind", "version", "n_examples", "n_tokens", "seed", "status", "description"]
61
  )
62
  df = pd.DataFrame(entries)
63
+ if kind_filter != "All":
64
+ df = df[df["kind"] == kind_filter]
65
+ cols = ["path", "kind", "version", "n_examples", "n_tokens", "seed", "status", "description"]
 
 
66
  return df[[c for c in cols if c in df.columns]]
67
 
68
 
69
+ def load_model_catalog(show_deleted: bool = False, show_smoke: bool = False,
70
+ mode_filter: str = "All") -> pd.DataFrame:
71
+ try:
72
+ cat = ModelCatalog(repo_id=MODEL_REPO).fetch(verbose=False)
73
+ entries = [
74
+ {
75
+ "name": e.name,
76
+ "status": e.status,
77
+ "mode": e.mode,
78
+ "benchmark": e.benchmark,
79
+ "contamination_rate": e.contamination_rate,
80
+ "contamination_seed": e.contamination_seed,
81
+ "accuracy_overall": e.accuracy_overall,
82
+ "accuracy_leaked": e.accuracy_leaked,
83
+ "accuracy_nonleaked": e.accuracy_nonleaked,
84
+ "proxy_dataset": e.proxy_dataset,
85
+ "base_model": e.base_model,
86
+ "epochs": e._cfg("epochs"),
87
+ }
88
+ for e in cat.entries
89
+ ]
90
+ except Exception:
91
+ entries = []
92
+
93
  if not entries:
94
  return pd.DataFrame(
95
  columns=["name", "status", "mode", "benchmark", "contamination_rate",
 
97
  "proxy_dataset", "base_model"]
98
  )
99
  df = pd.DataFrame(entries)
 
 
 
 
 
 
 
 
 
 
100
  if not show_deleted:
 
101
  is_deleted = (df.get("status", pd.Series(["VALID"] * len(df))) == "DELETED") | \
102
  df["name"].str.startswith("deleted/")
103
  df = df[~is_deleted]
104
  if not show_smoke:
105
  df = df[~df["name"].str.startswith("smoke/")]
106
+ if mode_filter != "All":
107
+ df = df[df["mode"] == mode_filter]
108
+ cols = ["name", "status", "mode", "benchmark", "contamination_rate", "contamination_seed",
 
109
  "accuracy_overall", "accuracy_leaked", "accuracy_nonleaked",
110
+ "proxy_dataset", "base_model", "epochs"]
111
  return df[[c for c in cols if c in df.columns]]
112
 
113
 
 
129
  return summary, df
130
 
131
 
132
+ def refresh_all(show_deleted: bool, show_smoke: bool, kind_filter: str, mode_filter: str):
133
  summary, queue_df = load_queue_status()
134
+ return (
135
+ load_data_catalog(kind_filter),
136
+ load_model_catalog(show_deleted, show_smoke, mode_filter),
137
+ summary,
138
+ queue_df,
139
+ )
140
 
141
 
142
  with gr.Blocks(title="STRIDE Applications") as demo:
 
154
  show_smoke = gr.Checkbox(label="Show smoke-test models", value=False)
155
 
156
  with gr.Tab("Data catalog"):
157
+ kind_filter = gr.Dropdown(choices=DATA_KIND_OPTIONS, value="All", label="Kind")
158
+ data_tbl = gr.DataFrame(interactive=False, wrap=True)
159
 
160
  with gr.Tab("Model catalog"):
161
+ mode_filter = gr.Dropdown(choices=MODEL_MODE_OPTIONS, value="All", label="Mode")
162
+ model_tbl = gr.DataFrame(interactive=False, wrap=True)
163
 
164
  with gr.Tab("GPU queue"):
165
  queue_md = gr.Markdown()
166
  queue_tbl = gr.DataFrame(interactive=False, wrap=True)
167
 
168
+ inputs = [show_deleted, show_smoke, kind_filter, mode_filter]
169
  outputs = [data_tbl, model_tbl, queue_md, queue_tbl]
170
+ demo.load(fn=refresh_all, inputs=inputs, outputs=outputs)
171
+ refresh_btn.click(fn=refresh_all, inputs=inputs, outputs=outputs)
172
+ show_deleted.change(fn=refresh_all, inputs=inputs, outputs=outputs)
173
+ show_smoke.change(fn=refresh_all, inputs=inputs, outputs=outputs)
174
+ kind_filter.change(fn=refresh_all, inputs=inputs, outputs=outputs)
175
+ mode_filter.change(fn=refresh_all, inputs=inputs, outputs=outputs)
176
 
177
 
178
  if __name__ == "__main__":