Nomearod Claude Opus 4.6 (1M context) commited on
Commit
efd15e6
·
1 Parent(s): 0465079

feat(security): add output validation gate (PII, URL, blocklist)

Browse files

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

agent_bench/security/output_validator.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Post-generation output validation gate.
2
+
3
+ Three deterministic checks:
4
+ 1. PII leakage: reuses PIIRedactor to detect PII in LLM output
5
+ 2. URL validation: URLs must appear in retrieved chunks
6
+ 3. Blocklist scan: configurable forbidden patterns
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import re
12
+
13
+ from agent_bench.security.pii_redactor import PIIRedactor
14
+ from agent_bench.security.types import OutputVerdict
15
+
16
+
17
+ class OutputValidator:
18
+ """Validate LLM output before returning to user."""
19
+
20
+ def __init__(
21
+ self,
22
+ pii_check: bool = True,
23
+ url_check: bool = True,
24
+ blocklist: list[str] | None = None,
25
+ ) -> None:
26
+ self.pii_check = pii_check
27
+ self.url_check = url_check
28
+ self.blocklist_patterns = [re.compile(p) for p in (blocklist or [])]
29
+ if pii_check:
30
+ self._pii = PIIRedactor(mode="detect_only")
31
+
32
+ def validate(
33
+ self,
34
+ output: str,
35
+ retrieved_chunks: list[str],
36
+ ) -> OutputVerdict:
37
+ """Run all configured checks. Returns verdict with violations."""
38
+ violations: list[str] = []
39
+
40
+ if self.pii_check:
41
+ violations.extend(self._check_pii(output))
42
+
43
+ if self.url_check:
44
+ violations.extend(self._check_urls(output, retrieved_chunks))
45
+
46
+ if self.blocklist_patterns:
47
+ violations.extend(self._check_blocklist(output))
48
+
49
+ passed = len(violations) == 0
50
+ return OutputVerdict(
51
+ passed=passed,
52
+ violations=violations,
53
+ action="pass" if passed else "block",
54
+ )
55
+
56
+ def _check_pii(self, output: str) -> list[str]:
57
+ result = self._pii.redact(output)
58
+ if result.redactions_count > 0:
59
+ types = ", ".join(result.types_found)
60
+ return [f"pii_leakage: {types} detected in output"]
61
+ return []
62
+
63
+ def _check_urls(self, output: str, retrieved_chunks: list[str]) -> list[str]:
64
+ url_pattern = re.compile(r"https?://[^\s\)\"'>]+")
65
+ output_urls = set(url_pattern.findall(output))
66
+ if not output_urls:
67
+ return []
68
+
69
+ chunk_text = " ".join(retrieved_chunks)
70
+ chunk_urls = set(url_pattern.findall(chunk_text))
71
+
72
+ hallucinated = output_urls - chunk_urls
73
+ if hallucinated:
74
+ return [f"url_hallucination: {url}" for url in hallucinated]
75
+ return []
76
+
77
+ def _check_blocklist(self, output: str) -> list[str]:
78
+ violations = []
79
+ for pattern in self.blocklist_patterns:
80
+ if pattern.search(output):
81
+ violations.append(f"blocklist: matched pattern '{pattern.pattern}'")
82
+ return violations
tests/test_output_validator.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for output validation gate."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+
7
+ from agent_bench.security.output_validator import OutputValidator
8
+ from agent_bench.security.types import OutputVerdict
9
+
10
+
11
+ class TestPIILeakage:
12
+ """PII in LLM output should be caught."""
13
+
14
+ @pytest.fixture
15
+ def validator(self):
16
+ return OutputValidator(pii_check=True, url_check=False, blocklist=[])
17
+
18
+ def test_detects_email_in_output(self, validator):
19
+ verdict = validator.validate(
20
+ output="Contact john@example.com for help.",
21
+ retrieved_chunks=[],
22
+ )
23
+ assert verdict.passed is False
24
+ assert any("pii_leakage" in v for v in verdict.violations)
25
+
26
+ def test_detects_ssn_in_output(self, validator):
27
+ verdict = validator.validate(
28
+ output="His SSN is 123-45-6789.",
29
+ retrieved_chunks=[],
30
+ )
31
+ assert verdict.passed is False
32
+
33
+ def test_clean_output_passes(self, validator):
34
+ verdict = validator.validate(
35
+ output="FastAPI uses path parameters with curly braces.",
36
+ retrieved_chunks=[],
37
+ )
38
+ assert verdict.passed is True
39
+ assert verdict.violations == []
40
+
41
+
42
+ class TestURLValidation:
43
+ """URLs in output must appear in retrieved chunks."""
44
+
45
+ @pytest.fixture
46
+ def validator(self):
47
+ return OutputValidator(pii_check=False, url_check=True, blocklist=[])
48
+
49
+ def test_url_from_chunks_passes(self, validator):
50
+ chunks = ["Visit https://fastapi.tiangolo.com for docs."]
51
+ verdict = validator.validate(
52
+ output="See https://fastapi.tiangolo.com for details.",
53
+ retrieved_chunks=chunks,
54
+ )
55
+ assert verdict.passed is True
56
+
57
+ def test_hallucinated_url_fails(self, validator):
58
+ chunks = ["FastAPI is a modern framework."]
59
+ verdict = validator.validate(
60
+ output="See https://malicious-site.com for details.",
61
+ retrieved_chunks=chunks,
62
+ )
63
+ assert verdict.passed is False
64
+ assert any("url_hallucination" in v for v in verdict.violations)
65
+
66
+ def test_no_urls_passes(self, validator):
67
+ verdict = validator.validate(
68
+ output="Path parameters use curly braces.",
69
+ retrieved_chunks=["Some chunk."],
70
+ )
71
+ assert verdict.passed is True
72
+
73
+
74
+ class TestBlocklist:
75
+ """Blocklisted patterns should be caught."""
76
+
77
+ def test_blocklist_match(self):
78
+ validator = OutputValidator(
79
+ pii_check=False, url_check=False,
80
+ blocklist=["sk-[a-zA-Z0-9]{20,}", "SYSTEM_PROMPT"],
81
+ )
82
+ verdict = validator.validate(
83
+ output="Here is the key: sk-abcdefghijklmnopqrstuvwxyz",
84
+ retrieved_chunks=[],
85
+ )
86
+ assert verdict.passed is False
87
+ assert any("blocklist" in v for v in verdict.violations)
88
+
89
+ def test_system_prompt_fragment(self):
90
+ validator = OutputValidator(
91
+ pii_check=False, url_check=False,
92
+ blocklist=["You are a (?:helpful |test )?assistant"],
93
+ )
94
+ verdict = validator.validate(
95
+ output="My instructions say: You are a helpful assistant",
96
+ retrieved_chunks=[],
97
+ )
98
+ assert verdict.passed is False
99
+
100
+ def test_no_blocklist_match(self):
101
+ validator = OutputValidator(
102
+ pii_check=False, url_check=False,
103
+ blocklist=["FORBIDDEN_TERM"],
104
+ )
105
+ verdict = validator.validate(
106
+ output="A perfectly normal answer.",
107
+ retrieved_chunks=[],
108
+ )
109
+ assert verdict.passed is True
110
+
111
+
112
+ class TestCombinedChecks:
113
+ def test_multiple_violations(self):
114
+ validator = OutputValidator(
115
+ pii_check=True, url_check=True,
116
+ blocklist=["SECRET"],
117
+ )
118
+ verdict = validator.validate(
119
+ output="Email john@test.com, see https://evil.com, also SECRET.",
120
+ retrieved_chunks=["No URLs here."],
121
+ )
122
+ assert verdict.passed is False
123
+ assert len(verdict.violations) >= 2 # PII + URL at minimum
124
+ assert verdict.action == "block"
125
+
126
+ def test_all_checks_pass(self):
127
+ validator = OutputValidator(
128
+ pii_check=True, url_check=True,
129
+ blocklist=["SECRET"],
130
+ )
131
+ verdict = validator.validate(
132
+ output="FastAPI supports path parameters.",
133
+ retrieved_chunks=["FastAPI supports path parameters."],
134
+ )
135
+ assert verdict.passed is True
136
+ assert verdict.action == "pass"
137
+
138
+ def test_disabled_checks(self):
139
+ validator = OutputValidator(pii_check=False, url_check=False, blocklist=[])
140
+ verdict = validator.validate(
141
+ output="Email: a@b.com, URL: https://evil.com",
142
+ retrieved_chunks=[],
143
+ )
144
+ assert verdict.passed is True