TrialPath / trialpath /tests /test_gemini.py
yakilee's picture
style: apply ruff format to entire codebase
e46883d
"""TDD tests for Gemini planner service."""
import json
from unittest.mock import MagicMock, patch
import pytest
from trialpath.models.eligibility_ledger import EligibilityLedger, OverallAssessment
from trialpath.models.search_anchors import SearchAnchors
from trialpath.services.gemini_planner import GeminiPlanner
@pytest.fixture
def sample_profile():
return {
"patient_id": "P001",
"demographics": {"age": 52, "sex": "female"},
"diagnosis": {
"primary_condition": "Non-Small Cell Lung Cancer",
"histology": "adenocarcinoma",
"stage": "IVa",
},
"biomarkers": [
{"name": "EGFR", "result": "Exon 19 deletion"},
],
"performance_status": {"scale": "ECOG", "value": 1},
}
@pytest.fixture
def mock_gemini():
"""Fixture that patches Gemini client for all tests."""
with patch("google.genai.Client") as MockClient:
mock_generate = MagicMock()
MockClient.return_value.models.generate_content = mock_generate
yield mock_generate
def _set_mock_response(mock_generate, text: str):
mock_response = MagicMock()
mock_response.text = text
mock_generate.return_value = mock_response
class TestGeminiSearchAnchorsGeneration:
"""Test Gemini structured output for SearchAnchors generation."""
@pytest.mark.asyncio
async def test_search_anchors_has_correct_condition(self, sample_profile, mock_gemini):
"""Generated SearchAnchors should reference NSCLC."""
_set_mock_response(
mock_gemini,
SearchAnchors(
condition="Non-Small Cell Lung Cancer",
subtype="adenocarcinoma",
biomarkers=["EGFR exon 19 deletion"],
stage="IV",
age=52,
performance_status_max=1,
).model_dump_json(),
)
planner = GeminiPlanner()
anchors = await planner.generate_search_anchors(sample_profile)
assert "lung" in anchors.condition.lower() or "nsclc" in anchors.condition.lower()
assert anchors.age == 52
@pytest.mark.asyncio
async def test_search_anchors_includes_biomarkers(self, sample_profile, mock_gemini):
"""SearchAnchors should include patient biomarkers."""
_set_mock_response(
mock_gemini,
SearchAnchors(
condition="NSCLC",
biomarkers=["EGFR exon 19 deletion"],
).model_dump_json(),
)
planner = GeminiPlanner()
anchors = await planner.generate_search_anchors(sample_profile)
assert len(anchors.biomarkers) > 0
assert any("EGFR" in b for b in anchors.biomarkers)
@pytest.mark.asyncio
async def test_search_anchors_json_schema_passed(self, sample_profile, mock_gemini):
"""Verify that Gemini is called with response_json_schema."""
_set_mock_response(mock_gemini, SearchAnchors(condition="NSCLC").model_dump_json())
planner = GeminiPlanner()
await planner.generate_search_anchors(sample_profile)
call_args = mock_gemini.call_args
config = call_args.kwargs.get("config", call_args[1].get("config", {}))
assert config.get("response_mime_type") == "application/json"
assert "response_json_schema" in config
class TestGeminiEligibilityEvaluation:
"""Test Gemini eligibility evaluation output."""
@pytest.mark.asyncio
async def test_ledger_has_all_required_fields(self):
"""EligibilityLedger from Gemini should have patient_id, nct_id, assessment."""
mock_ledger = EligibilityLedger(
patient_id="P001",
nct_id="NCT01234567",
overall_assessment=OverallAssessment.UNCERTAIN,
criteria=[],
gaps=[],
)
assert mock_ledger.patient_id == "P001"
assert mock_ledger.nct_id == "NCT01234567"
assert mock_ledger.overall_assessment in OverallAssessment
@pytest.mark.asyncio
async def test_error_handling_invalid_json(self):
"""Should raise error on invalid Gemini JSON response."""
with patch("google.genai.Client") as MockClient:
mock_response = MagicMock()
mock_response.text = "not valid json"
MockClient.return_value.models.generate_content = MagicMock(return_value=mock_response)
planner = GeminiPlanner()
with pytest.raises(Exception):
await planner.evaluate_eligibility({}, {}, None)
class TestGeminiRefineSearch:
"""Test search refinement methods."""
@pytest.mark.asyncio
async def test_refine_search_returns_search_anchors(self, mock_gemini):
"""refine_search should return valid SearchAnchors."""
refined = SearchAnchors(
condition="NSCLC",
biomarkers=["EGFR exon 19 deletion"],
trial_filters={"recruitment_status": ["Recruiting"], "phase": ["Phase 3"]},
)
_set_mock_response(mock_gemini, refined.model_dump_json())
planner = GeminiPlanner()
anchors = SearchAnchors(condition="NSCLC", biomarkers=["EGFR"])
result = await planner.refine_search(anchors, result_count=100)
assert isinstance(result, SearchAnchors)
assert result.condition == "NSCLC"
@pytest.mark.asyncio
async def test_relax_search_returns_search_anchors(self, mock_gemini):
"""relax_search should return valid SearchAnchors."""
relaxed = SearchAnchors(
condition="NSCLC",
trial_filters={"recruitment_status": ["Recruiting", "Not yet recruiting"]},
)
_set_mock_response(mock_gemini, relaxed.model_dump_json())
planner = GeminiPlanner()
anchors = SearchAnchors(condition="NSCLC", biomarkers=["EGFR"])
result = await planner.relax_search(anchors, result_count=0)
assert isinstance(result, SearchAnchors)
@pytest.mark.asyncio
async def test_refine_passes_result_count_in_prompt(self, mock_gemini):
"""refine_search should include result_count in prompt."""
_set_mock_response(mock_gemini, SearchAnchors(condition="NSCLC").model_dump_json())
planner = GeminiPlanner()
anchors = SearchAnchors(condition="NSCLC")
await planner.refine_search(anchors, result_count=150)
call_args = mock_gemini.call_args
prompt = call_args.kwargs.get("contents", call_args[1].get("contents", ""))
assert "150" in prompt
class TestGeminiSliceCriteria:
"""Test criteria slicing."""
@pytest.mark.asyncio
async def test_slice_criteria_returns_list(self, mock_gemini):
"""slice_criteria should return list of criterion dicts."""
_set_mock_response(
mock_gemini,
json.dumps(
{
"criteria": [
{
"criterion_id": "inc_1",
"type": "inclusion",
"text": "Age >= 18",
"category": "structural",
},
{
"criterion_id": "inc_2",
"type": "inclusion",
"text": "Confirmed NSCLC",
"category": "medical",
},
]
}
),
)
planner = GeminiPlanner()
result = await planner.slice_criteria({"nct_id": "NCT001"})
assert len(result) == 2
assert result[0]["criterion_id"] == "inc_1"
assert result[1]["category"] == "medical"
class TestGeminiStructuralCriterion:
"""Test structural criterion evaluation."""
@pytest.mark.asyncio
async def test_evaluate_structural_returns_decision(self, mock_gemini):
"""evaluate_structural_criterion should return decision dict."""
_set_mock_response(
mock_gemini,
json.dumps(
{
"decision": "met",
"reasoning": "Patient age 52 >= 18",
"confidence": 0.99,
}
),
)
planner = GeminiPlanner()
result = await planner.evaluate_structural_criterion(
"Age >= 18 years",
{"demographics": {"age": 52}},
)
assert result["decision"] == "met"
assert result["confidence"] == 0.99
class TestGeminiAggregateAssessments:
"""Test assessment aggregation."""
@pytest.mark.asyncio
async def test_aggregate_returns_ledger(self, mock_gemini):
"""aggregate_assessments should return EligibilityLedger."""
ledger = EligibilityLedger(
patient_id="P001",
nct_id="NCT001",
overall_assessment=OverallAssessment.LIKELY_ELIGIBLE,
criteria=[],
gaps=[],
)
_set_mock_response(mock_gemini, ledger.model_dump_json())
planner = GeminiPlanner()
result = await planner.aggregate_assessments({}, {}, [])
assert isinstance(result, EligibilityLedger)
assert result.overall_assessment == OverallAssessment.LIKELY_ELIGIBLE
class TestGeminiGapAnalysis:
"""Test gap analysis."""
@pytest.mark.asyncio
async def test_analyze_gaps_returns_list(self, mock_gemini):
"""analyze_gaps should return list of gap dicts."""
_set_mock_response(
mock_gemini,
json.dumps(
{
"gaps": [
{
"description": "Brain MRI results needed",
"recommended_action": "Upload brain MRI report",
"clinical_importance": "high",
"affected_trial_count": 2,
},
]
}
),
)
planner = GeminiPlanner()
result = await planner.analyze_gaps({}, [])
assert len(result) == 1
assert result[0]["clinical_importance"] == "high"
assert result[0]["affected_trial_count"] == 2
@pytest.mark.asyncio
async def test_analyze_gaps_uses_json_schema(self, mock_gemini):
"""analyze_gaps should pass JSON schema to Gemini."""
_set_mock_response(mock_gemini, json.dumps({"gaps": []}))
planner = GeminiPlanner()
await planner.analyze_gaps({}, [])
call_args = mock_gemini.call_args
config = call_args.kwargs.get("config", call_args[1].get("config", {}))
assert config.get("response_mime_type") == "application/json"
assert "response_json_schema" in config