File size: 4,704 Bytes
efd15e6
 
6ca375c
efd15e6
 
6ca375c
 
efd15e6
 
 
 
 
 
 
 
 
6ca375c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efd15e6
 
 
 
 
 
 
 
6ca375c
efd15e6
 
 
 
6ca375c
efd15e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ca375c
 
 
efd15e6
 
 
 
 
 
 
 
 
 
6ca375c
 
 
 
 
 
 
 
 
 
 
 
 
efd15e6
 
 
 
 
 
 
f665498
 
7d3f664
 
f665498
efd15e6
 
f665498
efd15e6
 
 
 
f665498
 
 
 
 
 
efd15e6
 
f665498
efd15e6
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""Post-generation output validation gate.

Four deterministic checks:
  1. PII leakage: reuses PIIRedactor to detect PII in LLM output
  2. URL validation: URLs must appear in retrieved chunks
  3. Secret leakage: deny-list of API key formats and env var literals
  4. Blocklist scan: configurable forbidden patterns
"""

from __future__ import annotations

import re

from agent_bench.security.pii_redactor import PIIRedactor
from agent_bench.security.types import OutputVerdict

# Always-on secret-leakage deny list. These fire regardless of config.
# Matches the well-known API-key prefixes and the common env var literals
# that a docs assistant should never emit.
_SECRET_PATTERNS: list[tuple[str, re.Pattern]] = [
    ("openai_api_key_format", re.compile(r"\bsk-(?!ant-)[A-Za-z0-9_\-]{20,}")),
    ("anthropic_api_key_format", re.compile(r"\bsk-ant-[A-Za-z0-9_\-]{20,}")),
    ("google_api_key_format", re.compile(r"\bAIza[0-9A-Za-z_\-]{35}\b")),
    ("aws_access_key_format", re.compile(r"\b(?:AKIA|ASIA)[0-9A-Z]{16}\b")),
    ("github_token_format", re.compile(r"\bgh[pousr]_[A-Za-z0-9]{36,}\b")),
    ("bearer_token_header", re.compile(
        r"\b[Bb]earer\s+[A-Za-z0-9_\-\.=]{20,}",
    )),
    ("env_var_literal", re.compile(
        r"\b(?:OPENAI_API_KEY|ANTHROPIC_API_KEY|"
        r"AWS_SECRET(?:_ACCESS_KEY)?|AWS_ACCESS_KEY(?:_ID)?|"
        r"GITHUB_TOKEN|DATABASE_URL|DB_PASSWORD)\s*=\s*\S+",
    )),
]


class OutputValidator:
    """Validate LLM output before returning to user."""

    def __init__(
        self,
        pii_check: bool = True,
        url_check: bool = True,
        secret_check: bool = True,
        blocklist: list[str] | None = None,
    ) -> None:
        self.pii_check = pii_check
        self.url_check = url_check
        self.secret_check = secret_check
        self.blocklist_patterns = [re.compile(p) for p in (blocklist or [])]
        if pii_check:
            self._pii = PIIRedactor(mode="detect_only")

    def validate(
        self,
        output: str,
        retrieved_chunks: list[str],
    ) -> OutputVerdict:
        """Run all configured checks. Returns verdict with violations."""
        violations: list[str] = []

        if self.pii_check:
            violations.extend(self._check_pii(output))

        if self.url_check:
            violations.extend(self._check_urls(output, retrieved_chunks))

        if self.secret_check:
            violations.extend(self._check_secrets(output))

        if self.blocklist_patterns:
            violations.extend(self._check_blocklist(output))

        passed = len(violations) == 0
        return OutputVerdict(
            passed=passed,
            violations=violations,
            action="pass" if passed else "block",
        )

    def _check_secrets(self, output: str) -> list[str]:
        """Fail closed on known-secret formats and env var assignments.

        These patterns never match legitimate FastAPI / Kubernetes doc
        content. Any hit is a leaked credential that must block the
        response before the client sees it.
        """
        violations = []
        for name, pattern in _SECRET_PATTERNS:
            if pattern.search(output):
                violations.append(f"secret_leakage: {name} detected in output")
        return violations

    def _check_pii(self, output: str) -> list[str]:
        result = self._pii.redact(output)
        if result.redactions_count > 0:
            types = ", ".join(result.types_found)
            return [f"pii_leakage: {types} detected in output"]
        return []

    @staticmethod
    def _normalize_url(url: str) -> str:
        """Strip trailing punctuation then trailing slashes for comparison."""
        return url.rstrip(".,;:").rstrip("/")

    def _check_urls(self, output: str, retrieved_chunks: list[str]) -> list[str]:
        url_pattern = re.compile(r"https?://[^\s\)\"'>]+")
        output_urls = url_pattern.findall(output)
        if not output_urls:
            return []

        chunk_text = " ".join(retrieved_chunks)
        chunk_urls_normalized = {self._normalize_url(u) for u in url_pattern.findall(chunk_text)}

        hallucinated = []
        for url in output_urls:
            if self._normalize_url(url) not in chunk_urls_normalized:
                hallucinated.append(url)

        if hallucinated:
            return [f"url_hallucination: {url}" for url in set(hallucinated)]
        return []

    def _check_blocklist(self, output: str) -> list[str]:
        violations = []
        for pattern in self.blocklist_patterns:
            if pattern.search(output):
                violations.append(f"blocklist: matched pattern '{pattern.pattern}'")
        return violations