Spaces:
Running
Running
File size: 7,728 Bytes
59abb4f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | from fhir_adapter import get_patient_profile, get_all_patient_ids
from clinicaltrials_api import search_trials_sync, get_trial_details_sync
from llm_client import parse_trial_protocol, score_patient_against_criteria
import re
try:
from neo4j_setup import neo4j_conn as _neo4j
except Exception:
_neo4j = None
# In-memory cache for parsed criteria and scores
_criteria_cache: dict[str, dict] = {}
_score_cache: dict[str, dict] = {}
def _parse_age_string(age_str: str) -> int | None:
if not age_str:
return None
match = re.search(r"(\d+)", age_str)
return int(match.group(1)) if match else None
def _quick_eligibility_check(patient_profile: dict, trial: dict) -> tuple[bool, list[str]]:
"""Rule-based pre-filter before expensive LLM scoring."""
flags = []
age = patient_profile.get("age", 0)
min_age = _parse_age_string(trial.get("min_age", ""))
max_age = _parse_age_string(trial.get("max_age", ""))
if min_age and age < min_age:
flags.append(f"Age {age} below minimum {min_age}")
if max_age and age > max_age:
flags.append(f"Age {age} above maximum {max_age}")
trial_sex = trial.get("sex", "ALL").upper()
patient_sex = patient_profile.get("gender", "").upper()
if trial_sex not in ("ALL", "BOTH") and patient_sex and patient_sex[0] != trial_sex[0]:
flags.append(f"Sex mismatch: trial requires {trial_sex}")
return len(flags) == 0, flags
def get_criteria_for_trial(trial: dict) -> dict:
nct_id = trial.get("nct_id", "")
if nct_id in _criteria_cache:
return _criteria_cache[nct_id]
eligibility_text = trial.get("eligibility_criteria", "")
if eligibility_text:
criteria = parse_trial_protocol(eligibility_text)
else:
criteria = {
"inclusion_criteria": [f"Confirmed diagnosis of {trial.get('brief_summary', 'target condition')[:50]}"],
"exclusion_criteria": ["Prior participation in conflicting trials"],
"age_range": {"min": 18, "max": None},
"required_diagnoses": [],
"required_biomarkers": [],
"excluded_medications": [],
"performance_status": None,
}
_criteria_cache[nct_id] = criteria
return criteria
def score_patient_for_trial(patient_id: str, trial: dict) -> dict:
cache_key = f"{patient_id}:{trial.get('nct_id', '')}"
if cache_key in _score_cache:
return _score_cache[cache_key]
patient_profile = get_patient_profile(patient_id)
if not patient_profile:
return {"error": "Patient not found", "overall_score": 0.0, "eligible": False}
# Quick rule-based pre-filter
passes_rules, rule_flags = _quick_eligibility_check(patient_profile, trial)
criteria = get_criteria_for_trial(trial)
result = score_patient_against_criteria(patient_profile, criteria, trial.get("title", "Clinical Trial"))
if not passes_rules:
result["overall_score"] = max(0.0, result.get("overall_score", 0.5) - 0.3)
result["eligible"] = False
result.setdefault("risk_flags", []).extend(rule_flags)
result["patient_id"] = patient_id
result["nct_id"] = trial.get("nct_id", "")
result["trial_title"] = trial.get("title", "")
result["match_path"] = _build_match_path(patient_profile, trial, criteria)
_score_cache[cache_key] = result
return result
def _build_match_path(patient_profile: dict, trial: dict, criteria: dict) -> list[dict]:
"""
Build a human-readable graph explainability path showing WHY a patient was matched.
Returns a list of path nodes: Patient → biomarker/diagnosis/lab → Trial
"""
path = []
patient_id = patient_profile.get("patient_id", "")
nct_id = trial.get("nct_id", "")
trial_title = trial.get("title", "")[:60]
# Check graph for shared biomarker edges
if _neo4j:
try:
rows = _neo4j.run_query(
"""
MATCH (p:Patient {id: $pid})-[:HAS_BIOMARKER]->(b:Biomarker)
MATCH (t:Trial {id: $nct_id})
WHERE t.parsed_biomarkers CONTAINS b.name OR t.eligibility_criteria CONTAINS b.name
RETURN b.name AS biomarker LIMIT 3
""",
{"pid": patient_id, "nct_id": nct_id},
)
for row in rows:
path.append({
"from": f"Patient:{patient_id}",
"rel": "HAS_BIOMARKER",
"to": f"Biomarker:{row['biomarker']}",
"note": "required by trial",
})
except Exception:
pass
# Add FHIR-based reasoning nodes from the criteria match
for item in (criteria.get("required_biomarkers") or [])[:2]:
biomarkers = patient_profile.get("biomarkers", {})
if any(item.lower() in str(k).lower() or item.lower() in str(v).lower()
for k, v in biomarkers.items()):
path.append({
"from": f"Patient:{patient_id}",
"rel": "HAS_BIOMARKER",
"to": f"Biomarker:{item}",
"note": "matches trial requirement",
})
for dx in (criteria.get("required_diagnoses") or [])[:2]:
for patient_dx in patient_profile.get("diagnosis_names", []):
if any(word in patient_dx.lower() for word in dx.lower().split()):
path.append({
"from": f"Patient:{patient_id}",
"rel": "HAS_DIAGNOSIS",
"to": f"Diagnosis:{patient_dx}",
"note": f"matches required: {dx}",
})
break
# Terminal node
path.append({
"from": f"Patient:{patient_id}",
"rel": "ELIGIBLE_FOR",
"to": f"Trial:{nct_id}",
"note": trial_title,
})
return path
def match_patient_to_trials(patient_id: str, condition: str | None = None, top_n: int = 5) -> list[dict]:
"""Find best-matching trials for a patient."""
patient_profile = get_patient_profile(patient_id)
if not patient_profile:
return []
# Infer condition from patient diagnoses if not provided
if not condition and patient_profile.get("diagnosis_names"):
condition = patient_profile["diagnosis_names"][0]
elif not condition:
condition = "cancer"
trials = search_trials_sync(condition, page_size=10)
scored = []
for trial in trials:
score_result = score_patient_for_trial(patient_id, trial)
scored.append({
**trial,
"match_score": score_result.get("overall_score", 0.0),
"eligible": score_result.get("eligible", False),
"match_summary": score_result.get("summary", ""),
"risk_flags": score_result.get("risk_flags", []),
})
scored.sort(key=lambda x: x["match_score"], reverse=True)
return scored[:top_n]
def find_eligible_patients_for_trial(nct_id: str) -> list[dict]:
"""Screen all known patients against a specific trial."""
trial = get_trial_details_sync(nct_id)
if not trial:
return []
results = []
for patient_id in get_all_patient_ids():
score_result = score_patient_for_trial(patient_id, trial)
if score_result.get("overall_score", 0) > 0.4:
results.append({
"patient_id": patient_id,
"match_score": score_result.get("overall_score", 0.0),
"eligible": score_result.get("eligible", False),
"summary": score_result.get("summary", ""),
"risk_flags": score_result.get("risk_flags", []),
})
results.sort(key=lambda x: x["match_score"], reverse=True)
return results
|