GuardLLM / app.py
AlephBeth-AI's picture
Upload app.py with huggingface_hub
1569836 verified
"""
GuardLLM - Interactive Prompt Security Visualizer
Combines t-SNE embedding visualization with real-time prompt risk analysis.
Powered by Llama Prompt Guard 2 (86M) and neuralchemy/Prompt-injection-dataset.
"""
import logging
import sys
import json
import traceback
import gradio as gr
import torch
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
from pathlib import Path
# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger("GuardLLM")
# ---------------------------------------------------------------------------
# Color palette for categories
# ---------------------------------------------------------------------------
CATEGORY_COLORS = {
"benign": "#22c55e",
"direct_injection": "#ef4444",
"jailbreak": "#f97316",
"system_extraction": "#a855f7",
"encoding_obfuscation": "#ec4899",
"persona_replacement": "#f59e0b",
"indirect_injection": "#e11d48",
"token_smuggling": "#7c3aed",
"many_shot": "#06b6d4",
"crescendo": "#14b8a6",
"context_overflow": "#8b5cf6",
"prompt_leaking": "#d946ef",
"unknown": "#64748b",
}
CATEGORY_LABELS = {
"benign": "Benign",
"direct_injection": "Direct Injection",
"jailbreak": "Jailbreak",
"system_extraction": "System Extraction",
"encoding_obfuscation": "Encoding / Obfuscation",
"persona_replacement": "Persona Replacement",
"indirect_injection": "Indirect Injection",
"token_smuggling": "Token Smuggling",
"many_shot": "Many-Shot",
"crescendo": "Crescendo",
"context_overflow": "Context Overflow",
"prompt_leaking": "Prompt Leaking",
"unknown": "Unknown",
}
# ---------------------------------------------------------------------------
# Lazy-loaded risk classifier (Llama Prompt Guard 2)
# ---------------------------------------------------------------------------
MODEL_ID = "meta-llama/Llama-Prompt-Guard-2-86M"
LABELS = ["Benign", "Malicious"]
_classifier = {"tokenizer": None, "model": None, "device": None}
def get_classifier():
if _classifier["model"] is None:
logger.info("Lazy-loading Llama Prompt Guard 2...")
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tok = AutoTokenizer.from_pretrained(MODEL_ID)
mdl = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
mdl.eval()
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mdl.to(dev)
_classifier["tokenizer"] = tok
_classifier["model"] = mdl
_classifier["device"] = dev
logger.info("Classifier loaded on %s", dev)
return _classifier["tokenizer"], _classifier["model"], _classifier["device"]
# ---------------------------------------------------------------------------
# Load precomputed t-SNE data
# ---------------------------------------------------------------------------
CACHE_DIR = Path(__file__).parent / "cache"
CACHE_FILE = CACHE_DIR / "embeddings_tsne.npz"
META_FILE = CACHE_DIR / "metadata.json"
logger.info("Loading precomputed t-SNE cache from %s", CACHE_DIR)
if not CACHE_FILE.exists() or not META_FILE.exists():
raise RuntimeError(
"Cache files not found in %s. Run precompute.py first." % CACHE_DIR
)
_npz = np.load(CACHE_FILE)
TSNE_COORDS = _npz["tsne_2d"]
with open(META_FILE, "r", encoding="utf-8") as f:
METADATA = json.load(f)
logger.info("Loaded %d points for visualization", len(METADATA))
ALL_TEXTS = [m["text"] for m in METADATA]
ALL_CATEGORIES = [m["category"] for m in METADATA]
ALL_SEVERITIES = [m["severity"] for m in METADATA]
ALL_LABELS_DS = [m["label"] for m in METADATA]
UNIQUE_CATEGORIES = sorted(set(ALL_CATEGORIES))
DROPDOWN_CHOICES = []
for i, m in enumerate(METADATA):
preview = m["text"][:70].replace("\n", " ")
if len(m["text"]) > 70:
preview += "..."
DROPDOWN_CHOICES.append(f"{i} | {m['category']} | {preview}")
# ---------------------------------------------------------------------------
# Analysis function
# ---------------------------------------------------------------------------
def analyze_prompt(text):
if not text or not text.strip():
return {}, 0.0
tokenizer, model, DEVICE = get_classifier()
inputs = tokenizer(
text, return_tensors="pt", truncation=True, max_length=512, padding=True
).to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
pred_idx = int(np.argmax(probs))
prob_dict = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
safety = float(probs[0])
return prob_dict, safety
# ---------------------------------------------------------------------------
# Build the t-SNE Plotly figure
# ---------------------------------------------------------------------------
def build_tsne_figure(selected_categories=None):
fig = go.Figure()
for cat in UNIQUE_CATEGORIES:
indices = [
i for i, c in enumerate(ALL_CATEGORIES)
if c == cat
and (selected_categories is None or cat in selected_categories)
]
if not indices:
continue
x = TSNE_COORDS[indices, 0].tolist()
y = TSNE_COORDS[indices, 1].tolist()
texts_preview = [
ALL_TEXTS[i][:80].replace("\n", " ") + ("..." if len(ALL_TEXTS[i]) > 80 else "")
for i in indices
]
severities = [ALL_SEVERITIES[i] or "benign" for i in indices]
hover_texts = [
f"<b>{CATEGORY_LABELS.get(cat, cat)}</b><br>"
f"Severity: {sev}<br>"
f"Index: {idx}<br>"
f"<i>{txt}</i>"
for idx, txt, sev in zip(indices, texts_preview, severities)
]
color = CATEGORY_COLORS.get(cat, CATEGORY_COLORS["unknown"])
label = CATEGORY_LABELS.get(cat, cat)
fig.add_trace(go.Scatter(
x=x, y=y,
mode="markers",
name=label,
marker=dict(
size=5 if len(indices) > 500 else 7,
color=color,
opacity=0.7,
line=dict(width=0.5, color="rgba(255,255,255,0.2)"),
),
text=hover_texts,
hoverinfo="text",
customdata=[str(i) for i in indices],
))
fig.update_layout(
template="plotly_dark",
paper_bgcolor="#0f172a",
plot_bgcolor="#1e293b",
title=dict(
text="t-SNE Embedding Space - Prompt Security Landscape",
font=dict(size=16, color="#e2e8f0"),
x=0.5,
),
legend=dict(
title=dict(text="Category", font=dict(color="#94a3b8")),
bgcolor="rgba(15,23,42,0.9)",
bordercolor="#334155",
borderwidth=1,
font=dict(color="#cbd5e1", size=10),
itemsizing="constant",
),
xaxis=dict(
title="t-SNE 1", showgrid=True, gridcolor="#334155",
zeroline=False, color="#94a3b8",
),
yaxis=dict(
title="t-SNE 2", showgrid=True, gridcolor="#334155",
zeroline=False, color="#94a3b8",
),
margin=dict(l=40, r=40, t=50, b=40),
height=600,
dragmode="pan",
)
return fig
# ---------------------------------------------------------------------------
# Callbacks
# ---------------------------------------------------------------------------
def on_filter_change(categories):
sel = categories if categories else None
return build_tsne_figure(sel)
def select_all_categories():
return gr.update(value=UNIQUE_CATEGORIES), build_tsne_figure(UNIQUE_CATEGORIES)
def deselect_all_categories():
return gr.update(value=[]), build_tsne_figure([])
def on_dropdown_select(choice):
if not choice:
return empty_analysis_html(), "*Select a prompt.*", ""
try:
idx = int(choice.split(" | ")[0])
text = ALL_TEXTS[idx]
category = ALL_CATEGORIES[idx]
severity = ALL_SEVERITIES[idx] or "N/A"
ground_truth = "Malicious" if ALL_LABELS_DS[idx] == 1 else "Benign"
prob_dict, safety = analyze_prompt(text)
pred_label = max(prob_dict, key=prob_dict.get)
confidence = prob_dict[pred_label]
result_html = build_result_html(pred_label, confidence, prob_dict, text)
risk_text = build_risk_assessment(pred_label, confidence, prob_dict)
risk_text += (
f"\n\n---\n**Dataset metadata:**\n"
f"- Category: **{CATEGORY_LABELS.get(category, category)}**\n"
f"- Severity: **{severity}**\n"
f"- Ground truth: **{ground_truth}**\n"
)
return result_html, risk_text, text
except Exception as e:
logger.error("Error: %s", e)
return empty_analysis_html(), f"Error: {e}", ""
def on_index_input(idx_str):
if not idx_str or not idx_str.strip():
return empty_analysis_html(), "*Click a point on the chart.*", ""
try:
idx = int(idx_str.strip())
if idx < 0 or idx >= len(ALL_TEXTS):
return empty_analysis_html(), f"Invalid index: {idx}", ""
text = ALL_TEXTS[idx]
category = ALL_CATEGORIES[idx]
severity = ALL_SEVERITIES[idx] or "N/A"
ground_truth = "Malicious" if ALL_LABELS_DS[idx] == 1 else "Benign"
prob_dict, safety = analyze_prompt(text)
pred_label = max(prob_dict, key=prob_dict.get)
confidence = prob_dict[pred_label]
result_html = build_result_html(pred_label, confidence, prob_dict, text)
risk_text = build_risk_assessment(pred_label, confidence, prob_dict)
risk_text += (
f"\n\n---\n**Dataset metadata:**\n"
f"- Category: **{CATEGORY_LABELS.get(category, category)}**\n"
f"- Severity: **{severity}**\n"
f"- Ground truth: **{ground_truth}**\n"
)
return result_html, risk_text, text
except Exception as e:
logger.error("Error: %s", e)
return empty_analysis_html(), f"Error: {e}", ""
def on_manual_analyze(text):
if not text or not text.strip():
return empty_analysis_html(), ""
prob_dict, safety = analyze_prompt(text)
pred_label = max(prob_dict, key=prob_dict.get)
confidence = prob_dict[pred_label]
result_html = build_result_html(pred_label, confidence, prob_dict, text)
risk_text = build_risk_assessment(pred_label, confidence, prob_dict)
return result_html, risk_text
# ---------------------------------------------------------------------------
# UI builders
# ---------------------------------------------------------------------------
def empty_analysis_html():
return """
<div style="text-align:center; padding:30px; color:#94a3b8;">
<p style="font-size:1em;">Click a point on the chart,<br>
select a prompt from the list,<br>
or enter a custom prompt below.</p>
</div>
"""
def build_result_html(label, confidence, probs, text):
color = "#22c55e" if label == "Benign" else "#ef4444"
emoji = "\u2705" if label == "Benign" else "\u26a0\ufe0f"
pct = confidence * 100
safety_score = probs["Benign"] * 100
safety_color = (
"#22c55e" if safety_score >= 70
else "#f59e0b" if safety_score >= 40
else "#ef4444"
)
bars_html = ""
for lbl in LABELS:
p = probs[lbl] * 100
c = "#22c55e" if lbl == "Benign" else "#ef4444"
bars_html += f"""
<div style="margin-bottom:8px;">
<div style="display:flex; justify-content:space-between; margin-bottom:2px;">
<span style="font-weight:600; color:#e2e8f0;">{lbl}</span>
<span style="color:#cbd5e1; font-weight:600;">{p:.1f}%</span>
</div>
<div style="background:#1e293b; border-radius:8px; height:18px; overflow:hidden;">
<div style="background:{c}; height:100%; width:{p}%; border-radius:8px;"></div>
</div>
</div>
"""
preview = text[:150].replace("<", "&lt;").replace(">", "&gt;")
if len(text) > 150:
preview += "..."
return f"""
<div style="background:#0f172a; border-radius:12px; padding:18px; font-family:system-ui,sans-serif;">
<div style="text-align:center; margin-bottom:14px;">
<div style="font-size:2em;">{emoji}</div>
<div style="font-size:1.2em; font-weight:700; color:{color};">{label}</div>
<div style="color:#94a3b8; font-size:0.85em;">Confidence: {pct:.1f}%</div>
</div>
<div style="background:#1e293b; border-radius:10px; padding:12px; margin-bottom:10px;">
<div style="display:flex; justify-content:space-between; margin-bottom:4px;">
<span style="color:#e2e8f0; font-weight:600;">Safety Score</span>
<span style="color:{safety_color}; font-weight:700; font-size:1.1em;">{safety_score:.0f}/100</span>
</div>
<div style="background:#334155; border-radius:8px; height:12px; overflow:hidden;">
<div style="background:linear-gradient(90deg, #ef4444, #f59e0b, #22c55e);
height:100%; width:{safety_score}%; border-radius:8px;"></div>
</div>
</div>
<div style="background:#1e293b; border-radius:10px; padding:12px; margin-bottom:10px;">
{bars_html}
</div>
<div style="background:#1e293b; border-radius:10px; padding:12px;">
<div style="color:#94a3b8; font-size:0.8em; margin-bottom:3px;">Analyzed prompt:</div>
<div style="color:#cbd5e1; font-style:italic; word-break:break-word; font-size:0.85em;">"{preview}"</div>
</div>
</div>
"""
def build_risk_assessment(label, confidence, probs):
safety_score = probs["Benign"] * 100
malicious_score = probs["Malicious"] * 100
if label == "Benign" and confidence > 0.85:
level, desc = "Low", "This prompt appears **safe**. No injection or jailbreak patterns detected."
elif label == "Benign":
level, desc = "Moderate", "Likely benign, but moderate confidence. Potentially ambiguous wording."
elif confidence > 0.85:
level, desc = "Critical", "**Malicious prompt detected** with high confidence. Likely injection or jailbreak attempt."
else:
level, desc = "High", "**Malicious prompt detected.** Possible injection or jailbreak. Review recommended."
return (
f"### Risk Level: {level}\n\n{desc}\n\n"
f"**Details:**\n"
f"- Safety score: **{safety_score:.0f}/100**\n"
f"- Predicted class: **{label}** ({confidence*100:.1f}%)\n"
f"- P(Benign) = {probs['Benign']*100:.1f}% | P(Malicious) = {malicious_score:.1f}%\n"
)
def build_stats_html():
total = len(METADATA)
n_benign = sum(1 for m in METADATA if m["label"] == 0)
n_malicious = total - n_benign
cat_counts = {}
for m in METADATA:
cat_counts[m["category"]] = cat_counts.get(m["category"], 0) + 1
cats_html = ""
for cat in sorted(cat_counts.keys(), key=lambda c: -cat_counts[c]):
count = cat_counts[cat]
color = CATEGORY_COLORS.get(cat, CATEGORY_COLORS["unknown"])
pct = count / total * 100
label = CATEGORY_LABELS.get(cat, cat)
cats_html += (
f'<div style="display:flex; justify-content:space-between; padding:2px 0;">'
f'<span style="color:{color}; font-weight:500; font-size:0.85em;">{label}</span>'
f'<span style="color:#94a3b8; font-size:0.85em;">{count} ({pct:.1f}%)</span>'
f'</div>'
)
return f"""
<div style="background:#0f172a; border-radius:12px; padding:14px; font-family:system-ui,sans-serif;">
<div style="color:#e2e8f0; font-weight:700; margin-bottom:8px;">Dataset Statistics</div>
<div style="display:flex; gap:10px; margin-bottom:10px;">
<div style="flex:1; background:#1e293b; border-radius:8px; padding:8px; text-align:center;">
<div style="color:#94a3b8; font-size:0.75em;">Total</div>
<div style="color:#e2e8f0; font-weight:700; font-size:1.2em;">{total:,}</div>
</div>
<div style="flex:1; background:#1e293b; border-radius:8px; padding:8px; text-align:center;">
<div style="color:#22c55e; font-size:0.75em;">Benign</div>
<div style="color:#22c55e; font-weight:700; font-size:1.2em;">{n_benign:,}</div>
</div>
<div style="flex:1; background:#1e293b; border-radius:8px; padding:8px; text-align:center;">
<div style="color:#ef4444; font-size:0.75em;">Malicious</div>
<div style="color:#ef4444; font-weight:700; font-size:1.2em;">{n_malicious:,}</div>
</div>
</div>
<div style="background:#1e293b; border-radius:8px; padding:8px;">
{cats_html}
</div>
</div>
"""
# ---------------------------------------------------------------------------
# JavaScript to bridge Plotly clicks -> Gradio
# ---------------------------------------------------------------------------
PLOTLY_CLICK_JS = """
() => {
function setupClickHandler() {
const plotEl = document.querySelector('#tsne-chart .js-plotly-plot');
if (!plotEl) {
setTimeout(setupClickHandler, 500);
return;
}
function handleClick(data) {
if (data && data.points && data.points.length > 0) {
const idx = data.points[0].customdata;
if (idx !== undefined && idx !== null) {
const inputEl = document.querySelector('#click-index-input textarea') || document.querySelector('#click-index-input input');
if (inputEl) {
const proto = inputEl.tagName === 'TEXTAREA'
? window.HTMLTextAreaElement.prototype
: window.HTMLInputElement.prototype;
const nativeSetter = Object.getOwnPropertyDescriptor(proto, 'value').set;
nativeSetter.call(inputEl, String(idx));
inputEl.dispatchEvent(new Event('input', { bubbles: true }));
setTimeout(() => {
inputEl.dispatchEvent(new Event('change', { bubbles: true }));
}, 50);
}
}
}
}
plotEl.on('plotly_click', handleClick);
const observer = new MutationObserver(() => {
const newPlot = document.querySelector('#tsne-chart .js-plotly-plot');
if (newPlot && !newPlot._hasClickHandler) {
newPlot._hasClickHandler = true;
newPlot.on('plotly_click', handleClick);
}
});
observer.observe(document.querySelector('#tsne-chart') || document.body, {
childList: true, subtree: true
});
}
setTimeout(setupClickHandler, 1000);
}
"""
# ---------------------------------------------------------------------------
# Gradio Interface
# ---------------------------------------------------------------------------
TITLE_HTML = """
<div style="text-align:center; padding:10px 0 4px 0;">
<h1 style="font-size:1.8em; margin:0;">GuardLLM - Prompt Security Visualizer</h1>
<p style="color:#94a3b8; font-size:0.95em; margin-top:4px;">
Interactive t-SNE embedding space &bull;
<a href="https://huggingface.co/meta-llama/Llama-Prompt-Guard-2-86M" target="_blank" style="color:#60a5fa;">
Llama Prompt Guard 2</a> &bull;
<a href="https://huggingface.co/datasets/neuralchemy/Prompt-injection-dataset" target="_blank" style="color:#60a5fa;">
neuralchemy dataset</a>
</p>
</div>
"""
HOW_TO_HTML = """
<div style="background:linear-gradient(135deg, #0f172a 0%, #1e293b 100%); border:1px solid #334155; border-radius:12px; padding:16px 20px; margin:0 0 8px 0; font-family:system-ui,sans-serif;">
<div style="color:#e2e8f0; font-weight:700; font-size:1em; margin-bottom:8px;">How to use this tool</div>
<div style="display:flex; flex-wrap:wrap; gap:12px;">
<div style="flex:1; min-width:180px; background:#1e293b; border-radius:8px; padding:10px 12px;">
<div style="color:#60a5fa; font-weight:600; font-size:0.85em; margin-bottom:4px;">1. Explore the map</div>
<div style="color:#94a3b8; font-size:0.8em; line-height:1.4;">Each dot represents a prompt from the dataset, positioned by semantic similarity. Colors indicate attack categories. Hover to preview, scroll to zoom, drag to pan.</div>
</div>
<div style="flex:1; min-width:180px; background:#1e293b; border-radius:8px; padding:10px 12px;">
<div style="color:#f59e0b; font-weight:600; font-size:0.85em; margin-bottom:4px;">2. Click to analyze</div>
<div style="color:#94a3b8; font-size:0.8em; line-height:1.4;">Click any point to run it through <strong style="color:#cbd5e1;">Llama Prompt Guard 2</strong>. The right panel will show the risk classification, safety score, and confidence breakdown.</div>
</div>
<div style="flex:1; min-width:180px; background:#1e293b; border-radius:8px; padding:10px 12px;">
<div style="color:#22c55e; font-weight:600; font-size:0.85em; margin-bottom:4px;">3. Test your own prompts</div>
<div style="color:#94a3b8; font-size:0.8em; line-height:1.4;">Type or paste any prompt in the <strong style="color:#cbd5e1;">Custom prompt</strong> field and hit Analyze to check if it would be flagged as an injection attempt.</div>
</div>
</div>
</div>
"""
with gr.Blocks(
title="GuardLLM - Prompt Security Visualizer",
) as demo:
gr.HTML(TITLE_HTML)
gr.HTML(HOW_TO_HTML)
click_index = gr.Textbox(
value="",
visible=True,
elem_id="click-index-input",
)
with gr.Row():
# ---- Left: t-SNE chart + filters ----
with gr.Column(scale=3):
with gr.Row():
select_all_btn = gr.Button("Select All", size="sm", scale=1)
deselect_all_btn = gr.Button("Deselect All", size="sm", scale=1)
category_filter = gr.CheckboxGroup(
choices=UNIQUE_CATEGORIES,
value=UNIQUE_CATEGORIES,
label="Filter by category",
interactive=True,
)
tsne_plot = gr.Plot(
value=build_tsne_figure(),
label="t-SNE Space",
elem_id="tsne-chart",
)
gr.Markdown(
"*Click a point to analyze it. "
"Hover to preview text. Use scroll wheel to zoom.*"
)
# ---- Right: Analysis first, then stats (swapped) ----
with gr.Column(scale=2):
gr.Markdown("### Analysis Result")
result_html = gr.HTML(value=empty_analysis_html())
risk_md = gr.Markdown(value="")
full_prompt = gr.Textbox(label="Full prompt", lines=3, interactive=False, visible=True)
gr.Markdown("---")
gr.Markdown("### Select a prompt")
prompt_dropdown = gr.Dropdown(
choices=DROPDOWN_CHOICES,
label="Search dataset",
filterable=True,
interactive=True,
)
gr.Markdown("### Or analyze a custom prompt")
manual_input = gr.Textbox(
label="Custom prompt",
placeholder="Type or paste a prompt...",
lines=2,
)
analyze_btn = gr.Button("Analyze", variant="primary")
gr.Markdown("---")
gr.HTML(build_stats_html())
# ---- Events ----
category_filter.change(
fn=on_filter_change,
inputs=[category_filter],
outputs=[tsne_plot],
)
select_all_btn.click(
fn=select_all_categories,
inputs=[],
outputs=[category_filter, tsne_plot],
)
deselect_all_btn.click(
fn=deselect_all_categories,
inputs=[],
outputs=[category_filter, tsne_plot],
)
click_index.change(
fn=on_index_input,
inputs=[click_index],
outputs=[result_html, risk_md, full_prompt],
)
prompt_dropdown.change(
fn=on_dropdown_select,
inputs=[prompt_dropdown],
outputs=[result_html, risk_md, full_prompt],
)
analyze_btn.click(
fn=on_manual_analyze,
inputs=[manual_input],
outputs=[result_html, risk_md],
)
manual_input.submit(
fn=on_manual_analyze,
inputs=[manual_input],
outputs=[result_html, risk_md],
)
demo.load(fn=None, inputs=None, outputs=None, js=PLOTLY_CLICK_JS)
gr.Markdown(
"""
---
<div style="text-align:center; color:#64748b; font-size:0.8em;">
<strong>GuardLLM</strong> - Prompt Security Visualizer<br>
Model: <a href="https://huggingface.co/meta-llama/Llama-Prompt-Guard-2-86M">
Llama Prompt Guard 2 (86M)</a> by Meta |
Dataset: <a href="https://huggingface.co/datasets/neuralchemy/Prompt-injection-dataset">
neuralchemy/Prompt-injection-dataset</a>
</div>
"""
)
logger.info("Gradio app built. Ready to launch.")
if __name__ == "__main__":
demo.launch(css="#click-index-input { position:absolute !important; width:1px !important; height:1px !important; overflow:hidden !important; opacity:0 !important; pointer-events:none !important; }")