Nomearod Claude Opus 4.6 (1M context) commited on
Commit
02f7f66
·
1 Parent(s): 14985f8

fix(security): output validation on /ask/stream, correct audit endpoint

Browse files

- /ask/stream now runs output validation on the assembled answer
after streaming completes. PII in streamed output triggers a
"[Output filtered for safety]" SSE chunk.
- _write_audit() takes an endpoint parameter instead of hardcoding
"/ask". Stream audit records are labeled "/ask/stream".
- _write_audit() records output_validation independently of result
metadata, so streaming audit includes validation verdicts.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

agent_bench/serving/routes.py CHANGED
@@ -190,7 +190,10 @@ async def ask_stream(body: AskRequest, request: Request) -> StreamingResponse:
190
  sec_config = getattr(request.app.state.config, "security", None)
191
  action = sec_config.injection.action if sec_config else "block"
192
  if not verdict.safe and action == "block":
193
- _write_audit(request, body, request_id, injection_verdict_data, blocked=True)
 
 
 
194
  from fastapi.responses import JSONResponse
195
  return JSONResponse(
196
  status_code=403,
@@ -209,9 +212,12 @@ async def ask_stream(body: AskRequest, request: Request) -> StreamingResponse:
209
 
210
  start = time.perf_counter()
211
 
 
 
212
  async def event_generator():
213
  full_answer: list[str] = []
214
  cost_usd = 0.0
 
215
  async for event in orchestrator.run_stream(
216
  question=body.question,
217
  system_prompt=system_prompt,
@@ -219,24 +225,47 @@ async def ask_stream(body: AskRequest, request: Request) -> StreamingResponse:
219
  strategy=body.retrieval_strategy,
220
  history=history,
221
  ):
 
 
222
  if event.type == "chunk" and event.content:
223
  full_answer.append(event.content)
224
  if event.type == "done" and event.metadata:
225
  cost_usd = event.metadata.get("estimated_cost_usd", 0.0)
226
  yield event.to_sse()
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  # Record metrics and persist session after streaming completes
229
  latency_ms = (time.perf_counter() - start) * 1000
230
  metrics.record(latency_ms=latency_ms, cost_usd=cost_usd)
231
 
232
  if body.session_id and conversation_store:
233
  conversation_store.append(body.session_id, "user", body.question)
234
- conversation_store.append(
235
- body.session_id, "assistant", "".join(full_answer)
236
- )
237
 
238
  # --- Security: audit log for streaming ---
239
- _write_audit(request, body, request_id, injection_verdict_data)
 
 
 
 
240
 
241
  return StreamingResponse(
242
  event_generator(),
@@ -317,6 +346,7 @@ def _write_audit(
317
  body: AskRequest,
318
  request_id: str,
319
  injection_verdict: dict,
 
320
  blocked: bool = False,
321
  result: object | None = None,
322
  output_verdict_data: dict | None = None,
@@ -332,22 +362,24 @@ def _write_audit(
332
  "request_id": request_id,
333
  "session_id": body.session_id,
334
  "client_ip": audit_logger.hash_ip(client_ip),
335
- "endpoint": "/ask",
336
  "input_query": body.question,
337
  "injection_verdict": injection_verdict,
338
  }
339
 
340
  if blocked:
341
  record["blocked"] = True
342
- elif result is not None:
343
- record.update({
344
- "retrieved_chunks": [s.source for s in getattr(result, "sources", [])],
345
- "llm_provider": getattr(result, "provider", ""),
346
- "llm_model": getattr(result, "model", ""),
347
- "output_tokens": getattr(result, "usage", None) and result.usage.output_tokens,
348
- "output_validation": output_verdict_data or {},
349
- "grounded_refusal": not bool(getattr(result, "sources", [])),
350
- "response_latency_ms": getattr(result, "latency_ms", 0),
351
- })
 
 
352
 
353
  audit_logger.log(record)
 
190
  sec_config = getattr(request.app.state.config, "security", None)
191
  action = sec_config.injection.action if sec_config else "block"
192
  if not verdict.safe and action == "block":
193
+ _write_audit(
194
+ request, body, request_id, injection_verdict_data,
195
+ endpoint="/ask/stream", blocked=True,
196
+ )
197
  from fastapi.responses import JSONResponse
198
  return JSONResponse(
199
  status_code=403,
 
212
 
213
  start = time.perf_counter()
214
 
215
+ output_validator = getattr(request.app.state, "output_validator", None)
216
+
217
  async def event_generator():
218
  full_answer: list[str] = []
219
  cost_usd = 0.0
220
+ all_sources: list[str] = []
221
  async for event in orchestrator.run_stream(
222
  question=body.question,
223
  system_prompt=system_prompt,
 
225
  strategy=body.retrieval_strategy,
226
  history=history,
227
  ):
228
+ if event.type == "sources" and event.sources:
229
+ all_sources = [s.get("source", "") for s in event.sources]
230
  if event.type == "chunk" and event.content:
231
  full_answer.append(event.content)
232
  if event.type == "done" and event.metadata:
233
  cost_usd = event.metadata.get("estimated_cost_usd", 0.0)
234
  yield event.to_sse()
235
 
236
+ # --- Security: output validation (post-generation) ---
237
+ answer_text = "".join(full_answer)
238
+ output_verdict_data: dict = {"passed": True, "violations": []}
239
+ if output_validator:
240
+ from agent_bench.serving.schemas import StreamEvent as SE
241
+ out_verdict = output_validator.validate(
242
+ output=answer_text,
243
+ retrieved_chunks=[], # chunks already redacted by SearchTool
244
+ )
245
+ output_verdict_data = {
246
+ "passed": out_verdict.passed,
247
+ "violations": out_verdict.violations,
248
+ }
249
+ if not out_verdict.passed and out_verdict.action == "block":
250
+ yield SE(
251
+ type="chunk",
252
+ content="\n\n[Output filtered for safety]",
253
+ ).to_sse()
254
+
255
  # Record metrics and persist session after streaming completes
256
  latency_ms = (time.perf_counter() - start) * 1000
257
  metrics.record(latency_ms=latency_ms, cost_usd=cost_usd)
258
 
259
  if body.session_id and conversation_store:
260
  conversation_store.append(body.session_id, "user", body.question)
261
+ conversation_store.append(body.session_id, "assistant", answer_text)
 
 
262
 
263
  # --- Security: audit log for streaming ---
264
+ _write_audit(
265
+ request, body, request_id, injection_verdict_data,
266
+ endpoint="/ask/stream",
267
+ output_verdict_data=output_verdict_data,
268
+ )
269
 
270
  return StreamingResponse(
271
  event_generator(),
 
346
  body: AskRequest,
347
  request_id: str,
348
  injection_verdict: dict,
349
+ endpoint: str = "/ask",
350
  blocked: bool = False,
351
  result: object | None = None,
352
  output_verdict_data: dict | None = None,
 
362
  "request_id": request_id,
363
  "session_id": body.session_id,
364
  "client_ip": audit_logger.hash_ip(client_ip),
365
+ "endpoint": endpoint,
366
  "input_query": body.question,
367
  "injection_verdict": injection_verdict,
368
  }
369
 
370
  if blocked:
371
  record["blocked"] = True
372
+ else:
373
+ if result is not None:
374
+ record.update({
375
+ "retrieved_chunks": [s.source for s in getattr(result, "sources", [])],
376
+ "llm_provider": getattr(result, "provider", ""),
377
+ "llm_model": getattr(result, "model", ""),
378
+ "output_tokens": getattr(result, "usage", None) and result.usage.output_tokens,
379
+ "grounded_refusal": not bool(getattr(result, "sources", [])),
380
+ "response_latency_ms": getattr(result, "latency_ms", 0),
381
+ })
382
+ if output_verdict_data is not None:
383
+ record["output_validation"] = output_verdict_data
384
 
385
  audit_logger.log(record)
tests/test_security_integration.py CHANGED
@@ -143,7 +143,7 @@ class TestStreamInjectionBlocking:
143
  assert resp.status_code == 200
144
 
145
  @pytest.mark.asyncio
146
- async def test_stream_audit_written(self, tmp_path):
147
  app = _make_security_app(tmp_path)
148
  audit_path = tmp_path / "audit.jsonl"
149
  transport = ASGITransport(app=app)
@@ -157,6 +157,33 @@ class TestStreamInjectionBlocking:
157
  record = json.loads(audit_path.read_text().strip().split("\n")[0])
158
  assert "request_id" in record
159
  assert "injection_verdict" in record
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
 
162
  class TestAuditLogging:
 
143
  assert resp.status_code == 200
144
 
145
  @pytest.mark.asyncio
146
+ async def test_stream_audit_written_with_correct_endpoint(self, tmp_path):
147
  app = _make_security_app(tmp_path)
148
  audit_path = tmp_path / "audit.jsonl"
149
  transport = ASGITransport(app=app)
 
157
  record = json.loads(audit_path.read_text().strip().split("\n")[0])
158
  assert "request_id" in record
159
  assert "injection_verdict" in record
160
+ assert record["endpoint"] == "/ask/stream"
161
+ assert "output_validation" in record
162
+
163
+ @pytest.mark.asyncio
164
+ async def test_stream_output_validation_runs(self, tmp_path):
165
+ """Output containing PII should trigger output validation on stream."""
166
+ from unittest.mock import AsyncMock, patch
167
+ from agent_bench.core.types import TokenUsage
168
+ from agent_bench.serving.schemas import StreamEvent
169
+
170
+ app = _make_security_app(tmp_path)
171
+
172
+ # Mock the orchestrator to return PII in the streamed answer
173
+ async def fake_run_stream(**kwargs):
174
+ yield StreamEvent(type="sources", sources=[])
175
+ yield StreamEvent(type="chunk", content="Contact john@example.com for help.")
176
+ yield StreamEvent(type="done", metadata={"estimated_cost_usd": 0.0})
177
+
178
+ app.state.orchestrator.run_stream = fake_run_stream
179
+
180
+ transport = ASGITransport(app=app)
181
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
182
+ resp = await client.post("/ask/stream", json={
183
+ "question": "How do I contact support?",
184
+ })
185
+ # The response should contain the safety filter message
186
+ assert "[Output filtered for safety]" in resp.text
187
 
188
 
189
  class TestAuditLogging: