| """ |
| reembed_chromadb.py |
| ------------------- |
| Rebuilds ChromaDB embeddings with HPO-enriched disease descriptions. |
| |
| Week 1 embedding text: |
| "{name}. {definition}. Also known as: {synonyms}." |
| |
| Week 2B embedding text (this script): |
| "{name}. {definition}. Phenotypes: {hpo_terms ordered by frequency}. |
| Also known as: {synonyms}." |
| |
| Adding phenotype terms directly into the embedding space means ChromaDB |
| can now find diseases by symptoms, not just by name similarity. |
| """ |
|
|
| import os |
| import sys |
| from pathlib import Path |
|
|
| import chromadb |
| from chromadb.config import Settings |
| from sentence_transformers import SentenceTransformer |
| from tqdm import tqdm |
| from dotenv import load_dotenv |
|
|
| load_dotenv(Path(__file__).parents[2] / ".env") |
|
|
| CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost") |
| CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000")) |
| COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases") |
| EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023") |
| CHROMA_PERSIST = Path(__file__).parents[2] / "data" / "chromadb" |
|
|
| BATCH_SIZE = 32 |
|
|
|
|
| |
| |
| |
|
|
| def build_documents(store) -> list[dict]: |
| """ |
| Pull every disease from the graph store and build HPO-enriched embed text. |
| HPO terms are sorted by frequency_order (most frequent first). |
| """ |
| docs = [] |
| disease_nodes = [ |
| (nid, attrs) |
| for nid, attrs in store.graph.nodes(data=True) |
| if attrs.get("type") == "Disease" |
| ] |
|
|
| for nid, attrs in tqdm(disease_nodes, desc=" Building documents", unit="disease"): |
| orpha_code = attrs["orpha_code"] |
| name = attrs.get("name", "") |
| definition = attrs.get("definition", "") |
|
|
| |
| synonyms = [] |
| hpo_terms = [] |
|
|
| for v, edata in store.graph[nid].items(): |
| vattrs = store.graph.nodes[v] |
| vtype = vattrs.get("type") |
|
|
| if vtype == "Synonym": |
| synonyms.append(vattrs["text"]) |
|
|
| elif vtype == "HPOTerm" and edata.get("label") == "MANIFESTS_AS": |
| freq_order = edata.get("frequency_order", 9) |
| |
| if freq_order == 5: |
| continue |
| hpo_terms.append((freq_order, vattrs.get("term", ""))) |
|
|
| |
| hpo_terms.sort(key=lambda x: x[0]) |
| hpo_term_names = [t[1] for t in hpo_terms[:30]] |
|
|
| |
| parts = [name] |
| if definition: |
| parts.append(definition) |
| if hpo_term_names: |
| parts.append("Clinical features: " + ", ".join(hpo_term_names) + ".") |
| if synonyms: |
| parts.append("Also known as: " + ", ".join(synonyms) + ".") |
|
|
| embed_text = " ".join(parts) |
|
|
| docs.append({ |
| "id": f"ORPHA:{orpha_code}", |
| "orpha_code": str(orpha_code), |
| "name": name, |
| "definition": definition, |
| "synonyms": ", ".join(synonyms), |
| "hpo_terms": ", ".join(hpo_term_names[:15]), |
| "embed_text": embed_text, |
| }) |
|
|
| return docs |
|
|
|
|
| |
| |
| |
|
|
| def get_chroma_client() -> tuple[chromadb.ClientAPI, str]: |
| try: |
| client = chromadb.HttpClient( |
| host=CHROMA_HOST, port=CHROMA_PORT, |
| settings=Settings(anonymized_telemetry=False), |
| ) |
| client.heartbeat() |
| return client, "ChromaDB HTTP (Docker)" |
| except Exception: |
| CHROMA_PERSIST.mkdir(parents=True, exist_ok=True) |
| client = chromadb.PersistentClient( |
| path=str(CHROMA_PERSIST), |
| settings=Settings(anonymized_telemetry=False), |
| ) |
| return client, "ChromaDB Embedded" |
|
|
|
|
| def recreate_collection(client: chromadb.ClientAPI, name: str) -> chromadb.Collection: |
| try: |
| client.delete_collection(name) |
| print(f" Deleted existing collection '{name}'.") |
| except Exception: |
| pass |
| col = client.create_collection(name=name, metadata={"hnsw:space": "cosine"}) |
| print(f" Created collection '{name}'.") |
| return col |
|
|
|
|
| def upsert_batches(col, docs: list[dict], embeddings) -> None: |
| for i in range(0, len(docs), BATCH_SIZE): |
| bd = docs[i : i + BATCH_SIZE] |
| be = embeddings[i : i + BATCH_SIZE] |
| col.upsert( |
| ids = [d["id"] for d in bd], |
| embeddings = be, |
| documents = [d["embed_text"] for d in bd], |
| metadatas = [{ |
| "orpha_code": d["orpha_code"], |
| "name": d["name"], |
| "definition": d["definition"][:500], |
| "synonyms": d["synonyms"], |
| "hpo_terms": d["hpo_terms"], |
| } for d in bd], |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def main() -> None: |
| print("=" * 60) |
| print("RareDx — Week 2B Step 1: Re-embed with HPO-Enriched Text") |
| print("=" * 60) |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent)) |
| from graph_store import LocalGraphStore |
| store = LocalGraphStore() |
| print(f"\nGraph: {store.disease_count():,} diseases | " |
| f"{store.hpo_term_count():,} HPO terms | " |
| f"{store.manifestation_count():,} phenotype edges") |
|
|
| |
| print("\nBuilding HPO-enriched documents...") |
| docs = build_documents(store) |
| print(f" {len(docs):,} documents ready.") |
|
|
| |
| sample = next((d for d in docs if "Marfan" in d["name"]), docs[0]) |
| print(f"\n Sample — {sample['name']}:") |
| preview = sample["embed_text"][:300] |
| print(f" {preview}...") |
|
|
| |
| print(f"\nLoading {EMBED_MODEL}...") |
| model = SentenceTransformer(EMBED_MODEL) |
| print(f" Embedding dim: {model.get_sentence_embedding_dimension()}") |
|
|
| |
| print(f"\nEmbedding {len(docs):,} documents (batch={BATCH_SIZE})...") |
| texts = [d["embed_text"] for d in docs] |
| embeddings = model.encode( |
| texts, |
| batch_size=BATCH_SIZE, |
| show_progress_bar=True, |
| normalize_embeddings=True, |
| ) |
| print(f" Shape: {embeddings.shape}") |
|
|
| |
| print("\nConnecting to ChromaDB...") |
| client, backend = get_chroma_client() |
| print(f" Backend: {backend}") |
| col = recreate_collection(client, COLLECTION_NAME) |
|
|
| print(f"Upserting {len(docs):,} documents...") |
| upsert_batches(col, docs, embeddings.tolist()) |
| print(f" Collection '{COLLECTION_NAME}': {col.count():,} documents.") |
|
|
| |
| print("\nSanity check: 'arachnodactyly tall stature ectopia lentis aortic dilation'") |
| probe = model.encode( |
| ["arachnodactyly tall stature ectopia lentis aortic dilation"], |
| normalize_embeddings=True, |
| ) |
| results = col.query(query_embeddings=probe.tolist(), n_results=5) |
| for meta, dist in zip(results["metadatas"][0], results["distances"][0]): |
| sim = round(1 - dist, 4) |
| print(f" [{sim:.4f}] ORPHA:{meta['orpha_code']} {meta['name']}") |
|
|
| print(f"\nStep 1 done — backend: {backend}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|