Spaces:
Sleeping
Sleeping
| """ | |
| Unit tests for execute_tool() dispatcher in src/agent/tools.py. | |
| All DL singletons are patched at their source module — no checkpoints loaded. | |
| No Anthropic API calls needed. | |
| Run: pytest tests/agent/test_tools.py | |
| """ | |
| from unittest.mock import MagicMock, patch | |
| import pytest | |
| from src.agent.memory import SessionMemory | |
| from src.agent.tools import execute_tool | |
| from src.classifier.model import DomainResult | |
| from src.ner.model import Entity | |
| from src.next_action.model import EscalationAction | |
| def memory(): | |
| return SessionMemory() | |
| # --------------------------------------------------------------------------- | |
| # draft_complaint — pure Python, no external calls | |
| # --------------------------------------------------------------------------- | |
| def test_draft_complaint_minimal_context(memory): | |
| ctx = { | |
| "domain": "ecommerce", | |
| "provider": "Flipkart", | |
| "incident_date": "2024-03-12", | |
| "desired_resolution": "full refund", | |
| } | |
| result = execute_tool("draft_complaint", {"complaint_context": ctx}, memory) | |
| assert result["status"] == "proceed" | |
| assert result["complaint_context"] == ctx | |
| def test_draft_complaint_missing_key_defaults_to_empty(memory): | |
| result = execute_tool("draft_complaint", {}, memory) | |
| assert result["status"] == "proceed" | |
| assert result["complaint_context"] == {} | |
| # --------------------------------------------------------------------------- | |
| # store_memory / get_memory | |
| # --------------------------------------------------------------------------- | |
| def test_store_memory_writes_to_session(memory): | |
| result = execute_tool("store_memory", {"key": "domain", "value": "telecom"}, memory) | |
| assert result == {"status": "stored", "key": "domain"} | |
| assert memory.get("domain") == "telecom" | |
| def test_get_memory_reads_from_session(memory): | |
| memory.set("domain", "telecom") | |
| result = execute_tool("get_memory", {"key": "domain"}, memory) | |
| assert result == {"key": "domain", "value": "telecom"} | |
| def test_get_memory_missing_key_returns_none(memory): | |
| result = execute_tool("get_memory", {"key": "absent"}, memory) | |
| assert result == {"key": "absent", "value": None} | |
| def test_store_get_round_trip(memory): | |
| execute_tool("store_memory", {"key": "prior_contact", "value": True}, memory) | |
| result = execute_tool("get_memory", {"key": "prior_contact"}, memory) | |
| assert result["value"] is True | |
| # --------------------------------------------------------------------------- | |
| # Unknown tool name | |
| # --------------------------------------------------------------------------- | |
| def test_unknown_tool_returns_error(memory): | |
| result = execute_tool("nonexistent_tool", {}, memory) | |
| assert "error" in result | |
| assert "nonexistent_tool" in result["error"] | |
| # --------------------------------------------------------------------------- | |
| # classify_domain — patch src.classifier.predict.classify | |
| # --------------------------------------------------------------------------- | |
| def test_classify_domain_high_confidence(memory): | |
| mock_result = DomainResult( | |
| domain="ecommerce", | |
| confidence=0.92, | |
| all_probs={"ecommerce": 0.92, "telecom": 0.03, "banking": 0.02, | |
| "cibil": 0.01, "insurance": 0.01, "general": 0.01}, | |
| low_confidence=False, | |
| ) | |
| with patch("src.classifier.predict.classify", return_value=mock_result): | |
| result = execute_tool("classify_domain", {"complaint_text": "Flipkart refund"}, memory) | |
| assert result["domain"] == "ecommerce" | |
| assert result["confidence"] == 0.92 | |
| assert result["low_confidence"] is False | |
| def test_classify_domain_low_confidence_propagated(memory): | |
| mock_result = DomainResult( | |
| domain="general", | |
| confidence=0.3, | |
| all_probs={"ecommerce": 0.2, "telecom": 0.1, "banking": 0.1, | |
| "cibil": 0.1, "insurance": 0.1, "general": 0.4}, | |
| low_confidence=True, | |
| ) | |
| with patch("src.classifier.predict.classify", return_value=mock_result): | |
| result = execute_tool("classify_domain", {"complaint_text": "I have a complaint"}, memory) | |
| assert result["low_confidence"] is True | |
| def test_classify_domain_exception_returns_error(memory): | |
| with patch("src.classifier.predict.classify", side_effect=RuntimeError("model error")): | |
| result = execute_tool("classify_domain", {"complaint_text": "test"}, memory) | |
| assert "error" in result | |
| assert "model error" in result["error"] | |
| # --------------------------------------------------------------------------- | |
| # extract_entities — patch src.ner.predict.extract_entities | |
| # --------------------------------------------------------------------------- | |
| def test_extract_entities_returns_list_of_dicts(memory): | |
| mock_entities = [Entity(text="Flipkart", label="ORG", start=0, end=8, confidence=0.91)] | |
| with patch("src.ner.predict.extract_entities", return_value=mock_entities): | |
| result = execute_tool("extract_entities", {"text": "Flipkart refund"}, memory) | |
| assert isinstance(result, list) | |
| assert len(result) == 1 | |
| assert result[0]["text"] == "Flipkart" | |
| assert result[0]["label"] == "ORG" | |
| assert result[0]["start"] == 0 | |
| assert result[0]["end"] == 8 | |
| assert result[0]["confidence"] == 0.91 | |
| def test_extract_entities_empty_list(memory): | |
| with patch("src.ner.predict.extract_entities", return_value=[]): | |
| result = execute_tool("extract_entities", {"text": "plain text"}, memory) | |
| assert result == [] | |
| def test_extract_entities_exception_returns_error(memory): | |
| with patch("src.ner.predict.extract_entities", side_effect=Exception("ner failed")): | |
| result = execute_tool("extract_entities", {"text": "test"}, memory) | |
| assert "error" in result | |
| assert "ner failed" in result["error"] | |
| # --------------------------------------------------------------------------- | |
| # recommend_action — patch src.next_action.predict.recommend_action | |
| # --------------------------------------------------------------------------- | |
| def _mock_actions(): | |
| return [ | |
| EscalationAction(action="company_support", authority="Flipkart Grievance", url="https://flipkart.com", confidence=0.5), | |
| EscalationAction(action="nch", authority="NCH", url="https://consumerhelpline.gov.in", confidence=0.3), | |
| ] | |
| def test_recommend_action_prior_contact_defaults_to_false(memory): | |
| with patch("src.next_action.predict.recommend_action", return_value=_mock_actions()) as mock_fn: | |
| execute_tool("recommend_action", {"domain": "ecommerce"}, memory) | |
| _, kwargs = mock_fn.call_args | |
| assert kwargs.get("prior_contact", mock_fn.call_args[0][2] if len(mock_fn.call_args[0]) > 2 else False) is False | |
| def test_recommend_action_entities_defaults_to_empty(memory): | |
| with patch("src.next_action.predict.recommend_action", return_value=_mock_actions()) as mock_fn: | |
| execute_tool("recommend_action", {"domain": "ecommerce"}, memory) | |
| args, kwargs = mock_fn.call_args | |
| entities_val = kwargs.get("entities", args[1] if len(args) > 1 else {}) | |
| assert entities_val == {} | |
| def test_recommend_action_returns_list_of_dicts(memory): | |
| with patch("src.next_action.predict.recommend_action", return_value=_mock_actions()): | |
| result = execute_tool("recommend_action", {"domain": "ecommerce", "entities": {}, "prior_contact": False}, memory) | |
| assert isinstance(result, list) | |
| assert result[0]["action"] == "company_support" | |
| def test_recommend_action_exception_returns_error(memory): | |
| with patch("src.next_action.predict.recommend_action", side_effect=RuntimeError("predictor failed")): | |
| result = execute_tool("recommend_action", {"domain": "ecommerce"}, memory) | |
| assert "error" in result | |
| assert "predictor failed" in result["error"] | |
| # --------------------------------------------------------------------------- | |
| # process_document — patch src.document_processor.processor.get_processor | |
| # --------------------------------------------------------------------------- | |
| def test_process_document_successful_result(memory): | |
| mock_entity = Entity(text="Invoice", label="REF_ID", start=0, end=7, confidence=0.85) | |
| mock_processor = MagicMock() | |
| mock_processor.process.return_value = { | |
| "raw_text": "Invoice #12345 for ₹4,299", | |
| "entities": [mock_entity], | |
| } | |
| with patch("src.document_processor.processor.get_processor", return_value=mock_processor): | |
| result = execute_tool("process_document", {"file_path": "/tmp/invoice.pdf"}, memory) | |
| assert result["raw_text"] == "Invoice #12345 for ₹4,299" | |
| assert isinstance(result["entities"], list) | |
| assert result["entities"][0]["text"] == "Invoice" | |
| def test_process_document_unsupported_extension_returns_error(memory): | |
| mock_processor = MagicMock() | |
| mock_processor.process.side_effect = ValueError("Unsupported file extension '.xyz'") | |
| with patch("src.document_processor.processor.get_processor", return_value=mock_processor): | |
| result = execute_tool("process_document", {"file_path": "/tmp/file.xyz"}, memory) | |
| assert "error" in result | |
| assert "Unsupported file extension" in result["error"] | |
| def test_process_document_runtime_error_returns_error(memory): | |
| mock_processor = MagicMock() | |
| mock_processor.process.side_effect = RuntimeError("tesseract not found") | |
| with patch("src.document_processor.processor.get_processor", return_value=mock_processor): | |
| result = execute_tool("process_document", {"file_path": "/tmp/scan.png"}, memory) | |
| assert "error" in result | |
| assert "tesseract not found" in result["error"] | |