Spaces:
Running
Running
Initial deployment: ClinicalMatch AI v2.0 — FHIR R4 · MCP (9 tools) · A2A workflow · SHARP compliance · 100k synthetic patients · Neo4j graph · GraphRAG chatbot
59abb4f | from fhir_adapter import get_patient_profile, get_all_patient_ids | |
| from clinicaltrials_api import search_trials_sync, get_trial_details_sync | |
| from llm_client import parse_trial_protocol, score_patient_against_criteria | |
| import re | |
| try: | |
| from neo4j_setup import neo4j_conn as _neo4j | |
| except Exception: | |
| _neo4j = None | |
| # In-memory cache for parsed criteria and scores | |
| _criteria_cache: dict[str, dict] = {} | |
| _score_cache: dict[str, dict] = {} | |
| def _parse_age_string(age_str: str) -> int | None: | |
| if not age_str: | |
| return None | |
| match = re.search(r"(\d+)", age_str) | |
| return int(match.group(1)) if match else None | |
| def _quick_eligibility_check(patient_profile: dict, trial: dict) -> tuple[bool, list[str]]: | |
| """Rule-based pre-filter before expensive LLM scoring.""" | |
| flags = [] | |
| age = patient_profile.get("age", 0) | |
| min_age = _parse_age_string(trial.get("min_age", "")) | |
| max_age = _parse_age_string(trial.get("max_age", "")) | |
| if min_age and age < min_age: | |
| flags.append(f"Age {age} below minimum {min_age}") | |
| if max_age and age > max_age: | |
| flags.append(f"Age {age} above maximum {max_age}") | |
| trial_sex = trial.get("sex", "ALL").upper() | |
| patient_sex = patient_profile.get("gender", "").upper() | |
| if trial_sex not in ("ALL", "BOTH") and patient_sex and patient_sex[0] != trial_sex[0]: | |
| flags.append(f"Sex mismatch: trial requires {trial_sex}") | |
| return len(flags) == 0, flags | |
| def get_criteria_for_trial(trial: dict) -> dict: | |
| nct_id = trial.get("nct_id", "") | |
| if nct_id in _criteria_cache: | |
| return _criteria_cache[nct_id] | |
| eligibility_text = trial.get("eligibility_criteria", "") | |
| if eligibility_text: | |
| criteria = parse_trial_protocol(eligibility_text) | |
| else: | |
| criteria = { | |
| "inclusion_criteria": [f"Confirmed diagnosis of {trial.get('brief_summary', 'target condition')[:50]}"], | |
| "exclusion_criteria": ["Prior participation in conflicting trials"], | |
| "age_range": {"min": 18, "max": None}, | |
| "required_diagnoses": [], | |
| "required_biomarkers": [], | |
| "excluded_medications": [], | |
| "performance_status": None, | |
| } | |
| _criteria_cache[nct_id] = criteria | |
| return criteria | |
| def score_patient_for_trial(patient_id: str, trial: dict) -> dict: | |
| cache_key = f"{patient_id}:{trial.get('nct_id', '')}" | |
| if cache_key in _score_cache: | |
| return _score_cache[cache_key] | |
| patient_profile = get_patient_profile(patient_id) | |
| if not patient_profile: | |
| return {"error": "Patient not found", "overall_score": 0.0, "eligible": False} | |
| # Quick rule-based pre-filter | |
| passes_rules, rule_flags = _quick_eligibility_check(patient_profile, trial) | |
| criteria = get_criteria_for_trial(trial) | |
| result = score_patient_against_criteria(patient_profile, criteria, trial.get("title", "Clinical Trial")) | |
| if not passes_rules: | |
| result["overall_score"] = max(0.0, result.get("overall_score", 0.5) - 0.3) | |
| result["eligible"] = False | |
| result.setdefault("risk_flags", []).extend(rule_flags) | |
| result["patient_id"] = patient_id | |
| result["nct_id"] = trial.get("nct_id", "") | |
| result["trial_title"] = trial.get("title", "") | |
| result["match_path"] = _build_match_path(patient_profile, trial, criteria) | |
| _score_cache[cache_key] = result | |
| return result | |
| def _build_match_path(patient_profile: dict, trial: dict, criteria: dict) -> list[dict]: | |
| """ | |
| Build a human-readable graph explainability path showing WHY a patient was matched. | |
| Returns a list of path nodes: Patient → biomarker/diagnosis/lab → Trial | |
| """ | |
| path = [] | |
| patient_id = patient_profile.get("patient_id", "") | |
| nct_id = trial.get("nct_id", "") | |
| trial_title = trial.get("title", "")[:60] | |
| # Check graph for shared biomarker edges | |
| if _neo4j: | |
| try: | |
| rows = _neo4j.run_query( | |
| """ | |
| MATCH (p:Patient {id: $pid})-[:HAS_BIOMARKER]->(b:Biomarker) | |
| MATCH (t:Trial {id: $nct_id}) | |
| WHERE t.parsed_biomarkers CONTAINS b.name OR t.eligibility_criteria CONTAINS b.name | |
| RETURN b.name AS biomarker LIMIT 3 | |
| """, | |
| {"pid": patient_id, "nct_id": nct_id}, | |
| ) | |
| for row in rows: | |
| path.append({ | |
| "from": f"Patient:{patient_id}", | |
| "rel": "HAS_BIOMARKER", | |
| "to": f"Biomarker:{row['biomarker']}", | |
| "note": "required by trial", | |
| }) | |
| except Exception: | |
| pass | |
| # Add FHIR-based reasoning nodes from the criteria match | |
| for item in (criteria.get("required_biomarkers") or [])[:2]: | |
| biomarkers = patient_profile.get("biomarkers", {}) | |
| if any(item.lower() in str(k).lower() or item.lower() in str(v).lower() | |
| for k, v in biomarkers.items()): | |
| path.append({ | |
| "from": f"Patient:{patient_id}", | |
| "rel": "HAS_BIOMARKER", | |
| "to": f"Biomarker:{item}", | |
| "note": "matches trial requirement", | |
| }) | |
| for dx in (criteria.get("required_diagnoses") or [])[:2]: | |
| for patient_dx in patient_profile.get("diagnosis_names", []): | |
| if any(word in patient_dx.lower() for word in dx.lower().split()): | |
| path.append({ | |
| "from": f"Patient:{patient_id}", | |
| "rel": "HAS_DIAGNOSIS", | |
| "to": f"Diagnosis:{patient_dx}", | |
| "note": f"matches required: {dx}", | |
| }) | |
| break | |
| # Terminal node | |
| path.append({ | |
| "from": f"Patient:{patient_id}", | |
| "rel": "ELIGIBLE_FOR", | |
| "to": f"Trial:{nct_id}", | |
| "note": trial_title, | |
| }) | |
| return path | |
| def match_patient_to_trials(patient_id: str, condition: str | None = None, top_n: int = 5) -> list[dict]: | |
| """Find best-matching trials for a patient.""" | |
| patient_profile = get_patient_profile(patient_id) | |
| if not patient_profile: | |
| return [] | |
| # Infer condition from patient diagnoses if not provided | |
| if not condition and patient_profile.get("diagnosis_names"): | |
| condition = patient_profile["diagnosis_names"][0] | |
| elif not condition: | |
| condition = "cancer" | |
| trials = search_trials_sync(condition, page_size=10) | |
| scored = [] | |
| for trial in trials: | |
| score_result = score_patient_for_trial(patient_id, trial) | |
| scored.append({ | |
| **trial, | |
| "match_score": score_result.get("overall_score", 0.0), | |
| "eligible": score_result.get("eligible", False), | |
| "match_summary": score_result.get("summary", ""), | |
| "risk_flags": score_result.get("risk_flags", []), | |
| }) | |
| scored.sort(key=lambda x: x["match_score"], reverse=True) | |
| return scored[:top_n] | |
| def find_eligible_patients_for_trial(nct_id: str) -> list[dict]: | |
| """Screen all known patients against a specific trial.""" | |
| trial = get_trial_details_sync(nct_id) | |
| if not trial: | |
| return [] | |
| results = [] | |
| for patient_id in get_all_patient_ids(): | |
| score_result = score_patient_for_trial(patient_id, trial) | |
| if score_result.get("overall_score", 0) > 0.4: | |
| results.append({ | |
| "patient_id": patient_id, | |
| "match_score": score_result.get("overall_score", 0.0), | |
| "eligible": score_result.get("eligible", False), | |
| "summary": score_result.get("summary", ""), | |
| "risk_flags": score_result.get("risk_flags", []), | |
| }) | |
| results.sort(key=lambda x: x["match_score"], reverse=True) | |
| return results | |