CTA / backend /matching_engine.py
TheQuantEd's picture
Initial deployment: ClinicalMatch AI v2.0 — FHIR R4 · MCP (9 tools) · A2A workflow · SHARP compliance · 100k synthetic patients · Neo4j graph · GraphRAG chatbot
59abb4f
from fhir_adapter import get_patient_profile, get_all_patient_ids
from clinicaltrials_api import search_trials_sync, get_trial_details_sync
from llm_client import parse_trial_protocol, score_patient_against_criteria
import re
try:
from neo4j_setup import neo4j_conn as _neo4j
except Exception:
_neo4j = None
# In-memory cache for parsed criteria and scores
_criteria_cache: dict[str, dict] = {}
_score_cache: dict[str, dict] = {}
def _parse_age_string(age_str: str) -> int | None:
if not age_str:
return None
match = re.search(r"(\d+)", age_str)
return int(match.group(1)) if match else None
def _quick_eligibility_check(patient_profile: dict, trial: dict) -> tuple[bool, list[str]]:
"""Rule-based pre-filter before expensive LLM scoring."""
flags = []
age = patient_profile.get("age", 0)
min_age = _parse_age_string(trial.get("min_age", ""))
max_age = _parse_age_string(trial.get("max_age", ""))
if min_age and age < min_age:
flags.append(f"Age {age} below minimum {min_age}")
if max_age and age > max_age:
flags.append(f"Age {age} above maximum {max_age}")
trial_sex = trial.get("sex", "ALL").upper()
patient_sex = patient_profile.get("gender", "").upper()
if trial_sex not in ("ALL", "BOTH") and patient_sex and patient_sex[0] != trial_sex[0]:
flags.append(f"Sex mismatch: trial requires {trial_sex}")
return len(flags) == 0, flags
def get_criteria_for_trial(trial: dict) -> dict:
nct_id = trial.get("nct_id", "")
if nct_id in _criteria_cache:
return _criteria_cache[nct_id]
eligibility_text = trial.get("eligibility_criteria", "")
if eligibility_text:
criteria = parse_trial_protocol(eligibility_text)
else:
criteria = {
"inclusion_criteria": [f"Confirmed diagnosis of {trial.get('brief_summary', 'target condition')[:50]}"],
"exclusion_criteria": ["Prior participation in conflicting trials"],
"age_range": {"min": 18, "max": None},
"required_diagnoses": [],
"required_biomarkers": [],
"excluded_medications": [],
"performance_status": None,
}
_criteria_cache[nct_id] = criteria
return criteria
def score_patient_for_trial(patient_id: str, trial: dict) -> dict:
cache_key = f"{patient_id}:{trial.get('nct_id', '')}"
if cache_key in _score_cache:
return _score_cache[cache_key]
patient_profile = get_patient_profile(patient_id)
if not patient_profile:
return {"error": "Patient not found", "overall_score": 0.0, "eligible": False}
# Quick rule-based pre-filter
passes_rules, rule_flags = _quick_eligibility_check(patient_profile, trial)
criteria = get_criteria_for_trial(trial)
result = score_patient_against_criteria(patient_profile, criteria, trial.get("title", "Clinical Trial"))
if not passes_rules:
result["overall_score"] = max(0.0, result.get("overall_score", 0.5) - 0.3)
result["eligible"] = False
result.setdefault("risk_flags", []).extend(rule_flags)
result["patient_id"] = patient_id
result["nct_id"] = trial.get("nct_id", "")
result["trial_title"] = trial.get("title", "")
result["match_path"] = _build_match_path(patient_profile, trial, criteria)
_score_cache[cache_key] = result
return result
def _build_match_path(patient_profile: dict, trial: dict, criteria: dict) -> list[dict]:
"""
Build a human-readable graph explainability path showing WHY a patient was matched.
Returns a list of path nodes: Patient → biomarker/diagnosis/lab → Trial
"""
path = []
patient_id = patient_profile.get("patient_id", "")
nct_id = trial.get("nct_id", "")
trial_title = trial.get("title", "")[:60]
# Check graph for shared biomarker edges
if _neo4j:
try:
rows = _neo4j.run_query(
"""
MATCH (p:Patient {id: $pid})-[:HAS_BIOMARKER]->(b:Biomarker)
MATCH (t:Trial {id: $nct_id})
WHERE t.parsed_biomarkers CONTAINS b.name OR t.eligibility_criteria CONTAINS b.name
RETURN b.name AS biomarker LIMIT 3
""",
{"pid": patient_id, "nct_id": nct_id},
)
for row in rows:
path.append({
"from": f"Patient:{patient_id}",
"rel": "HAS_BIOMARKER",
"to": f"Biomarker:{row['biomarker']}",
"note": "required by trial",
})
except Exception:
pass
# Add FHIR-based reasoning nodes from the criteria match
for item in (criteria.get("required_biomarkers") or [])[:2]:
biomarkers = patient_profile.get("biomarkers", {})
if any(item.lower() in str(k).lower() or item.lower() in str(v).lower()
for k, v in biomarkers.items()):
path.append({
"from": f"Patient:{patient_id}",
"rel": "HAS_BIOMARKER",
"to": f"Biomarker:{item}",
"note": "matches trial requirement",
})
for dx in (criteria.get("required_diagnoses") or [])[:2]:
for patient_dx in patient_profile.get("diagnosis_names", []):
if any(word in patient_dx.lower() for word in dx.lower().split()):
path.append({
"from": f"Patient:{patient_id}",
"rel": "HAS_DIAGNOSIS",
"to": f"Diagnosis:{patient_dx}",
"note": f"matches required: {dx}",
})
break
# Terminal node
path.append({
"from": f"Patient:{patient_id}",
"rel": "ELIGIBLE_FOR",
"to": f"Trial:{nct_id}",
"note": trial_title,
})
return path
def match_patient_to_trials(patient_id: str, condition: str | None = None, top_n: int = 5) -> list[dict]:
"""Find best-matching trials for a patient."""
patient_profile = get_patient_profile(patient_id)
if not patient_profile:
return []
# Infer condition from patient diagnoses if not provided
if not condition and patient_profile.get("diagnosis_names"):
condition = patient_profile["diagnosis_names"][0]
elif not condition:
condition = "cancer"
trials = search_trials_sync(condition, page_size=10)
scored = []
for trial in trials:
score_result = score_patient_for_trial(patient_id, trial)
scored.append({
**trial,
"match_score": score_result.get("overall_score", 0.0),
"eligible": score_result.get("eligible", False),
"match_summary": score_result.get("summary", ""),
"risk_flags": score_result.get("risk_flags", []),
})
scored.sort(key=lambda x: x["match_score"], reverse=True)
return scored[:top_n]
def find_eligible_patients_for_trial(nct_id: str) -> list[dict]:
"""Screen all known patients against a specific trial."""
trial = get_trial_details_sync(nct_id)
if not trial:
return []
results = []
for patient_id in get_all_patient_ids():
score_result = score_patient_for_trial(patient_id, trial)
if score_result.get("overall_score", 0) > 0.4:
results.append({
"patient_id": patient_id,
"match_score": score_result.get("overall_score", 0.0),
"eligible": score_result.get("eligible", False),
"summary": score_result.get("summary", ""),
"risk_flags": score_result.get("risk_flags", []),
})
results.sort(key=lambda x: x["match_score"], reverse=True)
return results