Nomearod Claude Opus 4.6 (1M context) commited on
Commit
77c4ed4
·
1 Parent(s): 91e4512

fix: stream stage events live, thread source_chunks, fix LangChain wrapper

Browse files

- Stage events now yield immediately to the client instead of being
buffered. Only the chunk is held back for output validation.
- source_chunks threaded through _orchestrator_done metadata and passed
to output_validator.validate(), matching /ask behavior.
- LangChain AgentBenchRetriever updated for RetrievalResult wrapper.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

agent_bench/agents/orchestrator.py CHANGED
@@ -197,6 +197,7 @@ class Orchestrator:
197
  messages.append(Message(role=Role.USER, content=question))
198
  tools = self.registry.get_definitions()
199
  all_sources: list[str] = []
 
200
  total_cost = 0.0
201
  total_input_tokens = 0
202
  total_output_tokens = 0
@@ -275,6 +276,10 @@ class Orchestrator:
275
 
276
  if "sources" in result.metadata:
277
  all_sources.extend(result.metadata["sources"])
 
 
 
 
278
 
279
  # Max iterations hit — force text answer without tools
280
  # (same pattern as run(): explicit call after loop)
@@ -314,6 +319,7 @@ class Orchestrator:
314
  "tokens_in": total_input_tokens,
315
  "tokens_out": total_output_tokens,
316
  "iterations": iteration if iteration else 1,
 
317
  },
318
  )
319
 
 
197
  messages.append(Message(role=Role.USER, content=question))
198
  tools = self.registry.get_definitions()
199
  all_sources: list[str] = []
200
+ all_source_chunks: list[str] = []
201
  total_cost = 0.0
202
  total_input_tokens = 0
203
  total_output_tokens = 0
 
276
 
277
  if "sources" in result.metadata:
278
  all_sources.extend(result.metadata["sources"])
279
+ if "source_chunks" in result.metadata:
280
+ all_source_chunks.extend(
281
+ result.metadata["source_chunks"]
282
+ )
283
 
284
  # Max iterations hit — force text answer without tools
285
  # (same pattern as run(): explicit call after loop)
 
319
  "tokens_in": total_input_tokens,
320
  "tokens_out": total_output_tokens,
321
  "iterations": iteration if iteration else 1,
322
+ "source_chunks": all_source_chunks,
323
  },
324
  )
325
 
agent_bench/langchain_baseline/retriever.py CHANGED
@@ -17,7 +17,7 @@ from langchain_core.retrievers import BaseRetriever
17
  class AgentBenchRetriever(BaseRetriever):
18
  """Wraps agent-bench's async Retriever as a LangChain retriever.
19
 
20
- Delegates to Retriever.search() which returns list[SearchResult].
21
  Each SearchResult has .chunk.content, .chunk.source, .chunk.id, .score.
22
  """
23
 
@@ -32,7 +32,7 @@ class AgentBenchRetriever(BaseRetriever):
32
  *,
33
  run_manager: AsyncCallbackManagerForRetrieverRun,
34
  ) -> List[LCDocument]:
35
- results = await self.retriever.search(query, top_k=self.top_k)
36
  return [
37
  LCDocument(
38
  page_content=r.chunk.content,
@@ -42,7 +42,7 @@ class AgentBenchRetriever(BaseRetriever):
42
  "score": r.score,
43
  },
44
  )
45
- for r in results
46
  ]
47
 
48
  def _get_relevant_documents(
 
17
  class AgentBenchRetriever(BaseRetriever):
18
  """Wraps agent-bench's async Retriever as a LangChain retriever.
19
 
20
+ Delegates to Retriever.search() which returns RetrievalResult.
21
  Each SearchResult has .chunk.content, .chunk.source, .chunk.id, .score.
22
  """
23
 
 
32
  *,
33
  run_manager: AsyncCallbackManagerForRetrieverRun,
34
  ) -> List[LCDocument]:
35
+ retrieval_result = await self.retriever.search(query, top_k=self.top_k)
36
  return [
37
  LCDocument(
38
  page_content=r.chunk.content,
 
42
  "score": r.score,
43
  },
44
  )
45
+ for r in retrieval_result.results
46
  ]
47
 
48
  def _get_relevant_documents(
agent_bench/serving/routes.py CHANGED
@@ -250,8 +250,10 @@ async def ask_stream(body: AskRequest, request: Request) -> StreamingResponse:
250
  "verdict": injection_verdict_data,
251
  }).to_sse()
252
 
253
- # Buffer orchestrator events for output validation
254
- buffered_events: list = []
 
 
255
  full_answer: list[str] = []
256
  done_meta: dict = {}
257
  async for event in orchestrator.run_stream(
@@ -261,21 +263,28 @@ async def ask_stream(body: AskRequest, request: Request) -> StreamingResponse:
261
  strategy=body.retrieval_strategy,
262
  history=history,
263
  ):
264
- buffered_events.append(event)
 
 
 
 
265
  if event.type == "chunk" and event.content:
266
  full_answer.append(event.content)
267
- if event.type == "_orchestrator_done" and event.metadata:
268
- done_meta = event.metadata
 
 
269
 
270
  # --- Security: output validation (post-generation, monitor mode) ---
271
  answer_text = "".join(full_answer)
272
  filtered_answer = answer_text
273
  output_verdict_data: dict = {"passed": True, "violations": []}
274
  output_blocked = False
 
275
  if output_validator:
276
  out_verdict = output_validator.validate(
277
  output=answer_text,
278
- retrieved_chunks=[],
279
  )
280
  output_verdict_data = {
281
  "passed": out_verdict.passed,
@@ -288,15 +297,11 @@ async def ask_stream(body: AskRequest, request: Request) -> StreamingResponse:
288
  "The output was filtered for safety."
289
  )
290
 
291
- # Yield buffered orchestrator events (stage events + legacy events)
292
- # Filter out _orchestrator_done — route handler emits the real done event
293
- for event in buffered_events:
294
- if event.type == "_orchestrator_done":
295
- continue
296
- if output_blocked and event.type == "chunk":
297
- yield StreamEvent(type="chunk", content=filtered_answer).to_sse()
298
- else:
299
- yield event.to_sse()
300
 
301
  # --- Output validation stage (monitor mode, after chunk) ---
302
  yield StreamEvent(type="stage", metadata={
@@ -320,7 +325,8 @@ async def ask_stream(body: AskRequest, request: Request) -> StreamingResponse:
320
  }).to_sse()
321
 
322
  # Record metrics and persist session
323
- metrics.record(latency_ms=latency_ms, cost_usd=done_meta.get("estimated_cost_usd", 0.0))
 
324
 
325
  if body.session_id and conversation_store:
326
  conversation_store.append(body.session_id, "user", body.question)
 
250
  "verdict": injection_verdict_data,
251
  }).to_sse()
252
 
253
+ # Stream orchestrator events live. Stage events are yielded
254
+ # immediately so the dashboard can animate in real time.
255
+ # Only the chunk content is accumulated for post-stream
256
+ # output validation (monitor mode).
257
  full_answer: list[str] = []
258
  done_meta: dict = {}
259
  async for event in orchestrator.run_stream(
 
263
  strategy=body.retrieval_strategy,
264
  history=history,
265
  ):
266
+ if event.type == "_orchestrator_done":
267
+ # Extract metadata, don't yield to client
268
+ if event.metadata:
269
+ done_meta = event.metadata
270
+ continue
271
  if event.type == "chunk" and event.content:
272
  full_answer.append(event.content)
273
+ # Don't yield chunk yet — validate first
274
+ continue
275
+ # Yield stage and sources events live
276
+ yield event.to_sse()
277
 
278
  # --- Security: output validation (post-generation, monitor mode) ---
279
  answer_text = "".join(full_answer)
280
  filtered_answer = answer_text
281
  output_verdict_data: dict = {"passed": True, "violations": []}
282
  output_blocked = False
283
+ source_chunks = done_meta.get("source_chunks", [])
284
  if output_validator:
285
  out_verdict = output_validator.validate(
286
  output=answer_text,
287
+ retrieved_chunks=source_chunks,
288
  )
289
  output_verdict_data = {
290
  "passed": out_verdict.passed,
 
297
  "The output was filtered for safety."
298
  )
299
 
300
+ # Yield the (possibly filtered) answer chunk
301
+ yield StreamEvent(
302
+ type="chunk",
303
+ content=filtered_answer if output_blocked else answer_text,
304
+ ).to_sse()
 
 
 
 
305
 
306
  # --- Output validation stage (monitor mode, after chunk) ---
307
  yield StreamEvent(type="stage", metadata={
 
325
  }).to_sse()
326
 
327
  # Record metrics and persist session
328
+ cost = done_meta.get("estimated_cost_usd", 0.0)
329
+ metrics.record(latency_ms=latency_ms, cost_usd=cost)
330
 
331
  if body.session_id and conversation_store:
332
  conversation_store.append(body.session_id, "user", body.question)
tests/test_langchain_baseline/test_retriever.py CHANGED
@@ -5,6 +5,14 @@ from unittest.mock import AsyncMock, MagicMock
5
  from agent_bench.langchain_baseline.retriever import AgentBenchRetriever
6
 
7
 
 
 
 
 
 
 
 
 
8
  def _make_mock_retriever(results=None):
9
  """Create a mock of agent_bench.rag.retriever.Retriever."""
10
  retriever = MagicMock()
@@ -17,7 +25,9 @@ def _make_mock_retriever(results=None):
17
  result.score = 0.85
18
  result.rank = 1
19
  results = [result]
20
- retriever.search = AsyncMock(return_value=results)
 
 
21
  return retriever
22
 
23
 
 
5
  from agent_bench.langchain_baseline.retriever import AgentBenchRetriever
6
 
7
 
8
+ def _make_retrieval_result(results):
9
+ """Wrap a list of mock SearchResults in a RetrievalResult-like object."""
10
+ rr = MagicMock()
11
+ rr.results = results
12
+ rr.pre_rerank_count = 0
13
+ return rr
14
+
15
+
16
  def _make_mock_retriever(results=None):
17
  """Create a mock of agent_bench.rag.retriever.Retriever."""
18
  retriever = MagicMock()
 
25
  result.score = 0.85
26
  result.rank = 1
27
  results = [result]
28
+ retriever.search = AsyncMock(
29
+ return_value=_make_retrieval_result(results),
30
+ )
31
  return retriever
32
 
33