guide / tests /agent /test_tools.py
Saravanakumar R
openspec cma orchestration tests
2feb381
Raw
History Blame Contribute Delete
9.55 kB
"""
Unit tests for execute_tool() dispatcher in src/agent/tools.py.
All DL singletons are patched at their source module — no checkpoints loaded.
No Anthropic API calls needed.
Run: pytest tests/agent/test_tools.py
"""
from unittest.mock import MagicMock, patch
import pytest
from src.agent.memory import SessionMemory
from src.agent.tools import execute_tool
from src.classifier.model import DomainResult
from src.ner.model import Entity
from src.next_action.model import EscalationAction
@pytest.fixture
def memory():
return SessionMemory()
# ---------------------------------------------------------------------------
# draft_complaint — pure Python, no external calls
# ---------------------------------------------------------------------------
def test_draft_complaint_minimal_context(memory):
ctx = {
"domain": "ecommerce",
"provider": "Flipkart",
"incident_date": "2024-03-12",
"desired_resolution": "full refund",
}
result = execute_tool("draft_complaint", {"complaint_context": ctx}, memory)
assert result["status"] == "proceed"
assert result["complaint_context"] == ctx
def test_draft_complaint_missing_key_defaults_to_empty(memory):
result = execute_tool("draft_complaint", {}, memory)
assert result["status"] == "proceed"
assert result["complaint_context"] == {}
# ---------------------------------------------------------------------------
# store_memory / get_memory
# ---------------------------------------------------------------------------
def test_store_memory_writes_to_session(memory):
result = execute_tool("store_memory", {"key": "domain", "value": "telecom"}, memory)
assert result == {"status": "stored", "key": "domain"}
assert memory.get("domain") == "telecom"
def test_get_memory_reads_from_session(memory):
memory.set("domain", "telecom")
result = execute_tool("get_memory", {"key": "domain"}, memory)
assert result == {"key": "domain", "value": "telecom"}
def test_get_memory_missing_key_returns_none(memory):
result = execute_tool("get_memory", {"key": "absent"}, memory)
assert result == {"key": "absent", "value": None}
def test_store_get_round_trip(memory):
execute_tool("store_memory", {"key": "prior_contact", "value": True}, memory)
result = execute_tool("get_memory", {"key": "prior_contact"}, memory)
assert result["value"] is True
# ---------------------------------------------------------------------------
# Unknown tool name
# ---------------------------------------------------------------------------
def test_unknown_tool_returns_error(memory):
result = execute_tool("nonexistent_tool", {}, memory)
assert "error" in result
assert "nonexistent_tool" in result["error"]
# ---------------------------------------------------------------------------
# classify_domain — patch src.classifier.predict.classify
# ---------------------------------------------------------------------------
def test_classify_domain_high_confidence(memory):
mock_result = DomainResult(
domain="ecommerce",
confidence=0.92,
all_probs={"ecommerce": 0.92, "telecom": 0.03, "banking": 0.02,
"cibil": 0.01, "insurance": 0.01, "general": 0.01},
low_confidence=False,
)
with patch("src.classifier.predict.classify", return_value=mock_result):
result = execute_tool("classify_domain", {"complaint_text": "Flipkart refund"}, memory)
assert result["domain"] == "ecommerce"
assert result["confidence"] == 0.92
assert result["low_confidence"] is False
def test_classify_domain_low_confidence_propagated(memory):
mock_result = DomainResult(
domain="general",
confidence=0.3,
all_probs={"ecommerce": 0.2, "telecom": 0.1, "banking": 0.1,
"cibil": 0.1, "insurance": 0.1, "general": 0.4},
low_confidence=True,
)
with patch("src.classifier.predict.classify", return_value=mock_result):
result = execute_tool("classify_domain", {"complaint_text": "I have a complaint"}, memory)
assert result["low_confidence"] is True
def test_classify_domain_exception_returns_error(memory):
with patch("src.classifier.predict.classify", side_effect=RuntimeError("model error")):
result = execute_tool("classify_domain", {"complaint_text": "test"}, memory)
assert "error" in result
assert "model error" in result["error"]
# ---------------------------------------------------------------------------
# extract_entities — patch src.ner.predict.extract_entities
# ---------------------------------------------------------------------------
def test_extract_entities_returns_list_of_dicts(memory):
mock_entities = [Entity(text="Flipkart", label="ORG", start=0, end=8, confidence=0.91)]
with patch("src.ner.predict.extract_entities", return_value=mock_entities):
result = execute_tool("extract_entities", {"text": "Flipkart refund"}, memory)
assert isinstance(result, list)
assert len(result) == 1
assert result[0]["text"] == "Flipkart"
assert result[0]["label"] == "ORG"
assert result[0]["start"] == 0
assert result[0]["end"] == 8
assert result[0]["confidence"] == 0.91
def test_extract_entities_empty_list(memory):
with patch("src.ner.predict.extract_entities", return_value=[]):
result = execute_tool("extract_entities", {"text": "plain text"}, memory)
assert result == []
def test_extract_entities_exception_returns_error(memory):
with patch("src.ner.predict.extract_entities", side_effect=Exception("ner failed")):
result = execute_tool("extract_entities", {"text": "test"}, memory)
assert "error" in result
assert "ner failed" in result["error"]
# ---------------------------------------------------------------------------
# recommend_action — patch src.next_action.predict.recommend_action
# ---------------------------------------------------------------------------
def _mock_actions():
return [
EscalationAction(action="company_support", authority="Flipkart Grievance", url="https://flipkart.com", confidence=0.5),
EscalationAction(action="nch", authority="NCH", url="https://consumerhelpline.gov.in", confidence=0.3),
]
def test_recommend_action_prior_contact_defaults_to_false(memory):
with patch("src.next_action.predict.recommend_action", return_value=_mock_actions()) as mock_fn:
execute_tool("recommend_action", {"domain": "ecommerce"}, memory)
_, kwargs = mock_fn.call_args
assert kwargs.get("prior_contact", mock_fn.call_args[0][2] if len(mock_fn.call_args[0]) > 2 else False) is False
def test_recommend_action_entities_defaults_to_empty(memory):
with patch("src.next_action.predict.recommend_action", return_value=_mock_actions()) as mock_fn:
execute_tool("recommend_action", {"domain": "ecommerce"}, memory)
args, kwargs = mock_fn.call_args
entities_val = kwargs.get("entities", args[1] if len(args) > 1 else {})
assert entities_val == {}
def test_recommend_action_returns_list_of_dicts(memory):
with patch("src.next_action.predict.recommend_action", return_value=_mock_actions()):
result = execute_tool("recommend_action", {"domain": "ecommerce", "entities": {}, "prior_contact": False}, memory)
assert isinstance(result, list)
assert result[0]["action"] == "company_support"
def test_recommend_action_exception_returns_error(memory):
with patch("src.next_action.predict.recommend_action", side_effect=RuntimeError("predictor failed")):
result = execute_tool("recommend_action", {"domain": "ecommerce"}, memory)
assert "error" in result
assert "predictor failed" in result["error"]
# ---------------------------------------------------------------------------
# process_document — patch src.document_processor.processor.get_processor
# ---------------------------------------------------------------------------
def test_process_document_successful_result(memory):
mock_entity = Entity(text="Invoice", label="REF_ID", start=0, end=7, confidence=0.85)
mock_processor = MagicMock()
mock_processor.process.return_value = {
"raw_text": "Invoice #12345 for ₹4,299",
"entities": [mock_entity],
}
with patch("src.document_processor.processor.get_processor", return_value=mock_processor):
result = execute_tool("process_document", {"file_path": "/tmp/invoice.pdf"}, memory)
assert result["raw_text"] == "Invoice #12345 for ₹4,299"
assert isinstance(result["entities"], list)
assert result["entities"][0]["text"] == "Invoice"
def test_process_document_unsupported_extension_returns_error(memory):
mock_processor = MagicMock()
mock_processor.process.side_effect = ValueError("Unsupported file extension '.xyz'")
with patch("src.document_processor.processor.get_processor", return_value=mock_processor):
result = execute_tool("process_document", {"file_path": "/tmp/file.xyz"}, memory)
assert "error" in result
assert "Unsupported file extension" in result["error"]
def test_process_document_runtime_error_returns_error(memory):
mock_processor = MagicMock()
mock_processor.process.side_effect = RuntimeError("tesseract not found")
with patch("src.document_processor.processor.get_processor", return_value=mock_processor):
result = execute_tool("process_document", {"file_path": "/tmp/scan.png"}, memory)
assert "error" in result
assert "tesseract not found" in result["error"]