File size: 6,225 Bytes
743ac52 e46883d 743ac52 f8adedd 743ac52 f8adedd 743ac52 97aee42 743ac52 e46883d 743ac52 97aee42 e46883d 743ac52 e46883d 743ac52 e46883d 97aee42 e46883d 743ac52 e46883d 743ac52 e46883d 743ac52 f8adedd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | """Shared pytest fixtures for TrialPath test suite."""
from __future__ import annotations
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
from app.services.mock_data import (
MOCK_ELIGIBILITY_LEDGERS,
MOCK_PATIENT_PROFILE,
MOCK_TRIAL_CANDIDATES,
)
from trialpath.models import (
EligibilityLedger,
PatientProfile,
SearchAnchors,
TrialCandidate,
)
# ---------------------------------------------------------------------------
# Sample data fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def sample_profile() -> PatientProfile:
"""Return the shared mock patient profile."""
return MOCK_PATIENT_PROFILE
@pytest.fixture()
def sample_trials() -> list[TrialCandidate]:
"""Return the shared mock trial candidates."""
return list(MOCK_TRIAL_CANDIDATES)
@pytest.fixture()
def sample_ledgers() -> list[EligibilityLedger]:
"""Return the shared mock eligibility ledgers."""
return list(MOCK_ELIGIBILITY_LEDGERS)
@pytest.fixture()
def sample_anchors(sample_profile: PatientProfile) -> SearchAnchors:
"""Build SearchAnchors from the mock profile."""
assert sample_profile.diagnosis is not None
assert sample_profile.performance_status is not None
return SearchAnchors(
condition=sample_profile.diagnosis.primary_condition,
subtype=sample_profile.diagnosis.histology,
biomarkers=[b.name for b in sample_profile.biomarkers],
stage=sample_profile.diagnosis.stage,
age=sample_profile.demographics.age,
performance_status_max=sample_profile.performance_status.value,
)
# ---------------------------------------------------------------------------
# Service mock fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def mock_medgemma():
"""Patch MedGemmaExtractor with a mock that returns sample profile data."""
with patch("trialpath.services.medgemma_extractor.MedGemmaExtractor") as cls:
instance = MagicMock()
instance.extract = AsyncMock(return_value=MOCK_PATIENT_PROFILE)
instance.evaluate_medical_criterion = AsyncMock(
return_value={
"decision": "met",
"confidence": 0.9,
"reasoning": "Criterion satisfied based on profile data.",
}
)
cls.return_value = instance
yield instance
@pytest.fixture()
def mock_gemini():
"""Patch GeminiPlanner with a mock that returns structured outputs."""
with patch("trialpath.services.gemini_planner.GeminiPlanner") as cls:
instance = MagicMock()
instance.generate_search_anchors = AsyncMock(
return_value=SearchAnchors(
condition="Non-Small Cell Lung Cancer",
biomarkers=["EGFR"],
stage="IIIB",
)
)
instance.evaluate_eligibility = AsyncMock(
return_value={
"overall_assessment": "uncertain",
"criteria": [],
}
)
instance.refine_search = AsyncMock(
return_value=SearchAnchors(
condition="NSCLC",
biomarkers=["EGFR"],
stage="IIIB",
)
)
instance.relax_search = AsyncMock(
return_value=SearchAnchors(
condition="Lung Cancer",
)
)
instance.slice_criteria = AsyncMock(
return_value=[
{"text": "Age >= 18", "type": "structural"},
{"text": "EGFR mutation positive", "type": "medical"},
]
)
instance.evaluate_structural_criterion = AsyncMock(
return_value={
"decision": "met",
"confidence": 0.95,
"reasoning": "Patient is 62, meets age requirement.",
}
)
instance.aggregate_assessments = AsyncMock(return_value=MOCK_ELIGIBILITY_LEDGERS[0])
instance.analyze_gaps = AsyncMock(
return_value=[
{
"description": "Brain MRI status unknown",
"recommended_action": "Order brain MRI",
"clinical_importance": "high",
}
]
)
cls.return_value = instance
yield instance
@pytest.fixture()
def mock_mcp():
"""Patch ClinicalTrialsMCPClient with a mock."""
with patch("trialpath.services.mcp_client.ClinicalTrialsMCPClient") as cls:
instance = AsyncMock()
instance.search_studies.return_value = {
"studies": [t.model_dump() for t in MOCK_TRIAL_CANDIDATES]
}
instance.get_study.return_value = MOCK_TRIAL_CANDIDATES[0].model_dump()
cls.return_value = instance
yield instance
# ---------------------------------------------------------------------------
# Live service fixtures (require real API keys / running servers)
# ---------------------------------------------------------------------------
@pytest.fixture(scope="session")
def live_env():
"""Ensure env vars are loaded; skip the entire session block if missing."""
if not os.environ.get("GEMINI_API_KEY"):
pytest.skip("GEMINI_API_KEY not set — skipping live tests")
@pytest.fixture(scope="session")
def live_gemini(live_env):
"""Return a real GeminiPlanner wired to the Gemini API."""
from trialpath.services.gemini_planner import GeminiPlanner
return GeminiPlanner()
@pytest.fixture(scope="session")
def live_mcp_client(live_env):
"""Return a real ClinicalTrialsMCPClient."""
from trialpath.services.mcp_client import ClinicalTrialsMCPClient
return ClinicalTrialsMCPClient()
@pytest.fixture(scope="session")
def live_medgemma(live_env):
"""Return a real MedGemmaExtractor (skip if no HF_TOKEN)."""
if not os.environ.get("HF_TOKEN"):
pytest.skip("HF_TOKEN not set — skipping MedGemma live tests")
from trialpath.services.medgemma_extractor import MedGemmaExtractor
return MedGemmaExtractor()
|