File size: 2,402 Bytes
499170b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
"""Unit tests for SearchHandler."""
from unittest.mock import AsyncMock
import pytest
from src.tools.search_handler import SearchHandler
from src.utils.exceptions import SearchError
from src.utils.models import Citation, Evidence
class TestSearchHandler:
"""Tests for SearchHandler."""
@pytest.mark.asyncio
async def test_execute_aggregates_results(self):
"""SearchHandler should aggregate results from all tools."""
# Create mock tools
mock_tool_1 = AsyncMock()
mock_tool_1.name = "pubmed"
mock_tool_1.search = AsyncMock(
return_value=[
Evidence(
content="Result 1",
citation=Citation(source="pubmed", title="T1", url="u1", date="2024"),
)
]
)
mock_tool_2 = AsyncMock()
mock_tool_2.name = "web"
mock_tool_2.search = AsyncMock(
return_value=[
Evidence(
content="Result 2",
citation=Citation(source="web", title="T2", url="u2", date="2024"),
)
]
)
handler = SearchHandler(tools=[mock_tool_1, mock_tool_2])
result = await handler.execute("test query")
expected_total = 2
assert result.total_found == expected_total
assert "pubmed" in result.sources_searched
assert "web" in result.sources_searched
assert len(result.errors) == 0
@pytest.mark.asyncio
async def test_execute_handles_tool_failure(self):
"""SearchHandler should continue if one tool fails."""
mock_tool_ok = AsyncMock()
mock_tool_ok.name = "pubmed"
mock_tool_ok.search = AsyncMock(
return_value=[
Evidence(
content="Good result",
citation=Citation(source="pubmed", title="T", url="u", date="2024"),
)
]
)
mock_tool_fail = AsyncMock()
mock_tool_fail.name = "web"
mock_tool_fail.search = AsyncMock(side_effect=SearchError("API down"))
handler = SearchHandler(tools=[mock_tool_ok, mock_tool_fail])
result = await handler.execute("test")
assert result.total_found == 1
assert "pubmed" in result.sources_searched
assert len(result.errors) == 1
assert "web" in result.errors[0]
|