File size: 8,400 Bytes
4afcb3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""
output_guardrail.py
===================
Validates AI model responses before returning them to the user.

Checks:
  1. System prompt leakage  — did the model accidentally reveal its system prompt?
  2. Secret / API key leakage — API keys, tokens, passwords in the response
  3. PII leakage             — email addresses, phone numbers, SSNs, credit cards
  4. Unsafe content          — explicit instructions for harmful activities
  5. Excessive refusal leak  — model revealing it was jailbroken / restricted
  6. Known data exfiltration patterns

Each check is individually configurable and produces a labelled flag.
"""

from __future__ import annotations

import re
import logging
import time
from dataclasses import dataclass, field
from typing import List

logger = logging.getLogger("ai_firewall.output_guardrail")


# ---------------------------------------------------------------------------
# Pattern catalogue
# ---------------------------------------------------------------------------

class _Patterns:
    # --- System prompt leakage ---
    SYSTEM_PROMPT_LEAK = [
        re.compile(r"my\s+(system\s+prompt|instructions?|directives?)\s+(is|are|say(s)?)\s*:?", re.I),
        re.compile(r"(i\s+was|i've\s+been)\s+(instructed|told|programmed|configured)\s+to", re.I),
        re.compile(r"(the\s+)?system\s+message\s+(says?|reads?|is)\s*:?", re.I),
        re.compile(r"(here\s+is|below\s+is)\s+(my\s+)?(full\s+|complete\s+)?(system\s+prompt|initial\s+instructions?)", re.I),
        re.compile(r"(confidential|hidden|secret)\s+(system\s+prompt|instructions?)", re.I),
    ]

    # --- API keys & secrets ---
    SECRET_PATTERNS = [
        re.compile(r"sk-[a-zA-Z0-9]{20,}", re.I),                                      # OpenAI
        re.compile(r"AIza[0-9A-Za-z\-_]{35}", re.I),                                   # Google API
        re.compile(r"AKIA[0-9A-Z]{16}", re.I),                                          # AWS access key
        re.compile(r"(?:ghp|ghs|gho|github_pat)_[a-zA-Z0-9]{36,}", re.I),             # GitHub tokens
        re.compile(r"xox[baprs]-[0-9]{10,}-[0-9A-Za-z\-]{20,}", re.I),               # Slack
        re.compile(r"(?:password|passwd|secret|api_key|apikey|token)\s*[:=]\s*[\"\']?[^\s\"\']{8,}[\"\']?", re.I),
        re.compile(r"Bearer\s+[a-zA-Z0-9._\-]{20,}", re.I),                            # Bearer tokens
        re.compile(r"-----BEGIN\s+(RSA|EC|OPENSSH|PGP)?\s*PRIVATE KEY-----"),          # Private keys
    ]

    # --- PII ---
    PII_PATTERNS = [
        re.compile(r"\b[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,}\b"),          # Email
        re.compile(r"\b(?:\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}\b"),       # Phone (US-ish)
        re.compile(r"\b\d{3}-\d{2}-\d{4}\b"),                                          # SSN
        re.compile(r"\b(?:4[0-9]{12}(?:[0-9]{3})?|5[1-5][0-9]{14}|3[47][0-9]{13})\b"), # Credit card
        re.compile(r"\b[A-Z]{2}\d{6}[A-Z]?\b"),                                        # Passport-like
    ]

    # --- Unsafe content ---
    UNSAFE_CONTENT = [
        re.compile(r"(how\s+to)?\s*(make|build|synthesize|create)\s+(explosives?|bombs?|weapons?|poison)", re.I),
        re.compile(r"step[\s-]by[\s-]step\s+(guide|instructions?)\s+.{0,40}(hack|phish|exploit|malware)", re.I),
        re.compile(r"(bypass|disable|defeat)\s+(security|authentication|2fa|mfa|captcha)", re.I),
        re.compile(r"(execute|run)\s+(arbitrary|remote)\s+(code|commands?)", re.I),
    ]

    # --- Jailbreak confirmation ---
    JAILBREAK_CONFIRMS = [
        re.compile(r"(in\s+)?DAN\s+mode\s*:", re.I),
        re.compile(r"as\s+(DAN|an?\s+unrestricted|an?\s+uncensored)\s+(ai|assistant|model)\s*:", re.I),
        re.compile(r"(ignoring|without)\s+(my\s+)?(safety|ethical|content)\s+(guidelines?|filters?|restrictions?)", re.I),
        re.compile(r"developer\s+mode\s+(enabled|activated|on)\s*:", re.I),
    ]


# Severity weights per check category
_SEVERITY = {
    "system_prompt_leak": 0.90,
    "secret_leak": 0.95,
    "pii_leak": 0.80,
    "unsafe_content": 0.85,
    "jailbreak_confirmation": 0.92,
}


@dataclass
class GuardrailResult:
    is_safe: bool
    risk_score: float
    flags: List[str] = field(default_factory=list)
    redacted_output: str = ""
    latency_ms: float = 0.0

    def to_dict(self) -> dict:
        return {
            "is_safe": self.is_safe,
            "risk_score": round(self.risk_score, 4),
            "flags": self.flags,
            "redacted_output": self.redacted_output,
            "latency_ms": round(self.latency_ms, 2),
        }


class OutputGuardrail:
    """
    Post-generation output guardrail.

    Scans the model's response for leakage and unsafe content before
    returning it to the caller.

    Parameters
    ----------
    threshold : float
        Risk score above which output is blocked (default 0.50).
    redact : bool
        If True, return a redacted version of the output with sensitive
        patterns replaced by [REDACTED] (default True).
    check_system_prompt_leak : bool
    check_secrets : bool
    check_pii : bool
    check_unsafe_content : bool
    check_jailbreak_confirmation : bool
    """

    def __init__(
        self,
        threshold: float = 0.50,
        redact: bool = True,
        check_system_prompt_leak: bool = True,
        check_secrets: bool = True,
        check_pii: bool = True,
        check_unsafe_content: bool = True,
        check_jailbreak_confirmation: bool = True,
    ) -> None:
        self.threshold = threshold
        self.redact = redact
        self.check_system_prompt_leak = check_system_prompt_leak
        self.check_secrets = check_secrets
        self.check_pii = check_pii
        self.check_unsafe_content = check_unsafe_content
        self.check_jailbreak_confirmation = check_jailbreak_confirmation

    # ------------------------------------------------------------------
    # Checks
    # ------------------------------------------------------------------

    def _run_patterns(self, text: str, patterns: list, label: str, out: str) -> tuple[float, List[str], str]:
        score = 0.0
        flags = []
        for p in patterns:
            if p.search(text):
                score = _SEVERITY.get(label, 0.7)
                flags.append(label)
                if self.redact:
                    out = p.sub("[REDACTED]", out)
                break  # one flag per category
        return score, flags, out

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def validate(self, output: str) -> GuardrailResult:
        """
        Validate a model response.

        Parameters
        ----------
        output : str
            Raw model response text.

        Returns
        -------
        GuardrailResult
        """
        t0 = time.perf_counter()

        max_score = 0.0
        all_flags: List[str] = []
        redacted = output

        checks = [
            (self.check_system_prompt_leak, _Patterns.SYSTEM_PROMPT_LEAK, "system_prompt_leak"),
            (self.check_secrets,            _Patterns.SECRET_PATTERNS,    "secret_leak"),
            (self.check_pii,                _Patterns.PII_PATTERNS,       "pii_leak"),
            (self.check_unsafe_content,     _Patterns.UNSAFE_CONTENT,     "unsafe_content"),
            (self.check_jailbreak_confirmation, _Patterns.JAILBREAK_CONFIRMS, "jailbreak_confirmation"),
        ]

        for enabled, patterns, label in checks:
            if not enabled:
                continue
            score, flags, redacted = self._run_patterns(output, patterns, label, redacted)
            if score > max_score:
                max_score = score
            all_flags.extend(flags)

        is_safe = max_score < self.threshold
        latency = (time.perf_counter() - t0) * 1000

        result = GuardrailResult(
            is_safe=is_safe,
            risk_score=max_score,
            flags=list(set(all_flags)),
            redacted_output=redacted if self.redact else output,
            latency_ms=latency,
        )

        if not is_safe:
            logger.warning("Output guardrail triggered! flags=%s score=%.3f", all_flags, max_score)

        return result

    def is_safe_output(self, output: str) -> bool:
        return self.validate(output).is_safe