TrialPath / conftest.py
yakilee's picture
test: update tests for evidence-linked mock data and new features
f8adedd
"""Shared pytest fixtures for TrialPath test suite."""
from __future__ import annotations
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
from app.services.mock_data import (
MOCK_ELIGIBILITY_LEDGERS,
MOCK_PATIENT_PROFILE,
MOCK_TRIAL_CANDIDATES,
)
from trialpath.models import (
EligibilityLedger,
PatientProfile,
SearchAnchors,
TrialCandidate,
)
# ---------------------------------------------------------------------------
# Sample data fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def sample_profile() -> PatientProfile:
"""Return the shared mock patient profile."""
return MOCK_PATIENT_PROFILE
@pytest.fixture()
def sample_trials() -> list[TrialCandidate]:
"""Return the shared mock trial candidates."""
return list(MOCK_TRIAL_CANDIDATES)
@pytest.fixture()
def sample_ledgers() -> list[EligibilityLedger]:
"""Return the shared mock eligibility ledgers."""
return list(MOCK_ELIGIBILITY_LEDGERS)
@pytest.fixture()
def sample_anchors(sample_profile: PatientProfile) -> SearchAnchors:
"""Build SearchAnchors from the mock profile."""
assert sample_profile.diagnosis is not None
assert sample_profile.performance_status is not None
return SearchAnchors(
condition=sample_profile.diagnosis.primary_condition,
subtype=sample_profile.diagnosis.histology,
biomarkers=[b.name for b in sample_profile.biomarkers],
stage=sample_profile.diagnosis.stage,
age=sample_profile.demographics.age,
performance_status_max=sample_profile.performance_status.value,
)
# ---------------------------------------------------------------------------
# Service mock fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def mock_medgemma():
"""Patch MedGemmaExtractor with a mock that returns sample profile data."""
with patch("trialpath.services.medgemma_extractor.MedGemmaExtractor") as cls:
instance = MagicMock()
instance.extract = AsyncMock(return_value=MOCK_PATIENT_PROFILE)
instance.evaluate_medical_criterion = AsyncMock(
return_value={
"decision": "met",
"confidence": 0.9,
"reasoning": "Criterion satisfied based on profile data.",
}
)
cls.return_value = instance
yield instance
@pytest.fixture()
def mock_gemini():
"""Patch GeminiPlanner with a mock that returns structured outputs."""
with patch("trialpath.services.gemini_planner.GeminiPlanner") as cls:
instance = MagicMock()
instance.generate_search_anchors = AsyncMock(
return_value=SearchAnchors(
condition="Non-Small Cell Lung Cancer",
biomarkers=["EGFR"],
stage="IIIB",
)
)
instance.evaluate_eligibility = AsyncMock(
return_value={
"overall_assessment": "uncertain",
"criteria": [],
}
)
instance.refine_search = AsyncMock(
return_value=SearchAnchors(
condition="NSCLC",
biomarkers=["EGFR"],
stage="IIIB",
)
)
instance.relax_search = AsyncMock(
return_value=SearchAnchors(
condition="Lung Cancer",
)
)
instance.slice_criteria = AsyncMock(
return_value=[
{"text": "Age >= 18", "type": "structural"},
{"text": "EGFR mutation positive", "type": "medical"},
]
)
instance.evaluate_structural_criterion = AsyncMock(
return_value={
"decision": "met",
"confidence": 0.95,
"reasoning": "Patient is 62, meets age requirement.",
}
)
instance.aggregate_assessments = AsyncMock(return_value=MOCK_ELIGIBILITY_LEDGERS[0])
instance.analyze_gaps = AsyncMock(
return_value=[
{
"description": "Brain MRI status unknown",
"recommended_action": "Order brain MRI",
"clinical_importance": "high",
}
]
)
cls.return_value = instance
yield instance
@pytest.fixture()
def mock_mcp():
"""Patch ClinicalTrialsMCPClient with a mock."""
with patch("trialpath.services.mcp_client.ClinicalTrialsMCPClient") as cls:
instance = AsyncMock()
instance.search_studies.return_value = {
"studies": [t.model_dump() for t in MOCK_TRIAL_CANDIDATES]
}
instance.get_study.return_value = MOCK_TRIAL_CANDIDATES[0].model_dump()
cls.return_value = instance
yield instance
# ---------------------------------------------------------------------------
# Live service fixtures (require real API keys / running servers)
# ---------------------------------------------------------------------------
@pytest.fixture(scope="session")
def live_env():
"""Ensure env vars are loaded; skip the entire session block if missing."""
if not os.environ.get("GEMINI_API_KEY"):
pytest.skip("GEMINI_API_KEY not set — skipping live tests")
@pytest.fixture(scope="session")
def live_gemini(live_env):
"""Return a real GeminiPlanner wired to the Gemini API."""
from trialpath.services.gemini_planner import GeminiPlanner
return GeminiPlanner()
@pytest.fixture(scope="session")
def live_mcp_client(live_env):
"""Return a real ClinicalTrialsMCPClient."""
from trialpath.services.mcp_client import ClinicalTrialsMCPClient
return ClinicalTrialsMCPClient()
@pytest.fixture(scope="session")
def live_medgemma(live_env):
"""Return a real MedGemmaExtractor (skip if no HF_TOKEN)."""
if not os.environ.get("HF_TOKEN"):
pytest.skip("HF_TOKEN not set — skipping MedGemma live tests")
from trialpath.services.medgemma_extractor import MedGemmaExtractor
return MedGemmaExtractor()