| """Service integration tests: verify service chains produce correct outputs.""" |
|
|
| from __future__ import annotations |
|
|
| import pytest |
|
|
| from app.services.mock_data import ( |
| MOCK_PATIENT_PROFILE, |
| MOCK_TRIAL_CANDIDATES, |
| ) |
| from trialpath.models import ( |
| EligibilityLedger, |
| PatientProfile, |
| SearchAnchors, |
| ) |
|
|
|
|
| class TestMedGemmaToProfileRoundtrip: |
| """MedGemma extraction β PatientProfile validation chain.""" |
|
|
| @pytest.mark.asyncio |
| async def test_extracted_profile_is_valid_model(self, mock_medgemma): |
| """Verify MedGemma output parses into a valid PatientProfile.""" |
| from trialpath.services.medgemma_extractor import MedGemmaExtractor |
|
|
| extractor = MedGemmaExtractor() |
| profile = await extractor.extract(["dummy.pdf"], {}) |
| assert isinstance(profile, PatientProfile) |
| assert profile.patient_id is not None |
|
|
| @pytest.mark.asyncio |
| async def test_extracted_profile_has_prescreen_data(self, mock_medgemma): |
| """Verify extracted profile has enough data for prescreening.""" |
| from trialpath.services.medgemma_extractor import MedGemmaExtractor |
|
|
| extractor = MedGemmaExtractor() |
| profile = await extractor.extract(["dummy.pdf"], {}) |
| assert profile.has_minimum_prescreen_data() |
|
|
| @pytest.mark.asyncio |
| async def test_profile_serialization_roundtrip(self, mock_medgemma): |
| """Verify profile survives JSON serialization.""" |
| from trialpath.services.medgemma_extractor import MedGemmaExtractor |
|
|
| extractor = MedGemmaExtractor() |
| profile = await extractor.extract(["dummy.pdf"], {}) |
| json_str = profile.model_dump_json() |
| restored = PatientProfile.model_validate_json(json_str) |
| assert restored.patient_id == profile.patient_id |
| assert restored.demographics.age == profile.demographics.age |
|
|
|
|
| class TestGeminiSearchAnchorChain: |
| """Profile β SearchAnchors β search param generation.""" |
|
|
| @pytest.mark.asyncio |
| async def test_anchors_from_profile(self, mock_gemini, sample_profile): |
| """Verify Gemini produces SearchAnchors from a profile.""" |
| from trialpath.services.gemini_planner import GeminiPlanner |
|
|
| planner = GeminiPlanner() |
| anchors = await planner.generate_search_anchors(sample_profile) |
| assert isinstance(anchors, SearchAnchors) |
| assert anchors.condition is not None |
|
|
| @pytest.mark.asyncio |
| async def test_refine_narrows_results(self, mock_gemini, sample_anchors): |
| """Verify refine_search produces tighter anchors.""" |
| from trialpath.services.gemini_planner import GeminiPlanner |
|
|
| planner = GeminiPlanner() |
| refined = await planner.refine_search(sample_anchors, result_count=100, search_log=[]) |
| assert isinstance(refined, SearchAnchors) |
|
|
| @pytest.mark.asyncio |
| async def test_relax_broadens_results(self, mock_gemini, sample_anchors): |
| """Verify relax_search produces broader anchors.""" |
| from trialpath.services.gemini_planner import GeminiPlanner |
|
|
| planner = GeminiPlanner() |
| relaxed = await planner.relax_search(sample_anchors, result_count=0, search_log=[]) |
| assert isinstance(relaxed, SearchAnchors) |
|
|
|
|
| class TestDualModelEligibility: |
| """Dual-model eligibility evaluation: MedGemma (medical) + Gemini (structural).""" |
|
|
| @pytest.mark.asyncio |
| async def test_slice_criteria_returns_typed_list(self, mock_gemini): |
| """Verify criteria slicing produces typed criterion list.""" |
| from trialpath.services.gemini_planner import GeminiPlanner |
|
|
| planner = GeminiPlanner() |
| trial = MOCK_TRIAL_CANDIDATES[0] |
| criteria = await planner.slice_criteria(trial.model_dump()) |
| assert isinstance(criteria, list) |
| assert len(criteria) >= 1 |
| for c in criteria: |
| assert "type" in c |
| assert c["type"] in ("medical", "structural") |
|
|
| @pytest.mark.asyncio |
| async def test_medical_criterion_via_medgemma(self, mock_medgemma): |
| """Verify MedGemma evaluates a medical criterion.""" |
| from trialpath.services.medgemma_extractor import MedGemmaExtractor |
|
|
| extractor = MedGemmaExtractor() |
| result = await extractor.evaluate_medical_criterion( |
| criterion_text="EGFR mutation positive", |
| patient_profile=MOCK_PATIENT_PROFILE, |
| evidence_docs=[], |
| ) |
| assert result["decision"] in ("met", "not_met", "unknown") |
| assert "confidence" in result |
|
|
| @pytest.mark.asyncio |
| async def test_structural_criterion_via_gemini(self, mock_gemini): |
| """Verify Gemini evaluates a structural criterion.""" |
| from trialpath.services.gemini_planner import GeminiPlanner |
|
|
| planner = GeminiPlanner() |
| result = await planner.evaluate_structural_criterion( |
| criterion_text="Age >= 18", |
| patient_profile=MOCK_PATIENT_PROFILE, |
| ) |
| assert result["decision"] in ("met", "not_met", "unknown") |
|
|
| @pytest.mark.asyncio |
| async def test_aggregate_produces_ledger(self, mock_gemini): |
| """Verify aggregation produces a valid EligibilityLedger.""" |
| from trialpath.services.gemini_planner import GeminiPlanner |
|
|
| planner = GeminiPlanner() |
| ledger = await planner.aggregate_assessments( |
| profile=MOCK_PATIENT_PROFILE, |
| trial=MOCK_TRIAL_CANDIDATES[0].model_dump(), |
| assessments=[ |
| {"criterion": "Age >= 18", "decision": "met", "confidence": 0.95}, |
| {"criterion": "EGFR+", "decision": "met", "confidence": 0.9}, |
| ], |
| ) |
| assert isinstance(ledger, EligibilityLedger) |
| assert ledger.nct_id is not None |
|
|
|
|
| class TestGapAnalysisAggregation: |
| """Gap analysis across multiple ledgers.""" |
|
|
| @pytest.mark.asyncio |
| async def test_gaps_identified_from_ledgers(self, mock_gemini, sample_profile, sample_ledgers): |
| """Verify gap analysis produces actionable items.""" |
| from trialpath.services.gemini_planner import GeminiPlanner |
|
|
| planner = GeminiPlanner() |
| gaps = await planner.analyze_gaps(sample_profile, sample_ledgers) |
| assert isinstance(gaps, list) |
| for gap in gaps: |
| assert "description" in gap |
| assert "recommended_action" in gap |
|
|
| def test_gap_deduplication_across_ledgers(self, sample_ledgers): |
| """Verify gaps are deduplicated across multiple ledgers.""" |
| all_gaps = [] |
| seen = set() |
| for ledger in sample_ledgers: |
| for gap in ledger.gaps: |
| if gap.description not in seen: |
| seen.add(gap.description) |
| all_gaps.append(gap) |
| |
| total = sum(len(lg.gaps) for lg in sample_ledgers) |
| assert len(all_gaps) <= total |
|
|
|
|
| class TestFullPipelineChain: |
| """End-to-end data flow through service chain with mocked services.""" |
|
|
| @pytest.mark.asyncio |
| async def test_profile_to_anchors_to_search_to_evaluate( |
| self, mock_medgemma, mock_gemini, sample_profile |
| ): |
| """Verify full chain: extract β anchors β search β evaluate.""" |
| from trialpath.services.gemini_planner import GeminiPlanner |
| from trialpath.services.medgemma_extractor import MedGemmaExtractor |
|
|
| |
| extractor = MedGemmaExtractor() |
| profile = await extractor.extract(["dummy.pdf"], {}) |
| assert isinstance(profile, PatientProfile) |
|
|
| |
| planner = GeminiPlanner() |
| anchors = await planner.generate_search_anchors(profile) |
| assert isinstance(anchors, SearchAnchors) |
|
|
| |
| trial = MOCK_TRIAL_CANDIDATES[0] |
| ledger = await planner.aggregate_assessments( |
| profile=profile, |
| trial=trial.model_dump(), |
| assessments=[{"criterion": "Age >= 18", "decision": "met", "confidence": 0.95}], |
| ) |
| assert isinstance(ledger, EligibilityLedger) |
|
|
| def test_full_data_contracts_compatible(self, sample_profile, sample_trials, sample_ledgers): |
| """Verify all data contracts work together end-to-end.""" |
| assert sample_profile.diagnosis is not None |
| |
| anchors = SearchAnchors( |
| condition=sample_profile.diagnosis.primary_condition, |
| biomarkers=[b.name for b in sample_profile.biomarkers], |
| stage=sample_profile.diagnosis.stage, |
| ) |
| assert anchors.condition is not None |
|
|
| |
| ledger_map = {lg.nct_id: lg for lg in sample_ledgers} |
| for trial in sample_trials: |
| assert trial.nct_id in ledger_map |
| ledger = ledger_map[trial.nct_id] |
| assert ledger.traffic_light in ("green", "yellow", "red") |
|
|
| |
| all_gaps = [] |
| for lg in sample_ledgers: |
| all_gaps.extend(lg.gaps) |
| |
| assert len(all_gaps) >= 1 |
|
|