TrialPath / tests /test_e2e.py
yakilee's picture
style: apply ruff format to entire codebase
e46883d
"""End-to-end smoke test: full journey with mocked services."""
from __future__ import annotations
import pytest
from app.services.mock_data import (
MOCK_ELIGIBILITY_LEDGERS,
MOCK_PATIENT_PROFILE,
MOCK_TRIAL_CANDIDATES,
)
from app.services.state_manager import JOURNEY_STATES
from trialpath.models import (
EligibilityLedger,
PatientProfile,
SearchAnchors,
TrialCandidate,
)
class TestE2EJourney:
"""Simulate the full 5-state journey: INGEST β†’ PRESCREEN β†’ VALIDATE β†’ GAP β†’ SUMMARY."""
def _build_session_state(self) -> dict:
"""Create a minimal session state dict simulating Streamlit."""
return {
"journey_state": "INGEST",
"parlant_session_id": None,
"parlant_agent_id": None,
"parlant_session_active": False,
"patient_profile": None,
"uploaded_files": [],
"search_anchors": None,
"trial_candidates": [],
"eligibility_ledger": [],
"last_event_offset": 0,
}
def test_full_journey_state_transitions(self):
"""Verify all state transitions complete in correct order."""
state = self._build_session_state()
# INGEST β†’ PRESCREEN
assert state["journey_state"] == "INGEST"
state["patient_profile"] = MOCK_PATIENT_PROFILE
state["journey_state"] = "PRESCREEN"
# PRESCREEN β†’ VALIDATE_TRIALS
assert state["journey_state"] == "PRESCREEN"
anchors = SearchAnchors(
condition="Non-Small Cell Lung Cancer",
biomarkers=["EGFR"],
stage="IIIB",
)
state["search_anchors"] = anchors
state["trial_candidates"] = list(MOCK_TRIAL_CANDIDATES)
state["journey_state"] = "VALIDATE_TRIALS"
# VALIDATE_TRIALS β†’ GAP_FOLLOWUP
assert state["journey_state"] == "VALIDATE_TRIALS"
state["eligibility_ledger"] = list(MOCK_ELIGIBILITY_LEDGERS)
state["journey_state"] = "GAP_FOLLOWUP"
# GAP_FOLLOWUP β†’ SUMMARY
assert state["journey_state"] == "GAP_FOLLOWUP"
state["journey_state"] = "SUMMARY"
assert state["journey_state"] == "SUMMARY"
def test_journey_produces_exportable_data(self):
"""Verify end state has all data needed for doctor packet export."""
state = self._build_session_state()
state["patient_profile"] = MOCK_PATIENT_PROFILE
state["trial_candidates"] = list(MOCK_TRIAL_CANDIDATES)
state["eligibility_ledger"] = list(MOCK_ELIGIBILITY_LEDGERS)
state["journey_state"] = "SUMMARY"
# Verify export data
profile = state["patient_profile"]
ledgers = state["eligibility_ledger"]
trials = state["trial_candidates"]
assert isinstance(profile, PatientProfile)
assert len(trials) == 3
assert len(ledgers) == 3
eligible = sum(1 for lg in ledgers if lg.traffic_light == "green")
uncertain = sum(1 for lg in ledgers if lg.traffic_light == "yellow")
ineligible = sum(1 for lg in ledgers if lg.traffic_light == "red")
assert eligible == 1
assert uncertain == 1
assert ineligible == 1
def test_gap_loop_back_to_ingest(self):
"""Verify GAP_FOLLOWUP can loop back to INGEST for new docs."""
state = self._build_session_state()
state["patient_profile"] = MOCK_PATIENT_PROFILE
state["trial_candidates"] = list(MOCK_TRIAL_CANDIDATES)
state["eligibility_ledger"] = list(MOCK_ELIGIBILITY_LEDGERS)
state["journey_state"] = "GAP_FOLLOWUP"
# User decides to upload more documents
state["journey_state"] = "INGEST"
assert state["journey_state"] == "INGEST"
# Existing data preserved for re-evaluation
assert state["patient_profile"] is not None
assert len(state["trial_candidates"]) == 3
def test_all_journey_states_reachable(self):
"""Verify each of the 5 journey states can be reached."""
state = self._build_session_state()
visited = []
for target_state in JOURNEY_STATES:
state["journey_state"] = target_state
visited.append(state["journey_state"])
assert visited == JOURNEY_STATES
assert len(visited) == 5
class TestE2EWithMockedServices:
"""E2E test using mocked service calls to verify data flow."""
@pytest.mark.asyncio
async def test_extract_to_search_to_evaluate_chain(self, mock_medgemma, mock_gemini):
"""Full service chain: extraction β†’ search anchors β†’ evaluate."""
from trialpath.services.gemini_planner import GeminiPlanner
from trialpath.services.medgemma_extractor import MedGemmaExtractor
# Step 1: Extract patient profile
extractor = MedGemmaExtractor()
profile = await extractor.extract(["patient_notes.pdf"], {})
assert isinstance(profile, PatientProfile)
# Step 2: Generate search anchors
planner = GeminiPlanner()
anchors = await planner.generate_search_anchors(profile)
assert isinstance(anchors, SearchAnchors)
# Step 3: Slice + evaluate criteria
criteria = await planner.slice_criteria(MOCK_TRIAL_CANDIDATES[0].model_dump())
assert len(criteria) >= 1
# Step 4: Evaluate each criterion
assessments = []
for c in criteria:
if c["type"] == "medical":
result = await extractor.evaluate_medical_criterion(c["text"], profile, [])
else:
result = await planner.evaluate_structural_criterion(c["text"], profile)
assessments.append(
{
"criterion": c["text"],
"decision": result["decision"],
"confidence": result.get("confidence", 0.5),
}
)
assert len(assessments) == len(criteria)
# Step 5: Aggregate into ledger
ledger = await planner.aggregate_assessments(
profile=profile,
trial=MOCK_TRIAL_CANDIDATES[0].model_dump(),
assessments=assessments,
)
assert isinstance(ledger, EligibilityLedger)
def test_data_contracts_survive_serialization(self):
"""Verify all data contracts survive JSON roundtrip."""
# PatientProfile
p_json = MOCK_PATIENT_PROFILE.model_dump_json()
p_restored = PatientProfile.model_validate_json(p_json)
assert p_restored.patient_id == MOCK_PATIENT_PROFILE.patient_id
# TrialCandidate
for t in MOCK_TRIAL_CANDIDATES:
t_json = t.model_dump_json()
t_restored = TrialCandidate.model_validate_json(t_json)
assert t_restored.nct_id == t.nct_id
# EligibilityLedger
for lg in MOCK_ELIGIBILITY_LEDGERS:
lg_json = lg.model_dump_json()
lg_restored = EligibilityLedger.model_validate_json(lg_json)
assert lg_restored.nct_id == lg.nct_id
# SearchAnchors
anchors = SearchAnchors(
condition="NSCLC",
biomarkers=["EGFR", "ALK"],
stage="IV",
)
a_json = anchors.model_dump_json()
a_restored = SearchAnchors.model_validate_json(a_json)
assert a_restored.condition == "NSCLC"
class TestE2ELatencyBudget:
"""Verify operations complete within latency budget (mocked)."""
@pytest.mark.asyncio
async def test_mock_operations_are_fast(self, mock_medgemma, mock_gemini):
"""With mocked services, full chain should complete near-instantly."""
import time
from trialpath.services.gemini_planner import GeminiPlanner
from trialpath.services.medgemma_extractor import MedGemmaExtractor
start = time.monotonic()
extractor = MedGemmaExtractor()
profile = await extractor.extract(["doc.pdf"], {})
planner = GeminiPlanner()
await planner.generate_search_anchors(profile)
criteria = await planner.slice_criteria(MOCK_TRIAL_CANDIDATES[0].model_dump())
for c in criteria:
if c["type"] == "medical":
await extractor.evaluate_medical_criterion(c["text"], profile, [])
else:
await planner.evaluate_structural_criterion(c["text"], profile)
await planner.aggregate_assessments(
profile=profile,
trial=MOCK_TRIAL_CANDIDATES[0].model_dump(),
assessments=[],
)
await planner.analyze_gaps(profile, list(MOCK_ELIGIBILITY_LEDGERS))
elapsed = time.monotonic() - start
# With mocks, should complete well under 1 second
assert elapsed < 1.0, f"Mock pipeline took {elapsed:.2f}s, expected < 1s"