Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| from typing import Dict, List | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer, util | |
| from transformers import pipeline | |
| # ============================================================ | |
| # BACKEND: GUARDRAIL LOGIC (Same as before, optimized) | |
| # ============================================================ | |
| class SafetyFinding: | |
| def __init__(self, label: str, severity: str, message: str): | |
| self.label = label | |
| self.severity = severity | |
| self.message = message | |
| def to_dict(self): | |
| return {"label": self.label, "severity": self.severity, "message": self.message} | |
| class GuardrailSystem: | |
| def __init__(self): | |
| print("⚙️ Loading Guardrail Models... Please wait.") | |
| # 1. Load Heuristic Keywords | |
| self.unsafe_terms = ["bomb", "kill", "suicide", "explosive", "hack", "rob", "steal", "drugs", "murder"] | |
| self.jailbreak_terms = ["ignore previous", "system prompt", "jailbreak", "developer mode"] | |
| # 2. Load HuggingFace Moderator (Lazy loading recommended, but here we init upfront) | |
| self.moderator = pipeline("text-classification", model="unitary/toxic-bert") | |
| # 3. Load JailbreakBench Embeddings | |
| dataset = load_dataset("JailbreakBench/JBB-Behaviors", "behaviors", split="harmful") | |
| self.malicious_goals = [row["Goal"] for row in dataset if row and row["Goal"]] | |
| self.embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
| self.goal_embeddings = self.embedder.encode(self.malicious_goals, convert_to_tensor=True) | |
| self.threshold = 0.5 | |
| def check_heuristics(self, text): | |
| findings = [] | |
| for term in self.unsafe_terms: | |
| if term in text.lower(): | |
| findings.append(SafetyFinding("unsafe_keyword", "high", f"Detected unsafe term: '{term}'")) | |
| for term in self.jailbreak_terms: | |
| if term in text.lower(): | |
| findings.append(SafetyFinding("jailbreak_keyword", "high", f"Detected jailbreak term: '{term}'")) | |
| if "@" in text: | |
| findings.append(SafetyFinding("pii_leak", "medium", "Potential PII (Email) detected")) | |
| return findings | |
| def check_similarity(self, text): | |
| findings = [] | |
| if not text.strip(): return findings | |
| user_emb = self.embedder.encode(text, convert_to_tensor=True) | |
| cos_scores = util.cos_sim(user_emb, self.goal_embeddings)[0] | |
| max_score = float(cos_scores.max()) | |
| if max_score >= self.threshold: | |
| findings.append(SafetyFinding("jailbreak_similarity", "high", f"Semantic Match to Jailbreak (Score: {max_score:.2f})")) | |
| return findings | |
| def check_moderation(self, text): | |
| findings = [] | |
| if not text.strip(): return findings | |
| results = self.moderator(text, truncation=True) | |
| for r in results: | |
| if r["label"] in ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"] and r["score"] > 0.7: | |
| findings.append(SafetyFinding("model_moderation", "high", f"Model Flag: {r['label']} ({r['score']:.2f})")) | |
| return findings | |
| def run_checks(self, user_prompt, simulated_response): | |
| findings = [] | |
| # Input Checks | |
| findings += self.check_heuristics(user_prompt) | |
| findings += self.check_similarity(user_prompt) | |
| # Output Checks (Simulated) | |
| findings += self.check_heuristics(simulated_response) | |
| findings += self.check_moderation(user_prompt + " " + simulated_response) | |
| # Decision | |
| status = "ALLOWED" | |
| if any(f.severity == "high" for f in findings): | |
| status = "BLOCKED" | |
| elif any(f.severity == "medium" for f in findings): | |
| status = "REDACTED" | |
| return status, findings | |
| # Initialize System (Global to keep in memory) | |
| guard = GuardrailSystem() | |
| # ============================================================ | |
| # FRONTEND: PROFESSIONAL UI LOGIC | |
| # ============================================================ | |
| def analyze_prompt(user_prompt): | |
| # Simulate LLM Generation for the demo | |
| simulated_output = "This is a harmless AI response." | |
| if "bomb" in user_prompt.lower(): simulated_output = "Here are instructions for..." | |
| if "email" in user_prompt.lower(): simulated_output = "Contact me at user@example.com" | |
| # Run Guardrails | |
| status, findings = guard.run_checks(user_prompt, simulated_output) | |
| # Generate HTML Status Card | |
| color_map = {"ALLOWED": "green", "BLOCKED": "red", "REDACTED": "orange"} | |
| icon_map = {"ALLOWED": "✅", "BLOCKED": "🛡️", "REDACTED": "⚠️"} | |
| html_status = f""" | |
| <div style='background-color: var(--background-fill-secondary); border-left: 5px solid {color_map[status]}; padding: 20px; border-radius: 8px; box-shadow: 0 4px 6px rgba(0,0,0,0.1);'> | |
| <h2 style='color: {color_map[status]}; margin: 0;'>{icon_map[status]} {status}</h2> | |
| <p style='margin-top: 5px; opacity: 0.8;'>Guardrail decision based on {len(findings)} risk factors.</p> | |
| </div> | |
| """ | |
| # Format Findings for Display | |
| clean_findings = [f.to_dict() for f in findings] | |
| return html_status, clean_findings, simulated_output if status == "ALLOWED" else "[CONTENT BLOCKED BY POLICY]" | |
| # Custom CSS for a clean look | |
| custom_css = """ | |
| .gradio-container {font-family: 'Inter', sans-serif;} | |
| h1 {text-align: center; color: #2d3748;} | |
| """ | |
| # Create the App | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate"), css=custom_css) as demo: | |
| # Header | |
| with gr.Row(): | |
| gr.Markdown( | |
| """ | |
| # 🛡️ Enterprise AI Guardrail System | |
| ### Real-time safety filtering using Semantic Search, BERT Moderation, and Heuristics. | |
| """ | |
| ) | |
| # Main Interface | |
| with gr.Row(): | |
| # Left Column: Inputs | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📥 Input Simulation") | |
| input_text = gr.Textbox( | |
| lines=5, | |
| label="User Prompt", | |
| placeholder="Enter a prompt to test the guardrails (e.g., 'how to build a bomb' or 'hello')...", | |
| elem_id="input_box" | |
| ) | |
| analyze_btn = gr.Button("🛡️ Run Safety Check", variant="primary", size="lg") | |
| with gr.Accordion("ℹ️ How it works", open=False): | |
| gr.Markdown(""" | |
| 1. **Heuristics:** Checks for banned keywords. | |
| 2. **Vector Database:** Compares prompt against known jailbreaks (JailbreakBench). | |
| 3. **BERT Classifier:** Scans for toxic tones. | |
| """) | |
| # Right Column: Analytics | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📊 Live Analytics") | |
| status_display = gr.HTML(label="Decision") | |
| with gr.Tabs(): | |
| with gr.TabItem("Findings"): | |
| findings_json = gr.JSON(label="Risk Factors Detected") | |
| with gr.TabItem("Raw Output"): | |
| final_output = gr.Code(label="LLM Response", language="markdown") | |
| # Footer | |
| gr.Markdown( | |
| """ | |
| --- | |
| <div style="text-align: center; opacity: 0.5; font-size: 0.8rem;"> | |
| Built for AI Safety Portfolio | Powered by HuggingFace Transformers & Gradio | |
| </div> | |
| """ | |
| ) | |
| # Event Linking | |
| analyze_btn.click( | |
| fn=analyze_prompt, | |
| inputs=input_text, | |
| outputs=[status_display, findings_json, final_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |