"""Active learning — what to ask next. Goal: pick the HPO term whose answer most reduces uncertainty over the current differential diagnosis distribution. Inspired by PhenoDP. Bootstrap (today): 1. Get the current differential (top-K diseases with probabilities). 2. For each candidate HPO across these diseases, compute an information-gain proxy: I(H) = H(D) - E_h[H(D|h)] where h ∈ {present, absent}. 3. Estimate P(h|D_i) from KG annotations (`Disease -[:HAS_PHENOTYPE]-> HPO`) with a Laplace-smoothed prevalence prior. 4. Rank by I(H), return top-N with rationale. Wraps `phenotype_recommender.recommend_phenotypes` if available — that already implements a discriminative-power score; we add SUS-PCDT awareness on top. """ from __future__ import annotations import logging import math from typing import Optional from .types import NextQuestion logger = logging.getLogger("gemeo.ask") async def _safe_query(cypher: str, params: dict = None) -> list: """Query the *knowledge* KG (raras-app, ~10k diseases / 11k HPOs). This module computes MaxInfoGain over disease/phenotype frequency profiles which are stored in the raras-app KG, NOT in the cases Neo4j (Aura). Previously this routed through space_graph._safe_query → Aura, which has only 20 curated diseases → next-questions always returned 0. tools.run_query targets the knowledge KG via NEO4J_URI. """ try: from tools import run_query return await run_query(cypher, params or {}, timeout=10.0) except Exception as e: logger.debug(f"tools.run_query failed, falling back to space_graph: {e}") try: from space_graph import _safe_query as q return await q(cypher, params or {}, timeout=10.0) except Exception as e: logger.debug(f"cypher failed: {e}") return [] def _entropy(probs: list[float]) -> float: s = sum(p for p in probs if p > 0) if s <= 0: return 0.0 norm = [p / s for p in probs] return -sum(p * math.log2(p) for p in norm if p > 0) async def _disease_phenotype_matrix(orpha_codes: list[str]) -> dict: """Returns {orpha: {hpo_id: prevalence_score}} from KG.""" if not orpha_codes: return {} rows = await _safe_query( """ MATCH (d:Disease)-[r:HAS_PHENOTYPE]->(p:Phenotype) WHERE d.orphaCode IN $orphas RETURN d.orphaCode AS orpha, p.hpoId AS hpo, p.name AS hpo_name, coalesce(r.frequency, r.prevalence, 0.5) AS freq """, {"orphas": orpha_codes}, ) matrix: dict = {} names: dict = {} for r in rows: orpha = r.get("orpha"); hpo = r.get("hpo"); freq = r.get("freq", 0.5) if not orpha or not hpo: continue try: f = float(freq) except Exception: f = 0.5 matrix.setdefault(orpha, {})[hpo] = max(0.05, min(0.95, f)) if r.get("hpo_name"): names[hpo] = r["hpo_name"] matrix["_names"] = names return matrix async def _is_in_pcdt(hpo_id: str, orpha_candidates: list[str]) -> bool: """Check if any candidate disease's PCDT mentions this HPO.""" try: from brazilian_context import get_pcdt except ImportError: return False for orpha in orpha_candidates: try: pcdt = get_pcdt(orpha) except Exception: continue if not pcdt: continue text = " ".join(str(v) for v in pcdt.values()).lower() if hpo_id.lower() in text: return True return False async def _info_gain_path( differential: list[dict], already_present: set, top_n: int, ) -> list[NextQuestion]: """Compute info-gain over the differential.""" orphas = [d["orpha"] for d in differential if d.get("orpha")] priors = {d["orpha"]: float(d.get("probability", 1.0)) for d in differential if d.get("orpha")} prior_sum = sum(priors.values()) if prior_sum <= 0: return [] priors = {k: v / prior_sum for k, v in priors.items()} matrix = await _disease_phenotype_matrix(orphas) names = matrix.pop("_names", {}) # collect candidate HPOs (not already present) candidate_hpos = set() for orpha, hpos in matrix.items(): for h in hpos.keys(): if h not in already_present: candidate_hpos.add(h) base_entropy = _entropy(list(priors.values())) scored = [] for hpo in candidate_hpos: # P(hpo | D_i) = matrix[D_i].get(hpo, 0.05) p_h = sum(priors.get(o, 0) * matrix.get(o, {}).get(hpo, 0.05) for o in orphas) p_not_h = 1 - p_h if p_h <= 0 or p_not_h <= 0: continue # posterior given hpo present post_present = [] post_absent = [] for o in orphas: f = matrix.get(o, {}).get(hpo, 0.05) post_present.append(priors.get(o, 0) * f) post_absent.append(priors.get(o, 0) * (1 - f)) ent_present = _entropy(post_present) ent_absent = _entropy(post_absent) expected_post = p_h * ent_present + p_not_h * ent_absent gain = base_entropy - expected_post if gain <= 0: continue # which diseases does this HPO discriminate among? discriminates = [] for o in orphas: f = matrix.get(o, {}).get(hpo, 0.05) if f >= 0.7 or f <= 0.15: discriminates.append(o) scored.append({ "hpo": hpo, "name": names.get(hpo, hpo), "gain": gain, "discriminates": discriminates[:5], }) scored.sort(key=lambda x: x["gain"], reverse=True) out = [] for s in scored[:top_n]: in_pcdt = await _is_in_pcdt(s["hpo"], orphas) rationale = ( f"Reduces diagnostic entropy by {s['gain']:.2f} bits." f" Discriminates among: {', '.join(s['discriminates']) or 'differential'}." ) out.append(NextQuestion( hpo_id=s["hpo"], name=s["name"], rationale=rationale, information_gain=round(s["gain"], 4), discriminates_between=s["discriminates"], asks_in_pcdt=in_pcdt, )) return out async def recommend(space, top_n: int = 5) -> list[NextQuestion]: """Recommend the next phenotypes/labs/tests to investigate.""" # 1) try the existing phenotype_recommender (it has its own discriminative score). # NOTE: signature is recommend_phenotypes(space, max_recommendations=...). # We previously passed top_n=top_n which raised TypeError (silently swallowed) # → next-questions always returned []. PhenotypeRecommendation also exposes # `discriminative_power` + `reason` + `discriminates`, NOT `score`/`rationale`/ # `discriminates_between`. try: from phenotype_recommender import recommend_phenotypes result = await recommend_phenotypes(space, max_recommendations=top_n) if result is not None: recs = result.recommendations if hasattr(result, "recommendations") else ( result.get("recommendations", []) if isinstance(result, dict) else [] ) if recs: # adapt + enrich with PCDT awareness orphas = [] for hyp in (getattr(space, "_hypotheses", {}) or {}).values(): if getattr(hyp, "orpha_code", None): orphas.append(hyp.orpha_code) out = [] for r in recs[:top_n]: if isinstance(r, dict): hpo = r.get("hpo_id"); name = r.get("name") gain = float( r.get("discriminative_power") or r.get("score") or r.get("information_gain") or 0 ) rationale = r.get("reason") or r.get("rationale") or "" disc = r.get("discriminates") or r.get("discriminates_between") or [] else: hpo = getattr(r, "hpo_id", None); name = getattr(r, "name", None) gain = float( getattr(r, "discriminative_power", 0) or getattr(r, "score", 0) or getattr(r, "information_gain", 0) or 0 ) rationale = getattr(r, "reason", "") or getattr(r, "rationale", "") disc = getattr(r, "discriminates", []) or getattr(r, "discriminates_between", []) or [] if not hpo: continue in_pcdt = await _is_in_pcdt(hpo, orphas) out.append(NextQuestion( hpo_id=hpo, name=name or hpo, rationale=rationale or "Discriminative for current differential.", information_gain=gain, discriminates_between=disc, asks_in_pcdt=in_pcdt, )) if out: return out except ImportError: pass except Exception as e: logger.debug(f"phenotype_recommender failed: {e}") # 2) fallback: in-house info-gain path differential = [] for hyp in (getattr(space, "_hypotheses", {}) or {}).values(): if getattr(hyp, "orpha_code", None): differential.append({ "orpha": hyp.orpha_code, "name": getattr(hyp, "disease_name", "") or getattr(hyp, "name", ""), "probability": getattr(hyp, "probability", 0.5), }) if not differential: return [] snap = space.get_current_snapshot() if hasattr(space, "get_current_snapshot") else None already_present = set() if snap: for p in snap.phenotypes: if p.get("hpo_id"): already_present.add(p["hpo_id"]) return await _info_gain_path(differential, already_present, top_n)