gemeo-arch / reference_impl /primekg_attention.py
timmers's picture
GEMEO Architecture v1.0 — spec + reference impl + figure
a0fa886 verified
"""PrimeKG cross-attention — graph-RAG into the Diffusion Forcing denoiser.
Now uses REAL EDGES from raras-app/data/graph-ml/hetero_graph.json:
- disease → has_phenotype → phenotype (curated phenotype linkage)
- disease → associated_with → gene (causal gene evidence)
- gene → interacts_with → gene (PPI network)
- phenotype → is_a → phenotype (HPO ontology)
Ego-subgraph BFS:
1. Start from disease node (ORPHA → PrimeKG index)
2. 1-hop: pull connected phenotypes (top-K by edge weight or count)
3. 1-hop: pull connected genes
4. 2-hop: gene→gene neighbors (interacting partners)
5. Concatenate fused embeddings of all selected nodes → cross-attn context
Falls back to cosine-similarity if graph not loaded.
White-space architecture (May 2026):
- EHRWorld, CLARITY, Time-Aware G-Transformer all skip KG conditioning
- PhenoKG/RareNet use KG for RETRIEVAL (rare disease diagnosis)
- We use it for GENERATION (counterfactual trajectory completion)
"""
from __future__ import annotations
import os
import json
import logging
from functools import lru_cache
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
log = logging.getLogger("gemeo.cdf.kg")
# Try raras-app paths first (richer, including hetero_graph edges + node_texts)
RARAS_KG_DIR = "/Users/dimas/raras-app/data/graph-ml"
LOCAL_KG_DIR = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
def _kg_path(name: str) -> str:
"""Prefer raras-app path if available, fall back to local fp16."""
raras = os.path.join(RARAS_KG_DIR, name)
if os.path.exists(raras):
return raras
local = os.path.join(LOCAL_KG_DIR, name)
return local if os.path.exists(local) else None
@lru_cache(maxsize=1)
def load_kg(prefer_raras: bool = True) -> dict | None:
"""Load PrimeKG: fused embeddings + node ids + edges + texts.
Returns dict:
emb : {kind: torch.Tensor(N, 3072)}
idx2id : {kind: {pos: id_str}}
id2idx : {kind: {id_str: pos}}
edges : {edge_type: {'src': [...], 'dst': [...]}}
adj : {edge_type: {src_idx: [dst_idx, ...]}} -- precomputed
texts : {kind: [str, ...]} -- aligned to position
num_nodes : {kind: int}
"""
# Try raras-app full file first, then local fp16
emb_path = (os.path.join(RARAS_KG_DIR, "fused_embeddings.npz")
if prefer_raras and os.path.exists(os.path.join(RARAS_KG_DIR, "fused_embeddings.npz"))
else _kg_path("fused_embeddings_fp16.npz"))
if not emb_path or not os.path.exists(emb_path):
log.warning("PrimeKG fused embeddings not found")
return None
nids_path = _kg_path("node_ids.json")
graph_path = _kg_path("hetero_graph.json")
texts_path = _kg_path("node_texts.json")
fz = np.load(emb_path)
nids = json.load(open(nids_path)) if nids_path else {}
graph = json.load(open(graph_path)) if graph_path else {"edges": {}, "num_nodes": {}}
texts = json.load(open(texts_path)) if texts_path else {}
out = {"emb": {}, "id2idx": {}, "idx2id": {}, "edges": {}, "adj": {},
"texts": texts, "num_nodes": graph.get("num_nodes", {})}
for kind in ("disease", "phenotype", "gene"):
if kind in fz.files:
out["emb"][kind] = torch.from_numpy(fz[kind].astype(np.float32))
if kind in nids:
out["idx2id"][kind] = {int(k): v for k, v in nids[kind].items()}
out["id2idx"][kind] = {v: int(k) for k, v in nids[kind].items()}
# Build adjacency from edges
for edge_type, edata in graph.get("edges", {}).items():
adj = {}
srcs = edata.get("src", []) if isinstance(edata, dict) else []
dsts = edata.get("dst", []) if isinstance(edata, dict) else []
for s, d in zip(srcs, dsts):
adj.setdefault(int(s), []).append(int(d))
out["adj"][edge_type] = adj
out["edges"][edge_type] = edata
log.info(f" KG loaded from {emb_path}")
log.info(f" disease={out['emb'].get('disease', torch.empty(0)).shape}, "
f"phenotype={out['emb'].get('phenotype', torch.empty(0)).shape}, "
f"gene={out['emb'].get('gene', torch.empty(0)).shape}")
log.info(f" edges: {list(out['edges'].keys())}")
return out
def ego_subgraph_real(orpha_code: str, k_pheno: int = 16, k_gene: int = 16,
k_gene_2hop: int = 0, kg: dict | None = None) -> torch.Tensor:
"""BFS ego-subgraph using REAL PrimeKG edges (not cosine similarity).
Returns concatenated embeddings (N, 3072) where:
- 1 disease node (the query)
- up to k_pheno phenotype nodes (direct edges)
- up to k_gene gene nodes (direct edges)
- up to k_gene_2hop gene-gene 2-hop neighbors
Falls back to cosine similarity if no edges available.
"""
if kg is None:
kg = load_kg()
if kg is None or "disease" not in kg["emb"]:
return None
d_id = kg["id2idx"]["disease"].get(str(orpha_code))
if d_id is None:
return None
d_emb = kg["emb"]["disease"][d_id]
nodes = [d_emb.unsqueeze(0)]
# Phenotype neighbors (via disease__has_phenotype__phenotype)
adj = kg["adj"].get("disease__has_phenotype__phenotype", {})
pheno_neighbors = adj.get(d_id, [])
if pheno_neighbors and "phenotype" in kg["emb"]:
pheno_neighbors = pheno_neighbors[:k_pheno]
nodes.append(kg["emb"]["phenotype"][pheno_neighbors])
elif "phenotype" in kg["emb"]:
# Fallback: cosine similarity
pool = kg["emb"]["phenotype"]
sim = F.cosine_similarity(d_emb.unsqueeze(0), pool, dim=-1)
top = sim.topk(min(k_pheno, pool.size(0))).indices
nodes.append(pool[top])
# Gene neighbors (via disease__associated_with__gene)
g_adj = kg["adj"].get("disease__associated_with__gene", {})
gene_neighbors = g_adj.get(d_id, [])
if gene_neighbors and "gene" in kg["emb"]:
gene_neighbors = gene_neighbors[:k_gene]
nodes.append(kg["emb"]["gene"][gene_neighbors])
# 2-hop: gene-gene neighbors of the genes we just pulled
if k_gene_2hop > 0:
gg_adj = kg["adj"].get("gene__interacts_with__gene", {})
seen = set(gene_neighbors)
second_hop = []
for g in gene_neighbors:
for g2 in gg_adj.get(g, []):
if g2 not in seen:
second_hop.append(g2)
seen.add(g2)
if len(second_hop) >= k_gene_2hop: break
if len(second_hop) >= k_gene_2hop: break
if second_hop:
nodes.append(kg["emb"]["gene"][second_hop])
elif "gene" in kg["emb"]:
pool = kg["emb"]["gene"]
sim = F.cosine_similarity(d_emb.unsqueeze(0), pool, dim=-1)
top = sim.topk(min(k_gene, pool.size(0))).indices
nodes.append(pool[top])
return torch.cat(nodes, dim=0)
# Keep old API name for backward compat
ego_subgraph = ego_subgraph_real
class KGCrossAttention(nn.Module):
"""Cross-attention from sequence (B, T, d_model) to KG ego (B, N, d_model)."""
def __init__(self, d_model: int, n_heads: int = 8, dropout: float = 0.1):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.kv_proj = nn.Linear(d_model, 2 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.norm_q = nn.LayerNorm(d_model)
self.norm_kv = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x_seq: torch.Tensor, x_kg: torch.Tensor) -> torch.Tensor:
B, T, D = x_seq.shape
_, N, _ = x_kg.shape
q = self.q_proj(self.norm_q(x_seq))
kv = self.kv_proj(self.norm_kv(x_kg))
k, v = kv.chunk(2, dim=-1)
q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = k.reshape(B, N, self.n_heads, self.head_dim).transpose(1, 2)
v = v.reshape(B, N, self.n_heads, self.head_dim).transpose(1, 2)
out = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.dropout.p if self.training else 0.0)
out = out.transpose(1, 2).reshape(B, T, D)
return x_seq + self.dropout(self.out_proj(out))
class KGProjector(nn.Module):
"""Project 3072-d KG embeddings to d_model with LayerNorm."""
def __init__(self, kg_dim: int, d_model: int):
super().__init__()
self.proj = nn.Sequential(
nn.Linear(kg_dim, d_model),
nn.GELU(),
nn.LayerNorm(d_model),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x)
def build_kg_batch(orpha_strings: list[str], d_model: int,
projector: KGProjector,
k_pheno: int = 16, k_gene: int = 16,
k_gene_2hop: int = 0) -> torch.Tensor:
"""Build (B, N, d_model) batched KG context for a batch of patient ORPHAs.
Falls back to zero context for missing ORPHAs.
"""
kg = load_kg()
if kg is None:
return torch.zeros(len(orpha_strings), 1, d_model,
device=next(projector.parameters()).device)
N = 1 + k_pheno + k_gene + k_gene_2hop
egos = []
for orpha in orpha_strings:
e = ego_subgraph_real(orpha, k_pheno, k_gene, k_gene_2hop, kg)
if e is None:
e = torch.zeros(N, kg["emb"]["disease"].size(-1))
elif e.size(0) < N:
pad = torch.zeros(N - e.size(0), e.size(-1))
e = torch.cat([e, pad], dim=0)
egos.append(e[:N])
egos = torch.stack(egos, dim=0)
return projector(egos.to(next(projector.parameters()).device))
def precompute_kg_for_dataset(orpha_codes: list[str], projector: KGProjector,
k_pheno: int = 16, k_gene: int = 16,
batch_size: int = 32) -> torch.Tensor:
"""Pre-compute KG context for an entire dataset in batches.
Returns (N_patients, kg_nodes, d_model) tensor on projector device.
Saves to disk-cacheable format.
"""
out = []
for i in range(0, len(orpha_codes), batch_size):
batch = orpha_codes[i:i + batch_size]
ctx = build_kg_batch(batch, projector.proj[0].out_features,
projector, k_pheno, k_gene)
out.append(ctx.cpu())
return torch.cat(out, dim=0)