"""Shared pytest fixtures for TrialPath test suite.""" from __future__ import annotations import os from unittest.mock import AsyncMock, MagicMock, patch import pytest try: from dotenv import load_dotenv load_dotenv() except ImportError: pass from app.services.mock_data import ( MOCK_ELIGIBILITY_LEDGERS, MOCK_PATIENT_PROFILE, MOCK_TRIAL_CANDIDATES, ) from trialpath.models import ( EligibilityLedger, PatientProfile, SearchAnchors, TrialCandidate, ) # --------------------------------------------------------------------------- # Sample data fixtures # --------------------------------------------------------------------------- @pytest.fixture() def sample_profile() -> PatientProfile: """Return the shared mock patient profile.""" return MOCK_PATIENT_PROFILE @pytest.fixture() def sample_trials() -> list[TrialCandidate]: """Return the shared mock trial candidates.""" return list(MOCK_TRIAL_CANDIDATES) @pytest.fixture() def sample_ledgers() -> list[EligibilityLedger]: """Return the shared mock eligibility ledgers.""" return list(MOCK_ELIGIBILITY_LEDGERS) @pytest.fixture() def sample_anchors(sample_profile: PatientProfile) -> SearchAnchors: """Build SearchAnchors from the mock profile.""" assert sample_profile.diagnosis is not None assert sample_profile.performance_status is not None return SearchAnchors( condition=sample_profile.diagnosis.primary_condition, subtype=sample_profile.diagnosis.histology, biomarkers=[b.name for b in sample_profile.biomarkers], stage=sample_profile.diagnosis.stage, age=sample_profile.demographics.age, performance_status_max=sample_profile.performance_status.value, ) # --------------------------------------------------------------------------- # Service mock fixtures # --------------------------------------------------------------------------- @pytest.fixture() def mock_medgemma(): """Patch MedGemmaExtractor with a mock that returns sample profile data.""" with patch("trialpath.services.medgemma_extractor.MedGemmaExtractor") as cls: instance = MagicMock() instance.extract = AsyncMock(return_value=MOCK_PATIENT_PROFILE) instance.evaluate_medical_criterion = AsyncMock( return_value={ "decision": "met", "confidence": 0.9, "reasoning": "Criterion satisfied based on profile data.", } ) cls.return_value = instance yield instance @pytest.fixture() def mock_gemini(): """Patch GeminiPlanner with a mock that returns structured outputs.""" with patch("trialpath.services.gemini_planner.GeminiPlanner") as cls: instance = MagicMock() instance.generate_search_anchors = AsyncMock( return_value=SearchAnchors( condition="Non-Small Cell Lung Cancer", biomarkers=["EGFR"], stage="IIIB", ) ) instance.evaluate_eligibility = AsyncMock( return_value={ "overall_assessment": "uncertain", "criteria": [], } ) instance.refine_search = AsyncMock( return_value=SearchAnchors( condition="NSCLC", biomarkers=["EGFR"], stage="IIIB", ) ) instance.relax_search = AsyncMock( return_value=SearchAnchors( condition="Lung Cancer", ) ) instance.slice_criteria = AsyncMock( return_value=[ {"text": "Age >= 18", "type": "structural"}, {"text": "EGFR mutation positive", "type": "medical"}, ] ) instance.evaluate_structural_criterion = AsyncMock( return_value={ "decision": "met", "confidence": 0.95, "reasoning": "Patient is 62, meets age requirement.", } ) instance.aggregate_assessments = AsyncMock(return_value=MOCK_ELIGIBILITY_LEDGERS[0]) instance.analyze_gaps = AsyncMock( return_value=[ { "description": "Brain MRI status unknown", "recommended_action": "Order brain MRI", "clinical_importance": "high", } ] ) cls.return_value = instance yield instance @pytest.fixture() def mock_mcp(): """Patch ClinicalTrialsMCPClient with a mock.""" with patch("trialpath.services.mcp_client.ClinicalTrialsMCPClient") as cls: instance = AsyncMock() instance.search_studies.return_value = { "studies": [t.model_dump() for t in MOCK_TRIAL_CANDIDATES] } instance.get_study.return_value = MOCK_TRIAL_CANDIDATES[0].model_dump() cls.return_value = instance yield instance # --------------------------------------------------------------------------- # Live service fixtures (require real API keys / running servers) # --------------------------------------------------------------------------- @pytest.fixture(scope="session") def live_env(): """Ensure env vars are loaded; skip the entire session block if missing.""" if not os.environ.get("GEMINI_API_KEY"): pytest.skip("GEMINI_API_KEY not set — skipping live tests") @pytest.fixture(scope="session") def live_gemini(live_env): """Return a real GeminiPlanner wired to the Gemini API.""" from trialpath.services.gemini_planner import GeminiPlanner return GeminiPlanner() @pytest.fixture(scope="session") def live_mcp_client(live_env): """Return a real ClinicalTrialsMCPClient.""" from trialpath.services.mcp_client import ClinicalTrialsMCPClient return ClinicalTrialsMCPClient() @pytest.fixture(scope="session") def live_medgemma(live_env): """Return a real MedGemmaExtractor (skip if no HF_TOKEN).""" if not os.environ.get("HF_TOKEN"): pytest.skip("HF_TOKEN not set — skipping MedGemma live tests") from trialpath.services.medgemma_extractor import MedGemmaExtractor return MedGemmaExtractor()