File size: 10,185 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
"""Active learning — what to ask next.

Goal: pick the HPO term whose answer most reduces uncertainty over the
current differential diagnosis distribution. Inspired by PhenoDP.

Bootstrap (today):
  1. Get the current differential (top-K diseases with probabilities).
  2. For each candidate HPO across these diseases, compute an information-gain
     proxy: I(H) = H(D) - E_h[H(D|h)] where h ∈ {present, absent}.
  3. Estimate P(h|D_i) from KG annotations (`Disease -[:HAS_PHENOTYPE]-> HPO`)
     with a Laplace-smoothed prevalence prior.
  4. Rank by I(H), return top-N with rationale.

Wraps `phenotype_recommender.recommend_phenotypes` if available — that already
implements a discriminative-power score; we add SUS-PCDT awareness on top.
"""
from __future__ import annotations
import logging
import math
from typing import Optional

from .types import NextQuestion

logger = logging.getLogger("gemeo.ask")


async def _safe_query(cypher: str, params: dict = None) -> list:
    """Query the *knowledge* KG (raras-app, ~10k diseases / 11k HPOs).

    This module computes MaxInfoGain over disease/phenotype frequency
    profiles which are stored in the raras-app KG, NOT in the cases
    Neo4j (Aura). Previously this routed through space_graph._safe_query
    → Aura, which has only 20 curated diseases → next-questions always
    returned 0. tools.run_query targets the knowledge KG via NEO4J_URI.
    """
    try:
        from tools import run_query
        return await run_query(cypher, params or {}, timeout=10.0)
    except Exception as e:
        logger.debug(f"tools.run_query failed, falling back to space_graph: {e}")
    try:
        from space_graph import _safe_query as q
        return await q(cypher, params or {}, timeout=10.0)
    except Exception as e:
        logger.debug(f"cypher failed: {e}")
        return []


def _entropy(probs: list[float]) -> float:
    s = sum(p for p in probs if p > 0)
    if s <= 0:
        return 0.0
    norm = [p / s for p in probs]
    return -sum(p * math.log2(p) for p in norm if p > 0)


async def _disease_phenotype_matrix(orpha_codes: list[str]) -> dict:
    """Returns {orpha: {hpo_id: prevalence_score}} from KG."""
    if not orpha_codes:
        return {}
    rows = await _safe_query(
        """
        MATCH (d:Disease)-[r:HAS_PHENOTYPE]->(p:Phenotype)
        WHERE d.orphaCode IN $orphas
        RETURN d.orphaCode AS orpha,
               p.hpoId AS hpo,
               p.name AS hpo_name,
               coalesce(r.frequency, r.prevalence, 0.5) AS freq
        """,
        {"orphas": orpha_codes},
    )
    matrix: dict = {}
    names: dict = {}
    for r in rows:
        orpha = r.get("orpha"); hpo = r.get("hpo"); freq = r.get("freq", 0.5)
        if not orpha or not hpo:
            continue
        try:
            f = float(freq)
        except Exception:
            f = 0.5
        matrix.setdefault(orpha, {})[hpo] = max(0.05, min(0.95, f))
        if r.get("hpo_name"):
            names[hpo] = r["hpo_name"]
    matrix["_names"] = names
    return matrix


async def _is_in_pcdt(hpo_id: str, orpha_candidates: list[str]) -> bool:
    """Check if any candidate disease's PCDT mentions this HPO."""
    try:
        from brazilian_context import get_pcdt
    except ImportError:
        return False
    for orpha in orpha_candidates:
        try:
            pcdt = get_pcdt(orpha)
        except Exception:
            continue
        if not pcdt:
            continue
        text = " ".join(str(v) for v in pcdt.values()).lower()
        if hpo_id.lower() in text:
            return True
    return False


async def _info_gain_path(
    differential: list[dict],
    already_present: set,
    top_n: int,
) -> list[NextQuestion]:
    """Compute info-gain over the differential."""
    orphas = [d["orpha"] for d in differential if d.get("orpha")]
    priors = {d["orpha"]: float(d.get("probability", 1.0)) for d in differential if d.get("orpha")}
    prior_sum = sum(priors.values())
    if prior_sum <= 0:
        return []
    priors = {k: v / prior_sum for k, v in priors.items()}

    matrix = await _disease_phenotype_matrix(orphas)
    names = matrix.pop("_names", {})

    # collect candidate HPOs (not already present)
    candidate_hpos = set()
    for orpha, hpos in matrix.items():
        for h in hpos.keys():
            if h not in already_present:
                candidate_hpos.add(h)

    base_entropy = _entropy(list(priors.values()))
    scored = []
    for hpo in candidate_hpos:
        # P(hpo | D_i) = matrix[D_i].get(hpo, 0.05)
        p_h = sum(priors.get(o, 0) * matrix.get(o, {}).get(hpo, 0.05) for o in orphas)
        p_not_h = 1 - p_h
        if p_h <= 0 or p_not_h <= 0:
            continue

        # posterior given hpo present
        post_present = []
        post_absent = []
        for o in orphas:
            f = matrix.get(o, {}).get(hpo, 0.05)
            post_present.append(priors.get(o, 0) * f)
            post_absent.append(priors.get(o, 0) * (1 - f))
        ent_present = _entropy(post_present)
        ent_absent = _entropy(post_absent)
        expected_post = p_h * ent_present + p_not_h * ent_absent
        gain = base_entropy - expected_post
        if gain <= 0:
            continue

        # which diseases does this HPO discriminate among?
        discriminates = []
        for o in orphas:
            f = matrix.get(o, {}).get(hpo, 0.05)
            if f >= 0.7 or f <= 0.15:
                discriminates.append(o)

        scored.append({
            "hpo": hpo,
            "name": names.get(hpo, hpo),
            "gain": gain,
            "discriminates": discriminates[:5],
        })

    scored.sort(key=lambda x: x["gain"], reverse=True)

    out = []
    for s in scored[:top_n]:
        in_pcdt = await _is_in_pcdt(s["hpo"], orphas)
        rationale = (
            f"Reduces diagnostic entropy by {s['gain']:.2f} bits."
            f" Discriminates among: {', '.join(s['discriminates']) or 'differential'}."
        )
        out.append(NextQuestion(
            hpo_id=s["hpo"],
            name=s["name"],
            rationale=rationale,
            information_gain=round(s["gain"], 4),
            discriminates_between=s["discriminates"],
            asks_in_pcdt=in_pcdt,
        ))
    return out


async def recommend(space, top_n: int = 5) -> list[NextQuestion]:
    """Recommend the next phenotypes/labs/tests to investigate."""

    # 1) try the existing phenotype_recommender (it has its own discriminative score).
    # NOTE: signature is recommend_phenotypes(space, max_recommendations=...).
    # We previously passed top_n=top_n which raised TypeError (silently swallowed)
    # → next-questions always returned []. PhenotypeRecommendation also exposes
    # `discriminative_power` + `reason` + `discriminates`, NOT `score`/`rationale`/
    # `discriminates_between`.
    try:
        from phenotype_recommender import recommend_phenotypes
        result = await recommend_phenotypes(space, max_recommendations=top_n)
        if result is not None:
            recs = result.recommendations if hasattr(result, "recommendations") else (
                result.get("recommendations", []) if isinstance(result, dict) else []
            )
            if recs:
                # adapt + enrich with PCDT awareness
                orphas = []
                for hyp in (getattr(space, "_hypotheses", {}) or {}).values():
                    if getattr(hyp, "orpha_code", None):
                        orphas.append(hyp.orpha_code)
                out = []
                for r in recs[:top_n]:
                    if isinstance(r, dict):
                        hpo = r.get("hpo_id"); name = r.get("name")
                        gain = float(
                            r.get("discriminative_power")
                            or r.get("score")
                            or r.get("information_gain")
                            or 0
                        )
                        rationale = r.get("reason") or r.get("rationale") or ""
                        disc = r.get("discriminates") or r.get("discriminates_between") or []
                    else:
                        hpo = getattr(r, "hpo_id", None); name = getattr(r, "name", None)
                        gain = float(
                            getattr(r, "discriminative_power", 0)
                            or getattr(r, "score", 0)
                            or getattr(r, "information_gain", 0)
                            or 0
                        )
                        rationale = getattr(r, "reason", "") or getattr(r, "rationale", "")
                        disc = getattr(r, "discriminates", []) or getattr(r, "discriminates_between", []) or []
                    if not hpo:
                        continue
                    in_pcdt = await _is_in_pcdt(hpo, orphas)
                    out.append(NextQuestion(
                        hpo_id=hpo, name=name or hpo,
                        rationale=rationale or "Discriminative for current differential.",
                        information_gain=gain,
                        discriminates_between=disc,
                        asks_in_pcdt=in_pcdt,
                    ))
                if out:
                    return out
    except ImportError:
        pass
    except Exception as e:
        logger.debug(f"phenotype_recommender failed: {e}")

    # 2) fallback: in-house info-gain path
    differential = []
    for hyp in (getattr(space, "_hypotheses", {}) or {}).values():
        if getattr(hyp, "orpha_code", None):
            differential.append({
                "orpha": hyp.orpha_code,
                "name": getattr(hyp, "disease_name", "") or getattr(hyp, "name", ""),
                "probability": getattr(hyp, "probability", 0.5),
            })
    if not differential:
        return []

    snap = space.get_current_snapshot() if hasattr(space, "get_current_snapshot") else None
    already_present = set()
    if snap:
        for p in snap.phenotypes:
            if p.get("hpo_id"):
                already_present.add(p["hpo_id"])

    return await _info_gain_path(differential, already_present, top_n)