Spaces:
Sleeping
Sleeping
| """Unit tests for retrieval components — no real embedder or Supabase calls.""" | |
| import unittest | |
| from unittest.mock import MagicMock | |
| from rag_engine.retrieval.query_preprocessor import QueryPreprocessor | |
| from rag_engine.retrieval.context_builder import ContextBuilder | |
| from rag_engine.retrieval.retriever import PolicyRetriever | |
| # ====================================================================== # | |
| # QueryPreprocessor | |
| # ====================================================================== # | |
| class TestQueryPreprocessor(unittest.TestCase): | |
| def setUp(self): | |
| self.p = QueryPreprocessor() | |
| def test_preprocessor_strips_whitespace(self): | |
| assert self.p.preprocess(" hello ") == "hello" | |
| def test_preprocessor_adds_question_mark(self): | |
| assert "?" in self.p.preprocess("is fire covered") | |
| def test_preprocessor_no_double_question_mark(self): | |
| result = self.p.preprocess("is fire covered?") | |
| assert result.count("?") == 1 | |
| def test_extract_filters_policy_id(self): | |
| filters = self.p.extract_filters("anything", policy_id="P1") | |
| assert filters["policy_id"] == "P1" | |
| def test_extract_filters_flood(self): | |
| filters = self.p.extract_filters("flood damage", policy_id="P1") | |
| assert filters["coverage_category"] == "flood" | |
| def test_extract_filters_deductible(self): | |
| filters = self.p.extract_filters("what is the deductible", "P1") | |
| assert filters["deductible_related"] is True | |
| # ====================================================================== # | |
| # PolicyRetriever | |
| # ====================================================================== # | |
| class TestPolicyRetriever(unittest.TestCase): | |
| def _make_mocks(self): | |
| mock_store = MagicMock() | |
| mock_embedder = MagicMock() | |
| mock_embedder.embed_query.return_value = [0.1] * 768 | |
| return mock_store, mock_embedder | |
| def test_retriever_calls_embed_and_search(self): | |
| mock_store, mock_embedder = self._make_mocks() | |
| mock_store.similarity_search.return_value = [ | |
| {"content": "flood excluded", "metadata": {}, "score": 0.9} | |
| ] | |
| retriever = PolicyRetriever(mock_store, mock_embedder) | |
| results = retriever.retrieve("flood damage", "POL-001", k=5) | |
| mock_embedder.embed_query.assert_called_once() | |
| mock_store.similarity_search.assert_called_once() | |
| assert len(results) == 1 | |
| def test_retriever_fallback_on_empty(self): | |
| mock_store, mock_embedder = self._make_mocks() | |
| # First call returns empty → triggers fallback; second call returns 1 result | |
| mock_store.similarity_search.side_effect = [ | |
| [], | |
| [{"content": "fallback result", "metadata": {}, "score": 0.5}], | |
| ] | |
| retriever = PolicyRetriever(mock_store, mock_embedder) | |
| results = retriever.retrieve("flood damage", "POL-001") | |
| assert mock_store.similarity_search.call_count == 2 | |
| # ====================================================================== # | |
| # ContextBuilder | |
| # ====================================================================== # | |
| class TestContextBuilder(unittest.TestCase): | |
| def test_context_builder_formats_correctly(self): | |
| results = [ | |
| { | |
| "content": "Flood is excluded", | |
| "metadata": { | |
| "section_name": "Exclusions", | |
| "clause_type": "exclusion", | |
| }, | |
| "score": 0.95, | |
| } | |
| ] | |
| cb = ContextBuilder() | |
| context = cb.build(results) | |
| assert "Source 1" in context | |
| assert "Exclusions" in context | |
| assert "Flood is excluded" in context | |
| def test_context_builder_empty_input(self): | |
| cb = ContextBuilder() | |
| result = cb.build([]) | |
| assert result == "" or result.strip() == "" | |
| if __name__ == "__main__": | |
| unittest.main() | |