SheildSense_API_SDK / ai_firewall /output_guardrail.py
cloud450's picture
Upload 48 files
4afcb3a verified
"""
output_guardrail.py
===================
Validates AI model responses before returning them to the user.
Checks:
1. System prompt leakage — did the model accidentally reveal its system prompt?
2. Secret / API key leakage — API keys, tokens, passwords in the response
3. PII leakage — email addresses, phone numbers, SSNs, credit cards
4. Unsafe content — explicit instructions for harmful activities
5. Excessive refusal leak — model revealing it was jailbroken / restricted
6. Known data exfiltration patterns
Each check is individually configurable and produces a labelled flag.
"""
from __future__ import annotations
import re
import logging
import time
from dataclasses import dataclass, field
from typing import List
logger = logging.getLogger("ai_firewall.output_guardrail")
# ---------------------------------------------------------------------------
# Pattern catalogue
# ---------------------------------------------------------------------------
class _Patterns:
# --- System prompt leakage ---
SYSTEM_PROMPT_LEAK = [
re.compile(r"my\s+(system\s+prompt|instructions?|directives?)\s+(is|are|say(s)?)\s*:?", re.I),
re.compile(r"(i\s+was|i've\s+been)\s+(instructed|told|programmed|configured)\s+to", re.I),
re.compile(r"(the\s+)?system\s+message\s+(says?|reads?|is)\s*:?", re.I),
re.compile(r"(here\s+is|below\s+is)\s+(my\s+)?(full\s+|complete\s+)?(system\s+prompt|initial\s+instructions?)", re.I),
re.compile(r"(confidential|hidden|secret)\s+(system\s+prompt|instructions?)", re.I),
]
# --- API keys & secrets ---
SECRET_PATTERNS = [
re.compile(r"sk-[a-zA-Z0-9]{20,}", re.I), # OpenAI
re.compile(r"AIza[0-9A-Za-z\-_]{35}", re.I), # Google API
re.compile(r"AKIA[0-9A-Z]{16}", re.I), # AWS access key
re.compile(r"(?:ghp|ghs|gho|github_pat)_[a-zA-Z0-9]{36,}", re.I), # GitHub tokens
re.compile(r"xox[baprs]-[0-9]{10,}-[0-9A-Za-z\-]{20,}", re.I), # Slack
re.compile(r"(?:password|passwd|secret|api_key|apikey|token)\s*[:=]\s*[\"\']?[^\s\"\']{8,}[\"\']?", re.I),
re.compile(r"Bearer\s+[a-zA-Z0-9._\-]{20,}", re.I), # Bearer tokens
re.compile(r"-----BEGIN\s+(RSA|EC|OPENSSH|PGP)?\s*PRIVATE KEY-----"), # Private keys
]
# --- PII ---
PII_PATTERNS = [
re.compile(r"\b[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,}\b"), # Email
re.compile(r"\b(?:\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}\b"), # Phone (US-ish)
re.compile(r"\b\d{3}-\d{2}-\d{4}\b"), # SSN
re.compile(r"\b(?:4[0-9]{12}(?:[0-9]{3})?|5[1-5][0-9]{14}|3[47][0-9]{13})\b"), # Credit card
re.compile(r"\b[A-Z]{2}\d{6}[A-Z]?\b"), # Passport-like
]
# --- Unsafe content ---
UNSAFE_CONTENT = [
re.compile(r"(how\s+to)?\s*(make|build|synthesize|create)\s+(explosives?|bombs?|weapons?|poison)", re.I),
re.compile(r"step[\s-]by[\s-]step\s+(guide|instructions?)\s+.{0,40}(hack|phish|exploit|malware)", re.I),
re.compile(r"(bypass|disable|defeat)\s+(security|authentication|2fa|mfa|captcha)", re.I),
re.compile(r"(execute|run)\s+(arbitrary|remote)\s+(code|commands?)", re.I),
]
# --- Jailbreak confirmation ---
JAILBREAK_CONFIRMS = [
re.compile(r"(in\s+)?DAN\s+mode\s*:", re.I),
re.compile(r"as\s+(DAN|an?\s+unrestricted|an?\s+uncensored)\s+(ai|assistant|model)\s*:", re.I),
re.compile(r"(ignoring|without)\s+(my\s+)?(safety|ethical|content)\s+(guidelines?|filters?|restrictions?)", re.I),
re.compile(r"developer\s+mode\s+(enabled|activated|on)\s*:", re.I),
]
# Severity weights per check category
_SEVERITY = {
"system_prompt_leak": 0.90,
"secret_leak": 0.95,
"pii_leak": 0.80,
"unsafe_content": 0.85,
"jailbreak_confirmation": 0.92,
}
@dataclass
class GuardrailResult:
is_safe: bool
risk_score: float
flags: List[str] = field(default_factory=list)
redacted_output: str = ""
latency_ms: float = 0.0
def to_dict(self) -> dict:
return {
"is_safe": self.is_safe,
"risk_score": round(self.risk_score, 4),
"flags": self.flags,
"redacted_output": self.redacted_output,
"latency_ms": round(self.latency_ms, 2),
}
class OutputGuardrail:
"""
Post-generation output guardrail.
Scans the model's response for leakage and unsafe content before
returning it to the caller.
Parameters
----------
threshold : float
Risk score above which output is blocked (default 0.50).
redact : bool
If True, return a redacted version of the output with sensitive
patterns replaced by [REDACTED] (default True).
check_system_prompt_leak : bool
check_secrets : bool
check_pii : bool
check_unsafe_content : bool
check_jailbreak_confirmation : bool
"""
def __init__(
self,
threshold: float = 0.50,
redact: bool = True,
check_system_prompt_leak: bool = True,
check_secrets: bool = True,
check_pii: bool = True,
check_unsafe_content: bool = True,
check_jailbreak_confirmation: bool = True,
) -> None:
self.threshold = threshold
self.redact = redact
self.check_system_prompt_leak = check_system_prompt_leak
self.check_secrets = check_secrets
self.check_pii = check_pii
self.check_unsafe_content = check_unsafe_content
self.check_jailbreak_confirmation = check_jailbreak_confirmation
# ------------------------------------------------------------------
# Checks
# ------------------------------------------------------------------
def _run_patterns(self, text: str, patterns: list, label: str, out: str) -> tuple[float, List[str], str]:
score = 0.0
flags = []
for p in patterns:
if p.search(text):
score = _SEVERITY.get(label, 0.7)
flags.append(label)
if self.redact:
out = p.sub("[REDACTED]", out)
break # one flag per category
return score, flags, out
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def validate(self, output: str) -> GuardrailResult:
"""
Validate a model response.
Parameters
----------
output : str
Raw model response text.
Returns
-------
GuardrailResult
"""
t0 = time.perf_counter()
max_score = 0.0
all_flags: List[str] = []
redacted = output
checks = [
(self.check_system_prompt_leak, _Patterns.SYSTEM_PROMPT_LEAK, "system_prompt_leak"),
(self.check_secrets, _Patterns.SECRET_PATTERNS, "secret_leak"),
(self.check_pii, _Patterns.PII_PATTERNS, "pii_leak"),
(self.check_unsafe_content, _Patterns.UNSAFE_CONTENT, "unsafe_content"),
(self.check_jailbreak_confirmation, _Patterns.JAILBREAK_CONFIRMS, "jailbreak_confirmation"),
]
for enabled, patterns, label in checks:
if not enabled:
continue
score, flags, redacted = self._run_patterns(output, patterns, label, redacted)
if score > max_score:
max_score = score
all_flags.extend(flags)
is_safe = max_score < self.threshold
latency = (time.perf_counter() - t0) * 1000
result = GuardrailResult(
is_safe=is_safe,
risk_score=max_score,
flags=list(set(all_flags)),
redacted_output=redacted if self.redact else output,
latency_ms=latency,
)
if not is_safe:
logger.warning("Output guardrail triggered! flags=%s score=%.3f", all_flags, max_score)
return result
def is_safe_output(self, output: str) -> bool:
return self.validate(output).is_safe