from fhir_adapter import get_patient_profile, get_all_patient_ids from clinicaltrials_api import search_trials_sync, get_trial_details_sync from llm_client import parse_trial_protocol, score_patient_against_criteria import re try: from neo4j_setup import neo4j_conn as _neo4j except Exception: _neo4j = None # In-memory cache for parsed criteria and scores _criteria_cache: dict[str, dict] = {} _score_cache: dict[str, dict] = {} def _parse_age_string(age_str: str) -> int | None: if not age_str: return None match = re.search(r"(\d+)", age_str) return int(match.group(1)) if match else None def _quick_eligibility_check(patient_profile: dict, trial: dict) -> tuple[bool, list[str]]: """Rule-based pre-filter before expensive LLM scoring.""" flags = [] age = patient_profile.get("age", 0) min_age = _parse_age_string(trial.get("min_age", "")) max_age = _parse_age_string(trial.get("max_age", "")) if min_age and age < min_age: flags.append(f"Age {age} below minimum {min_age}") if max_age and age > max_age: flags.append(f"Age {age} above maximum {max_age}") trial_sex = trial.get("sex", "ALL").upper() patient_sex = patient_profile.get("gender", "").upper() if trial_sex not in ("ALL", "BOTH") and patient_sex and patient_sex[0] != trial_sex[0]: flags.append(f"Sex mismatch: trial requires {trial_sex}") return len(flags) == 0, flags def get_criteria_for_trial(trial: dict) -> dict: nct_id = trial.get("nct_id", "") if nct_id in _criteria_cache: return _criteria_cache[nct_id] eligibility_text = trial.get("eligibility_criteria", "") if eligibility_text: criteria = parse_trial_protocol(eligibility_text) else: criteria = { "inclusion_criteria": [f"Confirmed diagnosis of {trial.get('brief_summary', 'target condition')[:50]}"], "exclusion_criteria": ["Prior participation in conflicting trials"], "age_range": {"min": 18, "max": None}, "required_diagnoses": [], "required_biomarkers": [], "excluded_medications": [], "performance_status": None, } _criteria_cache[nct_id] = criteria return criteria def score_patient_for_trial(patient_id: str, trial: dict) -> dict: cache_key = f"{patient_id}:{trial.get('nct_id', '')}" if cache_key in _score_cache: return _score_cache[cache_key] patient_profile = get_patient_profile(patient_id) if not patient_profile: return {"error": "Patient not found", "overall_score": 0.0, "eligible": False} # Quick rule-based pre-filter passes_rules, rule_flags = _quick_eligibility_check(patient_profile, trial) criteria = get_criteria_for_trial(trial) result = score_patient_against_criteria(patient_profile, criteria, trial.get("title", "Clinical Trial")) if not passes_rules: result["overall_score"] = max(0.0, result.get("overall_score", 0.5) - 0.3) result["eligible"] = False result.setdefault("risk_flags", []).extend(rule_flags) result["patient_id"] = patient_id result["nct_id"] = trial.get("nct_id", "") result["trial_title"] = trial.get("title", "") result["match_path"] = _build_match_path(patient_profile, trial, criteria) _score_cache[cache_key] = result return result def _build_match_path(patient_profile: dict, trial: dict, criteria: dict) -> list[dict]: """ Build a human-readable graph explainability path showing WHY a patient was matched. Returns a list of path nodes: Patient → biomarker/diagnosis/lab → Trial """ path = [] patient_id = patient_profile.get("patient_id", "") nct_id = trial.get("nct_id", "") trial_title = trial.get("title", "")[:60] # Check graph for shared biomarker edges if _neo4j: try: rows = _neo4j.run_query( """ MATCH (p:Patient {id: $pid})-[:HAS_BIOMARKER]->(b:Biomarker) MATCH (t:Trial {id: $nct_id}) WHERE t.parsed_biomarkers CONTAINS b.name OR t.eligibility_criteria CONTAINS b.name RETURN b.name AS biomarker LIMIT 3 """, {"pid": patient_id, "nct_id": nct_id}, ) for row in rows: path.append({ "from": f"Patient:{patient_id}", "rel": "HAS_BIOMARKER", "to": f"Biomarker:{row['biomarker']}", "note": "required by trial", }) except Exception: pass # Add FHIR-based reasoning nodes from the criteria match for item in (criteria.get("required_biomarkers") or [])[:2]: biomarkers = patient_profile.get("biomarkers", {}) if any(item.lower() in str(k).lower() or item.lower() in str(v).lower() for k, v in biomarkers.items()): path.append({ "from": f"Patient:{patient_id}", "rel": "HAS_BIOMARKER", "to": f"Biomarker:{item}", "note": "matches trial requirement", }) for dx in (criteria.get("required_diagnoses") or [])[:2]: for patient_dx in patient_profile.get("diagnosis_names", []): if any(word in patient_dx.lower() for word in dx.lower().split()): path.append({ "from": f"Patient:{patient_id}", "rel": "HAS_DIAGNOSIS", "to": f"Diagnosis:{patient_dx}", "note": f"matches required: {dx}", }) break # Terminal node path.append({ "from": f"Patient:{patient_id}", "rel": "ELIGIBLE_FOR", "to": f"Trial:{nct_id}", "note": trial_title, }) return path def match_patient_to_trials(patient_id: str, condition: str | None = None, top_n: int = 5) -> list[dict]: """Find best-matching trials for a patient.""" patient_profile = get_patient_profile(patient_id) if not patient_profile: return [] # Infer condition from patient diagnoses if not provided if not condition and patient_profile.get("diagnosis_names"): condition = patient_profile["diagnosis_names"][0] elif not condition: condition = "cancer" trials = search_trials_sync(condition, page_size=10) scored = [] for trial in trials: score_result = score_patient_for_trial(patient_id, trial) scored.append({ **trial, "match_score": score_result.get("overall_score", 0.0), "eligible": score_result.get("eligible", False), "match_summary": score_result.get("summary", ""), "risk_flags": score_result.get("risk_flags", []), }) scored.sort(key=lambda x: x["match_score"], reverse=True) return scored[:top_n] def find_eligible_patients_for_trial(nct_id: str) -> list[dict]: """Screen all known patients against a specific trial.""" trial = get_trial_details_sync(nct_id) if not trial: return [] results = [] for patient_id in get_all_patient_ids(): score_result = score_patient_for_trial(patient_id, trial) if score_result.get("overall_score", 0) > 0.4: results.append({ "patient_id": patient_id, "match_score": score_result.get("overall_score", 0.0), "eligible": score_result.get("eligible", False), "summary": score_result.get("summary", ""), "risk_flags": score_result.get("risk_flags", []), }) results.sort(key=lambda x: x["match_score"], reverse=True) return results