| """Tests for EmbeddingServiceProtocol compliance. |
| |
| TDD: These tests verify that both EmbeddingService and LlamaIndexRAGService |
| implement the EmbeddingServiceProtocol interface correctly. |
| """ |
|
|
| import asyncio |
| from unittest.mock import patch |
|
|
| import pytest |
|
|
| |
| pytest.importorskip("chromadb") |
| pytest.importorskip("sentence_transformers") |
|
|
|
|
| class TestEmbeddingServiceProtocolCompliance: |
| """Verify EmbeddingService implements EmbeddingServiceProtocol.""" |
|
|
| @pytest.fixture |
| def mock_sentence_transformer(self): |
| """Mock sentence transformer to avoid loading actual model.""" |
| import numpy as np |
|
|
| import src.services.embeddings |
|
|
| |
| src.services.embeddings._shared_model = None |
|
|
| with patch("src.services.embeddings.SentenceTransformer") as mock_st_class: |
| mock_model = mock_st_class.return_value |
| mock_model.encode.return_value = np.array([0.1, 0.2, 0.3]) |
| yield mock_model |
|
|
| |
| src.services.embeddings._shared_model = None |
|
|
| @pytest.fixture |
| def mock_chroma_client(self): |
| """Mock ChromaDB client.""" |
| with patch("src.services.embeddings.chromadb.Client") as mock_client_class: |
| mock_client = mock_client_class.return_value |
| mock_collection = mock_client.create_collection.return_value |
| mock_collection.query.return_value = { |
| "ids": [["id1"]], |
| "documents": [["doc1"]], |
| "metadatas": [[{"source": "pubmed"}]], |
| "distances": [[0.1]], |
| } |
| yield mock_client |
|
|
| def test_has_add_evidence_method(self, mock_sentence_transformer, mock_chroma_client): |
| """EmbeddingService should have async add_evidence method.""" |
| from src.services.embeddings import EmbeddingService |
|
|
| service = EmbeddingService() |
| assert hasattr(service, "add_evidence") |
| assert asyncio.iscoroutinefunction(service.add_evidence) |
|
|
| def test_has_search_similar_method(self, mock_sentence_transformer, mock_chroma_client): |
| """EmbeddingService should have async search_similar method.""" |
| from src.services.embeddings import EmbeddingService |
|
|
| service = EmbeddingService() |
| assert hasattr(service, "search_similar") |
| assert asyncio.iscoroutinefunction(service.search_similar) |
|
|
| def test_has_deduplicate_method(self, mock_sentence_transformer, mock_chroma_client): |
| """EmbeddingService should have async deduplicate method.""" |
| from src.services.embeddings import EmbeddingService |
|
|
| service = EmbeddingService() |
| assert hasattr(service, "deduplicate") |
| assert asyncio.iscoroutinefunction(service.deduplicate) |
|
|
| @pytest.mark.asyncio |
| async def test_add_evidence_signature(self, mock_sentence_transformer, mock_chroma_client): |
| """add_evidence should accept (evidence_id, content, metadata).""" |
| from src.services.embeddings import EmbeddingService |
|
|
| service = EmbeddingService() |
|
|
| |
| await service.add_evidence( |
| evidence_id="test-id", |
| content="test content", |
| metadata={"source": "pubmed", "title": "Test"}, |
| ) |
|
|
| @pytest.mark.asyncio |
| async def test_search_similar_signature(self, mock_sentence_transformer, mock_chroma_client): |
| """search_similar should accept (query, n_results) and return list[dict].""" |
| from src.services.embeddings import EmbeddingService |
|
|
| service = EmbeddingService() |
|
|
| results = await service.search_similar("test query", n_results=5) |
|
|
| assert isinstance(results, list) |
| if results: |
| assert isinstance(results[0], dict) |
| |
| assert "id" in results[0] |
| assert "content" in results[0] |
| assert "metadata" in results[0] |
| assert "distance" in results[0] |
|
|
| @pytest.mark.asyncio |
| async def test_deduplicate_signature(self, mock_sentence_transformer, mock_chroma_client): |
| """deduplicate should accept (evidence, threshold) and return list[Evidence].""" |
| from src.services.embeddings import EmbeddingService |
| from src.utils.models import Citation, Evidence |
|
|
| service = EmbeddingService() |
|
|
| |
| mock_chroma_client.create_collection.return_value.query.return_value = { |
| "ids": [[]], |
| "documents": [[]], |
| "metadatas": [[]], |
| "distances": [[]], |
| } |
|
|
| evidence = [ |
| Evidence( |
| content="test", |
| citation=Citation(source="pubmed", url="u1", title="t1", date="2024"), |
| ) |
| ] |
|
|
| results = await service.deduplicate(evidence, threshold=0.9) |
|
|
| assert isinstance(results, list) |
| assert all(isinstance(e, Evidence) for e in results) |
|
|
|
|
| class TestProtocolTypeChecking: |
| """Verify Protocol works with type checking.""" |
|
|
| def test_embedding_service_satisfies_protocol(self): |
| """EmbeddingService should satisfy EmbeddingServiceProtocol.""" |
|
|
| from src.services.embedding_protocol import EmbeddingServiceProtocol |
| from src.services.embeddings import EmbeddingService |
|
|
| |
| assert hasattr(EmbeddingServiceProtocol, "__protocol_attrs__") or True |
|
|
| |
| service_methods = {"add_evidence", "search_similar", "deduplicate"} |
| embedding_methods = {m for m in dir(EmbeddingService) if not m.startswith("_")} |
|
|
| assert service_methods.issubset(embedding_methods) |
|
|