""" Intake-based trial matching — accepts raw clinical data (SI units) and scores it against Trial nodes in the graph. No patient ID required. SI unit reference: Hemoglobin: g/dL (×10 → g/L) WBC: ×10⁹/L ANC: ×10⁹/L Platelets: ×10⁹/L Creatinine: μmol/L (÷88.4 → mg/dL) eGFR: mL/min/1.73m² Bilirubin: μmol/L (÷17.1 → mg/dL) ALT/AST: U/L Albumin: g/dL """ import re import uuid from typing import Optional from neo4j_setup import neo4j_conn # ── Biomarker registry ──────────────────────────────────────────────────────── # Maps graph node id → human label → search terms found in eligibility text BIOMARKER_REGISTRY = { "HER2_POS": ("HER2 Positive", ["HER2-positive", "HER2+", "HER2 amplified", "HER2/neu positive"]), "HER2_NEG": ("HER2 Negative", ["HER2-negative", "HER2-"]), "ER_POS": ("ER Positive", ["ER-positive", "ER+", "estrogen receptor positive"]), "PR_POS": ("PR Positive", ["PR-positive", "PR+", "progesterone receptor positive"]), "BRCA1_MUT": ("BRCA1 Mutation", ["BRCA1", "BRCA1 mutation", "BRCA1-mutated"]), "BRCA2_MUT": ("BRCA2 Mutation", ["BRCA2", "BRCA2 mutation", "BRCA2-mutated"]), "EGFR_MUT": ("EGFR Mutation", ["EGFR mutation", "EGFR-mutated", "EGFR exon 19", "EGFR exon 21"]), "ALK_POS": ("ALK Rearrangement",["ALK rearrangement", "ALK-positive", "ALK fusion"]), "ROS1_POS": ("ROS1 Rearrangement",["ROS1 rearrangement", "ROS1-positive", "ROS1 fusion"]), "PD_L1_POS": ("PD-L1 Positive", ["PD-L1", "PD-L1 positive", "PDL1"]), "KRAS_WT": ("KRAS Wild-type", ["KRAS wild-type", "KRAS WT", "KRAS-wildtype"]), "BRAF_MUT": ("BRAF V600E", ["BRAF V600E", "BRAF mutation", "BRAF-mutated"]), "MSI_H": ("MSI-High", ["MSI-H", "microsatellite instability-high", "MSI high", "dMMR"]), "NRAS_MUT": ("NRAS Mutation", ["NRAS mutation", "NRAS-mutated"]), "FLT3_MUT": ("FLT3 Mutation", ["FLT3 mutation", "FLT3-mutated", "FLT3-ITD"]), "IDH1_MUT": ("IDH1 Mutation", ["IDH1 mutation", "IDH1-mutated"]), "IDH2_MUT": ("IDH2 Mutation", ["IDH2 mutation", "IDH2-mutated"]), "BCR_ABL": ("BCR-ABL", ["BCR-ABL", "Philadelphia chromosome", "Ph-positive"]), "TRIPLE_NEG":("Triple Negative", ["triple-negative", "TNBC", "triple negative breast"]), } # ── Age parsing ─────────────────────────────────────────────────────────────── def _parse_age_years(age_str: str) -> Optional[int]: """'45 Years' → 45, '6 Months' → 0, '' → None""" if not age_str: return None m = re.search(r"(\d+)\s*year", age_str, re.I) if m: return int(m.group(1)) m = re.search(r"(\d+)\s*month", age_str, re.I) if m: return 0 m = re.search(r"(\d+)", age_str) if m: return int(m.group(1)) return None # ── ECOG parsing from eligibility text ──────────────────────────────────────── def _max_ecog_from_text(text: str) -> Optional[int]: """Extract maximum allowed ECOG from eligibility criteria text.""" patterns = [ r"ECOG\s+(?:performance\s+status\s+)?(?:of\s+)?(?:0\s*(?:or|-)\s*)?([0-4])", r"performance\s+status\s+(?:of\s+)?(?:0\s*(?:or|-)\s*)?([0-4])", r"Karnofsky\s+.*?(\d{2,3})\s*%", # convert KPS to ECOG approximately ] for pat in patterns: m = re.search(pat, text, re.I) if m: val = int(m.group(1)) if "Karnofsky" in pat: # KPS 80-100 ≈ ECOG 0-1, 60-70 ≈ 2, 40-50 ≈ 3 kps = val val = 0 if kps >= 80 else 1 if kps >= 70 else 2 if kps >= 60 else 3 return val return None # ── Lab value checking against eligibility text ─────────────────────────────── def _check_labs(labs: dict, eligibility_text: str) -> list[dict]: """ Parse common lab thresholds from eligibility text and check patient values. Returns list of {criterion, patient_value, threshold, met}. """ results = [] text = eligibility_text or "" def _find_threshold(patterns): for pat in patterns: m = re.search(pat, text, re.I) if m: return float(m.group(1)) return None # Hemoglobin ≥ threshold (g/dL in text; patient value in g/dL) hgb = labs.get("hemoglobin") if hgb is not None: # Try to find "hemoglobin >= X" or "Hgb >= X g/dL" thresh = _find_threshold([ r"hemoglobin\s*[≥>=]+\s*([\d.]+)\s*g/dL", r"Hgb\s*[≥>=]+\s*([\d.]+)", r"hemoglobin\s+of\s+at\s+least\s+([\d.]+)", ]) if thresh: results.append({"criterion": f"Hemoglobin ≥ {thresh} g/dL", "patient_value": f"{hgb} g/dL", "met": hgb >= thresh}) # Platelets ≥ threshold (×10⁹/L) plt = labs.get("platelets") if plt is not None: thresh = _find_threshold([ r"platelet[s]?\s*[≥>=]+\s*([\d,]+)\s*[×x]?\s*10[⁹9]/L", r"platelet[s]?\s+count\s*[≥>=]+\s*([\d,]+)", r"platelet[s]?\s+of\s+at\s+least\s+([\d,]+)", ]) if thresh: thresh_val = thresh / 1000 if thresh > 1000 else thresh # normalise if stored as /µL results.append({"criterion": f"Platelets ≥ {thresh_val} ×10⁹/L", "patient_value": f"{plt} ×10⁹/L", "met": plt >= thresh_val}) # Creatinine ≤ threshold (μmol/L patient; text may be mg/dL or μmol/L) cr = labs.get("creatinine") # patient value in μmol/L if cr is not None: # Most trial text uses mg/dL; convert patient value for comparison cr_mgdl = cr / 88.4 thresh = _find_threshold([ r"creatinine\s*[≤<=]+\s*([\d.]+)\s*mg/dL", r"serum\s+creatinine\s*[≤<=]+\s*([\d.]+)", ]) if thresh: results.append({"criterion": f"Creatinine ≤ {thresh} mg/dL ({round(thresh*88.4)} μmol/L)", "patient_value": f"{cr} μmol/L ({round(cr_mgdl, 2)} mg/dL)", "met": cr_mgdl <= thresh}) # eGFR ≥ threshold egfr = labs.get("egfr") if egfr is not None: thresh = _find_threshold([ r"(?:eGFR|GFR|creatinine\s+clearance)\s*[≥>=]+\s*([\d.]+)", r"glomerular\s+filtration\s+rate\s*[≥>=]+\s*([\d.]+)", ]) if thresh: results.append({"criterion": f"eGFR ≥ {thresh} mL/min/1.73m²", "patient_value": f"{egfr} mL/min", "met": egfr >= thresh}) # Bilirubin ≤ threshold (μmol/L patient; text usually mg/dL) bili = labs.get("bilirubin") if bili is not None: bili_mgdl = bili / 17.1 thresh = _find_threshold([ r"(?:total\s+)?bilirubin\s*[≤<=]+\s*([\d.]+)\s*(?:×\s*)?ULN", r"(?:total\s+)?bilirubin\s*[≤<=]+\s*([\d.]+)\s*mg/dL", ]) if thresh: # If "× ULN", ULN for bilirubin ≈ 1.0 mg/dL results.append({"criterion": f"Bilirubin ≤ {thresh} mg/dL ({round(thresh*17.1)} μmol/L)", "patient_value": f"{bili} μmol/L ({round(bili_mgdl, 2)} mg/dL)", "met": bili_mgdl <= thresh}) # ANC ≥ threshold (×10⁹/L) anc = labs.get("anc") if anc is not None: thresh = _find_threshold([ r"(?:ANC|absolute\s+neutrophil\s+count)\s*[≥>=]+\s*([\d.]+)\s*[×x]?\s*10[⁹9]/L", r"neutrophil[s]?\s*[≥>=]+\s*([\d.]+)", ]) if thresh: results.append({"criterion": f"ANC ≥ {thresh} ×10⁹/L", "patient_value": f"{anc} ×10⁹/L", "met": anc >= thresh}) return results # ── Main scoring function ───────────────────────────────────────────────────── def score_intake_against_trial(intake: dict, trial: dict) -> dict: """ Score a clinical intake profile against a single trial. Returns {score, eligible, criteria_breakdown, risk_flags}. """ breakdown = [] risk_flags = [] points = 0 max_points = 0 age = intake.get("age") sex = intake.get("sex", "").upper() ecog = intake.get("ecog") biomarkers = set(intake.get("biomarkers", [])) labs = intake.get("labs", {}) prior_chemo = intake.get("prior_chemo", False) eligibility_text = trial.get("eligibility_criteria", "") # ── Age (25 pts) ────────────────────────────────────────────────────────── max_points += 25 min_age = _parse_age_years(trial.get("min_age", "")) max_age = _parse_age_years(trial.get("max_age", "")) if age is not None: age_ok = True note = "" if min_age and age < min_age: age_ok = False note = f"Trial requires ≥{min_age} years" risk_flags.append(f"Below minimum age ({age} < {min_age})") if max_age and age > max_age: age_ok = False note = f"Trial requires ≤{max_age} years" risk_flags.append(f"Above maximum age ({age} > {max_age})") if age_ok: points += 25 note = f"Within range ({min_age or '≥18'}–{max_age or 'no max'})" breakdown.append({"criterion": "Age", "met": age_ok, "patient_value": f"{age} years", "note": note, "category": "demographics"}) # ── Sex (15 pts) ────────────────────────────────────────────────────────── max_points += 15 trial_sex = (trial.get("sex") or "ALL").upper() sex_ok = trial_sex in ("ALL", sex, "") if not sex_ok: risk_flags.append(f"Sex mismatch (trial requires {trial_sex})") else: points += 15 breakdown.append({"criterion": "Sex", "met": sex_ok, "patient_value": sex or "Not specified", "note": f"Trial: {trial_sex}", "category": "demographics"}) # ── ECOG (15 pts) ───────────────────────────────────────────────────────── max_points += 15 max_ecog = _max_ecog_from_text(eligibility_text) if ecog is not None and max_ecog is not None: ecog_ok = ecog <= max_ecog if not ecog_ok: risk_flags.append(f"ECOG {ecog} exceeds trial max ({max_ecog})") else: points += 15 breakdown.append({"criterion": "ECOG Performance Status", "met": ecog_ok, "patient_value": f"ECOG {ecog}", "note": f"Trial requires ≤{max_ecog}", "category": "performance"}) elif ecog is not None: points += 10 # partial credit — can't verify from text breakdown.append({"criterion": "ECOG Performance Status", "met": None, "patient_value": f"ECOG {ecog}", "note": "Could not parse limit from trial text", "category": "performance"}) # ── Biomarkers (30 pts) ─────────────────────────────────────────────────── max_points += 30 if biomarkers: matched_bm = [] for bm_id in biomarkers: info = BIOMARKER_REGISTRY.get(bm_id) if not info: continue label, search_terms = info found_in_text = any(term.lower() in eligibility_text.lower() for term in search_terms) matched_bm.append((label, found_in_text)) relevant = [m for m in matched_bm if m[1]] if relevant: points += 30 breakdown.append({ "criterion": "Biomarker Profile", "met": True, "patient_value": ", ".join(l for l, _ in relevant), "note": f"{len(relevant)} of your biomarkers appear in trial criteria", "category": "molecular", }) elif matched_bm: points += 5 breakdown.append({ "criterion": "Biomarker Profile", "met": None, "patient_value": ", ".join(l for l, _ in matched_bm), "note": "None of your biomarkers explicitly appear in criteria", "category": "molecular", }) # ── Lab values (15 pts) ─────────────────────────────────────────────────── if labs: max_points += 15 lab_results = _check_labs(labs, eligibility_text) if lab_results: all_ok = all(r["met"] for r in lab_results) any_fail = any(not r["met"] for r in lab_results) if all_ok: points += 15 elif not any_fail: points += 8 for r in lab_results: if not r["met"]: risk_flags.append(f"Lab out of range: {r['criterion']}") for r in lab_results: breakdown.append({ "criterion": r["criterion"], "met": r["met"], "patient_value": r["patient_value"], "note": "", "category": "labs", }) else: points += 8 # no parseable lab criteria — give partial credit score = points / max_points if max_points > 0 else 0 eligible = score >= 0.65 and not any("mismatch" in f or "exceeds" in f for f in risk_flags) return { "score": round(score, 3), "eligible": eligible, "criteria_breakdown": breakdown, "risk_flags": risk_flags, "points": points, "max_points": max_points, } # ── Graph query + batch scoring ─────────────────────────────────────────────── def match_intake_to_trials(intake: dict, condition: str, limit: int = 10) -> list[dict]: """ Query trials from the graph matching the condition, score each against intake, return ranked list. """ rows = neo4j_conn.run_query( """ MATCH (t:Trial) WHERE toLower(t.condition) CONTAINS toLower($condition) AND t.status IN ['RECRUITING', 'NOT_YET_RECRUITING'] RETURN t.id AS nct_id, t.title AS title, t.phase AS phase, t.condition AS condition, t.min_age AS min_age, t.max_age AS max_age, t.sex AS sex, t.eligibility_criteria AS eligibility_criteria, t.sponsor AS sponsor, t.location_count AS location_count, t.last_updated AS last_updated, t.ctgov_url AS ctgov_url LIMIT $limit """, {"condition": condition, "limit": limit * 3}, # over-fetch, then rank ) if not rows: return [] scored = [] for trial in rows: result = score_intake_against_trial(intake, trial) scored.append({ **trial, **result, }) scored.sort(key=lambda x: x["score"], reverse=True) return scored[:limit] def save_intake_as_patient(intake: dict) -> str: """Optionally persist the intake as a Patient node for long-term graph enrichment.""" pid = f"P_INTAKE_{uuid.uuid4().hex[:8].upper()}" neo4j_conn.run_query( """ MERGE (p:Patient {id: $id}) SET p += { age: $age, sex: $sex, ecog: $ecog, condition: $condition, source: 'intake_form', created_at: datetime() } """, { "id": pid, "age": intake.get("age"), "sex": intake.get("sex", ""), "ecog": intake.get("ecog"), "condition": intake.get("condition", ""), }, ) for bm_id in intake.get("biomarkers", []): neo4j_conn.run_query( """ MATCH (p:Patient {id: $pid}) MERGE (b:Biomarker {id: $bm_id}) ON CREATE SET b.name = $name MERGE (p)-[:HAS_BIOMARKER]->(b) """, {"pid": pid, "bm_id": bm_id, "name": BIOMARKER_REGISTRY.get(bm_id, (bm_id,))[0]}, ) return pid