raredx / backend /scripts /reembed_chromadb.py
Aswin92's picture
Upload folder using huggingface_hub
89c6379 verified
"""
reembed_chromadb.py
-------------------
Rebuilds ChromaDB embeddings with HPO-enriched disease descriptions.
Week 1 embedding text:
"{name}. {definition}. Also known as: {synonyms}."
Week 2B embedding text (this script):
"{name}. {definition}. Phenotypes: {hpo_terms ordered by frequency}.
Also known as: {synonyms}."
Adding phenotype terms directly into the embedding space means ChromaDB
can now find diseases by symptoms, not just by name similarity.
"""
import os
import sys
from pathlib import Path
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from dotenv import load_dotenv
load_dotenv(Path(__file__).parents[2] / ".env")
CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost")
CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000"))
COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases")
EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023")
CHROMA_PERSIST = Path(__file__).parents[2] / "data" / "chromadb"
BATCH_SIZE = 32
# ---------------------------------------------------------------------------
# Build enriched document text per disease
# ---------------------------------------------------------------------------
def build_documents(store) -> list[dict]:
"""
Pull every disease from the graph store and build HPO-enriched embed text.
HPO terms are sorted by frequency_order (most frequent first).
"""
docs = []
disease_nodes = [
(nid, attrs)
for nid, attrs in store.graph.nodes(data=True)
if attrs.get("type") == "Disease"
]
for nid, attrs in tqdm(disease_nodes, desc=" Building documents", unit="disease"):
orpha_code = attrs["orpha_code"]
name = attrs.get("name", "")
definition = attrs.get("definition", "")
# Collect synonyms and HPO terms from graph edges
synonyms = []
hpo_terms = []
for v, edata in store.graph[nid].items():
vattrs = store.graph.nodes[v]
vtype = vattrs.get("type")
if vtype == "Synonym":
synonyms.append(vattrs["text"])
elif vtype == "HPOTerm" and edata.get("label") == "MANIFESTS_AS":
freq_order = edata.get("frequency_order", 9)
# Skip excluded phenotypes (frequency_order == 5)
if freq_order == 5:
continue
hpo_terms.append((freq_order, vattrs.get("term", "")))
# Sort HPO terms: most frequent first
hpo_terms.sort(key=lambda x: x[0])
hpo_term_names = [t[1] for t in hpo_terms[:30]] # cap at 30 to control token length
# Build enriched text
parts = [name]
if definition:
parts.append(definition)
if hpo_term_names:
parts.append("Clinical features: " + ", ".join(hpo_term_names) + ".")
if synonyms:
parts.append("Also known as: " + ", ".join(synonyms) + ".")
embed_text = " ".join(parts)
docs.append({
"id": f"ORPHA:{orpha_code}",
"orpha_code": str(orpha_code),
"name": name,
"definition": definition,
"synonyms": ", ".join(synonyms),
"hpo_terms": ", ".join(hpo_term_names[:15]), # store subset in metadata
"embed_text": embed_text,
})
return docs
# ---------------------------------------------------------------------------
# ChromaDB helpers
# ---------------------------------------------------------------------------
def get_chroma_client() -> tuple[chromadb.ClientAPI, str]:
try:
client = chromadb.HttpClient(
host=CHROMA_HOST, port=CHROMA_PORT,
settings=Settings(anonymized_telemetry=False),
)
client.heartbeat()
return client, "ChromaDB HTTP (Docker)"
except Exception:
CHROMA_PERSIST.mkdir(parents=True, exist_ok=True)
client = chromadb.PersistentClient(
path=str(CHROMA_PERSIST),
settings=Settings(anonymized_telemetry=False),
)
return client, "ChromaDB Embedded"
def recreate_collection(client: chromadb.ClientAPI, name: str) -> chromadb.Collection:
try:
client.delete_collection(name)
print(f" Deleted existing collection '{name}'.")
except Exception:
pass
col = client.create_collection(name=name, metadata={"hnsw:space": "cosine"})
print(f" Created collection '{name}'.")
return col
def upsert_batches(col, docs: list[dict], embeddings) -> None:
for i in range(0, len(docs), BATCH_SIZE):
bd = docs[i : i + BATCH_SIZE]
be = embeddings[i : i + BATCH_SIZE]
col.upsert(
ids = [d["id"] for d in bd],
embeddings = be,
documents = [d["embed_text"] for d in bd],
metadatas = [{
"orpha_code": d["orpha_code"],
"name": d["name"],
"definition": d["definition"][:500],
"synonyms": d["synonyms"],
"hpo_terms": d["hpo_terms"],
} for d in bd],
)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
print("=" * 60)
print("RareDx — Week 2B Step 1: Re-embed with HPO-Enriched Text")
print("=" * 60)
# Load graph store
sys.path.insert(0, str(Path(__file__).parent))
from graph_store import LocalGraphStore
store = LocalGraphStore()
print(f"\nGraph: {store.disease_count():,} diseases | "
f"{store.hpo_term_count():,} HPO terms | "
f"{store.manifestation_count():,} phenotype edges")
# Build documents
print("\nBuilding HPO-enriched documents...")
docs = build_documents(store)
print(f" {len(docs):,} documents ready.")
# Sample — show the enrichment difference
sample = next((d for d in docs if "Marfan" in d["name"]), docs[0])
print(f"\n Sample — {sample['name']}:")
preview = sample["embed_text"][:300]
print(f" {preview}...")
# Load model
print(f"\nLoading {EMBED_MODEL}...")
model = SentenceTransformer(EMBED_MODEL)
print(f" Embedding dim: {model.get_sentence_embedding_dimension()}")
# Embed
print(f"\nEmbedding {len(docs):,} documents (batch={BATCH_SIZE})...")
texts = [d["embed_text"] for d in docs]
embeddings = model.encode(
texts,
batch_size=BATCH_SIZE,
show_progress_bar=True,
normalize_embeddings=True,
)
print(f" Shape: {embeddings.shape}")
# Store
print("\nConnecting to ChromaDB...")
client, backend = get_chroma_client()
print(f" Backend: {backend}")
col = recreate_collection(client, COLLECTION_NAME)
print(f"Upserting {len(docs):,} documents...")
upsert_batches(col, docs, embeddings.tolist())
print(f" Collection '{COLLECTION_NAME}': {col.count():,} documents.")
# Sanity check — now "arachnodactyly tall stature ectopia lentis" should hit Marfan
print("\nSanity check: 'arachnodactyly tall stature ectopia lentis aortic dilation'")
probe = model.encode(
["arachnodactyly tall stature ectopia lentis aortic dilation"],
normalize_embeddings=True,
)
results = col.query(query_embeddings=probe.tolist(), n_results=5)
for meta, dist in zip(results["metadatas"][0], results["distances"][0]):
sim = round(1 - dist, 4)
print(f" [{sim:.4f}] ORPHA:{meta['orpha_code']} {meta['name']}")
print(f"\nStep 1 done — backend: {backend}")
if __name__ == "__main__":
main()