""" graph_store.py -------------- Lightweight local graph store that mirrors the Neo4j schema used by RareDx. Uses NetworkX in-memory + JSON persistence as a drop-in fallback when the Neo4j Docker service is unavailable. Graph schema: (:Disease {orpha_code, name, definition, expert_link}) (:Synonym {text}) (:HPOTerm {hpo_id, term}) (:Disease)-[:HAS_SYNONYM]->(:Synonym) (:Disease)-[:MANIFESTS_AS {frequency, frequency_label, diagnostic_criteria}]->(:HPOTerm) """ import json from pathlib import Path from typing import Optional import networkx as nx DEFAULT_PATH = Path(__file__).parents[2] / "data" / "graph_store.json" class LocalGraphStore: """NetworkX-backed graph store with JSON persistence.""" def __init__(self, path: Path = DEFAULT_PATH) -> None: self.path = path self.graph = nx.DiGraph() if path.exists(): self._load() # ------------------------------------------------------------------ # Persistence # ------------------------------------------------------------------ def _load(self) -> None: data = json.loads(self.path.read_text(encoding="utf-8")) for node in data.get("nodes", []): nid = node.pop("id") self.graph.add_node(nid, **node) for edge in data.get("edges", []): attrs = {k: v for k, v in edge.items() if k not in ("src", "dst")} self.graph.add_edge(edge["src"], edge["dst"], **attrs) def save(self) -> None: self.path.parent.mkdir(parents=True, exist_ok=True) data = { "nodes": [{"id": n, **self.graph.nodes[n]} for n in self.graph.nodes], "edges": [ {"src": u, "dst": v, **d} for u, v, d in self.graph.edges(data=True) ], } self.path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") # ------------------------------------------------------------------ # Disease + Synonym write # ------------------------------------------------------------------ def upsert_disease(self, orpha_code: int, name: str, definition: str, expert_link: str) -> None: nid = f"Disease:{orpha_code}" self.graph.add_node( nid, type="Disease", orpha_code=orpha_code, name=name, definition=definition, expert_link=expert_link, ) def add_synonym(self, orpha_code: int, synonym_text: str) -> None: disease_nid = f"Disease:{orpha_code}" syn_nid = f"Synonym:{synonym_text}" self.graph.add_node(syn_nid, type="Synonym", text=synonym_text) self.graph.add_edge(disease_nid, syn_nid, label="HAS_SYNONYM") def upsert_disorders_bulk(self, disorders: list[dict]) -> int: for d in disorders: self.upsert_disease( orpha_code=d["orpha_code"], name=d["name"], definition=d.get("definition", ""), expert_link=d.get("expert_link", ""), ) for syn in d.get("synonyms", []): self.add_synonym(d["orpha_code"], syn) self.save() return len(disorders) # ------------------------------------------------------------------ # HPO write # ------------------------------------------------------------------ def upsert_hpo_term(self, hpo_id: str, term: str) -> None: """Create or update an HPOTerm node.""" nid = f"HPO:{hpo_id}" self.graph.add_node(nid, type="HPOTerm", hpo_id=hpo_id, term=term) def add_manifestation( self, orpha_code: int, hpo_id: str, frequency_label: str, frequency_order: int, diagnostic_criteria: str, ) -> None: """ Add (:Disease)-[:MANIFESTS_AS {frequency_label, frequency_order, diagnostic_criteria}]->(:HPOTerm) frequency_order: 1=Very frequent, 2=Frequent, 3=Occasional, 4=Rare, 5=Excluded, 0=Unknown """ disease_nid = f"Disease:{orpha_code}" hpo_nid = f"HPO:{hpo_id}" if disease_nid not in self.graph: return # skip if disease not loaded yet self.graph.add_edge( disease_nid, hpo_nid, label="MANIFESTS_AS", frequency_label=frequency_label, frequency_order=frequency_order, diagnostic_criteria=diagnostic_criteria, ) def upsert_hpo_bulk(self, associations: list[dict]) -> int: """ associations: list of {orpha_code, hpo_id, term, frequency_label, frequency_order, diagnostic_criteria} """ for a in associations: self.upsert_hpo_term(a["hpo_id"], a["term"]) self.add_manifestation( orpha_code=a["orpha_code"], hpo_id=a["hpo_id"], frequency_label=a["frequency_label"], frequency_order=a["frequency_order"], diagnostic_criteria=a["diagnostic_criteria"], ) self.save() return len(associations) # ------------------------------------------------------------------ # Disease read # ------------------------------------------------------------------ def find_disease_by_name(self, name_fragment: str) -> Optional[dict]: """Case-insensitive contains search.""" fragment = name_fragment.lower() for nid, attrs in self.graph.nodes(data=True): if attrs.get("type") == "Disease": if fragment in attrs.get("name", "").lower(): return self._hydrate_disease(nid, attrs) return None def get_disease_by_orpha(self, orpha_code: int) -> Optional[dict]: nid = f"Disease:{orpha_code}" if nid in self.graph: return self._hydrate_disease(nid, self.graph.nodes[nid]) return None def _hydrate_disease(self, nid: str, attrs: dict) -> dict: synonyms, hpo_terms = [], [] for v, edge_data in self.graph[nid].items(): vtype = self.graph.nodes[v].get("type") if vtype == "Synonym": synonyms.append(self.graph.nodes[v]["text"]) elif vtype == "HPOTerm": hpo_terms.append({ "hpo_id": self.graph.nodes[v]["hpo_id"], "term": self.graph.nodes[v]["term"], "frequency_label": edge_data.get("frequency_label", ""), "frequency_order": edge_data.get("frequency_order", 0), "diagnostic_criteria": edge_data.get("diagnostic_criteria", ""), }) hpo_terms.sort(key=lambda x: x["frequency_order"]) return { "orpha_code": attrs["orpha_code"], "name": attrs["name"], "definition": attrs.get("definition", ""), "expert_link": attrs.get("expert_link", ""), "synonyms": synonyms, "hpo_terms": hpo_terms, } # ------------------------------------------------------------------ # Phenotype-based diagnostic query # ------------------------------------------------------------------ def find_diseases_by_hpo( self, hpo_ids: list[str], top_n: int = 10, min_matches: int = 1, ) -> list[dict]: """ Given a list of HPO term IDs, find diseases that manifest those phenotypes. Returns diseases ranked by: 1. Number of matching HPO terms (desc) 2. Sum of frequency weights of matched terms (desc) (Very frequent=5, Frequent=4, Occasional=3, Rare=2, Excluded=-1, Unknown=1) This is the core graph-based differential diagnosis query. """ FREQ_WEIGHT = {1: 5, 2: 4, 3: 3, 4: 2, 5: -1, 0: 1} query_nodes = {f"HPO:{hid}" for hid in hpo_ids} # Walk from each HPO node to Disease predecessors disease_scores: dict[str, dict] = {} for hpo_nid in query_nodes: if hpo_nid not in self.graph: continue for disease_nid in self.graph.predecessors(hpo_nid): if self.graph.nodes[disease_nid].get("type") != "Disease": continue edge = self.graph[disease_nid][hpo_nid] if edge.get("label") != "MANIFESTS_AS": continue # Skip excluded phenotypes if edge.get("frequency_order") == 5: continue freq_w = FREQ_WEIGHT.get(edge.get("frequency_order", 0), 1) if disease_nid not in disease_scores: disease_scores[disease_nid] = { "match_count": 0, "freq_score": 0.0, "matched_hpo": [], } disease_scores[disease_nid]["match_count"] += 1 disease_scores[disease_nid]["freq_score"] += freq_w disease_scores[disease_nid]["matched_hpo"].append({ "hpo_id": self.graph.nodes[hpo_nid]["hpo_id"], "term": self.graph.nodes[hpo_nid]["term"], "frequency_label": edge.get("frequency_label", ""), }) # Filter minimum matches and rank ranked = [ (nid, scores) for nid, scores in disease_scores.items() if scores["match_count"] >= min_matches ] ranked.sort(key=lambda x: (x[1]["match_count"], x[1]["freq_score"]), reverse=True) results = [] for disease_nid, scores in ranked[:top_n]: attrs = self.graph.nodes[disease_nid] results.append({ "orpha_code": attrs["orpha_code"], "name": attrs["name"], "definition": attrs.get("definition", ""), "match_count": scores["match_count"], "total_query_terms": len(hpo_ids), "freq_score": round(scores["freq_score"], 2), "matched_hpo": scores["matched_hpo"], }) return results def find_diseases_by_hpo_terms( self, term_names: list[str], top_n: int = 10, ) -> list[dict]: """ Convenience wrapper: search by HPO term names (case-insensitive) instead of HPO IDs. """ hpo_ids = [] for name in term_names: name_lower = name.lower() for nid, attrs in self.graph.nodes(data=True): if attrs.get("type") == "HPOTerm": if name_lower in attrs.get("term", "").lower(): hpo_ids.append(attrs["hpo_id"]) break return self.find_diseases_by_hpo(hpo_ids, top_n=top_n) # ------------------------------------------------------------------ # Stats # ------------------------------------------------------------------ def disease_count(self) -> int: return sum(1 for _, d in self.graph.nodes(data=True) if d.get("type") == "Disease") def synonym_count(self) -> int: return sum(1 for _, d in self.graph.nodes(data=True) if d.get("type") == "Synonym") def hpo_term_count(self) -> int: return sum(1 for _, d in self.graph.nodes(data=True) if d.get("type") == "HPOTerm") def manifestation_count(self) -> int: return sum( 1 for _, _, d in self.graph.edges(data=True) if d.get("label") == "MANIFESTS_AS" ) def edge_count(self) -> int: return self.graph.number_of_edges()