| """Unit tests for EmbeddingService.""" |
|
|
| from unittest.mock import patch |
|
|
| import numpy as np |
| import pytest |
|
|
| |
| pytest.importorskip("chromadb") |
| pytest.importorskip("sentence_transformers") |
|
|
| from src.services.embeddings import EmbeddingService |
|
|
|
|
| class TestEmbeddingService: |
| @pytest.fixture(autouse=True) |
| def reset_singleton(self): |
| """Reset the shared model singleton before and after each test. |
| |
| Using autouse=True ensures this always runs, even if test fails. |
| """ |
| import src.services.embeddings |
|
|
| |
| original_model = src.services.embeddings._shared_model |
| src.services.embeddings._shared_model = None |
|
|
| yield |
|
|
| |
| src.services.embeddings._shared_model = original_model |
|
|
| @pytest.fixture |
| def mock_sentence_transformer(self): |
| """Mock the SentenceTransformer class.""" |
| with patch("src.services.embeddings.SentenceTransformer") as mock_st_class: |
| mock_model = mock_st_class.return_value |
| |
| mock_model.encode.return_value = np.array([0.1, 0.2, 0.3]) |
| yield mock_model |
|
|
| @pytest.fixture |
| def mock_chroma_client(self): |
| with patch("src.services.embeddings.chromadb.Client") as mock_client_class: |
| mock_client = mock_client_class.return_value |
| mock_collection = mock_client.create_collection.return_value |
| |
| mock_collection.query.return_value = { |
| "ids": [["id1"]], |
| "documents": [["doc1"]], |
| "metadatas": [[{"source": "pubmed"}]], |
| "distances": [[0.1]], |
| } |
| yield mock_client |
|
|
| @pytest.mark.asyncio |
| async def test_embed_returns_vector(self, mock_sentence_transformer, mock_chroma_client): |
| """Embedding should return a float vector (async check).""" |
| service = EmbeddingService() |
| embedding = await service.embed("testosterone libido") |
|
|
| assert isinstance(embedding, list) |
| assert len(embedding) == 3 |
| assert all(isinstance(x, float) for x in embedding) |
| |
| mock_sentence_transformer.encode.assert_called_once() |
|
|
| @pytest.mark.asyncio |
| async def test_batch_embed_efficient(self, mock_sentence_transformer, mock_chroma_client): |
| """Batch embedding should call encode with list.""" |
| |
| mock_sentence_transformer.encode.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) |
|
|
| service = EmbeddingService() |
| texts = ["text one", "text two"] |
|
|
| batch_results = await service.embed_batch(texts) |
|
|
| assert len(batch_results) == 2 |
| assert isinstance(batch_results[0], list) |
| mock_sentence_transformer.encode.assert_called_with(texts) |
|
|
| @pytest.mark.asyncio |
| async def test_add_and_search(self, mock_sentence_transformer, mock_chroma_client): |
| """Should be able to add evidence and search for similar.""" |
| service = EmbeddingService() |
| await service.add_evidence( |
| evidence_id="test1", |
| content="Testosterone activates androgen receptor pathway", |
| metadata={"source": "pubmed"}, |
| ) |
|
|
| |
| mock_collection = mock_chroma_client.create_collection.return_value |
| mock_collection.add.assert_called_once() |
|
|
| results = await service.search_similar("AMPK activation drugs", n_results=1) |
|
|
| |
| mock_collection.query.assert_called_once() |
| assert len(results) == 1 |
| assert results[0]["id"] == "id1" |
|
|
| @pytest.mark.asyncio |
| async def test_search_similar_empty_collection( |
| self, mock_sentence_transformer, mock_chroma_client |
| ): |
| """Search on empty collection should return empty list, not error.""" |
| mock_collection = mock_chroma_client.create_collection.return_value |
| mock_collection.query.return_value = { |
| "ids": [[]], |
| "documents": [[]], |
| "metadatas": [[]], |
| "distances": [[]], |
| } |
|
|
| service = EmbeddingService() |
| results = await service.search_similar("anything", n_results=5) |
| assert results == [] |
|
|
| @pytest.mark.asyncio |
| async def test_deduplicate(self, mock_sentence_transformer, mock_chroma_client): |
| """Deduplicate should remove similar items.""" |
| from src.utils.models import Citation, Evidence |
|
|
| service = EmbeddingService() |
|
|
| |
| |
| mock_collection = mock_chroma_client.create_collection.return_value |
|
|
| |
| |
| mock_collection.query.side_effect = [ |
| { |
| "ids": [["existing_id"]], |
| "documents": [["doc"]], |
| "metadatas": [[{}]], |
| "distances": [[0.05]], |
| }, |
| { |
| "ids": [[]], |
| "documents": [[]], |
| "metadatas": [[]], |
| "distances": [[]], |
| }, |
| ] |
|
|
| evidence = [ |
| Evidence( |
| content="Duplicate content", |
| citation=Citation(source="pubmed", url="u1", title="t1", date="2024"), |
| ), |
| Evidence( |
| content="Unique content", |
| citation=Citation(source="pubmed", url="u2", title="t2", date="2024"), |
| ), |
| ] |
|
|
| unique = await service.deduplicate(evidence, threshold=0.9) |
|
|
| |
| assert len(unique) == 1 |
| assert unique[0].citation.url == "u2" |
|
|