File size: 13,483 Bytes
38c016b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 |
"""
Unit tests for CrosswordGenerator to ensure robust crossword generation.
"""
import pytest
import asyncio
from unittest.mock import Mock, patch
import sys
from pathlib import Path
# Add project root to path for imports
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.services.crossword_generator import CrosswordGenerator
@pytest.fixture
def sample_words():
"""Sample word data for testing."""
return [
{"word": "DOG", "clue": "Man's best friend", "similarity": 0.8, "source": "test"},
{"word": "ELEPHANT", "clue": "Large mammal with trunk", "similarity": 0.7, "source": "test"},
{"word": "CAT", "clue": "Feline pet", "similarity": 0.9, "source": "test"},
{"word": "BUTTERFLY", "clue": "Colorful flying insect", "similarity": 0.6, "source": "test"},
{"word": "TIGER", "clue": "Striped big cat", "similarity": 0.75, "source": "test"},
{"word": "WHALE", "clue": "Largest marine mammal", "similarity": 0.65, "source": "test"},
]
@pytest.fixture
def mock_vector_service():
"""Mock vector search service for testing."""
mock_service = Mock()
mock_service.is_initialized = True
return mock_service
class TestCrosswordGenerator:
"""Test cases for CrosswordGenerator."""
def test_init(self):
"""Test generator initialization."""
generator = CrosswordGenerator()
assert generator.max_attempts == 100
assert generator.min_words == 6
assert generator.max_words == 10
assert generator.vector_service is None
def test_init_with_vector_service(self, mock_vector_service):
"""Test generator initialization with vector service."""
generator = CrosswordGenerator(vector_service=mock_vector_service)
assert generator.vector_service == mock_vector_service
def test_sort_words_for_crossword(self, sample_words):
"""Test word sorting by crossword suitability."""
generator = CrosswordGenerator()
sorted_words = generator._sort_words_for_crossword(sample_words)
# Should return list of dicts with crossword_score
assert len(sorted_words) == len(sample_words)
assert all(isinstance(w, dict) for w in sorted_words)
assert all("crossword_score" in w for w in sorted_words)
# Scores should be in descending order (with some randomization tolerance)
scores = [w["crossword_score"] for w in sorted_words]
# Allow for some randomization but generally descending
assert len(scores) > 0
def test_filter_by_difficulty(self, sample_words):
"""Test difficulty filtering."""
generator = CrosswordGenerator()
# Test easy difficulty (3-8 chars)
easy_words = generator._filter_by_difficulty(sample_words, "easy")
easy_lengths = [len(w["word"]) for w in easy_words]
assert all(3 <= length <= 8 for length in easy_lengths)
# Test medium difficulty (4-10 chars)
medium_words = generator._filter_by_difficulty(sample_words, "medium")
medium_lengths = [len(w["word"]) for w in medium_words]
assert all(4 <= length <= 10 for length in medium_lengths)
# Test hard difficulty (5-15 chars)
hard_words = generator._filter_by_difficulty(sample_words, "hard")
hard_lengths = [len(w["word"]) for w in hard_words]
assert all(5 <= length <= 15 for length in hard_lengths)
def test_calculate_grid_size(self):
"""Test grid size calculation."""
generator = CrosswordGenerator()
# Test with short words
short_words = ["DOG", "CAT", "BAT"]
size = generator._calculate_grid_size(short_words)
assert size >= 8 # Minimum size
assert size >= 3 # Longest word length
# Test with longer words
long_words = ["ELEPHANT", "BUTTERFLY", "HIPPOPOTAMUS"]
size = generator._calculate_grid_size(long_words)
assert size >= 12 # Longest word (HIPPOPOTAMUS)
def test_create_grid_word_processing(self, sample_words):
"""Test the critical word processing logic that was causing index errors."""
generator = CrosswordGenerator()
# This tests the fix for the list index out of range error
result = generator._create_grid(sample_words)
# Should not crash and should return a result or None
assert result is None or isinstance(result, dict)
# If result exists, it should have the correct structure
if result:
assert "grid" in result
assert "clues" in result
assert "placed_words" in result
def test_create_grid_empty_words(self):
"""Test grid creation with empty word list."""
generator = CrosswordGenerator()
result = generator._create_grid([])
assert result is None
def test_create_grid_malformed_words(self):
"""Test grid creation with malformed word data."""
generator = CrosswordGenerator()
# Test with various malformed inputs
malformed_words = [
"just_string", # String instead of dict
{"no_word_key": "value"}, # Dict without 'word' key
{"word": ""}, # Empty word
None, # None value
123, # Number
]
# Should not crash, might return None
result = generator._create_grid(malformed_words)
assert result is None or isinstance(result, dict)
def test_can_place_word_horizontal(self):
"""Test horizontal word placement validation."""
generator = CrosswordGenerator()
grid = [["." for _ in range(10)] for _ in range(10)]
# Test valid placement
assert generator._can_place_word(grid, "TEST", 5, 3, "horizontal")
# Test boundary violations
assert not generator._can_place_word(grid, "TOOLONG", 5, 7, "horizontal") # Too long
assert not generator._can_place_word(grid, "TEST", 5, -1, "horizontal") # Negative col
assert not generator._can_place_word(grid, "TEST", -1, 3, "horizontal") # Negative row
def test_can_place_word_vertical(self):
"""Test vertical word placement validation."""
generator = CrosswordGenerator()
grid = [["." for _ in range(10)] for _ in range(10)]
# Test valid placement
assert generator._can_place_word(grid, "TEST", 3, 5, "vertical")
# Test boundary violations
assert not generator._can_place_word(grid, "TOOLONG", 7, 5, "vertical") # Too long
assert not generator._can_place_word(grid, "TEST", -1, 5, "vertical") # Negative row
assert not generator._can_place_word(grid, "TEST", 3, -1, "vertical") # Negative col
def test_place_and_remove_word(self):
"""Test word placement and removal."""
generator = CrosswordGenerator()
grid = [["." for _ in range(10)] for _ in range(10)]
# Place word horizontally
original_state = generator._place_word(grid, "TEST", 5, 3, "horizontal")
# Check placement
assert grid[5][3] == "T"
assert grid[5][4] == "E"
assert grid[5][5] == "S"
assert grid[5][6] == "T"
# Remove word
generator._remove_word(grid, original_state)
# Check removal
assert grid[5][3] == "."
assert grid[5][4] == "."
assert grid[5][5] == "."
assert grid[5][6] == "."
def test_find_word_intersections(self):
"""Test finding intersections between words."""
generator = CrosswordGenerator()
# Test words with common letters
intersections = generator._find_word_intersections("CAT", "DOG")
assert len(intersections) == 0 # No common letters
intersections = generator._find_word_intersections("CAT", "ACE")
assert len(intersections) >= 1 # Common 'A' and 'C'
# Verify intersection format
for intersection in intersections:
assert "word_pos" in intersection
assert "placed_pos" in intersection
assert isinstance(intersection["word_pos"], int)
assert isinstance(intersection["placed_pos"], int)
def test_create_simple_cross(self, sample_words):
"""Test simple cross creation as fallback."""
generator = CrosswordGenerator()
# Use words that have intersections
words_with_intersection = [
{"word": "CAT", "clue": "Feline"},
{"word": "ACE", "clue": "Playing card"},
]
word_list = ["CAT", "ACE"]
result = generator._create_simple_cross(word_list, words_with_intersection)
if result: # If intersection found
assert "grid" in result
assert "clues" in result
assert "placed_words" in result
assert len(result["placed_words"]) == 2
def test_generate_clues(self, sample_words):
"""Test clue generation for placed words."""
generator = CrosswordGenerator()
placed_words = [
{"word": "DOG", "row": 0, "col": 0, "direction": "horizontal", "number": 1},
{"word": "CAT", "row": 0, "col": 0, "direction": "vertical", "number": 2},
]
clues = generator._generate_clues(sample_words, placed_words)
assert len(clues) == 2
for clue in clues:
assert "number" in clue
assert "word" in clue
assert "text" in clue
assert "direction" in clue
assert clue["direction"] in ["across", "down"]
assert "position" in clue
@pytest.mark.asyncio
async def test_select_words_with_vector_service(self, mock_vector_service, sample_words):
"""Test word selection with vector service."""
# Mock vector service methods
mock_vector_service.find_similar_words.return_value = sample_words
generator = CrosswordGenerator(vector_service=mock_vector_service)
words = await generator._select_words(["Animals"], "medium", True)
assert len(words) <= generator.max_words
assert all(isinstance(w, dict) for w in words)
mock_vector_service.find_similar_words.assert_called_once()
@pytest.mark.asyncio
async def test_select_words_without_vector_service(self):
"""Test word selection without vector service."""
generator = CrosswordGenerator()
# Should fallback to empty/static words
words = await generator._select_words(["Animals"], "medium", True)
# Without vector service and no static files, should return empty or minimal
assert isinstance(words, list)
@pytest.mark.asyncio
async def test_generate_puzzle_success(self, mock_vector_service, sample_words):
"""Test successful puzzle generation."""
mock_vector_service.find_similar_words.return_value = sample_words
generator = CrosswordGenerator(vector_service=mock_vector_service)
# Mock the grid creation to return a simple result
with patch.object(generator, '_create_grid') as mock_create_grid:
mock_create_grid.return_value = {
"grid": [["T", "E", "S", "T"], [".", ".", ".", "."]],
"placed_words": [{"word": "TEST", "row": 0, "col": 0, "direction": "horizontal", "number": 1}],
"clues": [{"number": 1, "word": "TEST", "text": "A test", "direction": "across", "position": {"row": 0, "col": 0}}]
}
result = await generator.generate_puzzle(["Animals"], "medium", True)
assert result is not None
assert "grid" in result
assert "clues" in result
assert "metadata" in result
assert result["metadata"]["topics"] == ["Animals"]
assert result["metadata"]["difficulty"] == "medium"
assert result["metadata"]["aiGenerated"] is True
@pytest.mark.asyncio
async def test_generate_puzzle_insufficient_words(self, mock_vector_service):
"""Test puzzle generation with insufficient words."""
# Return too few words
mock_vector_service.find_similar_words.return_value = [
{"word": "CAT", "clue": "Feline", "similarity": 0.8, "source": "test"}
]
generator = CrosswordGenerator(vector_service=mock_vector_service)
with pytest.raises(Exception, match="Not enough words generated"):
await generator.generate_puzzle(["Animals"], "medium", True)
@pytest.mark.asyncio
async def test_generate_puzzle_grid_creation_fails(self, mock_vector_service, sample_words):
"""Test puzzle generation when grid creation fails."""
mock_vector_service.find_similar_words.return_value = sample_words
generator = CrosswordGenerator(vector_service=mock_vector_service)
# Mock grid creation to fail
with patch.object(generator, '_create_grid', return_value=None):
with pytest.raises(Exception, match="Could not create crossword grid"):
await generator.generate_puzzle(["Animals"], "medium", True)
if __name__ == "__main__":
pytest.main([__file__, "-v"]) |