"""TDD tests for Parlant tool functions.""" import json from unittest.mock import AsyncMock, MagicMock, patch import pytest import trialpath.agent.tools as tools_module from trialpath.agent.tools import ( ALL_TOOLS, analyze_gaps, evaluate_trial_eligibility, extract_patient_profile, generate_search_anchors, refine_search_query, relax_search_query, search_clinical_trials, ) @pytest.fixture(autouse=True) def _reset_singletons(): """Reset cached service singletons between tests.""" tools_module._extractor = None tools_module._planner = None tools_module._mcp_client = None yield tools_module._extractor = None tools_module._planner = None tools_module._mcp_client = None @pytest.fixture def mock_context(): return MagicMock() class TestExtractPatientProfile: """Test extract_patient_profile tool.""" @pytest.mark.asyncio async def test_calls_medgemma_extractor(self, mock_context): """Should call MedGemmaExtractor.extract with correct args.""" profile = {"patient_id": "P001", "diagnosis": {"primary_condition": "NSCLC"}} with patch("trialpath.services.medgemma_extractor.MedGemmaExtractor") as MockExtractor: MockExtractor.return_value.extract = AsyncMock(return_value=profile) result = await extract_patient_profile.function( mock_context, document_urls=json.dumps(["doc1.pdf"]), metadata=json.dumps({"age": 52}), ) MockExtractor.return_value.extract.assert_called_once() assert result.data["patient_id"] == "P001" @pytest.mark.asyncio async def test_returns_tool_result_with_metadata(self, mock_context): """ToolResult should contain source metadata.""" with patch("trialpath.services.medgemma_extractor.MedGemmaExtractor") as MockExtractor: MockExtractor.return_value.extract = AsyncMock(return_value={}) result = await extract_patient_profile.function( mock_context, document_urls=json.dumps(["a.pdf", "b.pdf"]), metadata=json.dumps({}), ) assert result.metadata["source"] == "medgemma" assert result.metadata["doc_count"] == 2 class TestGenerateSearchAnchors: """Test generate_search_anchors tool.""" @pytest.mark.asyncio async def test_calls_gemini_planner(self, mock_context): """Should call GeminiPlanner.generate_search_anchors.""" from trialpath.models.search_anchors import SearchAnchors mock_anchors = SearchAnchors(condition="NSCLC") with patch("trialpath.services.gemini_planner.GeminiPlanner") as MockPlanner: MockPlanner.return_value.generate_search_anchors = AsyncMock(return_value=mock_anchors) result = await generate_search_anchors.function( mock_context, patient_profile=json.dumps({"patient_id": "P001"}), ) assert result.data["condition"] == "NSCLC" class TestSearchClinicalTrials: """Test search_clinical_trials tool.""" @pytest.mark.asyncio async def test_calls_mcp_client_and_normalizes(self, mock_context): """Should call MCP client and normalize results.""" raw_study = {"nctId": "NCT001", "title": "Test Trial"} with patch("trialpath.services.mcp_client.ClinicalTrialsMCPClient") as MockClient: MockClient.return_value.search = AsyncMock(return_value=[raw_study]) mock_trial = MagicMock() mock_trial.model_dump.return_value = {"nct_id": "NCT001", "title": "Test Trial"} MockClient.normalize_trial = MagicMock(return_value=mock_trial) result = await search_clinical_trials.function( mock_context, search_anchors=json.dumps({"condition": "NSCLC"}), ) assert result.data["count"] == 1 assert result.metadata["source"] == "clinicaltrials_mcp" class TestRefineSearchQuery: """Test refine_search_query tool.""" @pytest.mark.asyncio async def test_calls_gemini_refine(self, mock_context): """Should call GeminiPlanner.refine_search.""" from trialpath.models.search_anchors import SearchAnchors mock_refined = SearchAnchors(condition="NSCLC", biomarkers=["EGFR"]) with patch("trialpath.services.gemini_planner.GeminiPlanner") as MockPlanner: MockPlanner.return_value.refine_search = AsyncMock(return_value=mock_refined) result = await refine_search_query.function( mock_context, search_anchors=json.dumps({"condition": "NSCLC"}), result_count="100", ) assert result.metadata["action"] == "refine" assert result.metadata["prev_count"] == 100 class TestRelaxSearchQuery: """Test relax_search_query tool.""" @pytest.mark.asyncio async def test_calls_gemini_relax(self, mock_context): """Should call GeminiPlanner.relax_search.""" from trialpath.models.search_anchors import SearchAnchors mock_relaxed = SearchAnchors(condition="NSCLC") with patch("trialpath.services.gemini_planner.GeminiPlanner") as MockPlanner: MockPlanner.return_value.relax_search = AsyncMock(return_value=mock_relaxed) result = await relax_search_query.function( mock_context, search_anchors=json.dumps({"condition": "NSCLC"}), result_count="0", ) assert result.metadata["action"] == "relax" class TestEvaluateTrialEligibility: """Test evaluate_trial_eligibility tool.""" @pytest.mark.asyncio async def test_dual_model_evaluation(self, mock_context): """Should use MedGemma for medical and Gemini for structural criteria.""" from trialpath.models.eligibility_ledger import ( EligibilityLedger, OverallAssessment, ) mock_ledger = EligibilityLedger( patient_id="P001", nct_id="NCT001", overall_assessment=OverallAssessment.LIKELY_ELIGIBLE, criteria=[], gaps=[], ) with ( patch("trialpath.services.gemini_planner.GeminiPlanner") as MockPlanner, patch("trialpath.services.medgemma_extractor.MedGemmaExtractor") as MockExtractor, ): MockPlanner.return_value.slice_criteria = AsyncMock( return_value=[ { "criterion_id": "inc_1", "type": "inclusion", "text": "EGFR mutation", "category": "medical", }, { "criterion_id": "inc_2", "type": "inclusion", "text": "Age >= 18", "category": "structural", }, ] ) MockExtractor.return_value.evaluate_medical_criterion = AsyncMock( return_value={"decision": "met", "reasoning": "OK", "confidence": 0.9} ) MockPlanner.return_value.evaluate_structural_criterion = AsyncMock( return_value={"decision": "met", "reasoning": "OK", "confidence": 0.99} ) MockPlanner.return_value.aggregate_assessments = AsyncMock(return_value=mock_ledger) result = await evaluate_trial_eligibility.function( mock_context, patient_profile=json.dumps({"patient_id": "P001"}), trial_candidate=json.dumps({"nct_id": "NCT001"}), ) assert result.data["overall_assessment"] == "likely_eligible" assert result.metadata["criteria_count"] == 2 MockExtractor.return_value.evaluate_medical_criterion.assert_called_once() MockPlanner.return_value.evaluate_structural_criterion.assert_called_once() class TestAnalyzeGaps: """Test analyze_gaps tool.""" @pytest.mark.asyncio async def test_calls_gemini_gap_analysis(self, mock_context): """Should call GeminiPlanner.analyze_gaps.""" mock_gaps = [ { "description": "Brain MRI needed", "recommended_action": "Upload MRI", "clinical_importance": "high", "affected_trial_count": 2, } ] with patch("trialpath.services.gemini_planner.GeminiPlanner") as MockPlanner: MockPlanner.return_value.analyze_gaps = AsyncMock(return_value=mock_gaps) result = await analyze_gaps.function( mock_context, patient_profile=json.dumps({}), eligibility_ledgers=json.dumps([]), ) assert result.data["count"] == 1 assert result.data["gaps"][0]["clinical_importance"] == "high" class TestAllToolsExported: """Test ALL_TOOLS list completeness.""" def test_all_tools_has_7_entries(self): """ALL_TOOLS should contain exactly 7 tools.""" assert len(ALL_TOOLS) == 7 def test_all_tools_are_tool_entries(self): """Each item in ALL_TOOLS should be a ToolEntry.""" from parlant.sdk import ToolEntry for t in ALL_TOOLS: assert isinstance(t, ToolEntry), f"{t} is not a ToolEntry"