TrialPath / tests /test_service_integration.py
yakilee's picture
style: apply ruff format to entire codebase
e46883d
"""Service integration tests: verify service chains produce correct outputs."""
from __future__ import annotations
import pytest
from app.services.mock_data import (
MOCK_PATIENT_PROFILE,
MOCK_TRIAL_CANDIDATES,
)
from trialpath.models import (
EligibilityLedger,
PatientProfile,
SearchAnchors,
)
class TestMedGemmaToProfileRoundtrip:
"""MedGemma extraction β†’ PatientProfile validation chain."""
@pytest.mark.asyncio
async def test_extracted_profile_is_valid_model(self, mock_medgemma):
"""Verify MedGemma output parses into a valid PatientProfile."""
from trialpath.services.medgemma_extractor import MedGemmaExtractor
extractor = MedGemmaExtractor()
profile = await extractor.extract(["dummy.pdf"], {})
assert isinstance(profile, PatientProfile)
assert profile.patient_id is not None
@pytest.mark.asyncio
async def test_extracted_profile_has_prescreen_data(self, mock_medgemma):
"""Verify extracted profile has enough data for prescreening."""
from trialpath.services.medgemma_extractor import MedGemmaExtractor
extractor = MedGemmaExtractor()
profile = await extractor.extract(["dummy.pdf"], {})
assert profile.has_minimum_prescreen_data()
@pytest.mark.asyncio
async def test_profile_serialization_roundtrip(self, mock_medgemma):
"""Verify profile survives JSON serialization."""
from trialpath.services.medgemma_extractor import MedGemmaExtractor
extractor = MedGemmaExtractor()
profile = await extractor.extract(["dummy.pdf"], {})
json_str = profile.model_dump_json()
restored = PatientProfile.model_validate_json(json_str)
assert restored.patient_id == profile.patient_id
assert restored.demographics.age == profile.demographics.age
class TestGeminiSearchAnchorChain:
"""Profile β†’ SearchAnchors β†’ search param generation."""
@pytest.mark.asyncio
async def test_anchors_from_profile(self, mock_gemini, sample_profile):
"""Verify Gemini produces SearchAnchors from a profile."""
from trialpath.services.gemini_planner import GeminiPlanner
planner = GeminiPlanner()
anchors = await planner.generate_search_anchors(sample_profile)
assert isinstance(anchors, SearchAnchors)
assert anchors.condition is not None
@pytest.mark.asyncio
async def test_refine_narrows_results(self, mock_gemini, sample_anchors):
"""Verify refine_search produces tighter anchors."""
from trialpath.services.gemini_planner import GeminiPlanner
planner = GeminiPlanner()
refined = await planner.refine_search(sample_anchors, result_count=100, search_log=[])
assert isinstance(refined, SearchAnchors)
@pytest.mark.asyncio
async def test_relax_broadens_results(self, mock_gemini, sample_anchors):
"""Verify relax_search produces broader anchors."""
from trialpath.services.gemini_planner import GeminiPlanner
planner = GeminiPlanner()
relaxed = await planner.relax_search(sample_anchors, result_count=0, search_log=[])
assert isinstance(relaxed, SearchAnchors)
class TestDualModelEligibility:
"""Dual-model eligibility evaluation: MedGemma (medical) + Gemini (structural)."""
@pytest.mark.asyncio
async def test_slice_criteria_returns_typed_list(self, mock_gemini):
"""Verify criteria slicing produces typed criterion list."""
from trialpath.services.gemini_planner import GeminiPlanner
planner = GeminiPlanner()
trial = MOCK_TRIAL_CANDIDATES[0]
criteria = await planner.slice_criteria(trial.model_dump())
assert isinstance(criteria, list)
assert len(criteria) >= 1
for c in criteria:
assert "type" in c
assert c["type"] in ("medical", "structural")
@pytest.mark.asyncio
async def test_medical_criterion_via_medgemma(self, mock_medgemma):
"""Verify MedGemma evaluates a medical criterion."""
from trialpath.services.medgemma_extractor import MedGemmaExtractor
extractor = MedGemmaExtractor()
result = await extractor.evaluate_medical_criterion(
criterion_text="EGFR mutation positive",
patient_profile=MOCK_PATIENT_PROFILE,
evidence_docs=[],
)
assert result["decision"] in ("met", "not_met", "unknown")
assert "confidence" in result
@pytest.mark.asyncio
async def test_structural_criterion_via_gemini(self, mock_gemini):
"""Verify Gemini evaluates a structural criterion."""
from trialpath.services.gemini_planner import GeminiPlanner
planner = GeminiPlanner()
result = await planner.evaluate_structural_criterion(
criterion_text="Age >= 18",
patient_profile=MOCK_PATIENT_PROFILE,
)
assert result["decision"] in ("met", "not_met", "unknown")
@pytest.mark.asyncio
async def test_aggregate_produces_ledger(self, mock_gemini):
"""Verify aggregation produces a valid EligibilityLedger."""
from trialpath.services.gemini_planner import GeminiPlanner
planner = GeminiPlanner()
ledger = await planner.aggregate_assessments(
profile=MOCK_PATIENT_PROFILE,
trial=MOCK_TRIAL_CANDIDATES[0].model_dump(),
assessments=[
{"criterion": "Age >= 18", "decision": "met", "confidence": 0.95},
{"criterion": "EGFR+", "decision": "met", "confidence": 0.9},
],
)
assert isinstance(ledger, EligibilityLedger)
assert ledger.nct_id is not None
class TestGapAnalysisAggregation:
"""Gap analysis across multiple ledgers."""
@pytest.mark.asyncio
async def test_gaps_identified_from_ledgers(self, mock_gemini, sample_profile, sample_ledgers):
"""Verify gap analysis produces actionable items."""
from trialpath.services.gemini_planner import GeminiPlanner
planner = GeminiPlanner()
gaps = await planner.analyze_gaps(sample_profile, sample_ledgers)
assert isinstance(gaps, list)
for gap in gaps:
assert "description" in gap
assert "recommended_action" in gap
def test_gap_deduplication_across_ledgers(self, sample_ledgers):
"""Verify gaps are deduplicated across multiple ledgers."""
all_gaps = []
seen = set()
for ledger in sample_ledgers:
for gap in ledger.gaps:
if gap.description not in seen:
seen.add(gap.description)
all_gaps.append(gap)
# Unique gaps should be <= total gaps across all ledgers
total = sum(len(lg.gaps) for lg in sample_ledgers)
assert len(all_gaps) <= total
class TestFullPipelineChain:
"""End-to-end data flow through service chain with mocked services."""
@pytest.mark.asyncio
async def test_profile_to_anchors_to_search_to_evaluate(
self, mock_medgemma, mock_gemini, sample_profile
):
"""Verify full chain: extract β†’ anchors β†’ search β†’ evaluate."""
from trialpath.services.gemini_planner import GeminiPlanner
from trialpath.services.medgemma_extractor import MedGemmaExtractor
# Step 1: Extract profile
extractor = MedGemmaExtractor()
profile = await extractor.extract(["dummy.pdf"], {})
assert isinstance(profile, PatientProfile)
# Step 2: Generate search anchors
planner = GeminiPlanner()
anchors = await planner.generate_search_anchors(profile)
assert isinstance(anchors, SearchAnchors)
# Step 3: Evaluate eligibility (mocked)
trial = MOCK_TRIAL_CANDIDATES[0]
ledger = await planner.aggregate_assessments(
profile=profile,
trial=trial.model_dump(),
assessments=[{"criterion": "Age >= 18", "decision": "met", "confidence": 0.95}],
)
assert isinstance(ledger, EligibilityLedger)
def test_full_data_contracts_compatible(self, sample_profile, sample_trials, sample_ledgers):
"""Verify all data contracts work together end-to-end."""
assert sample_profile.diagnosis is not None
# Profile β†’ SearchAnchors
anchors = SearchAnchors(
condition=sample_profile.diagnosis.primary_condition,
biomarkers=[b.name for b in sample_profile.biomarkers],
stage=sample_profile.diagnosis.stage,
)
assert anchors.condition is not None
# Trials + Ledgers
ledger_map = {lg.nct_id: lg for lg in sample_ledgers}
for trial in sample_trials:
assert trial.nct_id in ledger_map
ledger = ledger_map[trial.nct_id]
assert ledger.traffic_light in ("green", "yellow", "red")
# Gaps
all_gaps = []
for lg in sample_ledgers:
all_gaps.extend(lg.gaps)
# At least one gap in mock data
assert len(all_gaps) >= 1