Spaces:
Sleeping
Sleeping
| """ | |
| Unit tests for GUIDEAgent in src/agent/agent.py. | |
| The Anthropic client is replaced with a MagicMock so no API calls are made. | |
| wrap_anthropic is patched as identity to skip LangSmith wiring. | |
| Run: pytest tests/agent/test_agent.py | |
| """ | |
| import json | |
| from types import SimpleNamespace | |
| from unittest.mock import MagicMock, patch | |
| import pytest | |
| from src.agent.agent import GUIDEAgent, _FALLBACK_REPLY, _MAX_TOOL_ROUNDS | |
| # --------------------------------------------------------------------------- | |
| # Helpers — build mock content blocks | |
| # --------------------------------------------------------------------------- | |
| def _make_text_block(text: str): | |
| """Return a SimpleNamespace that looks like an Anthropic TextBlock.""" | |
| return SimpleNamespace(type="text", text=text) | |
| def _make_tool_use_block(name: str, tool_id: str, input_data: dict): | |
| """Return a SimpleNamespace that looks like an Anthropic ToolUseBlock.""" | |
| return SimpleNamespace(type="tool_use", name=name, id=tool_id, input=input_data) | |
| def _make_response(blocks, stop_reason: str): | |
| """Return a mock Message with .content and .stop_reason.""" | |
| msg = MagicMock() | |
| msg.content = blocks | |
| msg.stop_reason = stop_reason | |
| return msg | |
| # --------------------------------------------------------------------------- | |
| # Fixture — GUIDEAgent with mocked Anthropic client | |
| # --------------------------------------------------------------------------- | |
| def mock_client(): | |
| return MagicMock() | |
| def agent(mock_client): | |
| with patch("src.agent.agent.wrap_anthropic", side_effect=lambda x: x), \ | |
| patch("src.agent.agent.anthropic.Anthropic", return_value=mock_client): | |
| return GUIDEAgent(session_id="test-session") | |
| def _configure_stream(mock_client, responses): | |
| """ | |
| Configure mock_client.messages.stream to yield responses in order. | |
| Each call to __enter__ returns a fresh stream mock that returns | |
| the next response from the list via get_final_message(). | |
| """ | |
| stream_mocks = [] | |
| for resp in responses: | |
| stream = MagicMock() | |
| stream.text_stream = iter([]) | |
| stream.get_final_message.return_value = resp | |
| ctx = MagicMock() | |
| ctx.__enter__ = MagicMock(return_value=stream) | |
| ctx.__exit__ = MagicMock(return_value=False) | |
| stream_mocks.append(ctx) | |
| mock_client.messages.stream.side_effect = stream_mocks | |
| # --------------------------------------------------------------------------- | |
| # Task 4.2 — send_message end_turn | |
| # --------------------------------------------------------------------------- | |
| def test_send_message_end_turn_returns_text(agent, mock_client): | |
| response = _make_response([_make_text_block("Hello there")], stop_reason="end_turn") | |
| _configure_stream(mock_client, [response]) | |
| result = agent.send_message("hi") | |
| assert result == "Hello there" | |
| def test_send_message_appends_to_history(agent, mock_client): | |
| response = _make_response([_make_text_block("Hi")], stop_reason="end_turn") | |
| _configure_stream(mock_client, [response]) | |
| agent.send_message("hello") | |
| history = agent.get_history() | |
| roles = [m["role"] for m in history] | |
| assert "user" in roles | |
| assert "assistant" in roles | |
| # --------------------------------------------------------------------------- | |
| # Task 4.3 — tool_use round followed by end_turn | |
| # --------------------------------------------------------------------------- | |
| def test_send_message_tool_use_then_end_turn(agent, mock_client): | |
| tool_block = _make_tool_use_block("classify_domain", "tu_001", {"complaint_text": "Flipkart refund"}) | |
| round1 = _make_response([tool_block], stop_reason="tool_use") | |
| round2 = _make_response([_make_text_block("Domain classified.")], stop_reason="end_turn") | |
| _configure_stream(mock_client, [round1, round2]) | |
| with patch("src.agent.tools.execute_tool", return_value={"domain": "ecommerce"}): | |
| result = agent.send_message("Flipkart hasn't refunded me") | |
| assert result == "Domain classified." | |
| def test_tool_result_block_has_correct_tool_use_id(agent, mock_client): | |
| tool_block = _make_tool_use_block("classify_domain", "tu_abc123", {"complaint_text": "test"}) | |
| round1 = _make_response([tool_block], stop_reason="tool_use") | |
| round2 = _make_response([_make_text_block("Done.")], stop_reason="end_turn") | |
| _configure_stream(mock_client, [round1, round2]) | |
| with patch("src.agent.tools.execute_tool", return_value={"domain": "ecommerce"}): | |
| agent.send_message("test") | |
| history = agent.get_history() | |
| tool_result_turn = next( | |
| m for m in history | |
| if m["role"] == "user" and isinstance(m["content"], list) | |
| ) | |
| assert tool_result_turn["content"][0]["type"] == "tool_result" | |
| assert tool_result_turn["content"][0]["tool_use_id"] == "tu_abc123" | |
| # --------------------------------------------------------------------------- | |
| # Task 4.4 — add_document prefix and queue clearing | |
| # --------------------------------------------------------------------------- | |
| def test_add_document_prefix_prepended(agent, mock_client): | |
| response = _make_response([_make_text_block("ok")], stop_reason="end_turn") | |
| _configure_stream(mock_client, [response]) | |
| agent.add_document("/tmp/bill.pdf") | |
| agent.send_message("check this") | |
| history = agent.get_history() | |
| user_turn = next(m for m in history if m["role"] == "user") | |
| assert "[Document uploaded: /tmp/bill.pdf]" in user_turn["content"] | |
| def test_pending_documents_cleared_after_send(agent, mock_client): | |
| resp1 = _make_response([_make_text_block("ok")], stop_reason="end_turn") | |
| resp2 = _make_response([_make_text_block("ok2")], stop_reason="end_turn") | |
| _configure_stream(mock_client, [resp1, resp2]) | |
| agent.add_document("/tmp/bill.pdf") | |
| agent.send_message("first message") | |
| agent.send_message("second message") | |
| history = agent.get_history() | |
| user_turns = [m for m in history if m["role"] == "user"] | |
| # Second user turn must NOT contain the document prefix | |
| assert "[Document uploaded:" not in user_turns[1]["content"] | |
| def test_two_documents_both_appear_in_prefix(agent, mock_client): | |
| response = _make_response([_make_text_block("ok")], stop_reason="end_turn") | |
| _configure_stream(mock_client, [response]) | |
| agent.add_document("/tmp/doc1.pdf") | |
| agent.add_document("/tmp/doc2.png") | |
| agent.send_message("process both") | |
| history = agent.get_history() | |
| user_turn = next(m for m in history if m["role"] == "user") | |
| assert "[Document uploaded: /tmp/doc1.pdf]" in user_turn["content"] | |
| assert "[Document uploaded: /tmp/doc2.png]" in user_turn["content"] | |
| # --------------------------------------------------------------------------- | |
| # Task 4.5 — confirm_entities HITL injection | |
| # --------------------------------------------------------------------------- | |
| def test_confirm_entities_message_starts_with_user_confirmed(agent, mock_client): | |
| response = _make_response([_make_text_block("Draft generated.")], stop_reason="end_turn") | |
| _configure_stream(mock_client, [response]) | |
| agent.confirm_entities({"ORG": "Flipkart", "AMOUNT": "₹4,299"}) | |
| history = agent.get_history() | |
| user_turn = next(m for m in history if m["role"] == "user") | |
| assert user_turn["content"].startswith("[USER CONFIRMED]:") | |
| def test_confirm_entities_json_payload_matches_input(agent, mock_client): | |
| response = _make_response([_make_text_block("Draft generated.")], stop_reason="end_turn") | |
| _configure_stream(mock_client, [response]) | |
| entities = {"ORG": "HDFC Bank", "AMOUNT": "₹5,000", "REF_ID": "TXN123"} | |
| agent.confirm_entities(entities) | |
| history = agent.get_history() | |
| user_turn = next(m for m in history if m["role"] == "user") | |
| json_str = user_turn["content"].removeprefix("[USER CONFIRMED]: ") | |
| assert json.loads(json_str) == entities | |
| # --------------------------------------------------------------------------- | |
| # Task 4.6 — max tool rounds fallback | |
| # --------------------------------------------------------------------------- | |
| def test_max_tool_rounds_returns_fallback(agent, mock_client): | |
| tool_block = _make_tool_use_block("classify_domain", "tu_x", {"complaint_text": "test"}) | |
| always_tool = _make_response([tool_block], stop_reason="tool_use") | |
| _configure_stream(mock_client, [always_tool] * _MAX_TOOL_ROUNDS) | |
| with patch("src.agent.tools.execute_tool", return_value={"domain": "ecommerce"}): | |
| result = agent.send_message("loop forever") | |
| assert result # non-empty | |
| assert result == _FALLBACK_REPLY | |