saeguardbench / app.py
mdarahmanxAI's picture
update paper link to GitHub PDF
5df4c0f verified
"""SAEGuardBench: HuggingFace Spaces Gradio app.
Interactive benchmark explorer for "Do SAE Features Actually Help Detect
Jailbreaks?" — a systematic comparison of 8 detection methods across 4
paradigms, 6 datasets, and 4 models (2B--70B parameters).
Key finding: SAE features consistently hurt jailbreak detection compared to
simple linear probes on raw activations (the "Detection Gap").
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import gradio as gr
import numpy as np
import pandas as pd
import plotly.graph_objects as go
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
APP_TITLE = "SAEGuardBench: Do SAE Features Actually Help Detect Jailbreaks?"
PROJECT_ROOT = Path(__file__).resolve().parent
RESULTS_DIR = PROJECT_ROOT / "results"
FIGURES_DIR = PROJECT_ROOT / "figures"
PAPER_FIGURES_DIR = PROJECT_ROOT / "paper" / "figures"
LEADERBOARD_CSV = RESULTS_DIR / "leaderboard.csv"
# Color palette (matches paper)
COLOR_PRIMARY = "#0072B2"
COLOR_SAE = "#E69F00"
COLOR_RAW = "#009E73"
COLOR_GAP_NEG = "#D55E00"
COLOR_GAP_POS = "#009E73"
# Detection Gap data (from paper Table 2 / Section 4)
DETECTION_GAP_DATA: list[dict[str, Any]] = [
{"model": "Gemma-2-2B", "raw_auroc": 0.949, "sae_auroc": 0.712, "gap": -0.237},
{"model": "Llama-3.1-8B", "raw_auroc": 0.867, "sae_auroc": 0.477, "gap": -0.391},
{"model": "Gemma-3-4B", "raw_auroc": 0.922, "sae_auroc": 0.709, "gap": -0.213},
{"model": "Llama-3.3-70B", "raw_auroc": 1.000, "sae_auroc": 0.949, "gap": -0.051},
]
# Method comparison data — Gemma-2-2B, Layer 12
# SAE methods use SAE features; probes use raw activations (matches paper Table 1)
# Values from results/sae_features_*.json and results/train_*.json
DATASETS = ["JailbreakBench", "HarmBench", "AdvBench", "SORRY-Bench", "WildJailbreak"]
METHOD_RESULTS: dict[str, list[float]] = {
"Linear Probe": [0.949, 1.000, 1.000, 1.000, 1.000],
"SAE-Classifier": [0.704, 0.984, 1.000, 1.000, 1.000],
"CC-Delta": [0.712, 0.972, 1.000, 0.997, 0.999],
"GSAE": [0.707, 0.966, 1.000, 1.000, 0.999],
"Random SAE": [0.571, 0.626, 1.000, 0.996, 1.000],
"MLP Probe": [0.942, 1.000, 1.000, 1.000, 0.999],
"FJD": [0.472, 0.500, 0.500, 0.490, 0.680],
"LlamaGuard-3": [0.885, 0.940, 0.990, 0.940, 0.850],
}
# Paradigm labels for leaderboard
METHOD_PARADIGMS: dict[str, str] = {
"Linear Probe": "Activation Probe",
"MLP Probe": "Activation Probe",
"SAE-Classifier": "SAE Feature",
"GSAE": "SAE Feature",
"Random SAE": "SAE Feature",
"CC-Delta": "SAE Feature",
"FJD": "Logit-Based",
"LlamaGuard-3": "External Model",
}
# Figure metadata: (filename, caption)
FIGURE_ENTRIES: list[tuple[str, str]] = [
(
"fig1_raw_vs_sae.png",
"Figure 1: Raw activation probes consistently outperform SAE-based "
"methods across all models and datasets.",
),
(
"fig_roc_curves.png",
"ROC curves comparing detection methods on JailbreakBench. The raw "
"linear probe dominates SAE-based approaches.",
),
(
"fig_hybrid_recovery.png",
"Hybrid recovery rates: combining raw detection with SAE explanation "
"recovers 88-106% of raw-only performance.",
),
(
"fig_safety_subspace.png",
"Safety subspace PCA: PC2 carries 13% variance in reconstruction but "
"58% in residual, explaining the Detection Gap.",
),
(
"fig_pareto_tradeoff.png",
"Pareto tradeoff between detection accuracy and interpretability "
"across all methods.",
),
(
"fig_crossmodel_hybrid.png",
"Cross-model hybrid results showing InterpGuard generalises across "
"model families and scales.",
),
]
PAPER_ABSTRACT = (
"Sparse Autoencoders (SAEs) are increasingly proposed as interpretable "
"safety monitors for large language models. But do their features actually "
"help detect jailbreaks? We introduce SAEGuardBench, a benchmark comparing "
"8 detection methods across 4 paradigms on 6 datasets and 4 models "
"(2B\u201370B parameters). The answer is no. SAE features consistently hurt "
"detection compared to simple linear probes on raw activations, a gap we "
"call the Detection Gap, which is negative on every model we test. The gap "
"persists across layers, transfer settings, wider SAEs, and nonlinear "
"classifiers. We trace the cause to the reconstruction objective, which "
"discards low-variance directions carrying safety signal. Yet SAE features "
"still capture interpretable concept structure that raw activations lack. "
"To exploit both strengths, we describe InterpGuard, a practical two-stage "
"recipe that detects with raw activations and explains with SAE features. "
"An LLM-as-judge evaluation across three frontier models reveals a "
"bottleneck: current SAE labels identify that a prompt is harmful but not "
"what kind of harm. We also show the gap is not fundamental: fine-tuning "
"the SAE encoder with a classification-aware objective nearly closes it, "
"confirming the problem lies in the training objective, not the architecture."
)
BIBTEX_CITATION = r"""@article{rahman2026saeguardbench,
title = {Do SAE Features Actually Help Detect Jailbreaks?
A Systematic Benchmark of Interpretability-Based Safety Methods},
author = {Rahman, Md A},
year = {2026},
url = {https://github.com/ronyrahmaan/saeguardbench}
}"""
# ---------------------------------------------------------------------------
# Data helpers
# ---------------------------------------------------------------------------
def _build_hardcoded_leaderboard() -> pd.DataFrame:
"""Build the leaderboard from hardcoded paper results.
Returns a DataFrame with columns:
Method, Paradigm, Model, Dataset, AUROC, F1, FPR@95TPR
"""
rows: list[dict[str, Any]] = []
model = "Gemma-2-2B"
for method, scores in METHOD_RESULTS.items():
paradigm = METHOD_PARADIGMS[method]
for dataset, auroc in zip(DATASETS, scores):
rows.append(
{
"Method": method,
"Paradigm": paradigm,
"Model": model,
"Dataset": dataset,
"AUROC": round(auroc, 3),
}
)
return pd.DataFrame(rows)
def load_leaderboard() -> pd.DataFrame:
"""Load the leaderboard CSV, falling back to hardcoded data."""
if LEADERBOARD_CSV.exists():
try:
df = pd.read_csv(LEADERBOARD_CSV)
required = {"Method", "Paradigm", "Model", "Dataset", "AUROC"}
if required.issubset(set(df.columns)):
return df
except Exception:
pass
return _build_hardcoded_leaderboard()
def _color_auroc(val: float) -> str:
"""Return an HTML-styled AUROC value with color coding."""
if val >= 0.9:
color = "#15803d" # green-700
elif val >= 0.7:
color = "#a16207" # yellow-700
else:
color = "#b91c1c" # red-700
return f"<span style='color:{color};font-weight:600'>{val:.3f}</span>"
# ---------------------------------------------------------------------------
# Tab builders
# ---------------------------------------------------------------------------
def build_leaderboard_tab() -> None:
"""Tab 1: Sortable leaderboard with filters."""
full_df = load_leaderboard()
# Detection Gap banner
avg_gap = np.mean([d["gap"] for d in DETECTION_GAP_DATA])
gr.Markdown(
f"### Detection Gap (average across models): "
f"**{avg_gap:+.3f}** AUROC\n"
"*SAE features hurt detection on every model tested.*"
)
model_choices = ["All", *sorted(full_df["Model"].unique().tolist())]
dataset_choices = ["All", *sorted(full_df["Dataset"].unique().tolist())]
paradigm_choices = ["All", *sorted(full_df["Paradigm"].unique().tolist())]
with gr.Row():
model_filter = gr.Dropdown(
choices=model_choices, value="All", label="Model"
)
dataset_filter = gr.Dropdown(
choices=dataset_choices, value="All", label="Dataset"
)
paradigm_filter = gr.Dropdown(
choices=paradigm_choices, value="All", label="Paradigm"
)
table = gr.Dataframe(
value=full_df,
label="Benchmark Results",
interactive=False,
)
def _filter_table(
model: str, dataset: str, paradigm: str
) -> pd.DataFrame:
"""Filter the leaderboard by the selected criteria."""
df = full_df.copy()
if model != "All":
df = df[df["Model"] == model]
if dataset != "All":
df = df[df["Dataset"] == dataset]
if paradigm != "All":
df = df[df["Paradigm"] == paradigm]
return df.sort_values("AUROC", ascending=False).reset_index(drop=True)
for filt in [model_filter, dataset_filter, paradigm_filter]:
filt.change(
fn=_filter_table,
inputs=[model_filter, dataset_filter, paradigm_filter],
outputs=table,
)
def build_detection_gap_tab() -> None:
"""Tab 2: Detection Gap bar charts."""
df = pd.DataFrame(DETECTION_GAP_DATA)
# --- Gap bar chart ---
gap_colors = [
COLOR_GAP_POS if g >= 0 else COLOR_GAP_NEG for g in df["gap"]
]
fig_gap = go.Figure(
go.Bar(
x=df["model"],
y=df["gap"],
marker_color=gap_colors,
text=[f"{g:+.3f}" for g in df["gap"]],
textposition="outside",
)
)
fig_gap.update_layout(
title="Detection Gap per Model (SAE AUROC minus Raw AUROC)",
xaxis_title="Model",
yaxis_title="Detection Gap (AUROC)",
yaxis_range=[
min(df["gap"]) - 0.08,
max(0.05, max(df["gap"]) + 0.08),
],
template="plotly_white",
height=450,
font={"family": "Inter, system-ui, sans-serif"},
)
fig_gap.add_hline(y=0, line_dash="dash", line_color="gray")
gr.Markdown(
"### Detection Gap\n"
"Negative values mean SAE features *hurt* detection vs. raw probes."
)
gr.Plot(fig_gap)
# --- Grouped bar chart: raw vs SAE ---
fig_grouped = go.Figure()
fig_grouped.add_trace(
go.Bar(
name="Raw Probe AUROC",
x=df["model"],
y=df["raw_auroc"],
marker_color=COLOR_RAW,
text=[f"{v:.3f}" for v in df["raw_auroc"]],
textposition="outside",
)
)
fig_grouped.add_trace(
go.Bar(
name="SAE AUROC",
x=df["model"],
y=df["sae_auroc"],
marker_color=COLOR_SAE,
text=[f"{v:.3f}" for v in df["sae_auroc"]],
textposition="outside",
)
)
fig_grouped.update_layout(
title="Raw Probe vs SAE Detection AUROC",
xaxis_title="Model",
yaxis_title="AUROC",
yaxis_range=[0, 1.12],
barmode="group",
template="plotly_white",
height=450,
legend={"orientation": "h", "yanchor": "bottom", "y": 1.02},
font={"family": "Inter, system-ui, sans-serif"},
)
gr.Plot(fig_grouped)
def build_method_comparison_tab() -> None:
"""Tab 3: Radar/spider chart comparing methods across datasets."""
method_names = list(METHOD_RESULTS.keys())
default_methods = ["Linear Probe", "SAE-Classifier", "CC-Delta"]
method_selector = gr.CheckboxGroup(
choices=method_names,
value=default_methods,
label="Select methods to compare",
)
plot_output = gr.Plot()
def _make_radar(selected: list[str]) -> go.Figure:
"""Build a radar chart for the selected methods."""
if not selected:
selected = default_methods
fig = go.Figure()
# Plotly radar needs the first category repeated to close the polygon
categories = [*DATASETS, DATASETS[0]]
# Use a qualitative color scale
colors = [
"#0072B2", "#E69F00", "#009E73", "#CC79A7",
"#D55E00", "#56B4E9", "#F0E442", "#999999",
]
for i, method in enumerate(selected):
if method not in METHOD_RESULTS:
continue
values = METHOD_RESULTS[method] + [METHOD_RESULTS[method][0]]
color = colors[i % len(colors)]
fig.add_trace(
go.Scatterpolar(
r=values,
theta=categories,
fill="toself",
name=method,
line_color=color,
opacity=0.7,
)
)
fig.update_layout(
polar={
"radialaxis": {
"visible": True,
"range": [0.3, 1.05],
"tickvals": [0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
}
},
title="Method Comparison (Gemma-2-2B, Layer 12, AUROC)",
template="plotly_white",
height=550,
font={"family": "Inter, system-ui, sans-serif"},
legend={"orientation": "h", "yanchor": "bottom", "y": -0.15},
)
return fig
method_selector.change(
fn=_make_radar, inputs=method_selector, outputs=plot_output
)
# Trigger initial render via app.load
plot_output.value = _make_radar(default_methods) # noqa: set at build time
def build_figures_tab() -> None:
"""Tab 4: Gallery of paper figures with captions."""
found_any = False
for filename, caption in FIGURE_ENTRIES:
# Check both figures/ and paper/figures/
path = FIGURES_DIR / filename
if not path.exists():
path = PAPER_FIGURES_DIR / filename
if path.exists():
found_any = True
gr.Markdown(f"**{caption}**")
gr.Image(str(path), label=filename, show_label=False)
gr.Markdown("---")
if not found_any:
gr.Markdown(
"*No figures found. Place PNG files in `figures/` or "
"`paper/figures/` to display them here.*"
)
def build_about_tab() -> None:
"""Tab 5: Paper abstract, links, and citation."""
gr.Markdown("## Abstract")
gr.Markdown(PAPER_ABSTRACT)
gr.Markdown("---")
gr.Markdown("## Links")
gr.Markdown(
"- **Code**: [github.com/ronyrahmaan/saeguardbench]"
"(https://github.com/ronyrahmaan/saeguardbench)\n"
"- **Dataset**: [HuggingFace](https://huggingface.co/datasets/mdarahmanxAI/SAEGuardBench)\n"
"- **Paper**: [PDF](https://github.com/ronyrahmaan/saeguardbench/blob/main/paper.pdf)"
)
gr.Markdown("---")
gr.Markdown("## Citation")
gr.Code(BIBTEX_CITATION, language=None, label="BibTeX")
gr.Markdown("---")
gr.Markdown(
"## Authors\n\n"
"**Md A Rahman** \n"
"Department of Computer Science, Texas Tech University \n"
"[ara02434@ttu.edu](mailto:ara02434@ttu.edu)"
)
# ---------------------------------------------------------------------------
# App assembly
# ---------------------------------------------------------------------------
def create_app() -> gr.Blocks:
"""Create and return the Gradio Blocks app."""
with gr.Blocks(title=APP_TITLE) as app:
gr.Markdown(f"# {APP_TITLE}")
gr.Markdown(
"A systematic benchmark of interpretability-based safety methods "
"for LLM jailbreak detection. **Key finding:** SAE features "
"consistently *hurt* detection compared to raw activation probes."
)
with gr.Tabs():
with gr.TabItem("Leaderboard"):
build_leaderboard_tab()
with gr.TabItem("Detection Gap"):
build_detection_gap_tab()
with gr.TabItem("Method Comparison"):
build_method_comparison_tab()
with gr.TabItem("Figures"):
build_figures_tab()
with gr.TabItem("About"):
build_about_tab()
return app
if __name__ == "__main__":
demo = create_app()
demo.launch()