| """TDD tests for ClinicalTrials MCP client.""" |
|
|
| from unittest.mock import AsyncMock, MagicMock, patch |
|
|
| import pytest |
|
|
| from trialpath.models.search_anchors import GeographyFilter, SearchAnchors, TrialFilters |
| from trialpath.services.mcp_client import ClinicalTrialsMCPClient, MCPError |
|
|
|
|
| class TestMCPClient: |
| """Test ClinicalTrials MCP client.""" |
|
|
| @pytest.fixture |
| def client(self): |
| return ClinicalTrialsMCPClient(mcp_url="http://localhost:3000") |
|
|
| @pytest.fixture |
| def sample_anchors(self): |
| return SearchAnchors( |
| condition="Non-Small Cell Lung Cancer", |
| subtype="adenocarcinoma", |
| biomarkers=["EGFR exon 19 deletion"], |
| stage="IV", |
| age=52, |
| geography=GeographyFilter(country="United States"), |
| trial_filters=TrialFilters( |
| recruitment_status=["Recruiting"], |
| phase=["Phase 3"], |
| ), |
| ) |
|
|
| def _mock_httpx(self, MockHTTP, response_data): |
| import json as _json |
|
|
| |
| mock_response = MagicMock() |
| mock_response.raise_for_status = MagicMock() |
| mock_response.text = _json.dumps(response_data) |
| mock_response.headers = {"content-type": "application/json"} |
|
|
| mock_client = MagicMock() |
| mock_client.post = AsyncMock(return_value=mock_response) |
|
|
| |
| mock_client_ctx = MagicMock() |
| mock_client_ctx.__aenter__ = AsyncMock(return_value=mock_client) |
| mock_client_ctx.__aexit__ = AsyncMock(return_value=None) |
| MockHTTP.return_value = mock_client_ctx |
| return mock_client |
|
|
| @pytest.mark.asyncio |
| async def test_search_builds_correct_query(self, client, sample_anchors): |
| """Search should combine condition, subtype, and biomarkers into query.""" |
| with patch("httpx.AsyncClient") as MockHTTP: |
| mock_client = self._mock_httpx(MockHTTP, {"result": {"studies": []}}) |
|
|
| await client.search(sample_anchors) |
|
|
| call_args = mock_client.post.call_args |
| body = call_args.kwargs.get("json", call_args[1].get("json", {})) |
| query = body["params"]["arguments"]["query"] |
|
|
| assert "Non-Small Cell Lung Cancer" in query |
| assert "adenocarcinoma" in query |
|
|
| @pytest.mark.asyncio |
| async def test_search_includes_country_filter(self, client, sample_anchors): |
| """Search should pass country as a parameter.""" |
| with patch("httpx.AsyncClient") as MockHTTP: |
| mock_client = self._mock_httpx(MockHTTP, {"result": {"studies": []}}) |
|
|
| await client.search(sample_anchors) |
|
|
| call_args = mock_client.post.call_args |
| body = call_args.kwargs.get("json", call_args[1].get("json", {})) |
| args = body["params"]["arguments"] |
|
|
| assert args.get("country") == "United States" |
|
|
| @pytest.mark.asyncio |
| async def test_search_includes_recruitment_status_filter(self, client, sample_anchors): |
| """Search should include recruitment status in filter expression.""" |
| with patch("httpx.AsyncClient") as MockHTTP: |
| mock_client = self._mock_httpx(MockHTTP, {"result": {"studies": []}}) |
|
|
| await client.search(sample_anchors) |
|
|
| call_args = mock_client.post.call_args |
| body = call_args.kwargs.get("json", call_args[1].get("json", {})) |
| filter_str = body["params"]["arguments"].get("filter", "") |
|
|
| assert "OverallStatus" in filter_str |
| assert "Recruiting" in filter_str |
|
|
| @pytest.mark.asyncio |
| async def test_get_study_by_nct_id(self, client): |
| """Should call get_study tool with correct NCT ID.""" |
| with patch("httpx.AsyncClient") as MockHTTP: |
| self._mock_httpx( |
| MockHTTP, {"result": {"studies": [{"nctId": "NCT01234567", "title": "Test Trial"}]}} |
| ) |
|
|
| result = await client.get_study("NCT01234567") |
| assert result["nctId"] == "NCT01234567" |
|
|
| @pytest.mark.asyncio |
| async def test_mcp_error_handling(self, client): |
| """Should raise MCPError on MCP server error response.""" |
| with patch("httpx.AsyncClient") as MockHTTP: |
| self._mock_httpx(MockHTTP, {"error": {"code": -32600, "message": "Invalid request"}}) |
|
|
| with pytest.raises(MCPError, match="Invalid request"): |
| await client.get_study("NCT00000000") |
|
|
| @pytest.mark.asyncio |
| async def test_find_eligible_passes_demographics(self, client): |
| """find_eligible should pass patient demographics correctly.""" |
| with patch("httpx.AsyncClient") as MockHTTP: |
| mock_client = self._mock_httpx( |
| MockHTTP, {"result": {"eligibleStudies": [], "totalMatches": 0}} |
| ) |
|
|
| await client.find_eligible( |
| age=52, |
| sex="Female", |
| conditions=["NSCLC"], |
| country="United States", |
| ) |
|
|
| call_args = mock_client.post.call_args |
| body = call_args.kwargs.get("json", call_args[1].get("json", {})) |
| args = body["params"]["arguments"] |
|
|
| assert args["age"] == 52 |
| assert args["sex"] == "Female" |
| assert args["conditions"] == ["NSCLC"] |
|
|
|
|
| class TestSearchDirectEnhanced: |
| """Test enhanced search_direct with intervention and eligibility dimensions.""" |
|
|
| @pytest.fixture |
| def client(self): |
| return ClinicalTrialsMCPClient(mcp_url="http://localhost:3000") |
|
|
| @pytest.mark.asyncio |
| async def test_search_direct_uses_query_cond_and_intr(self, client): |
| """search_direct should use query.cond for condition and query.intr for intervention.""" |
| anchors = SearchAnchors( |
| condition="Non-Small Cell Lung Cancer", |
| subtype="adenocarcinoma", |
| interventions=["osimertinib", "erlotinib"], |
| ) |
|
|
| with patch("requests.get") as mock_get: |
| mock_resp = MagicMock() |
| mock_resp.json.return_value = {"studies": []} |
| mock_resp.raise_for_status = MagicMock() |
| mock_get.return_value = mock_resp |
|
|
| await client.search_direct(anchors) |
|
|
| call_args = mock_get.call_args |
| params = call_args.kwargs.get("params", call_args[1].get("params", {})) |
| assert params["query.cond"] == "Non-Small Cell Lung Cancer adenocarcinoma" |
| assert params["query.intr"] == "osimertinib" |
|
|
| @pytest.mark.asyncio |
| async def test_search_direct_uses_eligibility_keywords(self, client): |
| """search_direct should use query.term for eligibility keywords.""" |
| anchors = SearchAnchors( |
| condition="NSCLC", |
| eligibility_keywords=["ECOG 0-1", "stage IV"], |
| ) |
|
|
| with patch("requests.get") as mock_get: |
| mock_resp = MagicMock() |
| mock_resp.json.return_value = {"studies": []} |
| mock_resp.raise_for_status = MagicMock() |
| mock_get.return_value = mock_resp |
|
|
| await client.search_direct(anchors) |
|
|
| call_args = mock_get.call_args |
| params = call_args.kwargs.get("params", call_args[1].get("params", {})) |
| assert params["query.term"] == "ECOG 0-1 stage IV" |
|
|
| @pytest.mark.asyncio |
| async def test_search_direct_without_interventions_backwards_compatible(self, client): |
| """search_direct without new fields should still work (backward compatible).""" |
| anchors = SearchAnchors( |
| condition="NSCLC", |
| biomarkers=["EGFR exon 19 deletion"], |
| ) |
|
|
| with patch("requests.get") as mock_get: |
| mock_resp = MagicMock() |
| mock_resp.json.return_value = {"studies": []} |
| mock_resp.raise_for_status = MagicMock() |
| mock_get.return_value = mock_resp |
|
|
| await client.search_direct(anchors) |
|
|
| call_args = mock_get.call_args |
| params = call_args.kwargs.get("params", call_args[1].get("params", {})) |
| assert params["query.cond"] == "NSCLC" |
| assert "query.intr" not in params |
| assert params["query.term"] == "EGFR" |
|
|
| @pytest.mark.asyncio |
| async def test_search_multi_variant_includes_intervention_variant(self, client): |
| """search_multi_variant should fire intervention-specific search variants.""" |
| anchors = SearchAnchors( |
| condition="NSCLC", |
| interventions=["osimertinib", "erlotinib"], |
| eligibility_keywords=["ECOG 0-1"], |
| ) |
|
|
| call_count = 0 |
|
|
| async def mock_search_direct(a): |
| nonlocal call_count |
| call_count += 1 |
| return [ |
| { |
| "protocolSection": { |
| "identificationModule": { |
| "nctId": f"NCT0000000{call_count}", |
| "briefTitle": f"Trial {call_count}", |
| }, |
| "statusModule": {}, |
| "designModule": {}, |
| "conditionsModule": {}, |
| "eligibilityModule": {}, |
| "contactsLocationsModule": {}, |
| } |
| } |
| ] |
|
|
| with patch.object(client, "search_direct", side_effect=mock_search_direct): |
| results = await client.search_multi_variant(anchors) |
|
|
| |
| assert call_count == 5 |
| |
| assert len(results) == 5 |
|
|
|
|
| class TestNormalizeTrial: |
| """Test normalize_trial conversion.""" |
|
|
| def test_normalize_full_ctgov_response(self): |
| """Should convert full ClinicalTrials.gov API response to TrialCandidate.""" |
| raw = { |
| "protocolSection": { |
| "identificationModule": { |
| "nctId": "NCT05012345", |
| "briefTitle": "Test NSCLC Trial", |
| }, |
| "statusModule": { |
| "overallStatus": "Recruiting", |
| }, |
| "designModule": { |
| "phases": ["PHASE3"], |
| }, |
| "conditionsModule": { |
| "conditions": ["Non-Small Cell Lung Cancer"], |
| }, |
| "eligibilityModule": { |
| "minimumAge": "18 Years", |
| "maximumAge": "80 Years", |
| "eligibilityCriteria": ( |
| "Inclusion Criteria:\n- Confirmed NSCLC\n- ECOG 0-1\n" |
| "Exclusion Criteria\n- Active brain metastases" |
| ), |
| }, |
| "contactsLocationsModule": { |
| "locations": [ |
| {"country": "United States", "city": "Boston"}, |
| ], |
| }, |
| }, |
| } |
|
|
| trial = ClinicalTrialsMCPClient.normalize_trial(raw) |
|
|
| assert trial.nct_id == "NCT05012345" |
| assert trial.title == "Test NSCLC Trial" |
| assert trial.status == "Recruiting" |
| assert trial.phase == "PHASE3" |
| assert len(trial.conditions) == 1 |
| assert trial.age_range is not None |
| assert trial.age_range.min == 18 |
| assert trial.age_range.max == 80 |
| assert len(trial.locations) == 1 |
| assert trial.locations[0].city == "Boston" |
| assert trial.eligibility_text is not None |
| assert "Confirmed NSCLC" in trial.eligibility_text.inclusion |
| assert "brain metastases" in trial.eligibility_text.exclusion |
|
|
| def test_normalize_minimal_response(self): |
| """Should handle minimal response with fallback fields.""" |
| raw = { |
| "nctId": "NCT00000001", |
| "title": "Minimal Trial", |
| } |
|
|
| trial = ClinicalTrialsMCPClient.normalize_trial(raw) |
|
|
| assert trial.nct_id == "NCT00000001" |
| assert trial.title == "Minimal Trial" |
| assert trial.phase is None |
| assert trial.age_range is None |
| assert trial.eligibility_text is None |
|
|
| def test_normalize_returns_trial_candidate_type(self): |
| """normalize_trial should return a TrialCandidate instance.""" |
| from trialpath.models.trial_candidate import TrialCandidate |
|
|
| raw = {"nctId": "NCT001", "title": "T"} |
| trial = ClinicalTrialsMCPClient.normalize_trial(raw) |
| assert isinstance(trial, TrialCandidate) |
|
|
| def test_normalize_fingerprint_text(self): |
| """fingerprint_text should combine title, conditions, and phase.""" |
| raw = { |
| "protocolSection": { |
| "identificationModule": { |
| "nctId": "NCT001", |
| "briefTitle": "EGFR Trial", |
| }, |
| "conditionsModule": {"conditions": ["NSCLC"]}, |
| "designModule": {"phases": ["PHASE2"]}, |
| "statusModule": {}, |
| "eligibilityModule": {}, |
| "contactsLocationsModule": {}, |
| } |
| } |
| trial = ClinicalTrialsMCPClient.normalize_trial(raw) |
| assert "EGFR Trial" in trial.fingerprint_text |
| assert "NSCLC" in trial.fingerprint_text |
|
|