Spaces:
Running
Running
| """ | |
| GuardLLM - Interactive Prompt Security Visualizer | |
| Combines t-SNE embedding visualization with real-time prompt risk analysis. | |
| Powered by Llama Prompt Guard 2 (86M) and neuralchemy/Prompt-injection-dataset. | |
| """ | |
| import logging | |
| import sys | |
| import json | |
| import traceback | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| import plotly.io as pio | |
| from pathlib import Path | |
| # --------------------------------------------------------------------------- | |
| # Logging | |
| # --------------------------------------------------------------------------- | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| handlers=[logging.StreamHandler(sys.stdout)], | |
| ) | |
| logger = logging.getLogger("GuardLLM") | |
| # --------------------------------------------------------------------------- | |
| # Color palette for categories | |
| # --------------------------------------------------------------------------- | |
| CATEGORY_COLORS = { | |
| "benign": "#22c55e", | |
| "direct_injection": "#ef4444", | |
| "jailbreak": "#f97316", | |
| "system_extraction": "#a855f7", | |
| "encoding_obfuscation": "#ec4899", | |
| "persona_replacement": "#f59e0b", | |
| "indirect_injection": "#e11d48", | |
| "token_smuggling": "#7c3aed", | |
| "many_shot": "#06b6d4", | |
| "crescendo": "#14b8a6", | |
| "context_overflow": "#8b5cf6", | |
| "prompt_leaking": "#d946ef", | |
| "unknown": "#64748b", | |
| } | |
| CATEGORY_LABELS = { | |
| "benign": "Benign", | |
| "direct_injection": "Direct Injection", | |
| "jailbreak": "Jailbreak", | |
| "system_extraction": "System Extraction", | |
| "encoding_obfuscation": "Encoding / Obfuscation", | |
| "persona_replacement": "Persona Replacement", | |
| "indirect_injection": "Indirect Injection", | |
| "token_smuggling": "Token Smuggling", | |
| "many_shot": "Many-Shot", | |
| "crescendo": "Crescendo", | |
| "context_overflow": "Context Overflow", | |
| "prompt_leaking": "Prompt Leaking", | |
| "unknown": "Unknown", | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Lazy-loaded risk classifier (Llama Prompt Guard 2) | |
| # --------------------------------------------------------------------------- | |
| MODEL_ID = "meta-llama/Llama-Prompt-Guard-2-86M" | |
| LABELS = ["Benign", "Malicious"] | |
| _classifier = {"tokenizer": None, "model": None, "device": None} | |
| def get_classifier(): | |
| if _classifier["model"] is None: | |
| logger.info("Lazy-loading Llama Prompt Guard 2...") | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| tok = AutoTokenizer.from_pretrained(MODEL_ID) | |
| mdl = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) | |
| mdl.eval() | |
| dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| mdl.to(dev) | |
| _classifier["tokenizer"] = tok | |
| _classifier["model"] = mdl | |
| _classifier["device"] = dev | |
| logger.info("Classifier loaded on %s", dev) | |
| return _classifier["tokenizer"], _classifier["model"], _classifier["device"] | |
| # --------------------------------------------------------------------------- | |
| # Load precomputed t-SNE data | |
| # --------------------------------------------------------------------------- | |
| CACHE_DIR = Path(__file__).parent / "cache" | |
| CACHE_FILE = CACHE_DIR / "embeddings_tsne.npz" | |
| META_FILE = CACHE_DIR / "metadata.json" | |
| logger.info("Loading precomputed t-SNE cache from %s", CACHE_DIR) | |
| if not CACHE_FILE.exists() or not META_FILE.exists(): | |
| raise RuntimeError( | |
| "Cache files not found in %s. Run precompute.py first." % CACHE_DIR | |
| ) | |
| _npz = np.load(CACHE_FILE) | |
| TSNE_COORDS = _npz["tsne_2d"] | |
| with open(META_FILE, "r", encoding="utf-8") as f: | |
| METADATA = json.load(f) | |
| logger.info("Loaded %d points for visualization", len(METADATA)) | |
| ALL_TEXTS = [m["text"] for m in METADATA] | |
| ALL_CATEGORIES = [m["category"] for m in METADATA] | |
| ALL_SEVERITIES = [m["severity"] for m in METADATA] | |
| ALL_LABELS_DS = [m["label"] for m in METADATA] | |
| UNIQUE_CATEGORIES = sorted(set(ALL_CATEGORIES)) | |
| DROPDOWN_CHOICES = [] | |
| for i, m in enumerate(METADATA): | |
| preview = m["text"][:70].replace("\n", " ") | |
| if len(m["text"]) > 70: | |
| preview += "..." | |
| DROPDOWN_CHOICES.append(f"{i} | {m['category']} | {preview}") | |
| # --------------------------------------------------------------------------- | |
| # Analysis function | |
| # --------------------------------------------------------------------------- | |
| def analyze_prompt(text): | |
| if not text or not text.strip(): | |
| return {}, 0.0 | |
| tokenizer, model, DEVICE = get_classifier() | |
| inputs = tokenizer( | |
| text, return_tensors="pt", truncation=True, max_length=512, padding=True | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy() | |
| pred_idx = int(np.argmax(probs)) | |
| prob_dict = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))} | |
| safety = float(probs[0]) | |
| return prob_dict, safety | |
| # --------------------------------------------------------------------------- | |
| # Build the t-SNE Plotly figure | |
| # --------------------------------------------------------------------------- | |
| def build_tsne_figure(selected_categories=None): | |
| fig = go.Figure() | |
| for cat in UNIQUE_CATEGORIES: | |
| indices = [ | |
| i for i, c in enumerate(ALL_CATEGORIES) | |
| if c == cat | |
| and (selected_categories is None or cat in selected_categories) | |
| ] | |
| if not indices: | |
| continue | |
| x = TSNE_COORDS[indices, 0].tolist() | |
| y = TSNE_COORDS[indices, 1].tolist() | |
| texts_preview = [ | |
| ALL_TEXTS[i][:80].replace("\n", " ") + ("..." if len(ALL_TEXTS[i]) > 80 else "") | |
| for i in indices | |
| ] | |
| severities = [ALL_SEVERITIES[i] or "benign" for i in indices] | |
| hover_texts = [ | |
| f"<b>{CATEGORY_LABELS.get(cat, cat)}</b><br>" | |
| f"Severity: {sev}<br>" | |
| f"Index: {idx}<br>" | |
| f"<i>{txt}</i>" | |
| for idx, txt, sev in zip(indices, texts_preview, severities) | |
| ] | |
| color = CATEGORY_COLORS.get(cat, CATEGORY_COLORS["unknown"]) | |
| label = CATEGORY_LABELS.get(cat, cat) | |
| fig.add_trace(go.Scatter( | |
| x=x, y=y, | |
| mode="markers", | |
| name=label, | |
| marker=dict( | |
| size=5 if len(indices) > 500 else 7, | |
| color=color, | |
| opacity=0.7, | |
| line=dict(width=0.5, color="rgba(255,255,255,0.2)"), | |
| ), | |
| text=hover_texts, | |
| hoverinfo="text", | |
| customdata=[str(i) for i in indices], | |
| )) | |
| fig.update_layout( | |
| template="plotly_dark", | |
| paper_bgcolor="#0f172a", | |
| plot_bgcolor="#1e293b", | |
| title=dict( | |
| text="t-SNE Embedding Space - Prompt Security Landscape", | |
| font=dict(size=16, color="#e2e8f0"), | |
| x=0.5, | |
| ), | |
| legend=dict( | |
| title=dict(text="Category", font=dict(color="#94a3b8")), | |
| bgcolor="rgba(15,23,42,0.9)", | |
| bordercolor="#334155", | |
| borderwidth=1, | |
| font=dict(color="#cbd5e1", size=10), | |
| itemsizing="constant", | |
| ), | |
| xaxis=dict( | |
| title="t-SNE 1", showgrid=True, gridcolor="#334155", | |
| zeroline=False, color="#94a3b8", | |
| ), | |
| yaxis=dict( | |
| title="t-SNE 2", showgrid=True, gridcolor="#334155", | |
| zeroline=False, color="#94a3b8", | |
| ), | |
| margin=dict(l=40, r=40, t=50, b=40), | |
| height=600, | |
| dragmode="pan", | |
| ) | |
| return fig | |
| # --------------------------------------------------------------------------- | |
| # Callbacks | |
| # --------------------------------------------------------------------------- | |
| def on_filter_change(categories): | |
| sel = categories if categories else None | |
| return build_tsne_figure(sel) | |
| def select_all_categories(): | |
| return gr.update(value=UNIQUE_CATEGORIES), build_tsne_figure(UNIQUE_CATEGORIES) | |
| def deselect_all_categories(): | |
| return gr.update(value=[]), build_tsne_figure([]) | |
| def on_dropdown_select(choice): | |
| if not choice: | |
| return empty_analysis_html(), "*Select a prompt.*", "" | |
| try: | |
| idx = int(choice.split(" | ")[0]) | |
| text = ALL_TEXTS[idx] | |
| category = ALL_CATEGORIES[idx] | |
| severity = ALL_SEVERITIES[idx] or "N/A" | |
| ground_truth = "Malicious" if ALL_LABELS_DS[idx] == 1 else "Benign" | |
| prob_dict, safety = analyze_prompt(text) | |
| pred_label = max(prob_dict, key=prob_dict.get) | |
| confidence = prob_dict[pred_label] | |
| result_html = build_result_html(pred_label, confidence, prob_dict, text) | |
| risk_text = build_risk_assessment(pred_label, confidence, prob_dict) | |
| risk_text += ( | |
| f"\n\n---\n**Dataset metadata:**\n" | |
| f"- Category: **{CATEGORY_LABELS.get(category, category)}**\n" | |
| f"- Severity: **{severity}**\n" | |
| f"- Ground truth: **{ground_truth}**\n" | |
| ) | |
| return result_html, risk_text, text | |
| except Exception as e: | |
| logger.error("Error: %s", e) | |
| return empty_analysis_html(), f"Error: {e}", "" | |
| def on_index_input(idx_str): | |
| if not idx_str or not idx_str.strip(): | |
| return empty_analysis_html(), "*Click a point on the chart.*", "" | |
| try: | |
| idx = int(idx_str.strip()) | |
| if idx < 0 or idx >= len(ALL_TEXTS): | |
| return empty_analysis_html(), f"Invalid index: {idx}", "" | |
| text = ALL_TEXTS[idx] | |
| category = ALL_CATEGORIES[idx] | |
| severity = ALL_SEVERITIES[idx] or "N/A" | |
| ground_truth = "Malicious" if ALL_LABELS_DS[idx] == 1 else "Benign" | |
| prob_dict, safety = analyze_prompt(text) | |
| pred_label = max(prob_dict, key=prob_dict.get) | |
| confidence = prob_dict[pred_label] | |
| result_html = build_result_html(pred_label, confidence, prob_dict, text) | |
| risk_text = build_risk_assessment(pred_label, confidence, prob_dict) | |
| risk_text += ( | |
| f"\n\n---\n**Dataset metadata:**\n" | |
| f"- Category: **{CATEGORY_LABELS.get(category, category)}**\n" | |
| f"- Severity: **{severity}**\n" | |
| f"- Ground truth: **{ground_truth}**\n" | |
| ) | |
| return result_html, risk_text, text | |
| except Exception as e: | |
| logger.error("Error: %s", e) | |
| return empty_analysis_html(), f"Error: {e}", "" | |
| def on_manual_analyze(text): | |
| if not text or not text.strip(): | |
| return empty_analysis_html(), "" | |
| prob_dict, safety = analyze_prompt(text) | |
| pred_label = max(prob_dict, key=prob_dict.get) | |
| confidence = prob_dict[pred_label] | |
| result_html = build_result_html(pred_label, confidence, prob_dict, text) | |
| risk_text = build_risk_assessment(pred_label, confidence, prob_dict) | |
| return result_html, risk_text | |
| # --------------------------------------------------------------------------- | |
| # UI builders | |
| # --------------------------------------------------------------------------- | |
| def empty_analysis_html(): | |
| return """ | |
| <div style="text-align:center; padding:30px; color:#94a3b8;"> | |
| <p style="font-size:1em;">Click a point on the chart,<br> | |
| select a prompt from the list,<br> | |
| or enter a custom prompt below.</p> | |
| </div> | |
| """ | |
| def build_result_html(label, confidence, probs, text): | |
| color = "#22c55e" if label == "Benign" else "#ef4444" | |
| emoji = "\u2705" if label == "Benign" else "\u26a0\ufe0f" | |
| pct = confidence * 100 | |
| safety_score = probs["Benign"] * 100 | |
| safety_color = ( | |
| "#22c55e" if safety_score >= 70 | |
| else "#f59e0b" if safety_score >= 40 | |
| else "#ef4444" | |
| ) | |
| bars_html = "" | |
| for lbl in LABELS: | |
| p = probs[lbl] * 100 | |
| c = "#22c55e" if lbl == "Benign" else "#ef4444" | |
| bars_html += f""" | |
| <div style="margin-bottom:8px;"> | |
| <div style="display:flex; justify-content:space-between; margin-bottom:2px;"> | |
| <span style="font-weight:600; color:#e2e8f0;">{lbl}</span> | |
| <span style="color:#cbd5e1; font-weight:600;">{p:.1f}%</span> | |
| </div> | |
| <div style="background:#1e293b; border-radius:8px; height:18px; overflow:hidden;"> | |
| <div style="background:{c}; height:100%; width:{p}%; border-radius:8px;"></div> | |
| </div> | |
| </div> | |
| """ | |
| preview = text[:150].replace("<", "<").replace(">", ">") | |
| if len(text) > 150: | |
| preview += "..." | |
| return f""" | |
| <div style="background:#0f172a; border-radius:12px; padding:18px; font-family:system-ui,sans-serif;"> | |
| <div style="text-align:center; margin-bottom:14px;"> | |
| <div style="font-size:2em;">{emoji}</div> | |
| <div style="font-size:1.2em; font-weight:700; color:{color};">{label}</div> | |
| <div style="color:#94a3b8; font-size:0.85em;">Confidence: {pct:.1f}%</div> | |
| </div> | |
| <div style="background:#1e293b; border-radius:10px; padding:12px; margin-bottom:10px;"> | |
| <div style="display:flex; justify-content:space-between; margin-bottom:4px;"> | |
| <span style="color:#e2e8f0; font-weight:600;">Safety Score</span> | |
| <span style="color:{safety_color}; font-weight:700; font-size:1.1em;">{safety_score:.0f}/100</span> | |
| </div> | |
| <div style="background:#334155; border-radius:8px; height:12px; overflow:hidden;"> | |
| <div style="background:linear-gradient(90deg, #ef4444, #f59e0b, #22c55e); | |
| height:100%; width:{safety_score}%; border-radius:8px;"></div> | |
| </div> | |
| </div> | |
| <div style="background:#1e293b; border-radius:10px; padding:12px; margin-bottom:10px;"> | |
| {bars_html} | |
| </div> | |
| <div style="background:#1e293b; border-radius:10px; padding:12px;"> | |
| <div style="color:#94a3b8; font-size:0.8em; margin-bottom:3px;">Analyzed prompt:</div> | |
| <div style="color:#cbd5e1; font-style:italic; word-break:break-word; font-size:0.85em;">"{preview}"</div> | |
| </div> | |
| </div> | |
| """ | |
| def build_risk_assessment(label, confidence, probs): | |
| safety_score = probs["Benign"] * 100 | |
| malicious_score = probs["Malicious"] * 100 | |
| if label == "Benign" and confidence > 0.85: | |
| level, desc = "Low", "This prompt appears **safe**. No injection or jailbreak patterns detected." | |
| elif label == "Benign": | |
| level, desc = "Moderate", "Likely benign, but moderate confidence. Potentially ambiguous wording." | |
| elif confidence > 0.85: | |
| level, desc = "Critical", "**Malicious prompt detected** with high confidence. Likely injection or jailbreak attempt." | |
| else: | |
| level, desc = "High", "**Malicious prompt detected.** Possible injection or jailbreak. Review recommended." | |
| return ( | |
| f"### Risk Level: {level}\n\n{desc}\n\n" | |
| f"**Details:**\n" | |
| f"- Safety score: **{safety_score:.0f}/100**\n" | |
| f"- Predicted class: **{label}** ({confidence*100:.1f}%)\n" | |
| f"- P(Benign) = {probs['Benign']*100:.1f}% | P(Malicious) = {malicious_score:.1f}%\n" | |
| ) | |
| def build_stats_html(): | |
| total = len(METADATA) | |
| n_benign = sum(1 for m in METADATA if m["label"] == 0) | |
| n_malicious = total - n_benign | |
| cat_counts = {} | |
| for m in METADATA: | |
| cat_counts[m["category"]] = cat_counts.get(m["category"], 0) + 1 | |
| cats_html = "" | |
| for cat in sorted(cat_counts.keys(), key=lambda c: -cat_counts[c]): | |
| count = cat_counts[cat] | |
| color = CATEGORY_COLORS.get(cat, CATEGORY_COLORS["unknown"]) | |
| pct = count / total * 100 | |
| label = CATEGORY_LABELS.get(cat, cat) | |
| cats_html += ( | |
| f'<div style="display:flex; justify-content:space-between; padding:2px 0;">' | |
| f'<span style="color:{color}; font-weight:500; font-size:0.85em;">{label}</span>' | |
| f'<span style="color:#94a3b8; font-size:0.85em;">{count} ({pct:.1f}%)</span>' | |
| f'</div>' | |
| ) | |
| return f""" | |
| <div style="background:#0f172a; border-radius:12px; padding:14px; font-family:system-ui,sans-serif;"> | |
| <div style="color:#e2e8f0; font-weight:700; margin-bottom:8px;">Dataset Statistics</div> | |
| <div style="display:flex; gap:10px; margin-bottom:10px;"> | |
| <div style="flex:1; background:#1e293b; border-radius:8px; padding:8px; text-align:center;"> | |
| <div style="color:#94a3b8; font-size:0.75em;">Total</div> | |
| <div style="color:#e2e8f0; font-weight:700; font-size:1.2em;">{total:,}</div> | |
| </div> | |
| <div style="flex:1; background:#1e293b; border-radius:8px; padding:8px; text-align:center;"> | |
| <div style="color:#22c55e; font-size:0.75em;">Benign</div> | |
| <div style="color:#22c55e; font-weight:700; font-size:1.2em;">{n_benign:,}</div> | |
| </div> | |
| <div style="flex:1; background:#1e293b; border-radius:8px; padding:8px; text-align:center;"> | |
| <div style="color:#ef4444; font-size:0.75em;">Malicious</div> | |
| <div style="color:#ef4444; font-weight:700; font-size:1.2em;">{n_malicious:,}</div> | |
| </div> | |
| </div> | |
| <div style="background:#1e293b; border-radius:8px; padding:8px;"> | |
| {cats_html} | |
| </div> | |
| </div> | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # JavaScript to bridge Plotly clicks -> Gradio | |
| # --------------------------------------------------------------------------- | |
| PLOTLY_CLICK_JS = """ | |
| () => { | |
| function setupClickHandler() { | |
| const plotEl = document.querySelector('#tsne-chart .js-plotly-plot'); | |
| if (!plotEl) { | |
| setTimeout(setupClickHandler, 500); | |
| return; | |
| } | |
| function handleClick(data) { | |
| if (data && data.points && data.points.length > 0) { | |
| const idx = data.points[0].customdata; | |
| if (idx !== undefined && idx !== null) { | |
| const inputEl = document.querySelector('#click-index-input textarea') || document.querySelector('#click-index-input input'); | |
| if (inputEl) { | |
| const proto = inputEl.tagName === 'TEXTAREA' | |
| ? window.HTMLTextAreaElement.prototype | |
| : window.HTMLInputElement.prototype; | |
| const nativeSetter = Object.getOwnPropertyDescriptor(proto, 'value').set; | |
| nativeSetter.call(inputEl, String(idx)); | |
| inputEl.dispatchEvent(new Event('input', { bubbles: true })); | |
| setTimeout(() => { | |
| inputEl.dispatchEvent(new Event('change', { bubbles: true })); | |
| }, 50); | |
| } | |
| } | |
| } | |
| } | |
| plotEl.on('plotly_click', handleClick); | |
| const observer = new MutationObserver(() => { | |
| const newPlot = document.querySelector('#tsne-chart .js-plotly-plot'); | |
| if (newPlot && !newPlot._hasClickHandler) { | |
| newPlot._hasClickHandler = true; | |
| newPlot.on('plotly_click', handleClick); | |
| } | |
| }); | |
| observer.observe(document.querySelector('#tsne-chart') || document.body, { | |
| childList: true, subtree: true | |
| }); | |
| } | |
| setTimeout(setupClickHandler, 1000); | |
| } | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # Gradio Interface | |
| # --------------------------------------------------------------------------- | |
| TITLE_HTML = """ | |
| <div style="text-align:center; padding:10px 0 4px 0;"> | |
| <h1 style="font-size:1.8em; margin:0;">GuardLLM - Prompt Security Visualizer</h1> | |
| <p style="color:#94a3b8; font-size:0.95em; margin-top:4px;"> | |
| Interactive t-SNE embedding space • | |
| <a href="https://huggingface.co/meta-llama/Llama-Prompt-Guard-2-86M" target="_blank" style="color:#60a5fa;"> | |
| Llama Prompt Guard 2</a> • | |
| <a href="https://huggingface.co/datasets/neuralchemy/Prompt-injection-dataset" target="_blank" style="color:#60a5fa;"> | |
| neuralchemy dataset</a> | |
| </p> | |
| </div> | |
| """ | |
| HOW_TO_HTML = """ | |
| <div style="background:linear-gradient(135deg, #0f172a 0%, #1e293b 100%); border:1px solid #334155; border-radius:12px; padding:16px 20px; margin:0 0 8px 0; font-family:system-ui,sans-serif;"> | |
| <div style="color:#e2e8f0; font-weight:700; font-size:1em; margin-bottom:8px;">How to use this tool</div> | |
| <div style="display:flex; flex-wrap:wrap; gap:12px;"> | |
| <div style="flex:1; min-width:180px; background:#1e293b; border-radius:8px; padding:10px 12px;"> | |
| <div style="color:#60a5fa; font-weight:600; font-size:0.85em; margin-bottom:4px;">1. Explore the map</div> | |
| <div style="color:#94a3b8; font-size:0.8em; line-height:1.4;">Each dot represents a prompt from the dataset, positioned by semantic similarity. Colors indicate attack categories. Hover to preview, scroll to zoom, drag to pan.</div> | |
| </div> | |
| <div style="flex:1; min-width:180px; background:#1e293b; border-radius:8px; padding:10px 12px;"> | |
| <div style="color:#f59e0b; font-weight:600; font-size:0.85em; margin-bottom:4px;">2. Click to analyze</div> | |
| <div style="color:#94a3b8; font-size:0.8em; line-height:1.4;">Click any point to run it through <strong style="color:#cbd5e1;">Llama Prompt Guard 2</strong>. The right panel will show the risk classification, safety score, and confidence breakdown.</div> | |
| </div> | |
| <div style="flex:1; min-width:180px; background:#1e293b; border-radius:8px; padding:10px 12px;"> | |
| <div style="color:#22c55e; font-weight:600; font-size:0.85em; margin-bottom:4px;">3. Test your own prompts</div> | |
| <div style="color:#94a3b8; font-size:0.8em; line-height:1.4;">Type or paste any prompt in the <strong style="color:#cbd5e1;">Custom prompt</strong> field and hit Analyze to check if it would be flagged as an injection attempt.</div> | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| with gr.Blocks( | |
| title="GuardLLM - Prompt Security Visualizer", | |
| ) as demo: | |
| gr.HTML(TITLE_HTML) | |
| gr.HTML(HOW_TO_HTML) | |
| click_index = gr.Textbox( | |
| value="", | |
| visible=True, | |
| elem_id="click-index-input", | |
| ) | |
| with gr.Row(): | |
| # ---- Left: t-SNE chart + filters ---- | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| select_all_btn = gr.Button("Select All", size="sm", scale=1) | |
| deselect_all_btn = gr.Button("Deselect All", size="sm", scale=1) | |
| category_filter = gr.CheckboxGroup( | |
| choices=UNIQUE_CATEGORIES, | |
| value=UNIQUE_CATEGORIES, | |
| label="Filter by category", | |
| interactive=True, | |
| ) | |
| tsne_plot = gr.Plot( | |
| value=build_tsne_figure(), | |
| label="t-SNE Space", | |
| elem_id="tsne-chart", | |
| ) | |
| gr.Markdown( | |
| "*Click a point to analyze it. " | |
| "Hover to preview text. Use scroll wheel to zoom.*" | |
| ) | |
| # ---- Right: Analysis first, then stats (swapped) ---- | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Analysis Result") | |
| result_html = gr.HTML(value=empty_analysis_html()) | |
| risk_md = gr.Markdown(value="") | |
| full_prompt = gr.Textbox(label="Full prompt", lines=3, interactive=False, visible=True) | |
| gr.Markdown("---") | |
| gr.Markdown("### Select a prompt") | |
| prompt_dropdown = gr.Dropdown( | |
| choices=DROPDOWN_CHOICES, | |
| label="Search dataset", | |
| filterable=True, | |
| interactive=True, | |
| ) | |
| gr.Markdown("### Or analyze a custom prompt") | |
| manual_input = gr.Textbox( | |
| label="Custom prompt", | |
| placeholder="Type or paste a prompt...", | |
| lines=2, | |
| ) | |
| analyze_btn = gr.Button("Analyze", variant="primary") | |
| gr.Markdown("---") | |
| gr.HTML(build_stats_html()) | |
| # ---- Events ---- | |
| category_filter.change( | |
| fn=on_filter_change, | |
| inputs=[category_filter], | |
| outputs=[tsne_plot], | |
| ) | |
| select_all_btn.click( | |
| fn=select_all_categories, | |
| inputs=[], | |
| outputs=[category_filter, tsne_plot], | |
| ) | |
| deselect_all_btn.click( | |
| fn=deselect_all_categories, | |
| inputs=[], | |
| outputs=[category_filter, tsne_plot], | |
| ) | |
| click_index.change( | |
| fn=on_index_input, | |
| inputs=[click_index], | |
| outputs=[result_html, risk_md, full_prompt], | |
| ) | |
| prompt_dropdown.change( | |
| fn=on_dropdown_select, | |
| inputs=[prompt_dropdown], | |
| outputs=[result_html, risk_md, full_prompt], | |
| ) | |
| analyze_btn.click( | |
| fn=on_manual_analyze, | |
| inputs=[manual_input], | |
| outputs=[result_html, risk_md], | |
| ) | |
| manual_input.submit( | |
| fn=on_manual_analyze, | |
| inputs=[manual_input], | |
| outputs=[result_html, risk_md], | |
| ) | |
| demo.load(fn=None, inputs=None, outputs=None, js=PLOTLY_CLICK_JS) | |
| gr.Markdown( | |
| """ | |
| --- | |
| <div style="text-align:center; color:#64748b; font-size:0.8em;"> | |
| <strong>GuardLLM</strong> - Prompt Security Visualizer<br> | |
| Model: <a href="https://huggingface.co/meta-llama/Llama-Prompt-Guard-2-86M"> | |
| Llama Prompt Guard 2 (86M)</a> by Meta | | |
| Dataset: <a href="https://huggingface.co/datasets/neuralchemy/Prompt-injection-dataset"> | |
| neuralchemy/Prompt-injection-dataset</a> | |
| </div> | |
| """ | |
| ) | |
| logger.info("Gradio app built. Ready to launch.") | |
| if __name__ == "__main__": | |
| demo.launch(css="#click-index-input { position:absolute !important; width:1px !important; height:1px !important; overflow:hidden !important; opacity:0 !important; pointer-events:none !important; }") | |