""" reembed_chromadb.py ------------------- Rebuilds ChromaDB embeddings with HPO-enriched disease descriptions. Week 1 embedding text: "{name}. {definition}. Also known as: {synonyms}." Week 2B embedding text (this script): "{name}. {definition}. Phenotypes: {hpo_terms ordered by frequency}. Also known as: {synonyms}." Adding phenotype terms directly into the embedding space means ChromaDB can now find diseases by symptoms, not just by name similarity. """ import os import sys from pathlib import Path import chromadb from chromadb.config import Settings from sentence_transformers import SentenceTransformer from tqdm import tqdm from dotenv import load_dotenv load_dotenv(Path(__file__).parents[2] / ".env") CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost") CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000")) COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases") EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023") CHROMA_PERSIST = Path(__file__).parents[2] / "data" / "chromadb" BATCH_SIZE = 32 # --------------------------------------------------------------------------- # Build enriched document text per disease # --------------------------------------------------------------------------- def build_documents(store) -> list[dict]: """ Pull every disease from the graph store and build HPO-enriched embed text. HPO terms are sorted by frequency_order (most frequent first). """ docs = [] disease_nodes = [ (nid, attrs) for nid, attrs in store.graph.nodes(data=True) if attrs.get("type") == "Disease" ] for nid, attrs in tqdm(disease_nodes, desc=" Building documents", unit="disease"): orpha_code = attrs["orpha_code"] name = attrs.get("name", "") definition = attrs.get("definition", "") # Collect synonyms and HPO terms from graph edges synonyms = [] hpo_terms = [] for v, edata in store.graph[nid].items(): vattrs = store.graph.nodes[v] vtype = vattrs.get("type") if vtype == "Synonym": synonyms.append(vattrs["text"]) elif vtype == "HPOTerm" and edata.get("label") == "MANIFESTS_AS": freq_order = edata.get("frequency_order", 9) # Skip excluded phenotypes (frequency_order == 5) if freq_order == 5: continue hpo_terms.append((freq_order, vattrs.get("term", ""))) # Sort HPO terms: most frequent first hpo_terms.sort(key=lambda x: x[0]) hpo_term_names = [t[1] for t in hpo_terms[:30]] # cap at 30 to control token length # Build enriched text parts = [name] if definition: parts.append(definition) if hpo_term_names: parts.append("Clinical features: " + ", ".join(hpo_term_names) + ".") if synonyms: parts.append("Also known as: " + ", ".join(synonyms) + ".") embed_text = " ".join(parts) docs.append({ "id": f"ORPHA:{orpha_code}", "orpha_code": str(orpha_code), "name": name, "definition": definition, "synonyms": ", ".join(synonyms), "hpo_terms": ", ".join(hpo_term_names[:15]), # store subset in metadata "embed_text": embed_text, }) return docs # --------------------------------------------------------------------------- # ChromaDB helpers # --------------------------------------------------------------------------- def get_chroma_client() -> tuple[chromadb.ClientAPI, str]: try: client = chromadb.HttpClient( host=CHROMA_HOST, port=CHROMA_PORT, settings=Settings(anonymized_telemetry=False), ) client.heartbeat() return client, "ChromaDB HTTP (Docker)" except Exception: CHROMA_PERSIST.mkdir(parents=True, exist_ok=True) client = chromadb.PersistentClient( path=str(CHROMA_PERSIST), settings=Settings(anonymized_telemetry=False), ) return client, "ChromaDB Embedded" def recreate_collection(client: chromadb.ClientAPI, name: str) -> chromadb.Collection: try: client.delete_collection(name) print(f" Deleted existing collection '{name}'.") except Exception: pass col = client.create_collection(name=name, metadata={"hnsw:space": "cosine"}) print(f" Created collection '{name}'.") return col def upsert_batches(col, docs: list[dict], embeddings) -> None: for i in range(0, len(docs), BATCH_SIZE): bd = docs[i : i + BATCH_SIZE] be = embeddings[i : i + BATCH_SIZE] col.upsert( ids = [d["id"] for d in bd], embeddings = be, documents = [d["embed_text"] for d in bd], metadatas = [{ "orpha_code": d["orpha_code"], "name": d["name"], "definition": d["definition"][:500], "synonyms": d["synonyms"], "hpo_terms": d["hpo_terms"], } for d in bd], ) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: print("=" * 60) print("RareDx — Week 2B Step 1: Re-embed with HPO-Enriched Text") print("=" * 60) # Load graph store sys.path.insert(0, str(Path(__file__).parent)) from graph_store import LocalGraphStore store = LocalGraphStore() print(f"\nGraph: {store.disease_count():,} diseases | " f"{store.hpo_term_count():,} HPO terms | " f"{store.manifestation_count():,} phenotype edges") # Build documents print("\nBuilding HPO-enriched documents...") docs = build_documents(store) print(f" {len(docs):,} documents ready.") # Sample — show the enrichment difference sample = next((d for d in docs if "Marfan" in d["name"]), docs[0]) print(f"\n Sample — {sample['name']}:") preview = sample["embed_text"][:300] print(f" {preview}...") # Load model print(f"\nLoading {EMBED_MODEL}...") model = SentenceTransformer(EMBED_MODEL) print(f" Embedding dim: {model.get_sentence_embedding_dimension()}") # Embed print(f"\nEmbedding {len(docs):,} documents (batch={BATCH_SIZE})...") texts = [d["embed_text"] for d in docs] embeddings = model.encode( texts, batch_size=BATCH_SIZE, show_progress_bar=True, normalize_embeddings=True, ) print(f" Shape: {embeddings.shape}") # Store print("\nConnecting to ChromaDB...") client, backend = get_chroma_client() print(f" Backend: {backend}") col = recreate_collection(client, COLLECTION_NAME) print(f"Upserting {len(docs):,} documents...") upsert_batches(col, docs, embeddings.tolist()) print(f" Collection '{COLLECTION_NAME}': {col.count():,} documents.") # Sanity check — now "arachnodactyly tall stature ectopia lentis" should hit Marfan print("\nSanity check: 'arachnodactyly tall stature ectopia lentis aortic dilation'") probe = model.encode( ["arachnodactyly tall stature ectopia lentis aortic dilation"], normalize_embeddings=True, ) results = col.query(query_embeddings=probe.tolist(), n_results=5) for meta, dist in zip(results["metadatas"][0], results["distances"][0]): sim = round(1 - dist, 4) print(f" [{sim:.4f}] ORPHA:{meta['orpha_code']} {meta['name']}") print(f"\nStep 1 done — backend: {backend}") if __name__ == "__main__": main()