timmers commited on
Commit
123aa6b
Β·
verified Β·
1 Parent(s): 8ef444f

Upload pillar_a_kg_proposer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pillar_a_kg_proposer.py +163 -80
pillar_a_kg_proposer.py CHANGED
@@ -1,81 +1,164 @@
1
- """GEMEO v2.0 β€” Pillar A: KG zero-shot onset proposer (worked demonstration).
2
-
3
- Given a patient's primary rare disease (ORPHA), use PrimeKG graph embeddings to
4
- propose the most likely UNSEEN onset candidates β€” related diseases (comorbid
5
- progression), phenotypes (complications), and genes β€” that the patient has not
6
- yet manifested. This is the structural-novelty source that a frequency /
7
- repeat-last-code baseline cannot fabricate.
8
-
9
- Mechanism: cosine similarity in the 64-dim PrimeKG graph-embedding space, which
10
- encodes disease–disease, disease–phenotype, disease–gene proximity. Candidates
11
- the patient already has are excluded (the new-onset constraint).
12
-
13
- This is a runnable demonstration on the major rare diseases in the GEMEO/SUS
14
- cohort. Full quantitative recall requires phenotype-annotated trajectories
15
- (Mayo multimodal substrate); here we show the mechanism produces biologically
16
- correct candidates.
 
 
 
 
 
 
 
 
 
 
17
  """
18
- import numpy as np, json, os
19
-
20
- BASE = os.path.expanduser("~/rarasnet-swarm-py/gemeo/data")
21
- emb = np.load(f"{BASE}/graph_embeddings.npz")
22
- node_ids = json.load(open(f"{BASE}/node_ids.json"))
23
-
24
- # Build orpha_id -> (node_type, row) index
25
- disease_emb = emb["disease"]; phen_emb = emb["phenotype"]; gene_emb = emb["gene"]
26
- # node_ids[type] is {row_index_str: orpha_or_hpo_or_gene_id}
27
- dis_id2row = {v: int(k) for k, v in node_ids["disease"].items()}
28
- phen_row2id = {int(k): v for k, v in node_ids["phenotype"].items()}
29
- gene_row2id = {int(k): v for k, v in node_ids["gene"].items()}
30
- dis_row2id = {int(k): v for k, v in node_ids["disease"].items()}
31
-
32
- def norm(x): return x / (np.linalg.norm(x, axis=-1, keepdims=True) + 1e-8)
33
- D = norm(disease_emb); P = norm(phen_emb); G = norm(gene_emb)
34
-
35
- # Cohort rare diseases (ORPHA codes from RARE_CIDS_APAC)
36
- COHORT = {
37
- "100": "Ataxia-telangiectasia", "646": "Niemann-Pick / Gaucher cluster",
38
- "355": "Gaucher disease", "98896": "Duchenne muscular dystrophy",
39
- "70": "Spinal muscular atrophy type 1", "586": "Cystic fibrosis",
40
- "579": "Mucopolysaccharidosis type I", "580": "Mucopolysaccharidosis type II",
41
- "905": "Wilson disease", "95": "Friedreich ataxia", "558": "Marfan syndrome",
42
- "636": "Neurofibromatosis type 1", "778": "Rett syndrome", "183660": "SCID",
43
- }
44
-
45
- def propose(orpha, k=8):
46
- """Return top-k unseen onset candidates (diseases, phenotypes, genes)."""
47
- if orpha not in dis_id2row:
48
- return None
49
- row = dis_id2row[orpha]
50
- q = D[row]
51
- # Related diseases (comorbid/progression)
52
- dsim = D @ q; dsim[row] = -1
53
- top_dis = [(dis_row2id[int(i)], float(dsim[i])) for i in np.argsort(-dsim)[:k]]
54
- # Related phenotypes (complications)
55
- psim = P @ q
56
- top_phen = [(phen_row2id[int(i)], float(psim[i])) for i in np.argsort(-psim)[:k]]
57
- # Related genes
58
- gsim = G @ q
59
- top_gene = [(gene_row2id[int(i)], float(gsim[i])) for i in np.argsort(-gsim)[:k]]
60
- return {"diseases": top_dis, "phenotypes": top_phen, "genes": top_gene}
61
-
62
- print("=" * 78)
63
- print("GEMEO v2.0 β€” Pillar A: KG zero-shot onset proposer (worked demonstration)")
64
- print("=" * 78)
65
- results = {}
66
- n_mapped = 0
67
- for orpha, name in COHORT.items():
68
- r = propose(orpha, k=6)
69
- if r is None:
70
- print(f"\n[{orpha}] {name}: NOT in PrimeKG disease nodes")
71
- continue
72
- n_mapped += 1
73
- results[orpha] = {"name": name, **r}
74
- print(f"\n[ORPHA:{orpha}] {name}")
75
- print(f" β†’ candidate comorbid-disease onsets: {[d for d,_ in r['diseases'][:5]]}")
76
- print(f" β†’ candidate phenotype complications: {[p for p,_ in r['phenotypes'][:5]]}")
77
- print(f" β†’ associated genes: {[g for g,_ in r['genes'][:5]]}")
78
-
79
- print(f"\n{n_mapped}/{len(COHORT)} cohort diseases mapped to PrimeKG.")
80
- json.dump(results, open("/tmp/pillar_a_demo.json", "w"), indent=2)
81
- print("Saved /tmp/pillar_a_demo.json")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GEMEO v2.0 β€” Pillar A: KG onset proposer via Random-Walk-with-Restart (RWR).
2
+
3
+ Given a patient's *manifested* clinical state (their diseases + observed
4
+ phenotypes + known variant genes), propose the most likely UNSEEN first-onset
5
+ candidates β€” genes, phenotypes (complications), and related diseases β€” that the
6
+ patient has not yet had.
7
+
8
+ Method (SOTA, training-free): **Random Walk with Restart on the heterogeneous
9
+ knowledge graph**, the state-of-the-art guilt-by-association algorithm for
10
+ network-based gene/disease prioritization (RWRH / MultiXrank lineage;
11
+ Valdeolivas 2019, Bioinformatics; Picart-Armada 2023). The walk restarts from
12
+ the patient's seed nodes with probability `r` and otherwise diffuses along typed
13
+ edges; the stationary visitation probability ranks every node by network
14
+ proximity to the patient's actual state. This is real link prediction by
15
+ network propagation β€” not embedding cosine similarity.
16
+
17
+ Key properties:
18
+ - **New-onset filter**: every already-manifested node is removed from the
19
+ ranking (we propose what the patient does NOT yet have).
20
+ - **Genomic seeding (optional)**: if variant pathogenicity scores are supplied
21
+ (e.g. from Evo 2 / AlphaMissense), the patient's variant-bearing genes are
22
+ weighted into the restart vector β€” the genome steers the proposal.
23
+ - **Traceable evidence**: each candidate ships its shortest path back to a seed.
24
+
25
+ Runs out-of-the-box on the built-in rare-disease fixture; transparently uses the
26
+ full PrimeKG graph when present (see rare_disease_kg.load_kg).
27
  """
28
+ from __future__ import annotations
29
+ import json, os, sys
30
+ from collections import deque
31
+
32
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
33
+ from rare_disease_kg import load_kg, KG
34
+
35
+
36
+ def rwr(kg: KG, seeds: dict[str, float], restart: float = 0.30,
37
+ tol: float = 1e-8, max_iter: int = 200) -> dict[str, float]:
38
+ """Random Walk with Restart. `seeds` maps node id -> nonneg weight.
39
+
40
+ Returns a stationary distribution over all nodes (column-normalized
41
+ transition, uniform over typed neighbors). Pure-Python sparse iteration β€”
42
+ no numpy dependency, scales fine to PrimeKG via the adjacency sets.
43
+ """
44
+ nodes = kg.nodes
45
+ s = {n: 0.0 for n in nodes}
46
+ z = sum(seeds.values()) or 1.0
47
+ for n, w in seeds.items():
48
+ if n in s:
49
+ s[n] = w / z
50
+ p = dict(s)
51
+ deg = {n: max(1, len(kg.neighbors(n))) for n in nodes}
52
+ for _ in range(max_iter):
53
+ nxt = {n: restart * s[n] for n in nodes}
54
+ for u in nodes:
55
+ pu = p[u]
56
+ if pu == 0.0:
57
+ continue
58
+ share = (1.0 - restart) * pu / deg[u]
59
+ for v in kg.neighbors(u):
60
+ nxt[v] += share
61
+ diff = sum(abs(nxt[n] - p[n]) for n in nodes)
62
+ p = nxt
63
+ if diff < tol:
64
+ break
65
+ return p
66
+
67
+
68
+ def shortest_path(kg: KG, src: str, dst: str, max_hops: int = 4):
69
+ """BFS evidence path src→…→dst with relation labels (for auditability)."""
70
+ if src == dst:
71
+ return [src]
72
+ seen = {src}; q = deque([(src, [src])])
73
+ while q:
74
+ u, path = q.popleft()
75
+ if len(path) > max_hops + 1:
76
+ continue
77
+ for v in kg.neighbors(u):
78
+ if v in seen:
79
+ continue
80
+ np_ = path + [v]
81
+ if v == dst:
82
+ return np_
83
+ seen.add(v); q.append((v, np_))
84
+ return None
85
+
86
+
87
+ def path_str(kg: KG, path) -> str:
88
+ if not path:
89
+ return "(no path within horizon)"
90
+ out = []
91
+ for a, b in zip(path, path[1:]):
92
+ rel = kg.edge_label.get((a, b), "β€”")
93
+ out.append(f"{a} β€”{rel}β†’ ")
94
+ return "".join(out) + path[-1]
95
+
96
+
97
+ def propose(kg: KG, manifested: list[str], variant_genes: dict[str, float] | None = None,
98
+ k: int = 8, restart: float = 0.30):
99
+ """Propose unseen first-onset candidates for a patient.
100
+
101
+ manifested: node ids the patient already has (diseases/phenotypes/genes)
102
+ variant_genes: optional {gene_id: pathogenicity in [0,1]} from a genomic
103
+ model (Evo 2 / AlphaMissense) β€” adds genomic seeds.
104
+ """
105
+ manifested = [m for m in manifested if m in kg.idx]
106
+ seeds = {m: 1.0 for m in manifested}
107
+ if variant_genes:
108
+ for g, path_score in variant_genes.items():
109
+ if g in kg.idx:
110
+ seeds[g] = seeds.get(g, 0.0) + 2.0 * float(path_score) # genome weighted up
111
+ if not seeds:
112
+ return {"error": "no seed nodes map to the KG"}
113
+ p = rwr(kg, seeds, restart=restart)
114
+ manifested_set = set(manifested)
115
+ out = {"genes": [], "phenotypes": [], "diseases": []}
116
+ bucket = {"gene": "genes", "phenotype": "phenotypes", "disease": "diseases"}
117
+ ranked = sorted(((n, sc) for n, sc in p.items() if sc > 0 and n not in manifested_set),
118
+ key=lambda x: -x[1])
119
+ for nid, sc in ranked:
120
+ b = bucket.get(kg.ntype.get(nid))
121
+ if not b or len(out[b]) >= k:
122
+ continue
123
+ # evidence path back to the nearest seed
124
+ best_path = None
125
+ for seed in manifested:
126
+ pth = shortest_path(kg, seed, nid)
127
+ if pth and (best_path is None or len(pth) < len(best_path)):
128
+ best_path = pth
129
+ out[b].append({"id": nid, "name": kg.names.get(nid, nid),
130
+ "rwr_score": round(sc, 6), "evidence": path_str(kg, best_path)})
131
+ return out
132
+
133
+
134
+ def _demo():
135
+ kg, src = load_kg()
136
+ print("=" * 80)
137
+ print(f"GEMEO v2.0 β€” Pillar A: RWR onset proposer [KG source: {src}, {len(kg.nodes)} nodes]")
138
+ print("=" * 80)
139
+ # A patient who presents as Marfan (disease known) β€” what onsets does the KG propose?
140
+ cases = [
141
+ ("Marfan presentation", ["ORPHA:558"], None),
142
+ ("Marfan, genome-seeded by a pathogenic FBN1 variant",
143
+ ["HP:0004942"], {"FBN1": 0.97}), # only an aortic phenotype + a genomic hit, no disease label
144
+ ("Duchenne presentation", ["ORPHA:98896"], None),
145
+ ]
146
+ out = {}
147
+ for label, manifested, variants in cases:
148
+ r = propose(kg, manifested, variant_genes=variants, k=5)
149
+ out[label] = {"manifested": manifested, "variants": variants, **r}
150
+ print(f"\n[{label}] seeds={manifested} variants={variants or 'β€”'}")
151
+ for kind in ("genes", "phenotypes", "diseases"):
152
+ if r.get(kind):
153
+ print(f" {kind}:")
154
+ for c in r[kind][:4]:
155
+ print(f" β€’ {c['id']} ({c['name']}) rwr={c['rwr_score']}")
156
+ print(f" evidence: {c['evidence']}")
157
+ json.dump(out, open("/tmp/pillar_a_demo.json", "w"), indent=2)
158
+ print("\nNew-onset filter: every already-manifested node is excluded from the ranking.")
159
+ print("Genomic seeding: pass variant_genes={gene: pathogenicity} from Evo 2 / AlphaMissense.")
160
+ print("Saved /tmp/pillar_a_demo.json")
161
+
162
+
163
+ if __name__ == "__main__":
164
+ _demo()