File size: 7,728 Bytes
59abb4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
from fhir_adapter import get_patient_profile, get_all_patient_ids
from clinicaltrials_api import search_trials_sync, get_trial_details_sync
from llm_client import parse_trial_protocol, score_patient_against_criteria
import re

try:
    from neo4j_setup import neo4j_conn as _neo4j
except Exception:
    _neo4j = None

# In-memory cache for parsed criteria and scores
_criteria_cache: dict[str, dict] = {}
_score_cache: dict[str, dict] = {}


def _parse_age_string(age_str: str) -> int | None:
    if not age_str:
        return None
    match = re.search(r"(\d+)", age_str)
    return int(match.group(1)) if match else None


def _quick_eligibility_check(patient_profile: dict, trial: dict) -> tuple[bool, list[str]]:
    """Rule-based pre-filter before expensive LLM scoring."""
    flags = []
    age = patient_profile.get("age", 0)

    min_age = _parse_age_string(trial.get("min_age", ""))
    max_age = _parse_age_string(trial.get("max_age", ""))

    if min_age and age < min_age:
        flags.append(f"Age {age} below minimum {min_age}")
    if max_age and age > max_age:
        flags.append(f"Age {age} above maximum {max_age}")

    trial_sex = trial.get("sex", "ALL").upper()
    patient_sex = patient_profile.get("gender", "").upper()
    if trial_sex not in ("ALL", "BOTH") and patient_sex and patient_sex[0] != trial_sex[0]:
        flags.append(f"Sex mismatch: trial requires {trial_sex}")

    return len(flags) == 0, flags


def get_criteria_for_trial(trial: dict) -> dict:
    nct_id = trial.get("nct_id", "")
    if nct_id in _criteria_cache:
        return _criteria_cache[nct_id]

    eligibility_text = trial.get("eligibility_criteria", "")
    if eligibility_text:
        criteria = parse_trial_protocol(eligibility_text)
    else:
        criteria = {
            "inclusion_criteria": [f"Confirmed diagnosis of {trial.get('brief_summary', 'target condition')[:50]}"],
            "exclusion_criteria": ["Prior participation in conflicting trials"],
            "age_range": {"min": 18, "max": None},
            "required_diagnoses": [],
            "required_biomarkers": [],
            "excluded_medications": [],
            "performance_status": None,
        }

    _criteria_cache[nct_id] = criteria
    return criteria


def score_patient_for_trial(patient_id: str, trial: dict) -> dict:
    cache_key = f"{patient_id}:{trial.get('nct_id', '')}"
    if cache_key in _score_cache:
        return _score_cache[cache_key]

    patient_profile = get_patient_profile(patient_id)
    if not patient_profile:
        return {"error": "Patient not found", "overall_score": 0.0, "eligible": False}

    # Quick rule-based pre-filter
    passes_rules, rule_flags = _quick_eligibility_check(patient_profile, trial)

    criteria = get_criteria_for_trial(trial)
    result = score_patient_against_criteria(patient_profile, criteria, trial.get("title", "Clinical Trial"))

    if not passes_rules:
        result["overall_score"] = max(0.0, result.get("overall_score", 0.5) - 0.3)
        result["eligible"] = False
        result.setdefault("risk_flags", []).extend(rule_flags)

    result["patient_id"] = patient_id
    result["nct_id"] = trial.get("nct_id", "")
    result["trial_title"] = trial.get("title", "")
    result["match_path"] = _build_match_path(patient_profile, trial, criteria)

    _score_cache[cache_key] = result
    return result


def _build_match_path(patient_profile: dict, trial: dict, criteria: dict) -> list[dict]:
    """
    Build a human-readable graph explainability path showing WHY a patient was matched.
    Returns a list of path nodes: Patient → biomarker/diagnosis/lab → Trial
    """
    path = []
    patient_id = patient_profile.get("patient_id", "")
    nct_id = trial.get("nct_id", "")
    trial_title = trial.get("title", "")[:60]

    # Check graph for shared biomarker edges
    if _neo4j:
        try:
            rows = _neo4j.run_query(
                """
                MATCH (p:Patient {id: $pid})-[:HAS_BIOMARKER]->(b:Biomarker)
                MATCH (t:Trial {id: $nct_id})
                WHERE t.parsed_biomarkers CONTAINS b.name OR t.eligibility_criteria CONTAINS b.name
                RETURN b.name AS biomarker LIMIT 3
                """,
                {"pid": patient_id, "nct_id": nct_id},
            )
            for row in rows:
                path.append({
                    "from": f"Patient:{patient_id}",
                    "rel": "HAS_BIOMARKER",
                    "to": f"Biomarker:{row['biomarker']}",
                    "note": "required by trial",
                })
        except Exception:
            pass

    # Add FHIR-based reasoning nodes from the criteria match
    for item in (criteria.get("required_biomarkers") or [])[:2]:
        biomarkers = patient_profile.get("biomarkers", {})
        if any(item.lower() in str(k).lower() or item.lower() in str(v).lower()
               for k, v in biomarkers.items()):
            path.append({
                "from": f"Patient:{patient_id}",
                "rel": "HAS_BIOMARKER",
                "to": f"Biomarker:{item}",
                "note": "matches trial requirement",
            })

    for dx in (criteria.get("required_diagnoses") or [])[:2]:
        for patient_dx in patient_profile.get("diagnosis_names", []):
            if any(word in patient_dx.lower() for word in dx.lower().split()):
                path.append({
                    "from": f"Patient:{patient_id}",
                    "rel": "HAS_DIAGNOSIS",
                    "to": f"Diagnosis:{patient_dx}",
                    "note": f"matches required: {dx}",
                })
                break

    # Terminal node
    path.append({
        "from": f"Patient:{patient_id}",
        "rel": "ELIGIBLE_FOR",
        "to": f"Trial:{nct_id}",
        "note": trial_title,
    })
    return path


def match_patient_to_trials(patient_id: str, condition: str | None = None, top_n: int = 5) -> list[dict]:
    """Find best-matching trials for a patient."""
    patient_profile = get_patient_profile(patient_id)
    if not patient_profile:
        return []

    # Infer condition from patient diagnoses if not provided
    if not condition and patient_profile.get("diagnosis_names"):
        condition = patient_profile["diagnosis_names"][0]
    elif not condition:
        condition = "cancer"

    trials = search_trials_sync(condition, page_size=10)

    scored = []
    for trial in trials:
        score_result = score_patient_for_trial(patient_id, trial)
        scored.append({
            **trial,
            "match_score": score_result.get("overall_score", 0.0),
            "eligible": score_result.get("eligible", False),
            "match_summary": score_result.get("summary", ""),
            "risk_flags": score_result.get("risk_flags", []),
        })

    scored.sort(key=lambda x: x["match_score"], reverse=True)
    return scored[:top_n]


def find_eligible_patients_for_trial(nct_id: str) -> list[dict]:
    """Screen all known patients against a specific trial."""
    trial = get_trial_details_sync(nct_id)
    if not trial:
        return []

    results = []
    for patient_id in get_all_patient_ids():
        score_result = score_patient_for_trial(patient_id, trial)
        if score_result.get("overall_score", 0) > 0.4:
            results.append({
                "patient_id": patient_id,
                "match_score": score_result.get("overall_score", 0.0),
                "eligible": score_result.get("eligible", False),
                "summary": score_result.get("summary", ""),
                "risk_flags": score_result.get("risk_flags", []),
            })

    results.sort(key=lambda x: x["match_score"], reverse=True)
    return results