guide / tests /agent /test_agent.py
Saravanakumar R
openspec cma orchestration tests
2feb381
Raw
History Blame Contribute Delete
8.56 kB
"""
Unit tests for GUIDEAgent in src/agent/agent.py.
The Anthropic client is replaced with a MagicMock so no API calls are made.
wrap_anthropic is patched as identity to skip LangSmith wiring.
Run: pytest tests/agent/test_agent.py
"""
import json
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from src.agent.agent import GUIDEAgent, _FALLBACK_REPLY, _MAX_TOOL_ROUNDS
# ---------------------------------------------------------------------------
# Helpers — build mock content blocks
# ---------------------------------------------------------------------------
def _make_text_block(text: str):
"""Return a SimpleNamespace that looks like an Anthropic TextBlock."""
return SimpleNamespace(type="text", text=text)
def _make_tool_use_block(name: str, tool_id: str, input_data: dict):
"""Return a SimpleNamespace that looks like an Anthropic ToolUseBlock."""
return SimpleNamespace(type="tool_use", name=name, id=tool_id, input=input_data)
def _make_response(blocks, stop_reason: str):
"""Return a mock Message with .content and .stop_reason."""
msg = MagicMock()
msg.content = blocks
msg.stop_reason = stop_reason
return msg
# ---------------------------------------------------------------------------
# Fixture — GUIDEAgent with mocked Anthropic client
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_client():
return MagicMock()
@pytest.fixture
def agent(mock_client):
with patch("src.agent.agent.wrap_anthropic", side_effect=lambda x: x), \
patch("src.agent.agent.anthropic.Anthropic", return_value=mock_client):
return GUIDEAgent(session_id="test-session")
def _configure_stream(mock_client, responses):
"""
Configure mock_client.messages.stream to yield responses in order.
Each call to __enter__ returns a fresh stream mock that returns
the next response from the list via get_final_message().
"""
stream_mocks = []
for resp in responses:
stream = MagicMock()
stream.text_stream = iter([])
stream.get_final_message.return_value = resp
ctx = MagicMock()
ctx.__enter__ = MagicMock(return_value=stream)
ctx.__exit__ = MagicMock(return_value=False)
stream_mocks.append(ctx)
mock_client.messages.stream.side_effect = stream_mocks
# ---------------------------------------------------------------------------
# Task 4.2 — send_message end_turn
# ---------------------------------------------------------------------------
def test_send_message_end_turn_returns_text(agent, mock_client):
response = _make_response([_make_text_block("Hello there")], stop_reason="end_turn")
_configure_stream(mock_client, [response])
result = agent.send_message("hi")
assert result == "Hello there"
def test_send_message_appends_to_history(agent, mock_client):
response = _make_response([_make_text_block("Hi")], stop_reason="end_turn")
_configure_stream(mock_client, [response])
agent.send_message("hello")
history = agent.get_history()
roles = [m["role"] for m in history]
assert "user" in roles
assert "assistant" in roles
# ---------------------------------------------------------------------------
# Task 4.3 — tool_use round followed by end_turn
# ---------------------------------------------------------------------------
def test_send_message_tool_use_then_end_turn(agent, mock_client):
tool_block = _make_tool_use_block("classify_domain", "tu_001", {"complaint_text": "Flipkart refund"})
round1 = _make_response([tool_block], stop_reason="tool_use")
round2 = _make_response([_make_text_block("Domain classified.")], stop_reason="end_turn")
_configure_stream(mock_client, [round1, round2])
with patch("src.agent.tools.execute_tool", return_value={"domain": "ecommerce"}):
result = agent.send_message("Flipkart hasn't refunded me")
assert result == "Domain classified."
def test_tool_result_block_has_correct_tool_use_id(agent, mock_client):
tool_block = _make_tool_use_block("classify_domain", "tu_abc123", {"complaint_text": "test"})
round1 = _make_response([tool_block], stop_reason="tool_use")
round2 = _make_response([_make_text_block("Done.")], stop_reason="end_turn")
_configure_stream(mock_client, [round1, round2])
with patch("src.agent.tools.execute_tool", return_value={"domain": "ecommerce"}):
agent.send_message("test")
history = agent.get_history()
tool_result_turn = next(
m for m in history
if m["role"] == "user" and isinstance(m["content"], list)
)
assert tool_result_turn["content"][0]["type"] == "tool_result"
assert tool_result_turn["content"][0]["tool_use_id"] == "tu_abc123"
# ---------------------------------------------------------------------------
# Task 4.4 — add_document prefix and queue clearing
# ---------------------------------------------------------------------------
def test_add_document_prefix_prepended(agent, mock_client):
response = _make_response([_make_text_block("ok")], stop_reason="end_turn")
_configure_stream(mock_client, [response])
agent.add_document("/tmp/bill.pdf")
agent.send_message("check this")
history = agent.get_history()
user_turn = next(m for m in history if m["role"] == "user")
assert "[Document uploaded: /tmp/bill.pdf]" in user_turn["content"]
def test_pending_documents_cleared_after_send(agent, mock_client):
resp1 = _make_response([_make_text_block("ok")], stop_reason="end_turn")
resp2 = _make_response([_make_text_block("ok2")], stop_reason="end_turn")
_configure_stream(mock_client, [resp1, resp2])
agent.add_document("/tmp/bill.pdf")
agent.send_message("first message")
agent.send_message("second message")
history = agent.get_history()
user_turns = [m for m in history if m["role"] == "user"]
# Second user turn must NOT contain the document prefix
assert "[Document uploaded:" not in user_turns[1]["content"]
def test_two_documents_both_appear_in_prefix(agent, mock_client):
response = _make_response([_make_text_block("ok")], stop_reason="end_turn")
_configure_stream(mock_client, [response])
agent.add_document("/tmp/doc1.pdf")
agent.add_document("/tmp/doc2.png")
agent.send_message("process both")
history = agent.get_history()
user_turn = next(m for m in history if m["role"] == "user")
assert "[Document uploaded: /tmp/doc1.pdf]" in user_turn["content"]
assert "[Document uploaded: /tmp/doc2.png]" in user_turn["content"]
# ---------------------------------------------------------------------------
# Task 4.5 — confirm_entities HITL injection
# ---------------------------------------------------------------------------
def test_confirm_entities_message_starts_with_user_confirmed(agent, mock_client):
response = _make_response([_make_text_block("Draft generated.")], stop_reason="end_turn")
_configure_stream(mock_client, [response])
agent.confirm_entities({"ORG": "Flipkart", "AMOUNT": "₹4,299"})
history = agent.get_history()
user_turn = next(m for m in history if m["role"] == "user")
assert user_turn["content"].startswith("[USER CONFIRMED]:")
def test_confirm_entities_json_payload_matches_input(agent, mock_client):
response = _make_response([_make_text_block("Draft generated.")], stop_reason="end_turn")
_configure_stream(mock_client, [response])
entities = {"ORG": "HDFC Bank", "AMOUNT": "₹5,000", "REF_ID": "TXN123"}
agent.confirm_entities(entities)
history = agent.get_history()
user_turn = next(m for m in history if m["role"] == "user")
json_str = user_turn["content"].removeprefix("[USER CONFIRMED]: ")
assert json.loads(json_str) == entities
# ---------------------------------------------------------------------------
# Task 4.6 — max tool rounds fallback
# ---------------------------------------------------------------------------
def test_max_tool_rounds_returns_fallback(agent, mock_client):
tool_block = _make_tool_use_block("classify_domain", "tu_x", {"complaint_text": "test"})
always_tool = _make_response([tool_block], stop_reason="tool_use")
_configure_stream(mock_client, [always_tool] * _MAX_TOOL_ROUNDS)
with patch("src.agent.tools.execute_tool", return_value={"domain": "ecommerce"}):
result = agent.send_message("loop forever")
assert result # non-empty
assert result == _FALLBACK_REPLY