TrialPath / trialpath /tests /test_mcp.py
yakilee's picture
fix: correct logger variable and test mocks for enhanced search_direct
008813e
"""TDD tests for ClinicalTrials MCP client."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from trialpath.models.search_anchors import GeographyFilter, SearchAnchors, TrialFilters
from trialpath.services.mcp_client import ClinicalTrialsMCPClient, MCPError
class TestMCPClient:
"""Test ClinicalTrials MCP client."""
@pytest.fixture
def client(self):
return ClinicalTrialsMCPClient(mcp_url="http://localhost:3000")
@pytest.fixture
def sample_anchors(self):
return SearchAnchors(
condition="Non-Small Cell Lung Cancer",
subtype="adenocarcinoma",
biomarkers=["EGFR exon 19 deletion"],
stage="IV",
age=52,
geography=GeographyFilter(country="United States"),
trial_filters=TrialFilters(
recruitment_status=["Recruiting"],
phase=["Phase 3"],
),
)
def _mock_httpx(self, MockHTTP, response_data):
import json as _json
# Mock the response from client.post()
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
mock_response.text = _json.dumps(response_data)
mock_response.headers = {"content-type": "application/json"}
mock_client = MagicMock()
mock_client.post = AsyncMock(return_value=mock_response)
# AsyncClient() itself is an async context manager
mock_client_ctx = MagicMock()
mock_client_ctx.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_ctx.__aexit__ = AsyncMock(return_value=None)
MockHTTP.return_value = mock_client_ctx
return mock_client
@pytest.mark.asyncio
async def test_search_builds_correct_query(self, client, sample_anchors):
"""Search should combine condition, subtype, and biomarkers into query."""
with patch("httpx.AsyncClient") as MockHTTP:
mock_client = self._mock_httpx(MockHTTP, {"result": {"studies": []}})
await client.search(sample_anchors)
call_args = mock_client.post.call_args
body = call_args.kwargs.get("json", call_args[1].get("json", {}))
query = body["params"]["arguments"]["query"]
assert "Non-Small Cell Lung Cancer" in query
assert "adenocarcinoma" in query
@pytest.mark.asyncio
async def test_search_includes_country_filter(self, client, sample_anchors):
"""Search should pass country as a parameter."""
with patch("httpx.AsyncClient") as MockHTTP:
mock_client = self._mock_httpx(MockHTTP, {"result": {"studies": []}})
await client.search(sample_anchors)
call_args = mock_client.post.call_args
body = call_args.kwargs.get("json", call_args[1].get("json", {}))
args = body["params"]["arguments"]
assert args.get("country") == "United States"
@pytest.mark.asyncio
async def test_search_includes_recruitment_status_filter(self, client, sample_anchors):
"""Search should include recruitment status in filter expression."""
with patch("httpx.AsyncClient") as MockHTTP:
mock_client = self._mock_httpx(MockHTTP, {"result": {"studies": []}})
await client.search(sample_anchors)
call_args = mock_client.post.call_args
body = call_args.kwargs.get("json", call_args[1].get("json", {}))
filter_str = body["params"]["arguments"].get("filter", "")
assert "OverallStatus" in filter_str
assert "Recruiting" in filter_str
@pytest.mark.asyncio
async def test_get_study_by_nct_id(self, client):
"""Should call get_study tool with correct NCT ID."""
with patch("httpx.AsyncClient") as MockHTTP:
self._mock_httpx(
MockHTTP, {"result": {"studies": [{"nctId": "NCT01234567", "title": "Test Trial"}]}}
)
result = await client.get_study("NCT01234567")
assert result["nctId"] == "NCT01234567"
@pytest.mark.asyncio
async def test_mcp_error_handling(self, client):
"""Should raise MCPError on MCP server error response."""
with patch("httpx.AsyncClient") as MockHTTP:
self._mock_httpx(MockHTTP, {"error": {"code": -32600, "message": "Invalid request"}})
with pytest.raises(MCPError, match="Invalid request"):
await client.get_study("NCT00000000")
@pytest.mark.asyncio
async def test_find_eligible_passes_demographics(self, client):
"""find_eligible should pass patient demographics correctly."""
with patch("httpx.AsyncClient") as MockHTTP:
mock_client = self._mock_httpx(
MockHTTP, {"result": {"eligibleStudies": [], "totalMatches": 0}}
)
await client.find_eligible(
age=52,
sex="Female",
conditions=["NSCLC"],
country="United States",
)
call_args = mock_client.post.call_args
body = call_args.kwargs.get("json", call_args[1].get("json", {}))
args = body["params"]["arguments"]
assert args["age"] == 52
assert args["sex"] == "Female"
assert args["conditions"] == ["NSCLC"]
class TestSearchDirectEnhanced:
"""Test enhanced search_direct with intervention and eligibility dimensions."""
@pytest.fixture
def client(self):
return ClinicalTrialsMCPClient(mcp_url="http://localhost:3000")
@pytest.mark.asyncio
async def test_search_direct_uses_query_cond_and_intr(self, client):
"""search_direct should use query.cond for condition and query.intr for intervention."""
anchors = SearchAnchors(
condition="Non-Small Cell Lung Cancer",
subtype="adenocarcinoma",
interventions=["osimertinib", "erlotinib"],
)
with patch("requests.get") as mock_get:
mock_resp = MagicMock()
mock_resp.json.return_value = {"studies": []}
mock_resp.raise_for_status = MagicMock()
mock_get.return_value = mock_resp
await client.search_direct(anchors)
call_args = mock_get.call_args
params = call_args.kwargs.get("params", call_args[1].get("params", {}))
assert params["query.cond"] == "Non-Small Cell Lung Cancer adenocarcinoma"
assert params["query.intr"] == "osimertinib"
@pytest.mark.asyncio
async def test_search_direct_uses_eligibility_keywords(self, client):
"""search_direct should use query.term for eligibility keywords."""
anchors = SearchAnchors(
condition="NSCLC",
eligibility_keywords=["ECOG 0-1", "stage IV"],
)
with patch("requests.get") as mock_get:
mock_resp = MagicMock()
mock_resp.json.return_value = {"studies": []}
mock_resp.raise_for_status = MagicMock()
mock_get.return_value = mock_resp
await client.search_direct(anchors)
call_args = mock_get.call_args
params = call_args.kwargs.get("params", call_args[1].get("params", {}))
assert params["query.term"] == "ECOG 0-1 stage IV"
@pytest.mark.asyncio
async def test_search_direct_without_interventions_backwards_compatible(self, client):
"""search_direct without new fields should still work (backward compatible)."""
anchors = SearchAnchors(
condition="NSCLC",
biomarkers=["EGFR exon 19 deletion"],
)
with patch("requests.get") as mock_get:
mock_resp = MagicMock()
mock_resp.json.return_value = {"studies": []}
mock_resp.raise_for_status = MagicMock()
mock_get.return_value = mock_resp
await client.search_direct(anchors)
call_args = mock_get.call_args
params = call_args.kwargs.get("params", call_args[1].get("params", {}))
assert params["query.cond"] == "NSCLC"
assert "query.intr" not in params
assert params["query.term"] == "EGFR"
@pytest.mark.asyncio
async def test_search_multi_variant_includes_intervention_variant(self, client):
"""search_multi_variant should fire intervention-specific search variants."""
anchors = SearchAnchors(
condition="NSCLC",
interventions=["osimertinib", "erlotinib"],
eligibility_keywords=["ECOG 0-1"],
)
call_count = 0
async def mock_search_direct(a):
nonlocal call_count
call_count += 1
return [
{
"protocolSection": {
"identificationModule": {
"nctId": f"NCT0000000{call_count}",
"briefTitle": f"Trial {call_count}",
},
"statusModule": {},
"designModule": {},
"conditionsModule": {},
"eligibilityModule": {},
"contactsLocationsModule": {},
}
}
]
with patch.object(client, "search_direct", side_effect=mock_search_direct):
results = await client.search_multi_variant(anchors)
# Variants: full(1) + broad(1) + interventions(2) + eligibility(1) = 5
assert call_count == 5
# All unique NCT IDs
assert len(results) == 5
class TestNormalizeTrial:
"""Test normalize_trial conversion."""
def test_normalize_full_ctgov_response(self):
"""Should convert full ClinicalTrials.gov API response to TrialCandidate."""
raw = {
"protocolSection": {
"identificationModule": {
"nctId": "NCT05012345",
"briefTitle": "Test NSCLC Trial",
},
"statusModule": {
"overallStatus": "Recruiting",
},
"designModule": {
"phases": ["PHASE3"],
},
"conditionsModule": {
"conditions": ["Non-Small Cell Lung Cancer"],
},
"eligibilityModule": {
"minimumAge": "18 Years",
"maximumAge": "80 Years",
"eligibilityCriteria": (
"Inclusion Criteria:\n- Confirmed NSCLC\n- ECOG 0-1\n"
"Exclusion Criteria\n- Active brain metastases"
),
},
"contactsLocationsModule": {
"locations": [
{"country": "United States", "city": "Boston"},
],
},
},
}
trial = ClinicalTrialsMCPClient.normalize_trial(raw)
assert trial.nct_id == "NCT05012345"
assert trial.title == "Test NSCLC Trial"
assert trial.status == "Recruiting"
assert trial.phase == "PHASE3"
assert len(trial.conditions) == 1
assert trial.age_range is not None
assert trial.age_range.min == 18
assert trial.age_range.max == 80
assert len(trial.locations) == 1
assert trial.locations[0].city == "Boston"
assert trial.eligibility_text is not None
assert "Confirmed NSCLC" in trial.eligibility_text.inclusion
assert "brain metastases" in trial.eligibility_text.exclusion
def test_normalize_minimal_response(self):
"""Should handle minimal response with fallback fields."""
raw = {
"nctId": "NCT00000001",
"title": "Minimal Trial",
}
trial = ClinicalTrialsMCPClient.normalize_trial(raw)
assert trial.nct_id == "NCT00000001"
assert trial.title == "Minimal Trial"
assert trial.phase is None
assert trial.age_range is None
assert trial.eligibility_text is None
def test_normalize_returns_trial_candidate_type(self):
"""normalize_trial should return a TrialCandidate instance."""
from trialpath.models.trial_candidate import TrialCandidate
raw = {"nctId": "NCT001", "title": "T"}
trial = ClinicalTrialsMCPClient.normalize_trial(raw)
assert isinstance(trial, TrialCandidate)
def test_normalize_fingerprint_text(self):
"""fingerprint_text should combine title, conditions, and phase."""
raw = {
"protocolSection": {
"identificationModule": {
"nctId": "NCT001",
"briefTitle": "EGFR Trial",
},
"conditionsModule": {"conditions": ["NSCLC"]},
"designModule": {"phases": ["PHASE2"]},
"statusModule": {},
"eligibilityModule": {},
"contactsLocationsModule": {},
}
}
trial = ClinicalTrialsMCPClient.normalize_trial(raw)
assert "EGFR Trial" in trial.fingerprint_text
assert "NSCLC" in trial.fingerprint_text