TrialPath / trialpath /tests /test_tools.py
yakilee's picture
style: apply ruff format to entire codebase
e46883d
"""TDD tests for Parlant tool functions."""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import trialpath.agent.tools as tools_module
from trialpath.agent.tools import (
ALL_TOOLS,
analyze_gaps,
evaluate_trial_eligibility,
extract_patient_profile,
generate_search_anchors,
refine_search_query,
relax_search_query,
search_clinical_trials,
)
@pytest.fixture(autouse=True)
def _reset_singletons():
"""Reset cached service singletons between tests."""
tools_module._extractor = None
tools_module._planner = None
tools_module._mcp_client = None
yield
tools_module._extractor = None
tools_module._planner = None
tools_module._mcp_client = None
@pytest.fixture
def mock_context():
return MagicMock()
class TestExtractPatientProfile:
"""Test extract_patient_profile tool."""
@pytest.mark.asyncio
async def test_calls_medgemma_extractor(self, mock_context):
"""Should call MedGemmaExtractor.extract with correct args."""
profile = {"patient_id": "P001", "diagnosis": {"primary_condition": "NSCLC"}}
with patch("trialpath.services.medgemma_extractor.MedGemmaExtractor") as MockExtractor:
MockExtractor.return_value.extract = AsyncMock(return_value=profile)
result = await extract_patient_profile.function(
mock_context,
document_urls=json.dumps(["doc1.pdf"]),
metadata=json.dumps({"age": 52}),
)
MockExtractor.return_value.extract.assert_called_once()
assert result.data["patient_id"] == "P001"
@pytest.mark.asyncio
async def test_returns_tool_result_with_metadata(self, mock_context):
"""ToolResult should contain source metadata."""
with patch("trialpath.services.medgemma_extractor.MedGemmaExtractor") as MockExtractor:
MockExtractor.return_value.extract = AsyncMock(return_value={})
result = await extract_patient_profile.function(
mock_context,
document_urls=json.dumps(["a.pdf", "b.pdf"]),
metadata=json.dumps({}),
)
assert result.metadata["source"] == "medgemma"
assert result.metadata["doc_count"] == 2
class TestGenerateSearchAnchors:
"""Test generate_search_anchors tool."""
@pytest.mark.asyncio
async def test_calls_gemini_planner(self, mock_context):
"""Should call GeminiPlanner.generate_search_anchors."""
from trialpath.models.search_anchors import SearchAnchors
mock_anchors = SearchAnchors(condition="NSCLC")
with patch("trialpath.services.gemini_planner.GeminiPlanner") as MockPlanner:
MockPlanner.return_value.generate_search_anchors = AsyncMock(return_value=mock_anchors)
result = await generate_search_anchors.function(
mock_context,
patient_profile=json.dumps({"patient_id": "P001"}),
)
assert result.data["condition"] == "NSCLC"
class TestSearchClinicalTrials:
"""Test search_clinical_trials tool."""
@pytest.mark.asyncio
async def test_calls_mcp_client_and_normalizes(self, mock_context):
"""Should call MCP client and normalize results."""
raw_study = {"nctId": "NCT001", "title": "Test Trial"}
with patch("trialpath.services.mcp_client.ClinicalTrialsMCPClient") as MockClient:
MockClient.return_value.search = AsyncMock(return_value=[raw_study])
mock_trial = MagicMock()
mock_trial.model_dump.return_value = {"nct_id": "NCT001", "title": "Test Trial"}
MockClient.normalize_trial = MagicMock(return_value=mock_trial)
result = await search_clinical_trials.function(
mock_context,
search_anchors=json.dumps({"condition": "NSCLC"}),
)
assert result.data["count"] == 1
assert result.metadata["source"] == "clinicaltrials_mcp"
class TestRefineSearchQuery:
"""Test refine_search_query tool."""
@pytest.mark.asyncio
async def test_calls_gemini_refine(self, mock_context):
"""Should call GeminiPlanner.refine_search."""
from trialpath.models.search_anchors import SearchAnchors
mock_refined = SearchAnchors(condition="NSCLC", biomarkers=["EGFR"])
with patch("trialpath.services.gemini_planner.GeminiPlanner") as MockPlanner:
MockPlanner.return_value.refine_search = AsyncMock(return_value=mock_refined)
result = await refine_search_query.function(
mock_context,
search_anchors=json.dumps({"condition": "NSCLC"}),
result_count="100",
)
assert result.metadata["action"] == "refine"
assert result.metadata["prev_count"] == 100
class TestRelaxSearchQuery:
"""Test relax_search_query tool."""
@pytest.mark.asyncio
async def test_calls_gemini_relax(self, mock_context):
"""Should call GeminiPlanner.relax_search."""
from trialpath.models.search_anchors import SearchAnchors
mock_relaxed = SearchAnchors(condition="NSCLC")
with patch("trialpath.services.gemini_planner.GeminiPlanner") as MockPlanner:
MockPlanner.return_value.relax_search = AsyncMock(return_value=mock_relaxed)
result = await relax_search_query.function(
mock_context,
search_anchors=json.dumps({"condition": "NSCLC"}),
result_count="0",
)
assert result.metadata["action"] == "relax"
class TestEvaluateTrialEligibility:
"""Test evaluate_trial_eligibility tool."""
@pytest.mark.asyncio
async def test_dual_model_evaluation(self, mock_context):
"""Should use MedGemma for medical and Gemini for structural criteria."""
from trialpath.models.eligibility_ledger import (
EligibilityLedger,
OverallAssessment,
)
mock_ledger = EligibilityLedger(
patient_id="P001",
nct_id="NCT001",
overall_assessment=OverallAssessment.LIKELY_ELIGIBLE,
criteria=[],
gaps=[],
)
with (
patch("trialpath.services.gemini_planner.GeminiPlanner") as MockPlanner,
patch("trialpath.services.medgemma_extractor.MedGemmaExtractor") as MockExtractor,
):
MockPlanner.return_value.slice_criteria = AsyncMock(
return_value=[
{
"criterion_id": "inc_1",
"type": "inclusion",
"text": "EGFR mutation",
"category": "medical",
},
{
"criterion_id": "inc_2",
"type": "inclusion",
"text": "Age >= 18",
"category": "structural",
},
]
)
MockExtractor.return_value.evaluate_medical_criterion = AsyncMock(
return_value={"decision": "met", "reasoning": "OK", "confidence": 0.9}
)
MockPlanner.return_value.evaluate_structural_criterion = AsyncMock(
return_value={"decision": "met", "reasoning": "OK", "confidence": 0.99}
)
MockPlanner.return_value.aggregate_assessments = AsyncMock(return_value=mock_ledger)
result = await evaluate_trial_eligibility.function(
mock_context,
patient_profile=json.dumps({"patient_id": "P001"}),
trial_candidate=json.dumps({"nct_id": "NCT001"}),
)
assert result.data["overall_assessment"] == "likely_eligible"
assert result.metadata["criteria_count"] == 2
MockExtractor.return_value.evaluate_medical_criterion.assert_called_once()
MockPlanner.return_value.evaluate_structural_criterion.assert_called_once()
class TestAnalyzeGaps:
"""Test analyze_gaps tool."""
@pytest.mark.asyncio
async def test_calls_gemini_gap_analysis(self, mock_context):
"""Should call GeminiPlanner.analyze_gaps."""
mock_gaps = [
{
"description": "Brain MRI needed",
"recommended_action": "Upload MRI",
"clinical_importance": "high",
"affected_trial_count": 2,
}
]
with patch("trialpath.services.gemini_planner.GeminiPlanner") as MockPlanner:
MockPlanner.return_value.analyze_gaps = AsyncMock(return_value=mock_gaps)
result = await analyze_gaps.function(
mock_context,
patient_profile=json.dumps({}),
eligibility_ledgers=json.dumps([]),
)
assert result.data["count"] == 1
assert result.data["gaps"][0]["clinical_importance"] == "high"
class TestAllToolsExported:
"""Test ALL_TOOLS list completeness."""
def test_all_tools_has_7_entries(self):
"""ALL_TOOLS should contain exactly 7 tools."""
assert len(ALL_TOOLS) == 7
def test_all_tools_are_tool_entries(self):
"""Each item in ALL_TOOLS should be a ToolEntry."""
from parlant.sdk import ToolEntry
for t in ALL_TOOLS:
assert isinstance(t, ToolEntry), f"{t} is not a ToolEntry"