"""Tests for RAG pipeline: chunker, embedder, store, retriever, and reranker.""" from __future__ import annotations import numpy as np import pytest from agent_bench.rag.chunker import Chunk, chunk_fixed, chunk_recursive, chunk_text from agent_bench.rag.embedder import Embedder from agent_bench.rag.reranker import CrossEncoderReranker from agent_bench.rag.retriever import Retriever from agent_bench.rag.store import HybridStore, SearchResult # --- Chunker tests --- class TestChunker: SAMPLE_TEXT = ( "FastAPI is a modern web framework.\n\n" "It is based on standard Python type hints.\n\n" "Path parameters are declared in the URL path using curly braces. " "You can specify their types using annotations.\n\n" "Query parameters are parsed automatically from the query string. " "They support default values and optional types.\n\n" "Request bodies use Pydantic models for validation." ) def test_recursive_within_size_limits(self): chunk_size = 100 overlap = 64 chunks = chunk_recursive( self.SAMPLE_TEXT, "test.md", chunk_size=chunk_size, chunk_overlap=overlap ) for c in chunks: # Overlap prepend may push up to overlap chars beyond chunk_size assert len(c.content) <= chunk_size + overlap + 1, ( f"Chunk too long: {len(c.content)} chars" ) assert len(chunks) > 1 def test_fixed_within_size_limits(self): chunks = chunk_fixed(self.SAMPLE_TEXT, "test.md", chunk_size=100, chunk_overlap=20) for c in chunks: assert len(c.content) <= 100 assert len(chunks) > 1 def test_recursive_preserves_text(self): """Every word in the source should appear in at least one chunk.""" chunks = chunk_recursive(self.SAMPLE_TEXT, "test.md", chunk_size=200) all_words = set(self.SAMPLE_TEXT.split()) chunk_words = set() for c in chunks: chunk_words.update(c.content.split()) assert all_words.issubset(chunk_words) def test_fixed_preserves_text_coverage(self): """Every word in the source should appear in at least one chunk.""" chunks = chunk_fixed(self.SAMPLE_TEXT, "test.md", chunk_size=100, chunk_overlap=20) all_words = set(self.SAMPLE_TEXT.split()) chunk_words = set() for c in chunks: chunk_words.update(c.content.split()) assert all_words.issubset(chunk_words) def test_chunk_source_is_bare_filename(self): chunks = chunk_text(self.SAMPLE_TEXT, "fastapi_intro.md", strategy="recursive") for c in chunks: assert c.source == "fastapi_intro.md" assert "/" not in c.source def test_chunk_text_dispatcher(self): rec = chunk_text(self.SAMPLE_TEXT, "t.md", strategy="recursive", chunk_size=200) fix = chunk_text(self.SAMPLE_TEXT, "t.md", strategy="fixed", chunk_size=200) assert all(c.metadata.get("strategy") == "recursive" for c in rec) assert all(c.metadata.get("strategy") == "fixed" for c in fix) def test_empty_text(self): assert chunk_recursive("", "empty.md") == [] assert chunk_fixed("", "empty.md") == [] # --- Embedder tests --- class TestEmbedder: def test_embed_produces_correct_shape(self, mock_embedder: Embedder): vec = mock_embedder.embed("test sentence") assert vec.shape == (384,) def test_embed_is_normalized(self, mock_embedder: Embedder): vec = mock_embedder.embed("test sentence") norm = np.linalg.norm(vec) assert norm == pytest.approx(1.0, abs=1e-5) def test_embed_batch_shape(self, mock_embedder: Embedder): vecs = mock_embedder.embed_batch(["sentence one", "sentence two", "sentence three"]) assert vecs.shape == (3, 384) def test_cache_hit_skips_model(self, mock_embedding_model, tmp_path): """Second embed() call for same text should use cache, not model.""" embedder = Embedder(model=mock_embedding_model, cache_dir=str(tmp_path)) _ = embedder.embed("cache test") calls_after_first = mock_embedding_model.call_count _ = embedder.embed("cache test") assert mock_embedding_model.call_count == calls_after_first def test_different_texts_produce_different_embeddings(self, mock_embedder: Embedder): v1 = mock_embedder.embed("path parameters") v2 = mock_embedder.embed("query parameters") assert not np.allclose(v1, v2) # --- Store tests --- class TestHybridStore: def test_add_and_semantic_search(self, test_store: HybridStore, mock_embedder: Embedder): """Semantic search returns relevant result for a known query.""" query_vec = mock_embedder.embed("path parameters curly braces") results = test_store.search( query_embedding=query_vec, query_text="path parameters curly braces", top_k=3, strategy="semantic", ) assert len(results) > 0 assert all(isinstance(r, SearchResult) for r in results) # Should have scores and ranks assert results[0].rank == 1 assert results[0].retrieval_strategy == "semantic" def test_keyword_search(self, test_store: HybridStore, mock_embedder: Embedder): """BM25 keyword search finds chunks with matching terms.""" query_vec = mock_embedder.embed("Pydantic models validation") results = test_store.search( query_embedding=query_vec, query_text="Pydantic models validation", top_k=3, strategy="keyword", ) assert len(results) > 0 # Top result should be the request body chunk (mentions Pydantic) assert "Pydantic" in results[0].chunk.content def test_hybrid_returns_results_from_both( self, test_store: HybridStore, mock_embedder: Embedder ): """RRF hybrid search returns results — both dense and sparse contribute.""" query_vec = mock_embedder.embed("path parameters FastAPI") results = test_store.search( query_embedding=query_vec, query_text="path parameters FastAPI", top_k=5, strategy="hybrid", ) assert len(results) > 0 assert all(r.retrieval_strategy == "hybrid" for r in results) # RRF scores should be positive and sorted descending for i in range(len(results) - 1): assert results[i].score >= results[i + 1].score def test_empty_store(self): store = HybridStore(dimension=384) dummy_vec = np.random.randn(384).astype(np.float32) results = store.search( query_embedding=dummy_vec, query_text="test", top_k=5, strategy="hybrid" ) assert results == [] def test_save_load_roundtrip(self, test_store: HybridStore, mock_embedder: Embedder, tmp_path): """Save and load preserves all data and produces same search results.""" store_path = tmp_path / "test_store" # Search before save query_vec = mock_embedder.embed("path parameters") results_before = test_store.search( query_embedding=query_vec, query_text="path parameters", top_k=3, strategy="hybrid", ) # Save and reload test_store.save(store_path) loaded = HybridStore.load(store_path, rrf_k=60) # Stats match assert loaded.stats().total_chunks == test_store.stats().total_chunks assert loaded.stats().faiss_index_size == test_store.stats().faiss_index_size # Search after load results_after = loaded.search( query_embedding=query_vec, query_text="path parameters", top_k=3, strategy="hybrid", ) assert len(results_after) == len(results_before) assert [r.chunk.id for r in results_after] == [r.chunk.id for r in results_before] def test_stats(self, test_store: HybridStore): stats = test_store.stats() assert stats.total_chunks == 5 assert stats.faiss_index_size == 5 assert stats.unique_sources == 4 # 4 unique source files in sample chunks # --- Retriever tests --- class TestRetriever: @pytest.mark.asyncio async def test_search_returns_results(self, test_retriever: Retriever): result = await test_retriever.search("path parameters", top_k=3) assert len(result.results) > 0 assert all(isinstance(r, SearchResult) for r in result.results) @pytest.mark.asyncio async def test_search_strategy_override(self, test_retriever: Retriever): result = await test_retriever.search("Pydantic models", top_k=3, strategy="keyword") assert len(result.results) > 0 assert all(r.retrieval_strategy == "keyword" for r in result.results) # --- Reranker tests --- class MockCrossEncoder: """Mock cross-encoder that returns deterministic scores based on content length.""" def predict(self, pairs: list[tuple[str, str]]) -> list[float]: # Score based on content length — longer content scores higher # This gives a deterministic, predictable reordering return [float(len(content)) for _, content in pairs] class TestCrossEncoderReranker: def _make_chunks(self, contents: list[str]) -> list[Chunk]: return [ Chunk(id=f"c{i}", content=c, source=f"doc_{i}.md", chunk_index=0) for i, c in enumerate(contents) ] def test_reranker_reorders(self): """Reranker reorders chunks by cross-encoder score.""" chunks = self._make_chunks(["short", "a medium length chunk", "longest chunk content here"]) reranker = CrossEncoderReranker(model=MockCrossEncoder()) result = reranker.rerank("test query", chunks, top_k=3) # MockCrossEncoder scores by content length, so longest first assert result[0][0].content == "longest chunk content here" assert result[1][0].content == "a medium length chunk" assert result[2][0].content == "short" def test_reranker_top_k(self): """Reranker returns exactly top_k results from a larger input.""" chunks = self._make_chunks([f"content {i}" for i in range(20)]) reranker = CrossEncoderReranker(model=MockCrossEncoder()) result = reranker.rerank("test query", chunks, top_k=5) assert len(result) == 5 def test_reranker_disabled(self, mock_embedder: Embedder, test_store: HybridStore): """Retriever without reranker preserves RRF order.""" retriever_no_reranker = Retriever(embedder=mock_embedder, store=test_store) retriever_with_none = Retriever( embedder=mock_embedder, store=test_store, reranker=None, ) import asyncio results_a = asyncio.get_event_loop().run_until_complete( retriever_no_reranker.search("path parameters", top_k=3) ) results_b = asyncio.get_event_loop().run_until_complete( retriever_with_none.search("path parameters", top_k=3) ) assert [r.chunk.id for r in results_a.results] == [r.chunk.id for r in results_b.results] def test_reranker_empty_input(self): """Empty chunk list returns empty list.""" reranker = CrossEncoderReranker(model=MockCrossEncoder()) result = reranker.rerank("test query", [], top_k=5) assert result == [] @pytest.mark.asyncio async def test_reranked_results_preserve_rrf_scores( self, mock_embedder: Embedder, test_store: HybridStore, ): """Reranked results carry original RRF scores, not 0.0. This is critical: the refusal gate in SearchTool checks max_score from the returned results. If reranking zeroes out scores, the refusal gate would reject every reranked query. """ reranker = CrossEncoderReranker(model=MockCrossEncoder()) retriever = Retriever( embedder=mock_embedder, store=test_store, reranker=reranker, reranker_top_k=3, ) result = await retriever.search("path parameters", top_k=3) assert len(result.results) > 0 # All scores must be positive (preserved from RRF), not 0.0 scores = [r.score for r in result.results] assert all(r.score > 0 for r in result.results), ( f"Reranked scores should be positive RRF scores, got: {scores}" ) @pytest.mark.asyncio async def test_refusal_with_reranker_enabled(self): """Integration: out-of-scope query with reranker on still refuses. The refusal gate fires on RRF max_score BEFORE reranking (go/no-go decision). This test validates the Feature 1 + Feature 2 interaction. """ from agent_bench.tools.search import SearchTool from tests.test_tools import MockChunk, MockRetriever, MockSearchResult # Low scores — should trigger refusal regardless of reranker low_score_results = [ MockSearchResult( chunk=MockChunk(content="Unrelated content", source="irrelevant.md"), score=0.005, ), ] retriever = MockRetriever(results=low_score_results) tool = SearchTool(retriever=retriever, refusal_threshold=0.02) result = await tool.execute(query="how to cook pasta") assert result.metadata["refused"] is True assert "No relevant documents found" in result.result