Spaces:
Running
Running
fix(security): output validation on /ask/stream, correct audit endpoint
Browse files- /ask/stream now runs output validation on the assembled answer
after streaming completes. PII in streamed output triggers a
"[Output filtered for safety]" SSE chunk.
- _write_audit() takes an endpoint parameter instead of hardcoding
"/ask". Stream audit records are labeled "/ask/stream".
- _write_audit() records output_validation independently of result
metadata, so streaming audit includes validation verdicts.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- agent_bench/serving/routes.py +48 -16
- tests/test_security_integration.py +28 -1
agent_bench/serving/routes.py
CHANGED
|
@@ -190,7 +190,10 @@ async def ask_stream(body: AskRequest, request: Request) -> StreamingResponse:
|
|
| 190 |
sec_config = getattr(request.app.state.config, "security", None)
|
| 191 |
action = sec_config.injection.action if sec_config else "block"
|
| 192 |
if not verdict.safe and action == "block":
|
| 193 |
-
_write_audit(
|
|
|
|
|
|
|
|
|
|
| 194 |
from fastapi.responses import JSONResponse
|
| 195 |
return JSONResponse(
|
| 196 |
status_code=403,
|
|
@@ -209,9 +212,12 @@ async def ask_stream(body: AskRequest, request: Request) -> StreamingResponse:
|
|
| 209 |
|
| 210 |
start = time.perf_counter()
|
| 211 |
|
|
|
|
|
|
|
| 212 |
async def event_generator():
|
| 213 |
full_answer: list[str] = []
|
| 214 |
cost_usd = 0.0
|
|
|
|
| 215 |
async for event in orchestrator.run_stream(
|
| 216 |
question=body.question,
|
| 217 |
system_prompt=system_prompt,
|
|
@@ -219,24 +225,47 @@ async def ask_stream(body: AskRequest, request: Request) -> StreamingResponse:
|
|
| 219 |
strategy=body.retrieval_strategy,
|
| 220 |
history=history,
|
| 221 |
):
|
|
|
|
|
|
|
| 222 |
if event.type == "chunk" and event.content:
|
| 223 |
full_answer.append(event.content)
|
| 224 |
if event.type == "done" and event.metadata:
|
| 225 |
cost_usd = event.metadata.get("estimated_cost_usd", 0.0)
|
| 226 |
yield event.to_sse()
|
| 227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
# Record metrics and persist session after streaming completes
|
| 229 |
latency_ms = (time.perf_counter() - start) * 1000
|
| 230 |
metrics.record(latency_ms=latency_ms, cost_usd=cost_usd)
|
| 231 |
|
| 232 |
if body.session_id and conversation_store:
|
| 233 |
conversation_store.append(body.session_id, "user", body.question)
|
| 234 |
-
conversation_store.append(
|
| 235 |
-
body.session_id, "assistant", "".join(full_answer)
|
| 236 |
-
)
|
| 237 |
|
| 238 |
# --- Security: audit log for streaming ---
|
| 239 |
-
_write_audit(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
return StreamingResponse(
|
| 242 |
event_generator(),
|
|
@@ -317,6 +346,7 @@ def _write_audit(
|
|
| 317 |
body: AskRequest,
|
| 318 |
request_id: str,
|
| 319 |
injection_verdict: dict,
|
|
|
|
| 320 |
blocked: bool = False,
|
| 321 |
result: object | None = None,
|
| 322 |
output_verdict_data: dict | None = None,
|
|
@@ -332,22 +362,24 @@ def _write_audit(
|
|
| 332 |
"request_id": request_id,
|
| 333 |
"session_id": body.session_id,
|
| 334 |
"client_ip": audit_logger.hash_ip(client_ip),
|
| 335 |
-
"endpoint":
|
| 336 |
"input_query": body.question,
|
| 337 |
"injection_verdict": injection_verdict,
|
| 338 |
}
|
| 339 |
|
| 340 |
if blocked:
|
| 341 |
record["blocked"] = True
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
|
|
|
|
|
|
| 352 |
|
| 353 |
audit_logger.log(record)
|
|
|
|
| 190 |
sec_config = getattr(request.app.state.config, "security", None)
|
| 191 |
action = sec_config.injection.action if sec_config else "block"
|
| 192 |
if not verdict.safe and action == "block":
|
| 193 |
+
_write_audit(
|
| 194 |
+
request, body, request_id, injection_verdict_data,
|
| 195 |
+
endpoint="/ask/stream", blocked=True,
|
| 196 |
+
)
|
| 197 |
from fastapi.responses import JSONResponse
|
| 198 |
return JSONResponse(
|
| 199 |
status_code=403,
|
|
|
|
| 212 |
|
| 213 |
start = time.perf_counter()
|
| 214 |
|
| 215 |
+
output_validator = getattr(request.app.state, "output_validator", None)
|
| 216 |
+
|
| 217 |
async def event_generator():
|
| 218 |
full_answer: list[str] = []
|
| 219 |
cost_usd = 0.0
|
| 220 |
+
all_sources: list[str] = []
|
| 221 |
async for event in orchestrator.run_stream(
|
| 222 |
question=body.question,
|
| 223 |
system_prompt=system_prompt,
|
|
|
|
| 225 |
strategy=body.retrieval_strategy,
|
| 226 |
history=history,
|
| 227 |
):
|
| 228 |
+
if event.type == "sources" and event.sources:
|
| 229 |
+
all_sources = [s.get("source", "") for s in event.sources]
|
| 230 |
if event.type == "chunk" and event.content:
|
| 231 |
full_answer.append(event.content)
|
| 232 |
if event.type == "done" and event.metadata:
|
| 233 |
cost_usd = event.metadata.get("estimated_cost_usd", 0.0)
|
| 234 |
yield event.to_sse()
|
| 235 |
|
| 236 |
+
# --- Security: output validation (post-generation) ---
|
| 237 |
+
answer_text = "".join(full_answer)
|
| 238 |
+
output_verdict_data: dict = {"passed": True, "violations": []}
|
| 239 |
+
if output_validator:
|
| 240 |
+
from agent_bench.serving.schemas import StreamEvent as SE
|
| 241 |
+
out_verdict = output_validator.validate(
|
| 242 |
+
output=answer_text,
|
| 243 |
+
retrieved_chunks=[], # chunks already redacted by SearchTool
|
| 244 |
+
)
|
| 245 |
+
output_verdict_data = {
|
| 246 |
+
"passed": out_verdict.passed,
|
| 247 |
+
"violations": out_verdict.violations,
|
| 248 |
+
}
|
| 249 |
+
if not out_verdict.passed and out_verdict.action == "block":
|
| 250 |
+
yield SE(
|
| 251 |
+
type="chunk",
|
| 252 |
+
content="\n\n[Output filtered for safety]",
|
| 253 |
+
).to_sse()
|
| 254 |
+
|
| 255 |
# Record metrics and persist session after streaming completes
|
| 256 |
latency_ms = (time.perf_counter() - start) * 1000
|
| 257 |
metrics.record(latency_ms=latency_ms, cost_usd=cost_usd)
|
| 258 |
|
| 259 |
if body.session_id and conversation_store:
|
| 260 |
conversation_store.append(body.session_id, "user", body.question)
|
| 261 |
+
conversation_store.append(body.session_id, "assistant", answer_text)
|
|
|
|
|
|
|
| 262 |
|
| 263 |
# --- Security: audit log for streaming ---
|
| 264 |
+
_write_audit(
|
| 265 |
+
request, body, request_id, injection_verdict_data,
|
| 266 |
+
endpoint="/ask/stream",
|
| 267 |
+
output_verdict_data=output_verdict_data,
|
| 268 |
+
)
|
| 269 |
|
| 270 |
return StreamingResponse(
|
| 271 |
event_generator(),
|
|
|
|
| 346 |
body: AskRequest,
|
| 347 |
request_id: str,
|
| 348 |
injection_verdict: dict,
|
| 349 |
+
endpoint: str = "/ask",
|
| 350 |
blocked: bool = False,
|
| 351 |
result: object | None = None,
|
| 352 |
output_verdict_data: dict | None = None,
|
|
|
|
| 362 |
"request_id": request_id,
|
| 363 |
"session_id": body.session_id,
|
| 364 |
"client_ip": audit_logger.hash_ip(client_ip),
|
| 365 |
+
"endpoint": endpoint,
|
| 366 |
"input_query": body.question,
|
| 367 |
"injection_verdict": injection_verdict,
|
| 368 |
}
|
| 369 |
|
| 370 |
if blocked:
|
| 371 |
record["blocked"] = True
|
| 372 |
+
else:
|
| 373 |
+
if result is not None:
|
| 374 |
+
record.update({
|
| 375 |
+
"retrieved_chunks": [s.source for s in getattr(result, "sources", [])],
|
| 376 |
+
"llm_provider": getattr(result, "provider", ""),
|
| 377 |
+
"llm_model": getattr(result, "model", ""),
|
| 378 |
+
"output_tokens": getattr(result, "usage", None) and result.usage.output_tokens,
|
| 379 |
+
"grounded_refusal": not bool(getattr(result, "sources", [])),
|
| 380 |
+
"response_latency_ms": getattr(result, "latency_ms", 0),
|
| 381 |
+
})
|
| 382 |
+
if output_verdict_data is not None:
|
| 383 |
+
record["output_validation"] = output_verdict_data
|
| 384 |
|
| 385 |
audit_logger.log(record)
|
tests/test_security_integration.py
CHANGED
|
@@ -143,7 +143,7 @@ class TestStreamInjectionBlocking:
|
|
| 143 |
assert resp.status_code == 200
|
| 144 |
|
| 145 |
@pytest.mark.asyncio
|
| 146 |
-
async def
|
| 147 |
app = _make_security_app(tmp_path)
|
| 148 |
audit_path = tmp_path / "audit.jsonl"
|
| 149 |
transport = ASGITransport(app=app)
|
|
@@ -157,6 +157,33 @@ class TestStreamInjectionBlocking:
|
|
| 157 |
record = json.loads(audit_path.read_text().strip().split("\n")[0])
|
| 158 |
assert "request_id" in record
|
| 159 |
assert "injection_verdict" in record
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
|
| 162 |
class TestAuditLogging:
|
|
|
|
| 143 |
assert resp.status_code == 200
|
| 144 |
|
| 145 |
@pytest.mark.asyncio
|
| 146 |
+
async def test_stream_audit_written_with_correct_endpoint(self, tmp_path):
|
| 147 |
app = _make_security_app(tmp_path)
|
| 148 |
audit_path = tmp_path / "audit.jsonl"
|
| 149 |
transport = ASGITransport(app=app)
|
|
|
|
| 157 |
record = json.loads(audit_path.read_text().strip().split("\n")[0])
|
| 158 |
assert "request_id" in record
|
| 159 |
assert "injection_verdict" in record
|
| 160 |
+
assert record["endpoint"] == "/ask/stream"
|
| 161 |
+
assert "output_validation" in record
|
| 162 |
+
|
| 163 |
+
@pytest.mark.asyncio
|
| 164 |
+
async def test_stream_output_validation_runs(self, tmp_path):
|
| 165 |
+
"""Output containing PII should trigger output validation on stream."""
|
| 166 |
+
from unittest.mock import AsyncMock, patch
|
| 167 |
+
from agent_bench.core.types import TokenUsage
|
| 168 |
+
from agent_bench.serving.schemas import StreamEvent
|
| 169 |
+
|
| 170 |
+
app = _make_security_app(tmp_path)
|
| 171 |
+
|
| 172 |
+
# Mock the orchestrator to return PII in the streamed answer
|
| 173 |
+
async def fake_run_stream(**kwargs):
|
| 174 |
+
yield StreamEvent(type="sources", sources=[])
|
| 175 |
+
yield StreamEvent(type="chunk", content="Contact john@example.com for help.")
|
| 176 |
+
yield StreamEvent(type="done", metadata={"estimated_cost_usd": 0.0})
|
| 177 |
+
|
| 178 |
+
app.state.orchestrator.run_stream = fake_run_stream
|
| 179 |
+
|
| 180 |
+
transport = ASGITransport(app=app)
|
| 181 |
+
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
| 182 |
+
resp = await client.post("/ask/stream", json={
|
| 183 |
+
"question": "How do I contact support?",
|
| 184 |
+
})
|
| 185 |
+
# The response should contain the safety filter message
|
| 186 |
+
assert "[Output filtered for safety]" in resp.text
|
| 187 |
|
| 188 |
|
| 189 |
class TestAuditLogging:
|