| | """Shared pytest fixtures for TrialPath test suite.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import os |
| | from unittest.mock import AsyncMock, MagicMock, patch |
| |
|
| | import pytest |
| |
|
| | try: |
| | from dotenv import load_dotenv |
| |
|
| | load_dotenv() |
| | except ImportError: |
| | pass |
| |
|
| | from app.services.mock_data import ( |
| | MOCK_ELIGIBILITY_LEDGERS, |
| | MOCK_PATIENT_PROFILE, |
| | MOCK_TRIAL_CANDIDATES, |
| | ) |
| | from trialpath.models import ( |
| | EligibilityLedger, |
| | PatientProfile, |
| | SearchAnchors, |
| | TrialCandidate, |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @pytest.fixture() |
| | def sample_profile() -> PatientProfile: |
| | """Return the shared mock patient profile.""" |
| | return MOCK_PATIENT_PROFILE |
| |
|
| |
|
| | @pytest.fixture() |
| | def sample_trials() -> list[TrialCandidate]: |
| | """Return the shared mock trial candidates.""" |
| | return list(MOCK_TRIAL_CANDIDATES) |
| |
|
| |
|
| | @pytest.fixture() |
| | def sample_ledgers() -> list[EligibilityLedger]: |
| | """Return the shared mock eligibility ledgers.""" |
| | return list(MOCK_ELIGIBILITY_LEDGERS) |
| |
|
| |
|
| | @pytest.fixture() |
| | def sample_anchors(sample_profile: PatientProfile) -> SearchAnchors: |
| | """Build SearchAnchors from the mock profile.""" |
| | assert sample_profile.diagnosis is not None |
| | assert sample_profile.performance_status is not None |
| | return SearchAnchors( |
| | condition=sample_profile.diagnosis.primary_condition, |
| | subtype=sample_profile.diagnosis.histology, |
| | biomarkers=[b.name for b in sample_profile.biomarkers], |
| | stage=sample_profile.diagnosis.stage, |
| | age=sample_profile.demographics.age, |
| | performance_status_max=sample_profile.performance_status.value, |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @pytest.fixture() |
| | def mock_medgemma(): |
| | """Patch MedGemmaExtractor with a mock that returns sample profile data.""" |
| | with patch("trialpath.services.medgemma_extractor.MedGemmaExtractor") as cls: |
| | instance = MagicMock() |
| | instance.extract = AsyncMock(return_value=MOCK_PATIENT_PROFILE) |
| | instance.evaluate_medical_criterion = AsyncMock( |
| | return_value={ |
| | "decision": "met", |
| | "confidence": 0.9, |
| | "reasoning": "Criterion satisfied based on profile data.", |
| | } |
| | ) |
| | cls.return_value = instance |
| | yield instance |
| |
|
| |
|
| | @pytest.fixture() |
| | def mock_gemini(): |
| | """Patch GeminiPlanner with a mock that returns structured outputs.""" |
| | with patch("trialpath.services.gemini_planner.GeminiPlanner") as cls: |
| | instance = MagicMock() |
| | instance.generate_search_anchors = AsyncMock( |
| | return_value=SearchAnchors( |
| | condition="Non-Small Cell Lung Cancer", |
| | biomarkers=["EGFR"], |
| | stage="IIIB", |
| | ) |
| | ) |
| | instance.evaluate_eligibility = AsyncMock( |
| | return_value={ |
| | "overall_assessment": "uncertain", |
| | "criteria": [], |
| | } |
| | ) |
| | instance.refine_search = AsyncMock( |
| | return_value=SearchAnchors( |
| | condition="NSCLC", |
| | biomarkers=["EGFR"], |
| | stage="IIIB", |
| | ) |
| | ) |
| | instance.relax_search = AsyncMock( |
| | return_value=SearchAnchors( |
| | condition="Lung Cancer", |
| | ) |
| | ) |
| | instance.slice_criteria = AsyncMock( |
| | return_value=[ |
| | {"text": "Age >= 18", "type": "structural"}, |
| | {"text": "EGFR mutation positive", "type": "medical"}, |
| | ] |
| | ) |
| | instance.evaluate_structural_criterion = AsyncMock( |
| | return_value={ |
| | "decision": "met", |
| | "confidence": 0.95, |
| | "reasoning": "Patient is 62, meets age requirement.", |
| | } |
| | ) |
| | instance.aggregate_assessments = AsyncMock(return_value=MOCK_ELIGIBILITY_LEDGERS[0]) |
| | instance.analyze_gaps = AsyncMock( |
| | return_value=[ |
| | { |
| | "description": "Brain MRI status unknown", |
| | "recommended_action": "Order brain MRI", |
| | "clinical_importance": "high", |
| | } |
| | ] |
| | ) |
| | cls.return_value = instance |
| | yield instance |
| |
|
| |
|
| | @pytest.fixture() |
| | def mock_mcp(): |
| | """Patch ClinicalTrialsMCPClient with a mock.""" |
| | with patch("trialpath.services.mcp_client.ClinicalTrialsMCPClient") as cls: |
| | instance = AsyncMock() |
| | instance.search_studies.return_value = { |
| | "studies": [t.model_dump() for t in MOCK_TRIAL_CANDIDATES] |
| | } |
| | instance.get_study.return_value = MOCK_TRIAL_CANDIDATES[0].model_dump() |
| | cls.return_value = instance |
| | yield instance |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @pytest.fixture(scope="session") |
| | def live_env(): |
| | """Ensure env vars are loaded; skip the entire session block if missing.""" |
| | if not os.environ.get("GEMINI_API_KEY"): |
| | pytest.skip("GEMINI_API_KEY not set — skipping live tests") |
| |
|
| |
|
| | @pytest.fixture(scope="session") |
| | def live_gemini(live_env): |
| | """Return a real GeminiPlanner wired to the Gemini API.""" |
| | from trialpath.services.gemini_planner import GeminiPlanner |
| |
|
| | return GeminiPlanner() |
| |
|
| |
|
| | @pytest.fixture(scope="session") |
| | def live_mcp_client(live_env): |
| | """Return a real ClinicalTrialsMCPClient.""" |
| | from trialpath.services.mcp_client import ClinicalTrialsMCPClient |
| |
|
| | return ClinicalTrialsMCPClient() |
| |
|
| |
|
| | @pytest.fixture(scope="session") |
| | def live_medgemma(live_env): |
| | """Return a real MedGemmaExtractor (skip if no HF_TOKEN).""" |
| | if not os.environ.get("HF_TOKEN"): |
| | pytest.skip("HF_TOKEN not set — skipping MedGemma live tests") |
| |
|
| | from trialpath.services.medgemma_extractor import MedGemmaExtractor |
| |
|
| | return MedGemmaExtractor() |
| |
|