Spaces:
Running
Running
fix: stream stage events live, thread source_chunks, fix LangChain wrapper
Browse files- Stage events now yield immediately to the client instead of being
buffered. Only the chunk is held back for output validation.
- source_chunks threaded through _orchestrator_done metadata and passed
to output_validator.validate(), matching /ask behavior.
- LangChain AgentBenchRetriever updated for RetrievalResult wrapper.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
agent_bench/agents/orchestrator.py
CHANGED
|
@@ -197,6 +197,7 @@ class Orchestrator:
|
|
| 197 |
messages.append(Message(role=Role.USER, content=question))
|
| 198 |
tools = self.registry.get_definitions()
|
| 199 |
all_sources: list[str] = []
|
|
|
|
| 200 |
total_cost = 0.0
|
| 201 |
total_input_tokens = 0
|
| 202 |
total_output_tokens = 0
|
|
@@ -275,6 +276,10 @@ class Orchestrator:
|
|
| 275 |
|
| 276 |
if "sources" in result.metadata:
|
| 277 |
all_sources.extend(result.metadata["sources"])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
# Max iterations hit — force text answer without tools
|
| 280 |
# (same pattern as run(): explicit call after loop)
|
|
@@ -314,6 +319,7 @@ class Orchestrator:
|
|
| 314 |
"tokens_in": total_input_tokens,
|
| 315 |
"tokens_out": total_output_tokens,
|
| 316 |
"iterations": iteration if iteration else 1,
|
|
|
|
| 317 |
},
|
| 318 |
)
|
| 319 |
|
|
|
|
| 197 |
messages.append(Message(role=Role.USER, content=question))
|
| 198 |
tools = self.registry.get_definitions()
|
| 199 |
all_sources: list[str] = []
|
| 200 |
+
all_source_chunks: list[str] = []
|
| 201 |
total_cost = 0.0
|
| 202 |
total_input_tokens = 0
|
| 203 |
total_output_tokens = 0
|
|
|
|
| 276 |
|
| 277 |
if "sources" in result.metadata:
|
| 278 |
all_sources.extend(result.metadata["sources"])
|
| 279 |
+
if "source_chunks" in result.metadata:
|
| 280 |
+
all_source_chunks.extend(
|
| 281 |
+
result.metadata["source_chunks"]
|
| 282 |
+
)
|
| 283 |
|
| 284 |
# Max iterations hit — force text answer without tools
|
| 285 |
# (same pattern as run(): explicit call after loop)
|
|
|
|
| 319 |
"tokens_in": total_input_tokens,
|
| 320 |
"tokens_out": total_output_tokens,
|
| 321 |
"iterations": iteration if iteration else 1,
|
| 322 |
+
"source_chunks": all_source_chunks,
|
| 323 |
},
|
| 324 |
)
|
| 325 |
|
agent_bench/langchain_baseline/retriever.py
CHANGED
|
@@ -17,7 +17,7 @@ from langchain_core.retrievers import BaseRetriever
|
|
| 17 |
class AgentBenchRetriever(BaseRetriever):
|
| 18 |
"""Wraps agent-bench's async Retriever as a LangChain retriever.
|
| 19 |
|
| 20 |
-
Delegates to Retriever.search() which returns
|
| 21 |
Each SearchResult has .chunk.content, .chunk.source, .chunk.id, .score.
|
| 22 |
"""
|
| 23 |
|
|
@@ -32,7 +32,7 @@ class AgentBenchRetriever(BaseRetriever):
|
|
| 32 |
*,
|
| 33 |
run_manager: AsyncCallbackManagerForRetrieverRun,
|
| 34 |
) -> List[LCDocument]:
|
| 35 |
-
|
| 36 |
return [
|
| 37 |
LCDocument(
|
| 38 |
page_content=r.chunk.content,
|
|
@@ -42,7 +42,7 @@ class AgentBenchRetriever(BaseRetriever):
|
|
| 42 |
"score": r.score,
|
| 43 |
},
|
| 44 |
)
|
| 45 |
-
for r in results
|
| 46 |
]
|
| 47 |
|
| 48 |
def _get_relevant_documents(
|
|
|
|
| 17 |
class AgentBenchRetriever(BaseRetriever):
|
| 18 |
"""Wraps agent-bench's async Retriever as a LangChain retriever.
|
| 19 |
|
| 20 |
+
Delegates to Retriever.search() which returns RetrievalResult.
|
| 21 |
Each SearchResult has .chunk.content, .chunk.source, .chunk.id, .score.
|
| 22 |
"""
|
| 23 |
|
|
|
|
| 32 |
*,
|
| 33 |
run_manager: AsyncCallbackManagerForRetrieverRun,
|
| 34 |
) -> List[LCDocument]:
|
| 35 |
+
retrieval_result = await self.retriever.search(query, top_k=self.top_k)
|
| 36 |
return [
|
| 37 |
LCDocument(
|
| 38 |
page_content=r.chunk.content,
|
|
|
|
| 42 |
"score": r.score,
|
| 43 |
},
|
| 44 |
)
|
| 45 |
+
for r in retrieval_result.results
|
| 46 |
]
|
| 47 |
|
| 48 |
def _get_relevant_documents(
|
agent_bench/serving/routes.py
CHANGED
|
@@ -250,8 +250,10 @@ async def ask_stream(body: AskRequest, request: Request) -> StreamingResponse:
|
|
| 250 |
"verdict": injection_verdict_data,
|
| 251 |
}).to_sse()
|
| 252 |
|
| 253 |
-
#
|
| 254 |
-
|
|
|
|
|
|
|
| 255 |
full_answer: list[str] = []
|
| 256 |
done_meta: dict = {}
|
| 257 |
async for event in orchestrator.run_stream(
|
|
@@ -261,21 +263,28 @@ async def ask_stream(body: AskRequest, request: Request) -> StreamingResponse:
|
|
| 261 |
strategy=body.retrieval_strategy,
|
| 262 |
history=history,
|
| 263 |
):
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
if event.type == "chunk" and event.content:
|
| 266 |
full_answer.append(event.content)
|
| 267 |
-
|
| 268 |
-
|
|
|
|
|
|
|
| 269 |
|
| 270 |
# --- Security: output validation (post-generation, monitor mode) ---
|
| 271 |
answer_text = "".join(full_answer)
|
| 272 |
filtered_answer = answer_text
|
| 273 |
output_verdict_data: dict = {"passed": True, "violations": []}
|
| 274 |
output_blocked = False
|
|
|
|
| 275 |
if output_validator:
|
| 276 |
out_verdict = output_validator.validate(
|
| 277 |
output=answer_text,
|
| 278 |
-
retrieved_chunks=
|
| 279 |
)
|
| 280 |
output_verdict_data = {
|
| 281 |
"passed": out_verdict.passed,
|
|
@@ -288,15 +297,11 @@ async def ask_stream(body: AskRequest, request: Request) -> StreamingResponse:
|
|
| 288 |
"The output was filtered for safety."
|
| 289 |
)
|
| 290 |
|
| 291 |
-
# Yield
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
if
|
| 295 |
-
|
| 296 |
-
if output_blocked and event.type == "chunk":
|
| 297 |
-
yield StreamEvent(type="chunk", content=filtered_answer).to_sse()
|
| 298 |
-
else:
|
| 299 |
-
yield event.to_sse()
|
| 300 |
|
| 301 |
# --- Output validation stage (monitor mode, after chunk) ---
|
| 302 |
yield StreamEvent(type="stage", metadata={
|
|
@@ -320,7 +325,8 @@ async def ask_stream(body: AskRequest, request: Request) -> StreamingResponse:
|
|
| 320 |
}).to_sse()
|
| 321 |
|
| 322 |
# Record metrics and persist session
|
| 323 |
-
|
|
|
|
| 324 |
|
| 325 |
if body.session_id and conversation_store:
|
| 326 |
conversation_store.append(body.session_id, "user", body.question)
|
|
|
|
| 250 |
"verdict": injection_verdict_data,
|
| 251 |
}).to_sse()
|
| 252 |
|
| 253 |
+
# Stream orchestrator events live. Stage events are yielded
|
| 254 |
+
# immediately so the dashboard can animate in real time.
|
| 255 |
+
# Only the chunk content is accumulated for post-stream
|
| 256 |
+
# output validation (monitor mode).
|
| 257 |
full_answer: list[str] = []
|
| 258 |
done_meta: dict = {}
|
| 259 |
async for event in orchestrator.run_stream(
|
|
|
|
| 263 |
strategy=body.retrieval_strategy,
|
| 264 |
history=history,
|
| 265 |
):
|
| 266 |
+
if event.type == "_orchestrator_done":
|
| 267 |
+
# Extract metadata, don't yield to client
|
| 268 |
+
if event.metadata:
|
| 269 |
+
done_meta = event.metadata
|
| 270 |
+
continue
|
| 271 |
if event.type == "chunk" and event.content:
|
| 272 |
full_answer.append(event.content)
|
| 273 |
+
# Don't yield chunk yet — validate first
|
| 274 |
+
continue
|
| 275 |
+
# Yield stage and sources events live
|
| 276 |
+
yield event.to_sse()
|
| 277 |
|
| 278 |
# --- Security: output validation (post-generation, monitor mode) ---
|
| 279 |
answer_text = "".join(full_answer)
|
| 280 |
filtered_answer = answer_text
|
| 281 |
output_verdict_data: dict = {"passed": True, "violations": []}
|
| 282 |
output_blocked = False
|
| 283 |
+
source_chunks = done_meta.get("source_chunks", [])
|
| 284 |
if output_validator:
|
| 285 |
out_verdict = output_validator.validate(
|
| 286 |
output=answer_text,
|
| 287 |
+
retrieved_chunks=source_chunks,
|
| 288 |
)
|
| 289 |
output_verdict_data = {
|
| 290 |
"passed": out_verdict.passed,
|
|
|
|
| 297 |
"The output was filtered for safety."
|
| 298 |
)
|
| 299 |
|
| 300 |
+
# Yield the (possibly filtered) answer chunk
|
| 301 |
+
yield StreamEvent(
|
| 302 |
+
type="chunk",
|
| 303 |
+
content=filtered_answer if output_blocked else answer_text,
|
| 304 |
+
).to_sse()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
|
| 306 |
# --- Output validation stage (monitor mode, after chunk) ---
|
| 307 |
yield StreamEvent(type="stage", metadata={
|
|
|
|
| 325 |
}).to_sse()
|
| 326 |
|
| 327 |
# Record metrics and persist session
|
| 328 |
+
cost = done_meta.get("estimated_cost_usd", 0.0)
|
| 329 |
+
metrics.record(latency_ms=latency_ms, cost_usd=cost)
|
| 330 |
|
| 331 |
if body.session_id and conversation_store:
|
| 332 |
conversation_store.append(body.session_id, "user", body.question)
|
tests/test_langchain_baseline/test_retriever.py
CHANGED
|
@@ -5,6 +5,14 @@ from unittest.mock import AsyncMock, MagicMock
|
|
| 5 |
from agent_bench.langchain_baseline.retriever import AgentBenchRetriever
|
| 6 |
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
def _make_mock_retriever(results=None):
|
| 9 |
"""Create a mock of agent_bench.rag.retriever.Retriever."""
|
| 10 |
retriever = MagicMock()
|
|
@@ -17,7 +25,9 @@ def _make_mock_retriever(results=None):
|
|
| 17 |
result.score = 0.85
|
| 18 |
result.rank = 1
|
| 19 |
results = [result]
|
| 20 |
-
retriever.search = AsyncMock(
|
|
|
|
|
|
|
| 21 |
return retriever
|
| 22 |
|
| 23 |
|
|
|
|
| 5 |
from agent_bench.langchain_baseline.retriever import AgentBenchRetriever
|
| 6 |
|
| 7 |
|
| 8 |
+
def _make_retrieval_result(results):
|
| 9 |
+
"""Wrap a list of mock SearchResults in a RetrievalResult-like object."""
|
| 10 |
+
rr = MagicMock()
|
| 11 |
+
rr.results = results
|
| 12 |
+
rr.pre_rerank_count = 0
|
| 13 |
+
return rr
|
| 14 |
+
|
| 15 |
+
|
| 16 |
def _make_mock_retriever(results=None):
|
| 17 |
"""Create a mock of agent_bench.rag.retriever.Retriever."""
|
| 18 |
retriever = MagicMock()
|
|
|
|
| 25 |
result.score = 0.85
|
| 26 |
result.rank = 1
|
| 27 |
results = [result]
|
| 28 |
+
retriever.search = AsyncMock(
|
| 29 |
+
return_value=_make_retrieval_result(results),
|
| 30 |
+
)
|
| 31 |
return retriever
|
| 32 |
|
| 33 |
|