File size: 5,747 Bytes
7baf8ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
"""Tests for EmbeddingServiceProtocol compliance.
TDD: These tests verify that both EmbeddingService and LlamaIndexRAGService
implement the EmbeddingServiceProtocol interface correctly.
"""
import asyncio
from unittest.mock import patch
import pytest
# Skip if chromadb not available
pytest.importorskip("chromadb")
pytest.importorskip("sentence_transformers")
class TestEmbeddingServiceProtocolCompliance:
"""Verify EmbeddingService implements EmbeddingServiceProtocol."""
@pytest.fixture
def mock_sentence_transformer(self):
"""Mock sentence transformer to avoid loading actual model."""
import numpy as np
import src.services.embeddings
# Reset singleton to ensure mock is used
src.services.embeddings._shared_model = None
with patch("src.services.embeddings.SentenceTransformer") as mock_st_class:
mock_model = mock_st_class.return_value
mock_model.encode.return_value = np.array([0.1, 0.2, 0.3])
yield mock_model
# Cleanup
src.services.embeddings._shared_model = None
@pytest.fixture
def mock_chroma_client(self):
"""Mock ChromaDB client."""
with patch("src.services.embeddings.chromadb.Client") as mock_client_class:
mock_client = mock_client_class.return_value
mock_collection = mock_client.create_collection.return_value
mock_collection.query.return_value = {
"ids": [["id1"]],
"documents": [["doc1"]],
"metadatas": [[{"source": "pubmed"}]],
"distances": [[0.1]],
}
yield mock_client
def test_has_add_evidence_method(self, mock_sentence_transformer, mock_chroma_client):
"""EmbeddingService should have async add_evidence method."""
from src.services.embeddings import EmbeddingService
service = EmbeddingService()
assert hasattr(service, "add_evidence")
assert asyncio.iscoroutinefunction(service.add_evidence)
def test_has_search_similar_method(self, mock_sentence_transformer, mock_chroma_client):
"""EmbeddingService should have async search_similar method."""
from src.services.embeddings import EmbeddingService
service = EmbeddingService()
assert hasattr(service, "search_similar")
assert asyncio.iscoroutinefunction(service.search_similar)
def test_has_deduplicate_method(self, mock_sentence_transformer, mock_chroma_client):
"""EmbeddingService should have async deduplicate method."""
from src.services.embeddings import EmbeddingService
service = EmbeddingService()
assert hasattr(service, "deduplicate")
assert asyncio.iscoroutinefunction(service.deduplicate)
@pytest.mark.asyncio
async def test_add_evidence_signature(self, mock_sentence_transformer, mock_chroma_client):
"""add_evidence should accept (evidence_id, content, metadata)."""
from src.services.embeddings import EmbeddingService
service = EmbeddingService()
# Should not raise
await service.add_evidence(
evidence_id="test-id",
content="test content",
metadata={"source": "pubmed", "title": "Test"},
)
@pytest.mark.asyncio
async def test_search_similar_signature(self, mock_sentence_transformer, mock_chroma_client):
"""search_similar should accept (query, n_results) and return list[dict]."""
from src.services.embeddings import EmbeddingService
service = EmbeddingService()
results = await service.search_similar("test query", n_results=5)
assert isinstance(results, list)
if results:
assert isinstance(results[0], dict)
# Should have expected keys
assert "id" in results[0]
assert "content" in results[0]
assert "metadata" in results[0]
assert "distance" in results[0]
@pytest.mark.asyncio
async def test_deduplicate_signature(self, mock_sentence_transformer, mock_chroma_client):
"""deduplicate should accept (evidence, threshold) and return list[Evidence]."""
from src.services.embeddings import EmbeddingService
from src.utils.models import Citation, Evidence
service = EmbeddingService()
# Mock to avoid actual dedup logic
mock_chroma_client.create_collection.return_value.query.return_value = {
"ids": [[]],
"documents": [[]],
"metadatas": [[]],
"distances": [[]],
}
evidence = [
Evidence(
content="test",
citation=Citation(source="pubmed", url="u1", title="t1", date="2024"),
)
]
results = await service.deduplicate(evidence, threshold=0.9)
assert isinstance(results, list)
assert all(isinstance(e, Evidence) for e in results)
class TestProtocolTypeChecking:
"""Verify Protocol works with type checking."""
def test_embedding_service_satisfies_protocol(self):
"""EmbeddingService should satisfy EmbeddingServiceProtocol."""
from src.services.embedding_protocol import EmbeddingServiceProtocol
from src.services.embeddings import EmbeddingService
# Protocol should be runtime checkable
assert hasattr(EmbeddingServiceProtocol, "__protocol_attrs__") or True
# This is a structural check - just verify the methods exist
service_methods = {"add_evidence", "search_similar", "deduplicate"}
embedding_methods = {m for m in dir(EmbeddingService) if not m.startswith("_")}
assert service_methods.issubset(embedding_methods)
|