| """Active learning — what to ask next. |
| |
| Goal: pick the HPO term whose answer most reduces uncertainty over the |
| current differential diagnosis distribution. Inspired by PhenoDP. |
| |
| Bootstrap (today): |
| 1. Get the current differential (top-K diseases with probabilities). |
| 2. For each candidate HPO across these diseases, compute an information-gain |
| proxy: I(H) = H(D) - E_h[H(D|h)] where h ∈ {present, absent}. |
| 3. Estimate P(h|D_i) from KG annotations (`Disease -[:HAS_PHENOTYPE]-> HPO`) |
| with a Laplace-smoothed prevalence prior. |
| 4. Rank by I(H), return top-N with rationale. |
| |
| Wraps `phenotype_recommender.recommend_phenotypes` if available — that already |
| implements a discriminative-power score; we add SUS-PCDT awareness on top. |
| """ |
| from __future__ import annotations |
| import logging |
| import math |
| from typing import Optional |
|
|
| from .types import NextQuestion |
|
|
| logger = logging.getLogger("gemeo.ask") |
|
|
|
|
| async def _safe_query(cypher: str, params: dict = None) -> list: |
| """Query the *knowledge* KG (raras-app, ~10k diseases / 11k HPOs). |
| |
| This module computes MaxInfoGain over disease/phenotype frequency |
| profiles which are stored in the raras-app KG, NOT in the cases |
| Neo4j (Aura). Previously this routed through space_graph._safe_query |
| → Aura, which has only 20 curated diseases → next-questions always |
| returned 0. tools.run_query targets the knowledge KG via NEO4J_URI. |
| """ |
| try: |
| from tools import run_query |
| return await run_query(cypher, params or {}, timeout=10.0) |
| except Exception as e: |
| logger.debug(f"tools.run_query failed, falling back to space_graph: {e}") |
| try: |
| from space_graph import _safe_query as q |
| return await q(cypher, params or {}, timeout=10.0) |
| except Exception as e: |
| logger.debug(f"cypher failed: {e}") |
| return [] |
|
|
|
|
| def _entropy(probs: list[float]) -> float: |
| s = sum(p for p in probs if p > 0) |
| if s <= 0: |
| return 0.0 |
| norm = [p / s for p in probs] |
| return -sum(p * math.log2(p) for p in norm if p > 0) |
|
|
|
|
| async def _disease_phenotype_matrix(orpha_codes: list[str]) -> dict: |
| """Returns {orpha: {hpo_id: prevalence_score}} from KG.""" |
| if not orpha_codes: |
| return {} |
| rows = await _safe_query( |
| """ |
| MATCH (d:Disease)-[r:HAS_PHENOTYPE]->(p:Phenotype) |
| WHERE d.orphaCode IN $orphas |
| RETURN d.orphaCode AS orpha, |
| p.hpoId AS hpo, |
| p.name AS hpo_name, |
| coalesce(r.frequency, r.prevalence, 0.5) AS freq |
| """, |
| {"orphas": orpha_codes}, |
| ) |
| matrix: dict = {} |
| names: dict = {} |
| for r in rows: |
| orpha = r.get("orpha"); hpo = r.get("hpo"); freq = r.get("freq", 0.5) |
| if not orpha or not hpo: |
| continue |
| try: |
| f = float(freq) |
| except Exception: |
| f = 0.5 |
| matrix.setdefault(orpha, {})[hpo] = max(0.05, min(0.95, f)) |
| if r.get("hpo_name"): |
| names[hpo] = r["hpo_name"] |
| matrix["_names"] = names |
| return matrix |
|
|
|
|
| async def _is_in_pcdt(hpo_id: str, orpha_candidates: list[str]) -> bool: |
| """Check if any candidate disease's PCDT mentions this HPO.""" |
| try: |
| from brazilian_context import get_pcdt |
| except ImportError: |
| return False |
| for orpha in orpha_candidates: |
| try: |
| pcdt = get_pcdt(orpha) |
| except Exception: |
| continue |
| if not pcdt: |
| continue |
| text = " ".join(str(v) for v in pcdt.values()).lower() |
| if hpo_id.lower() in text: |
| return True |
| return False |
|
|
|
|
| async def _info_gain_path( |
| differential: list[dict], |
| already_present: set, |
| top_n: int, |
| ) -> list[NextQuestion]: |
| """Compute info-gain over the differential.""" |
| orphas = [d["orpha"] for d in differential if d.get("orpha")] |
| priors = {d["orpha"]: float(d.get("probability", 1.0)) for d in differential if d.get("orpha")} |
| prior_sum = sum(priors.values()) |
| if prior_sum <= 0: |
| return [] |
| priors = {k: v / prior_sum for k, v in priors.items()} |
|
|
| matrix = await _disease_phenotype_matrix(orphas) |
| names = matrix.pop("_names", {}) |
|
|
| |
| candidate_hpos = set() |
| for orpha, hpos in matrix.items(): |
| for h in hpos.keys(): |
| if h not in already_present: |
| candidate_hpos.add(h) |
|
|
| base_entropy = _entropy(list(priors.values())) |
| scored = [] |
| for hpo in candidate_hpos: |
| |
| p_h = sum(priors.get(o, 0) * matrix.get(o, {}).get(hpo, 0.05) for o in orphas) |
| p_not_h = 1 - p_h |
| if p_h <= 0 or p_not_h <= 0: |
| continue |
|
|
| |
| post_present = [] |
| post_absent = [] |
| for o in orphas: |
| f = matrix.get(o, {}).get(hpo, 0.05) |
| post_present.append(priors.get(o, 0) * f) |
| post_absent.append(priors.get(o, 0) * (1 - f)) |
| ent_present = _entropy(post_present) |
| ent_absent = _entropy(post_absent) |
| expected_post = p_h * ent_present + p_not_h * ent_absent |
| gain = base_entropy - expected_post |
| if gain <= 0: |
| continue |
|
|
| |
| discriminates = [] |
| for o in orphas: |
| f = matrix.get(o, {}).get(hpo, 0.05) |
| if f >= 0.7 or f <= 0.15: |
| discriminates.append(o) |
|
|
| scored.append({ |
| "hpo": hpo, |
| "name": names.get(hpo, hpo), |
| "gain": gain, |
| "discriminates": discriminates[:5], |
| }) |
|
|
| scored.sort(key=lambda x: x["gain"], reverse=True) |
|
|
| out = [] |
| for s in scored[:top_n]: |
| in_pcdt = await _is_in_pcdt(s["hpo"], orphas) |
| rationale = ( |
| f"Reduces diagnostic entropy by {s['gain']:.2f} bits." |
| f" Discriminates among: {', '.join(s['discriminates']) or 'differential'}." |
| ) |
| out.append(NextQuestion( |
| hpo_id=s["hpo"], |
| name=s["name"], |
| rationale=rationale, |
| information_gain=round(s["gain"], 4), |
| discriminates_between=s["discriminates"], |
| asks_in_pcdt=in_pcdt, |
| )) |
| return out |
|
|
|
|
| async def recommend(space, top_n: int = 5) -> list[NextQuestion]: |
| """Recommend the next phenotypes/labs/tests to investigate.""" |
|
|
| |
| |
| |
| |
| |
| |
| try: |
| from phenotype_recommender import recommend_phenotypes |
| result = await recommend_phenotypes(space, max_recommendations=top_n) |
| if result is not None: |
| recs = result.recommendations if hasattr(result, "recommendations") else ( |
| result.get("recommendations", []) if isinstance(result, dict) else [] |
| ) |
| if recs: |
| |
| orphas = [] |
| for hyp in (getattr(space, "_hypotheses", {}) or {}).values(): |
| if getattr(hyp, "orpha_code", None): |
| orphas.append(hyp.orpha_code) |
| out = [] |
| for r in recs[:top_n]: |
| if isinstance(r, dict): |
| hpo = r.get("hpo_id"); name = r.get("name") |
| gain = float( |
| r.get("discriminative_power") |
| or r.get("score") |
| or r.get("information_gain") |
| or 0 |
| ) |
| rationale = r.get("reason") or r.get("rationale") or "" |
| disc = r.get("discriminates") or r.get("discriminates_between") or [] |
| else: |
| hpo = getattr(r, "hpo_id", None); name = getattr(r, "name", None) |
| gain = float( |
| getattr(r, "discriminative_power", 0) |
| or getattr(r, "score", 0) |
| or getattr(r, "information_gain", 0) |
| or 0 |
| ) |
| rationale = getattr(r, "reason", "") or getattr(r, "rationale", "") |
| disc = getattr(r, "discriminates", []) or getattr(r, "discriminates_between", []) or [] |
| if not hpo: |
| continue |
| in_pcdt = await _is_in_pcdt(hpo, orphas) |
| out.append(NextQuestion( |
| hpo_id=hpo, name=name or hpo, |
| rationale=rationale or "Discriminative for current differential.", |
| information_gain=gain, |
| discriminates_between=disc, |
| asks_in_pcdt=in_pcdt, |
| )) |
| if out: |
| return out |
| except ImportError: |
| pass |
| except Exception as e: |
| logger.debug(f"phenotype_recommender failed: {e}") |
|
|
| |
| differential = [] |
| for hyp in (getattr(space, "_hypotheses", {}) or {}).values(): |
| if getattr(hyp, "orpha_code", None): |
| differential.append({ |
| "orpha": hyp.orpha_code, |
| "name": getattr(hyp, "disease_name", "") or getattr(hyp, "name", ""), |
| "probability": getattr(hyp, "probability", 0.5), |
| }) |
| if not differential: |
| return [] |
|
|
| snap = space.get_current_snapshot() if hasattr(space, "get_current_snapshot") else None |
| already_present = set() |
| if snap: |
| for p in snap.phenotypes: |
| if p.get("hpo_id"): |
| already_present.add(p["hpo_id"]) |
|
|
| return await _info_gain_path(differential, already_present, top_n) |
|
|