policylens-rag-api / rag_engine /tests /test_retrieval.py
DEVJHAWAR11
sync: deploy to HuggingFace Space
5b7955a
"""Unit tests for retrieval components — no real embedder or Supabase calls."""
import unittest
from unittest.mock import MagicMock
from rag_engine.retrieval.query_preprocessor import QueryPreprocessor
from rag_engine.retrieval.context_builder import ContextBuilder
from rag_engine.retrieval.retriever import PolicyRetriever
# ====================================================================== #
# QueryPreprocessor
# ====================================================================== #
class TestQueryPreprocessor(unittest.TestCase):
def setUp(self):
self.p = QueryPreprocessor()
def test_preprocessor_strips_whitespace(self):
assert self.p.preprocess(" hello ") == "hello"
def test_preprocessor_adds_question_mark(self):
assert "?" in self.p.preprocess("is fire covered")
def test_preprocessor_no_double_question_mark(self):
result = self.p.preprocess("is fire covered?")
assert result.count("?") == 1
def test_extract_filters_policy_id(self):
filters = self.p.extract_filters("anything", policy_id="P1")
assert filters["policy_id"] == "P1"
def test_extract_filters_flood(self):
filters = self.p.extract_filters("flood damage", policy_id="P1")
assert filters["coverage_category"] == "flood"
def test_extract_filters_deductible(self):
filters = self.p.extract_filters("what is the deductible", "P1")
assert filters["deductible_related"] is True
# ====================================================================== #
# PolicyRetriever
# ====================================================================== #
class TestPolicyRetriever(unittest.TestCase):
def _make_mocks(self):
mock_store = MagicMock()
mock_embedder = MagicMock()
mock_embedder.embed_query.return_value = [0.1] * 768
return mock_store, mock_embedder
def test_retriever_calls_embed_and_search(self):
mock_store, mock_embedder = self._make_mocks()
mock_store.similarity_search.return_value = [
{"content": "flood excluded", "metadata": {}, "score": 0.9}
]
retriever = PolicyRetriever(mock_store, mock_embedder)
results = retriever.retrieve("flood damage", "POL-001", k=5)
mock_embedder.embed_query.assert_called_once()
mock_store.similarity_search.assert_called_once()
assert len(results) == 1
def test_retriever_fallback_on_empty(self):
mock_store, mock_embedder = self._make_mocks()
# First call returns empty → triggers fallback; second call returns 1 result
mock_store.similarity_search.side_effect = [
[],
[{"content": "fallback result", "metadata": {}, "score": 0.5}],
]
retriever = PolicyRetriever(mock_store, mock_embedder)
results = retriever.retrieve("flood damage", "POL-001")
assert mock_store.similarity_search.call_count == 2
# ====================================================================== #
# ContextBuilder
# ====================================================================== #
class TestContextBuilder(unittest.TestCase):
def test_context_builder_formats_correctly(self):
results = [
{
"content": "Flood is excluded",
"metadata": {
"section_name": "Exclusions",
"clause_type": "exclusion",
},
"score": 0.95,
}
]
cb = ContextBuilder()
context = cb.build(results)
assert "Source 1" in context
assert "Exclusions" in context
assert "Flood is excluded" in context
def test_context_builder_empty_input(self):
cb = ContextBuilder()
result = cb.build([])
assert result == "" or result.strip() == ""
if __name__ == "__main__":
unittest.main()