| """PrimeKG cross-attention — graph-RAG into the Diffusion Forcing denoiser. |
| |
| Now uses REAL EDGES from raras-app/data/graph-ml/hetero_graph.json: |
| - disease → has_phenotype → phenotype (curated phenotype linkage) |
| - disease → associated_with → gene (causal gene evidence) |
| - gene → interacts_with → gene (PPI network) |
| - phenotype → is_a → phenotype (HPO ontology) |
| |
| Ego-subgraph BFS: |
| 1. Start from disease node (ORPHA → PrimeKG index) |
| 2. 1-hop: pull connected phenotypes (top-K by edge weight or count) |
| 3. 1-hop: pull connected genes |
| 4. 2-hop: gene→gene neighbors (interacting partners) |
| 5. Concatenate fused embeddings of all selected nodes → cross-attn context |
| |
| Falls back to cosine-similarity if graph not loaded. |
| |
| White-space architecture (May 2026): |
| - EHRWorld, CLARITY, Time-Aware G-Transformer all skip KG conditioning |
| - PhenoKG/RareNet use KG for RETRIEVAL (rare disease diagnosis) |
| - We use it for GENERATION (counterfactual trajectory completion) |
| """ |
| from __future__ import annotations |
| import os |
| import json |
| import logging |
| from functools import lru_cache |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| log = logging.getLogger("gemeo.cdf.kg") |
|
|
| |
| RARAS_KG_DIR = "/Users/dimas/raras-app/data/graph-ml" |
| LOCAL_KG_DIR = os.path.join( |
| os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data") |
|
|
|
|
| def _kg_path(name: str) -> str: |
| """Prefer raras-app path if available, fall back to local fp16.""" |
| raras = os.path.join(RARAS_KG_DIR, name) |
| if os.path.exists(raras): |
| return raras |
| local = os.path.join(LOCAL_KG_DIR, name) |
| return local if os.path.exists(local) else None |
|
|
|
|
| @lru_cache(maxsize=1) |
| def load_kg(prefer_raras: bool = True) -> dict | None: |
| """Load PrimeKG: fused embeddings + node ids + edges + texts. |
| |
| Returns dict: |
| emb : {kind: torch.Tensor(N, 3072)} |
| idx2id : {kind: {pos: id_str}} |
| id2idx : {kind: {id_str: pos}} |
| edges : {edge_type: {'src': [...], 'dst': [...]}} |
| adj : {edge_type: {src_idx: [dst_idx, ...]}} -- precomputed |
| texts : {kind: [str, ...]} -- aligned to position |
| num_nodes : {kind: int} |
| """ |
| |
| emb_path = (os.path.join(RARAS_KG_DIR, "fused_embeddings.npz") |
| if prefer_raras and os.path.exists(os.path.join(RARAS_KG_DIR, "fused_embeddings.npz")) |
| else _kg_path("fused_embeddings_fp16.npz")) |
| if not emb_path or not os.path.exists(emb_path): |
| log.warning("PrimeKG fused embeddings not found") |
| return None |
|
|
| nids_path = _kg_path("node_ids.json") |
| graph_path = _kg_path("hetero_graph.json") |
| texts_path = _kg_path("node_texts.json") |
|
|
| fz = np.load(emb_path) |
| nids = json.load(open(nids_path)) if nids_path else {} |
| graph = json.load(open(graph_path)) if graph_path else {"edges": {}, "num_nodes": {}} |
| texts = json.load(open(texts_path)) if texts_path else {} |
|
|
| out = {"emb": {}, "id2idx": {}, "idx2id": {}, "edges": {}, "adj": {}, |
| "texts": texts, "num_nodes": graph.get("num_nodes", {})} |
|
|
| for kind in ("disease", "phenotype", "gene"): |
| if kind in fz.files: |
| out["emb"][kind] = torch.from_numpy(fz[kind].astype(np.float32)) |
| if kind in nids: |
| out["idx2id"][kind] = {int(k): v for k, v in nids[kind].items()} |
| out["id2idx"][kind] = {v: int(k) for k, v in nids[kind].items()} |
|
|
| |
| for edge_type, edata in graph.get("edges", {}).items(): |
| adj = {} |
| srcs = edata.get("src", []) if isinstance(edata, dict) else [] |
| dsts = edata.get("dst", []) if isinstance(edata, dict) else [] |
| for s, d in zip(srcs, dsts): |
| adj.setdefault(int(s), []).append(int(d)) |
| out["adj"][edge_type] = adj |
| out["edges"][edge_type] = edata |
|
|
| log.info(f" KG loaded from {emb_path}") |
| log.info(f" disease={out['emb'].get('disease', torch.empty(0)).shape}, " |
| f"phenotype={out['emb'].get('phenotype', torch.empty(0)).shape}, " |
| f"gene={out['emb'].get('gene', torch.empty(0)).shape}") |
| log.info(f" edges: {list(out['edges'].keys())}") |
| return out |
|
|
|
|
| def ego_subgraph_real(orpha_code: str, k_pheno: int = 16, k_gene: int = 16, |
| k_gene_2hop: int = 0, kg: dict | None = None) -> torch.Tensor: |
| """BFS ego-subgraph using REAL PrimeKG edges (not cosine similarity). |
| |
| Returns concatenated embeddings (N, 3072) where: |
| - 1 disease node (the query) |
| - up to k_pheno phenotype nodes (direct edges) |
| - up to k_gene gene nodes (direct edges) |
| - up to k_gene_2hop gene-gene 2-hop neighbors |
| |
| Falls back to cosine similarity if no edges available. |
| """ |
| if kg is None: |
| kg = load_kg() |
| if kg is None or "disease" not in kg["emb"]: |
| return None |
|
|
| d_id = kg["id2idx"]["disease"].get(str(orpha_code)) |
| if d_id is None: |
| return None |
|
|
| d_emb = kg["emb"]["disease"][d_id] |
| nodes = [d_emb.unsqueeze(0)] |
|
|
| |
| adj = kg["adj"].get("disease__has_phenotype__phenotype", {}) |
| pheno_neighbors = adj.get(d_id, []) |
| if pheno_neighbors and "phenotype" in kg["emb"]: |
| pheno_neighbors = pheno_neighbors[:k_pheno] |
| nodes.append(kg["emb"]["phenotype"][pheno_neighbors]) |
| elif "phenotype" in kg["emb"]: |
| |
| pool = kg["emb"]["phenotype"] |
| sim = F.cosine_similarity(d_emb.unsqueeze(0), pool, dim=-1) |
| top = sim.topk(min(k_pheno, pool.size(0))).indices |
| nodes.append(pool[top]) |
|
|
| |
| g_adj = kg["adj"].get("disease__associated_with__gene", {}) |
| gene_neighbors = g_adj.get(d_id, []) |
| if gene_neighbors and "gene" in kg["emb"]: |
| gene_neighbors = gene_neighbors[:k_gene] |
| nodes.append(kg["emb"]["gene"][gene_neighbors]) |
|
|
| |
| if k_gene_2hop > 0: |
| gg_adj = kg["adj"].get("gene__interacts_with__gene", {}) |
| seen = set(gene_neighbors) |
| second_hop = [] |
| for g in gene_neighbors: |
| for g2 in gg_adj.get(g, []): |
| if g2 not in seen: |
| second_hop.append(g2) |
| seen.add(g2) |
| if len(second_hop) >= k_gene_2hop: break |
| if len(second_hop) >= k_gene_2hop: break |
| if second_hop: |
| nodes.append(kg["emb"]["gene"][second_hop]) |
| elif "gene" in kg["emb"]: |
| pool = kg["emb"]["gene"] |
| sim = F.cosine_similarity(d_emb.unsqueeze(0), pool, dim=-1) |
| top = sim.topk(min(k_gene, pool.size(0))).indices |
| nodes.append(pool[top]) |
|
|
| return torch.cat(nodes, dim=0) |
|
|
|
|
| |
| ego_subgraph = ego_subgraph_real |
|
|
|
|
| class KGCrossAttention(nn.Module): |
| """Cross-attention from sequence (B, T, d_model) to KG ego (B, N, d_model).""" |
| def __init__(self, d_model: int, n_heads: int = 8, dropout: float = 0.1): |
| super().__init__() |
| self.n_heads = n_heads |
| self.head_dim = d_model // n_heads |
| self.q_proj = nn.Linear(d_model, d_model, bias=False) |
| self.kv_proj = nn.Linear(d_model, 2 * d_model, bias=False) |
| self.out_proj = nn.Linear(d_model, d_model, bias=False) |
| self.norm_q = nn.LayerNorm(d_model) |
| self.norm_kv = nn.LayerNorm(d_model) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x_seq: torch.Tensor, x_kg: torch.Tensor) -> torch.Tensor: |
| B, T, D = x_seq.shape |
| _, N, _ = x_kg.shape |
| q = self.q_proj(self.norm_q(x_seq)) |
| kv = self.kv_proj(self.norm_kv(x_kg)) |
| k, v = kv.chunk(2, dim=-1) |
| q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| k = k.reshape(B, N, self.n_heads, self.head_dim).transpose(1, 2) |
| v = v.reshape(B, N, self.n_heads, self.head_dim).transpose(1, 2) |
| out = F.scaled_dot_product_attention( |
| q, k, v, dropout_p=self.dropout.p if self.training else 0.0) |
| out = out.transpose(1, 2).reshape(B, T, D) |
| return x_seq + self.dropout(self.out_proj(out)) |
|
|
|
|
| class KGProjector(nn.Module): |
| """Project 3072-d KG embeddings to d_model with LayerNorm.""" |
| def __init__(self, kg_dim: int, d_model: int): |
| super().__init__() |
| self.proj = nn.Sequential( |
| nn.Linear(kg_dim, d_model), |
| nn.GELU(), |
| nn.LayerNorm(d_model), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.proj(x) |
|
|
|
|
| def build_kg_batch(orpha_strings: list[str], d_model: int, |
| projector: KGProjector, |
| k_pheno: int = 16, k_gene: int = 16, |
| k_gene_2hop: int = 0) -> torch.Tensor: |
| """Build (B, N, d_model) batched KG context for a batch of patient ORPHAs. |
| |
| Falls back to zero context for missing ORPHAs. |
| """ |
| kg = load_kg() |
| if kg is None: |
| return torch.zeros(len(orpha_strings), 1, d_model, |
| device=next(projector.parameters()).device) |
| N = 1 + k_pheno + k_gene + k_gene_2hop |
| egos = [] |
| for orpha in orpha_strings: |
| e = ego_subgraph_real(orpha, k_pheno, k_gene, k_gene_2hop, kg) |
| if e is None: |
| e = torch.zeros(N, kg["emb"]["disease"].size(-1)) |
| elif e.size(0) < N: |
| pad = torch.zeros(N - e.size(0), e.size(-1)) |
| e = torch.cat([e, pad], dim=0) |
| egos.append(e[:N]) |
| egos = torch.stack(egos, dim=0) |
| return projector(egos.to(next(projector.parameters()).device)) |
|
|
|
|
| def precompute_kg_for_dataset(orpha_codes: list[str], projector: KGProjector, |
| k_pheno: int = 16, k_gene: int = 16, |
| batch_size: int = 32) -> torch.Tensor: |
| """Pre-compute KG context for an entire dataset in batches. |
| |
| Returns (N_patients, kg_nodes, d_model) tensor on projector device. |
| Saves to disk-cacheable format. |
| """ |
| out = [] |
| for i in range(0, len(orpha_codes), batch_size): |
| batch = orpha_codes[i:i + batch_size] |
| ctx = build_kg_batch(batch, projector.proj[0].out_features, |
| projector, k_pheno, k_gene) |
| out.append(ctx.cpu()) |
| return torch.cat(out, dim=0) |
|
|