Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """Rule-based 30-day readmission risk classification engine. | |
| Reference implementation of the algorithm described in ALGORITHM_DESIGN.md. | |
| Input: TOON lines (CLUSTER|Keyword|Value|Timestamp) | |
| Output: Risk classification + days-to-readmission prediction | |
| Usage: | |
| # From TOON string | |
| engine = ReadmissionRiskEngine() | |
| result = engine.score_from_toon(toon_text) | |
| print(result) | |
| # From TOON file | |
| result = engine.score_from_file("path/to/extraction.txt") | |
| # From JSONL training data | |
| results = engine.score_from_jsonl("dspy_fine_tuning/data/trainset_full.jsonl") | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import math | |
| import re | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| # --------------------------------------------------------------------------- | |
| # Data classes | |
| # --------------------------------------------------------------------------- | |
| class ParsedFact: | |
| cluster: str | |
| keyword: str | |
| value: Union[float, str] | |
| timestamp: str | |
| is_numeric: bool | |
| plausibility_ok: bool = True | |
| class ClusterScore: | |
| cluster: str | |
| score: int | |
| max_score: int | |
| contributing_factors: List[str] = field(default_factory=list) | |
| class InteractionResult: | |
| pattern_id: str | |
| pattern_name: str | |
| bonus: int | |
| description: str | |
| class SurvivalCurve: | |
| """P(readmit by day t) for several horizons.""" | |
| horizons: Dict[int, float] # {7: 0.05, 14: 0.12, 21: 0.18, 30: 0.23} | |
| class RiskResult: | |
| # Scores | |
| composite_score: int | |
| cluster_scores: Dict[str, ClusterScore] | |
| interaction_bonus: int | |
| interactions_triggered: List[InteractionResult] | |
| # Risk classification | |
| probability: float | |
| risk_category: str # Low / Medium / High / Critical | |
| risk_color: str | |
| # Days prediction | |
| estimated_days: float | |
| days_bucket: str # "0-7 days" / "8-14 days" / "15-30 days" | |
| survival_curve: SurvivalCurve | |
| # Explainability | |
| risk_factors: List[str] | |
| protective_factors: List[str] | |
| missing_clusters: List[str] | |
| data_completeness: float | |
| confidence: str # high / medium / low | |
| # Raw data | |
| n_facts_parsed: int | |
| n_facts_dropped: int | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| VALID_CLUSTERS = { | |
| "DEMOGRAPHICS", "VITALS", "LABS", "PROBLEMS", "SYMPTOMS", | |
| "MEDICATIONS", "PROCEDURES", "UTILIZATION", "DISPOSITION", | |
| } | |
| NUMERIC_CLUSTERS = {"VITALS", "LABS", "UTILIZATION"} | |
| OBJECTIVE_CLUSTERS = {"DEMOGRAPHICS", "VITALS", "LABS", "UTILIZATION", "DISPOSITION"} | |
| # --------------------------------------------------------------------------- | |
| # Engine | |
| # --------------------------------------------------------------------------- | |
| class ReadmissionRiskEngine: | |
| """Main entry point for readmission risk scoring.""" | |
| def __init__(self, config_dir: Optional[Path] = None): | |
| if config_dir is None: | |
| config_dir = Path(__file__).parent / "config" | |
| self._config_dir = config_dir | |
| self._scoring_rules = self._load_json("scoring_rules.json") | |
| self._problem_groups = self._load_json("snomed_problem_groups.json")["groups"] | |
| self._symptom_groups = self._load_json("symptom_urgency_groups.json")["groups"] | |
| # Build lookup indexes | |
| self._problem_synonym_index = self._build_synonym_index(self._problem_groups) | |
| self._symptom_synonym_index = self._build_synonym_index(self._symptom_groups) | |
| # Calibration parameters | |
| cal = self._scoring_rules["_meta"]["calibration"] | |
| self._alpha = cal["alpha"] | |
| self._beta = cal["beta"] | |
| # Days prediction parameters | |
| days_cfg = self._scoring_rules["DAYS_PREDICTION"]["models"] | |
| reg = days_cfg["regression"]["parameters"] | |
| self._d_max = reg["D_max"] | |
| self._gamma = reg["gamma"] | |
| surv = days_cfg["survival"]["parameters"] | |
| self._k_base = surv["k_base"] | |
| # -- Loading helpers ---------------------------------------------------- | |
| def _load_json(self, filename: str) -> Dict[str, Any]: | |
| p = self._config_dir / filename | |
| return json.loads(p.read_text(encoding="utf-8")) | |
| def _build_synonym_index(groups: List[Dict]) -> Dict[str, str]: | |
| """Map lowercase synonym → group id.""" | |
| idx: Dict[str, str] = {} | |
| for g in groups: | |
| gid = g["id"] | |
| for syn in g.get("synonyms", []): | |
| key = syn.strip().lower() | |
| if key not in idx: | |
| idx[key] = gid | |
| return idx | |
| def _match_to_group( | |
| self, | |
| keyword: str, | |
| synonym_index: Dict[str, str], | |
| groups: List[Dict], | |
| ) -> Optional[Dict]: | |
| """Smart matching: exact > word-boundary substring > raw substring. | |
| Avoids false matches like 'tia' in 'essential' by preferring | |
| word-boundary matches and longer synonyms. | |
| """ | |
| kw_lower = keyword.strip().lower() | |
| # 1) Exact match (full keyword == synonym) | |
| gid = synonym_index.get(kw_lower) | |
| if gid: | |
| return self._group_by_id(groups, gid) | |
| # Tokenize keyword into words for word-boundary matching | |
| kw_words = set(re.split(r"[\s,;/\-()]+", kw_lower)) | |
| # 2) Word-boundary match: synonym is a whole word within the keyword | |
| # OR keyword starts/ends with the synonym as a distinct token | |
| best_wb_match: Optional[str] = None | |
| best_wb_len = 0 | |
| # 3) Raw substring match (fallback, requires min 4 chars to avoid noise) | |
| best_sub_match: Optional[str] = None | |
| best_sub_len = 0 | |
| for syn, gid in synonym_index.items(): | |
| if syn not in kw_lower: | |
| continue | |
| # Check if it's a word-boundary match | |
| is_word_match = ( | |
| syn in kw_words # exact word token | |
| or kw_lower.startswith(syn + " ") | |
| or kw_lower.endswith(" " + syn) | |
| or (" " + syn + " ") in kw_lower | |
| ) | |
| if is_word_match and len(syn) > best_wb_len: | |
| best_wb_match = gid | |
| best_wb_len = len(syn) | |
| elif not is_word_match and len(syn) >= 4 and len(syn) > best_sub_len: | |
| # Only use raw substring for synonyms >= 4 chars | |
| best_sub_match = gid | |
| best_sub_len = len(syn) | |
| # Prefer word-boundary matches over raw substring | |
| chosen = best_wb_match or best_sub_match | |
| if chosen: | |
| return self._group_by_id(groups, chosen) | |
| return None | |
| # -- Layer 1: Parser & Normalizer ---------------------------------------- | |
| def _try_parse_float(value: str) -> Optional[float]: | |
| """Best-effort numeric parse. | |
| Stage2 should emit numeric-only values for numeric fields, but in practice | |
| we sometimes see light decoration like '3 days'. For scoring purposes we | |
| accept the first numeric token, but we avoid parsing ratios like '120/80'. | |
| """ | |
| s = (value or "").strip() | |
| if not s: | |
| return None | |
| # Avoid BP-style ratios and similar formats. | |
| if "/" in s: | |
| return None | |
| # Fast path: pure float | |
| try: | |
| return float(s) | |
| except Exception: | |
| pass | |
| # Fallback: extract first numeric token | |
| m = re.search(r"[-+]?\d+(?:\.\d+)?", s) | |
| if not m: | |
| return None | |
| try: | |
| return float(m.group(0)) | |
| except Exception: | |
| return None | |
| def _split_semantic_items(value: str, *, limit: int = 20) -> List[str]: | |
| """Split a semicolon/comma/newline separated list into normalized items.""" | |
| raw = (value or "").strip() | |
| if not raw: | |
| return [] | |
| parts: List[str] = [] | |
| for seg in re.split(r"[;\n]+", raw): | |
| seg = seg.strip() | |
| if not seg: | |
| continue | |
| for item in seg.split(","): | |
| it = " ".join(item.strip().split()) | |
| if not it: | |
| continue | |
| parts.append(it.strip(" -")) | |
| if len(parts) >= limit: | |
| break | |
| if len(parts) >= limit: | |
| break | |
| # Dedup while preserving order. | |
| out: List[str] = [] | |
| seen: set[str] = set() | |
| for it in parts: | |
| k = it.casefold() | |
| if k in seen: | |
| continue | |
| seen.add(k) | |
| out.append(it) | |
| return out | |
| def _strip_prefix(keyword: str, prefixes: List[str]) -> str: | |
| k = (keyword or "").strip() | |
| k_cf = k.casefold() | |
| for p in prefixes: | |
| p_cf = p.casefold() | |
| if k_cf.startswith(p_cf): | |
| k = k[len(p) :].strip() | |
| k_cf = k.casefold() | |
| return k | |
| def _normalize_discharge_disposition(value: str) -> str: | |
| """Normalize common discharge disposition variants to the scoring allowlist.""" | |
| v = (value or "").strip() | |
| v_cf = v.casefold() | |
| if not v: | |
| return v | |
| # Canonical allowlist (scoring_rules.json): Home, Home with Services, Rehab, SNF, LTAC, Hospice, AMA | |
| if v_cf in {"home with service", "home w service", "home with svc", "home w/ service"}: | |
| return "Home with Services" | |
| if v_cf in {"home with services", "home w services", "home w/ services", "home health", "home health care"}: | |
| return "Home with Services" | |
| if v_cf in {"hospice residence", "hospice care"}: | |
| return "Hospice" | |
| return v | |
| def _normalize_mental_status(value: str) -> str: | |
| v = (value or "").strip() | |
| v_cf = v.casefold() | |
| if not v: | |
| return v | |
| if "alert" in v_cf and "orient" in v_cf: | |
| return "alert" | |
| if v_cf in {"a&o", "ao", "a/ox3", "a/ox4"}: | |
| return "alert" | |
| return v | |
| def parse_toon(self, toon_text: str) -> Tuple[Dict[str, List[ParsedFact]], int, int]: | |
| """Parse TOON text into structured facts. | |
| Returns (facts_by_cluster, n_parsed, n_dropped). | |
| """ | |
| facts: Dict[str, List[ParsedFact]] = {} | |
| n_parsed = 0 | |
| n_dropped = 0 | |
| seen_objective: set = set() | |
| for raw_line in toon_text.strip().splitlines(): | |
| line = raw_line.strip() | |
| if not line or line.startswith("#"): | |
| continue | |
| parts = line.split("|") | |
| if len(parts) != 4: | |
| n_dropped += 1 | |
| continue | |
| cluster, keyword, value, timestamp = (p.strip() for p in parts) | |
| if cluster not in VALID_CLUSTERS: | |
| n_dropped += 1 | |
| continue | |
| # Strip common semantic prefixes embedded in the keyword. | |
| if cluster == "PROBLEMS": | |
| keyword = self._strip_prefix(keyword, ["PMH:", "PMH/Comorbidities:", "Discharge Dx:", "Working Dx:", "Complication:", "Complications:"]) | |
| elif cluster == "SYMPTOMS": | |
| keyword = self._strip_prefix(keyword, ["ADM:", "DC:"]) | |
| # Expand common Stage2 aggregate semantic lines into per-item facts. | |
| # This makes the scorer robust to model drift like: | |
| # PROBLEMS|Discharge Dx|CHF; COPD|Discharge | |
| # instead of emitting one line per diagnosis. | |
| if cluster == "PROBLEMS": | |
| kw_cf = keyword.strip().casefold() | |
| acute_keys = {"discharge dx", "working dx", "complication", "complications"} | |
| chronic_keys = {"pmh/comorbidities", "pmh", "comorbidities", "past medical history"} | |
| items = self._split_semantic_items(value) | |
| if kw_cf in acute_keys and items: | |
| for it in items: | |
| fact = ParsedFact( | |
| cluster="PROBLEMS", | |
| keyword=it, | |
| value="acute", | |
| timestamp="Discharge", | |
| is_numeric=False, | |
| plausibility_ok=True, | |
| ) | |
| facts.setdefault("PROBLEMS", []).append(fact) | |
| n_parsed += 1 | |
| continue | |
| if kw_cf in chronic_keys and items: | |
| for it in items: | |
| fact = ParsedFact( | |
| cluster="PROBLEMS", | |
| keyword=it, | |
| value="chronic", | |
| timestamp="Past", | |
| is_numeric=False, | |
| plausibility_ok=True, | |
| ) | |
| facts.setdefault("PROBLEMS", []).append(fact) | |
| n_parsed += 1 | |
| continue | |
| # Numeric parsing: | |
| # - Strictly numeric clusters MUST parse (else drop). | |
| # - Non-numeric clusters may still have numeric keywords (e.g. MEDICATIONS Medication Count, | |
| # PROCEDURES Mechanical Ventilation days). Those should parse so scoring rules apply. | |
| is_numeric = False | |
| parsed_value: Union[float, str] = value | |
| kw_rules = self._scoring_rules.get(cluster, {}).get("keywords", {}).get(keyword, {}) | |
| kw_type = kw_rules.get("type") if isinstance(kw_rules, dict) else None | |
| if cluster in NUMERIC_CLUSTERS: | |
| v = self._try_parse_float(value) | |
| if v is None: | |
| n_dropped += 1 | |
| continue | |
| parsed_value = v | |
| is_numeric = True | |
| elif kw_type == "range": | |
| v = self._try_parse_float(value) | |
| if v is None: | |
| n_dropped += 1 | |
| continue | |
| parsed_value = v | |
| is_numeric = True | |
| elif kw_type == "mixed": | |
| # Mixed: numeric is optional; keep as string if parsing fails. | |
| v = self._try_parse_float(value) | |
| if v is not None: | |
| parsed_value = v | |
| is_numeric = True | |
| # Plausibility check | |
| plausibility_ok = True | |
| if is_numeric: | |
| plausibility_ok = self._check_plausibility(cluster, keyword, parsed_value) | |
| # Dedup for objective clusters | |
| if cluster in OBJECTIVE_CLUSTERS: | |
| key = (cluster, keyword) | |
| if key in seen_objective: | |
| # Keep the one with better timestamp | |
| n_dropped += 1 | |
| continue | |
| seen_objective.add(key) | |
| fact = ParsedFact( | |
| cluster=cluster, | |
| keyword=keyword, | |
| value=parsed_value, | |
| timestamp=timestamp, | |
| is_numeric=is_numeric, | |
| plausibility_ok=plausibility_ok, | |
| ) | |
| facts.setdefault(cluster, []).append(fact) | |
| n_parsed += 1 | |
| return facts, n_parsed, n_dropped | |
| def _check_plausibility(self, cluster: str, keyword: str, value: float) -> bool: | |
| cluster_rules = self._scoring_rules.get(cluster, {}).get("keywords", {}) | |
| kw_rules = cluster_rules.get(keyword, {}) | |
| plaus = kw_rules.get("plausibility") | |
| if plaus: | |
| return plaus["min"] <= value <= plaus["max"] | |
| return True | |
| # -- Layer 2: Concept Mapper -------------------------------------------- | |
| def map_problem_to_group(self, keyword: str) -> Optional[Dict]: | |
| """Map a PROBLEMS keyword to a SNOMED concept group.""" | |
| return self._match_to_group(keyword, self._problem_synonym_index, self._problem_groups) | |
| def map_symptom_to_group(self, keyword: str) -> Optional[Dict]: | |
| """Map a SYMPTOMS keyword to an urgency group.""" | |
| return self._match_to_group(keyword, self._symptom_synonym_index, self._symptom_groups) | |
| def _group_by_id(groups: List[Dict], gid: str) -> Optional[Dict]: | |
| for g in groups: | |
| if g["id"] == gid: | |
| return g | |
| return None | |
| # -- Layer 3: Cluster Scorers ------------------------------------------- | |
| def _score_range_keyword(self, rules: Dict, value: float) -> Tuple[int, str]: | |
| """Score a numeric value using range rules. Returns (score, label).""" | |
| for r in rules.get("ranges", []): | |
| if r["min"] <= value <= r["max"]: | |
| return r["score"], r.get("label", "") | |
| return 0, "" | |
| def score_demographics(self, facts: List[ParsedFact]) -> ClusterScore: | |
| rules = self._scoring_rules["DEMOGRAPHICS"]["keywords"] | |
| score = 0 | |
| factors: List[str] = [] | |
| age_found = False | |
| for f in facts: | |
| if f.keyword == "Age" and f.is_numeric: | |
| age_found = True | |
| pts, label = self._score_range_keyword(rules["Age"], f.value) | |
| score += pts | |
| if pts > 0: | |
| factors.append(f"Age {int(f.value)} ({label}, +{pts})") | |
| elif f.keyword == "Sex": | |
| val = str(f.value).lower() | |
| pts = rules["Sex"]["values"].get(val, 0) | |
| score += pts | |
| if pts > 0: | |
| factors.append(f"Sex={val} (+{pts})") | |
| if not age_found: | |
| default = rules["Age"].get("missing_score", 2) | |
| score += default | |
| factors.append(f"Age missing (default +{default})") | |
| return ClusterScore("DEMOGRAPHICS", score, 10, factors) | |
| def score_vitals(self, facts: List[ParsedFact]) -> ClusterScore: | |
| rules = self._scoring_rules["VITALS"]["keywords"] | |
| score = 0 | |
| factors: List[str] = [] | |
| for f in facts: | |
| if not f.is_numeric or not f.plausibility_ok: | |
| continue | |
| kw_rules = rules.get(f.keyword) | |
| if not kw_rules or kw_rules.get("type") == "no_direct_score": | |
| continue | |
| pts, label = self._score_range_keyword(kw_rules, f.value) | |
| score += pts | |
| if pts > 0: | |
| factors.append(f"{f.keyword}={f.value} ({label}, +{pts})") | |
| return ClusterScore("VITALS", score, 25, factors) | |
| def score_labs(self, facts: List[ParsedFact]) -> ClusterScore: | |
| rules = self._scoring_rules["LABS"]["keywords"] | |
| score = 0 | |
| factors: List[str] = [] | |
| for f in facts: | |
| if not f.is_numeric or not f.plausibility_ok: | |
| continue | |
| kw_rules = rules.get(f.keyword) | |
| if not kw_rules: | |
| continue | |
| pts, label = self._score_range_keyword(kw_rules, f.value) | |
| score += pts | |
| if pts > 0: | |
| factors.append(f"{f.keyword}={f.value} ({label}, +{pts})") | |
| return ClusterScore("LABS", score, 30, factors) | |
| def score_problems(self, facts: List[ParsedFact]) -> ClusterScore: | |
| score = 0 | |
| factors: List[str] = [] | |
| active_groups: Dict[str, int] = {} # group_id -> max weight | |
| include_values = {"chronic", "acute", "exist"} | |
| for f in facts: | |
| val = str(f.value).lower().strip() | |
| if val not in include_values: | |
| continue | |
| group = self.map_problem_to_group(f.keyword) | |
| if group: | |
| gid = group["id"] | |
| w = group["risk_weight"] | |
| if gid not in active_groups or w > active_groups[gid]: | |
| active_groups[gid] = w | |
| factors.append(f"{f.keyword} → {group['name']} (weight {w})") | |
| base_score = sum(active_groups.values()) | |
| # Multimorbidity bonus | |
| n_groups = len(active_groups) | |
| mm_bonus = 0 | |
| if n_groups > 3: | |
| mm_bonus = min(n_groups - 3, 5) | |
| factors.append(f"Multimorbidity: {n_groups} groups (+{mm_bonus})") | |
| score = min(base_score + mm_bonus, 40) | |
| return ClusterScore("PROBLEMS", score, 40, factors) | |
| def score_symptoms(self, facts: List[ParsedFact]) -> ClusterScore: | |
| sev_mult = {"severe": 1.5, "yes": 1.0, "no": 0.0} | |
| score = 0.0 | |
| factors: List[str] = [] | |
| active_groups: Dict[str, float] = {} | |
| active_count = 0 | |
| for f in facts: | |
| val = str(f.value).lower().strip() | |
| mult = sev_mult.get(val, 0.0) | |
| if mult == 0.0: | |
| continue | |
| active_count += 1 | |
| group = self.map_symptom_to_group(f.keyword) | |
| if group: | |
| gid = group["id"] | |
| w = group["risk_weight"] * mult | |
| if gid not in active_groups or w > active_groups[gid]: | |
| active_groups[gid] = w | |
| factors.append(f"{f.keyword}={val} → {group['name']} (+{w:.1f})") | |
| base_score = sum(active_groups.values()) | |
| # Active symptom count bonus | |
| bonus = 0 | |
| if active_count > 3: | |
| bonus = 2 | |
| factors.append(f"Active symptoms: {active_count} (>3, +2)") | |
| score = min(int(round(base_score + bonus)), 15) | |
| return ClusterScore("SYMPTOMS", score, 15, factors) | |
| def score_medications(self, facts: List[ParsedFact]) -> ClusterScore: | |
| rules = self._scoring_rules["MEDICATIONS"]["keywords"] | |
| score = 0 | |
| factors: List[str] = [] | |
| med_count_val: Optional[float] = None | |
| for f in facts: | |
| kw_rules = rules.get(f.keyword) | |
| if not kw_rules: | |
| continue | |
| if kw_rules["type"] == "range" and f.is_numeric: | |
| pts, label = self._score_range_keyword(kw_rules, f.value) | |
| score += pts | |
| if f.keyword == "Medication Count": | |
| med_count_val = f.value | |
| if pts > 0: | |
| factors.append(f"{f.keyword}={f.value} ({label}, +{pts})") | |
| elif kw_rules["type"] == "categorical": | |
| val = str(f.value).lower().strip() | |
| pts = kw_rules["values"].get(val, 0) | |
| score += pts | |
| if pts > 0: | |
| factors.append(f"{f.keyword}={val} (+{pts})") | |
| # Derived polypharmacy: if med_count >= 5 and Polypharmacy not already scored | |
| polypharmacy_scored = any("Polypharmacy" in f for f in factors) | |
| if med_count_val is not None and med_count_val >= 5 and not polypharmacy_scored: | |
| score += 3 | |
| factors.append(f"Derived Polypharmacy (Med Count={int(med_count_val)} >=5, +3)") | |
| return ClusterScore("MEDICATIONS", min(score, 15), 15, factors) | |
| def score_procedures(self, facts: List[ParsedFact]) -> ClusterScore: | |
| rules = self._scoring_rules["PROCEDURES"]["keywords"] | |
| score = 0 | |
| factors: List[str] = [] | |
| specific_scored = False | |
| for f in facts: | |
| kw_rules = rules.get(f.keyword) | |
| if not kw_rules: | |
| continue | |
| if f.keyword == "Mechanical Ventilation": | |
| # Mixed type: numeric > 0 or categorical | |
| if f.is_numeric and f.value > 0: | |
| score += kw_rules["score_if_any_positive"] | |
| factors.append(f"Mechanical Ventilation={f.value} days (+{kw_rules['score_if_any_positive']})") | |
| specific_scored = True | |
| elif str(f.value).lower().strip() != "no": | |
| score += kw_rules["score_if_any_positive"] | |
| factors.append(f"Mechanical Ventilation={f.value} (+{kw_rules['score_if_any_positive']})") | |
| specific_scored = True | |
| elif f.keyword == "Dialysis": | |
| val = str(f.value).lower().strip() | |
| pts = kw_rules["values"].get(val, 0) | |
| score += pts | |
| if pts > 0: | |
| factors.append(f"Dialysis={val} (+{pts})") | |
| specific_scored = True | |
| elif f.keyword == "Surgery": | |
| val = str(f.value).lower().strip() | |
| pts = kw_rules["values"].get(val, 0) | |
| score += pts | |
| if pts > 0: | |
| factors.append(f"Surgery={val} (+{pts})") | |
| specific_scored = True | |
| elif f.keyword == "Any Procedure": | |
| # Only score if no specific procedure was scored | |
| pass # handled below | |
| # Fallback: Any Procedure | |
| if not specific_scored: | |
| for f in facts: | |
| if f.keyword == "Any Procedure": | |
| val = str(f.value).lower().strip() | |
| pts = rules["Any Procedure"]["values"].get(val, 0) | |
| score += pts | |
| if pts > 0: | |
| factors.append(f"Any Procedure={val} (generic fallback, +{pts})") | |
| break | |
| return ClusterScore("PROCEDURES", min(score, 15), 15, factors) | |
| def score_utilization(self, facts: List[ParsedFact]) -> ClusterScore: | |
| rules = self._scoring_rules["UTILIZATION"]["keywords"] | |
| score = 0 | |
| factors: List[str] = [] | |
| for f in facts: | |
| if not f.is_numeric: | |
| continue | |
| kw_rules = rules.get(f.keyword) | |
| if not kw_rules: | |
| continue | |
| pts, label = self._score_range_keyword(kw_rules, f.value) | |
| score += pts | |
| if pts > 0: | |
| factors.append(f"{f.keyword}={f.value} ({label}, +{pts})") | |
| return ClusterScore("UTILIZATION", min(score, 20), 20, factors) | |
| def score_disposition(self, facts: List[ParsedFact]) -> ClusterScore: | |
| rules = self._scoring_rules["DISPOSITION"]["keywords"] | |
| score = 0 | |
| factors: List[str] = [] | |
| for f in facts: | |
| kw_rules = rules.get(f.keyword) | |
| if not kw_rules: | |
| continue | |
| val = str(f.value).strip() | |
| if f.keyword == "Discharge Disposition": | |
| val = self._normalize_discharge_disposition(val) | |
| elif f.keyword == "Mental Status": | |
| val = self._normalize_mental_status(val) | |
| # Try exact match first, then case-insensitive | |
| pts = kw_rules["values"].get(val, kw_rules["values"].get(val.lower(), 0)) | |
| score += pts | |
| if pts > 0: | |
| factors.append(f"{f.keyword}={val} (+{pts})") | |
| return ClusterScore("DISPOSITION", min(score, 15), 15, factors) | |
| # -- Layer 4: Pattern Detector ------------------------------------------ | |
| def detect_interactions( | |
| self, | |
| facts: Dict[str, List[ParsedFact]], | |
| cluster_scores: Dict[str, ClusterScore], | |
| ) -> List[InteractionResult]: | |
| """Detect cross-cluster clinical patterns.""" | |
| results: List[InteractionResult] = [] | |
| # Helper: get numeric value for a cluster/keyword | |
| def get_val(cluster: str, keyword: str) -> Optional[float]: | |
| for f in facts.get(cluster, []): | |
| if f.keyword == keyword and f.is_numeric: | |
| return f.value | |
| return None | |
| def get_str(cluster: str, keyword: str) -> Optional[str]: | |
| for f in facts.get(cluster, []): | |
| if f.keyword == keyword: | |
| return str(f.value).lower().strip() | |
| return None | |
| def has_symptom_group(group_id: str) -> bool: | |
| for f in facts.get("SYMPTOMS", []): | |
| val = str(f.value).lower().strip() | |
| if val in ("yes", "severe"): | |
| g = self.map_symptom_to_group(f.keyword) | |
| if g and g["id"] == group_id: | |
| return True | |
| return False | |
| def has_problem_group(group_id: str) -> bool: | |
| for f in facts.get("PROBLEMS", []): | |
| val = str(f.value).lower().strip() | |
| if val in ("chronic", "acute", "exist"): | |
| g = self.map_problem_to_group(f.keyword) | |
| if g and g["id"] == group_id: | |
| return True | |
| return False | |
| # --- Sepsis Pattern --- | |
| hr = get_val("VITALS", "Heart Rate") | |
| sbp = get_val("VITALS", "Systolic BP") | |
| rr = get_val("VITALS", "Respiratory Rate") | |
| wbc = get_val("LABS", "WBC") | |
| temp = get_val("VITALS", "Temperature") | |
| if hr is not None and hr > 100: | |
| has_hemodynamic = (sbp is not None and sbp < 100) or (rr is not None and rr > 22) | |
| has_infection = ( | |
| (wbc is not None and (wbc > 12 or wbc < 4)) | |
| or (temp is not None and temp > 100.4) | |
| ) | |
| if has_hemodynamic and has_infection: | |
| results.append(InteractionResult( | |
| "sepsis_pattern", "Sepsis / SIRS Pattern", 10, | |
| f"HR={hr}, SBP={sbp}, RR={rr}, WBC={wbc}, Temp={temp}", | |
| )) | |
| # --- AKI Pattern --- | |
| cr = get_val("LABS", "Creatinine") | |
| bun = get_val("LABS", "BUN") | |
| k = get_val("LABS", "Potassium") | |
| na = get_val("LABS", "Sodium") | |
| bicarb = get_val("LABS", "Bicarbonate") | |
| if cr is not None and cr > 1.5 and bun is not None and bun > 30: | |
| has_electrolyte = ( | |
| (k is not None and k > 5.0) | |
| or (na is not None and na < 135) | |
| or (bicarb is not None and bicarb < 22) | |
| ) | |
| if has_electrolyte: | |
| results.append(InteractionResult( | |
| "aki_pattern", "Acute Kidney Injury Pattern", 8, | |
| f"Cr={cr}, BUN={bun}, K={k}, Na={na}, Bicarb={bicarb}", | |
| )) | |
| # --- Decompensated HF --- | |
| if has_problem_group("heart_failure"): | |
| has_decomp_sign = ( | |
| has_symptom_group("edema_fluid") | |
| or has_symptom_group("respiratory_distress") | |
| or (bun is not None and bun > 40) | |
| ) | |
| if has_decomp_sign: | |
| results.append(InteractionResult( | |
| "decompensated_hf", "Decompensated Heart Failure", 8, | |
| "Heart failure + fluid overload/dyspnea/elevated BUN", | |
| )) | |
| # --- Frailty Syndrome --- | |
| age = get_val("DEMOGRAPHICS", "Age") | |
| hgb = get_val("LABS", "Hemoglobin") | |
| mental = get_str("DISPOSITION", "Mental Status") | |
| disp = get_str("DISPOSITION", "Discharge Disposition") | |
| n_problem_groups = len(set( | |
| self.map_problem_to_group(f.keyword)["id"] | |
| for f in facts.get("PROBLEMS", []) | |
| if str(f.value).lower().strip() in ("chronic", "acute", "exist") | |
| and self.map_problem_to_group(f.keyword) is not None | |
| )) | |
| if age is not None and age > 75: | |
| frailty_count = 0 | |
| if n_problem_groups >= 3: | |
| frailty_count += 1 | |
| if hgb is not None and hgb < 10: | |
| frailty_count += 1 | |
| if mental in ("confused", "lethargic"): | |
| frailty_count += 1 | |
| if disp in ("snf", "ltac", "rehab"): | |
| frailty_count += 1 | |
| if frailty_count >= 2: | |
| results.append(InteractionResult( | |
| "frailty_syndrome", "Frailty Syndrome", 6, | |
| f"Age={age}, problems={n_problem_groups}, Hgb={hgb}, mental={mental}, disp={disp}", | |
| )) | |
| # --- Unstable Discharge --- | |
| if disp == "ama": | |
| results.append(InteractionResult( | |
| "unstable_discharge", "Unstable Discharge (AMA)", 5, | |
| "Discharge Against Medical Advice", | |
| )) | |
| elif mental in ("confused", "lethargic") and disp in ("home", None): | |
| results.append(InteractionResult( | |
| "unstable_discharge", "Unstable Discharge (altered + Home)", 5, | |
| f"Mental={mental}, Disposition={disp}", | |
| )) | |
| # --- Respiratory Failure --- | |
| spo2 = get_val("VITALS", "SpO2") | |
| if spo2 is not None and spo2 < 92: | |
| has_resp = (rr is not None and rr > 24) or has_symptom_group("respiratory_distress") | |
| if has_resp: | |
| results.append(InteractionResult( | |
| "respiratory_failure", "Respiratory Failure Pattern", 6, | |
| f"SpO2={spo2}, RR={rr}", | |
| )) | |
| # --- Metabolic Crisis --- | |
| glucose = get_val("LABS", "Glucose") | |
| if glucose is not None and glucose > 300: | |
| has_metabolic = ( | |
| (bicarb is not None and bicarb < 18) | |
| or (k is not None and k > 5.5) | |
| ) | |
| if has_metabolic: | |
| results.append(InteractionResult( | |
| "metabolic_crisis", "Metabolic Crisis (DKA/HHS)", 6, | |
| f"Glucose={glucose}, Bicarb={bicarb}, K={k}", | |
| )) | |
| # --- Bleeding Risk --- | |
| plt = get_val("LABS", "Platelet") | |
| anticoag = get_str("MEDICATIONS", "Anticoagulation") | |
| if hgb is not None and hgb < 8: | |
| has_bleed_risk = ( | |
| (plt is not None and plt < 100) | |
| or anticoag == "yes" | |
| ) | |
| if has_bleed_risk: | |
| results.append(InteractionResult( | |
| "bleeding_risk", "Active Bleeding Risk", 6, | |
| f"Hgb={hgb}, Plt={plt}, Anticoag={anticoag}", | |
| )) | |
| return results | |
| # -- Layer 5: Risk Aggregator ------------------------------------------- | |
| def _logistic(self, score: int) -> float: | |
| """Convert composite score to probability via logistic function.""" | |
| z = self._alpha + self._beta * score | |
| return 1.0 / (1.0 + math.exp(-z)) | |
| def _classify_risk(self, score: int) -> Tuple[str, str]: | |
| """Return (category, color) for a given composite score.""" | |
| for cat in self._scoring_rules["_meta"]["risk_categories"]: | |
| if cat["score_min"] <= score <= cat["score_max"]: | |
| return cat["name"], cat["color"] | |
| return "Critical", "red" | |
| # -- Layer 6: Days Predictor -------------------------------------------- | |
| def _predict_days(self, score: int) -> float: | |
| """Estimate days to readmission (point estimate).""" | |
| return max(1.0, self._d_max * math.exp(-self._gamma * score)) | |
| def _predict_bucket(self, estimated_days: float) -> str: | |
| if estimated_days <= 7: | |
| return "0-7 days" | |
| elif estimated_days <= 14: | |
| return "8-14 days" | |
| else: | |
| return "15-30 days" | |
| def _predict_survival(self, score: int, p_30d: float) -> SurvivalCurve: | |
| """Compute P(readmit by day t) for several horizons.""" | |
| k = self._k_base + 0.02 * (score - 30) | |
| k = max(0.5, k) # floor to avoid degenerate cases | |
| horizons: Dict[int, float] = {} | |
| denom = 1.0 - math.exp(-k) | |
| if abs(denom) < 1e-9: | |
| denom = 1e-9 | |
| for t in [7, 14, 21, 30]: | |
| f_t = (1.0 - math.exp(-(t / 30.0) * k)) / denom | |
| p_t = p_30d * f_t | |
| horizons[t] = round(min(max(p_t, 0.0), 1.0), 4) | |
| return SurvivalCurve(horizons=horizons) | |
| # -- Main Scoring Pipeline ----------------------------------------------- | |
| def score(self, facts: Dict[str, List[ParsedFact]], n_parsed: int = 0, n_dropped: int = 0) -> RiskResult: | |
| """Run full scoring pipeline on parsed facts.""" | |
| # Layer 3: Cluster scores | |
| cluster_scores: Dict[str, ClusterScore] = {} | |
| cluster_scores["DEMOGRAPHICS"] = self.score_demographics(facts.get("DEMOGRAPHICS", [])) | |
| cluster_scores["VITALS"] = self.score_vitals(facts.get("VITALS", [])) | |
| cluster_scores["LABS"] = self.score_labs(facts.get("LABS", [])) | |
| cluster_scores["PROBLEMS"] = self.score_problems(facts.get("PROBLEMS", [])) | |
| cluster_scores["SYMPTOMS"] = self.score_symptoms(facts.get("SYMPTOMS", [])) | |
| cluster_scores["MEDICATIONS"] = self.score_medications(facts.get("MEDICATIONS", [])) | |
| cluster_scores["PROCEDURES"] = self.score_procedures(facts.get("PROCEDURES", [])) | |
| cluster_scores["UTILIZATION"] = self.score_utilization(facts.get("UTILIZATION", [])) | |
| cluster_scores["DISPOSITION"] = self.score_disposition(facts.get("DISPOSITION", [])) | |
| # Layer 4: Interaction detection | |
| interactions = self.detect_interactions(facts, cluster_scores) | |
| interaction_bonus = sum(i.bonus for i in interactions) | |
| # Layer 5: Aggregate | |
| composite = sum(cs.score for cs in cluster_scores.values()) + interaction_bonus | |
| probability = self._logistic(composite) | |
| category, color = self._classify_risk(composite) | |
| # Layer 6: Days prediction | |
| est_days = self._predict_days(composite) | |
| bucket = self._predict_bucket(est_days) | |
| survival = self._predict_survival(composite, probability) | |
| # Explainability | |
| risk_factors: List[str] = [] | |
| protective_factors: List[str] = [] | |
| for cs in cluster_scores.values(): | |
| risk_factors.extend(cs.contributing_factors) | |
| # Identify protective factors (normal values in important clusters) | |
| for cluster in ["VITALS", "LABS"]: | |
| cs = cluster_scores[cluster] | |
| if cs.score == 0 and facts.get(cluster): | |
| protective_factors.append(f"Normal {cluster.lower()} at discharge") | |
| if cluster_scores["DISPOSITION"].score == 0 and facts.get("DISPOSITION"): | |
| protective_factors.append("Stable disposition (Home, alert)") | |
| for i in interactions: | |
| risk_factors.append(f"[PATTERN] {i.pattern_name} (+{i.bonus})") | |
| # Missing data | |
| missing_clusters = [c for c in VALID_CLUSTERS if c not in facts or not facts[c]] | |
| completeness = 1.0 - len(missing_clusters) / len(VALID_CLUSTERS) | |
| if completeness >= 0.7: | |
| confidence = "high" | |
| elif completeness >= 0.5: | |
| confidence = "medium" | |
| else: | |
| confidence = "low" | |
| return RiskResult( | |
| composite_score=composite, | |
| cluster_scores=cluster_scores, | |
| interaction_bonus=interaction_bonus, | |
| interactions_triggered=interactions, | |
| probability=round(probability, 4), | |
| risk_category=category, | |
| risk_color=color, | |
| estimated_days=round(est_days, 1), | |
| days_bucket=bucket, | |
| survival_curve=survival, | |
| risk_factors=risk_factors, | |
| protective_factors=protective_factors, | |
| missing_clusters=sorted(missing_clusters), | |
| data_completeness=round(completeness, 2), | |
| confidence=confidence, | |
| n_facts_parsed=n_parsed, | |
| n_facts_dropped=n_dropped, | |
| ) | |
| # -- Convenience Methods ------------------------------------------------ | |
| def score_from_toon(self, toon_text: str) -> RiskResult: | |
| """Score from raw TOON text.""" | |
| facts, n_parsed, n_dropped = self.parse_toon(toon_text) | |
| return self.score(facts, n_parsed, n_dropped) | |
| def score_from_file(self, path: Union[str, Path]) -> RiskResult: | |
| """Score from a TOON text file.""" | |
| text = Path(path).read_text(encoding="utf-8") | |
| return self.score_from_toon(text) | |
| def score_from_jsonl(self, path: Union[str, Path], limit: int = 0) -> List[Tuple[str, RiskResult]]: | |
| """Score all entries in a JSONL file (trainset_full format). | |
| Returns list of (hadm_id, RiskResult). | |
| """ | |
| results: List[Tuple[str, RiskResult]] = [] | |
| p = Path(path) | |
| with p.open("r", encoding="utf-8") as f: | |
| for i, line in enumerate(f): | |
| if limit and i >= limit: | |
| break | |
| obj = json.loads(line) | |
| hadm_id = str(obj.get("hadm_id", f"row_{i}")) | |
| completion = obj.get("completion", "") | |
| if completion: | |
| result = self.score_from_toon(completion) | |
| results.append((hadm_id, result)) | |
| return results | |
| # --------------------------------------------------------------------------- | |
| # Pretty-printing | |
| # --------------------------------------------------------------------------- | |
| def format_result(result: RiskResult, hadm_id: str = "") -> str: | |
| """Format RiskResult as human-readable report.""" | |
| lines: List[str] = [] | |
| header = f"=== Readmission Risk Report" | |
| if hadm_id: | |
| header += f" (hadm_id: {hadm_id})" | |
| header += " ===" | |
| lines.append(header) | |
| lines.append("") | |
| # Summary | |
| lines.append(f"RISK: {result.risk_category} ({result.risk_color})") | |
| lines.append(f"Probability of 30-day readmission: {result.probability:.1%}") | |
| lines.append(f"Composite score: {result.composite_score}") | |
| lines.append(f"Confidence: {result.confidence} (data completeness: {result.data_completeness:.0%})") | |
| lines.append("") | |
| # Days prediction | |
| lines.append("--- Days-to-Readmission Prediction ---") | |
| lines.append(f"Point estimate: ~{result.estimated_days:.0f} days") | |
| lines.append(f"Bucket: {result.days_bucket}") | |
| lines.append("Survival curve:") | |
| for t, p in sorted(result.survival_curve.horizons.items()): | |
| lines.append(f" P(readmit by day {t:2d}): {p:.1%}") | |
| lines.append("") | |
| # Cluster breakdown | |
| lines.append("--- Cluster Scores ---") | |
| for cluster in ["DEMOGRAPHICS", "VITALS", "LABS", "PROBLEMS", "SYMPTOMS", | |
| "MEDICATIONS", "PROCEDURES", "UTILIZATION", "DISPOSITION"]: | |
| cs = result.cluster_scores.get(cluster) | |
| if cs: | |
| lines.append(f" {cluster}: {cs.score}/{cs.max_score}") | |
| lines.append(f" INTERACTIONS: +{result.interaction_bonus}") | |
| lines.append(f" TOTAL: {result.composite_score}") | |
| lines.append("") | |
| # Risk factors | |
| if result.risk_factors: | |
| lines.append("--- Risk Factors ---") | |
| for rf in result.risk_factors: | |
| lines.append(f" - {rf}") | |
| lines.append("") | |
| # Protective factors | |
| if result.protective_factors: | |
| lines.append("--- Protective Factors ---") | |
| for pf in result.protective_factors: | |
| lines.append(f" + {pf}") | |
| lines.append("") | |
| # Triggered patterns | |
| if result.interactions_triggered: | |
| lines.append("--- Clinical Patterns Detected ---") | |
| for ix in result.interactions_triggered: | |
| lines.append(f" [{ix.pattern_id}] {ix.pattern_name}: +{ix.bonus} pts") | |
| lines.append(f" Evidence: {ix.description}") | |
| lines.append("") | |
| # Missing data | |
| if result.missing_clusters: | |
| lines.append(f"--- Missing Data ({len(result.missing_clusters)} clusters) ---") | |
| for mc in result.missing_clusters: | |
| lines.append(f" ? {mc}") | |
| lines.append("") | |
| lines.append(f"Facts parsed: {result.n_facts_parsed}, dropped: {result.n_facts_dropped}") | |
| return "\n".join(lines) | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| import argparse | |
| ap = argparse.ArgumentParser(description="Rule-based 30-day readmission risk engine") | |
| sub = ap.add_subparsers(dest="cmd") | |
| # Score a single TOON file | |
| p_file = sub.add_parser("file", help="Score a single TOON file") | |
| p_file.add_argument("path", help="Path to TOON text file") | |
| # Score from JSONL | |
| p_jsonl = sub.add_parser("jsonl", help="Score all entries in a JSONL file") | |
| p_jsonl.add_argument("path", help="Path to JSONL file") | |
| p_jsonl.add_argument("--limit", type=int, default=0, help="Limit number of entries") | |
| p_jsonl.add_argument("--summary", action="store_true", help="Show summary statistics only") | |
| # Score from inline TOON text | |
| p_inline = sub.add_parser("inline", help="Score inline TOON text (pipe to stdin)") | |
| args = ap.parse_args() | |
| engine = ReadmissionRiskEngine() | |
| if args.cmd == "file": | |
| result = engine.score_from_file(args.path) | |
| print(format_result(result)) | |
| elif args.cmd == "jsonl": | |
| results = engine.score_from_jsonl(args.path, limit=args.limit) | |
| if args.summary: | |
| scores = [r.composite_score for _, r in results] | |
| probs = [r.probability for _, r in results] | |
| categories = {} | |
| for _, r in results: | |
| categories[r.risk_category] = categories.get(r.risk_category, 0) + 1 | |
| print(f"=== Summary ({len(results)} patients) ===") | |
| print(f"Score: mean={sum(scores)/len(scores):.1f}, " | |
| f"min={min(scores)}, max={max(scores)}, " | |
| f"median={sorted(scores)[len(scores)//2]}") | |
| print(f"P(readmit): mean={sum(probs)/len(probs):.1%}") | |
| print("Risk categories:") | |
| for cat in ["Low", "Medium", "High", "Critical"]: | |
| n = categories.get(cat, 0) | |
| pct = n / len(results) * 100 if results else 0 | |
| print(f" {cat}: {n} ({pct:.0f}%)") | |
| days = [r.estimated_days for _, r in results] | |
| print(f"Days estimate: mean={sum(days)/len(days):.1f}, " | |
| f"min={min(days):.1f}, max={max(days):.1f}") | |
| else: | |
| for hadm_id, result in results: | |
| print(format_result(result, hadm_id)) | |
| print("\n" + "=" * 60 + "\n") | |
| elif args.cmd == "inline": | |
| import sys | |
| toon_text = sys.stdin.read() | |
| result = engine.score_from_toon(toon_text) | |
| print(format_result(result)) | |
| else: | |
| ap.print_help() | |
| if __name__ == "__main__": | |
| main() | |