GF-Score / app.py
aryashah00's picture
Upload app.py with huggingface_hub
571e661 verified
"""
GF-Score: Fairness-Aware Robustness Auditing Dashboard
=======================================================
Hugging Face Spaces entry point.
Loads pre-computed evaluation results (no model inference required)
and serves an interactive Gradio dashboard.
"""
import sys
import json
import logging
from pathlib import Path
from datetime import datetime
# Ensure repo root is on the path so gf_score package is importable
ROOT = Path(__file__).parent.resolve()
sys.path.insert(0, str(ROOT))
import gradio as gr
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("gf_score.hf_app")
# ---------------------------------------------------------------------------
# Paths — resolved relative to repo root (works both locally and on HF)
# ---------------------------------------------------------------------------
RESULTS_DIR = ROOT / "outputs" / "results"
REPORTS_DIR = RESULTS_DIR / "reports"
REPORTS_DIR.mkdir(parents=True, exist_ok=True)
# ---------------------------------------------------------------------------
# Model short-name mappings (copied from config to keep this file standalone)
# ---------------------------------------------------------------------------
CIFAR10_SHORT_NAMES = {
"Augustin2020Adversarial_34_10_extra": "Augustin_WRN_extra",
"Augustin2020Adversarial_34_10": "Augustin_WRN",
"Augustin2020Adversarial": "Augustin2020",
"Ding2020MMA": "Ding_MMA",
"Engstrom2019Robustness": "Engstrom2019",
"Gowal2020Uncovering": "Gowal2020",
"Gowal2020Uncovering_extra": "Gowal_extra",
"Rade2021Helper_R18_ddpm": "Rade_R18",
"Rebuffi2021Fixing_28_10_cutmix_ddpm": "Rebuffi_28_ddpm",
"Rebuffi2021Fixing_70_16_cutmix_ddpm": "Rebuffi_70_ddpm",
"Rebuffi2021Fixing_70_16_cutmix_extra":"Rebuffi_extra",
"Rebuffi2021Fixing_R18_cutmix_ddpm": "Rebuffi_R18",
"Rice2020Overfitting": "Rice2020",
"Rony2019Decoupling": "Rony2019",
"Sehwag2021Proxy": "Sehwag_Proxy",
"Sehwag2021Proxy_R18": "Sehwag_R18",
"Wu2020Adversarial": "Wu2020",
}
IMAGENET_SHORT_NAMES = {
"Salman2020Do_50_2": "Salman_WRN50-2",
"Salman2020Do_R50": "Salman_R50",
"Engstrom2019Robustness": "Engstrom2019",
"Wong2020Fast": "Wong2020",
"Salman2020Do_R18": "Salman_R18",
}
SHORT_NAMES = {"cifar10": CIFAR10_SHORT_NAMES, "imagenet": IMAGENET_SHORT_NAMES}
# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------
def load_results(dataset: str):
suffix = f"_{dataset}" if dataset != "cifar10" else ""
path = RESULTS_DIR / f"full_results{suffix}.json"
if not path.exists():
logger.warning(f"Results file not found: {path}")
return None
with open(path, "r") as f:
return json.load(f)
def get_available_datasets():
available = []
if (RESULTS_DIR / "full_results.json").exists():
available.append("cifar10")
if (RESULTS_DIR / "full_results_imagenet.json").exists():
available.append("imagenet")
return available or ["cifar10"]
def get_model_choices(results, dataset: str):
if results is None:
return []
names = SHORT_NAMES.get(dataset, {})
return [names.get(m, m) for m in results["model_results"].keys()]
def display_name_to_full(display_name: str, results, dataset: str):
names = SHORT_NAMES.get(dataset, {})
for full, short in names.items():
if short == display_name and full in results["model_results"]:
return full
return display_name if display_name in results["model_results"] else None
def get_class_names(results):
if results is None:
return []
meta = results.get("metadata", {})
cls = meta.get("class_names")
if cls:
return cls
model_results = results.get("model_results", {})
if model_results:
first = next(iter(model_results.values()))
return list(first.get("per_class_scores", {}).keys())
return []
# ---------------------------------------------------------------------------
# Analysis
# ---------------------------------------------------------------------------
def analyze_model(model_display_name, lambda_val, dataset, results):
if results is None:
return (
"⚠️ **No results found.** The pre-computed evaluation files are missing.\n\n"
"Please ensure `outputs/results/full_results.json` (and `full_results_imagenet.json`) "
"are committed to the Space repository.",
"<p>No data available.</p>",
)
full_name = display_name_to_full(model_display_name, results, dataset)
if full_name is None:
return f"Model `{model_display_name}` not found in results.", "<p>Not found.</p>"
r = results["model_results"][full_name]
class_names = get_class_names(results)
num_classes = len(class_names)
agg = r["aggregate_great_score"]
rdi = r["rdi"]
nrgc = r["nrgc"]
wcr = r["wcr"]
wcr_class = r.get("wcr_class", "—")
fp_great = agg - lambda_val * rdi
fp_at_0 = agg
fp_at_1 = agg - rdi
ds_label = "CIFAR-10" if dataset == "cifar10" else "ImageNet"
threat = "L2 (ε=0.5)" if dataset == "cifar10" else "L∞ (ε=4/255)"
rdi_icon = "✅ Low" if rdi < 0.1 else ("⚠️ Moderate" if rdi < 0.3 else "❌ High")
wcr_icon = "✅ Good" if wcr > 0.2 else ("⚠️ Low" if wcr > 0.05 else "❌ Critical")
fp_interp = (
"No fairness penalty (= aggregate GREAT Score)" if lambda_val == 0.0 else
"Mild fairness adjustment" if lambda_val < 0.3 else
"Balanced robustness-fairness trade-off" if lambda_val < 0.7 else
"Strong fairness emphasis"
)
# ---- Markdown analysis output ----
md = f"""## 🛡️ {model_display_name}
**Dataset:** {ds_label} &nbsp;|&nbsp; **Threat Model:** {threat} &nbsp;|&nbsp; **Classes:** {num_classes}
---
### Aggregate Metrics
| Metric | Value | Status |
|--------|------:|--------|
| **GREAT Score** (Ω̂) | `{agg:.4f}` | Certified robustness lower bound |
| **RDI** (Disparity) | `{rdi:.4f}` | {rdi_icon} |
| **NRGC** (Gini) | `{nrgc:.4f}` | Class inequality index ∈ [0, 1) |
| **WCR** (Worst-Case)| `{wcr:.4f}` | {wcr_icon} — worst class: `{wcr_class}` |
---
### 🎛️ Fairness-Penalized Score (FP-GREAT)
**FP-GREAT = Ω̄ − λ × RDI = {agg:.4f}{lambda_val:.2f} × {rdi:.4f} = `{fp_great:.4f}`**
*{fp_interp}*
| λ | FP-GREAT | Meaning |
|---|----------:|---------|
| 0.00 | {fp_at_0:.4f} | Pure robustness (no penalty) |
| **{lambda_val:.2f}** | **{fp_great:.4f}** | ← Current |
| 1.00 | {fp_at_1:.4f} | Max fairness penalty |
---
### Per-Class Robustness Scores
"""
per_class = r.get("per_class_scores", {})
per_acc = r.get("per_class_accuracy", {})
max_score = max(per_class.values()) if per_class else 1.0
if num_classes > 30:
sorted_cls = sorted(per_class.keys(), key=lambda c: per_class.get(c, 0))
bottom10, top10 = sorted_cls[:10], sorted_cls[-10:]
md += f"*{num_classes} total classes — showing bottom 10 and top 10:*\n\n"
md += "**🔴 Bottom 10 — Most Vulnerable:**\n\n"
md += "| Class | Score | Accuracy | Bar |\n|-------|------:|----------:|-----|\n"
for cls in bottom10:
s = per_class.get(cls, 0)
a = per_acc.get(cls, 0)
bar = "█" * int(s / max(max_score, 0.001) * 15)
md += f"| `{cls}` | {s:.4f} | {a:.1%} | {bar} |\n"
md += "\n**🟢 Top 10 — Most Robust:**\n\n"
md += "| Class | Score | Accuracy | Bar |\n|-------|------:|----------:|-----|\n"
for cls in top10:
s = per_class.get(cls, 0)
a = per_acc.get(cls, 0)
bar = "█" * int(s / max(max_score, 0.001) * 15)
md += f"| `{cls}` | {s:.4f} | {a:.1%} | {bar} |\n"
else:
md += "| Class | Score | Accuracy | Bar |\n|-------|------:|----------:|-----|\n"
for cls in class_names:
s = per_class.get(cls, 0)
a = per_acc.get(cls, 0)
bar = "█" * int(s / max(max_score, 0.001) * 15)
md += f"| `{cls}` | {s:.4f} | {a:.1%} | {bar} |\n"
vuln = r.get("vulnerability_ranking", [])
if vuln:
display_vuln = vuln[:10] if num_classes > 30 else vuln
suffix_txt = f" (top 10 of {num_classes})" if num_classes > 30 else ""
md += f"\n### Vulnerability Ranking{suffix_txt}\n"
for rank, (cls, score) in enumerate(display_vuln, 1):
icon = "🔴" if rank <= 3 else ("🟡" if rank <= len(display_vuln) - 3 else "🟢")
md += f"{rank}. {icon} **`{cls}`**: {score:.4f}\n"
# ---- HTML audit report (isolated in iframe to prevent CSS bleed) ----
html_doc = _build_html_report(
model_display_name, r, ds_label, threat, num_classes,
class_names, per_class, per_acc, max_score, vuln,
agg, rdi, nrgc, wcr, wcr_class, fp_great, lambda_val,
)
# Escape for srcdoc attribute: & first, then "
escaped = html_doc.replace("&", "&amp;").replace('"', "&quot;")
iframe = (
f'<iframe srcdoc="{escaped}" '
f'style="width:100%;height:720px;border:none;border-radius:8px;" '
f'sandbox="allow-same-origin"></iframe>'
)
return md, iframe
def _build_html_report(
model_name, r, ds_label, threat, num_classes,
class_names, per_class, per_acc, max_score, vuln,
agg, rdi, nrgc, wcr, wcr_class, fp_great, lambda_val,
):
rdi_css = "pass" if rdi < 0.1 else ("warn" if rdi < 0.3 else "fail")
wcr_css = "pass" if wcr > 0.2 else ("warn" if wcr > 0.05 else "fail")
# Build per-class table rows
if num_classes > 30:
sorted_cls = sorted(per_class.keys(), key=lambda c: per_class.get(c, 0))
display_cls = sorted_cls[:10] + sorted_cls[-10:]
else:
display_cls = class_names
class_rows = ""
for cls in display_cls:
s = per_class.get(cls, 0)
a = per_acc.get(cls, 0)
w = int(s / max(max_score, 0.001) * 200)
class_rows += (
f"<tr><td>{cls}</td><td>{s:.4f}</td><td>{a:.1%}</td>"
f'<td><div class="bar" style="width:{w}px"></div></td></tr>\n'
)
vuln_rows = ""
total_v = len(vuln)
for rank, (cls, score) in enumerate((vuln[:10] if num_classes > 30 else vuln), 1):
if rank <= 3:
status = '<span class="fail">⚠ Vulnerable</span>'
elif rank >= total_v - 2:
status = '<span class="pass">✓ Robust</span>'
else:
status = '<span class="warn">— Average</span>'
vuln_rows += f"<tr><td>{rank}</td><td>{cls}</td><td>{score:.4f}</td><td>{status}</td></tr>\n"
if rdi >= 0.3:
assessment = (
f"<strong class='fail'>High disparity (RDI={rdi:.3f}).</strong> "
f"Class <em>{wcr_class}</em> is significantly more vulnerable. "
)
elif rdi >= 0.1:
assessment = f"<strong class='warn'>Moderate disparity (RDI={rdi:.3f}).</strong> Some classes are noticeably more vulnerable."
else:
assessment = f"<strong class='pass'>Low disparity (RDI={rdi:.3f}).</strong> Robustness is distributed relatively evenly across classes."
if wcr < 0.05:
assessment += f" <strong class='fail'>Critical:</strong> Worst-case class ({wcr_class}) has near-zero robustness (WCR={wcr:.4f})."
elif wcr < 0.2:
assessment += f" Worst-case class ({wcr_class}) has limited robustness (WCR={wcr:.4f})."
return f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>GF-Score Audit — {model_name}</title>
<style>
body{{font-family:'Segoe UI',sans-serif;margin:0;background:#f5f7fa;color:#333}}
.wrap{{max-width:860px;margin:24px auto;background:#fff;padding:36px;border-radius:10px;box-shadow:0 2px 12px rgba(0,0,0,.1)}}
h1{{color:#2c3e50;border-bottom:3px solid #3498db;padding-bottom:8px;font-size:1.4em}}
h2{{color:#34495e;margin-top:28px;font-size:1.1em}}
.cards{{display:flex;flex-wrap:wrap;gap:12px;margin:12px 0}}
.card{{background:#ecf0f1;padding:14px 18px;border-radius:8px;text-align:center;min-width:130px}}
.card .val{{font-size:1.7em;font-weight:700;color:#2c3e50}}
.card .lbl{{font-size:.7em;color:#7f8c8d;text-transform:uppercase;margin-top:2px}}
.pass{{color:#27ae60}}.warn{{color:#e67e22}}.fail{{color:#e74c3c}}
table{{border-collapse:collapse;width:100%;margin:12px 0;font-size:.9em}}
th{{background:#3498db;color:#fff;padding:9px 14px;text-align:left}}
td{{padding:7px 14px;border-bottom:1px solid #eee}}
tr:nth-child(even){{background:#f9f9f9}}
.bar{{height:14px;background:linear-gradient(90deg,#3498db,#2ecc71);border-radius:3px;display:inline-block}}
.footer{{margin-top:24px;padding-top:12px;border-top:1px solid #eee;font-size:.75em;color:#aaa}}
</style>
</head>
<body>
<div class="wrap">
<h1>🛡️ GF-Score Robustness Audit Report</h1>
<p><strong>Model:</strong> {model_name}<br>
<strong>Dataset:</strong> {ds_label} &nbsp;|&nbsp; <strong>Threat Model:</strong> {threat} &nbsp;|&nbsp; <strong>Classes:</strong> {num_classes}<br>
<strong>Generated:</strong> {datetime.now().strftime('%Y-%m-%d %H:%M UTC')}</p>
<h2>Summary Metrics</h2>
<div class="cards">
<div class="card"><div class="val">{agg:.4f}</div><div class="lbl">GREAT Score</div></div>
<div class="card"><div class="val {rdi_css}">{rdi:.4f}</div><div class="lbl">RDI</div></div>
<div class="card"><div class="val">{nrgc:.4f}</div><div class="lbl">NRGC (Gini)</div></div>
<div class="card"><div class="val {wcr_css}">{wcr:.4f}</div><div class="lbl">WCR ({wcr_class})</div></div>
<div class="card"><div class="val">{fp_great:.4f}</div><div class="lbl">FP-GREAT (λ={lambda_val})</div></div>
</div>
<h2>Per-Class Robustness Profile</h2>
<table>
<tr><th>Class</th><th>GREAT Score</th><th>Clean Acc.</th><th>Visual</th></tr>
{class_rows}
</table>
<h2>Vulnerability Ranking</h2>
<table>
<tr><th>Rank</th><th>Class</th><th>Score</th><th>Status</th></tr>
{vuln_rows}
</table>
<h2>Assessment</h2>
<p>{assessment}</p>
<div class="footer">
GF-Score v0.1.0 · Based on GREAT Score (Li et al., NeurIPS 2024) extended with per-class fairness metrics ·
Metrics: RDI (Max Group Disparity), NRGC (Gini), WCR (Rawlsian Maximin), FP-GREAT (IHDI Adaptation)
</div>
</div>
</body>
</html>"""
# ---------------------------------------------------------------------------
# Gradio App
# ---------------------------------------------------------------------------
def build_app():
available = get_available_datasets()
default_ds = available[0]
results_cache = {ds: load_results(ds) for ds in available}
dataset_labels = {
"cifar10": "CIFAR-10 (10 classes · L2 threat model · 17 models)",
"imagenet": "ImageNet (1000 classes · L∞ threat model · 5 models)",
}
with gr.Blocks(
title="GF-Score Auditing Dashboard",
theme=gr.themes.Soft(),
css=".gr-markdown table { width: 100%; }",
) as demo:
current_ds = gr.State(default_ds)
current_results = gr.State(results_cache.get(default_ds))
gr.Markdown("""
# 🛡️ GF-Score: Fairness-Aware Robustness Auditing Dashboard
Inspect **class-conditional adversarial robustness** of certified models with four fairness metrics
grounded in welfare economics. Based on [GREAT Score (NeurIPS 2024)](https://arxiv.org/abs/2304.09875),
extended with per-class decomposition, disparity analysis, and **attack-free** self-calibration.
| Metric | Meaning |
|--------|---------|
| **RDI** | Range of per-class robustness (Max Group Disparity) |
| **NRGC** | Normalized Gini Coefficient — overall inequality |
| **WCR** | Worst-case class robustness (Rawlsian maximin) |
| **FP-GREAT** | Fairness-penalized aggregate score: Ω̄ − λ·RDI |
""")
with gr.Row():
dataset_dd = gr.Dropdown(
choices=[(dataset_labels[ds], ds) for ds in available],
value=default_ds,
label="Dataset",
scale=2,
)
model_dd = gr.Dropdown(
choices=get_model_choices(results_cache.get(default_ds), default_ds),
value=(get_model_choices(results_cache.get(default_ds), default_ds) or [None])[0],
label="Model",
scale=2,
)
with gr.Row():
lambda_sl = gr.Slider(
minimum=0.0, maximum=1.0, value=0.5, step=0.05,
label="Fairness Penalty λ (FP-GREAT = GREAT Score − λ × RDI)",
scale=3,
)
analyze_btn = gr.Button("🔍 Analyze", variant="primary", scale=1)
with gr.Tabs():
with gr.TabItem("📊 Analysis"):
analysis_md = gr.Markdown()
with gr.TabItem("📄 Full HTML Report"):
report_html = gr.HTML()
# ---- callbacks ----
def on_dataset_change(ds_choice):
res = results_cache.get(ds_choice) or load_results(ds_choice)
results_cache[ds_choice] = res
choices = get_model_choices(res, ds_choice)
default_model = choices[0] if choices else None
return (
gr.update(choices=choices, value=default_model),
ds_choice,
res,
)
def run(model_name, lam, ds, res):
md, html = analyze_model(model_name, lam, ds, res)
return md, html
dataset_dd.change(
fn=on_dataset_change,
inputs=[dataset_dd],
outputs=[model_dd, current_ds, current_results],
)
analyze_btn.click(
fn=run,
inputs=[model_dd, lambda_sl, current_ds, current_results],
outputs=[analysis_md, report_html],
)
lambda_sl.change(
fn=run,
inputs=[model_dd, lambda_sl, current_ds, current_results],
outputs=[analysis_md, report_html],
)
gr.Markdown("""---
*GF-Score v0.1.0 · [Paper (NeurIPS 2026, under review)]() · [GitHub](https://github.com/aryashah00/GF-Score)*""")
return demo
demo = build_app()
if __name__ == "__main__":
demo.launch()