| """TDD tests for Gemini planner service.""" |
|
|
| import json |
| from unittest.mock import MagicMock, patch |
|
|
| import pytest |
|
|
| from trialpath.models.eligibility_ledger import EligibilityLedger, OverallAssessment |
| from trialpath.models.search_anchors import SearchAnchors |
| from trialpath.services.gemini_planner import GeminiPlanner |
|
|
|
|
| @pytest.fixture |
| def sample_profile(): |
| return { |
| "patient_id": "P001", |
| "demographics": {"age": 52, "sex": "female"}, |
| "diagnosis": { |
| "primary_condition": "Non-Small Cell Lung Cancer", |
| "histology": "adenocarcinoma", |
| "stage": "IVa", |
| }, |
| "biomarkers": [ |
| {"name": "EGFR", "result": "Exon 19 deletion"}, |
| ], |
| "performance_status": {"scale": "ECOG", "value": 1}, |
| } |
|
|
|
|
| @pytest.fixture |
| def mock_gemini(): |
| """Fixture that patches Gemini client for all tests.""" |
| with patch("google.genai.Client") as MockClient: |
| mock_generate = MagicMock() |
| MockClient.return_value.models.generate_content = mock_generate |
| yield mock_generate |
|
|
|
|
| def _set_mock_response(mock_generate, text: str): |
| mock_response = MagicMock() |
| mock_response.text = text |
| mock_generate.return_value = mock_response |
|
|
|
|
| class TestGeminiSearchAnchorsGeneration: |
| """Test Gemini structured output for SearchAnchors generation.""" |
|
|
| @pytest.mark.asyncio |
| async def test_search_anchors_has_correct_condition(self, sample_profile, mock_gemini): |
| """Generated SearchAnchors should reference NSCLC.""" |
| _set_mock_response( |
| mock_gemini, |
| SearchAnchors( |
| condition="Non-Small Cell Lung Cancer", |
| subtype="adenocarcinoma", |
| biomarkers=["EGFR exon 19 deletion"], |
| stage="IV", |
| age=52, |
| performance_status_max=1, |
| ).model_dump_json(), |
| ) |
|
|
| planner = GeminiPlanner() |
| anchors = await planner.generate_search_anchors(sample_profile) |
|
|
| assert "lung" in anchors.condition.lower() or "nsclc" in anchors.condition.lower() |
| assert anchors.age == 52 |
|
|
| @pytest.mark.asyncio |
| async def test_search_anchors_includes_biomarkers(self, sample_profile, mock_gemini): |
| """SearchAnchors should include patient biomarkers.""" |
| _set_mock_response( |
| mock_gemini, |
| SearchAnchors( |
| condition="NSCLC", |
| biomarkers=["EGFR exon 19 deletion"], |
| ).model_dump_json(), |
| ) |
|
|
| planner = GeminiPlanner() |
| anchors = await planner.generate_search_anchors(sample_profile) |
|
|
| assert len(anchors.biomarkers) > 0 |
| assert any("EGFR" in b for b in anchors.biomarkers) |
|
|
| @pytest.mark.asyncio |
| async def test_search_anchors_json_schema_passed(self, sample_profile, mock_gemini): |
| """Verify that Gemini is called with response_json_schema.""" |
| _set_mock_response(mock_gemini, SearchAnchors(condition="NSCLC").model_dump_json()) |
|
|
| planner = GeminiPlanner() |
| await planner.generate_search_anchors(sample_profile) |
|
|
| call_args = mock_gemini.call_args |
| config = call_args.kwargs.get("config", call_args[1].get("config", {})) |
| assert config.get("response_mime_type") == "application/json" |
| assert "response_json_schema" in config |
|
|
|
|
| class TestGeminiEligibilityEvaluation: |
| """Test Gemini eligibility evaluation output.""" |
|
|
| @pytest.mark.asyncio |
| async def test_ledger_has_all_required_fields(self): |
| """EligibilityLedger from Gemini should have patient_id, nct_id, assessment.""" |
| mock_ledger = EligibilityLedger( |
| patient_id="P001", |
| nct_id="NCT01234567", |
| overall_assessment=OverallAssessment.UNCERTAIN, |
| criteria=[], |
| gaps=[], |
| ) |
|
|
| assert mock_ledger.patient_id == "P001" |
| assert mock_ledger.nct_id == "NCT01234567" |
| assert mock_ledger.overall_assessment in OverallAssessment |
|
|
| @pytest.mark.asyncio |
| async def test_error_handling_invalid_json(self): |
| """Should raise error on invalid Gemini JSON response.""" |
| with patch("google.genai.Client") as MockClient: |
| mock_response = MagicMock() |
| mock_response.text = "not valid json" |
|
|
| MockClient.return_value.models.generate_content = MagicMock(return_value=mock_response) |
|
|
| planner = GeminiPlanner() |
| with pytest.raises(Exception): |
| await planner.evaluate_eligibility({}, {}, None) |
|
|
|
|
| class TestGeminiRefineSearch: |
| """Test search refinement methods.""" |
|
|
| @pytest.mark.asyncio |
| async def test_refine_search_returns_search_anchors(self, mock_gemini): |
| """refine_search should return valid SearchAnchors.""" |
| refined = SearchAnchors( |
| condition="NSCLC", |
| biomarkers=["EGFR exon 19 deletion"], |
| trial_filters={"recruitment_status": ["Recruiting"], "phase": ["Phase 3"]}, |
| ) |
| _set_mock_response(mock_gemini, refined.model_dump_json()) |
|
|
| planner = GeminiPlanner() |
| anchors = SearchAnchors(condition="NSCLC", biomarkers=["EGFR"]) |
| result = await planner.refine_search(anchors, result_count=100) |
|
|
| assert isinstance(result, SearchAnchors) |
| assert result.condition == "NSCLC" |
|
|
| @pytest.mark.asyncio |
| async def test_relax_search_returns_search_anchors(self, mock_gemini): |
| """relax_search should return valid SearchAnchors.""" |
| relaxed = SearchAnchors( |
| condition="NSCLC", |
| trial_filters={"recruitment_status": ["Recruiting", "Not yet recruiting"]}, |
| ) |
| _set_mock_response(mock_gemini, relaxed.model_dump_json()) |
|
|
| planner = GeminiPlanner() |
| anchors = SearchAnchors(condition="NSCLC", biomarkers=["EGFR"]) |
| result = await planner.relax_search(anchors, result_count=0) |
|
|
| assert isinstance(result, SearchAnchors) |
|
|
| @pytest.mark.asyncio |
| async def test_refine_passes_result_count_in_prompt(self, mock_gemini): |
| """refine_search should include result_count in prompt.""" |
| _set_mock_response(mock_gemini, SearchAnchors(condition="NSCLC").model_dump_json()) |
|
|
| planner = GeminiPlanner() |
| anchors = SearchAnchors(condition="NSCLC") |
| await planner.refine_search(anchors, result_count=150) |
|
|
| call_args = mock_gemini.call_args |
| prompt = call_args.kwargs.get("contents", call_args[1].get("contents", "")) |
| assert "150" in prompt |
|
|
|
|
| class TestGeminiSliceCriteria: |
| """Test criteria slicing.""" |
|
|
| @pytest.mark.asyncio |
| async def test_slice_criteria_returns_list(self, mock_gemini): |
| """slice_criteria should return list of criterion dicts.""" |
| _set_mock_response( |
| mock_gemini, |
| json.dumps( |
| { |
| "criteria": [ |
| { |
| "criterion_id": "inc_1", |
| "type": "inclusion", |
| "text": "Age >= 18", |
| "category": "structural", |
| }, |
| { |
| "criterion_id": "inc_2", |
| "type": "inclusion", |
| "text": "Confirmed NSCLC", |
| "category": "medical", |
| }, |
| ] |
| } |
| ), |
| ) |
|
|
| planner = GeminiPlanner() |
| result = await planner.slice_criteria({"nct_id": "NCT001"}) |
|
|
| assert len(result) == 2 |
| assert result[0]["criterion_id"] == "inc_1" |
| assert result[1]["category"] == "medical" |
|
|
|
|
| class TestGeminiStructuralCriterion: |
| """Test structural criterion evaluation.""" |
|
|
| @pytest.mark.asyncio |
| async def test_evaluate_structural_returns_decision(self, mock_gemini): |
| """evaluate_structural_criterion should return decision dict.""" |
| _set_mock_response( |
| mock_gemini, |
| json.dumps( |
| { |
| "decision": "met", |
| "reasoning": "Patient age 52 >= 18", |
| "confidence": 0.99, |
| } |
| ), |
| ) |
|
|
| planner = GeminiPlanner() |
| result = await planner.evaluate_structural_criterion( |
| "Age >= 18 years", |
| {"demographics": {"age": 52}}, |
| ) |
|
|
| assert result["decision"] == "met" |
| assert result["confidence"] == 0.99 |
|
|
|
|
| class TestGeminiAggregateAssessments: |
| """Test assessment aggregation.""" |
|
|
| @pytest.mark.asyncio |
| async def test_aggregate_returns_ledger(self, mock_gemini): |
| """aggregate_assessments should return EligibilityLedger.""" |
| ledger = EligibilityLedger( |
| patient_id="P001", |
| nct_id="NCT001", |
| overall_assessment=OverallAssessment.LIKELY_ELIGIBLE, |
| criteria=[], |
| gaps=[], |
| ) |
| _set_mock_response(mock_gemini, ledger.model_dump_json()) |
|
|
| planner = GeminiPlanner() |
| result = await planner.aggregate_assessments({}, {}, []) |
|
|
| assert isinstance(result, EligibilityLedger) |
| assert result.overall_assessment == OverallAssessment.LIKELY_ELIGIBLE |
|
|
|
|
| class TestGeminiGapAnalysis: |
| """Test gap analysis.""" |
|
|
| @pytest.mark.asyncio |
| async def test_analyze_gaps_returns_list(self, mock_gemini): |
| """analyze_gaps should return list of gap dicts.""" |
| _set_mock_response( |
| mock_gemini, |
| json.dumps( |
| { |
| "gaps": [ |
| { |
| "description": "Brain MRI results needed", |
| "recommended_action": "Upload brain MRI report", |
| "clinical_importance": "high", |
| "affected_trial_count": 2, |
| }, |
| ] |
| } |
| ), |
| ) |
|
|
| planner = GeminiPlanner() |
| result = await planner.analyze_gaps({}, []) |
|
|
| assert len(result) == 1 |
| assert result[0]["clinical_importance"] == "high" |
| assert result[0]["affected_trial_count"] == 2 |
|
|
| @pytest.mark.asyncio |
| async def test_analyze_gaps_uses_json_schema(self, mock_gemini): |
| """analyze_gaps should pass JSON schema to Gemini.""" |
| _set_mock_response(mock_gemini, json.dumps({"gaps": []})) |
|
|
| planner = GeminiPlanner() |
| await planner.analyze_gaps({}, []) |
|
|
| call_args = mock_gemini.call_args |
| config = call_args.kwargs.get("config", call_args[1].get("config", {})) |
| assert config.get("response_mime_type") == "application/json" |
| assert "response_json_schema" in config |
|
|