10tenfirestorm's picture
Update app.py
6130c63 verified
import gradio as gr
import os
from typing import Dict, List
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
# ============================================================
# BACKEND: GUARDRAIL LOGIC (Same as before, optimized)
# ============================================================
class SafetyFinding:
def __init__(self, label: str, severity: str, message: str):
self.label = label
self.severity = severity
self.message = message
def to_dict(self):
return {"label": self.label, "severity": self.severity, "message": self.message}
class GuardrailSystem:
def __init__(self):
print("⚙️ Loading Guardrail Models... Please wait.")
# 1. Load Heuristic Keywords
self.unsafe_terms = ["bomb", "kill", "suicide", "explosive", "hack", "rob", "steal", "drugs", "murder"]
self.jailbreak_terms = ["ignore previous", "system prompt", "jailbreak", "developer mode"]
# 2. Load HuggingFace Moderator (Lazy loading recommended, but here we init upfront)
self.moderator = pipeline("text-classification", model="unitary/toxic-bert")
# 3. Load JailbreakBench Embeddings
dataset = load_dataset("JailbreakBench/JBB-Behaviors", "behaviors", split="harmful")
self.malicious_goals = [row["Goal"] for row in dataset if row and row["Goal"]]
self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
self.goal_embeddings = self.embedder.encode(self.malicious_goals, convert_to_tensor=True)
self.threshold = 0.5
def check_heuristics(self, text):
findings = []
for term in self.unsafe_terms:
if term in text.lower():
findings.append(SafetyFinding("unsafe_keyword", "high", f"Detected unsafe term: '{term}'"))
for term in self.jailbreak_terms:
if term in text.lower():
findings.append(SafetyFinding("jailbreak_keyword", "high", f"Detected jailbreak term: '{term}'"))
if "@" in text:
findings.append(SafetyFinding("pii_leak", "medium", "Potential PII (Email) detected"))
return findings
def check_similarity(self, text):
findings = []
if not text.strip(): return findings
user_emb = self.embedder.encode(text, convert_to_tensor=True)
cos_scores = util.cos_sim(user_emb, self.goal_embeddings)[0]
max_score = float(cos_scores.max())
if max_score >= self.threshold:
findings.append(SafetyFinding("jailbreak_similarity", "high", f"Semantic Match to Jailbreak (Score: {max_score:.2f})"))
return findings
def check_moderation(self, text):
findings = []
if not text.strip(): return findings
results = self.moderator(text, truncation=True)
for r in results:
if r["label"] in ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"] and r["score"] > 0.7:
findings.append(SafetyFinding("model_moderation", "high", f"Model Flag: {r['label']} ({r['score']:.2f})"))
return findings
def run_checks(self, user_prompt, simulated_response):
findings = []
# Input Checks
findings += self.check_heuristics(user_prompt)
findings += self.check_similarity(user_prompt)
# Output Checks (Simulated)
findings += self.check_heuristics(simulated_response)
findings += self.check_moderation(user_prompt + " " + simulated_response)
# Decision
status = "ALLOWED"
if any(f.severity == "high" for f in findings):
status = "BLOCKED"
elif any(f.severity == "medium" for f in findings):
status = "REDACTED"
return status, findings
# Initialize System (Global to keep in memory)
guard = GuardrailSystem()
# ============================================================
# FRONTEND: PROFESSIONAL UI LOGIC
# ============================================================
def analyze_prompt(user_prompt):
# Simulate LLM Generation for the demo
simulated_output = "This is a harmless AI response."
if "bomb" in user_prompt.lower(): simulated_output = "Here are instructions for..."
if "email" in user_prompt.lower(): simulated_output = "Contact me at user@example.com"
# Run Guardrails
status, findings = guard.run_checks(user_prompt, simulated_output)
# Generate HTML Status Card
color_map = {"ALLOWED": "green", "BLOCKED": "red", "REDACTED": "orange"}
icon_map = {"ALLOWED": "✅", "BLOCKED": "🛡️", "REDACTED": "⚠️"}
html_status = f"""
<div style='background-color: var(--background-fill-secondary); border-left: 5px solid {color_map[status]}; padding: 20px; border-radius: 8px; box-shadow: 0 4px 6px rgba(0,0,0,0.1);'>
<h2 style='color: {color_map[status]}; margin: 0;'>{icon_map[status]} {status}</h2>
<p style='margin-top: 5px; opacity: 0.8;'>Guardrail decision based on {len(findings)} risk factors.</p>
</div>
"""
# Format Findings for Display
clean_findings = [f.to_dict() for f in findings]
return html_status, clean_findings, simulated_output if status == "ALLOWED" else "[CONTENT BLOCKED BY POLICY]"
# Custom CSS for a clean look
custom_css = """
.gradio-container {font-family: 'Inter', sans-serif;}
h1 {text-align: center; color: #2d3748;}
"""
# Create the App
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate"), css=custom_css) as demo:
# Header
with gr.Row():
gr.Markdown(
"""
# 🛡️ Enterprise AI Guardrail System
### Real-time safety filtering using Semantic Search, BERT Moderation, and Heuristics.
"""
)
# Main Interface
with gr.Row():
# Left Column: Inputs
with gr.Column(scale=1):
gr.Markdown("### 📥 Input Simulation")
input_text = gr.Textbox(
lines=5,
label="User Prompt",
placeholder="Enter a prompt to test the guardrails (e.g., 'how to build a bomb' or 'hello')...",
elem_id="input_box"
)
analyze_btn = gr.Button("🛡️ Run Safety Check", variant="primary", size="lg")
with gr.Accordion("ℹ️ How it works", open=False):
gr.Markdown("""
1. **Heuristics:** Checks for banned keywords.
2. **Vector Database:** Compares prompt against known jailbreaks (JailbreakBench).
3. **BERT Classifier:** Scans for toxic tones.
""")
# Right Column: Analytics
with gr.Column(scale=1):
gr.Markdown("### 📊 Live Analytics")
status_display = gr.HTML(label="Decision")
with gr.Tabs():
with gr.TabItem("Findings"):
findings_json = gr.JSON(label="Risk Factors Detected")
with gr.TabItem("Raw Output"):
final_output = gr.Code(label="LLM Response", language="markdown")
# Footer
gr.Markdown(
"""
---
<div style="text-align: center; opacity: 0.5; font-size: 0.8rem;">
Built for AI Safety Portfolio | Powered by HuggingFace Transformers & Gradio
</div>
"""
)
# Event Linking
analyze_btn.click(
fn=analyze_prompt,
inputs=input_text,
outputs=[status_display, findings_json, final_output]
)
if __name__ == "__main__":
demo.launch()