abc123 / crossword-app /backend-py /test-unit /test_vector_search.py
vimalk78's picture
feat(crossword): generated crosswords with clues
486eff6
"""
Unit tests for VectorSearchService.
"""
import pytest
import asyncio
import os
import tempfile
import json
from unittest.mock import Mock, patch, MagicMock
import sys
from pathlib import Path
import numpy as np
# Add project root to path for imports
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.services.vector_search import VectorSearchService
@pytest.fixture
def mock_sentence_transformer():
"""Mock SentenceTransformer for testing."""
mock_model = Mock()
mock_model.encode.return_value = np.random.rand(5, 384) # 5 words, 384 dimensions
# Mock tokenizer
mock_tokenizer = Mock()
mock_tokenizer.get_vocab.return_value = {
"dog": 1, "cat": 2, "elephant": 3, "tiger": 4, "whale": 5,
"bird": 6, "fish": 7, "lion": 8, "bear": 9, "rabbit": 10,
"horse": 11, "sheep": 12, "goat": 13, "duck": 14, "chicken": 15
}
mock_model.tokenizer = mock_tokenizer
return mock_model
class TestVectorSearchService:
"""Test cases for VectorSearchService."""
def test_init(self):
"""Test service initialization."""
service = VectorSearchService()
assert service.model is None
assert service.vocab is None
assert service.word_embeddings is None
assert service.faiss_index is None
assert service.is_initialized is False
# Check default configuration
assert "all-mpnet-base-v2" in service.model_name
assert service.min_similarity_threshold == 0.45
assert service.max_results == 40
def test_filter_vocabulary(self):
"""Test vocabulary filtering."""
service = VectorSearchService()
vocab_dict = {
"dog": 1, "cat": 2, "elephant": 3, # Good words
"the": 4, "and": 5, "##ing": 6, # Should be filtered
"dogs": 7, "cats": 8, # Plurals - should be filtered
"a": 9, "ab": 10, # Too short
"supercalifragilisticexpialidocious": 11, # Too long
"[CLS]": 12, "<start>": 13, # Special tokens
}
filtered = service._filter_vocabulary(vocab_dict)
# Should keep good words
assert "DOG" in filtered
assert "CAT" in filtered
assert "ELEPHANT" in filtered
# Should filter out bad words
assert "THE" not in filtered
assert "AND" not in filtered
assert "DOGS" not in filtered
assert "CATS" not in filtered
assert "A" not in filtered
assert "[CLS]" not in filtered
def test_is_plural(self):
"""Test plural detection."""
service = VectorSearchService()
# Test plurals
assert service._is_plural("DOGS") is True
assert service._is_plural("CATS") is True
assert service._is_plural("BIRDS") is True
# Test non-plurals
assert service._is_plural("DOG") is False
assert service._is_plural("CLASS") is False # Ends in SS
assert service._is_plural("BUS") is False # Ends in US
assert service._is_plural("THIS") is False # Ends in IS
assert service._is_plural("CAT") is False
def test_is_boring_word(self):
"""Test boring word detection."""
service = VectorSearchService()
# Test boring words
assert service._is_boring_word("RUNNING") is True # ING ending
assert service._is_boring_word("EDUCATION") is True # TION ending
assert service._is_boring_word("HAPPINESS") is True # NESS ending
assert service._is_boring_word("GET") is True # Common short word
# Test interesting words
assert service._is_boring_word("DOG") is False
assert service._is_boring_word("ELEPHANT") is False
assert service._is_boring_word("COMPUTER") is False
def test_matches_difficulty(self):
"""Test difficulty matching."""
service = VectorSearchService()
# Easy: 3-8 chars
assert service._matches_difficulty("DOG", "easy") is True # 3 chars
assert service._matches_difficulty("ELEPHANT", "easy") is True # 8 chars
assert service._matches_difficulty("AB", "easy") is False # Too short
assert service._matches_difficulty("SUPERLONGSWORD", "easy") is False # Too long
# Medium: 4-10 chars
assert service._matches_difficulty("CATS", "medium") is True # 4 chars
assert service._matches_difficulty("BUTTERFLIES", "medium") is False # 11 chars
# Hard: 5-15 chars
assert service._matches_difficulty("TIGER", "hard") is True # 5 chars
assert service._matches_difficulty("DOG", "hard") is False # Too short
def test_generate_clue(self):
"""Test clue generation."""
service = VectorSearchService()
# Test topic-specific clues
clue = service._generate_clue("ELEPHANT", "Animals")
assert "elephant" in clue.lower()
assert "animal" in clue.lower()
clue = service._generate_clue("COMPUTER", "Technology")
assert "computer" in clue.lower()
assert "tech" in clue.lower()
# Test generic clue
clue = service._generate_clue("WORD", "Unknown")
assert "word" in clue.lower()
assert "unknown" in clue.lower()
def test_is_interesting_word(self):
"""Test interesting word detection."""
service = VectorSearchService()
# Test word matching topic (should be allowed - current behavior)
assert service._is_interesting_word("ANIMAL", "Animals") is True
assert service._is_interesting_word("ANIMALS", "Animals") is False
# Test obvious animal words (current implementation allows these)
assert service._is_interesting_word("MAMMAL", "Animals") is True
assert service._is_interesting_word("WILDLIFE", "Animals") is False
# Test abstract words (current implementation allows these too)
assert service._is_interesting_word("EDUCATION", "School") is True
assert service._is_interesting_word("HAPPINESS", "Emotions") is True # Current implementation allows -ness
# Test good words
assert service._is_interesting_word("ELEPHANT", "Animals") is True
assert service._is_interesting_word("COMPUTER", "Technology") is True
@pytest.mark.asyncio
@patch('src.services.vector_search.SentenceTransformer')
@patch('src.services.vector_search.faiss')
async def test_initialize_success(self, mock_faiss, mock_transformer_class, mock_sentence_transformer):
"""Test successful service initialization."""
# Setup mocks
mock_transformer_class.return_value = mock_sentence_transformer
mock_index = Mock()
mock_faiss.IndexFlatIP.return_value = mock_index
mock_faiss.normalize_L2 = Mock()
service = VectorSearchService()
await service.initialize()
assert service.is_initialized is True
assert service.model == mock_sentence_transformer
assert service.vocab is not None
assert service.faiss_index == mock_index
@pytest.mark.asyncio
@patch('src.services.vector_search.SentenceTransformer')
async def test_initialize_failure(self, mock_transformer_class):
"""Test service initialization failure."""
# Make SentenceTransformer raise an exception
mock_transformer_class.side_effect = Exception("Model load failed")
service = VectorSearchService()
with pytest.raises(Exception, match="Model load failed"):
await service.initialize()
assert service.is_initialized is False
@pytest.mark.asyncio
async def test_find_similar_words_not_initialized(self):
"""Test word search when service not initialized."""
service = VectorSearchService()
words = await service.find_similar_words("Animals", "medium", 5)
# Should return empty list when not initialized and no fallback
assert len(words) == 0
@pytest.mark.asyncio
@patch('src.services.vector_search.faiss')
async def test_find_similar_words_initialized(self, mock_faiss, mock_sentence_transformer):
"""Test word search when service is initialized."""
# Setup service as initialized
service = VectorSearchService()
service.is_initialized = True
service.model = mock_sentence_transformer
service.vocab = ["ELEPHANT", "TIGER", "LION", "BEAR", "WHALE"]
# Mock FAISS search results
mock_index = Mock()
mock_index.search.return_value = (
np.array([[0.8, 0.7, 0.6, 0.5, 0.4]]), # Scores
np.array([[0, 1, 2, 3, 4]]) # Indices
)
service.faiss_index = mock_index
# Mock embedding generation
mock_sentence_transformer.encode.return_value = np.array([[0.1, 0.2, 0.3]])
mock_faiss.normalize_L2 = Mock()
words = await service.find_similar_words("Animals", "medium", 5)
assert len(words) > 0
assert all(w["source"] == "vector_search" for w in words)
assert all("similarity" in w for w in words)
assert mock_index.search.call_count >= 1
@pytest.mark.asyncio
async def test_cleanup(self):
"""Test service cleanup."""
service = VectorSearchService()
service.model = Mock()
service.word_embeddings = Mock()
service.faiss_index = Mock()
service.is_initialized = True
await service.cleanup()
assert service.is_initialized is False
if __name__ == "__main__":
pytest.main([__file__, "-v"])