raredx / backend /scripts /symptom_parser.py
Aswin92's picture
Upload folder using huggingface_hub
89c6379 verified
"""
symptom_parser.py
-----------------
Maps free-text clinical symptoms to HPO term IDs using BioLORD-2023
semantic similarity — no string matching, no exact-name lookup.
Algorithm:
1. Build an HPO embedding index: embed all 8,701 HPO terms with BioLORD.
2. Segment the clinical note into candidate phrases.
3. Embed each phrase and find the nearest HPO term by cosine similarity.
4. Return matches above a confidence threshold.
The index is cached to disk so it only needs to be built once.
Can be used as a module (SymptomParser class) or as a CLI:
python symptom_parser.py "tall stature, displaced lens, heart murmur"
"""
import io
import json
import sys
import re
from dataclasses import dataclass
from pathlib import Path
import numpy as np
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
load_dotenv(Path(__file__).parents[2] / ".env")
INDEX_DIR = Path(__file__).parents[2] / "data" / "hpo_index"
EMBED_FILE = INDEX_DIR / "embeddings.npy"
TERMS_FILE = INDEX_DIR / "terms.json"
# Multi-word phrase threshold — catches paraphrases well.
DEFAULT_THRESHOLD = 0.55
# Single-word threshold — higher because a single word has no context;
# only exact or near-exact HPO terms (e.g. "scoliosis" → 0.95) should pass.
SINGLE_WORD_THRESHOLD = 0.82
@dataclass
class HPOMatch:
phrase: str
hpo_id: str
term: str
score: float
# ---------------------------------------------------------------------------
# Index build / load
# ---------------------------------------------------------------------------
def build_hpo_index(model: SentenceTransformer) -> tuple[np.ndarray, list[dict]]:
"""
Embed all HPOTerm nodes from the graph store.
Returns (embeddings [N, D], terms [{"hpo_id": ..., "term": ...}]).
"""
sys.path.insert(0, str(Path(__file__).parent))
from graph_store import LocalGraphStore
store = LocalGraphStore()
terms = [
{"hpo_id": attrs["hpo_id"], "term": attrs["term"]}
for _, attrs in store.graph.nodes(data=True)
if attrs.get("type") == "HPOTerm"
]
if not terms:
raise RuntimeError("No HPOTerm nodes in graph store. Run ingest_hpo.py first.")
print(f" Building HPO index for {len(terms):,} terms...")
texts = [t["term"] for t in terms]
embeddings = model.encode(
texts,
batch_size=128,
show_progress_bar=True,
normalize_embeddings=True,
)
INDEX_DIR.mkdir(parents=True, exist_ok=True)
np.save(str(EMBED_FILE), embeddings.astype(np.float32))
TERMS_FILE.write_text(json.dumps(terms, ensure_ascii=False), encoding="utf-8")
print(f" Index saved to {INDEX_DIR}")
return embeddings.astype(np.float32), terms
def load_hpo_index(model: SentenceTransformer, force_rebuild: bool = False):
"""Load cached index or build it if missing / stale."""
if not force_rebuild and EMBED_FILE.exists() and TERMS_FILE.exists():
embeddings = np.load(str(EMBED_FILE))
terms = json.loads(TERMS_FILE.read_text(encoding="utf-8"))
return embeddings, terms
return build_hpo_index(model)
# ---------------------------------------------------------------------------
# Note segmentation
# ---------------------------------------------------------------------------
# Clinical notes typically list symptoms as comma-separated phrases,
# sometimes separated by semicolons, periods, or conjunctions.
_SPLIT_RE = re.compile(r"[,;]|\band\b|\bwith\b|\bplus\b", re.IGNORECASE)
# Tokens that are almost certainly not symptoms (demographics, filler words).
# Single-word symptoms like "scoliosis" must NOT match this.
_SKIP_RE = re.compile(
r"^\s*("
r"\d+[\s-]*(year|month|week|day|yr|mo)s?[\s-]*(old)?" # age
r"|male|female|man|woman|boy|girl" # sex/gender
r"|patient|presents?|has|have|had|history|noted" # clinical filler
r"|found|showing|revealed|demonstrated" # more filler
r"|with|and|the|a|an|of|in|on|at|to|by" # stop words
r"|left|right|bilateral|unilateral" # laterality alone
r")\s*$",
re.IGNORECASE,
)
def segment_note(note: str) -> list[str]:
"""
Split a clinical note into candidate symptom phrases.
Single words are allowed through (unlike before) but will be held to
a higher BioLORD similarity threshold in SymptomParser.parse().
Demographic / filler tokens are still stripped by _SKIP_RE.
"""
raw_phrases = _SPLIT_RE.split(note)
phrases = []
for p in raw_phrases:
p = p.strip().rstrip(".")
if not p or _SKIP_RE.match(p):
continue
phrases.append(p)
return phrases
# ---------------------------------------------------------------------------
# SymptomParser
# ---------------------------------------------------------------------------
class SymptomParser:
"""
Maps free-text clinical notes to HPO term matches using BioLORD embeddings.
Usage:
parser = SymptomParser(model)
matches = parser.parse("tall stature, displaced lens, heart murmur")
"""
def __init__(
self,
model: SentenceTransformer,
threshold: float = DEFAULT_THRESHOLD,
force_rebuild: bool = False,
) -> None:
self.model = model
self.threshold = threshold
print("Loading HPO embedding index...")
self.embeddings, self.terms = load_hpo_index(model, force_rebuild)
print(f" Index ready: {len(self.terms):,} HPO terms, "
f"dim={self.embeddings.shape[1]}")
def parse(self, clinical_note: str) -> list[HPOMatch]:
"""
Parse a clinical note and return HPO matches above threshold.
Deduplicates by HPO ID (keeps highest-scoring match per term).
"""
phrases = segment_note(clinical_note)
if not phrases:
return []
# Embed all phrases in one batch
phrase_embs = self.model.encode(
phrases,
normalize_embeddings=True,
show_progress_bar=False,
) # (P, D)
# Cosine similarity against entire HPO index: (P, N)
sims = phrase_embs @ self.embeddings.T # normalized, so dot = cosine
# For each phrase pick the best HPO term
best_indices = np.argmax(sims, axis=1)
best_scores = sims[np.arange(len(phrases)), best_indices]
# Collect matches above threshold.
# Single-word phrases need a stricter threshold to avoid false positives.
seen_hpo: dict[str, HPOMatch] = {}
for phrase, idx, score in zip(phrases, best_indices, best_scores):
is_single_word = len(phrase.split()) == 1
cutoff = SINGLE_WORD_THRESHOLD if is_single_word else self.threshold
if float(score) < cutoff:
continue
t = self.terms[idx]
hpo_id = t["hpo_id"]
match = HPOMatch(
phrase=phrase,
hpo_id=hpo_id,
term=t["term"],
score=round(float(score), 4),
)
# Keep the highest-scoring phrase for each HPO ID
if hpo_id not in seen_hpo or seen_hpo[hpo_id].score < match.score:
seen_hpo[hpo_id] = match
# Sort by score descending
return sorted(seen_hpo.values(), key=lambda m: m.score, reverse=True)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main() -> None:
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
import os
embed_model = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023")
note = " ".join(sys.argv[1:]) if len(sys.argv) > 1 else (
"18 year old male, extremely tall, displaced lens in left eye, "
"heart murmur, flexible joints, scoliosis"
)
print("=" * 60)
print("RareDx Symptom Parser — HPO Semantic Matching")
print("=" * 60)
print(f"\nInput: {note}\n")
model = SentenceTransformer(embed_model)
parser = SymptomParser(model)
matches = parser.parse(note)
print(f"\nMatched {len(matches)} HPO terms:\n")
print(f" {'Score':>6} {'HPO ID':<12} {'Term':<40} Phrase")
print(f" {'-'*6} {'-'*12} {'-'*40} {'-'*30}")
for m in matches:
print(f" {m.score:>6.4f} {m.hpo_id:<12} {m.term:<40} \"{m.phrase}\"")
if __name__ == "__main__":
main()