timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""Active learning — what to ask next.
Goal: pick the HPO term whose answer most reduces uncertainty over the
current differential diagnosis distribution. Inspired by PhenoDP.
Bootstrap (today):
1. Get the current differential (top-K diseases with probabilities).
2. For each candidate HPO across these diseases, compute an information-gain
proxy: I(H) = H(D) - E_h[H(D|h)] where h ∈ {present, absent}.
3. Estimate P(h|D_i) from KG annotations (`Disease -[:HAS_PHENOTYPE]-> HPO`)
with a Laplace-smoothed prevalence prior.
4. Rank by I(H), return top-N with rationale.
Wraps `phenotype_recommender.recommend_phenotypes` if available — that already
implements a discriminative-power score; we add SUS-PCDT awareness on top.
"""
from __future__ import annotations
import logging
import math
from typing import Optional
from .types import NextQuestion
logger = logging.getLogger("gemeo.ask")
async def _safe_query(cypher: str, params: dict = None) -> list:
"""Query the *knowledge* KG (raras-app, ~10k diseases / 11k HPOs).
This module computes MaxInfoGain over disease/phenotype frequency
profiles which are stored in the raras-app KG, NOT in the cases
Neo4j (Aura). Previously this routed through space_graph._safe_query
→ Aura, which has only 20 curated diseases → next-questions always
returned 0. tools.run_query targets the knowledge KG via NEO4J_URI.
"""
try:
from tools import run_query
return await run_query(cypher, params or {}, timeout=10.0)
except Exception as e:
logger.debug(f"tools.run_query failed, falling back to space_graph: {e}")
try:
from space_graph import _safe_query as q
return await q(cypher, params or {}, timeout=10.0)
except Exception as e:
logger.debug(f"cypher failed: {e}")
return []
def _entropy(probs: list[float]) -> float:
s = sum(p for p in probs if p > 0)
if s <= 0:
return 0.0
norm = [p / s for p in probs]
return -sum(p * math.log2(p) for p in norm if p > 0)
async def _disease_phenotype_matrix(orpha_codes: list[str]) -> dict:
"""Returns {orpha: {hpo_id: prevalence_score}} from KG."""
if not orpha_codes:
return {}
rows = await _safe_query(
"""
MATCH (d:Disease)-[r:HAS_PHENOTYPE]->(p:Phenotype)
WHERE d.orphaCode IN $orphas
RETURN d.orphaCode AS orpha,
p.hpoId AS hpo,
p.name AS hpo_name,
coalesce(r.frequency, r.prevalence, 0.5) AS freq
""",
{"orphas": orpha_codes},
)
matrix: dict = {}
names: dict = {}
for r in rows:
orpha = r.get("orpha"); hpo = r.get("hpo"); freq = r.get("freq", 0.5)
if not orpha or not hpo:
continue
try:
f = float(freq)
except Exception:
f = 0.5
matrix.setdefault(orpha, {})[hpo] = max(0.05, min(0.95, f))
if r.get("hpo_name"):
names[hpo] = r["hpo_name"]
matrix["_names"] = names
return matrix
async def _is_in_pcdt(hpo_id: str, orpha_candidates: list[str]) -> bool:
"""Check if any candidate disease's PCDT mentions this HPO."""
try:
from brazilian_context import get_pcdt
except ImportError:
return False
for orpha in orpha_candidates:
try:
pcdt = get_pcdt(orpha)
except Exception:
continue
if not pcdt:
continue
text = " ".join(str(v) for v in pcdt.values()).lower()
if hpo_id.lower() in text:
return True
return False
async def _info_gain_path(
differential: list[dict],
already_present: set,
top_n: int,
) -> list[NextQuestion]:
"""Compute info-gain over the differential."""
orphas = [d["orpha"] for d in differential if d.get("orpha")]
priors = {d["orpha"]: float(d.get("probability", 1.0)) for d in differential if d.get("orpha")}
prior_sum = sum(priors.values())
if prior_sum <= 0:
return []
priors = {k: v / prior_sum for k, v in priors.items()}
matrix = await _disease_phenotype_matrix(orphas)
names = matrix.pop("_names", {})
# collect candidate HPOs (not already present)
candidate_hpos = set()
for orpha, hpos in matrix.items():
for h in hpos.keys():
if h not in already_present:
candidate_hpos.add(h)
base_entropy = _entropy(list(priors.values()))
scored = []
for hpo in candidate_hpos:
# P(hpo | D_i) = matrix[D_i].get(hpo, 0.05)
p_h = sum(priors.get(o, 0) * matrix.get(o, {}).get(hpo, 0.05) for o in orphas)
p_not_h = 1 - p_h
if p_h <= 0 or p_not_h <= 0:
continue
# posterior given hpo present
post_present = []
post_absent = []
for o in orphas:
f = matrix.get(o, {}).get(hpo, 0.05)
post_present.append(priors.get(o, 0) * f)
post_absent.append(priors.get(o, 0) * (1 - f))
ent_present = _entropy(post_present)
ent_absent = _entropy(post_absent)
expected_post = p_h * ent_present + p_not_h * ent_absent
gain = base_entropy - expected_post
if gain <= 0:
continue
# which diseases does this HPO discriminate among?
discriminates = []
for o in orphas:
f = matrix.get(o, {}).get(hpo, 0.05)
if f >= 0.7 or f <= 0.15:
discriminates.append(o)
scored.append({
"hpo": hpo,
"name": names.get(hpo, hpo),
"gain": gain,
"discriminates": discriminates[:5],
})
scored.sort(key=lambda x: x["gain"], reverse=True)
out = []
for s in scored[:top_n]:
in_pcdt = await _is_in_pcdt(s["hpo"], orphas)
rationale = (
f"Reduces diagnostic entropy by {s['gain']:.2f} bits."
f" Discriminates among: {', '.join(s['discriminates']) or 'differential'}."
)
out.append(NextQuestion(
hpo_id=s["hpo"],
name=s["name"],
rationale=rationale,
information_gain=round(s["gain"], 4),
discriminates_between=s["discriminates"],
asks_in_pcdt=in_pcdt,
))
return out
async def recommend(space, top_n: int = 5) -> list[NextQuestion]:
"""Recommend the next phenotypes/labs/tests to investigate."""
# 1) try the existing phenotype_recommender (it has its own discriminative score).
# NOTE: signature is recommend_phenotypes(space, max_recommendations=...).
# We previously passed top_n=top_n which raised TypeError (silently swallowed)
# → next-questions always returned []. PhenotypeRecommendation also exposes
# `discriminative_power` + `reason` + `discriminates`, NOT `score`/`rationale`/
# `discriminates_between`.
try:
from phenotype_recommender import recommend_phenotypes
result = await recommend_phenotypes(space, max_recommendations=top_n)
if result is not None:
recs = result.recommendations if hasattr(result, "recommendations") else (
result.get("recommendations", []) if isinstance(result, dict) else []
)
if recs:
# adapt + enrich with PCDT awareness
orphas = []
for hyp in (getattr(space, "_hypotheses", {}) or {}).values():
if getattr(hyp, "orpha_code", None):
orphas.append(hyp.orpha_code)
out = []
for r in recs[:top_n]:
if isinstance(r, dict):
hpo = r.get("hpo_id"); name = r.get("name")
gain = float(
r.get("discriminative_power")
or r.get("score")
or r.get("information_gain")
or 0
)
rationale = r.get("reason") or r.get("rationale") or ""
disc = r.get("discriminates") or r.get("discriminates_between") or []
else:
hpo = getattr(r, "hpo_id", None); name = getattr(r, "name", None)
gain = float(
getattr(r, "discriminative_power", 0)
or getattr(r, "score", 0)
or getattr(r, "information_gain", 0)
or 0
)
rationale = getattr(r, "reason", "") or getattr(r, "rationale", "")
disc = getattr(r, "discriminates", []) or getattr(r, "discriminates_between", []) or []
if not hpo:
continue
in_pcdt = await _is_in_pcdt(hpo, orphas)
out.append(NextQuestion(
hpo_id=hpo, name=name or hpo,
rationale=rationale or "Discriminative for current differential.",
information_gain=gain,
discriminates_between=disc,
asks_in_pcdt=in_pcdt,
))
if out:
return out
except ImportError:
pass
except Exception as e:
logger.debug(f"phenotype_recommender failed: {e}")
# 2) fallback: in-house info-gain path
differential = []
for hyp in (getattr(space, "_hypotheses", {}) or {}).values():
if getattr(hyp, "orpha_code", None):
differential.append({
"orpha": hyp.orpha_code,
"name": getattr(hyp, "disease_name", "") or getattr(hyp, "name", ""),
"probability": getattr(hyp, "probability", 0.5),
})
if not differential:
return []
snap = space.get_current_snapshot() if hasattr(space, "get_current_snapshot") else None
already_present = set()
if snap:
for p in snap.phenotypes:
if p.get("hpo_id"):
already_present.add(p["hpo_id"])
return await _info_gain_path(differential, already_present, top_n)