DeepCritical / tests /unit /agents /test_search_agent.py
VibecoderMcSwaggins's picture
fix(ci): add E402/RUF100 to ignore for cross-env compatibility
ecbc47b
raw
history blame
2.55 kB
"""Unit tests for SearchAgent."""
from unittest.mock import AsyncMock
import pytest
# Skip all tests if agent_framework not installed (optional dep)
pytest.importorskip("agent_framework")
from agent_framework import ChatMessage, Role
from src.agents.search_agent import SearchAgent
from src.utils.models import Citation, Evidence, SearchResult
@pytest.fixture
def mock_handler() -> AsyncMock:
"""Mock search handler."""
handler = AsyncMock()
handler.execute.return_value = SearchResult(
query="test query",
evidence=[
Evidence(
content="test content",
citation=Citation(
source="pubmed",
title="Test Title",
url="http://test.com",
date="2023",
authors=["Author A"],
),
relevance=1.0,
)
],
sources_searched=["pubmed"],
total_found=1,
)
return handler
@pytest.mark.asyncio
async def test_run_executes_search(mock_handler: AsyncMock) -> None:
"""Test that run executes search and updates evidence store."""
store: dict = {"current": []}
agent = SearchAgent(mock_handler, store)
response = await agent.run("test query")
# Check handler called
mock_handler.execute.assert_awaited_once_with("test query", max_results_per_tool=10)
# Check store updated
assert len(store["current"]) == 1
assert store["current"][0].content == "test content"
# Check response
assert response.messages[0].role == Role.ASSISTANT
assert "Found 1 sources" in response.messages[0].text
@pytest.mark.asyncio
async def test_run_handles_chat_message_input(mock_handler: AsyncMock) -> None:
"""Test that run handles ChatMessage input."""
store: dict = {"current": []}
agent = SearchAgent(mock_handler, store)
message = ChatMessage(role=Role.USER, text="test query")
await agent.run(message)
mock_handler.execute.assert_awaited_once_with("test query", max_results_per_tool=10)
@pytest.mark.asyncio
async def test_run_handles_list_input(mock_handler: AsyncMock) -> None:
"""Test that run handles list of messages."""
store: dict = {"current": []}
agent = SearchAgent(mock_handler, store)
messages = [
ChatMessage(role=Role.SYSTEM, text="sys"),
ChatMessage(role=Role.USER, text="test query"),
]
await agent.run(messages)
mock_handler.execute.assert_awaited_once_with("test query", max_results_per_tool=10)