File size: 6,225 Bytes
743ac52
e46883d
743ac52
 
f8adedd
743ac52
 
 
 
f8adedd
 
 
 
 
 
 
743ac52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97aee42
 
743ac52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e46883d
743ac52
97aee42
e46883d
 
 
 
 
 
 
743ac52
 
 
 
 
 
 
e46883d
743ac52
e46883d
 
 
 
 
 
 
 
 
 
 
 
97aee42
e46883d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
743ac52
e46883d
 
 
 
 
 
 
 
 
 
 
743ac52
 
 
 
 
 
 
e46883d
743ac52
 
 
 
 
 
 
f8adedd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""Shared pytest fixtures for TrialPath test suite."""

from __future__ import annotations

import os
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

try:
    from dotenv import load_dotenv

    load_dotenv()
except ImportError:
    pass

from app.services.mock_data import (
    MOCK_ELIGIBILITY_LEDGERS,
    MOCK_PATIENT_PROFILE,
    MOCK_TRIAL_CANDIDATES,
)
from trialpath.models import (
    EligibilityLedger,
    PatientProfile,
    SearchAnchors,
    TrialCandidate,
)

# ---------------------------------------------------------------------------
# Sample data fixtures
# ---------------------------------------------------------------------------


@pytest.fixture()
def sample_profile() -> PatientProfile:
    """Return the shared mock patient profile."""
    return MOCK_PATIENT_PROFILE


@pytest.fixture()
def sample_trials() -> list[TrialCandidate]:
    """Return the shared mock trial candidates."""
    return list(MOCK_TRIAL_CANDIDATES)


@pytest.fixture()
def sample_ledgers() -> list[EligibilityLedger]:
    """Return the shared mock eligibility ledgers."""
    return list(MOCK_ELIGIBILITY_LEDGERS)


@pytest.fixture()
def sample_anchors(sample_profile: PatientProfile) -> SearchAnchors:
    """Build SearchAnchors from the mock profile."""
    assert sample_profile.diagnosis is not None
    assert sample_profile.performance_status is not None
    return SearchAnchors(
        condition=sample_profile.diagnosis.primary_condition,
        subtype=sample_profile.diagnosis.histology,
        biomarkers=[b.name for b in sample_profile.biomarkers],
        stage=sample_profile.diagnosis.stage,
        age=sample_profile.demographics.age,
        performance_status_max=sample_profile.performance_status.value,
    )


# ---------------------------------------------------------------------------
# Service mock fixtures
# ---------------------------------------------------------------------------


@pytest.fixture()
def mock_medgemma():
    """Patch MedGemmaExtractor with a mock that returns sample profile data."""
    with patch("trialpath.services.medgemma_extractor.MedGemmaExtractor") as cls:
        instance = MagicMock()
        instance.extract = AsyncMock(return_value=MOCK_PATIENT_PROFILE)
        instance.evaluate_medical_criterion = AsyncMock(
            return_value={
                "decision": "met",
                "confidence": 0.9,
                "reasoning": "Criterion satisfied based on profile data.",
            }
        )
        cls.return_value = instance
        yield instance


@pytest.fixture()
def mock_gemini():
    """Patch GeminiPlanner with a mock that returns structured outputs."""
    with patch("trialpath.services.gemini_planner.GeminiPlanner") as cls:
        instance = MagicMock()
        instance.generate_search_anchors = AsyncMock(
            return_value=SearchAnchors(
                condition="Non-Small Cell Lung Cancer",
                biomarkers=["EGFR"],
                stage="IIIB",
            )
        )
        instance.evaluate_eligibility = AsyncMock(
            return_value={
                "overall_assessment": "uncertain",
                "criteria": [],
            }
        )
        instance.refine_search = AsyncMock(
            return_value=SearchAnchors(
                condition="NSCLC",
                biomarkers=["EGFR"],
                stage="IIIB",
            )
        )
        instance.relax_search = AsyncMock(
            return_value=SearchAnchors(
                condition="Lung Cancer",
            )
        )
        instance.slice_criteria = AsyncMock(
            return_value=[
                {"text": "Age >= 18", "type": "structural"},
                {"text": "EGFR mutation positive", "type": "medical"},
            ]
        )
        instance.evaluate_structural_criterion = AsyncMock(
            return_value={
                "decision": "met",
                "confidence": 0.95,
                "reasoning": "Patient is 62, meets age requirement.",
            }
        )
        instance.aggregate_assessments = AsyncMock(return_value=MOCK_ELIGIBILITY_LEDGERS[0])
        instance.analyze_gaps = AsyncMock(
            return_value=[
                {
                    "description": "Brain MRI status unknown",
                    "recommended_action": "Order brain MRI",
                    "clinical_importance": "high",
                }
            ]
        )
        cls.return_value = instance
        yield instance


@pytest.fixture()
def mock_mcp():
    """Patch ClinicalTrialsMCPClient with a mock."""
    with patch("trialpath.services.mcp_client.ClinicalTrialsMCPClient") as cls:
        instance = AsyncMock()
        instance.search_studies.return_value = {
            "studies": [t.model_dump() for t in MOCK_TRIAL_CANDIDATES]
        }
        instance.get_study.return_value = MOCK_TRIAL_CANDIDATES[0].model_dump()
        cls.return_value = instance
        yield instance


# ---------------------------------------------------------------------------
# Live service fixtures (require real API keys / running servers)
# ---------------------------------------------------------------------------


@pytest.fixture(scope="session")
def live_env():
    """Ensure env vars are loaded; skip the entire session block if missing."""
    if not os.environ.get("GEMINI_API_KEY"):
        pytest.skip("GEMINI_API_KEY not set — skipping live tests")


@pytest.fixture(scope="session")
def live_gemini(live_env):
    """Return a real GeminiPlanner wired to the Gemini API."""
    from trialpath.services.gemini_planner import GeminiPlanner

    return GeminiPlanner()


@pytest.fixture(scope="session")
def live_mcp_client(live_env):
    """Return a real ClinicalTrialsMCPClient."""
    from trialpath.services.mcp_client import ClinicalTrialsMCPClient

    return ClinicalTrialsMCPClient()


@pytest.fixture(scope="session")
def live_medgemma(live_env):
    """Return a real MedGemmaExtractor (skip if no HF_TOKEN)."""
    if not os.environ.get("HF_TOKEN"):
        pytest.skip("HF_TOKEN not set — skipping MedGemma live tests")

    from trialpath.services.medgemma_extractor import MedGemmaExtractor

    return MedGemmaExtractor()