raredx / backend /scripts /graph_store.py
Aswin92's picture
Upload folder using huggingface_hub
89c6379 verified
"""
graph_store.py
--------------
Lightweight local graph store that mirrors the Neo4j schema used by RareDx.
Uses NetworkX in-memory + JSON persistence as a drop-in fallback when
the Neo4j Docker service is unavailable.
Graph schema:
(:Disease {orpha_code, name, definition, expert_link})
(:Synonym {text})
(:HPOTerm {hpo_id, term})
(:Disease)-[:HAS_SYNONYM]->(:Synonym)
(:Disease)-[:MANIFESTS_AS {frequency, frequency_label, diagnostic_criteria}]->(:HPOTerm)
"""
import json
from pathlib import Path
from typing import Optional
import networkx as nx
DEFAULT_PATH = Path(__file__).parents[2] / "data" / "graph_store.json"
class LocalGraphStore:
"""NetworkX-backed graph store with JSON persistence."""
def __init__(self, path: Path = DEFAULT_PATH) -> None:
self.path = path
self.graph = nx.DiGraph()
if path.exists():
self._load()
# ------------------------------------------------------------------
# Persistence
# ------------------------------------------------------------------
def _load(self) -> None:
data = json.loads(self.path.read_text(encoding="utf-8"))
for node in data.get("nodes", []):
nid = node.pop("id")
self.graph.add_node(nid, **node)
for edge in data.get("edges", []):
attrs = {k: v for k, v in edge.items() if k not in ("src", "dst")}
self.graph.add_edge(edge["src"], edge["dst"], **attrs)
def save(self) -> None:
self.path.parent.mkdir(parents=True, exist_ok=True)
data = {
"nodes": [{"id": n, **self.graph.nodes[n]} for n in self.graph.nodes],
"edges": [
{"src": u, "dst": v, **d}
for u, v, d in self.graph.edges(data=True)
],
}
self.path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
# ------------------------------------------------------------------
# Disease + Synonym write
# ------------------------------------------------------------------
def upsert_disease(self, orpha_code: int, name: str, definition: str, expert_link: str) -> None:
nid = f"Disease:{orpha_code}"
self.graph.add_node(
nid,
type="Disease",
orpha_code=orpha_code,
name=name,
definition=definition,
expert_link=expert_link,
)
def add_synonym(self, orpha_code: int, synonym_text: str) -> None:
disease_nid = f"Disease:{orpha_code}"
syn_nid = f"Synonym:{synonym_text}"
self.graph.add_node(syn_nid, type="Synonym", text=synonym_text)
self.graph.add_edge(disease_nid, syn_nid, label="HAS_SYNONYM")
def upsert_disorders_bulk(self, disorders: list[dict]) -> int:
for d in disorders:
self.upsert_disease(
orpha_code=d["orpha_code"],
name=d["name"],
definition=d.get("definition", ""),
expert_link=d.get("expert_link", ""),
)
for syn in d.get("synonyms", []):
self.add_synonym(d["orpha_code"], syn)
self.save()
return len(disorders)
# ------------------------------------------------------------------
# HPO write
# ------------------------------------------------------------------
def upsert_hpo_term(self, hpo_id: str, term: str) -> None:
"""Create or update an HPOTerm node."""
nid = f"HPO:{hpo_id}"
self.graph.add_node(nid, type="HPOTerm", hpo_id=hpo_id, term=term)
def add_manifestation(
self,
orpha_code: int,
hpo_id: str,
frequency_label: str,
frequency_order: int,
diagnostic_criteria: str,
) -> None:
"""
Add (:Disease)-[:MANIFESTS_AS {frequency_label, frequency_order, diagnostic_criteria}]->(:HPOTerm)
frequency_order: 1=Very frequent, 2=Frequent, 3=Occasional, 4=Rare, 5=Excluded, 0=Unknown
"""
disease_nid = f"Disease:{orpha_code}"
hpo_nid = f"HPO:{hpo_id}"
if disease_nid not in self.graph:
return # skip if disease not loaded yet
self.graph.add_edge(
disease_nid,
hpo_nid,
label="MANIFESTS_AS",
frequency_label=frequency_label,
frequency_order=frequency_order,
diagnostic_criteria=diagnostic_criteria,
)
def upsert_hpo_bulk(self, associations: list[dict]) -> int:
"""
associations: list of {orpha_code, hpo_id, term, frequency_label,
frequency_order, diagnostic_criteria}
"""
for a in associations:
self.upsert_hpo_term(a["hpo_id"], a["term"])
self.add_manifestation(
orpha_code=a["orpha_code"],
hpo_id=a["hpo_id"],
frequency_label=a["frequency_label"],
frequency_order=a["frequency_order"],
diagnostic_criteria=a["diagnostic_criteria"],
)
self.save()
return len(associations)
# ------------------------------------------------------------------
# Disease read
# ------------------------------------------------------------------
def find_disease_by_name(self, name_fragment: str) -> Optional[dict]:
"""Case-insensitive contains search."""
fragment = name_fragment.lower()
for nid, attrs in self.graph.nodes(data=True):
if attrs.get("type") == "Disease":
if fragment in attrs.get("name", "").lower():
return self._hydrate_disease(nid, attrs)
return None
def get_disease_by_orpha(self, orpha_code: int) -> Optional[dict]:
nid = f"Disease:{orpha_code}"
if nid in self.graph:
return self._hydrate_disease(nid, self.graph.nodes[nid])
return None
def _hydrate_disease(self, nid: str, attrs: dict) -> dict:
synonyms, hpo_terms = [], []
for v, edge_data in self.graph[nid].items():
vtype = self.graph.nodes[v].get("type")
if vtype == "Synonym":
synonyms.append(self.graph.nodes[v]["text"])
elif vtype == "HPOTerm":
hpo_terms.append({
"hpo_id": self.graph.nodes[v]["hpo_id"],
"term": self.graph.nodes[v]["term"],
"frequency_label": edge_data.get("frequency_label", ""),
"frequency_order": edge_data.get("frequency_order", 0),
"diagnostic_criteria": edge_data.get("diagnostic_criteria", ""),
})
hpo_terms.sort(key=lambda x: x["frequency_order"])
return {
"orpha_code": attrs["orpha_code"],
"name": attrs["name"],
"definition": attrs.get("definition", ""),
"expert_link": attrs.get("expert_link", ""),
"synonyms": synonyms,
"hpo_terms": hpo_terms,
}
# ------------------------------------------------------------------
# Phenotype-based diagnostic query
# ------------------------------------------------------------------
def find_diseases_by_hpo(
self,
hpo_ids: list[str],
top_n: int = 10,
min_matches: int = 1,
) -> list[dict]:
"""
Given a list of HPO term IDs, find diseases that manifest those phenotypes.
Returns diseases ranked by:
1. Number of matching HPO terms (desc)
2. Sum of frequency weights of matched terms (desc)
(Very frequent=5, Frequent=4, Occasional=3, Rare=2, Excluded=-1, Unknown=1)
This is the core graph-based differential diagnosis query.
"""
FREQ_WEIGHT = {1: 5, 2: 4, 3: 3, 4: 2, 5: -1, 0: 1}
query_nodes = {f"HPO:{hid}" for hid in hpo_ids}
# Walk from each HPO node to Disease predecessors
disease_scores: dict[str, dict] = {}
for hpo_nid in query_nodes:
if hpo_nid not in self.graph:
continue
for disease_nid in self.graph.predecessors(hpo_nid):
if self.graph.nodes[disease_nid].get("type") != "Disease":
continue
edge = self.graph[disease_nid][hpo_nid]
if edge.get("label") != "MANIFESTS_AS":
continue
# Skip excluded phenotypes
if edge.get("frequency_order") == 5:
continue
freq_w = FREQ_WEIGHT.get(edge.get("frequency_order", 0), 1)
if disease_nid not in disease_scores:
disease_scores[disease_nid] = {
"match_count": 0,
"freq_score": 0.0,
"matched_hpo": [],
}
disease_scores[disease_nid]["match_count"] += 1
disease_scores[disease_nid]["freq_score"] += freq_w
disease_scores[disease_nid]["matched_hpo"].append({
"hpo_id": self.graph.nodes[hpo_nid]["hpo_id"],
"term": self.graph.nodes[hpo_nid]["term"],
"frequency_label": edge.get("frequency_label", ""),
})
# Filter minimum matches and rank
ranked = [
(nid, scores)
for nid, scores in disease_scores.items()
if scores["match_count"] >= min_matches
]
ranked.sort(key=lambda x: (x[1]["match_count"], x[1]["freq_score"]), reverse=True)
results = []
for disease_nid, scores in ranked[:top_n]:
attrs = self.graph.nodes[disease_nid]
results.append({
"orpha_code": attrs["orpha_code"],
"name": attrs["name"],
"definition": attrs.get("definition", ""),
"match_count": scores["match_count"],
"total_query_terms": len(hpo_ids),
"freq_score": round(scores["freq_score"], 2),
"matched_hpo": scores["matched_hpo"],
})
return results
def find_diseases_by_hpo_terms(
self,
term_names: list[str],
top_n: int = 10,
) -> list[dict]:
"""
Convenience wrapper: search by HPO term names (case-insensitive)
instead of HPO IDs.
"""
hpo_ids = []
for name in term_names:
name_lower = name.lower()
for nid, attrs in self.graph.nodes(data=True):
if attrs.get("type") == "HPOTerm":
if name_lower in attrs.get("term", "").lower():
hpo_ids.append(attrs["hpo_id"])
break
return self.find_diseases_by_hpo(hpo_ids, top_n=top_n)
# ------------------------------------------------------------------
# Stats
# ------------------------------------------------------------------
def disease_count(self) -> int:
return sum(1 for _, d in self.graph.nodes(data=True) if d.get("type") == "Disease")
def synonym_count(self) -> int:
return sum(1 for _, d in self.graph.nodes(data=True) if d.get("type") == "Synonym")
def hpo_term_count(self) -> int:
return sum(1 for _, d in self.graph.nodes(data=True) if d.get("type") == "HPOTerm")
def manifestation_count(self) -> int:
return sum(
1 for _, _, d in self.graph.edges(data=True)
if d.get("label") == "MANIFESTS_AS"
)
def edge_count(self) -> int:
return self.graph.number_of_edges()