| """Unit tests for SearchAgent.""" |
|
|
| from unittest.mock import AsyncMock |
|
|
| import pytest |
|
|
| |
| pytest.importorskip("agent_framework") |
|
|
| from agent_framework import ChatMessage, Role |
|
|
| from src.agents.search_agent import SearchAgent |
| from src.utils.models import Citation, Evidence, SearchResult |
|
|
|
|
| @pytest.fixture |
| def mock_handler() -> AsyncMock: |
| """Mock search handler.""" |
| handler = AsyncMock() |
| handler.execute.return_value = SearchResult( |
| query="test query", |
| evidence=[ |
| Evidence( |
| content="test content", |
| citation=Citation( |
| source="pubmed", |
| title="Test Title", |
| url="http://test.com", |
| date="2023", |
| authors=["Author A"], |
| ), |
| relevance=1.0, |
| ) |
| ], |
| sources_searched=["pubmed"], |
| total_found=1, |
| ) |
| return handler |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_run_executes_search(mock_handler: AsyncMock) -> None: |
| """Test that run executes search and updates evidence store.""" |
| store: dict = {"current": []} |
| agent = SearchAgent(mock_handler, store) |
|
|
| response = await agent.run("test query") |
|
|
| |
| mock_handler.execute.assert_awaited_once_with("test query", max_results_per_tool=10) |
|
|
| |
| assert len(store["current"]) == 1 |
| assert store["current"][0].content == "test content" |
|
|
| |
| assert response.messages[0].role == Role.ASSISTANT |
| assert "Found 1 sources" in response.messages[0].text |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_run_handles_chat_message_input(mock_handler: AsyncMock) -> None: |
| """Test that run handles ChatMessage input.""" |
| store: dict = {"current": []} |
| agent = SearchAgent(mock_handler, store) |
|
|
| message = ChatMessage(role=Role.USER, text="test query") |
| await agent.run(message) |
|
|
| mock_handler.execute.assert_awaited_once_with("test query", max_results_per_tool=10) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_run_handles_list_input(mock_handler: AsyncMock) -> None: |
| """Test that run handles list of messages.""" |
| store: dict = {"current": []} |
| agent = SearchAgent(mock_handler, store) |
|
|
| messages = [ |
| ChatMessage(role=Role.SYSTEM, text="sys"), |
| ChatMessage(role=Role.USER, text="test query"), |
| ] |
| await agent.run(messages) |
| mock_handler.execute.assert_awaited_once_with("test query", max_results_per_tool=10) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_run_uses_embeddings(mock_handler: AsyncMock) -> None: |
| """Test that run uses embedding service if provided.""" |
| store: dict = {"current": []} |
|
|
| |
| mock_embeddings = AsyncMock() |
| |
| mock_embeddings.deduplicate.return_value = [ |
| Evidence( |
| content="unique content", |
| citation=Citation(source="pubmed", url="u1", title="t1", date="2024"), |
| ) |
| ] |
| |
| mock_embeddings.search_similar.return_value = [ |
| { |
| "id": "u2", |
| "content": "related content", |
| "metadata": {"source": "pubmed", "title": "related", "date": "2024"}, |
| "distance": 0.1, |
| } |
| ] |
|
|
| agent = SearchAgent(mock_handler, store, embedding_service=mock_embeddings) |
|
|
| await agent.run("test query") |
|
|
| |
| mock_embeddings.deduplicate.assert_awaited_once() |
|
|
| |
| mock_embeddings.search_similar.assert_awaited_once_with("test query", n_results=5) |
|
|
| |
| |
| |
|
|
| |
| assert any(e.citation.url == "u1" for e in store["current"]) |
| |
| assert any(e.citation.url == "u2" for e in store["current"]) |
|
|