import gradio as gr import os from typing import Dict, List from datasets import load_dataset from sentence_transformers import SentenceTransformer, util from transformers import pipeline # ============================================================ # BACKEND: GUARDRAIL LOGIC (Same as before, optimized) # ============================================================ class SafetyFinding: def __init__(self, label: str, severity: str, message: str): self.label = label self.severity = severity self.message = message def to_dict(self): return {"label": self.label, "severity": self.severity, "message": self.message} class GuardrailSystem: def __init__(self): print("âī¸ Loading Guardrail Models... Please wait.") # 1. Load Heuristic Keywords self.unsafe_terms = ["bomb", "kill", "suicide", "explosive", "hack", "rob", "steal", "drugs", "murder"] self.jailbreak_terms = ["ignore previous", "system prompt", "jailbreak", "developer mode"] # 2. Load HuggingFace Moderator (Lazy loading recommended, but here we init upfront) self.moderator = pipeline("text-classification", model="unitary/toxic-bert") # 3. Load JailbreakBench Embeddings dataset = load_dataset("JailbreakBench/JBB-Behaviors", "behaviors", split="harmful") self.malicious_goals = [row["Goal"] for row in dataset if row and row["Goal"]] self.embedder = SentenceTransformer("all-MiniLM-L6-v2") self.goal_embeddings = self.embedder.encode(self.malicious_goals, convert_to_tensor=True) self.threshold = 0.5 def check_heuristics(self, text): findings = [] for term in self.unsafe_terms: if term in text.lower(): findings.append(SafetyFinding("unsafe_keyword", "high", f"Detected unsafe term: '{term}'")) for term in self.jailbreak_terms: if term in text.lower(): findings.append(SafetyFinding("jailbreak_keyword", "high", f"Detected jailbreak term: '{term}'")) if "@" in text: findings.append(SafetyFinding("pii_leak", "medium", "Potential PII (Email) detected")) return findings def check_similarity(self, text): findings = [] if not text.strip(): return findings user_emb = self.embedder.encode(text, convert_to_tensor=True) cos_scores = util.cos_sim(user_emb, self.goal_embeddings)[0] max_score = float(cos_scores.max()) if max_score >= self.threshold: findings.append(SafetyFinding("jailbreak_similarity", "high", f"Semantic Match to Jailbreak (Score: {max_score:.2f})")) return findings def check_moderation(self, text): findings = [] if not text.strip(): return findings results = self.moderator(text, truncation=True) for r in results: if r["label"] in ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"] and r["score"] > 0.7: findings.append(SafetyFinding("model_moderation", "high", f"Model Flag: {r['label']} ({r['score']:.2f})")) return findings def run_checks(self, user_prompt, simulated_response): findings = [] # Input Checks findings += self.check_heuristics(user_prompt) findings += self.check_similarity(user_prompt) # Output Checks (Simulated) findings += self.check_heuristics(simulated_response) findings += self.check_moderation(user_prompt + " " + simulated_response) # Decision status = "ALLOWED" if any(f.severity == "high" for f in findings): status = "BLOCKED" elif any(f.severity == "medium" for f in findings): status = "REDACTED" return status, findings # Initialize System (Global to keep in memory) guard = GuardrailSystem() # ============================================================ # FRONTEND: PROFESSIONAL UI LOGIC # ============================================================ def analyze_prompt(user_prompt): # Simulate LLM Generation for the demo simulated_output = "This is a harmless AI response." if "bomb" in user_prompt.lower(): simulated_output = "Here are instructions for..." if "email" in user_prompt.lower(): simulated_output = "Contact me at user@example.com" # Run Guardrails status, findings = guard.run_checks(user_prompt, simulated_output) # Generate HTML Status Card color_map = {"ALLOWED": "green", "BLOCKED": "red", "REDACTED": "orange"} icon_map = {"ALLOWED": "â ", "BLOCKED": "đĄī¸", "REDACTED": "â ī¸"} html_status = f"""
Guardrail decision based on {len(findings)} risk factors.