| """ |
| embed_chromadb.py |
| ----------------- |
| Generates BioLORD-2023 embeddings for each Orphanet disease and stores |
| them in ChromaDB. |
| |
| Primary: ChromaDB HTTP client (Docker service at localhost:8000) |
| Fallback: ChromaDB PersistentClient (embedded, no server required) |
| |
| Embedding text strategy: |
| "<name>. <definition>. Also known as: <syn1>, <syn2>, ..." |
| """ |
|
|
| import os |
| import sys |
| from pathlib import Path |
| from lxml import etree |
| import chromadb |
| from chromadb.config import Settings |
| from sentence_transformers import SentenceTransformer |
| from dotenv import load_dotenv |
|
|
| load_dotenv(Path(__file__).parents[2] / ".env") |
|
|
| CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost") |
| CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000")) |
| COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases") |
| EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023") |
| XML_PATH = Path(os.getenv("ORPHANET_XML", "./data/orphanet/en_product1.xml")) |
|
|
| CHROMA_PERSIST_DIR = Path(__file__).parents[2] / "data" / "chromadb" |
| BATCH_SIZE = 32 |
|
|
|
|
| |
| |
| |
|
|
| def _text(element, xpath: str) -> str: |
| nodes = element.xpath(xpath) |
| if nodes: |
| val = nodes[0] |
| return (val.text or "").strip() if hasattr(val, "text") else str(val).strip() |
| return "" |
|
|
|
|
| def parse_disorders(xml_path: Path) -> list[dict]: |
| print(f"Parsing {xml_path} ...") |
| tree = etree.parse(str(xml_path)) |
| root = tree.getroot() |
| disorders = [] |
| for disorder in root.xpath("//Disorder"): |
| orpha_code = _text(disorder, "OrphaCode") |
| name = _text(disorder, "Name[@lang='en']") |
| definition = _text(disorder, "TextAuto[@lang='en']") |
| synonyms = [ |
| s.text.strip() |
| for s in disorder.xpath("SynonymList/Synonym[@lang='en']") |
| if s.text and s.text.strip() |
| ] |
| if not orpha_code or not name: |
| continue |
|
|
| parts = [name] |
| if definition: |
| parts.append(definition) |
| if synonyms: |
| parts.append(f"Also known as: {', '.join(synonyms)}.") |
| embed_text = " ".join(parts) |
|
|
| disorders.append({ |
| "id": f"ORPHA:{orpha_code}", |
| "orpha_code": orpha_code, |
| "name": name, |
| "definition": definition, |
| "synonyms": synonyms, |
| "embed_text": embed_text, |
| }) |
|
|
| print(f" Parsed {len(disorders)} disorders.") |
| return disorders |
|
|
|
|
| |
| |
| |
|
|
| def get_chroma_client() -> tuple[chromadb.ClientAPI, str]: |
| """ |
| Try HTTP client (Docker). On failure, fall back to embedded PersistentClient. |
| Returns (client, backend_label). |
| """ |
| try: |
| client = chromadb.HttpClient( |
| host=CHROMA_HOST, |
| port=CHROMA_PORT, |
| settings=Settings(anonymized_telemetry=False), |
| ) |
| client.heartbeat() |
| print(" ChromaDB HTTP server connected.") |
| return client, "ChromaDB HTTP (Docker)" |
| except Exception as exc: |
| print(f" ChromaDB HTTP not reachable ({exc}).") |
| print(f" Using embedded PersistentClient at {CHROMA_PERSIST_DIR}") |
| CHROMA_PERSIST_DIR.mkdir(parents=True, exist_ok=True) |
| client = chromadb.PersistentClient( |
| path=str(CHROMA_PERSIST_DIR), |
| settings=Settings(anonymized_telemetry=False), |
| ) |
| return client, "ChromaDB Embedded (local)" |
|
|
|
|
| def get_or_create_collection(client: chromadb.ClientAPI, name: str) -> chromadb.Collection: |
| try: |
| client.delete_collection(name) |
| print(f" Deleted existing collection '{name}'.") |
| except Exception: |
| pass |
| collection = client.create_collection( |
| name=name, |
| metadata={"hnsw:space": "cosine"}, |
| ) |
| print(f" Created collection '{name}'.") |
| return collection |
|
|
|
|
| def upsert_in_batches( |
| collection: chromadb.Collection, |
| disorders: list[dict], |
| embeddings: list[list[float]], |
| ) -> None: |
| for i in range(0, len(disorders), BATCH_SIZE): |
| bd = disorders[i : i + BATCH_SIZE] |
| be = embeddings[i : i + BATCH_SIZE] |
| collection.upsert( |
| ids=[d["id"] for d in bd], |
| embeddings=be, |
| documents=[d["embed_text"] for d in bd], |
| metadatas=[ |
| { |
| "orpha_code": d["orpha_code"], |
| "name": d["name"], |
| "definition": d["definition"][:500] if d["definition"] else "", |
| "synonyms": ", ".join(d["synonyms"]), |
| } |
| for d in bd |
| ], |
| ) |
| print(f" Upserted {min(i + BATCH_SIZE, len(disorders))} / {len(disorders)} ...", end="\r") |
| print() |
|
|
|
|
| |
| |
| |
|
|
| def main() -> None: |
| print("=" * 60) |
| print("RareDx — Step 3: Embed Diseases into ChromaDB (BioLORD-2023)") |
| print("=" * 60) |
|
|
| if not XML_PATH.exists(): |
| print(f"ERROR: XML not found at {XML_PATH}. Run download_orphanet.py first.") |
| sys.exit(1) |
|
|
| disorders = parse_disorders(XML_PATH) |
|
|
| |
| print(f"\nLoading embedding model: {EMBED_MODEL}") |
| print(" (First run will download ~440 MB from HuggingFace — please wait.)") |
| model = SentenceTransformer(EMBED_MODEL) |
| dim = model.get_sentence_embedding_dimension() |
| print(f" Model loaded. Embedding dim: {dim}") |
|
|
| |
| print(f"\nGenerating embeddings for {len(disorders)} diseases...") |
| texts = [d["embed_text"] for d in disorders] |
| embeddings = model.encode( |
| texts, |
| batch_size=BATCH_SIZE, |
| show_progress_bar=True, |
| normalize_embeddings=True, |
| ) |
| print(f" Embeddings shape: {embeddings.shape}") |
|
|
| |
| print("\nConnecting to ChromaDB...") |
| chroma, backend_label = get_chroma_client() |
| collection = get_or_create_collection(chroma, COLLECTION_NAME) |
|
|
| print(f"\nUpserting {len(disorders)} documents...") |
| upsert_in_batches(collection, disorders, embeddings.tolist()) |
|
|
| final_count = collection.count() |
| print(f" Collection '{COLLECTION_NAME}' has {final_count} documents.") |
|
|
| |
| print("\nSanity check: semantic search for 'connective tissue disorder'") |
| probe = model.encode(["connective tissue disorder"], normalize_embeddings=True) |
| results = collection.query(query_embeddings=probe.tolist(), n_results=3) |
| for meta in results["metadatas"][0]: |
| print(f" -> [{meta['orpha_code']}] {meta['name']}") |
|
|
| print(f"\nStep 3 complete — backend: {backend_label}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|