""" GF-Score: Fairness-Aware Robustness Auditing Dashboard ======================================================= Hugging Face Spaces entry point. Loads pre-computed evaluation results (no model inference required) and serves an interactive Gradio dashboard. """ import sys import json import logging from pathlib import Path from datetime import datetime # Ensure repo root is on the path so gf_score package is importable ROOT = Path(__file__).parent.resolve() sys.path.insert(0, str(ROOT)) import gradio as gr logging.basicConfig(level=logging.INFO) logger = logging.getLogger("gf_score.hf_app") # --------------------------------------------------------------------------- # Paths — resolved relative to repo root (works both locally and on HF) # --------------------------------------------------------------------------- RESULTS_DIR = ROOT / "outputs" / "results" REPORTS_DIR = RESULTS_DIR / "reports" REPORTS_DIR.mkdir(parents=True, exist_ok=True) # --------------------------------------------------------------------------- # Model short-name mappings (copied from config to keep this file standalone) # --------------------------------------------------------------------------- CIFAR10_SHORT_NAMES = { "Augustin2020Adversarial_34_10_extra": "Augustin_WRN_extra", "Augustin2020Adversarial_34_10": "Augustin_WRN", "Augustin2020Adversarial": "Augustin2020", "Ding2020MMA": "Ding_MMA", "Engstrom2019Robustness": "Engstrom2019", "Gowal2020Uncovering": "Gowal2020", "Gowal2020Uncovering_extra": "Gowal_extra", "Rade2021Helper_R18_ddpm": "Rade_R18", "Rebuffi2021Fixing_28_10_cutmix_ddpm": "Rebuffi_28_ddpm", "Rebuffi2021Fixing_70_16_cutmix_ddpm": "Rebuffi_70_ddpm", "Rebuffi2021Fixing_70_16_cutmix_extra":"Rebuffi_extra", "Rebuffi2021Fixing_R18_cutmix_ddpm": "Rebuffi_R18", "Rice2020Overfitting": "Rice2020", "Rony2019Decoupling": "Rony2019", "Sehwag2021Proxy": "Sehwag_Proxy", "Sehwag2021Proxy_R18": "Sehwag_R18", "Wu2020Adversarial": "Wu2020", } IMAGENET_SHORT_NAMES = { "Salman2020Do_50_2": "Salman_WRN50-2", "Salman2020Do_R50": "Salman_R50", "Engstrom2019Robustness": "Engstrom2019", "Wong2020Fast": "Wong2020", "Salman2020Do_R18": "Salman_R18", } SHORT_NAMES = {"cifar10": CIFAR10_SHORT_NAMES, "imagenet": IMAGENET_SHORT_NAMES} # --------------------------------------------------------------------------- # Data loading # --------------------------------------------------------------------------- def load_results(dataset: str): suffix = f"_{dataset}" if dataset != "cifar10" else "" path = RESULTS_DIR / f"full_results{suffix}.json" if not path.exists(): logger.warning(f"Results file not found: {path}") return None with open(path, "r") as f: return json.load(f) def get_available_datasets(): available = [] if (RESULTS_DIR / "full_results.json").exists(): available.append("cifar10") if (RESULTS_DIR / "full_results_imagenet.json").exists(): available.append("imagenet") return available or ["cifar10"] def get_model_choices(results, dataset: str): if results is None: return [] names = SHORT_NAMES.get(dataset, {}) return [names.get(m, m) for m in results["model_results"].keys()] def display_name_to_full(display_name: str, results, dataset: str): names = SHORT_NAMES.get(dataset, {}) for full, short in names.items(): if short == display_name and full in results["model_results"]: return full return display_name if display_name in results["model_results"] else None def get_class_names(results): if results is None: return [] meta = results.get("metadata", {}) cls = meta.get("class_names") if cls: return cls model_results = results.get("model_results", {}) if model_results: first = next(iter(model_results.values())) return list(first.get("per_class_scores", {}).keys()) return [] # --------------------------------------------------------------------------- # Analysis # --------------------------------------------------------------------------- def analyze_model(model_display_name, lambda_val, dataset, results): if results is None: return ( "⚠️ **No results found.** The pre-computed evaluation files are missing.\n\n" "Please ensure `outputs/results/full_results.json` (and `full_results_imagenet.json`) " "are committed to the Space repository.", "
No data available.
", ) full_name = display_name_to_full(model_display_name, results, dataset) if full_name is None: return f"Model `{model_display_name}` not found in results.", "Not found.
" r = results["model_results"][full_name] class_names = get_class_names(results) num_classes = len(class_names) agg = r["aggregate_great_score"] rdi = r["rdi"] nrgc = r["nrgc"] wcr = r["wcr"] wcr_class = r.get("wcr_class", "—") fp_great = agg - lambda_val * rdi fp_at_0 = agg fp_at_1 = agg - rdi ds_label = "CIFAR-10" if dataset == "cifar10" else "ImageNet" threat = "L2 (ε=0.5)" if dataset == "cifar10" else "L∞ (ε=4/255)" rdi_icon = "✅ Low" if rdi < 0.1 else ("⚠️ Moderate" if rdi < 0.3 else "❌ High") wcr_icon = "✅ Good" if wcr > 0.2 else ("⚠️ Low" if wcr > 0.05 else "❌ Critical") fp_interp = ( "No fairness penalty (= aggregate GREAT Score)" if lambda_val == 0.0 else "Mild fairness adjustment" if lambda_val < 0.3 else "Balanced robustness-fairness trade-off" if lambda_val < 0.7 else "Strong fairness emphasis" ) # ---- Markdown analysis output ---- md = f"""## 🛡️ {model_display_name} **Dataset:** {ds_label} | **Threat Model:** {threat} | **Classes:** {num_classes} --- ### Aggregate Metrics | Metric | Value | Status | |--------|------:|--------| | **GREAT Score** (Ω̂) | `{agg:.4f}` | Certified robustness lower bound | | **RDI** (Disparity) | `{rdi:.4f}` | {rdi_icon} | | **NRGC** (Gini) | `{nrgc:.4f}` | Class inequality index ∈ [0, 1) | | **WCR** (Worst-Case)| `{wcr:.4f}` | {wcr_icon} — worst class: `{wcr_class}` | --- ### 🎛️ Fairness-Penalized Score (FP-GREAT) **FP-GREAT = Ω̄ − λ × RDI = {agg:.4f} − {lambda_val:.2f} × {rdi:.4f} = `{fp_great:.4f}`** *{fp_interp}* | λ | FP-GREAT | Meaning | |---|----------:|---------| | 0.00 | {fp_at_0:.4f} | Pure robustness (no penalty) | | **{lambda_val:.2f}** | **{fp_great:.4f}** | ← Current | | 1.00 | {fp_at_1:.4f} | Max fairness penalty | --- ### Per-Class Robustness Scores """ per_class = r.get("per_class_scores", {}) per_acc = r.get("per_class_accuracy", {}) max_score = max(per_class.values()) if per_class else 1.0 if num_classes > 30: sorted_cls = sorted(per_class.keys(), key=lambda c: per_class.get(c, 0)) bottom10, top10 = sorted_cls[:10], sorted_cls[-10:] md += f"*{num_classes} total classes — showing bottom 10 and top 10:*\n\n" md += "**🔴 Bottom 10 — Most Vulnerable:**\n\n" md += "| Class | Score | Accuracy | Bar |\n|-------|------:|----------:|-----|\n" for cls in bottom10: s = per_class.get(cls, 0) a = per_acc.get(cls, 0) bar = "█" * int(s / max(max_score, 0.001) * 15) md += f"| `{cls}` | {s:.4f} | {a:.1%} | {bar} |\n" md += "\n**🟢 Top 10 — Most Robust:**\n\n" md += "| Class | Score | Accuracy | Bar |\n|-------|------:|----------:|-----|\n" for cls in top10: s = per_class.get(cls, 0) a = per_acc.get(cls, 0) bar = "█" * int(s / max(max_score, 0.001) * 15) md += f"| `{cls}` | {s:.4f} | {a:.1%} | {bar} |\n" else: md += "| Class | Score | Accuracy | Bar |\n|-------|------:|----------:|-----|\n" for cls in class_names: s = per_class.get(cls, 0) a = per_acc.get(cls, 0) bar = "█" * int(s / max(max_score, 0.001) * 15) md += f"| `{cls}` | {s:.4f} | {a:.1%} | {bar} |\n" vuln = r.get("vulnerability_ranking", []) if vuln: display_vuln = vuln[:10] if num_classes > 30 else vuln suffix_txt = f" (top 10 of {num_classes})" if num_classes > 30 else "" md += f"\n### Vulnerability Ranking{suffix_txt}\n" for rank, (cls, score) in enumerate(display_vuln, 1): icon = "🔴" if rank <= 3 else ("🟡" if rank <= len(display_vuln) - 3 else "🟢") md += f"{rank}. {icon} **`{cls}`**: {score:.4f}\n" # ---- HTML audit report (isolated in iframe to prevent CSS bleed) ---- html_doc = _build_html_report( model_display_name, r, ds_label, threat, num_classes, class_names, per_class, per_acc, max_score, vuln, agg, rdi, nrgc, wcr, wcr_class, fp_great, lambda_val, ) # Escape for srcdoc attribute: & first, then " escaped = html_doc.replace("&", "&").replace('"', """) iframe = ( f'' ) return md, iframe def _build_html_report( model_name, r, ds_label, threat, num_classes, class_names, per_class, per_acc, max_score, vuln, agg, rdi, nrgc, wcr, wcr_class, fp_great, lambda_val, ): rdi_css = "pass" if rdi < 0.1 else ("warn" if rdi < 0.3 else "fail") wcr_css = "pass" if wcr > 0.2 else ("warn" if wcr > 0.05 else "fail") # Build per-class table rows if num_classes > 30: sorted_cls = sorted(per_class.keys(), key=lambda c: per_class.get(c, 0)) display_cls = sorted_cls[:10] + sorted_cls[-10:] else: display_cls = class_names class_rows = "" for cls in display_cls: s = per_class.get(cls, 0) a = per_acc.get(cls, 0) w = int(s / max(max_score, 0.001) * 200) class_rows += ( f"Model: {model_name}
Dataset: {ds_label} | Threat Model: {threat} | Classes: {num_classes}
Generated: {datetime.now().strftime('%Y-%m-%d %H:%M UTC')}
| Class | GREAT Score | Clean Acc. | Visual |
|---|
| Rank | Class | Score | Status |
|---|
{assessment}