"""Tests for agent tools (hybrid_search, list_documents, fetch_document, search_within_document, multi_query_search, summarize_document).""" from unittest.mock import MagicMock import pytest from src.agent.tools import ToolResultStore, make_retrieval_tools, _merge_results, _format_results from src.models import DocumentChunk, QueryResult from src.retrieval.hybrid import HybridSearchResult # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _chunk(chunk_id: str = "c1", document_id: str = "doc.pdf", text: str = "text", page_number: int = 1, chunk_index: int = 0) -> DocumentChunk: return DocumentChunk( chunk_id=chunk_id, document_id=document_id, text=text, metadata={"page_number": page_number, "chunk_index": chunk_index}, ) def _qr(chunk_id: str = "c1", document_id: str = "doc.pdf", text: str = "text", score: float = 0.8, source: str = "hybrid", page_number: int = 1) -> QueryResult: return QueryResult( chunk=_chunk(chunk_id=chunk_id, document_id=document_id, text=text, page_number=page_number), score=score, source=source, ) def _hybrid_result(results: list[QueryResult]) -> HybridSearchResult: return HybridSearchResult( dense_results=results, sparse_results=results, fused_results=results, ) @pytest.fixture def components(): """Create mock retriever, reranker, vector_store, and store.""" retriever = MagicMock() reranker = MagicMock() vector_store = MagicMock() store = ToolResultStore() return retriever, reranker, vector_store, store # --------------------------------------------------------------------------- # Unit tests for helper functions # --------------------------------------------------------------------------- class TestMergeResults: def test_merge_empty(self) -> None: assert _merge_results([], []) == [] def test_merge_keeps_higher_score(self) -> None: old = [_qr(chunk_id="c1", score=0.5)] new = [_qr(chunk_id="c1", score=0.9)] merged = _merge_results(old, new) assert len(merged) == 1 assert merged[0].score == 0.9 def test_merge_keeps_old_if_higher(self) -> None: old = [_qr(chunk_id="c1", score=0.9)] new = [_qr(chunk_id="c1", score=0.5)] merged = _merge_results(old, new) assert merged[0].score == 0.9 def test_merge_combines_different_ids(self) -> None: old = [_qr(chunk_id="c1", score=0.5)] new = [_qr(chunk_id="c2", score=0.9)] merged = _merge_results(old, new) assert len(merged) == 2 assert merged[0].chunk.chunk_id == "c2" # higher score first def test_merge_sorted_descending(self) -> None: results = [_qr(chunk_id=f"c{i}", score=s) for i, s in enumerate([0.3, 0.9, 0.6])] merged = _merge_results([], results) scores = [r.score for r in merged] assert scores == sorted(scores, reverse=True) class TestFormatResults: def test_empty_returns_no_results_message(self) -> None: result = _format_results([]) assert "Ingen relevante" in result def test_includes_document_id_and_score(self) -> None: results = [_qr(document_id="policy.pdf", score=0.85)] text = _format_results(results) assert "policy.pdf" in text assert "0.850" in text def test_includes_page_number(self) -> None: results = [_qr(page_number=5)] text = _format_results(results) assert "side 5" in text # --------------------------------------------------------------------------- # hybrid_search # --------------------------------------------------------------------------- class TestHybridSearch: def test_returns_formatted_results(self, components) -> None: retriever, reranker, vector_store, store = components results = [_qr(document_id="a.pdf", score=0.9, text="answer")] retriever.search_detailed.return_value = _hybrid_result(results) reranker.rerank.return_value = results tools = make_retrieval_tools(retriever, reranker, vector_store, store) hybrid_search = tools[0] output = hybrid_search.invoke({"query": "test", "top_k": 5}) assert "a.pdf" in output assert "answer" in output retriever.search_detailed.assert_called_once_with("test", top_k=5) def test_accumulates_in_store(self, components) -> None: retriever, reranker, vector_store, store = components results = [_qr(chunk_id="c1", score=0.8)] retriever.search_detailed.return_value = _hybrid_result(results) reranker.rerank.return_value = results tools = make_retrieval_tools(retriever, reranker, vector_store, store) tools[0].invoke({"query": "q1"}) assert len(store.retrieved) == 1 assert store.retrieved[0].chunk.chunk_id == "c1" assert len(store.tool_calls) == 1 assert store.tool_calls[0] == ("hybrid_search", "q1") def test_no_results(self, components) -> None: retriever, reranker, vector_store, store = components retriever.search_detailed.return_value = _hybrid_result([]) reranker.rerank.return_value = [] tools = make_retrieval_tools(retriever, reranker, vector_store, store) output = tools[0].invoke({"query": "nothing"}) assert "Ingen relevante" in output # --------------------------------------------------------------------------- # list_documents # --------------------------------------------------------------------------- class TestListDocuments: def test_returns_document_list(self, components) -> None: retriever, reranker, vector_store, store = components vector_store.list_document_ids.return_value = ["a.pdf", "b.pdf"] tools = make_retrieval_tools(retriever, reranker, vector_store, store) list_docs = tools[1] output = list_docs.invoke({}) assert "a.pdf" in output assert "b.pdf" in output assert "2 i alt" in output def test_empty_knowledge_base(self, components) -> None: retriever, reranker, vector_store, store = components vector_store.list_document_ids.return_value = [] tools = make_retrieval_tools(retriever, reranker, vector_store, store) output = tools[1].invoke({}) assert "empty" in output.lower() or "Ingen" in output # --------------------------------------------------------------------------- # fetch_document # --------------------------------------------------------------------------- class TestFetchDocument: def test_returns_full_text(self, components) -> None: retriever, reranker, vector_store, store = components chunks = [_chunk(chunk_id="c1", text="page1"), _chunk(chunk_id="c2", text="page2")] vector_store.get_chunks_by_document_id.return_value = chunks tools = make_retrieval_tools(retriever, reranker, vector_store, store) fetch = tools[2] output = fetch.invoke({"document_id": "doc.pdf"}) assert "page1" in output assert "page2" in output assert len(store.retrieved) == 2 def test_document_not_found(self, components) -> None: retriever, reranker, vector_store, store = components vector_store.get_chunks_by_document_id.return_value = [] tools = make_retrieval_tools(retriever, reranker, vector_store, store) output = tools[2].invoke({"document_id": "missing.pdf"}) assert "ikke fundet" in output # --------------------------------------------------------------------------- # search_within_document # --------------------------------------------------------------------------- class TestSearchWithinDocument: def test_reranks_document_chunks(self, components) -> None: retriever, reranker, vector_store, store = components chunks = [ _chunk(chunk_id="c1", text="irrelevant"), _chunk(chunk_id="c2", text="relevant answer"), ] vector_store.get_chunks_by_document_id.return_value = chunks reranker.rerank.return_value = [_qr(chunk_id="c2", text="relevant answer", score=0.95)] tools = make_retrieval_tools(retriever, reranker, vector_store, store) search_within = tools[3] output = search_within.invoke({"document_id": "doc.pdf", "query": "answer"}) assert "relevant answer" in output assert "0.950" in output reranker.rerank.assert_called_once() # Verify it passed all chunks to reranker candidates = reranker.rerank.call_args[0][1] assert len(candidates) == 2 def test_document_not_found(self, components) -> None: retriever, reranker, vector_store, store = components vector_store.get_chunks_by_document_id.return_value = [] tools = make_retrieval_tools(retriever, reranker, vector_store, store) output = tools[3].invoke({"document_id": "missing.pdf", "query": "test"}) assert "ikke fundet" in output def test_accumulates_in_store(self, components) -> None: retriever, reranker, vector_store, store = components chunks = [_chunk(chunk_id="c1")] vector_store.get_chunks_by_document_id.return_value = chunks reranker.rerank.return_value = [_qr(chunk_id="c1", score=0.7)] tools = make_retrieval_tools(retriever, reranker, vector_store, store) tools[3].invoke({"document_id": "doc.pdf", "query": "q"}) assert len(store.retrieved) == 1 assert store.tool_calls[-1][0] == "search_within_document" # --------------------------------------------------------------------------- # multi_query_search (requires llm_chain) # --------------------------------------------------------------------------- class TestMultiQuerySearch: def test_decomposes_and_searches(self, components) -> None: retriever, reranker, vector_store, store = components llm_chain = MagicMock() # LLM returns 2 sub-queries llm_chain.invoke.return_value = "eksamenregler bachelor\neksamensregler kandidat" results_a = [_qr(chunk_id="c1", score=0.9, text="bachelor exam")] results_b = [_qr(chunk_id="c2", score=0.85, text="master exam")] retriever.search_detailed.side_effect = [ _hybrid_result(results_a), _hybrid_result(results_b), ] reranker.rerank.side_effect = [results_a, results_b] tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain) multi_search = tools[4] output = multi_search.invoke({"question": "Compare exam rules"}) assert "delforespørgsler" in output assert retriever.search_detailed.call_count == 2 assert reranker.rerank.call_count == 2 assert len(store.retrieved) == 2 def test_fallback_when_decompose_fails(self, components) -> None: retriever, reranker, vector_store, store = components llm_chain = MagicMock() # LLM returns empty/garbage llm_chain.invoke.return_value = "" results = [_qr(chunk_id="c1", score=0.8)] retriever.search_detailed.return_value = _hybrid_result(results) reranker.rerank.return_value = results tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain) output = tools[4].invoke({"question": "original question"}) # Should fall back to the original question as single query assert retriever.search_detailed.call_count == 1 assert "0.800" in output def test_not_available_without_llm(self, components) -> None: retriever, reranker, vector_store, store = components tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=None) tool_names = [t.name for t in tools] assert "multi_query_search" not in tool_names assert "summarize_document" not in tool_names def test_deduplicates_across_sub_queries(self, components) -> None: retriever, reranker, vector_store, store = components llm_chain = MagicMock() llm_chain.invoke.return_value = "query1\nquery2" # Both sub-queries return the same chunk same_result = [_qr(chunk_id="c1", score=0.8)] retriever.search_detailed.return_value = _hybrid_result(same_result) reranker.rerank.return_value = same_result tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain) tools[4].invoke({"question": "test"}) # Should be deduplicated to 1 assert len(store.retrieved) == 1 # --------------------------------------------------------------------------- # summarize_document (requires llm_chain) # --------------------------------------------------------------------------- class TestSummarizeDocument: def test_generates_summary(self, components) -> None: retriever, reranker, vector_store, store = components llm_chain = MagicMock() llm_chain.invoke.return_value = "This document covers exam policies." chunks = [_chunk(chunk_id="c1", text="Exam rules...")] vector_store.get_chunks_by_document_id.return_value = chunks tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain) summarize = tools[5] output = summarize.invoke({"document_id": "exam.pdf"}) assert "Resumé af exam.pdf" in output assert "exam policies" in output llm_chain.invoke.assert_called_once() # Verify the prompt includes the document text prompt = llm_chain.invoke.call_args[0][0] assert "Exam rules" in prompt def test_document_not_found(self, components) -> None: retriever, reranker, vector_store, store = components llm_chain = MagicMock() vector_store.get_chunks_by_document_id.return_value = [] tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain) output = tools[5].invoke({"document_id": "missing.pdf"}) assert "ikke fundet" in output llm_chain.invoke.assert_not_called() def test_truncates_long_documents(self, components) -> None: retriever, reranker, vector_store, store = components llm_chain = MagicMock() llm_chain.invoke.return_value = "summary" # Create a document longer than 8000 chars long_text = "x" * 10000 chunks = [_chunk(chunk_id="c1", text=long_text)] vector_store.get_chunks_by_document_id.return_value = chunks tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain) tools[5].invoke({"document_id": "long.pdf"}) prompt = llm_chain.invoke.call_args[0][0] assert "forkortet" in prompt def test_registers_chunks_as_sources(self, components) -> None: retriever, reranker, vector_store, store = components llm_chain = MagicMock() llm_chain.invoke.return_value = "summary" chunks = [_chunk(chunk_id="c1"), _chunk(chunk_id="c2")] vector_store.get_chunks_by_document_id.return_value = chunks tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain) tools[5].invoke({"document_id": "doc.pdf"}) assert len(store.retrieved) == 2 assert store.tool_calls[-1] == ("summarize_document", "doc.pdf")