Spaces:
Sleeping
Sleeping
File size: 4,873 Bytes
b8630cb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | """
Agent 4: AI Prompt Injection Detection Module
Uses a fine‑tuned DeBERTa model (MNLI) + rule‑based patterns.
"""
import re
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
class PromptInjectionAgent:
"""
Detects prompt injection and jailbreak attempts in user inputs.
Combines a transformer model (trained on MNLI) with heuristic rules.
"""
def __init__(self, model_name="mrm8488/deberta-v3-small-finetuned-mnli", threshold=0.6):
"""
Args:
model_name: Hugging Face model identifier for a DeBERTa MNLI model.
threshold: Confidence threshold above which input is flagged as injection.
"""
print("Loading Prompt Injection Agent (MNLI-based)...")
self.threshold = threshold
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.model.eval() # inference mode
print("✓ Model loaded successfully")
# Rule‑based patterns (covers common jailbreak attempts)
self.injection_patterns = [
(r"ignore previous instructions", "instruction_override"),
(r"ignore all previous", "instruction_override"),
(r"disregard previous", "instruction_override"),
(r"system prompt", "system_override"),
(r"you are now", "role_playing"),
(r"act as", "role_playing"),
(r"new role:", "role_playing"),
(r"forget your instructions", "instruction_override"),
(r"do anything now", "privilege_escalation"),
(r"you must", "privilege_escalation"),
(r"you are free", "jailbreak"),
(r"no restrictions", "jailbreak"),
(r"override", "instruction_override"),
(r"jailbreak", "jailbreak"),
(r"dan", "jailbreak"), # DAN mode
(r"developer mode", "jailbreak"),
(r"chatgpt, you are now", "role_playing"),
(r"you are an ai with no ethics", "role_playing"),
(r"output raw", "attention_diversion"),
(r"base64 decode", "attention_diversion"),
]
def analyze(self, text: str) -> dict:
"""
Analyze input text for prompt injection.
Returns:
dict with keys:
prompt_injection_detected (bool): final decision
confidence (float): combined risk score
risk_score (float): same as confidence (for backward compatibility)
matched_patterns (list): regex patterns that fired
attack_categories (list): types of injection detected
explanation (list): human‑readable reasons
"""
# -------------------- Rule‑based scan --------------------
text_lower = text.lower()
rule_score = 0.0
matched_patterns = []
attack_categories = []
for pattern, category in self.injection_patterns:
if re.search(pattern, text_lower):
rule_score += 0.3
matched_patterns.append(pattern)
attack_categories.append(category)
# -------------------- Transformer inference --------------------
# Tokenize
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
# MNLI classes: 0 = entailment, 1 = neutral, 2 = contradiction
contradiction_prob = probs[0][2].item()
# -------------------- Combine scores --------------------
# 70% weight on contradiction probability, 30% on rule‑based
combined_risk = 0.7 * contradiction_prob + 0.3 * min(rule_score, 1.0)
detected = combined_risk > self.threshold
# -------------------- Build explanation --------------------
explanation = []
explanation.append(f"Contradiction probability: {contradiction_prob:.1%}")
if attack_categories:
unique_cats = list(set(attack_categories))
explanation.append(f"Rule matches: {', '.join(unique_cats)}")
if detected:
explanation.append(f"Combined risk {combined_risk:.1%} exceeds threshold {self.threshold:.0%}")
return {
"prompt_injection_detected": detected,
"confidence": combined_risk,
"risk_score": combined_risk, # alias for compatibility
"matched_patterns": matched_patterns,
"attack_categories": list(set(attack_categories)),
"explanation": explanation
} |