File size: 7,756 Bytes
89c6379 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 | """
reembed_chromadb.py
-------------------
Rebuilds ChromaDB embeddings with HPO-enriched disease descriptions.
Week 1 embedding text:
"{name}. {definition}. Also known as: {synonyms}."
Week 2B embedding text (this script):
"{name}. {definition}. Phenotypes: {hpo_terms ordered by frequency}.
Also known as: {synonyms}."
Adding phenotype terms directly into the embedding space means ChromaDB
can now find diseases by symptoms, not just by name similarity.
"""
import os
import sys
from pathlib import Path
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from dotenv import load_dotenv
load_dotenv(Path(__file__).parents[2] / ".env")
CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost")
CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000"))
COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases")
EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023")
CHROMA_PERSIST = Path(__file__).parents[2] / "data" / "chromadb"
BATCH_SIZE = 32
# ---------------------------------------------------------------------------
# Build enriched document text per disease
# ---------------------------------------------------------------------------
def build_documents(store) -> list[dict]:
"""
Pull every disease from the graph store and build HPO-enriched embed text.
HPO terms are sorted by frequency_order (most frequent first).
"""
docs = []
disease_nodes = [
(nid, attrs)
for nid, attrs in store.graph.nodes(data=True)
if attrs.get("type") == "Disease"
]
for nid, attrs in tqdm(disease_nodes, desc=" Building documents", unit="disease"):
orpha_code = attrs["orpha_code"]
name = attrs.get("name", "")
definition = attrs.get("definition", "")
# Collect synonyms and HPO terms from graph edges
synonyms = []
hpo_terms = []
for v, edata in store.graph[nid].items():
vattrs = store.graph.nodes[v]
vtype = vattrs.get("type")
if vtype == "Synonym":
synonyms.append(vattrs["text"])
elif vtype == "HPOTerm" and edata.get("label") == "MANIFESTS_AS":
freq_order = edata.get("frequency_order", 9)
# Skip excluded phenotypes (frequency_order == 5)
if freq_order == 5:
continue
hpo_terms.append((freq_order, vattrs.get("term", "")))
# Sort HPO terms: most frequent first
hpo_terms.sort(key=lambda x: x[0])
hpo_term_names = [t[1] for t in hpo_terms[:30]] # cap at 30 to control token length
# Build enriched text
parts = [name]
if definition:
parts.append(definition)
if hpo_term_names:
parts.append("Clinical features: " + ", ".join(hpo_term_names) + ".")
if synonyms:
parts.append("Also known as: " + ", ".join(synonyms) + ".")
embed_text = " ".join(parts)
docs.append({
"id": f"ORPHA:{orpha_code}",
"orpha_code": str(orpha_code),
"name": name,
"definition": definition,
"synonyms": ", ".join(synonyms),
"hpo_terms": ", ".join(hpo_term_names[:15]), # store subset in metadata
"embed_text": embed_text,
})
return docs
# ---------------------------------------------------------------------------
# ChromaDB helpers
# ---------------------------------------------------------------------------
def get_chroma_client() -> tuple[chromadb.ClientAPI, str]:
try:
client = chromadb.HttpClient(
host=CHROMA_HOST, port=CHROMA_PORT,
settings=Settings(anonymized_telemetry=False),
)
client.heartbeat()
return client, "ChromaDB HTTP (Docker)"
except Exception:
CHROMA_PERSIST.mkdir(parents=True, exist_ok=True)
client = chromadb.PersistentClient(
path=str(CHROMA_PERSIST),
settings=Settings(anonymized_telemetry=False),
)
return client, "ChromaDB Embedded"
def recreate_collection(client: chromadb.ClientAPI, name: str) -> chromadb.Collection:
try:
client.delete_collection(name)
print(f" Deleted existing collection '{name}'.")
except Exception:
pass
col = client.create_collection(name=name, metadata={"hnsw:space": "cosine"})
print(f" Created collection '{name}'.")
return col
def upsert_batches(col, docs: list[dict], embeddings) -> None:
for i in range(0, len(docs), BATCH_SIZE):
bd = docs[i : i + BATCH_SIZE]
be = embeddings[i : i + BATCH_SIZE]
col.upsert(
ids = [d["id"] for d in bd],
embeddings = be,
documents = [d["embed_text"] for d in bd],
metadatas = [{
"orpha_code": d["orpha_code"],
"name": d["name"],
"definition": d["definition"][:500],
"synonyms": d["synonyms"],
"hpo_terms": d["hpo_terms"],
} for d in bd],
)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
print("=" * 60)
print("RareDx — Week 2B Step 1: Re-embed with HPO-Enriched Text")
print("=" * 60)
# Load graph store
sys.path.insert(0, str(Path(__file__).parent))
from graph_store import LocalGraphStore
store = LocalGraphStore()
print(f"\nGraph: {store.disease_count():,} diseases | "
f"{store.hpo_term_count():,} HPO terms | "
f"{store.manifestation_count():,} phenotype edges")
# Build documents
print("\nBuilding HPO-enriched documents...")
docs = build_documents(store)
print(f" {len(docs):,} documents ready.")
# Sample — show the enrichment difference
sample = next((d for d in docs if "Marfan" in d["name"]), docs[0])
print(f"\n Sample — {sample['name']}:")
preview = sample["embed_text"][:300]
print(f" {preview}...")
# Load model
print(f"\nLoading {EMBED_MODEL}...")
model = SentenceTransformer(EMBED_MODEL)
print(f" Embedding dim: {model.get_sentence_embedding_dimension()}")
# Embed
print(f"\nEmbedding {len(docs):,} documents (batch={BATCH_SIZE})...")
texts = [d["embed_text"] for d in docs]
embeddings = model.encode(
texts,
batch_size=BATCH_SIZE,
show_progress_bar=True,
normalize_embeddings=True,
)
print(f" Shape: {embeddings.shape}")
# Store
print("\nConnecting to ChromaDB...")
client, backend = get_chroma_client()
print(f" Backend: {backend}")
col = recreate_collection(client, COLLECTION_NAME)
print(f"Upserting {len(docs):,} documents...")
upsert_batches(col, docs, embeddings.tolist())
print(f" Collection '{COLLECTION_NAME}': {col.count():,} documents.")
# Sanity check — now "arachnodactyly tall stature ectopia lentis" should hit Marfan
print("\nSanity check: 'arachnodactyly tall stature ectopia lentis aortic dilation'")
probe = model.encode(
["arachnodactyly tall stature ectopia lentis aortic dilation"],
normalize_embeddings=True,
)
results = col.query(query_embeddings=probe.tolist(), n_results=5)
for meta, dist in zip(results["metadatas"][0], results["distances"][0]):
sim = round(1 - dist, 4)
print(f" [{sim:.4f}] ORPHA:{meta['orpha_code']} {meta['name']}")
print(f"\nStep 1 done — backend: {backend}")
if __name__ == "__main__":
main()
|