CTA / backend /intake_matching.py
TheQuantEd's picture
Initial deployment: ClinicalMatch AI v2.0 β€” FHIR R4 Β· MCP (9 tools) Β· A2A workflow Β· SHARP compliance Β· 100k synthetic patients Β· Neo4j graph Β· GraphRAG chatbot
59abb4f
"""
Intake-based trial matching β€” accepts raw clinical data (SI units) and scores
it against Trial nodes in the graph. No patient ID required.
SI unit reference:
Hemoglobin: g/dL (Γ—10 β†’ g/L)
WBC: Γ—10⁹/L
ANC: Γ—10⁹/L
Platelets: Γ—10⁹/L
Creatinine: ΞΌmol/L (Γ·88.4 β†’ mg/dL)
eGFR: mL/min/1.73mΒ²
Bilirubin: ΞΌmol/L (Γ·17.1 β†’ mg/dL)
ALT/AST: U/L
Albumin: g/dL
"""
import re
import uuid
from typing import Optional
from neo4j_setup import neo4j_conn
# ── Biomarker registry ────────────────────────────────────────────────────────
# Maps graph node id β†’ human label β†’ search terms found in eligibility text
BIOMARKER_REGISTRY = {
"HER2_POS": ("HER2 Positive", ["HER2-positive", "HER2+", "HER2 amplified", "HER2/neu positive"]),
"HER2_NEG": ("HER2 Negative", ["HER2-negative", "HER2-"]),
"ER_POS": ("ER Positive", ["ER-positive", "ER+", "estrogen receptor positive"]),
"PR_POS": ("PR Positive", ["PR-positive", "PR+", "progesterone receptor positive"]),
"BRCA1_MUT": ("BRCA1 Mutation", ["BRCA1", "BRCA1 mutation", "BRCA1-mutated"]),
"BRCA2_MUT": ("BRCA2 Mutation", ["BRCA2", "BRCA2 mutation", "BRCA2-mutated"]),
"EGFR_MUT": ("EGFR Mutation", ["EGFR mutation", "EGFR-mutated", "EGFR exon 19", "EGFR exon 21"]),
"ALK_POS": ("ALK Rearrangement",["ALK rearrangement", "ALK-positive", "ALK fusion"]),
"ROS1_POS": ("ROS1 Rearrangement",["ROS1 rearrangement", "ROS1-positive", "ROS1 fusion"]),
"PD_L1_POS": ("PD-L1 Positive", ["PD-L1", "PD-L1 positive", "PDL1"]),
"KRAS_WT": ("KRAS Wild-type", ["KRAS wild-type", "KRAS WT", "KRAS-wildtype"]),
"BRAF_MUT": ("BRAF V600E", ["BRAF V600E", "BRAF mutation", "BRAF-mutated"]),
"MSI_H": ("MSI-High", ["MSI-H", "microsatellite instability-high", "MSI high", "dMMR"]),
"NRAS_MUT": ("NRAS Mutation", ["NRAS mutation", "NRAS-mutated"]),
"FLT3_MUT": ("FLT3 Mutation", ["FLT3 mutation", "FLT3-mutated", "FLT3-ITD"]),
"IDH1_MUT": ("IDH1 Mutation", ["IDH1 mutation", "IDH1-mutated"]),
"IDH2_MUT": ("IDH2 Mutation", ["IDH2 mutation", "IDH2-mutated"]),
"BCR_ABL": ("BCR-ABL", ["BCR-ABL", "Philadelphia chromosome", "Ph-positive"]),
"TRIPLE_NEG":("Triple Negative", ["triple-negative", "TNBC", "triple negative breast"]),
}
# ── Age parsing ───────────────────────────────────────────────────────────────
def _parse_age_years(age_str: str) -> Optional[int]:
"""'45 Years' β†’ 45, '6 Months' β†’ 0, '' β†’ None"""
if not age_str:
return None
m = re.search(r"(\d+)\s*year", age_str, re.I)
if m:
return int(m.group(1))
m = re.search(r"(\d+)\s*month", age_str, re.I)
if m:
return 0
m = re.search(r"(\d+)", age_str)
if m:
return int(m.group(1))
return None
# ── ECOG parsing from eligibility text ────────────────────────────────────────
def _max_ecog_from_text(text: str) -> Optional[int]:
"""Extract maximum allowed ECOG from eligibility criteria text."""
patterns = [
r"ECOG\s+(?:performance\s+status\s+)?(?:of\s+)?(?:0\s*(?:or|-)\s*)?([0-4])",
r"performance\s+status\s+(?:of\s+)?(?:0\s*(?:or|-)\s*)?([0-4])",
r"Karnofsky\s+.*?(\d{2,3})\s*%", # convert KPS to ECOG approximately
]
for pat in patterns:
m = re.search(pat, text, re.I)
if m:
val = int(m.group(1))
if "Karnofsky" in pat:
# KPS 80-100 β‰ˆ ECOG 0-1, 60-70 β‰ˆ 2, 40-50 β‰ˆ 3
kps = val
val = 0 if kps >= 80 else 1 if kps >= 70 else 2 if kps >= 60 else 3
return val
return None
# ── Lab value checking against eligibility text ───────────────────────────────
def _check_labs(labs: dict, eligibility_text: str) -> list[dict]:
"""
Parse common lab thresholds from eligibility text and check patient values.
Returns list of {criterion, patient_value, threshold, met}.
"""
results = []
text = eligibility_text or ""
def _find_threshold(patterns):
for pat in patterns:
m = re.search(pat, text, re.I)
if m:
return float(m.group(1))
return None
# Hemoglobin β‰₯ threshold (g/dL in text; patient value in g/dL)
hgb = labs.get("hemoglobin")
if hgb is not None:
# Try to find "hemoglobin >= X" or "Hgb >= X g/dL"
thresh = _find_threshold([
r"hemoglobin\s*[β‰₯>=]+\s*([\d.]+)\s*g/dL",
r"Hgb\s*[β‰₯>=]+\s*([\d.]+)",
r"hemoglobin\s+of\s+at\s+least\s+([\d.]+)",
])
if thresh:
results.append({"criterion": f"Hemoglobin β‰₯ {thresh} g/dL", "patient_value": f"{hgb} g/dL", "met": hgb >= thresh})
# Platelets β‰₯ threshold (Γ—10⁹/L)
plt = labs.get("platelets")
if plt is not None:
thresh = _find_threshold([
r"platelet[s]?\s*[β‰₯>=]+\s*([\d,]+)\s*[Γ—x]?\s*10[⁹9]/L",
r"platelet[s]?\s+count\s*[β‰₯>=]+\s*([\d,]+)",
r"platelet[s]?\s+of\s+at\s+least\s+([\d,]+)",
])
if thresh:
thresh_val = thresh / 1000 if thresh > 1000 else thresh # normalise if stored as /Β΅L
results.append({"criterion": f"Platelets β‰₯ {thresh_val} Γ—10⁹/L", "patient_value": f"{plt} Γ—10⁹/L", "met": plt >= thresh_val})
# Creatinine ≀ threshold (ΞΌmol/L patient; text may be mg/dL or ΞΌmol/L)
cr = labs.get("creatinine") # patient value in ΞΌmol/L
if cr is not None:
# Most trial text uses mg/dL; convert patient value for comparison
cr_mgdl = cr / 88.4
thresh = _find_threshold([
r"creatinine\s*[≀<=]+\s*([\d.]+)\s*mg/dL",
r"serum\s+creatinine\s*[≀<=]+\s*([\d.]+)",
])
if thresh:
results.append({"criterion": f"Creatinine ≀ {thresh} mg/dL ({round(thresh*88.4)} ΞΌmol/L)", "patient_value": f"{cr} ΞΌmol/L ({round(cr_mgdl, 2)} mg/dL)", "met": cr_mgdl <= thresh})
# eGFR β‰₯ threshold
egfr = labs.get("egfr")
if egfr is not None:
thresh = _find_threshold([
r"(?:eGFR|GFR|creatinine\s+clearance)\s*[β‰₯>=]+\s*([\d.]+)",
r"glomerular\s+filtration\s+rate\s*[β‰₯>=]+\s*([\d.]+)",
])
if thresh:
results.append({"criterion": f"eGFR β‰₯ {thresh} mL/min/1.73mΒ²", "patient_value": f"{egfr} mL/min", "met": egfr >= thresh})
# Bilirubin ≀ threshold (ΞΌmol/L patient; text usually mg/dL)
bili = labs.get("bilirubin")
if bili is not None:
bili_mgdl = bili / 17.1
thresh = _find_threshold([
r"(?:total\s+)?bilirubin\s*[≀<=]+\s*([\d.]+)\s*(?:Γ—\s*)?ULN",
r"(?:total\s+)?bilirubin\s*[≀<=]+\s*([\d.]+)\s*mg/dL",
])
if thresh:
# If "Γ— ULN", ULN for bilirubin β‰ˆ 1.0 mg/dL
results.append({"criterion": f"Bilirubin ≀ {thresh} mg/dL ({round(thresh*17.1)} ΞΌmol/L)", "patient_value": f"{bili} ΞΌmol/L ({round(bili_mgdl, 2)} mg/dL)", "met": bili_mgdl <= thresh})
# ANC β‰₯ threshold (Γ—10⁹/L)
anc = labs.get("anc")
if anc is not None:
thresh = _find_threshold([
r"(?:ANC|absolute\s+neutrophil\s+count)\s*[β‰₯>=]+\s*([\d.]+)\s*[Γ—x]?\s*10[⁹9]/L",
r"neutrophil[s]?\s*[β‰₯>=]+\s*([\d.]+)",
])
if thresh:
results.append({"criterion": f"ANC β‰₯ {thresh} Γ—10⁹/L", "patient_value": f"{anc} Γ—10⁹/L", "met": anc >= thresh})
return results
# ── Main scoring function ─────────────────────────────────────────────────────
def score_intake_against_trial(intake: dict, trial: dict) -> dict:
"""
Score a clinical intake profile against a single trial.
Returns {score, eligible, criteria_breakdown, risk_flags}.
"""
breakdown = []
risk_flags = []
points = 0
max_points = 0
age = intake.get("age")
sex = intake.get("sex", "").upper()
ecog = intake.get("ecog")
biomarkers = set(intake.get("biomarkers", []))
labs = intake.get("labs", {})
prior_chemo = intake.get("prior_chemo", False)
eligibility_text = trial.get("eligibility_criteria", "")
# ── Age (25 pts) ──────────────────────────────────────────────────────────
max_points += 25
min_age = _parse_age_years(trial.get("min_age", ""))
max_age = _parse_age_years(trial.get("max_age", ""))
if age is not None:
age_ok = True
note = ""
if min_age and age < min_age:
age_ok = False
note = f"Trial requires β‰₯{min_age} years"
risk_flags.append(f"Below minimum age ({age} < {min_age})")
if max_age and age > max_age:
age_ok = False
note = f"Trial requires ≀{max_age} years"
risk_flags.append(f"Above maximum age ({age} > {max_age})")
if age_ok:
points += 25
note = f"Within range ({min_age or 'β‰₯18'}–{max_age or 'no max'})"
breakdown.append({"criterion": "Age", "met": age_ok, "patient_value": f"{age} years", "note": note, "category": "demographics"})
# ── Sex (15 pts) ──────────────────────────────────────────────────────────
max_points += 15
trial_sex = (trial.get("sex") or "ALL").upper()
sex_ok = trial_sex in ("ALL", sex, "")
if not sex_ok:
risk_flags.append(f"Sex mismatch (trial requires {trial_sex})")
else:
points += 15
breakdown.append({"criterion": "Sex", "met": sex_ok, "patient_value": sex or "Not specified", "note": f"Trial: {trial_sex}", "category": "demographics"})
# ── ECOG (15 pts) ─────────────────────────────────────────────────────────
max_points += 15
max_ecog = _max_ecog_from_text(eligibility_text)
if ecog is not None and max_ecog is not None:
ecog_ok = ecog <= max_ecog
if not ecog_ok:
risk_flags.append(f"ECOG {ecog} exceeds trial max ({max_ecog})")
else:
points += 15
breakdown.append({"criterion": "ECOG Performance Status", "met": ecog_ok, "patient_value": f"ECOG {ecog}", "note": f"Trial requires ≀{max_ecog}", "category": "performance"})
elif ecog is not None:
points += 10 # partial credit β€” can't verify from text
breakdown.append({"criterion": "ECOG Performance Status", "met": None, "patient_value": f"ECOG {ecog}", "note": "Could not parse limit from trial text", "category": "performance"})
# ── Biomarkers (30 pts) ───────────────────────────────────────────────────
max_points += 30
if biomarkers:
matched_bm = []
for bm_id in biomarkers:
info = BIOMARKER_REGISTRY.get(bm_id)
if not info:
continue
label, search_terms = info
found_in_text = any(term.lower() in eligibility_text.lower() for term in search_terms)
matched_bm.append((label, found_in_text))
relevant = [m for m in matched_bm if m[1]]
if relevant:
points += 30
breakdown.append({
"criterion": "Biomarker Profile",
"met": True,
"patient_value": ", ".join(l for l, _ in relevant),
"note": f"{len(relevant)} of your biomarkers appear in trial criteria",
"category": "molecular",
})
elif matched_bm:
points += 5
breakdown.append({
"criterion": "Biomarker Profile",
"met": None,
"patient_value": ", ".join(l for l, _ in matched_bm),
"note": "None of your biomarkers explicitly appear in criteria",
"category": "molecular",
})
# ── Lab values (15 pts) ───────────────────────────────────────────────────
if labs:
max_points += 15
lab_results = _check_labs(labs, eligibility_text)
if lab_results:
all_ok = all(r["met"] for r in lab_results)
any_fail = any(not r["met"] for r in lab_results)
if all_ok:
points += 15
elif not any_fail:
points += 8
for r in lab_results:
if not r["met"]:
risk_flags.append(f"Lab out of range: {r['criterion']}")
for r in lab_results:
breakdown.append({
"criterion": r["criterion"],
"met": r["met"],
"patient_value": r["patient_value"],
"note": "",
"category": "labs",
})
else:
points += 8 # no parseable lab criteria β€” give partial credit
score = points / max_points if max_points > 0 else 0
eligible = score >= 0.65 and not any("mismatch" in f or "exceeds" in f for f in risk_flags)
return {
"score": round(score, 3),
"eligible": eligible,
"criteria_breakdown": breakdown,
"risk_flags": risk_flags,
"points": points,
"max_points": max_points,
}
# ── Graph query + batch scoring ───────────────────────────────────────────────
def match_intake_to_trials(intake: dict, condition: str, limit: int = 10) -> list[dict]:
"""
Query trials from the graph matching the condition, score each against intake,
return ranked list.
"""
rows = neo4j_conn.run_query(
"""
MATCH (t:Trial)
WHERE toLower(t.condition) CONTAINS toLower($condition)
AND t.status IN ['RECRUITING', 'NOT_YET_RECRUITING']
RETURN t.id AS nct_id, t.title AS title, t.phase AS phase,
t.condition AS condition, t.min_age AS min_age, t.max_age AS max_age,
t.sex AS sex, t.eligibility_criteria AS eligibility_criteria,
t.sponsor AS sponsor, t.location_count AS location_count,
t.last_updated AS last_updated, t.ctgov_url AS ctgov_url
LIMIT $limit
""",
{"condition": condition, "limit": limit * 3}, # over-fetch, then rank
)
if not rows:
return []
scored = []
for trial in rows:
result = score_intake_against_trial(intake, trial)
scored.append({
**trial,
**result,
})
scored.sort(key=lambda x: x["score"], reverse=True)
return scored[:limit]
def save_intake_as_patient(intake: dict) -> str:
"""Optionally persist the intake as a Patient node for long-term graph enrichment."""
pid = f"P_INTAKE_{uuid.uuid4().hex[:8].upper()}"
neo4j_conn.run_query(
"""
MERGE (p:Patient {id: $id})
SET p += {
age: $age, sex: $sex, ecog: $ecog, condition: $condition,
source: 'intake_form', created_at: datetime()
}
""",
{
"id": pid,
"age": intake.get("age"),
"sex": intake.get("sex", ""),
"ecog": intake.get("ecog"),
"condition": intake.get("condition", ""),
},
)
for bm_id in intake.get("biomarkers", []):
neo4j_conn.run_query(
"""
MATCH (p:Patient {id: $pid})
MERGE (b:Biomarker {id: $bm_id})
ON CREATE SET b.name = $name
MERGE (p)-[:HAS_BIOMARKER]->(b)
""",
{"pid": pid, "bm_id": bm_id, "name": BIOMARKER_REGISTRY.get(bm_id, (bm_id,))[0]},
)
return pid