File size: 6,626 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""PatientEncoder β€” turn a ClinicalSnapshot into a vector.

Strategy:
  1. **Bootstrap (works TODAY)**: weighted mean of pre-computed embeddings
     from raras-app graph-ml β€” disease (Γ—2), phenotype (Γ—1), gene (Γ—1).
     This matches what raras-app does in `generate-patient-embeddings.mjs`,
     so embeddings live in the same space as `Patient.embedding` in Neo4j
     and `/api/graph/similar-patients`.

  2. **HGT (Phase 2, gemeo/train/hgt.py)**: trained heterogeneous graph
     transformer that produces patient embeddings via attention over
     the patient's HPO+Gene+Lab subgraph. Replaces step (1) when
     `gemeo/artifacts/hgt_patient_encoder.pt` exists.

The fallback chain is deterministic β€” `encode()` always returns a vector
(even if zeros, signalled by `quality='empty'`).
"""
from __future__ import annotations
import logging
import os
from typing import Optional

from . import bridge

logger = logging.getLogger("gemeo.encoder")

DEFAULT_DIM = 3072  # matches raras-app fused embedding & Neo4j vector index
HGT_CKPT = os.environ.get(
    "GEMEO_HGT_CKPT",
    os.path.join(os.path.dirname(__file__), "artifacts", "hgt_patient_encoder.pt"),
)


def _l2_normalize(vec):
    import numpy as np
    n = np.linalg.norm(vec)
    if n < 1e-9:
        return vec
    return vec / n


def encode_bootstrap(
    phenotypes: list,
    diseases: list,
    genes: list,
    *,
    dim: int = DEFAULT_DIM,
    weight_disease: float = 2.0,
    weight_phenotype: float = 1.0,
    weight_gene: float = 1.0,
):
    """Aggregate raras-app fused embeddings β€” same space as Neo4j Patient.embedding.

    Args:
        phenotypes: list of HPO ids ["HP:0001250", ...] or dicts {"hpo_id": ...}
        diseases:   list of ORPHA codes ["79253", ...] or dicts {"orpha": ...}
        genes:      list of HGNC symbols ["GBA", ...] or dicts {"symbol": ...}

    Returns:
        (vector: np.ndarray (dim,), quality: str)
        quality ∈ {"empty", "partial", "full"}
    """
    import numpy as np

    def _ids(coll, key_options):
        out = []
        for c in coll or []:
            if isinstance(c, str):
                out.append(c)
            elif isinstance(c, dict):
                for k in key_options:
                    if c.get(k):
                        out.append(c[k]); break
        return out

    hpo_ids = _ids(phenotypes, ["hpo_id", "hpoId", "id"])
    orpha_ids = _ids(diseases, ["orpha", "orpha_code", "orphaCode", "code"])
    gene_ids = _ids(genes, ["symbol", "gene", "name"])

    accum = np.zeros(dim, dtype=np.float32)
    n = 0
    hits = {"disease": 0, "phenotype": 0, "gene": 0}
    misses = {"disease": 0, "phenotype": 0, "gene": 0}

    for orpha in orpha_ids:
        v = bridge.lookup_disease_embedding(str(orpha))
        if v is not None and v.shape[0] == dim:
            accum += weight_disease * v.astype(np.float32)
            n += weight_disease
            hits["disease"] += 1
        else:
            misses["disease"] += 1

    for hpo in hpo_ids:
        v = bridge.lookup_phenotype_embedding(str(hpo))
        if v is not None and v.shape[0] == dim:
            accum += weight_phenotype * v.astype(np.float32)
            n += weight_phenotype
            hits["phenotype"] += 1
        else:
            misses["phenotype"] += 1

    for sym in gene_ids:
        v = bridge.lookup_gene_embedding(str(sym).upper())
        if v is not None and v.shape[0] == dim:
            accum += weight_gene * v.astype(np.float32)
            n += weight_gene
            hits["gene"] += 1
        else:
            misses["gene"] += 1

    if n == 0:
        return np.zeros(dim, dtype=np.float32), "empty"

    avg = accum / n
    avg = _l2_normalize(avg)

    n_input = len(hpo_ids) + len(orpha_ids) + len(gene_ids)
    n_hit = sum(hits.values())
    quality = "full" if n_hit == n_input else "partial"
    return avg, quality


# ─── HGT slot (Phase 2) ────────────────────────────────────────────────────

_HGT_MODEL = None


def _try_load_hgt():
    global _HGT_MODEL
    if _HGT_MODEL is not None:
        return _HGT_MODEL
    if not os.path.exists(HGT_CKPT):
        return None
    try:
        import torch
        _HGT_MODEL = torch.load(HGT_CKPT, map_location="cpu", weights_only=False)
        logger.info(f"Loaded HGT patient encoder from {HGT_CKPT}")
        return _HGT_MODEL
    except Exception as e:
        logger.warning(f"HGT checkpoint exists but failed to load: {e}")
        return None


def encode_hgt(snapshot_dict: dict, dim: int = DEFAULT_DIM):
    """Run the trained HGT patient encoder. Returns None if model unavailable."""
    model = _try_load_hgt()
    if model is None:
        return None
    try:
        # The actual forward signature is defined by gemeo/train/hgt.py.
        # We pass a structured dict with keys: phenotypes, genes, labs, diseases.
        if hasattr(model, "encode_patient"):
            return model.encode_patient(snapshot_dict)
        return None
    except Exception as e:
        logger.error(f"HGT encode failed, falling back to bootstrap: {e}")
        return None


def encode(snapshot_dict: dict, *, dim: int = DEFAULT_DIM):
    """Top-level encoder. Tries HGT, falls back to bootstrap.

    snapshot_dict expected keys:
        phenotypes: [{hpo_id, name, ...}, ...]
        diseases:   [{orpha, ...}, ...] (optional β€” confirmed/probable dx)
        genes:      [{symbol, ...}, ...]
    """
    v = encode_hgt(snapshot_dict, dim=dim)
    if v is not None:
        return v, "hgt"
    vec, qual = encode_bootstrap(
        phenotypes=snapshot_dict.get("phenotypes", []),
        diseases=snapshot_dict.get("diseases", []),
        genes=snapshot_dict.get("genes", []),
        dim=dim,
    )
    return vec, f"bootstrap_{qual}"


def encode_patient_space(space) -> tuple:
    """Convenience: encode a `PatientSpace` object directly."""
    snap = space.get_current_snapshot() if hasattr(space, "get_current_snapshot") else None
    if snap is None:
        # Fallback: no snapshots, build from hypotheses + recent events
        diseases = []
        for hyp in (getattr(space, "_hypotheses", {}) or {}).values():
            if hasattr(hyp, "orpha_code") and hyp.orpha_code:
                diseases.append({"orpha": hyp.orpha_code})
        return encode({
            "phenotypes": [],
            "diseases": diseases,
            "genes": [],
        })
    return encode({
        "phenotypes": snap.phenotypes,
        "diseases": snap.diagnoses,
        "genes": snap.genes,
    })