File size: 2,553 Bytes
15459e9
 
 
 
 
 
5a82171
 
 
ecbc47b
5a82171
ecbc47b
 
15459e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""Unit tests for SearchAgent."""

from unittest.mock import AsyncMock

import pytest

# Skip all tests if agent_framework not installed (optional dep)
pytest.importorskip("agent_framework")

from agent_framework import ChatMessage, Role

from src.agents.search_agent import SearchAgent
from src.utils.models import Citation, Evidence, SearchResult


@pytest.fixture
def mock_handler() -> AsyncMock:
    """Mock search handler."""
    handler = AsyncMock()
    handler.execute.return_value = SearchResult(
        query="test query",
        evidence=[
            Evidence(
                content="test content",
                citation=Citation(
                    source="pubmed",
                    title="Test Title",
                    url="http://test.com",
                    date="2023",
                    authors=["Author A"],
                ),
                relevance=1.0,
            )
        ],
        sources_searched=["pubmed"],
        total_found=1,
    )
    return handler


@pytest.mark.asyncio
async def test_run_executes_search(mock_handler: AsyncMock) -> None:
    """Test that run executes search and updates evidence store."""
    store: dict = {"current": []}
    agent = SearchAgent(mock_handler, store)

    response = await agent.run("test query")

    # Check handler called
    mock_handler.execute.assert_awaited_once_with("test query", max_results_per_tool=10)

    # Check store updated
    assert len(store["current"]) == 1
    assert store["current"][0].content == "test content"

    # Check response
    assert response.messages[0].role == Role.ASSISTANT
    assert "Found 1 sources" in response.messages[0].text


@pytest.mark.asyncio
async def test_run_handles_chat_message_input(mock_handler: AsyncMock) -> None:
    """Test that run handles ChatMessage input."""
    store: dict = {"current": []}
    agent = SearchAgent(mock_handler, store)

    message = ChatMessage(role=Role.USER, text="test query")
    await agent.run(message)

    mock_handler.execute.assert_awaited_once_with("test query", max_results_per_tool=10)


@pytest.mark.asyncio
async def test_run_handles_list_input(mock_handler: AsyncMock) -> None:
    """Test that run handles list of messages."""
    store: dict = {"current": []}
    agent = SearchAgent(mock_handler, store)

    messages = [
        ChatMessage(role=Role.SYSTEM, text="sys"),
        ChatMessage(role=Role.USER, text="test query"),
    ]
    await agent.run(messages)

    mock_handler.execute.assert_awaited_once_with("test query", max_results_per_tool=10)