raredx / backend /scripts /embed_chromadb.py
Aswin92's picture
Upload folder using huggingface_hub
89c6379 verified
"""
embed_chromadb.py
-----------------
Generates BioLORD-2023 embeddings for each Orphanet disease and stores
them in ChromaDB.
Primary: ChromaDB HTTP client (Docker service at localhost:8000)
Fallback: ChromaDB PersistentClient (embedded, no server required)
Embedding text strategy:
"<name>. <definition>. Also known as: <syn1>, <syn2>, ..."
"""
import os
import sys
from pathlib import Path
from lxml import etree
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
load_dotenv(Path(__file__).parents[2] / ".env")
CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost")
CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000"))
COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases")
EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023")
XML_PATH = Path(os.getenv("ORPHANET_XML", "./data/orphanet/en_product1.xml"))
CHROMA_PERSIST_DIR = Path(__file__).parents[2] / "data" / "chromadb"
BATCH_SIZE = 32
# ---------------------------------------------------------------------------
# XML parsing
# ---------------------------------------------------------------------------
def _text(element, xpath: str) -> str:
nodes = element.xpath(xpath)
if nodes:
val = nodes[0]
return (val.text or "").strip() if hasattr(val, "text") else str(val).strip()
return ""
def parse_disorders(xml_path: Path) -> list[dict]:
print(f"Parsing {xml_path} ...")
tree = etree.parse(str(xml_path))
root = tree.getroot()
disorders = []
for disorder in root.xpath("//Disorder"):
orpha_code = _text(disorder, "OrphaCode")
name = _text(disorder, "Name[@lang='en']")
definition = _text(disorder, "TextAuto[@lang='en']")
synonyms = [
s.text.strip()
for s in disorder.xpath("SynonymList/Synonym[@lang='en']")
if s.text and s.text.strip()
]
if not orpha_code or not name:
continue
parts = [name]
if definition:
parts.append(definition)
if synonyms:
parts.append(f"Also known as: {', '.join(synonyms)}.")
embed_text = " ".join(parts)
disorders.append({
"id": f"ORPHA:{orpha_code}",
"orpha_code": orpha_code,
"name": name,
"definition": definition,
"synonyms": synonyms,
"embed_text": embed_text,
})
print(f" Parsed {len(disorders)} disorders.")
return disorders
# ---------------------------------------------------------------------------
# ChromaDB client — HTTP first, persistent fallback
# ---------------------------------------------------------------------------
def get_chroma_client() -> tuple[chromadb.ClientAPI, str]:
"""
Try HTTP client (Docker). On failure, fall back to embedded PersistentClient.
Returns (client, backend_label).
"""
try:
client = chromadb.HttpClient(
host=CHROMA_HOST,
port=CHROMA_PORT,
settings=Settings(anonymized_telemetry=False),
)
client.heartbeat()
print(" ChromaDB HTTP server connected.")
return client, "ChromaDB HTTP (Docker)"
except Exception as exc:
print(f" ChromaDB HTTP not reachable ({exc}).")
print(f" Using embedded PersistentClient at {CHROMA_PERSIST_DIR}")
CHROMA_PERSIST_DIR.mkdir(parents=True, exist_ok=True)
client = chromadb.PersistentClient(
path=str(CHROMA_PERSIST_DIR),
settings=Settings(anonymized_telemetry=False),
)
return client, "ChromaDB Embedded (local)"
def get_or_create_collection(client: chromadb.ClientAPI, name: str) -> chromadb.Collection:
try:
client.delete_collection(name)
print(f" Deleted existing collection '{name}'.")
except Exception:
pass
collection = client.create_collection(
name=name,
metadata={"hnsw:space": "cosine"},
)
print(f" Created collection '{name}'.")
return collection
def upsert_in_batches(
collection: chromadb.Collection,
disorders: list[dict],
embeddings: list[list[float]],
) -> None:
for i in range(0, len(disorders), BATCH_SIZE):
bd = disorders[i : i + BATCH_SIZE]
be = embeddings[i : i + BATCH_SIZE]
collection.upsert(
ids=[d["id"] for d in bd],
embeddings=be,
documents=[d["embed_text"] for d in bd],
metadatas=[
{
"orpha_code": d["orpha_code"],
"name": d["name"],
"definition": d["definition"][:500] if d["definition"] else "",
"synonyms": ", ".join(d["synonyms"]),
}
for d in bd
],
)
print(f" Upserted {min(i + BATCH_SIZE, len(disorders))} / {len(disorders)} ...", end="\r")
print()
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
print("=" * 60)
print("RareDx — Step 3: Embed Diseases into ChromaDB (BioLORD-2023)")
print("=" * 60)
if not XML_PATH.exists():
print(f"ERROR: XML not found at {XML_PATH}. Run download_orphanet.py first.")
sys.exit(1)
disorders = parse_disorders(XML_PATH)
# Load BioLORD-2023
print(f"\nLoading embedding model: {EMBED_MODEL}")
print(" (First run will download ~440 MB from HuggingFace — please wait.)")
model = SentenceTransformer(EMBED_MODEL)
dim = model.get_sentence_embedding_dimension()
print(f" Model loaded. Embedding dim: {dim}")
# Generate embeddings
print(f"\nGenerating embeddings for {len(disorders)} diseases...")
texts = [d["embed_text"] for d in disorders]
embeddings = model.encode(
texts,
batch_size=BATCH_SIZE,
show_progress_bar=True,
normalize_embeddings=True,
)
print(f" Embeddings shape: {embeddings.shape}")
# Connect to ChromaDB
print("\nConnecting to ChromaDB...")
chroma, backend_label = get_chroma_client()
collection = get_or_create_collection(chroma, COLLECTION_NAME)
print(f"\nUpserting {len(disorders)} documents...")
upsert_in_batches(collection, disorders, embeddings.tolist())
final_count = collection.count()
print(f" Collection '{COLLECTION_NAME}' has {final_count} documents.")
# Sanity check
print("\nSanity check: semantic search for 'connective tissue disorder'")
probe = model.encode(["connective tissue disorder"], normalize_embeddings=True)
results = collection.query(query_embeddings=probe.tolist(), n_results=3)
for meta in results["metadatas"][0]:
print(f" -> [{meta['orpha_code']}] {meta['name']}")
print(f"\nStep 3 complete — backend: {backend_label}")
if __name__ == "__main__":
main()