Nomearod Claude Opus 4.6 (1M context) commited on
Commit
f5d9df4
·
1 Parent(s): 3542b5b

feat: langchain retriever wrapper over existing async hybrid retriever

Browse files

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

agent_bench/langchain_baseline/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """LangChain baseline: tool-calling agent for framework comparison."""
agent_bench/langchain_baseline/retriever.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangChain BaseRetriever wrapping agent-bench's async hybrid retriever."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ from typing import TYPE_CHECKING, Any, List
7
+
8
+ from langchain_core.callbacks import (
9
+ AsyncCallbackManagerForRetrieverRun,
10
+ CallbackManagerForRetrieverRun,
11
+ )
12
+ from langchain_core.documents import Document as LCDocument
13
+ from langchain_core.retrievers import BaseRetriever
14
+
15
+ if TYPE_CHECKING:
16
+ from agent_bench.rag.retriever import Retriever
17
+
18
+
19
+ class AgentBenchRetriever(BaseRetriever):
20
+ """Wraps agent-bench's async Retriever as a LangChain retriever.
21
+
22
+ Delegates to Retriever.search() which returns list[SearchResult].
23
+ Each SearchResult has .chunk.content, .chunk.source, .chunk.id, .score.
24
+ """
25
+
26
+ retriever: Any # agent_bench.rag.retriever.Retriever (Pydantic can't validate it)
27
+ top_k: int = 5
28
+
29
+ model_config = {"arbitrary_types_allowed": True}
30
+
31
+ async def _aget_relevant_documents(
32
+ self,
33
+ query: str,
34
+ *,
35
+ run_manager: AsyncCallbackManagerForRetrieverRun,
36
+ ) -> List[LCDocument]:
37
+ results = await self.retriever.search(query, top_k=self.top_k)
38
+ return [
39
+ LCDocument(
40
+ page_content=r.chunk.content,
41
+ metadata={
42
+ "source": r.chunk.source,
43
+ "chunk_id": r.chunk.id,
44
+ "score": r.score,
45
+ },
46
+ )
47
+ for r in results
48
+ ]
49
+
50
+ def _get_relevant_documents(
51
+ self,
52
+ query: str,
53
+ *,
54
+ run_manager: CallbackManagerForRetrieverRun,
55
+ ) -> List[LCDocument]:
56
+ """Sync fallback: runs async implementation in a new event loop thread."""
57
+ loop = asyncio.new_event_loop()
58
+ try:
59
+ return loop.run_until_complete(
60
+ self._aget_relevant_documents(
61
+ query,
62
+ run_manager=AsyncCallbackManagerForRetrieverRun.get_noop_manager(),
63
+ )
64
+ )
65
+ finally:
66
+ loop.close()
tests/test_langchain_baseline/__init__.py ADDED
File without changes
tests/test_langchain_baseline/test_retriever.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for LangChain retriever wrapper around agent-bench's async Retriever."""
2
+
3
+ from unittest.mock import AsyncMock, MagicMock
4
+
5
+ import pytest
6
+
7
+ from agent_bench.langchain_baseline.retriever import AgentBenchRetriever
8
+
9
+
10
+ def _make_mock_retriever(results=None):
11
+ """Create a mock of agent_bench.rag.retriever.Retriever."""
12
+ retriever = MagicMock()
13
+ if results is None:
14
+ # Default: one result with known fields
15
+ result = MagicMock()
16
+ result.chunk.content = "Path parameters use curly braces."
17
+ result.chunk.source = "fastapi_path_params.md"
18
+ result.chunk.id = "chunk_001"
19
+ result.score = 0.85
20
+ result.rank = 1
21
+ results = [result]
22
+ retriever.search = AsyncMock(return_value=results)
23
+ return retriever
24
+
25
+
26
+ async def test_returns_langchain_documents():
27
+ mock_ret = _make_mock_retriever()
28
+ wrapper = AgentBenchRetriever(retriever=mock_ret, top_k=5)
29
+ docs = await wrapper.ainvoke("path parameters")
30
+
31
+ assert len(docs) == 1
32
+ assert docs[0].page_content == "Path parameters use curly braces."
33
+ assert docs[0].metadata["source"] == "fastapi_path_params.md"
34
+ assert docs[0].metadata["chunk_id"] == "chunk_001"
35
+ assert docs[0].metadata["score"] == 0.85
36
+
37
+
38
+ async def test_passes_top_k_to_underlying_retriever():
39
+ mock_ret = _make_mock_retriever()
40
+ wrapper = AgentBenchRetriever(retriever=mock_ret, top_k=3)
41
+ await wrapper.ainvoke("test")
42
+ mock_ret.search.assert_called_once_with("test", top_k=3)
43
+
44
+
45
+ async def test_handles_empty_results():
46
+ mock_ret = _make_mock_retriever(results=[])
47
+ wrapper = AgentBenchRetriever(retriever=mock_ret, top_k=5)
48
+ docs = await wrapper.ainvoke("nonsense")
49
+ assert docs == []
50
+
51
+
52
+ async def test_multiple_results_preserve_order():
53
+ r1 = MagicMock()
54
+ r1.chunk.content = "First"
55
+ r1.chunk.source = "a.md"
56
+ r1.chunk.id = "c1"
57
+ r1.score = 0.9
58
+
59
+ r2 = MagicMock()
60
+ r2.chunk.content = "Second"
61
+ r2.chunk.source = "b.md"
62
+ r2.chunk.id = "c2"
63
+ r2.score = 0.7
64
+
65
+ mock_ret = _make_mock_retriever(results=[r1, r2])
66
+ wrapper = AgentBenchRetriever(retriever=mock_ret, top_k=5)
67
+ docs = await wrapper.ainvoke("test")
68
+
69
+ assert len(docs) == 2
70
+ assert docs[0].page_content == "First"
71
+ assert docs[1].page_content == "Second"