noteguard-agent / src /retrieve.py
Chaeyoon
deploy: v0.2.0 Gold RAP structure
a08af1a
Raw
History Blame Contribute Delete
2.81 kB
"""Superlinked retrieval node — PHI-safe semantic search over de-identified notes.
GUARANTEE: every document entering the index and every chunk leaving it passes
through assert_clean(). PHI cannot enter or leave this module.
sentence-transformers >= 5.6 renamed _model_config to _get_model_config(); we add
a compat property before importing superlinked so no site-packages edits are needed.
Note: no `from __future__ import annotations` here — Superlinked's SchemaFactory
inspects class annotations at runtime and needs real type objects, not strings.
"""
# --- compat shim: sentence-transformers 5.6 + superlinked 37 --------------------
try:
from sentence_transformers import SentenceTransformer as _ST
if not hasattr(_ST, "_model_config"):
_ST._model_config = property(lambda self: self._get_model_config()) # type: ignore[attr-defined]
except ImportError:
pass
# ---------------------------------------------------------------------------------
import superlinked.framework as sl
from src.deid import NoteGuard
_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
class _NoteSchema(sl.Schema):
note_id: sl.IdField
text: sl.String
_note = _NoteSchema()
_space = sl.TextSimilaritySpace(text=_note.text, model=_MODEL)
_index = sl.Index([_space])
class NoteIndex:
"""In-memory Superlinked index over de-identified clinical notes."""
def __init__(self) -> None:
self._source = sl.InMemorySource(_note)
executor = sl.InMemoryExecutor(sources=[self._source], indices=[_index])
self._app = executor.run()
self._query = (
sl.Query(_index)
.find(_note)
.select_all()
.similar(_space.text, sl.Param("query_text"))
.limit(sl.Param("limit"))
)
self._count = 0
def add_notes(self, notes: list[dict], ng: NoteGuard) -> None:
"""Index de-identified notes. Raises if any PHI is detected."""
rows = []
for note in notes:
text = note.get("text", "")
ng.assert_clean(text)
self._count += 1
rows.append({"note_id": note.get("note_id", str(self._count)), "text": text})
if rows:
self._source.put(rows)
def retrieve(self, query_text: str, ng: NoteGuard, top_k: int = 3) -> list[str]:
"""Return de-identified context chunks. Raises if PHI is found anywhere."""
ng.assert_clean(query_text)
if self._count == 0:
return []
result = self._app.query(self._query, query_text=query_text, limit=top_k)
chunks = []
for entry in result.entries:
text = str(entry.fields.get("text", ""))
if text:
ng.assert_clean(text)
chunks.append(text)
return chunks