File size: 10,565 Bytes
a0fa886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
"""PrimeKG cross-attention — graph-RAG into the Diffusion Forcing denoiser.

Now uses REAL EDGES from raras-app/data/graph-ml/hetero_graph.json:
  - disease → has_phenotype → phenotype  (curated phenotype linkage)
  - disease → associated_with → gene     (causal gene evidence)
  - gene → interacts_with → gene         (PPI network)
  - phenotype → is_a → phenotype         (HPO ontology)

Ego-subgraph BFS:
  1. Start from disease node (ORPHA → PrimeKG index)
  2. 1-hop: pull connected phenotypes (top-K by edge weight or count)
  3. 1-hop: pull connected genes
  4. 2-hop: gene→gene neighbors (interacting partners)
  5. Concatenate fused embeddings of all selected nodes → cross-attn context

Falls back to cosine-similarity if graph not loaded.

White-space architecture (May 2026):
  - EHRWorld, CLARITY, Time-Aware G-Transformer all skip KG conditioning
  - PhenoKG/RareNet use KG for RETRIEVAL (rare disease diagnosis)
  - We use it for GENERATION (counterfactual trajectory completion)
"""
from __future__ import annotations
import os
import json
import logging
from functools import lru_cache

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

log = logging.getLogger("gemeo.cdf.kg")

# Try raras-app paths first (richer, including hetero_graph edges + node_texts)
RARAS_KG_DIR = "/Users/dimas/raras-app/data/graph-ml"
LOCAL_KG_DIR = os.path.join(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")


def _kg_path(name: str) -> str:
    """Prefer raras-app path if available, fall back to local fp16."""
    raras = os.path.join(RARAS_KG_DIR, name)
    if os.path.exists(raras):
        return raras
    local = os.path.join(LOCAL_KG_DIR, name)
    return local if os.path.exists(local) else None


@lru_cache(maxsize=1)
def load_kg(prefer_raras: bool = True) -> dict | None:
    """Load PrimeKG: fused embeddings + node ids + edges + texts.

    Returns dict:
        emb        : {kind: torch.Tensor(N, 3072)}
        idx2id     : {kind: {pos: id_str}}
        id2idx     : {kind: {id_str: pos}}
        edges      : {edge_type: {'src': [...], 'dst': [...]}}
        adj        : {edge_type: {src_idx: [dst_idx, ...]}}  -- precomputed
        texts      : {kind: [str, ...]}  -- aligned to position
        num_nodes  : {kind: int}
    """
    # Try raras-app full file first, then local fp16
    emb_path = (os.path.join(RARAS_KG_DIR, "fused_embeddings.npz")
                if prefer_raras and os.path.exists(os.path.join(RARAS_KG_DIR, "fused_embeddings.npz"))
                else _kg_path("fused_embeddings_fp16.npz"))
    if not emb_path or not os.path.exists(emb_path):
        log.warning("PrimeKG fused embeddings not found")
        return None

    nids_path = _kg_path("node_ids.json")
    graph_path = _kg_path("hetero_graph.json")
    texts_path = _kg_path("node_texts.json")

    fz = np.load(emb_path)
    nids = json.load(open(nids_path)) if nids_path else {}
    graph = json.load(open(graph_path)) if graph_path else {"edges": {}, "num_nodes": {}}
    texts = json.load(open(texts_path)) if texts_path else {}

    out = {"emb": {}, "id2idx": {}, "idx2id": {}, "edges": {}, "adj": {},
           "texts": texts, "num_nodes": graph.get("num_nodes", {})}

    for kind in ("disease", "phenotype", "gene"):
        if kind in fz.files:
            out["emb"][kind] = torch.from_numpy(fz[kind].astype(np.float32))
            if kind in nids:
                out["idx2id"][kind] = {int(k): v for k, v in nids[kind].items()}
                out["id2idx"][kind] = {v: int(k) for k, v in nids[kind].items()}

    # Build adjacency from edges
    for edge_type, edata in graph.get("edges", {}).items():
        adj = {}
        srcs = edata.get("src", []) if isinstance(edata, dict) else []
        dsts = edata.get("dst", []) if isinstance(edata, dict) else []
        for s, d in zip(srcs, dsts):
            adj.setdefault(int(s), []).append(int(d))
        out["adj"][edge_type] = adj
        out["edges"][edge_type] = edata

    log.info(f"  KG loaded from {emb_path}")
    log.info(f"  disease={out['emb'].get('disease', torch.empty(0)).shape}, "
             f"phenotype={out['emb'].get('phenotype', torch.empty(0)).shape}, "
             f"gene={out['emb'].get('gene', torch.empty(0)).shape}")
    log.info(f"  edges: {list(out['edges'].keys())}")
    return out


def ego_subgraph_real(orpha_code: str, k_pheno: int = 16, k_gene: int = 16,
                      k_gene_2hop: int = 0, kg: dict | None = None) -> torch.Tensor:
    """BFS ego-subgraph using REAL PrimeKG edges (not cosine similarity).

    Returns concatenated embeddings (N, 3072) where:
      - 1 disease node (the query)
      - up to k_pheno phenotype nodes (direct edges)
      - up to k_gene gene nodes (direct edges)
      - up to k_gene_2hop gene-gene 2-hop neighbors

    Falls back to cosine similarity if no edges available.
    """
    if kg is None:
        kg = load_kg()
    if kg is None or "disease" not in kg["emb"]:
        return None

    d_id = kg["id2idx"]["disease"].get(str(orpha_code))
    if d_id is None:
        return None

    d_emb = kg["emb"]["disease"][d_id]
    nodes = [d_emb.unsqueeze(0)]

    # Phenotype neighbors (via disease__has_phenotype__phenotype)
    adj = kg["adj"].get("disease__has_phenotype__phenotype", {})
    pheno_neighbors = adj.get(d_id, [])
    if pheno_neighbors and "phenotype" in kg["emb"]:
        pheno_neighbors = pheno_neighbors[:k_pheno]
        nodes.append(kg["emb"]["phenotype"][pheno_neighbors])
    elif "phenotype" in kg["emb"]:
        # Fallback: cosine similarity
        pool = kg["emb"]["phenotype"]
        sim = F.cosine_similarity(d_emb.unsqueeze(0), pool, dim=-1)
        top = sim.topk(min(k_pheno, pool.size(0))).indices
        nodes.append(pool[top])

    # Gene neighbors (via disease__associated_with__gene)
    g_adj = kg["adj"].get("disease__associated_with__gene", {})
    gene_neighbors = g_adj.get(d_id, [])
    if gene_neighbors and "gene" in kg["emb"]:
        gene_neighbors = gene_neighbors[:k_gene]
        nodes.append(kg["emb"]["gene"][gene_neighbors])

        # 2-hop: gene-gene neighbors of the genes we just pulled
        if k_gene_2hop > 0:
            gg_adj = kg["adj"].get("gene__interacts_with__gene", {})
            seen = set(gene_neighbors)
            second_hop = []
            for g in gene_neighbors:
                for g2 in gg_adj.get(g, []):
                    if g2 not in seen:
                        second_hop.append(g2)
                        seen.add(g2)
                        if len(second_hop) >= k_gene_2hop: break
                if len(second_hop) >= k_gene_2hop: break
            if second_hop:
                nodes.append(kg["emb"]["gene"][second_hop])
    elif "gene" in kg["emb"]:
        pool = kg["emb"]["gene"]
        sim = F.cosine_similarity(d_emb.unsqueeze(0), pool, dim=-1)
        top = sim.topk(min(k_gene, pool.size(0))).indices
        nodes.append(pool[top])

    return torch.cat(nodes, dim=0)


# Keep old API name for backward compat
ego_subgraph = ego_subgraph_real


class KGCrossAttention(nn.Module):
    """Cross-attention from sequence (B, T, d_model) to KG ego (B, N, d_model)."""
    def __init__(self, d_model: int, n_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.kv_proj = nn.Linear(d_model, 2 * d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        self.norm_q = nn.LayerNorm(d_model)
        self.norm_kv = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x_seq: torch.Tensor, x_kg: torch.Tensor) -> torch.Tensor:
        B, T, D = x_seq.shape
        _, N, _ = x_kg.shape
        q = self.q_proj(self.norm_q(x_seq))
        kv = self.kv_proj(self.norm_kv(x_kg))
        k, v = kv.chunk(2, dim=-1)
        q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.reshape(B, N, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.reshape(B, N, self.n_heads, self.head_dim).transpose(1, 2)
        out = F.scaled_dot_product_attention(
            q, k, v, dropout_p=self.dropout.p if self.training else 0.0)
        out = out.transpose(1, 2).reshape(B, T, D)
        return x_seq + self.dropout(self.out_proj(out))


class KGProjector(nn.Module):
    """Project 3072-d KG embeddings to d_model with LayerNorm."""
    def __init__(self, kg_dim: int, d_model: int):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(kg_dim, d_model),
            nn.GELU(),
            nn.LayerNorm(d_model),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.proj(x)


def build_kg_batch(orpha_strings: list[str], d_model: int,
                   projector: KGProjector,
                   k_pheno: int = 16, k_gene: int = 16,
                   k_gene_2hop: int = 0) -> torch.Tensor:
    """Build (B, N, d_model) batched KG context for a batch of patient ORPHAs.

    Falls back to zero context for missing ORPHAs.
    """
    kg = load_kg()
    if kg is None:
        return torch.zeros(len(orpha_strings), 1, d_model,
                          device=next(projector.parameters()).device)
    N = 1 + k_pheno + k_gene + k_gene_2hop
    egos = []
    for orpha in orpha_strings:
        e = ego_subgraph_real(orpha, k_pheno, k_gene, k_gene_2hop, kg)
        if e is None:
            e = torch.zeros(N, kg["emb"]["disease"].size(-1))
        elif e.size(0) < N:
            pad = torch.zeros(N - e.size(0), e.size(-1))
            e = torch.cat([e, pad], dim=0)
        egos.append(e[:N])
    egos = torch.stack(egos, dim=0)
    return projector(egos.to(next(projector.parameters()).device))


def precompute_kg_for_dataset(orpha_codes: list[str], projector: KGProjector,
                              k_pheno: int = 16, k_gene: int = 16,
                              batch_size: int = 32) -> torch.Tensor:
    """Pre-compute KG context for an entire dataset in batches.

    Returns (N_patients, kg_nodes, d_model) tensor on projector device.
    Saves to disk-cacheable format.
    """
    out = []
    for i in range(0, len(orpha_codes), batch_size):
        batch = orpha_codes[i:i + batch_size]
        ctx = build_kg_batch(batch, projector.proj[0].out_features,
                             projector, k_pheno, k_gene)
        out.append(ctx.cpu())
    return torch.cat(out, dim=0)