agentbench / agent_bench /security /injection_detector.py
Nomearod's picture
security: fail-closed on secret extraction and env var leakage
6ca375c
"""Prompt injection detection.
Two-tier detection:
Tier 1 — Heuristic regex (local, <1ms): catches common injection patterns
Tier 2 — DeBERTa classifier (Modal GPU): high-confidence arbiter
Deployments without GPU run heuristic-only.
"""
from __future__ import annotations
import base64
import re
import structlog
from agent_bench.security.types import SecurityVerdict
logger = structlog.get_logger()
# --- Tier 1: Heuristic patterns ---
# Each pattern is (name, compiled_regex).
# Patterns use word boundaries and case-insensitive matching.
# Ordered from most specific to least specific.
_HEURISTIC_PATTERNS: list[tuple[str, re.Pattern]] = [
# Role/identity hijacking
("role_switch", re.compile(
r"\byou\s+are\s+now\b", re.IGNORECASE
)),
("act_as", re.compile(
r"\b(?:from\s+now\s+on\s+)?(?:you\s+will\s+)?act\s+(?:as\s+(?:if\s+)?)", re.IGNORECASE
)),
("pretend", re.compile(
r"\bpretend\s+you\s+are\b", re.IGNORECASE
)),
# Instruction override
("ignore_previous", re.compile(
r"\bignore\s+(?:all\s+)?(?:previous|prior|above|earlier|your|my)\s+(?:instructions?|context|rules|guidelines|directives)\b",
re.IGNORECASE,
)),
("disregard", re.compile(
r"\bdisregard\s+(?:all\s+)?(?:your|previous|prior)?\s*(?:instructions?|rules|guidelines)\b",
re.IGNORECASE,
)),
("forget_instructions", re.compile(
r"\bforget\s+(?:all\s+|everything\s+)?(?:you\s+were\s+told|previous|prior|your\s+instructions?|your\s+context)\b",
re.IGNORECASE,
)),
("do_not_follow", re.compile(
r"\bdo\s+not\s+follow\s+(?:your\s+)?(?:original\s+)?instructions?\b",
re.IGNORECASE,
)),
# System prompt extraction
("reveal_prompt", re.compile(
r"\b(?:reveal|show|display|output|print|repeat|tell\s+me|give\s+me|share|leak|dump|paste|write\s+out)\s+(?:me\s+)?(?:your\s+)?(?:system\s+prompt|initial\s+instructions?|instructions?\s+verbatim|original\s+instructions?|hidden\s+prompt|internal\s+prompt)\b",
re.IGNORECASE,
)),
("what_is_prompt", re.compile(
r"\bwhat\s+(?:is|are)\s+your\s+(?:system\s+prompt|instructions?|initial\s+prompt|hidden\s+prompt)\b",
re.IGNORECASE,
)),
# Direct prompt requests (catches "give me your system prompt")
("give_prompt", re.compile(
r"\b(?:give|send|copy|provide)\s+(?:me\s+)?(?:the\s+|your\s+)?(?:system\s+prompt|full\s+prompt|original\s+prompt|system\s+instructions?|internal\s+instructions?|hidden\s+instructions?)\b",
re.IGNORECASE,
)),
# Prompt as a noun target (catches "I want your system prompt")
("want_prompt", re.compile(
r"\b(?:i\s+want|i\s+need|hand\s+over|access)\s+(?:to\s+see\s+)?(?:your\s+)?(?:system\s+prompt|internal\s+prompt|original\s+instructions?|system\s+instructions?)\b",
re.IGNORECASE,
)),
# Secret / credential extraction
# Gated on extraction-verb + determiner ("the/your/exact/...") to avoid
# false-positives on educational questions like "What is an API key?".
("api_key_extract", re.compile(
r"\b(?:what\s+is|what\s+are|tell\s+me|give\s+me|show\s+me|"
r"reveal|share|print|output|copy|send|dump|leak|hand\s+over|disclose)\s+"
r"(?:me\s+)?"
r"(?:the|your|exact|actual|current|configured|real)\s+"
r"(?:exact\s+|current\s+|actual\s+|configured\s+|real\s+)?"
r"(?:api\s+key|api_key|secret\s+key|access\s+token|"
r"auth\s+token|bearer\s+token|private\s+key)\b",
re.IGNORECASE,
)),
("credential_extract", re.compile(
r"\b(?:what\s+are|tell\s+me|give\s+me|show\s+me|"
r"reveal|share|dump|leak|disclose|hand\s+over)\s+"
r"(?:me\s+)?"
r"(?:the|your)\s+"
r"(?:credentials?|secrets?|passwords?|"
r"auth\s+details?|login\s+details?)\b",
re.IGNORECASE,
)),
("env_var_extract", re.compile(
r"\b(?:what(?:\s+are)?|tell\s+me|give\s+me|show\s+me|"
r"reveal|share|dump|leak|print|list|read)\s+"
r"(?:me\s+)?"
r"(?:the\s+|your\s+|all\s+)?"
r"(?:environment\s+variables?|env\s+vars?|env\s+variables?|"
r"process\s+env|\.env\s+file|\.env\s+contents?)\b",
re.IGNORECASE,
)),
# Literal known-secret env var names. Fail closed: mentioning these by
# name in a question to a docs assistant is almost always an extraction
# attempt. Narrow scope (not generic "API_KEY") to reduce false positives.
("known_secret_literal", re.compile(
r"(?:OPENAI_API_KEY|ANTHROPIC_API_KEY|"
r"AWS_SECRET(?:_ACCESS_KEY)?|AWS_ACCESS_KEY(?:_ID)?|"
r"GITHUB_TOKEN|DATABASE_URL|DB_PASSWORD)",
re.IGNORECASE,
)),
# System message injection
("system_prefix", re.compile(
r"^(?:system\s*:|###\s*SYSTEM\s*###|```system)", re.IGNORECASE | re.MULTILINE
)),
("system_block", re.compile(
r"```system\b", re.IGNORECASE
)),
# Jailbreak keywords
("jailbreak", re.compile(
r"\b(?:DAN|jailbreak|jailbroken|unrestricted\s+(?:AI|assistant|mode))\b",
re.IGNORECASE,
)),
("no_restrictions", re.compile(
r"\b(?:no|without|remove)\s+(?:content\s+policy|safety\s+guidelines|restrictions|filters|guardrails)\b",
re.IGNORECASE,
)),
]
class InjectionDetector:
"""Two-tier injection detection."""
def __init__(
self,
tiers: list[str] | None = None,
classifier_url: str = "",
enabled: bool = True,
) -> None:
self.tiers = tiers or ["heuristic", "classifier"]
self.classifier_url = classifier_url
self.enabled = enabled
def detect(self, text: str) -> SecurityVerdict:
"""Run detection tiers in order. Return on first match."""
if not self.enabled or not text.strip():
return SecurityVerdict(safe=True, tier="heuristic", confidence=1.0)
# Tier 1: Heuristic
if "heuristic" in self.tiers:
verdict = self._heuristic(text)
if not verdict.safe:
return verdict
# Tier 2: Classifier (async call needed — see detect_async)
# Synchronous detect() only runs heuristic. Use detect_async() for
# the full pipeline including the Modal classifier.
return SecurityVerdict(safe=True, tier="heuristic", confidence=1.0)
async def detect_async(self, text: str) -> SecurityVerdict:
"""Run all configured tiers including async classifier."""
if not self.enabled or not text.strip():
return SecurityVerdict(safe=True, tier="heuristic", confidence=1.0)
# Tier 1: Heuristic
if "heuristic" in self.tiers:
verdict = self._heuristic(text)
if not verdict.safe:
return verdict
# Tier 2: Classifier
if "classifier" in self.tiers and self.classifier_url:
verdict = await self._classify(text)
if not verdict.safe:
return verdict
return SecurityVerdict(safe=True, tier=self.tiers[-1], confidence=1.0)
def _heuristic(self, text: str) -> SecurityVerdict:
"""Tier 1: regex-based heuristic detection."""
# Check base64-encoded payloads
b64_verdict = self._check_base64(text)
if b64_verdict is not None:
return b64_verdict
for name, pattern in _HEURISTIC_PATTERNS:
if pattern.search(text):
logger.warning("injection_detected", tier="heuristic", pattern=name)
return SecurityVerdict(
safe=False,
tier="heuristic",
confidence=1.0,
matched_pattern=name,
)
return SecurityVerdict(safe=True, tier="heuristic", confidence=1.0)
def _check_base64(self, text: str) -> SecurityVerdict | None:
"""Check for base64-encoded injection payloads."""
b64_pattern = re.compile(r"[A-Za-z0-9+/]{20,}={0,2}")
for match in b64_pattern.finditer(text):
try:
decoded = base64.b64decode(match.group()).decode("utf-8", errors="ignore").lower()
for name, pattern in _HEURISTIC_PATTERNS:
if pattern.search(decoded):
logger.warning(
"injection_detected",
tier="heuristic",
pattern="base64_injection",
decoded_match=name,
)
return SecurityVerdict(
safe=False,
tier="heuristic",
confidence=1.0,
matched_pattern="base64_injection",
)
except Exception:
continue
return None
async def _classify(self, text: str) -> SecurityVerdict:
"""Tier 2: DeBERTa classifier via Modal endpoint."""
import httpx
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post(
self.classifier_url,
json={"text": text},
)
resp.raise_for_status()
data = resp.json()
label = data.get("label", "SAFE")
score = float(data.get("score", 0.0))
is_injection = label == "INJECTION" and score > 0.5
if is_injection:
logger.warning("injection_detected", tier="classifier", score=score)
return SecurityVerdict(
safe=not is_injection,
tier="classifier",
confidence=score,
)
except Exception as exc:
logger.error("classifier_error", error=str(exc))
# Fail open: if classifier is unavailable, allow the request
return SecurityVerdict(safe=True, tier="classifier", confidence=0.0)