Spaces:
Sleeping
Sleeping
feat: langchain retriever wrapper over existing async hybrid retriever
Browse filesCo-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
agent_bench/langchain_baseline/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""LangChain baseline: tool-calling agent for framework comparison."""
|
agent_bench/langchain_baseline/retriever.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LangChain BaseRetriever wrapping agent-bench's async hybrid retriever."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
from typing import TYPE_CHECKING, Any, List
|
| 7 |
+
|
| 8 |
+
from langchain_core.callbacks import (
|
| 9 |
+
AsyncCallbackManagerForRetrieverRun,
|
| 10 |
+
CallbackManagerForRetrieverRun,
|
| 11 |
+
)
|
| 12 |
+
from langchain_core.documents import Document as LCDocument
|
| 13 |
+
from langchain_core.retrievers import BaseRetriever
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
from agent_bench.rag.retriever import Retriever
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class AgentBenchRetriever(BaseRetriever):
|
| 20 |
+
"""Wraps agent-bench's async Retriever as a LangChain retriever.
|
| 21 |
+
|
| 22 |
+
Delegates to Retriever.search() which returns list[SearchResult].
|
| 23 |
+
Each SearchResult has .chunk.content, .chunk.source, .chunk.id, .score.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
retriever: Any # agent_bench.rag.retriever.Retriever (Pydantic can't validate it)
|
| 27 |
+
top_k: int = 5
|
| 28 |
+
|
| 29 |
+
model_config = {"arbitrary_types_allowed": True}
|
| 30 |
+
|
| 31 |
+
async def _aget_relevant_documents(
|
| 32 |
+
self,
|
| 33 |
+
query: str,
|
| 34 |
+
*,
|
| 35 |
+
run_manager: AsyncCallbackManagerForRetrieverRun,
|
| 36 |
+
) -> List[LCDocument]:
|
| 37 |
+
results = await self.retriever.search(query, top_k=self.top_k)
|
| 38 |
+
return [
|
| 39 |
+
LCDocument(
|
| 40 |
+
page_content=r.chunk.content,
|
| 41 |
+
metadata={
|
| 42 |
+
"source": r.chunk.source,
|
| 43 |
+
"chunk_id": r.chunk.id,
|
| 44 |
+
"score": r.score,
|
| 45 |
+
},
|
| 46 |
+
)
|
| 47 |
+
for r in results
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
def _get_relevant_documents(
|
| 51 |
+
self,
|
| 52 |
+
query: str,
|
| 53 |
+
*,
|
| 54 |
+
run_manager: CallbackManagerForRetrieverRun,
|
| 55 |
+
) -> List[LCDocument]:
|
| 56 |
+
"""Sync fallback: runs async implementation in a new event loop thread."""
|
| 57 |
+
loop = asyncio.new_event_loop()
|
| 58 |
+
try:
|
| 59 |
+
return loop.run_until_complete(
|
| 60 |
+
self._aget_relevant_documents(
|
| 61 |
+
query,
|
| 62 |
+
run_manager=AsyncCallbackManagerForRetrieverRun.get_noop_manager(),
|
| 63 |
+
)
|
| 64 |
+
)
|
| 65 |
+
finally:
|
| 66 |
+
loop.close()
|
tests/test_langchain_baseline/__init__.py
ADDED
|
File without changes
|
tests/test_langchain_baseline/test_retriever.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for LangChain retriever wrapper around agent-bench's async Retriever."""
|
| 2 |
+
|
| 3 |
+
from unittest.mock import AsyncMock, MagicMock
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from agent_bench.langchain_baseline.retriever import AgentBenchRetriever
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _make_mock_retriever(results=None):
|
| 11 |
+
"""Create a mock of agent_bench.rag.retriever.Retriever."""
|
| 12 |
+
retriever = MagicMock()
|
| 13 |
+
if results is None:
|
| 14 |
+
# Default: one result with known fields
|
| 15 |
+
result = MagicMock()
|
| 16 |
+
result.chunk.content = "Path parameters use curly braces."
|
| 17 |
+
result.chunk.source = "fastapi_path_params.md"
|
| 18 |
+
result.chunk.id = "chunk_001"
|
| 19 |
+
result.score = 0.85
|
| 20 |
+
result.rank = 1
|
| 21 |
+
results = [result]
|
| 22 |
+
retriever.search = AsyncMock(return_value=results)
|
| 23 |
+
return retriever
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
async def test_returns_langchain_documents():
|
| 27 |
+
mock_ret = _make_mock_retriever()
|
| 28 |
+
wrapper = AgentBenchRetriever(retriever=mock_ret, top_k=5)
|
| 29 |
+
docs = await wrapper.ainvoke("path parameters")
|
| 30 |
+
|
| 31 |
+
assert len(docs) == 1
|
| 32 |
+
assert docs[0].page_content == "Path parameters use curly braces."
|
| 33 |
+
assert docs[0].metadata["source"] == "fastapi_path_params.md"
|
| 34 |
+
assert docs[0].metadata["chunk_id"] == "chunk_001"
|
| 35 |
+
assert docs[0].metadata["score"] == 0.85
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
async def test_passes_top_k_to_underlying_retriever():
|
| 39 |
+
mock_ret = _make_mock_retriever()
|
| 40 |
+
wrapper = AgentBenchRetriever(retriever=mock_ret, top_k=3)
|
| 41 |
+
await wrapper.ainvoke("test")
|
| 42 |
+
mock_ret.search.assert_called_once_with("test", top_k=3)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
async def test_handles_empty_results():
|
| 46 |
+
mock_ret = _make_mock_retriever(results=[])
|
| 47 |
+
wrapper = AgentBenchRetriever(retriever=mock_ret, top_k=5)
|
| 48 |
+
docs = await wrapper.ainvoke("nonsense")
|
| 49 |
+
assert docs == []
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
async def test_multiple_results_preserve_order():
|
| 53 |
+
r1 = MagicMock()
|
| 54 |
+
r1.chunk.content = "First"
|
| 55 |
+
r1.chunk.source = "a.md"
|
| 56 |
+
r1.chunk.id = "c1"
|
| 57 |
+
r1.score = 0.9
|
| 58 |
+
|
| 59 |
+
r2 = MagicMock()
|
| 60 |
+
r2.chunk.content = "Second"
|
| 61 |
+
r2.chunk.source = "b.md"
|
| 62 |
+
r2.chunk.id = "c2"
|
| 63 |
+
r2.score = 0.7
|
| 64 |
+
|
| 65 |
+
mock_ret = _make_mock_retriever(results=[r1, r2])
|
| 66 |
+
wrapper = AgentBenchRetriever(retriever=mock_ret, top_k=5)
|
| 67 |
+
docs = await wrapper.ainvoke("test")
|
| 68 |
+
|
| 69 |
+
assert len(docs) == 2
|
| 70 |
+
assert docs[0].page_content == "First"
|
| 71 |
+
assert docs[1].page_content == "Second"
|