""" embed_chromadb.py ----------------- Generates BioLORD-2023 embeddings for each Orphanet disease and stores them in ChromaDB. Primary: ChromaDB HTTP client (Docker service at localhost:8000) Fallback: ChromaDB PersistentClient (embedded, no server required) Embedding text strategy: ". . Also known as: , , ..." """ import os import sys from pathlib import Path from lxml import etree import chromadb from chromadb.config import Settings from sentence_transformers import SentenceTransformer from dotenv import load_dotenv load_dotenv(Path(__file__).parents[2] / ".env") CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost") CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000")) COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases") EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023") XML_PATH = Path(os.getenv("ORPHANET_XML", "./data/orphanet/en_product1.xml")) CHROMA_PERSIST_DIR = Path(__file__).parents[2] / "data" / "chromadb" BATCH_SIZE = 32 # --------------------------------------------------------------------------- # XML parsing # --------------------------------------------------------------------------- def _text(element, xpath: str) -> str: nodes = element.xpath(xpath) if nodes: val = nodes[0] return (val.text or "").strip() if hasattr(val, "text") else str(val).strip() return "" def parse_disorders(xml_path: Path) -> list[dict]: print(f"Parsing {xml_path} ...") tree = etree.parse(str(xml_path)) root = tree.getroot() disorders = [] for disorder in root.xpath("//Disorder"): orpha_code = _text(disorder, "OrphaCode") name = _text(disorder, "Name[@lang='en']") definition = _text(disorder, "TextAuto[@lang='en']") synonyms = [ s.text.strip() for s in disorder.xpath("SynonymList/Synonym[@lang='en']") if s.text and s.text.strip() ] if not orpha_code or not name: continue parts = [name] if definition: parts.append(definition) if synonyms: parts.append(f"Also known as: {', '.join(synonyms)}.") embed_text = " ".join(parts) disorders.append({ "id": f"ORPHA:{orpha_code}", "orpha_code": orpha_code, "name": name, "definition": definition, "synonyms": synonyms, "embed_text": embed_text, }) print(f" Parsed {len(disorders)} disorders.") return disorders # --------------------------------------------------------------------------- # ChromaDB client — HTTP first, persistent fallback # --------------------------------------------------------------------------- def get_chroma_client() -> tuple[chromadb.ClientAPI, str]: """ Try HTTP client (Docker). On failure, fall back to embedded PersistentClient. Returns (client, backend_label). """ try: client = chromadb.HttpClient( host=CHROMA_HOST, port=CHROMA_PORT, settings=Settings(anonymized_telemetry=False), ) client.heartbeat() print(" ChromaDB HTTP server connected.") return client, "ChromaDB HTTP (Docker)" except Exception as exc: print(f" ChromaDB HTTP not reachable ({exc}).") print(f" Using embedded PersistentClient at {CHROMA_PERSIST_DIR}") CHROMA_PERSIST_DIR.mkdir(parents=True, exist_ok=True) client = chromadb.PersistentClient( path=str(CHROMA_PERSIST_DIR), settings=Settings(anonymized_telemetry=False), ) return client, "ChromaDB Embedded (local)" def get_or_create_collection(client: chromadb.ClientAPI, name: str) -> chromadb.Collection: try: client.delete_collection(name) print(f" Deleted existing collection '{name}'.") except Exception: pass collection = client.create_collection( name=name, metadata={"hnsw:space": "cosine"}, ) print(f" Created collection '{name}'.") return collection def upsert_in_batches( collection: chromadb.Collection, disorders: list[dict], embeddings: list[list[float]], ) -> None: for i in range(0, len(disorders), BATCH_SIZE): bd = disorders[i : i + BATCH_SIZE] be = embeddings[i : i + BATCH_SIZE] collection.upsert( ids=[d["id"] for d in bd], embeddings=be, documents=[d["embed_text"] for d in bd], metadatas=[ { "orpha_code": d["orpha_code"], "name": d["name"], "definition": d["definition"][:500] if d["definition"] else "", "synonyms": ", ".join(d["synonyms"]), } for d in bd ], ) print(f" Upserted {min(i + BATCH_SIZE, len(disorders))} / {len(disorders)} ...", end="\r") print() # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: print("=" * 60) print("RareDx — Step 3: Embed Diseases into ChromaDB (BioLORD-2023)") print("=" * 60) if not XML_PATH.exists(): print(f"ERROR: XML not found at {XML_PATH}. Run download_orphanet.py first.") sys.exit(1) disorders = parse_disorders(XML_PATH) # Load BioLORD-2023 print(f"\nLoading embedding model: {EMBED_MODEL}") print(" (First run will download ~440 MB from HuggingFace — please wait.)") model = SentenceTransformer(EMBED_MODEL) dim = model.get_sentence_embedding_dimension() print(f" Model loaded. Embedding dim: {dim}") # Generate embeddings print(f"\nGenerating embeddings for {len(disorders)} diseases...") texts = [d["embed_text"] for d in disorders] embeddings = model.encode( texts, batch_size=BATCH_SIZE, show_progress_bar=True, normalize_embeddings=True, ) print(f" Embeddings shape: {embeddings.shape}") # Connect to ChromaDB print("\nConnecting to ChromaDB...") chroma, backend_label = get_chroma_client() collection = get_or_create_collection(chroma, COLLECTION_NAME) print(f"\nUpserting {len(disorders)} documents...") upsert_in_batches(collection, disorders, embeddings.tolist()) final_count = collection.count() print(f" Collection '{COLLECTION_NAME}' has {final_count} documents.") # Sanity check print("\nSanity check: semantic search for 'connective tissue disorder'") probe = model.encode(["connective tissue disorder"], normalize_embeddings=True) results = collection.query(query_embeddings=probe.tolist(), n_results=3) for meta in results["metadatas"][0]: print(f" -> [{meta['orpha_code']}] {meta['name']}") print(f"\nStep 3 complete — backend: {backend_label}") if __name__ == "__main__": main()