Spaces:
Running
Running
File size: 2,553 Bytes
15459e9 5a82171 ecbc47b 5a82171 ecbc47b 15459e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
"""Unit tests for SearchAgent."""
from unittest.mock import AsyncMock
import pytest
# Skip all tests if agent_framework not installed (optional dep)
pytest.importorskip("agent_framework")
from agent_framework import ChatMessage, Role
from src.agents.search_agent import SearchAgent
from src.utils.models import Citation, Evidence, SearchResult
@pytest.fixture
def mock_handler() -> AsyncMock:
"""Mock search handler."""
handler = AsyncMock()
handler.execute.return_value = SearchResult(
query="test query",
evidence=[
Evidence(
content="test content",
citation=Citation(
source="pubmed",
title="Test Title",
url="http://test.com",
date="2023",
authors=["Author A"],
),
relevance=1.0,
)
],
sources_searched=["pubmed"],
total_found=1,
)
return handler
@pytest.mark.asyncio
async def test_run_executes_search(mock_handler: AsyncMock) -> None:
"""Test that run executes search and updates evidence store."""
store: dict = {"current": []}
agent = SearchAgent(mock_handler, store)
response = await agent.run("test query")
# Check handler called
mock_handler.execute.assert_awaited_once_with("test query", max_results_per_tool=10)
# Check store updated
assert len(store["current"]) == 1
assert store["current"][0].content == "test content"
# Check response
assert response.messages[0].role == Role.ASSISTANT
assert "Found 1 sources" in response.messages[0].text
@pytest.mark.asyncio
async def test_run_handles_chat_message_input(mock_handler: AsyncMock) -> None:
"""Test that run handles ChatMessage input."""
store: dict = {"current": []}
agent = SearchAgent(mock_handler, store)
message = ChatMessage(role=Role.USER, text="test query")
await agent.run(message)
mock_handler.execute.assert_awaited_once_with("test query", max_results_per_tool=10)
@pytest.mark.asyncio
async def test_run_handles_list_input(mock_handler: AsyncMock) -> None:
"""Test that run handles list of messages."""
store: dict = {"current": []}
agent = SearchAgent(mock_handler, store)
messages = [
ChatMessage(role=Role.SYSTEM, text="sys"),
ChatMessage(role=Role.USER, text="test query"),
]
await agent.run(messages)
mock_handler.execute.assert_awaited_once_with("test query", max_results_per_tool=10)
|