Spaces:
Sleeping
Sleeping
| """Prompt injection detection. | |
| Two-tier detection: | |
| Tier 1 — Heuristic regex (local, <1ms): catches common injection patterns | |
| Tier 2 — DeBERTa classifier (Modal GPU): high-confidence arbiter | |
| Deployments without GPU run heuristic-only. | |
| """ | |
| from __future__ import annotations | |
| import base64 | |
| import re | |
| import structlog | |
| from agent_bench.security.types import SecurityVerdict | |
| logger = structlog.get_logger() | |
| # --- Tier 1: Heuristic patterns --- | |
| # Each pattern is (name, compiled_regex). | |
| # Patterns use word boundaries and case-insensitive matching. | |
| # Ordered from most specific to least specific. | |
| _HEURISTIC_PATTERNS: list[tuple[str, re.Pattern]] = [ | |
| # Role/identity hijacking | |
| ("role_switch", re.compile( | |
| r"\byou\s+are\s+now\b", re.IGNORECASE | |
| )), | |
| ("act_as", re.compile( | |
| r"\b(?:from\s+now\s+on\s+)?(?:you\s+will\s+)?act\s+(?:as\s+(?:if\s+)?)", re.IGNORECASE | |
| )), | |
| ("pretend", re.compile( | |
| r"\bpretend\s+you\s+are\b", re.IGNORECASE | |
| )), | |
| # Instruction override | |
| ("ignore_previous", re.compile( | |
| r"\bignore\s+(?:all\s+)?(?:previous|prior|above|earlier|your|my)\s+(?:instructions?|context|rules|guidelines|directives)\b", | |
| re.IGNORECASE, | |
| )), | |
| ("disregard", re.compile( | |
| r"\bdisregard\s+(?:all\s+)?(?:your|previous|prior)?\s*(?:instructions?|rules|guidelines)\b", | |
| re.IGNORECASE, | |
| )), | |
| ("forget_instructions", re.compile( | |
| r"\bforget\s+(?:all\s+|everything\s+)?(?:you\s+were\s+told|previous|prior|your\s+instructions?|your\s+context)\b", | |
| re.IGNORECASE, | |
| )), | |
| ("do_not_follow", re.compile( | |
| r"\bdo\s+not\s+follow\s+(?:your\s+)?(?:original\s+)?instructions?\b", | |
| re.IGNORECASE, | |
| )), | |
| # System prompt extraction | |
| ("reveal_prompt", re.compile( | |
| r"\b(?:reveal|show|display|output|print|repeat|tell\s+me|give\s+me|share|leak|dump|paste|write\s+out)\s+(?:me\s+)?(?:your\s+)?(?:system\s+prompt|initial\s+instructions?|instructions?\s+verbatim|original\s+instructions?|hidden\s+prompt|internal\s+prompt)\b", | |
| re.IGNORECASE, | |
| )), | |
| ("what_is_prompt", re.compile( | |
| r"\bwhat\s+(?:is|are)\s+your\s+(?:system\s+prompt|instructions?|initial\s+prompt|hidden\s+prompt)\b", | |
| re.IGNORECASE, | |
| )), | |
| # Direct prompt requests (catches "give me your system prompt") | |
| ("give_prompt", re.compile( | |
| r"\b(?:give|send|copy|provide)\s+(?:me\s+)?(?:the\s+|your\s+)?(?:system\s+prompt|full\s+prompt|original\s+prompt|system\s+instructions?|internal\s+instructions?|hidden\s+instructions?)\b", | |
| re.IGNORECASE, | |
| )), | |
| # Prompt as a noun target (catches "I want your system prompt") | |
| ("want_prompt", re.compile( | |
| r"\b(?:i\s+want|i\s+need|hand\s+over|access)\s+(?:to\s+see\s+)?(?:your\s+)?(?:system\s+prompt|internal\s+prompt|original\s+instructions?|system\s+instructions?)\b", | |
| re.IGNORECASE, | |
| )), | |
| # Secret / credential extraction | |
| # Gated on extraction-verb + determiner ("the/your/exact/...") to avoid | |
| # false-positives on educational questions like "What is an API key?". | |
| ("api_key_extract", re.compile( | |
| r"\b(?:what\s+is|what\s+are|tell\s+me|give\s+me|show\s+me|" | |
| r"reveal|share|print|output|copy|send|dump|leak|hand\s+over|disclose)\s+" | |
| r"(?:me\s+)?" | |
| r"(?:the|your|exact|actual|current|configured|real)\s+" | |
| r"(?:exact\s+|current\s+|actual\s+|configured\s+|real\s+)?" | |
| r"(?:api\s+key|api_key|secret\s+key|access\s+token|" | |
| r"auth\s+token|bearer\s+token|private\s+key)\b", | |
| re.IGNORECASE, | |
| )), | |
| ("credential_extract", re.compile( | |
| r"\b(?:what\s+are|tell\s+me|give\s+me|show\s+me|" | |
| r"reveal|share|dump|leak|disclose|hand\s+over)\s+" | |
| r"(?:me\s+)?" | |
| r"(?:the|your)\s+" | |
| r"(?:credentials?|secrets?|passwords?|" | |
| r"auth\s+details?|login\s+details?)\b", | |
| re.IGNORECASE, | |
| )), | |
| ("env_var_extract", re.compile( | |
| r"\b(?:what(?:\s+are)?|tell\s+me|give\s+me|show\s+me|" | |
| r"reveal|share|dump|leak|print|list|read)\s+" | |
| r"(?:me\s+)?" | |
| r"(?:the\s+|your\s+|all\s+)?" | |
| r"(?:environment\s+variables?|env\s+vars?|env\s+variables?|" | |
| r"process\s+env|\.env\s+file|\.env\s+contents?)\b", | |
| re.IGNORECASE, | |
| )), | |
| # Literal known-secret env var names. Fail closed: mentioning these by | |
| # name in a question to a docs assistant is almost always an extraction | |
| # attempt. Narrow scope (not generic "API_KEY") to reduce false positives. | |
| ("known_secret_literal", re.compile( | |
| r"(?:OPENAI_API_KEY|ANTHROPIC_API_KEY|" | |
| r"AWS_SECRET(?:_ACCESS_KEY)?|AWS_ACCESS_KEY(?:_ID)?|" | |
| r"GITHUB_TOKEN|DATABASE_URL|DB_PASSWORD)", | |
| re.IGNORECASE, | |
| )), | |
| # System message injection | |
| ("system_prefix", re.compile( | |
| r"^(?:system\s*:|###\s*SYSTEM\s*###|```system)", re.IGNORECASE | re.MULTILINE | |
| )), | |
| ("system_block", re.compile( | |
| r"```system\b", re.IGNORECASE | |
| )), | |
| # Jailbreak keywords | |
| ("jailbreak", re.compile( | |
| r"\b(?:DAN|jailbreak|jailbroken|unrestricted\s+(?:AI|assistant|mode))\b", | |
| re.IGNORECASE, | |
| )), | |
| ("no_restrictions", re.compile( | |
| r"\b(?:no|without|remove)\s+(?:content\s+policy|safety\s+guidelines|restrictions|filters|guardrails)\b", | |
| re.IGNORECASE, | |
| )), | |
| ] | |
| class InjectionDetector: | |
| """Two-tier injection detection.""" | |
| def __init__( | |
| self, | |
| tiers: list[str] | None = None, | |
| classifier_url: str = "", | |
| enabled: bool = True, | |
| ) -> None: | |
| self.tiers = tiers or ["heuristic", "classifier"] | |
| self.classifier_url = classifier_url | |
| self.enabled = enabled | |
| def detect(self, text: str) -> SecurityVerdict: | |
| """Run detection tiers in order. Return on first match.""" | |
| if not self.enabled or not text.strip(): | |
| return SecurityVerdict(safe=True, tier="heuristic", confidence=1.0) | |
| # Tier 1: Heuristic | |
| if "heuristic" in self.tiers: | |
| verdict = self._heuristic(text) | |
| if not verdict.safe: | |
| return verdict | |
| # Tier 2: Classifier (async call needed — see detect_async) | |
| # Synchronous detect() only runs heuristic. Use detect_async() for | |
| # the full pipeline including the Modal classifier. | |
| return SecurityVerdict(safe=True, tier="heuristic", confidence=1.0) | |
| async def detect_async(self, text: str) -> SecurityVerdict: | |
| """Run all configured tiers including async classifier.""" | |
| if not self.enabled or not text.strip(): | |
| return SecurityVerdict(safe=True, tier="heuristic", confidence=1.0) | |
| # Tier 1: Heuristic | |
| if "heuristic" in self.tiers: | |
| verdict = self._heuristic(text) | |
| if not verdict.safe: | |
| return verdict | |
| # Tier 2: Classifier | |
| if "classifier" in self.tiers and self.classifier_url: | |
| verdict = await self._classify(text) | |
| if not verdict.safe: | |
| return verdict | |
| return SecurityVerdict(safe=True, tier=self.tiers[-1], confidence=1.0) | |
| def _heuristic(self, text: str) -> SecurityVerdict: | |
| """Tier 1: regex-based heuristic detection.""" | |
| # Check base64-encoded payloads | |
| b64_verdict = self._check_base64(text) | |
| if b64_verdict is not None: | |
| return b64_verdict | |
| for name, pattern in _HEURISTIC_PATTERNS: | |
| if pattern.search(text): | |
| logger.warning("injection_detected", tier="heuristic", pattern=name) | |
| return SecurityVerdict( | |
| safe=False, | |
| tier="heuristic", | |
| confidence=1.0, | |
| matched_pattern=name, | |
| ) | |
| return SecurityVerdict(safe=True, tier="heuristic", confidence=1.0) | |
| def _check_base64(self, text: str) -> SecurityVerdict | None: | |
| """Check for base64-encoded injection payloads.""" | |
| b64_pattern = re.compile(r"[A-Za-z0-9+/]{20,}={0,2}") | |
| for match in b64_pattern.finditer(text): | |
| try: | |
| decoded = base64.b64decode(match.group()).decode("utf-8", errors="ignore").lower() | |
| for name, pattern in _HEURISTIC_PATTERNS: | |
| if pattern.search(decoded): | |
| logger.warning( | |
| "injection_detected", | |
| tier="heuristic", | |
| pattern="base64_injection", | |
| decoded_match=name, | |
| ) | |
| return SecurityVerdict( | |
| safe=False, | |
| tier="heuristic", | |
| confidence=1.0, | |
| matched_pattern="base64_injection", | |
| ) | |
| except Exception: | |
| continue | |
| return None | |
| async def _classify(self, text: str) -> SecurityVerdict: | |
| """Tier 2: DeBERTa classifier via Modal endpoint.""" | |
| import httpx | |
| try: | |
| async with httpx.AsyncClient(timeout=10.0) as client: | |
| resp = await client.post( | |
| self.classifier_url, | |
| json={"text": text}, | |
| ) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| label = data.get("label", "SAFE") | |
| score = float(data.get("score", 0.0)) | |
| is_injection = label == "INJECTION" and score > 0.5 | |
| if is_injection: | |
| logger.warning("injection_detected", tier="classifier", score=score) | |
| return SecurityVerdict( | |
| safe=not is_injection, | |
| tier="classifier", | |
| confidence=score, | |
| ) | |
| except Exception as exc: | |
| logger.error("classifier_error", error=str(exc)) | |
| # Fail open: if classifier is unavailable, allow the request | |
| return SecurityVerdict(safe=True, tier="classifier", confidence=0.0) | |