"""PrimeKG cross-attention — graph-RAG into the Diffusion Forcing denoiser. Now uses REAL EDGES from raras-app/data/graph-ml/hetero_graph.json: - disease → has_phenotype → phenotype (curated phenotype linkage) - disease → associated_with → gene (causal gene evidence) - gene → interacts_with → gene (PPI network) - phenotype → is_a → phenotype (HPO ontology) Ego-subgraph BFS: 1. Start from disease node (ORPHA → PrimeKG index) 2. 1-hop: pull connected phenotypes (top-K by edge weight or count) 3. 1-hop: pull connected genes 4. 2-hop: gene→gene neighbors (interacting partners) 5. Concatenate fused embeddings of all selected nodes → cross-attn context Falls back to cosine-similarity if graph not loaded. White-space architecture (May 2026): - EHRWorld, CLARITY, Time-Aware G-Transformer all skip KG conditioning - PhenoKG/RareNet use KG for RETRIEVAL (rare disease diagnosis) - We use it for GENERATION (counterfactual trajectory completion) """ from __future__ import annotations import os import json import logging from functools import lru_cache import numpy as np import torch import torch.nn as nn import torch.nn.functional as F log = logging.getLogger("gemeo.cdf.kg") # Try raras-app paths first (richer, including hetero_graph edges + node_texts) RARAS_KG_DIR = "/Users/dimas/raras-app/data/graph-ml" LOCAL_KG_DIR = os.path.join( os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data") def _kg_path(name: str) -> str: """Prefer raras-app path if available, fall back to local fp16.""" raras = os.path.join(RARAS_KG_DIR, name) if os.path.exists(raras): return raras local = os.path.join(LOCAL_KG_DIR, name) return local if os.path.exists(local) else None @lru_cache(maxsize=1) def load_kg(prefer_raras: bool = True) -> dict | None: """Load PrimeKG: fused embeddings + node ids + edges + texts. Returns dict: emb : {kind: torch.Tensor(N, 3072)} idx2id : {kind: {pos: id_str}} id2idx : {kind: {id_str: pos}} edges : {edge_type: {'src': [...], 'dst': [...]}} adj : {edge_type: {src_idx: [dst_idx, ...]}} -- precomputed texts : {kind: [str, ...]} -- aligned to position num_nodes : {kind: int} """ # Try raras-app full file first, then local fp16 emb_path = (os.path.join(RARAS_KG_DIR, "fused_embeddings.npz") if prefer_raras and os.path.exists(os.path.join(RARAS_KG_DIR, "fused_embeddings.npz")) else _kg_path("fused_embeddings_fp16.npz")) if not emb_path or not os.path.exists(emb_path): log.warning("PrimeKG fused embeddings not found") return None nids_path = _kg_path("node_ids.json") graph_path = _kg_path("hetero_graph.json") texts_path = _kg_path("node_texts.json") fz = np.load(emb_path) nids = json.load(open(nids_path)) if nids_path else {} graph = json.load(open(graph_path)) if graph_path else {"edges": {}, "num_nodes": {}} texts = json.load(open(texts_path)) if texts_path else {} out = {"emb": {}, "id2idx": {}, "idx2id": {}, "edges": {}, "adj": {}, "texts": texts, "num_nodes": graph.get("num_nodes", {})} for kind in ("disease", "phenotype", "gene"): if kind in fz.files: out["emb"][kind] = torch.from_numpy(fz[kind].astype(np.float32)) if kind in nids: out["idx2id"][kind] = {int(k): v for k, v in nids[kind].items()} out["id2idx"][kind] = {v: int(k) for k, v in nids[kind].items()} # Build adjacency from edges for edge_type, edata in graph.get("edges", {}).items(): adj = {} srcs = edata.get("src", []) if isinstance(edata, dict) else [] dsts = edata.get("dst", []) if isinstance(edata, dict) else [] for s, d in zip(srcs, dsts): adj.setdefault(int(s), []).append(int(d)) out["adj"][edge_type] = adj out["edges"][edge_type] = edata log.info(f" KG loaded from {emb_path}") log.info(f" disease={out['emb'].get('disease', torch.empty(0)).shape}, " f"phenotype={out['emb'].get('phenotype', torch.empty(0)).shape}, " f"gene={out['emb'].get('gene', torch.empty(0)).shape}") log.info(f" edges: {list(out['edges'].keys())}") return out def ego_subgraph_real(orpha_code: str, k_pheno: int = 16, k_gene: int = 16, k_gene_2hop: int = 0, kg: dict | None = None) -> torch.Tensor: """BFS ego-subgraph using REAL PrimeKG edges (not cosine similarity). Returns concatenated embeddings (N, 3072) where: - 1 disease node (the query) - up to k_pheno phenotype nodes (direct edges) - up to k_gene gene nodes (direct edges) - up to k_gene_2hop gene-gene 2-hop neighbors Falls back to cosine similarity if no edges available. """ if kg is None: kg = load_kg() if kg is None or "disease" not in kg["emb"]: return None d_id = kg["id2idx"]["disease"].get(str(orpha_code)) if d_id is None: return None d_emb = kg["emb"]["disease"][d_id] nodes = [d_emb.unsqueeze(0)] # Phenotype neighbors (via disease__has_phenotype__phenotype) adj = kg["adj"].get("disease__has_phenotype__phenotype", {}) pheno_neighbors = adj.get(d_id, []) if pheno_neighbors and "phenotype" in kg["emb"]: pheno_neighbors = pheno_neighbors[:k_pheno] nodes.append(kg["emb"]["phenotype"][pheno_neighbors]) elif "phenotype" in kg["emb"]: # Fallback: cosine similarity pool = kg["emb"]["phenotype"] sim = F.cosine_similarity(d_emb.unsqueeze(0), pool, dim=-1) top = sim.topk(min(k_pheno, pool.size(0))).indices nodes.append(pool[top]) # Gene neighbors (via disease__associated_with__gene) g_adj = kg["adj"].get("disease__associated_with__gene", {}) gene_neighbors = g_adj.get(d_id, []) if gene_neighbors and "gene" in kg["emb"]: gene_neighbors = gene_neighbors[:k_gene] nodes.append(kg["emb"]["gene"][gene_neighbors]) # 2-hop: gene-gene neighbors of the genes we just pulled if k_gene_2hop > 0: gg_adj = kg["adj"].get("gene__interacts_with__gene", {}) seen = set(gene_neighbors) second_hop = [] for g in gene_neighbors: for g2 in gg_adj.get(g, []): if g2 not in seen: second_hop.append(g2) seen.add(g2) if len(second_hop) >= k_gene_2hop: break if len(second_hop) >= k_gene_2hop: break if second_hop: nodes.append(kg["emb"]["gene"][second_hop]) elif "gene" in kg["emb"]: pool = kg["emb"]["gene"] sim = F.cosine_similarity(d_emb.unsqueeze(0), pool, dim=-1) top = sim.topk(min(k_gene, pool.size(0))).indices nodes.append(pool[top]) return torch.cat(nodes, dim=0) # Keep old API name for backward compat ego_subgraph = ego_subgraph_real class KGCrossAttention(nn.Module): """Cross-attention from sequence (B, T, d_model) to KG ego (B, N, d_model).""" def __init__(self, d_model: int, n_heads: int = 8, dropout: float = 0.1): super().__init__() self.n_heads = n_heads self.head_dim = d_model // n_heads self.q_proj = nn.Linear(d_model, d_model, bias=False) self.kv_proj = nn.Linear(d_model, 2 * d_model, bias=False) self.out_proj = nn.Linear(d_model, d_model, bias=False) self.norm_q = nn.LayerNorm(d_model) self.norm_kv = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x_seq: torch.Tensor, x_kg: torch.Tensor) -> torch.Tensor: B, T, D = x_seq.shape _, N, _ = x_kg.shape q = self.q_proj(self.norm_q(x_seq)) kv = self.kv_proj(self.norm_kv(x_kg)) k, v = kv.chunk(2, dim=-1) q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = k.reshape(B, N, self.n_heads, self.head_dim).transpose(1, 2) v = v.reshape(B, N, self.n_heads, self.head_dim).transpose(1, 2) out = F.scaled_dot_product_attention( q, k, v, dropout_p=self.dropout.p if self.training else 0.0) out = out.transpose(1, 2).reshape(B, T, D) return x_seq + self.dropout(self.out_proj(out)) class KGProjector(nn.Module): """Project 3072-d KG embeddings to d_model with LayerNorm.""" def __init__(self, kg_dim: int, d_model: int): super().__init__() self.proj = nn.Sequential( nn.Linear(kg_dim, d_model), nn.GELU(), nn.LayerNorm(d_model), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.proj(x) def build_kg_batch(orpha_strings: list[str], d_model: int, projector: KGProjector, k_pheno: int = 16, k_gene: int = 16, k_gene_2hop: int = 0) -> torch.Tensor: """Build (B, N, d_model) batched KG context for a batch of patient ORPHAs. Falls back to zero context for missing ORPHAs. """ kg = load_kg() if kg is None: return torch.zeros(len(orpha_strings), 1, d_model, device=next(projector.parameters()).device) N = 1 + k_pheno + k_gene + k_gene_2hop egos = [] for orpha in orpha_strings: e = ego_subgraph_real(orpha, k_pheno, k_gene, k_gene_2hop, kg) if e is None: e = torch.zeros(N, kg["emb"]["disease"].size(-1)) elif e.size(0) < N: pad = torch.zeros(N - e.size(0), e.size(-1)) e = torch.cat([e, pad], dim=0) egos.append(e[:N]) egos = torch.stack(egos, dim=0) return projector(egos.to(next(projector.parameters()).device)) def precompute_kg_for_dataset(orpha_codes: list[str], projector: KGProjector, k_pheno: int = 16, k_gene: int = 16, batch_size: int = 32) -> torch.Tensor: """Pre-compute KG context for an entire dataset in batches. Returns (N_patients, kg_nodes, d_model) tensor on projector device. Saves to disk-cacheable format. """ out = [] for i in range(0, len(orpha_codes), batch_size): batch = orpha_codes[i:i + batch_size] ctx = build_kg_batch(batch, projector.proj[0].out_features, projector, k_pheno, k_gene) out.append(ctx.cpu()) return torch.cat(out, dim=0)