| | """TDD tests for Parlant tool functions.""" |
| |
|
| | import json |
| | from unittest.mock import AsyncMock, MagicMock, patch |
| |
|
| | import pytest |
| |
|
| | import trialpath.agent.tools as tools_module |
| | from trialpath.agent.tools import ( |
| | ALL_TOOLS, |
| | analyze_gaps, |
| | evaluate_trial_eligibility, |
| | extract_patient_profile, |
| | generate_search_anchors, |
| | refine_search_query, |
| | relax_search_query, |
| | search_clinical_trials, |
| | ) |
| |
|
| |
|
| | @pytest.fixture(autouse=True) |
| | def _reset_singletons(): |
| | """Reset cached service singletons between tests.""" |
| | tools_module._extractor = None |
| | tools_module._planner = None |
| | tools_module._mcp_client = None |
| | yield |
| | tools_module._extractor = None |
| | tools_module._planner = None |
| | tools_module._mcp_client = None |
| |
|
| |
|
| | @pytest.fixture |
| | def mock_context(): |
| | return MagicMock() |
| |
|
| |
|
| | class TestExtractPatientProfile: |
| | """Test extract_patient_profile tool.""" |
| |
|
| | @pytest.mark.asyncio |
| | async def test_calls_medgemma_extractor(self, mock_context): |
| | """Should call MedGemmaExtractor.extract with correct args.""" |
| | profile = {"patient_id": "P001", "diagnosis": {"primary_condition": "NSCLC"}} |
| |
|
| | with patch("trialpath.services.medgemma_extractor.MedGemmaExtractor") as MockExtractor: |
| | MockExtractor.return_value.extract = AsyncMock(return_value=profile) |
| |
|
| | result = await extract_patient_profile.function( |
| | mock_context, |
| | document_urls=json.dumps(["doc1.pdf"]), |
| | metadata=json.dumps({"age": 52}), |
| | ) |
| |
|
| | MockExtractor.return_value.extract.assert_called_once() |
| | assert result.data["patient_id"] == "P001" |
| |
|
| | @pytest.mark.asyncio |
| | async def test_returns_tool_result_with_metadata(self, mock_context): |
| | """ToolResult should contain source metadata.""" |
| | with patch("trialpath.services.medgemma_extractor.MedGemmaExtractor") as MockExtractor: |
| | MockExtractor.return_value.extract = AsyncMock(return_value={}) |
| |
|
| | result = await extract_patient_profile.function( |
| | mock_context, |
| | document_urls=json.dumps(["a.pdf", "b.pdf"]), |
| | metadata=json.dumps({}), |
| | ) |
| |
|
| | assert result.metadata["source"] == "medgemma" |
| | assert result.metadata["doc_count"] == 2 |
| |
|
| |
|
| | class TestGenerateSearchAnchors: |
| | """Test generate_search_anchors tool.""" |
| |
|
| | @pytest.mark.asyncio |
| | async def test_calls_gemini_planner(self, mock_context): |
| | """Should call GeminiPlanner.generate_search_anchors.""" |
| | from trialpath.models.search_anchors import SearchAnchors |
| |
|
| | mock_anchors = SearchAnchors(condition="NSCLC") |
| |
|
| | with patch("trialpath.services.gemini_planner.GeminiPlanner") as MockPlanner: |
| | MockPlanner.return_value.generate_search_anchors = AsyncMock(return_value=mock_anchors) |
| |
|
| | result = await generate_search_anchors.function( |
| | mock_context, |
| | patient_profile=json.dumps({"patient_id": "P001"}), |
| | ) |
| |
|
| | assert result.data["condition"] == "NSCLC" |
| |
|
| |
|
| | class TestSearchClinicalTrials: |
| | """Test search_clinical_trials tool.""" |
| |
|
| | @pytest.mark.asyncio |
| | async def test_calls_mcp_client_and_normalizes(self, mock_context): |
| | """Should call MCP client and normalize results.""" |
| | raw_study = {"nctId": "NCT001", "title": "Test Trial"} |
| |
|
| | with patch("trialpath.services.mcp_client.ClinicalTrialsMCPClient") as MockClient: |
| | MockClient.return_value.search = AsyncMock(return_value=[raw_study]) |
| | mock_trial = MagicMock() |
| | mock_trial.model_dump.return_value = {"nct_id": "NCT001", "title": "Test Trial"} |
| | MockClient.normalize_trial = MagicMock(return_value=mock_trial) |
| |
|
| | result = await search_clinical_trials.function( |
| | mock_context, |
| | search_anchors=json.dumps({"condition": "NSCLC"}), |
| | ) |
| |
|
| | assert result.data["count"] == 1 |
| | assert result.metadata["source"] == "clinicaltrials_mcp" |
| |
|
| |
|
| | class TestRefineSearchQuery: |
| | """Test refine_search_query tool.""" |
| |
|
| | @pytest.mark.asyncio |
| | async def test_calls_gemini_refine(self, mock_context): |
| | """Should call GeminiPlanner.refine_search.""" |
| | from trialpath.models.search_anchors import SearchAnchors |
| |
|
| | mock_refined = SearchAnchors(condition="NSCLC", biomarkers=["EGFR"]) |
| |
|
| | with patch("trialpath.services.gemini_planner.GeminiPlanner") as MockPlanner: |
| | MockPlanner.return_value.refine_search = AsyncMock(return_value=mock_refined) |
| |
|
| | result = await refine_search_query.function( |
| | mock_context, |
| | search_anchors=json.dumps({"condition": "NSCLC"}), |
| | result_count="100", |
| | ) |
| |
|
| | assert result.metadata["action"] == "refine" |
| | assert result.metadata["prev_count"] == 100 |
| |
|
| |
|
| | class TestRelaxSearchQuery: |
| | """Test relax_search_query tool.""" |
| |
|
| | @pytest.mark.asyncio |
| | async def test_calls_gemini_relax(self, mock_context): |
| | """Should call GeminiPlanner.relax_search.""" |
| | from trialpath.models.search_anchors import SearchAnchors |
| |
|
| | mock_relaxed = SearchAnchors(condition="NSCLC") |
| |
|
| | with patch("trialpath.services.gemini_planner.GeminiPlanner") as MockPlanner: |
| | MockPlanner.return_value.relax_search = AsyncMock(return_value=mock_relaxed) |
| |
|
| | result = await relax_search_query.function( |
| | mock_context, |
| | search_anchors=json.dumps({"condition": "NSCLC"}), |
| | result_count="0", |
| | ) |
| |
|
| | assert result.metadata["action"] == "relax" |
| |
|
| |
|
| | class TestEvaluateTrialEligibility: |
| | """Test evaluate_trial_eligibility tool.""" |
| |
|
| | @pytest.mark.asyncio |
| | async def test_dual_model_evaluation(self, mock_context): |
| | """Should use MedGemma for medical and Gemini for structural criteria.""" |
| | from trialpath.models.eligibility_ledger import ( |
| | EligibilityLedger, |
| | OverallAssessment, |
| | ) |
| |
|
| | mock_ledger = EligibilityLedger( |
| | patient_id="P001", |
| | nct_id="NCT001", |
| | overall_assessment=OverallAssessment.LIKELY_ELIGIBLE, |
| | criteria=[], |
| | gaps=[], |
| | ) |
| |
|
| | with ( |
| | patch("trialpath.services.gemini_planner.GeminiPlanner") as MockPlanner, |
| | patch("trialpath.services.medgemma_extractor.MedGemmaExtractor") as MockExtractor, |
| | ): |
| | MockPlanner.return_value.slice_criteria = AsyncMock( |
| | return_value=[ |
| | { |
| | "criterion_id": "inc_1", |
| | "type": "inclusion", |
| | "text": "EGFR mutation", |
| | "category": "medical", |
| | }, |
| | { |
| | "criterion_id": "inc_2", |
| | "type": "inclusion", |
| | "text": "Age >= 18", |
| | "category": "structural", |
| | }, |
| | ] |
| | ) |
| | MockExtractor.return_value.evaluate_medical_criterion = AsyncMock( |
| | return_value={"decision": "met", "reasoning": "OK", "confidence": 0.9} |
| | ) |
| | MockPlanner.return_value.evaluate_structural_criterion = AsyncMock( |
| | return_value={"decision": "met", "reasoning": "OK", "confidence": 0.99} |
| | ) |
| | MockPlanner.return_value.aggregate_assessments = AsyncMock(return_value=mock_ledger) |
| |
|
| | result = await evaluate_trial_eligibility.function( |
| | mock_context, |
| | patient_profile=json.dumps({"patient_id": "P001"}), |
| | trial_candidate=json.dumps({"nct_id": "NCT001"}), |
| | ) |
| |
|
| | assert result.data["overall_assessment"] == "likely_eligible" |
| | assert result.metadata["criteria_count"] == 2 |
| | MockExtractor.return_value.evaluate_medical_criterion.assert_called_once() |
| | MockPlanner.return_value.evaluate_structural_criterion.assert_called_once() |
| |
|
| |
|
| | class TestAnalyzeGaps: |
| | """Test analyze_gaps tool.""" |
| |
|
| | @pytest.mark.asyncio |
| | async def test_calls_gemini_gap_analysis(self, mock_context): |
| | """Should call GeminiPlanner.analyze_gaps.""" |
| | mock_gaps = [ |
| | { |
| | "description": "Brain MRI needed", |
| | "recommended_action": "Upload MRI", |
| | "clinical_importance": "high", |
| | "affected_trial_count": 2, |
| | } |
| | ] |
| |
|
| | with patch("trialpath.services.gemini_planner.GeminiPlanner") as MockPlanner: |
| | MockPlanner.return_value.analyze_gaps = AsyncMock(return_value=mock_gaps) |
| |
|
| | result = await analyze_gaps.function( |
| | mock_context, |
| | patient_profile=json.dumps({}), |
| | eligibility_ledgers=json.dumps([]), |
| | ) |
| |
|
| | assert result.data["count"] == 1 |
| | assert result.data["gaps"][0]["clinical_importance"] == "high" |
| |
|
| |
|
| | class TestAllToolsExported: |
| | """Test ALL_TOOLS list completeness.""" |
| |
|
| | def test_all_tools_has_7_entries(self): |
| | """ALL_TOOLS should contain exactly 7 tools.""" |
| | assert len(ALL_TOOLS) == 7 |
| |
|
| | def test_all_tools_are_tool_entries(self): |
| | """Each item in ALL_TOOLS should be a ToolEntry.""" |
| | from parlant.sdk import ToolEntry |
| |
|
| | for t in ALL_TOOLS: |
| | assert isinstance(t, ToolEntry), f"{t} is not a ToolEntry" |
| |
|