"""TDD tests for Gemini planner service.""" import json from unittest.mock import MagicMock, patch import pytest from trialpath.models.eligibility_ledger import EligibilityLedger, OverallAssessment from trialpath.models.search_anchors import SearchAnchors from trialpath.services.gemini_planner import GeminiPlanner @pytest.fixture def sample_profile(): return { "patient_id": "P001", "demographics": {"age": 52, "sex": "female"}, "diagnosis": { "primary_condition": "Non-Small Cell Lung Cancer", "histology": "adenocarcinoma", "stage": "IVa", }, "biomarkers": [ {"name": "EGFR", "result": "Exon 19 deletion"}, ], "performance_status": {"scale": "ECOG", "value": 1}, } @pytest.fixture def mock_gemini(): """Fixture that patches Gemini client for all tests.""" with patch("google.genai.Client") as MockClient: mock_generate = MagicMock() MockClient.return_value.models.generate_content = mock_generate yield mock_generate def _set_mock_response(mock_generate, text: str): mock_response = MagicMock() mock_response.text = text mock_generate.return_value = mock_response class TestGeminiSearchAnchorsGeneration: """Test Gemini structured output for SearchAnchors generation.""" @pytest.mark.asyncio async def test_search_anchors_has_correct_condition(self, sample_profile, mock_gemini): """Generated SearchAnchors should reference NSCLC.""" _set_mock_response( mock_gemini, SearchAnchors( condition="Non-Small Cell Lung Cancer", subtype="adenocarcinoma", biomarkers=["EGFR exon 19 deletion"], stage="IV", age=52, performance_status_max=1, ).model_dump_json(), ) planner = GeminiPlanner() anchors = await planner.generate_search_anchors(sample_profile) assert "lung" in anchors.condition.lower() or "nsclc" in anchors.condition.lower() assert anchors.age == 52 @pytest.mark.asyncio async def test_search_anchors_includes_biomarkers(self, sample_profile, mock_gemini): """SearchAnchors should include patient biomarkers.""" _set_mock_response( mock_gemini, SearchAnchors( condition="NSCLC", biomarkers=["EGFR exon 19 deletion"], ).model_dump_json(), ) planner = GeminiPlanner() anchors = await planner.generate_search_anchors(sample_profile) assert len(anchors.biomarkers) > 0 assert any("EGFR" in b for b in anchors.biomarkers) @pytest.mark.asyncio async def test_search_anchors_json_schema_passed(self, sample_profile, mock_gemini): """Verify that Gemini is called with response_json_schema.""" _set_mock_response(mock_gemini, SearchAnchors(condition="NSCLC").model_dump_json()) planner = GeminiPlanner() await planner.generate_search_anchors(sample_profile) call_args = mock_gemini.call_args config = call_args.kwargs.get("config", call_args[1].get("config", {})) assert config.get("response_mime_type") == "application/json" assert "response_json_schema" in config class TestGeminiEligibilityEvaluation: """Test Gemini eligibility evaluation output.""" @pytest.mark.asyncio async def test_ledger_has_all_required_fields(self): """EligibilityLedger from Gemini should have patient_id, nct_id, assessment.""" mock_ledger = EligibilityLedger( patient_id="P001", nct_id="NCT01234567", overall_assessment=OverallAssessment.UNCERTAIN, criteria=[], gaps=[], ) assert mock_ledger.patient_id == "P001" assert mock_ledger.nct_id == "NCT01234567" assert mock_ledger.overall_assessment in OverallAssessment @pytest.mark.asyncio async def test_error_handling_invalid_json(self): """Should raise error on invalid Gemini JSON response.""" with patch("google.genai.Client") as MockClient: mock_response = MagicMock() mock_response.text = "not valid json" MockClient.return_value.models.generate_content = MagicMock(return_value=mock_response) planner = GeminiPlanner() with pytest.raises(Exception): await planner.evaluate_eligibility({}, {}, None) class TestGeminiRefineSearch: """Test search refinement methods.""" @pytest.mark.asyncio async def test_refine_search_returns_search_anchors(self, mock_gemini): """refine_search should return valid SearchAnchors.""" refined = SearchAnchors( condition="NSCLC", biomarkers=["EGFR exon 19 deletion"], trial_filters={"recruitment_status": ["Recruiting"], "phase": ["Phase 3"]}, ) _set_mock_response(mock_gemini, refined.model_dump_json()) planner = GeminiPlanner() anchors = SearchAnchors(condition="NSCLC", biomarkers=["EGFR"]) result = await planner.refine_search(anchors, result_count=100) assert isinstance(result, SearchAnchors) assert result.condition == "NSCLC" @pytest.mark.asyncio async def test_relax_search_returns_search_anchors(self, mock_gemini): """relax_search should return valid SearchAnchors.""" relaxed = SearchAnchors( condition="NSCLC", trial_filters={"recruitment_status": ["Recruiting", "Not yet recruiting"]}, ) _set_mock_response(mock_gemini, relaxed.model_dump_json()) planner = GeminiPlanner() anchors = SearchAnchors(condition="NSCLC", biomarkers=["EGFR"]) result = await planner.relax_search(anchors, result_count=0) assert isinstance(result, SearchAnchors) @pytest.mark.asyncio async def test_refine_passes_result_count_in_prompt(self, mock_gemini): """refine_search should include result_count in prompt.""" _set_mock_response(mock_gemini, SearchAnchors(condition="NSCLC").model_dump_json()) planner = GeminiPlanner() anchors = SearchAnchors(condition="NSCLC") await planner.refine_search(anchors, result_count=150) call_args = mock_gemini.call_args prompt = call_args.kwargs.get("contents", call_args[1].get("contents", "")) assert "150" in prompt class TestGeminiSliceCriteria: """Test criteria slicing.""" @pytest.mark.asyncio async def test_slice_criteria_returns_list(self, mock_gemini): """slice_criteria should return list of criterion dicts.""" _set_mock_response( mock_gemini, json.dumps( { "criteria": [ { "criterion_id": "inc_1", "type": "inclusion", "text": "Age >= 18", "category": "structural", }, { "criterion_id": "inc_2", "type": "inclusion", "text": "Confirmed NSCLC", "category": "medical", }, ] } ), ) planner = GeminiPlanner() result = await planner.slice_criteria({"nct_id": "NCT001"}) assert len(result) == 2 assert result[0]["criterion_id"] == "inc_1" assert result[1]["category"] == "medical" class TestGeminiStructuralCriterion: """Test structural criterion evaluation.""" @pytest.mark.asyncio async def test_evaluate_structural_returns_decision(self, mock_gemini): """evaluate_structural_criterion should return decision dict.""" _set_mock_response( mock_gemini, json.dumps( { "decision": "met", "reasoning": "Patient age 52 >= 18", "confidence": 0.99, } ), ) planner = GeminiPlanner() result = await planner.evaluate_structural_criterion( "Age >= 18 years", {"demographics": {"age": 52}}, ) assert result["decision"] == "met" assert result["confidence"] == 0.99 class TestGeminiAggregateAssessments: """Test assessment aggregation.""" @pytest.mark.asyncio async def test_aggregate_returns_ledger(self, mock_gemini): """aggregate_assessments should return EligibilityLedger.""" ledger = EligibilityLedger( patient_id="P001", nct_id="NCT001", overall_assessment=OverallAssessment.LIKELY_ELIGIBLE, criteria=[], gaps=[], ) _set_mock_response(mock_gemini, ledger.model_dump_json()) planner = GeminiPlanner() result = await planner.aggregate_assessments({}, {}, []) assert isinstance(result, EligibilityLedger) assert result.overall_assessment == OverallAssessment.LIKELY_ELIGIBLE class TestGeminiGapAnalysis: """Test gap analysis.""" @pytest.mark.asyncio async def test_analyze_gaps_returns_list(self, mock_gemini): """analyze_gaps should return list of gap dicts.""" _set_mock_response( mock_gemini, json.dumps( { "gaps": [ { "description": "Brain MRI results needed", "recommended_action": "Upload brain MRI report", "clinical_importance": "high", "affected_trial_count": 2, }, ] } ), ) planner = GeminiPlanner() result = await planner.analyze_gaps({}, []) assert len(result) == 1 assert result[0]["clinical_importance"] == "high" assert result[0]["affected_trial_count"] == 2 @pytest.mark.asyncio async def test_analyze_gaps_uses_json_schema(self, mock_gemini): """analyze_gaps should pass JSON schema to Gemini.""" _set_mock_response(mock_gemini, json.dumps({"gaps": []})) planner = GeminiPlanner() await planner.analyze_gaps({}, []) call_args = mock_gemini.call_args config = call_args.kwargs.get("config", call_args[1].get("config", {})) assert config.get("response_mime_type") == "application/json" assert "response_json_schema" in config