""" Unit tests for execute_tool() dispatcher in src/agent/tools.py. All DL singletons are patched at their source module — no checkpoints loaded. No Anthropic API calls needed. Run: pytest tests/agent/test_tools.py """ from unittest.mock import MagicMock, patch import pytest from src.agent.memory import SessionMemory from src.agent.tools import execute_tool from src.classifier.model import DomainResult from src.ner.model import Entity from src.next_action.model import EscalationAction @pytest.fixture def memory(): return SessionMemory() # --------------------------------------------------------------------------- # draft_complaint — pure Python, no external calls # --------------------------------------------------------------------------- def test_draft_complaint_minimal_context(memory): ctx = { "domain": "ecommerce", "provider": "Flipkart", "incident_date": "2024-03-12", "desired_resolution": "full refund", } result = execute_tool("draft_complaint", {"complaint_context": ctx}, memory) assert result["status"] == "proceed" assert result["complaint_context"] == ctx def test_draft_complaint_missing_key_defaults_to_empty(memory): result = execute_tool("draft_complaint", {}, memory) assert result["status"] == "proceed" assert result["complaint_context"] == {} # --------------------------------------------------------------------------- # store_memory / get_memory # --------------------------------------------------------------------------- def test_store_memory_writes_to_session(memory): result = execute_tool("store_memory", {"key": "domain", "value": "telecom"}, memory) assert result == {"status": "stored", "key": "domain"} assert memory.get("domain") == "telecom" def test_get_memory_reads_from_session(memory): memory.set("domain", "telecom") result = execute_tool("get_memory", {"key": "domain"}, memory) assert result == {"key": "domain", "value": "telecom"} def test_get_memory_missing_key_returns_none(memory): result = execute_tool("get_memory", {"key": "absent"}, memory) assert result == {"key": "absent", "value": None} def test_store_get_round_trip(memory): execute_tool("store_memory", {"key": "prior_contact", "value": True}, memory) result = execute_tool("get_memory", {"key": "prior_contact"}, memory) assert result["value"] is True # --------------------------------------------------------------------------- # Unknown tool name # --------------------------------------------------------------------------- def test_unknown_tool_returns_error(memory): result = execute_tool("nonexistent_tool", {}, memory) assert "error" in result assert "nonexistent_tool" in result["error"] # --------------------------------------------------------------------------- # classify_domain — patch src.classifier.predict.classify # --------------------------------------------------------------------------- def test_classify_domain_high_confidence(memory): mock_result = DomainResult( domain="ecommerce", confidence=0.92, all_probs={"ecommerce": 0.92, "telecom": 0.03, "banking": 0.02, "cibil": 0.01, "insurance": 0.01, "general": 0.01}, low_confidence=False, ) with patch("src.classifier.predict.classify", return_value=mock_result): result = execute_tool("classify_domain", {"complaint_text": "Flipkart refund"}, memory) assert result["domain"] == "ecommerce" assert result["confidence"] == 0.92 assert result["low_confidence"] is False def test_classify_domain_low_confidence_propagated(memory): mock_result = DomainResult( domain="general", confidence=0.3, all_probs={"ecommerce": 0.2, "telecom": 0.1, "banking": 0.1, "cibil": 0.1, "insurance": 0.1, "general": 0.4}, low_confidence=True, ) with patch("src.classifier.predict.classify", return_value=mock_result): result = execute_tool("classify_domain", {"complaint_text": "I have a complaint"}, memory) assert result["low_confidence"] is True def test_classify_domain_exception_returns_error(memory): with patch("src.classifier.predict.classify", side_effect=RuntimeError("model error")): result = execute_tool("classify_domain", {"complaint_text": "test"}, memory) assert "error" in result assert "model error" in result["error"] # --------------------------------------------------------------------------- # extract_entities — patch src.ner.predict.extract_entities # --------------------------------------------------------------------------- def test_extract_entities_returns_list_of_dicts(memory): mock_entities = [Entity(text="Flipkart", label="ORG", start=0, end=8, confidence=0.91)] with patch("src.ner.predict.extract_entities", return_value=mock_entities): result = execute_tool("extract_entities", {"text": "Flipkart refund"}, memory) assert isinstance(result, list) assert len(result) == 1 assert result[0]["text"] == "Flipkart" assert result[0]["label"] == "ORG" assert result[0]["start"] == 0 assert result[0]["end"] == 8 assert result[0]["confidence"] == 0.91 def test_extract_entities_empty_list(memory): with patch("src.ner.predict.extract_entities", return_value=[]): result = execute_tool("extract_entities", {"text": "plain text"}, memory) assert result == [] def test_extract_entities_exception_returns_error(memory): with patch("src.ner.predict.extract_entities", side_effect=Exception("ner failed")): result = execute_tool("extract_entities", {"text": "test"}, memory) assert "error" in result assert "ner failed" in result["error"] # --------------------------------------------------------------------------- # recommend_action — patch src.next_action.predict.recommend_action # --------------------------------------------------------------------------- def _mock_actions(): return [ EscalationAction(action="company_support", authority="Flipkart Grievance", url="https://flipkart.com", confidence=0.5), EscalationAction(action="nch", authority="NCH", url="https://consumerhelpline.gov.in", confidence=0.3), ] def test_recommend_action_prior_contact_defaults_to_false(memory): with patch("src.next_action.predict.recommend_action", return_value=_mock_actions()) as mock_fn: execute_tool("recommend_action", {"domain": "ecommerce"}, memory) _, kwargs = mock_fn.call_args assert kwargs.get("prior_contact", mock_fn.call_args[0][2] if len(mock_fn.call_args[0]) > 2 else False) is False def test_recommend_action_entities_defaults_to_empty(memory): with patch("src.next_action.predict.recommend_action", return_value=_mock_actions()) as mock_fn: execute_tool("recommend_action", {"domain": "ecommerce"}, memory) args, kwargs = mock_fn.call_args entities_val = kwargs.get("entities", args[1] if len(args) > 1 else {}) assert entities_val == {} def test_recommend_action_returns_list_of_dicts(memory): with patch("src.next_action.predict.recommend_action", return_value=_mock_actions()): result = execute_tool("recommend_action", {"domain": "ecommerce", "entities": {}, "prior_contact": False}, memory) assert isinstance(result, list) assert result[0]["action"] == "company_support" def test_recommend_action_exception_returns_error(memory): with patch("src.next_action.predict.recommend_action", side_effect=RuntimeError("predictor failed")): result = execute_tool("recommend_action", {"domain": "ecommerce"}, memory) assert "error" in result assert "predictor failed" in result["error"] # --------------------------------------------------------------------------- # process_document — patch src.document_processor.processor.get_processor # --------------------------------------------------------------------------- def test_process_document_successful_result(memory): mock_entity = Entity(text="Invoice", label="REF_ID", start=0, end=7, confidence=0.85) mock_processor = MagicMock() mock_processor.process.return_value = { "raw_text": "Invoice #12345 for ₹4,299", "entities": [mock_entity], } with patch("src.document_processor.processor.get_processor", return_value=mock_processor): result = execute_tool("process_document", {"file_path": "/tmp/invoice.pdf"}, memory) assert result["raw_text"] == "Invoice #12345 for ₹4,299" assert isinstance(result["entities"], list) assert result["entities"][0]["text"] == "Invoice" def test_process_document_unsupported_extension_returns_error(memory): mock_processor = MagicMock() mock_processor.process.side_effect = ValueError("Unsupported file extension '.xyz'") with patch("src.document_processor.processor.get_processor", return_value=mock_processor): result = execute_tool("process_document", {"file_path": "/tmp/file.xyz"}, memory) assert "error" in result assert "Unsupported file extension" in result["error"] def test_process_document_runtime_error_returns_error(memory): mock_processor = MagicMock() mock_processor.process.side_effect = RuntimeError("tesseract not found") with patch("src.document_processor.processor.get_processor", return_value=mock_processor): result = execute_tool("process_document", {"file_path": "/tmp/scan.png"}, memory) assert "error" in result assert "tesseract not found" in result["error"]