File size: 8,548 Bytes
89c6379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""
symptom_parser.py
-----------------
Maps free-text clinical symptoms to HPO term IDs using BioLORD-2023
semantic similarity — no string matching, no exact-name lookup.

Algorithm:
  1. Build an HPO embedding index: embed all 8,701 HPO terms with BioLORD.
  2. Segment the clinical note into candidate phrases.
  3. Embed each phrase and find the nearest HPO term by cosine similarity.
  4. Return matches above a confidence threshold.

The index is cached to disk so it only needs to be built once.

Can be used as a module (SymptomParser class) or as a CLI:
  python symptom_parser.py "tall stature, displaced lens, heart murmur"
"""

import io
import json
import sys
import re
from dataclasses import dataclass
from pathlib import Path

import numpy as np
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv

load_dotenv(Path(__file__).parents[2] / ".env")

INDEX_DIR  = Path(__file__).parents[2] / "data" / "hpo_index"
EMBED_FILE = INDEX_DIR / "embeddings.npy"
TERMS_FILE = INDEX_DIR / "terms.json"

# Multi-word phrase threshold — catches paraphrases well.
DEFAULT_THRESHOLD = 0.55

# Single-word threshold — higher because a single word has no context;
# only exact or near-exact HPO terms (e.g. "scoliosis" → 0.95) should pass.
SINGLE_WORD_THRESHOLD = 0.82


@dataclass
class HPOMatch:
    phrase:    str
    hpo_id:    str
    term:      str
    score:     float


# ---------------------------------------------------------------------------
# Index build / load
# ---------------------------------------------------------------------------

def build_hpo_index(model: SentenceTransformer) -> tuple[np.ndarray, list[dict]]:
    """
    Embed all HPOTerm nodes from the graph store.
    Returns (embeddings [N, D], terms [{"hpo_id": ..., "term": ...}]).
    """
    sys.path.insert(0, str(Path(__file__).parent))
    from graph_store import LocalGraphStore

    store = LocalGraphStore()
    terms = [
        {"hpo_id": attrs["hpo_id"], "term": attrs["term"]}
        for _, attrs in store.graph.nodes(data=True)
        if attrs.get("type") == "HPOTerm"
    ]

    if not terms:
        raise RuntimeError("No HPOTerm nodes in graph store. Run ingest_hpo.py first.")

    print(f"  Building HPO index for {len(terms):,} terms...")
    texts = [t["term"] for t in terms]
    embeddings = model.encode(
        texts,
        batch_size=128,
        show_progress_bar=True,
        normalize_embeddings=True,
    )

    INDEX_DIR.mkdir(parents=True, exist_ok=True)
    np.save(str(EMBED_FILE), embeddings.astype(np.float32))
    TERMS_FILE.write_text(json.dumps(terms, ensure_ascii=False), encoding="utf-8")
    print(f"  Index saved to {INDEX_DIR}")

    return embeddings.astype(np.float32), terms


def load_hpo_index(model: SentenceTransformer, force_rebuild: bool = False):
    """Load cached index or build it if missing / stale."""
    if not force_rebuild and EMBED_FILE.exists() and TERMS_FILE.exists():
        embeddings = np.load(str(EMBED_FILE))
        terms = json.loads(TERMS_FILE.read_text(encoding="utf-8"))
        return embeddings, terms

    return build_hpo_index(model)


# ---------------------------------------------------------------------------
# Note segmentation
# ---------------------------------------------------------------------------

# Clinical notes typically list symptoms as comma-separated phrases,
# sometimes separated by semicolons, periods, or conjunctions.
_SPLIT_RE = re.compile(r"[,;]|\band\b|\bwith\b|\bplus\b", re.IGNORECASE)
# Tokens that are almost certainly not symptoms (demographics, filler words).
# Single-word symptoms like "scoliosis" must NOT match this.
_SKIP_RE = re.compile(
    r"^\s*("
    r"\d+[\s-]*(year|month|week|day|yr|mo)s?[\s-]*(old)?"  # age
    r"|male|female|man|woman|boy|girl"                       # sex/gender
    r"|patient|presents?|has|have|had|history|noted"         # clinical filler
    r"|found|showing|revealed|demonstrated"                  # more filler
    r"|with|and|the|a|an|of|in|on|at|to|by"                 # stop words
    r"|left|right|bilateral|unilateral"                      # laterality alone
    r")\s*$",
    re.IGNORECASE,
)


def segment_note(note: str) -> list[str]:
    """
    Split a clinical note into candidate symptom phrases.

    Single words are allowed through (unlike before) but will be held to
    a higher BioLORD similarity threshold in SymptomParser.parse().
    Demographic / filler tokens are still stripped by _SKIP_RE.
    """
    raw_phrases = _SPLIT_RE.split(note)
    phrases = []
    for p in raw_phrases:
        p = p.strip().rstrip(".")
        if not p or _SKIP_RE.match(p):
            continue
        phrases.append(p)
    return phrases


# ---------------------------------------------------------------------------
# SymptomParser
# ---------------------------------------------------------------------------

class SymptomParser:
    """
    Maps free-text clinical notes to HPO term matches using BioLORD embeddings.

    Usage:
        parser = SymptomParser(model)
        matches = parser.parse("tall stature, displaced lens, heart murmur")
    """

    def __init__(
        self,
        model: SentenceTransformer,
        threshold: float = DEFAULT_THRESHOLD,
        force_rebuild: bool = False,
    ) -> None:
        self.model     = model
        self.threshold = threshold
        print("Loading HPO embedding index...")
        self.embeddings, self.terms = load_hpo_index(model, force_rebuild)
        print(f"  Index ready: {len(self.terms):,} HPO terms, "
              f"dim={self.embeddings.shape[1]}")

    def parse(self, clinical_note: str) -> list[HPOMatch]:
        """
        Parse a clinical note and return HPO matches above threshold.
        Deduplicates by HPO ID (keeps highest-scoring match per term).
        """
        phrases = segment_note(clinical_note)
        if not phrases:
            return []

        # Embed all phrases in one batch
        phrase_embs = self.model.encode(
            phrases,
            normalize_embeddings=True,
            show_progress_bar=False,
        )  # (P, D)

        # Cosine similarity against entire HPO index: (P, N)
        sims = phrase_embs @ self.embeddings.T  # normalized, so dot = cosine

        # For each phrase pick the best HPO term
        best_indices = np.argmax(sims, axis=1)
        best_scores  = sims[np.arange(len(phrases)), best_indices]

        # Collect matches above threshold.
        # Single-word phrases need a stricter threshold to avoid false positives.
        seen_hpo: dict[str, HPOMatch] = {}
        for phrase, idx, score in zip(phrases, best_indices, best_scores):
            is_single_word = len(phrase.split()) == 1
            cutoff = SINGLE_WORD_THRESHOLD if is_single_word else self.threshold
            if float(score) < cutoff:
                continue
            t = self.terms[idx]
            hpo_id = t["hpo_id"]
            match  = HPOMatch(
                phrase=phrase,
                hpo_id=hpo_id,
                term=t["term"],
                score=round(float(score), 4),
            )
            # Keep the highest-scoring phrase for each HPO ID
            if hpo_id not in seen_hpo or seen_hpo[hpo_id].score < match.score:
                seen_hpo[hpo_id] = match

        # Sort by score descending
        return sorted(seen_hpo.values(), key=lambda m: m.score, reverse=True)


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

def main() -> None:
    sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")

    import os
    embed_model = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023")
    note = " ".join(sys.argv[1:]) if len(sys.argv) > 1 else (
        "18 year old male, extremely tall, displaced lens in left eye, "
        "heart murmur, flexible joints, scoliosis"
    )

    print("=" * 60)
    print("RareDx Symptom Parser — HPO Semantic Matching")
    print("=" * 60)
    print(f"\nInput: {note}\n")

    model  = SentenceTransformer(embed_model)
    parser = SymptomParser(model)
    matches = parser.parse(note)

    print(f"\nMatched {len(matches)} HPO terms:\n")
    print(f"  {'Score':>6}  {'HPO ID':<12}  {'Term':<40}  Phrase")
    print(f"  {'-'*6}  {'-'*12}  {'-'*40}  {'-'*30}")
    for m in matches:
        print(f"  {m.score:>6.4f}  {m.hpo_id:<12}  {m.term:<40}  \"{m.phrase}\"")


if __name__ == "__main__":
    main()