File size: 9,286 Bytes
58be37b
 
 
 
193a549
58be37b
 
 
d4178aa
58be37b
 
d4178aa
 
58be37b
d4178aa
 
 
58be37b
d4178aa
 
2fe7eec
 
 
 
 
d4178aa
2fe7eec
d4178aa
 
 
 
 
2fe7eec
d4178aa
 
 
 
 
 
 
7cf19ee
 
58be37b
 
 
 
 
 
 
 
 
 
 
 
8ced1c3
 
 
 
 
 
 
 
 
 
 
ac91a98
 
58be37b
 
8ced1c3
58be37b
d4178aa
8ced1c3
d4178aa
 
 
 
 
 
 
 
 
193a549
 
8ced1c3
 
d4178aa
 
193a549
ac91a98
 
58be37b
 
d4178aa
 
58be37b
d4178aa
 
 
 
 
 
8ced1c3
d4178aa
 
 
 
 
 
 
 
 
 
 
 
 
 
e53ab8d
d4178aa
e53ab8d
 
 
193a549
 
8ced1c3
d4178aa
8ced1c3
58be37b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fe7eec
ac91a98
193a549
ac91a98
 
193a549
ac91a98
193a549
58be37b
 
c184220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58be37b
 
 
 
 
 
 
 
 
e53ab8d
 
 
d4178aa
58be37b
 
193a549
8ced1c3
58be37b
 
193a549
8ced1c3
58be37b
 
d4178aa
ac91a98
 
c184220
 
 
 
 
 
 
 
2efaa1b
c184220
 
2efaa1b
c184220
2efaa1b
c184220
 
ac91a98
 
 
 
2efaa1b
ac91a98
2efaa1b
ac91a98
 
 
 
58be37b
 
 
84bcd44
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
"""STRIDE Applications dashboard — reads HF-hosted catalogs + queue status."""
from __future__ import annotations

import json
from pathlib import Path

import gradio as gr
import pandas as pd
from huggingface_hub import hf_hub_download, snapshot_download


DATASET_REPO = "stride-influence/stride-applications-data"
MODEL_REPO   = "stride-influence/stride-applications-models"

DATA_KIND_OPTIONS  = ["All", "benchmark_split", "contamination_manifest", "eval_config",
                      "eval_results", "proxy_corpus", "training_pool"]
MODEL_MODE_OPTIONS = ["All", "contaminated", "clean"]


def _fetch_catalog_entries(repo_id: str, repo_type: str, folder: str) -> list[dict]:
    """Download all per-entry JSON files from data_catalog/ or model_catalog/ on HF.

    Uses HF's native revision-based cache (no local_dir) so deleted files are
    automatically excluded — each HF commit gets its own snapshot directory.
    """
    try:
        snapshot_dir = snapshot_download(
            repo_id=repo_id,
            repo_type=repo_type,
            allow_patterns=[f"{folder}/*.json"],
        )
        entries = []
        for f in Path(snapshot_dir).glob(f"{folder}/*.json"):
            try:
                entries.append(json.loads(f.read_text()))
            except Exception:
                pass
        return entries
    except Exception:
        return []


def _try_load(repo_id: str, filename: str, repo_type: str):
    try:
        path = hf_hub_download(
            repo_id=repo_id, filename=filename,
            repo_type=repo_type, force_download=True,
        )
        with open(path) as f:
            return json.load(f)
    except Exception:
        return None


def _hf_dataset_link(path: str) -> str:
    url = f"https://huggingface.co/datasets/{DATASET_REPO}/blob/main/{path}"
    short = path.split("/")[-1] if "/" in path else path
    return f'<a href="{url}" target="_blank">{short}</a>'


def _hf_model_link(name: str) -> str:
    url = f"https://huggingface.co/{MODEL_REPO}/tree/main/{name}"
    return f'<a href="{url}" target="_blank">{name}</a>'


def load_data_catalog(kind_filter: str = "All") -> pd.DataFrame:
    entries = _fetch_catalog_entries(DATASET_REPO, "dataset", "data_catalog")
    if not entries:
        return pd.DataFrame(
            columns=["file", "path", "kind", "version", "n_examples", "n_tokens", "seed", "status", "description"]
        )
    df = pd.DataFrame([{
        "file":        _hf_dataset_link(e.get("path", "")),
        "path":        e.get("path", ""),
        "kind":        e.get("kind", ""),
        "version":     e.get("version", ""),
        "n_examples":  e.get("n_examples"),
        "n_tokens":    e.get("n_tokens"),
        "seed":        e.get("seed"),
        "status":      e.get("status", ""),
        "description": e.get("description", ""),
    } for e in entries])
    if kind_filter != "All":
        df = df[df["kind"] == kind_filter]
    cols = ["file", "kind", "n_examples", "n_tokens", "seed", "status", "description"]
    return df[[c for c in cols if c in df.columns]].sort_values("file").reset_index(drop=True)


def load_model_catalog(show_deleted: bool = False, show_smoke: bool = False,
                       mode_filter: str = "All") -> pd.DataFrame:
    entries = _fetch_catalog_entries(MODEL_REPO, "model", "model_catalog")
    if not entries:
        return pd.DataFrame(
            columns=["name", "status", "mode", "contamination_rate", "contamination_seed",
                     "accuracy_leaked", "accuracy_nonleaked", "lr", "epochs", "base_model", "proxy_dataset"]
        )

    rows = []
    for e in entries:
        cfg = e.get("config", {})
        mtr = e.get("metrics", {})
        rows.append({
            "model":              _hf_model_link(e.get("name", "")),
            "name":               e.get("name", ""),
            "status":             e.get("status", "VALID"),
            "mode":               cfg.get("mode", ""),
            "contamination_rate": cfg.get("contamination_rate"),
            "contamination_seed": cfg.get("contamination_seed"),
            "accuracy_leaked":    cfg.get("accuracy_leaked")    or mtr.get("accuracy_leaked")    or mtr.get("final_leaked_acc"),
            "accuracy_nonleaked": cfg.get("accuracy_nonleaked") or mtr.get("accuracy_nonleaked") or mtr.get("final_nonleaked_acc"),
            "lr":                 cfg.get("lr"),
            "epochs":             cfg.get("epochs"),
            "base_model":         cfg.get("base_model", ""),
            "proxy_dataset":      cfg.get("proxy_dataset", ""),
        })

    df = pd.DataFrame(rows)
    if not show_deleted:
        is_deleted = (df["status"] == "DELETED") | df["name"].str.startswith("deleted/")
        df = df[~is_deleted]
    if not show_smoke:
        df = df[~df["name"].str.startswith("smoke/")]
    if mode_filter != "All":
        df = df[df["mode"] == mode_filter]
    cols = ["model", "status", "mode", "contamination_rate", "contamination_seed",
            "accuracy_leaked", "accuracy_nonleaked", "lr", "epochs", "base_model", "proxy_dataset"]
    return df[[c for c in cols if c in df.columns]].sort_values("model").reset_index(drop=True)


def load_queue_status():
    status = _try_load(MODEL_REPO, "queue_status.json", "model")
    if not status:
        return "_No queue status available yet._", pd.DataFrame()
    summary = (
        f"**Updated:** {status.get('timestamp', '?')}  \n"
        f"Total: **{status.get('total', 0)}**  ·  "
        f"Pending: {status.get('pending', 0)}  ·  "
        f"Running: **{status.get('running', 0)}**  ·  "
        f"Done: {status.get('done', 0)}  ·  "
        f"Failed: {status.get('failed', 0)}  ·  "
        f"Stale: {status.get('stale', 0)}"
    )
    jobs = status.get("jobs", [])
    df = pd.DataFrame(jobs) if jobs else pd.DataFrame()
    return summary, df


def refresh_all(show_deleted: bool, show_smoke: bool, kind_filter: str, mode_filter: str):
    summary, queue_df = load_queue_status()
    return (
        load_data_catalog(kind_filter),
        load_model_catalog(show_deleted, show_smoke, mode_filter),
        summary,
        queue_df,
    )


LATEX_FILE = "paper/latex_snippets.txt"


def load_latex() -> str:
    """Fetch paper/latex_snippets.txt from the HF dataset repo (always fresh)."""
    try:
        path = hf_hub_download(
            repo_id=DATASET_REPO,
            filename=LATEX_FILE,
            repo_type="dataset",
            force_download=True,
        )
        return Path(path).read_text()
    except Exception as e:
        return f"(no content yet — {e})"


with gr.Blocks(title="STRIDE Applications") as demo:
    gr.Markdown(
        "# STRIDE Applications — status dashboard\n"
        "Live view of the data catalog, model catalog, and GPU queue for the "
        "STRIDE training-data attribution + benchmark-leakage experiments.\n\n"
        f"Data repo: [`{DATASET_REPO}`](https://huggingface.co/datasets/{DATASET_REPO})  ·  "
        f"Model repo: [`{MODEL_REPO}`](https://huggingface.co/{MODEL_REPO})"
    )

    with gr.Row():
        refresh_btn = gr.Button("Refresh", variant="primary")
        show_deleted = gr.Checkbox(label="Show deleted models", value=False)
        show_smoke   = gr.Checkbox(label="Show smoke-test models", value=False)

    with gr.Tab("Data catalog"):
        kind_filter = gr.Dropdown(choices=DATA_KIND_OPTIONS, value="All", label="Kind")
        data_tbl = gr.DataFrame(interactive=False, datatype=["html", "str", "number", "number", "number", "str", "str"])

    with gr.Tab("Model catalog"):
        mode_filter = gr.Dropdown(choices=MODEL_MODE_OPTIONS, value="All", label="Mode")
        model_tbl = gr.DataFrame(interactive=False, datatype=["html", "str", "str", "number", "number", "number", "number", "number", "number", "str", "str"])

    with gr.Tab("GPU queue"):
        queue_md  = gr.Markdown()
        queue_tbl = gr.DataFrame(interactive=False)

    with gr.Tab("Paper / LaTeX"):
        gr.Markdown(
            "LaTeX snippets for copy-pasting into the paper. "
            f"Backed by [`{DATASET_REPO}/{LATEX_FILE}`]"
            f"(https://huggingface.co/datasets/{DATASET_REPO}/blob/main/{LATEX_FILE}) — "
            "push a new version there to update this view."
        )
        latex_refresh_btn = gr.Button("Refresh", variant="secondary")
        latex_box = gr.Textbox(
            label="",
            lines=40,
            max_lines=200,
            interactive=False,
            show_copy_button=True,
        )

    inputs  = [show_deleted, show_smoke, kind_filter, mode_filter]
    outputs = [data_tbl, model_tbl, queue_md, queue_tbl]

    demo.load(fn=refresh_all, inputs=inputs, outputs=outputs)
    demo.load(fn=load_latex, outputs=[latex_box])
    refresh_btn.click(fn=refresh_all, inputs=inputs, outputs=outputs)
    latex_refresh_btn.click(fn=load_latex, outputs=[latex_box])
    show_deleted.change(fn=refresh_all, inputs=inputs, outputs=outputs)
    show_smoke.change(fn=refresh_all, inputs=inputs, outputs=outputs)
    kind_filter.change(fn=refresh_all, inputs=inputs, outputs=outputs)
    mode_filter.change(fn=refresh_all, inputs=inputs, outputs=outputs)


if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)