Spaces:
Running
Running
| """Superlinked retrieval node — PHI-safe semantic search over de-identified notes. | |
| GUARANTEE: every document entering the index and every chunk leaving it passes | |
| through assert_clean(). PHI cannot enter or leave this module. | |
| sentence-transformers >= 5.6 renamed _model_config to _get_model_config(); we add | |
| a compat property before importing superlinked so no site-packages edits are needed. | |
| Note: no `from __future__ import annotations` here — Superlinked's SchemaFactory | |
| inspects class annotations at runtime and needs real type objects, not strings. | |
| """ | |
| # --- compat shim: sentence-transformers 5.6 + superlinked 37 -------------------- | |
| try: | |
| from sentence_transformers import SentenceTransformer as _ST | |
| if not hasattr(_ST, "_model_config"): | |
| _ST._model_config = property(lambda self: self._get_model_config()) # type: ignore[attr-defined] | |
| except ImportError: | |
| pass | |
| # --------------------------------------------------------------------------------- | |
| import superlinked.framework as sl | |
| from src.deid import NoteGuard | |
| _MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
| class _NoteSchema(sl.Schema): | |
| note_id: sl.IdField | |
| text: sl.String | |
| _note = _NoteSchema() | |
| _space = sl.TextSimilaritySpace(text=_note.text, model=_MODEL) | |
| _index = sl.Index([_space]) | |
| class NoteIndex: | |
| """In-memory Superlinked index over de-identified clinical notes.""" | |
| def __init__(self) -> None: | |
| self._source = sl.InMemorySource(_note) | |
| executor = sl.InMemoryExecutor(sources=[self._source], indices=[_index]) | |
| self._app = executor.run() | |
| self._query = ( | |
| sl.Query(_index) | |
| .find(_note) | |
| .select_all() | |
| .similar(_space.text, sl.Param("query_text")) | |
| .limit(sl.Param("limit")) | |
| ) | |
| self._count = 0 | |
| def add_notes(self, notes: list[dict], ng: NoteGuard) -> None: | |
| """Index de-identified notes. Raises if any PHI is detected.""" | |
| rows = [] | |
| for note in notes: | |
| text = note.get("text", "") | |
| ng.assert_clean(text) | |
| self._count += 1 | |
| rows.append({"note_id": note.get("note_id", str(self._count)), "text": text}) | |
| if rows: | |
| self._source.put(rows) | |
| def retrieve(self, query_text: str, ng: NoteGuard, top_k: int = 3) -> list[str]: | |
| """Return de-identified context chunks. Raises if PHI is found anywhere.""" | |
| ng.assert_clean(query_text) | |
| if self._count == 0: | |
| return [] | |
| result = self._app.query(self._query, query_text=query_text, limit=top_k) | |
| chunks = [] | |
| for entry in result.entries: | |
| text = str(entry.fields.get("text", "")) | |
| if text: | |
| ng.assert_clean(text) | |
| chunks.append(text) | |
| return chunks | |