Spaces:
Running
Running
Deploy THIRAWAT mapper app
Browse files- .gitattributes +2 -35
- .gitignore +11 -0
- AGENTS.md +39 -0
- README.md +122 -14
- app.py +43 -0
- docs/upstream-instruction-issues.md +73 -0
- pyproject.toml +30 -0
- requirements.txt +6 -0
- scripts/offline/build_duckdb.py +85 -0
- scripts/offline/build_lancedb_index.py +179 -0
- scripts/offline/publish_index_hf.py +83 -0
- spec/athena +0 -0
- spec/example.py +327 -0
- spec/spec.md +11 -0
- src/thirawat_demo/__init__.py +2 -0
- src/thirawat_demo/runtime/__init__.py +8 -0
- src/thirawat_demo/runtime/config.py +138 -0
- src/thirawat_demo/runtime/index_loader.py +69 -0
- src/thirawat_demo/runtime/peft_reranker.py +227 -0
- src/thirawat_demo/runtime/search_service.py +307 -0
- src/thirawat_demo/space_ui.py +310 -0
- tests/conftest.py +12 -0
- tests/test_build_lancedb_index.py +53 -0
- tests/test_peft_reranker.py +87 -0
- tests/test_runtime_config.py +82 -0
- tests/test_search_service.py +152 -0
- tests/test_space_ui.py +97 -0
- uv.lock +0 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,2 @@
|
|
| 1 |
-
|
| 2 |
-
*
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
# Auto detect text files and perform LF normalization
|
| 2 |
+
* text=auto
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
__pycache__/
|
| 3 |
+
.pytest_cache/
|
| 4 |
+
.mypy_cache/
|
| 5 |
+
.ruff_cache/
|
| 6 |
+
.DS_Store
|
| 7 |
+
|
| 8 |
+
data/
|
| 9 |
+
!data/.gitkeep
|
| 10 |
+
|
| 11 |
+
temp/*.tmp
|
AGENTS.md
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Repository Guidelines
|
| 2 |
+
|
| 3 |
+
## Project Structure & Module Organization
|
| 4 |
+
- `src/thirawat_demo/runtime/`: runtime configuration, index loading, reranking, and search orchestration.
|
| 5 |
+
- `src/thirawat_demo/space_ui.py`: Gradio interface logic; `app.py` is the local/Space entrypoint.
|
| 6 |
+
- `scripts/offline/`: offline pipeline scripts (`build_duckdb.py`, `build_lancedb_index.py`, `publish_index_hf.py`).
|
| 7 |
+
- `tests/`: pytest suite for runtime behavior and indexing workflow.
|
| 8 |
+
- `data/`: local/generated artifacts (DuckDB + LanceDB); avoid treating as source code.
|
| 9 |
+
- `docs/` and `spec/`: reference notes and specification examples.
|
| 10 |
+
|
| 11 |
+
## Build, Test, and Development Commands
|
| 12 |
+
- `uv sync --python 3.11`: install dependencies from `uv.lock`.
|
| 13 |
+
- `uv run pytest`: run all tests in `tests/`.
|
| 14 |
+
- `uv run pytest tests/test_search_service.py -q`: run a focused test file.
|
| 15 |
+
- `uv run python app.py`: start the Gradio runtime app locally.
|
| 16 |
+
- `uv run python scripts/offline/build_duckdb.py --help`: inspect offline build options before running data jobs.
|
| 17 |
+
|
| 18 |
+
## Coding Style & Naming Conventions
|
| 19 |
+
- Target Python 3.11+, 4-space indentation, and PEP 8-compatible formatting.
|
| 20 |
+
- Prefer explicit type hints and small, single-purpose functions.
|
| 21 |
+
- Naming: `snake_case` for modules/functions/variables, `PascalCase` for classes, `UPPER_SNAKE_CASE` for constants.
|
| 22 |
+
- Keep the architecture boundary strict: heavy indexing stays in `scripts/offline/`; runtime code in `src/thirawat_demo/runtime/`.
|
| 23 |
+
|
| 24 |
+
## Testing Guidelines
|
| 25 |
+
- Framework: `pytest` (configured in `pyproject.toml` with `testpaths = ["tests"]`).
|
| 26 |
+
- File naming: `tests/test_*.py`; test function naming: `test_*`.
|
| 27 |
+
- Use `monkeypatch` for environment-variable behavior and `tmp_path` for filesystem-dependent cases.
|
| 28 |
+
- Add or update tests with every behavioral change, especially around config parsing, retrieval/reranking flow, and UI integration points.
|
| 29 |
+
|
| 30 |
+
## Commit & Pull Request Guidelines
|
| 31 |
+
- Current history is minimal and uses short subject lines (example: `Initial commit`).
|
| 32 |
+
- Use concise, imperative commit subjects (prefer <= 72 chars), e.g., `Add post-score validation for runtime config`.
|
| 33 |
+
- PRs should include: purpose, key changes, verification commands run, and any required env var/data changes.
|
| 34 |
+
- Include screenshots for UI-impacting changes (`space_ui.py`) and note any model/index compatibility implications.
|
| 35 |
+
|
| 36 |
+
## Security & Configuration Tips
|
| 37 |
+
- Never commit secrets (for example `HF_TOKEN`).
|
| 38 |
+
- Keep runtime config in environment variables (`INDEX_REPO`, `DEVICE`, `LANCEDB_TABLE`, etc.).
|
| 39 |
+
- Publish large index artifacts through the HF dataset workflow instead of Git.
|
README.md
CHANGED
|
@@ -1,14 +1,122 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# THIRAWAT Mapper Demo
|
| 2 |
+
|
| 3 |
+
This repository is intentionally split into:
|
| 4 |
+
|
| 5 |
+
1. **Offline indexing pipeline** (Athena -> DuckDB -> LanceDB -> HF dataset upload)
|
| 6 |
+
2. **Hugging Face Space runtime app** (download prebuilt index + serve Gradio UI)
|
| 7 |
+
|
| 8 |
+
## Separation Rule
|
| 9 |
+
|
| 10 |
+
The Space runtime must **not** build indexes.
|
| 11 |
+
All heavy data preparation happens offline via scripts in `/scripts/offline`.
|
| 12 |
+
|
| 13 |
+
## Project Layout
|
| 14 |
+
|
| 15 |
+
- `/Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/scripts/offline/build_duckdb.py`
|
| 16 |
+
- `/Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/scripts/offline/build_lancedb_index.py`
|
| 17 |
+
- `/Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/scripts/offline/publish_index_hf.py`
|
| 18 |
+
- `/Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/src/thirawat_demo/runtime/config.py`
|
| 19 |
+
- `/Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/src/thirawat_demo/runtime/index_loader.py`
|
| 20 |
+
- `/Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/src/thirawat_demo/runtime/search_service.py`
|
| 21 |
+
- `/Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/src/thirawat_demo/space_ui.py`
|
| 22 |
+
- `/Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/app.py`
|
| 23 |
+
|
| 24 |
+
## Setup
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
uv sync --python 3.11
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
## Offline Build Flow
|
| 31 |
+
|
| 32 |
+
### 1) Athena -> DuckDB
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
uv run python scripts/offline/build_duckdb.py \
|
| 36 |
+
--athena-dir /path/to/athena-export \
|
| 37 |
+
--out /Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/data/derived/concepts.duckdb \
|
| 38 |
+
--overwrite
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### 2) DuckDB -> LanceDB index
|
| 42 |
+
|
| 43 |
+
`--device auto` behavior:
|
| 44 |
+
- macOS + MPS available: `mps`
|
| 45 |
+
- otherwise if CUDA available: `cuda`
|
| 46 |
+
- otherwise: `cpu`
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
uv run python scripts/offline/build_lancedb_index.py \
|
| 50 |
+
--duckdb /Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/data/derived/concepts.duckdb \
|
| 51 |
+
--out-db /Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/data/lancedb/db \
|
| 52 |
+
--table concepts_drug \
|
| 53 |
+
--device auto
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
Default concept classes used by the builder:
|
| 57 |
+
|
| 58 |
+
- Clinical Drug
|
| 59 |
+
- Quant Clinical Drug
|
| 60 |
+
- Clinical Drug Comp
|
| 61 |
+
- Clinical Drug Form
|
| 62 |
+
- Branded Drug
|
| 63 |
+
- Quant Branded Drug
|
| 64 |
+
- Branded Drug Comp
|
| 65 |
+
- Branded Drug Form
|
| 66 |
+
- Ingredient
|
| 67 |
+
|
| 68 |
+
### 3) Publish index artifact to HF dataset repo
|
| 69 |
+
|
| 70 |
+
```bash
|
| 71 |
+
export HF_TOKEN=hf_xxx
|
| 72 |
+
uv run python scripts/offline/publish_index_hf.py \
|
| 73 |
+
--repo-id your-org/thirawat-mapper-demo-index \
|
| 74 |
+
--source-dir /Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/data/lancedb \
|
| 75 |
+
--revision main
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
## HF Space Deployment Flow
|
| 79 |
+
|
| 80 |
+
In Space settings, configure env vars as needed:
|
| 81 |
+
|
| 82 |
+
- `INDEX_REPO` (required unless local index is baked into image)
|
| 83 |
+
- `INDEX_REVISION` (default `main`)
|
| 84 |
+
- `LANCEDB_TABLE` (default `concepts_drug`)
|
| 85 |
+
- `DEVICE` (default `auto`)
|
| 86 |
+
- `TOP_K_DEFAULT`, `CANDIDATE_TOPK`, `RETRIEVAL_TOPK` (optional)
|
| 87 |
+
- `POST_MODE` (default `tiebreak`; supported: `blend|tiebreak|lex`)
|
| 88 |
+
- `POST_WEIGHT` (default `0.05`; used for `blend`)
|
| 89 |
+
- `TIEBREAK_EPS` (default `0.01`)
|
| 90 |
+
- `TIEBREAK_TOPN` (default `50`)
|
| 91 |
+
- `POST_STRENGTH_WEIGHT` (default `0.6`)
|
| 92 |
+
- `POST_JACCARD_WEIGHT` (default `0.4`)
|
| 93 |
+
- `POST_BRAND_PENALTY` (default `0.3`)
|
| 94 |
+
- `POST_MINMAX` (default `true`)
|
| 95 |
+
- `BRAND_STRICT` (default `false`)
|
| 96 |
+
- `HF_TOKEN` (optional for private repos)
|
| 97 |
+
|
| 98 |
+
Space entrypoint:
|
| 99 |
+
|
| 100 |
+
```bash
|
| 101 |
+
python app.py
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
## Local Runtime Test
|
| 105 |
+
|
| 106 |
+
```bash
|
| 107 |
+
export INDEX_REPO=your-org/thirawat-mapper-demo-index
|
| 108 |
+
export DEVICE=auto
|
| 109 |
+
uv run python app.py
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
## Notes
|
| 113 |
+
|
| 114 |
+
- Runtime search enforces two-stage retrieval: SapBERT vector retrieval + THIRAWAT reranking.
|
| 115 |
+
- Runtime reranker loads PEFT adapters directly (trainer-style) and does not merge checkpoints to disk at startup.
|
| 116 |
+
- Runtime ranking applies THIRAWAT deterministic post/tie-break rules.
|
| 117 |
+
- Warning note: `No sentence-transformers model found with name cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR.` can appear during reranker startup. This is informational in the current setup: PyLate ColBERT first checks for a Sentence-Transformers-packaged model and then falls back to building from the base Hugging Face encoder.
|
| 118 |
+
- Retrieval embeddings use Hugging Face `transformers` directly, while reranking uses PyLate ColBERT (which relies on Sentence-Transformers internals). So `sentence-transformers` is still required for reranking even though retrieval itself does not depend on it.
|
| 119 |
+
- Gradio UI includes concept-class multi-select filters (include-only). Defaults come from index manifest when present.
|
| 120 |
+
- Domain is fixed to `Drug` in this v1 app.
|
| 121 |
+
- Upstream doc issues are tracked in:
|
| 122 |
+
- `/Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/docs/upstream-instruction-issues.md`
|
app.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hugging Face Space entrypoint."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import socket
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
ROOT = Path(__file__).resolve().parent
|
| 11 |
+
SRC = ROOT / "src"
|
| 12 |
+
if str(SRC) not in sys.path:
|
| 13 |
+
sys.path.insert(0, str(SRC))
|
| 14 |
+
|
| 15 |
+
from thirawat_demo.space_ui import build_demo
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _is_port_free(port: int) -> bool:
|
| 19 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
| 20 |
+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
| 21 |
+
return sock.connect_ex(("127.0.0.1", port)) != 0
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _resolve_server_port() -> int:
|
| 25 |
+
"""Use fixed PORT in managed environments; choose first free local port."""
|
| 26 |
+
port_env = os.getenv("PORT")
|
| 27 |
+
if port_env:
|
| 28 |
+
return int(port_env)
|
| 29 |
+
gradio_port_env = os.getenv("GRADIO_SERVER_PORT")
|
| 30 |
+
if gradio_port_env:
|
| 31 |
+
return int(gradio_port_env)
|
| 32 |
+
for candidate in range(7860, 7871):
|
| 33 |
+
if _is_port_free(candidate):
|
| 34 |
+
return candidate
|
| 35 |
+
return 7860
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
demo = build_demo()
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
port = _resolve_server_port()
|
| 42 |
+
print(f"Starting THIRAWAT Mapper Demo on http://127.0.0.1:{port}", flush=True)
|
| 43 |
+
demo.launch(server_name="0.0.0.0", server_port=port, share=False)
|
docs/upstream-instruction-issues.md
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Upstream Instruction Issues
|
| 2 |
+
|
| 3 |
+
This note captures discrepancies found while implementing the demo so they can be fixed upstream.
|
| 4 |
+
|
| 5 |
+
## 1) `athena2duckdb` README default output mismatch
|
| 6 |
+
- Repo: <https://github.com/sidataplus/athena2duckdb>
|
| 7 |
+
- README says `--out` default is `vocab.duckdb`.
|
| 8 |
+
- CLI source (`src/athena2duckdb/cli.py`) uses `omop_vocab.duckdb`.
|
| 9 |
+
- Suggested fix:
|
| 10 |
+
- Align README default value with code, or
|
| 11 |
+
- Change code default to match README.
|
| 12 |
+
|
| 13 |
+
## 2) `THIRAWAT-mapper` README says reranker is gated
|
| 14 |
+
- Repo: <https://github.com/sidataplus/THIRAWAT-mapper>
|
| 15 |
+
- README notes `sidataplus/THIRAWAT-SapBERT` is gated.
|
| 16 |
+
- Verified on **February 10, 2026** via HF model API that `gated=false`, `private=false`.
|
| 17 |
+
- Suggested fix:
|
| 18 |
+
- Update README access notes and include verification date.
|
| 19 |
+
|
| 20 |
+
## 3) `THIRAWAT-mapper` Cloudflare markdown fence issue
|
| 21 |
+
- Repo: <https://github.com/sidataplus/THIRAWAT-mapper>
|
| 22 |
+
- Cloudflare section has a duplicated fenced code opener in README, causing broken rendering.
|
| 23 |
+
- Suggested fix:
|
| 24 |
+
- Remove the extra code fence and validate markdown rendering.
|
| 25 |
+
|
| 26 |
+
## 4) `THIRAWAT-mapper` index build docs under-document MPS usage
|
| 27 |
+
- Repo: <https://github.com/sidataplus/THIRAWAT-mapper>
|
| 28 |
+
- README examples mostly show `--device cuda`.
|
| 29 |
+
- Apple Silicon users can run with `--device mps`; CPU fallback is also valid.
|
| 30 |
+
- Suggested fix:
|
| 31 |
+
- Add a short device matrix and one MPS example command.
|
| 32 |
+
|
| 33 |
+
## 5) `THIRAWAT-mapper` docs can confuse on `--profiles-table`
|
| 34 |
+
- Repo: <https://github.com/sidataplus/THIRAWAT-mapper>
|
| 35 |
+
- README examples require `--profiles-table concept_profiles`.
|
| 36 |
+
- Athena-to-DuckDB outputs typically do not contain `concept_profiles`.
|
| 37 |
+
- Current builder code does fallback to inline profile creation if `profiles_table` is absent.
|
| 38 |
+
- Suggested fix:
|
| 39 |
+
- Document fallback behavior clearly and show an Athena-only example path.
|
| 40 |
+
|
| 41 |
+
## 6) `athena2duckdb==0.1.0` import-time syntax error in `loader.py`
|
| 42 |
+
- Repo: <https://github.com/sidataplus/athena2duckdb>
|
| 43 |
+
- In our local run on **February 10, 2026**, importing `athena2duckdb` fails:
|
| 44 |
+
- `SyntaxError: f-string expression part cannot include a backslash`
|
| 45 |
+
- Affected code pattern in `loader.py`:
|
| 46 |
+
- `f" {',\\n '.join(columns_sql)}\\n"` inside an f-string expression.
|
| 47 |
+
- Impact:
|
| 48 |
+
- `build_duckdb.py` cannot execute with published `athena2duckdb==0.1.0`.
|
| 49 |
+
- Suggested fix:
|
| 50 |
+
- Refactor string assembly to avoid backslash-containing expression inside the f-string and release `0.1.1`.
|
| 51 |
+
|
| 52 |
+
## 7) `THIRAWAT-mapper` docs mix `transformers` defaults with `--backend st` examples
|
| 53 |
+
- Repo: <https://github.com/sidataplus/thirawat-mapper>
|
| 54 |
+
- README states embedding backend default is `transformers` and says `--backend transformers` can be omitted.
|
| 55 |
+
- But `docs/retrieval_reranker.md` examples repeatedly use `--backend st` (vector build/eval/RAG examples).
|
| 56 |
+
- Suggested fix:
|
| 57 |
+
- Standardize examples to default `transformers`, or
|
| 58 |
+
- Add a clear "when to use `st` vs `transformers`" note and keep examples consistent with that guidance.
|
| 59 |
+
|
| 60 |
+
## 8) `THIRAWAT-mapper` lacks troubleshooting note for SapBERT sentence-transformers warning
|
| 61 |
+
- Repo: <https://github.com/sidataplus/thirawat-mapper>
|
| 62 |
+
- During ColBERT/PyLate startup with SapBERT, users can see:
|
| 63 |
+
- `No sentence-transformers model found with name cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR.`
|
| 64 |
+
- This is often benign fallback behavior, but docs do not explain it, so users may treat it as a hard failure.
|
| 65 |
+
- Suggested fix:
|
| 66 |
+
- Add a short troubleshooting note clarifying expected behavior and when it is truly an error (actual load failure).
|
| 67 |
+
|
| 68 |
+
## 9) `THIRAWAT-mapper` docs under-document deterministic post/tie-break inference rules
|
| 69 |
+
- Repo: <https://github.com/sidataplus/thirawat-mapper>
|
| 70 |
+
- Installed package behavior (`thirawat_mapper==0.1.4`) includes deterministic tie-break/post ranking (`tiebreak_rerank`, `enrich_with_post_scores`, `eps`, `topn`, and ordered tie-break keys).
|
| 71 |
+
- Public markdown docs focus on blend-style post scoring but do not document tie-break mode semantics and knobs.
|
| 72 |
+
- Suggested fix:
|
| 73 |
+
- Add an inference-ranking section documenting `blend|lex|tiebreak` modes, default parameters, and deterministic sorting behavior.
|
pyproject.toml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "thirawat-mapper-demo"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "THIRAWAT mapper demo split into offline index pipeline and HF Space runtime app."
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.11,<3.13"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"athena2duckdb==0.1.0",
|
| 9 |
+
"gradio==6.5.1",
|
| 10 |
+
"huggingface_hub>=0.27.0",
|
| 11 |
+
"lancedb==0.29.2",
|
| 12 |
+
"pandas>=2.2.0",
|
| 13 |
+
"peft>=0.17.0",
|
| 14 |
+
"thirawat-mapper==0.1.4",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
[dependency-groups]
|
| 18 |
+
dev = [
|
| 19 |
+
"pytest>=8.3.0",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
[tool.pytest.ini_options]
|
| 23 |
+
testpaths = ["tests"]
|
| 24 |
+
|
| 25 |
+
[build-system]
|
| 26 |
+
requires = ["hatchling>=1.24.0"]
|
| 27 |
+
build-backend = "hatchling.build"
|
| 28 |
+
|
| 29 |
+
[tool.hatch.build.targets.wheel]
|
| 30 |
+
packages = ["src/thirawat_demo"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==6.5.1
|
| 2 |
+
huggingface_hub>=0.27.0
|
| 3 |
+
lancedb==0.29.2
|
| 4 |
+
pandas>=2.2.0
|
| 5 |
+
peft>=0.17.0
|
| 6 |
+
thirawat-mapper==0.1.4
|
scripts/offline/build_duckdb.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Offline: convert Athena vocabulary export to DuckDB."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 10 |
+
DEFAULT_DUCKDB_PATH = REPO_ROOT / "data" / "derived" / "concepts.duckdb"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def parse_args() -> argparse.Namespace:
|
| 14 |
+
parser = argparse.ArgumentParser(description="Build DuckDB from Athena vocabulary export.")
|
| 15 |
+
parser.add_argument("--athena-dir", required=True, help="Directory containing Athena CSV/TSV files.")
|
| 16 |
+
parser.add_argument(
|
| 17 |
+
"--out",
|
| 18 |
+
default=str(DEFAULT_DUCKDB_PATH),
|
| 19 |
+
help=f"Output DuckDB path (default: {DEFAULT_DUCKDB_PATH}).",
|
| 20 |
+
)
|
| 21 |
+
parser.add_argument("--overwrite", action="store_true", help="Overwrite existing DuckDB file.")
|
| 22 |
+
parser.add_argument("--threads", type=int, default=None, help="DuckDB threads (default: auto).")
|
| 23 |
+
parser.add_argument("--schema", default="main", help="Target schema name (default: main).")
|
| 24 |
+
parser.add_argument("--sep", default="\t", help="Input delimiter (default: tab).")
|
| 25 |
+
parser.add_argument("--encoding", default="UTF-8", help="Input encoding (default: UTF-8).")
|
| 26 |
+
return parser.parse_args()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def run(args: argparse.Namespace) -> int:
|
| 30 |
+
try:
|
| 31 |
+
from athena2duckdb import CSVOptions, load_vocab_dir, verify_row_counts
|
| 32 |
+
except SyntaxError as exc:
|
| 33 |
+
raise RuntimeError(
|
| 34 |
+
"Failed to import athena2duckdb due upstream syntax error. "
|
| 35 |
+
"Current workaround: use an already-built DuckDB file (for example a previous vocab.duckdb) "
|
| 36 |
+
"and continue with scripts/offline/build_lancedb_index.py."
|
| 37 |
+
) from exc
|
| 38 |
+
|
| 39 |
+
athena_dir = Path(args.athena_dir).expanduser().resolve()
|
| 40 |
+
out_path = Path(args.out).expanduser().resolve()
|
| 41 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 42 |
+
|
| 43 |
+
if not athena_dir.exists():
|
| 44 |
+
raise FileNotFoundError(f"Athena directory does not exist: {athena_dir}")
|
| 45 |
+
|
| 46 |
+
csv_options = CSVOptions(sep=args.sep, encoding=args.encoding)
|
| 47 |
+
|
| 48 |
+
summary = load_vocab_dir(
|
| 49 |
+
input_dir=athena_dir,
|
| 50 |
+
out_path=out_path,
|
| 51 |
+
csv_options=csv_options,
|
| 52 |
+
overwrite=bool(args.overwrite),
|
| 53 |
+
threads=args.threads,
|
| 54 |
+
schema=args.schema,
|
| 55 |
+
)
|
| 56 |
+
print(f"Loaded {len(summary.vocab_files)} tables into {summary.db_path} (schema {summary.schema}).")
|
| 57 |
+
|
| 58 |
+
results = verify_row_counts(
|
| 59 |
+
db_path=summary.db_path,
|
| 60 |
+
vocab_files=summary.vocab_files,
|
| 61 |
+
csv_options=csv_options,
|
| 62 |
+
threads=args.threads,
|
| 63 |
+
schema=summary.schema,
|
| 64 |
+
)
|
| 65 |
+
mismatches = [result for result in results if not result.matches]
|
| 66 |
+
for result in results:
|
| 67 |
+
status = "OK" if result.matches else "MISMATCH"
|
| 68 |
+
print(
|
| 69 |
+
f"{status:9s} table={result.table_name:<25s} "
|
| 70 |
+
f"csv_rows={result.csv_rows:,} table_rows={result.table_rows:,}"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
if mismatches:
|
| 74 |
+
print(f"Found {len(mismatches)} row-count mismatches.")
|
| 75 |
+
return 2
|
| 76 |
+
return 0
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def main() -> int:
|
| 80 |
+
args = parse_args()
|
| 81 |
+
return run(args)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
raise SystemExit(main())
|
scripts/offline/build_lancedb_index.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Offline: build LanceDB index for THIRAWAT mapper."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import platform
|
| 9 |
+
|
| 10 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 11 |
+
DEFAULT_DUCKDB_PATH = REPO_ROOT / "data" / "derived" / "concepts.duckdb"
|
| 12 |
+
DEFAULT_LANCEDB_DIR = REPO_ROOT / "data" / "lancedb" / "db"
|
| 13 |
+
|
| 14 |
+
DEFAULT_DOMAIN_IDS = ["Drug"]
|
| 15 |
+
DEFAULT_CONCEPT_CLASSES = [
|
| 16 |
+
"Clinical Drug",
|
| 17 |
+
"Quant Clinical Drug",
|
| 18 |
+
"Clinical Drug Comp",
|
| 19 |
+
"Clinical Drug Form",
|
| 20 |
+
"Branded Drug",
|
| 21 |
+
"Quant Branded Drug",
|
| 22 |
+
"Branded Drug Comp",
|
| 23 |
+
"Branded Drug Form",
|
| 24 |
+
"Ingredient",
|
| 25 |
+
]
|
| 26 |
+
DEFAULT_EXCLUDED_CONCEPT_CLASSES = [
|
| 27 |
+
"Clinical Drug Box",
|
| 28 |
+
"Branded Drug Box",
|
| 29 |
+
"Branded Pack Box",
|
| 30 |
+
"Clinical Pack Box",
|
| 31 |
+
"Marketed Product",
|
| 32 |
+
"Quant Branded Box",
|
| 33 |
+
"Quant Clinical Box",
|
| 34 |
+
]
|
| 35 |
+
DEFAULT_EXTRA_COLUMNS = ["concept_name", "concept_code", "domain_id", "vocabulary_id", "concept_class_id"]
|
| 36 |
+
DEFAULT_MODEL_ID = "cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _split_multi(values: list[str] | None) -> list[str]:
|
| 40 |
+
if not values:
|
| 41 |
+
return []
|
| 42 |
+
items: list[str] = []
|
| 43 |
+
for value in values:
|
| 44 |
+
parts = [part.strip() for part in str(value).split(",")]
|
| 45 |
+
items.extend([part for part in parts if part])
|
| 46 |
+
return items
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def resolve_device(device: str) -> str:
|
| 50 |
+
requested = device.strip().lower()
|
| 51 |
+
if requested != "auto":
|
| 52 |
+
return requested
|
| 53 |
+
try:
|
| 54 |
+
import torch # type: ignore
|
| 55 |
+
except Exception:
|
| 56 |
+
return "cpu"
|
| 57 |
+
|
| 58 |
+
is_darwin = platform.system().lower() == "darwin"
|
| 59 |
+
has_mps = bool(getattr(torch.backends, "mps", None)) and torch.backends.mps.is_available()
|
| 60 |
+
has_cuda = bool(torch.cuda.is_available())
|
| 61 |
+
|
| 62 |
+
# Policy requested for this repo:
|
| 63 |
+
# auto => macOS MPS first, then CUDA, then CPU.
|
| 64 |
+
if is_darwin and has_mps:
|
| 65 |
+
return "mps"
|
| 66 |
+
if has_cuda:
|
| 67 |
+
return "cuda"
|
| 68 |
+
return "cpu"
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def parse_args() -> argparse.Namespace:
|
| 72 |
+
parser = argparse.ArgumentParser(description="Build LanceDB index using thirawat-mapper.")
|
| 73 |
+
parser.add_argument("--duckdb", default=str(DEFAULT_DUCKDB_PATH), help="Path to DuckDB vocabulary file.")
|
| 74 |
+
parser.add_argument("--profiles-table", default="concept_profiles", help="Profiles table name.")
|
| 75 |
+
parser.add_argument("--concepts-table", default="concept", help="Concepts table name.")
|
| 76 |
+
parser.add_argument("--out-db", default=str(DEFAULT_LANCEDB_DIR), help="Output LanceDB directory.")
|
| 77 |
+
parser.add_argument("--table", default="concepts_drug", help="Output LanceDB table name.")
|
| 78 |
+
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--domain-id",
|
| 81 |
+
action="append",
|
| 82 |
+
default=None,
|
| 83 |
+
help="Domain filters, comma-separated or repeated (default: Drug).",
|
| 84 |
+
)
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"--concept-class-id",
|
| 87 |
+
action="append",
|
| 88 |
+
default=None,
|
| 89 |
+
help="Concept class filters, comma-separated or repeated.",
|
| 90 |
+
)
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"--exclude-concept-class-id",
|
| 93 |
+
action="append",
|
| 94 |
+
default=None,
|
| 95 |
+
help="Concept class exclusions, comma-separated or repeated.",
|
| 96 |
+
)
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"--extra-column",
|
| 99 |
+
action="append",
|
| 100 |
+
default=None,
|
| 101 |
+
help="Extra columns to carry into index table (comma-separated or repeated).",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
parser.add_argument("--batch-size", type=int, default=256, help="Embedding batch size.")
|
| 105 |
+
parser.add_argument("--model-id", default=DEFAULT_MODEL_ID, help="Encoder model id.")
|
| 106 |
+
parser.add_argument("--pooling", choices=["cls", "mean"], default="cls", help="Pooling type.")
|
| 107 |
+
parser.add_argument("--max-length", type=int, default=128, help="Encoder max token length.")
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--device",
|
| 110 |
+
choices=["auto", "mps", "cuda", "cpu"],
|
| 111 |
+
default="auto",
|
| 112 |
+
help="Index build device; auto resolves to mps (Darwin), then cuda, then cpu.",
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument("--trust-remote-code", action="store_true", help="Pass trust_remote_code to encoder.")
|
| 115 |
+
return parser.parse_args()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def run(args: argparse.Namespace) -> int:
|
| 119 |
+
from thirawat_mapper.index.build import main as thirawat_index_build_main
|
| 120 |
+
|
| 121 |
+
duckdb_path = Path(args.duckdb).expanduser().resolve()
|
| 122 |
+
out_db = Path(args.out_db).expanduser().resolve()
|
| 123 |
+
out_db.mkdir(parents=True, exist_ok=True)
|
| 124 |
+
|
| 125 |
+
if not duckdb_path.exists():
|
| 126 |
+
raise FileNotFoundError(f"DuckDB file does not exist: {duckdb_path}")
|
| 127 |
+
|
| 128 |
+
resolved_device = resolve_device(args.device)
|
| 129 |
+
domain_ids = _split_multi(args.domain_id) or DEFAULT_DOMAIN_IDS
|
| 130 |
+
concept_classes = _split_multi(args.concept_class_id) or DEFAULT_CONCEPT_CLASSES
|
| 131 |
+
excluded_classes = _split_multi(args.exclude_concept_class_id) or DEFAULT_EXCLUDED_CONCEPT_CLASSES
|
| 132 |
+
extra_columns = _split_multi(args.extra_column) or DEFAULT_EXTRA_COLUMNS
|
| 133 |
+
|
| 134 |
+
cli_args = [
|
| 135 |
+
"--duckdb",
|
| 136 |
+
str(duckdb_path),
|
| 137 |
+
"--profiles-table",
|
| 138 |
+
args.profiles_table,
|
| 139 |
+
"--concepts-table",
|
| 140 |
+
args.concepts_table,
|
| 141 |
+
"--domain-id",
|
| 142 |
+
",".join(domain_ids),
|
| 143 |
+
"--concept-class-id",
|
| 144 |
+
",".join(concept_classes),
|
| 145 |
+
"--exclude-concept-class-id",
|
| 146 |
+
",".join(excluded_classes),
|
| 147 |
+
"--extra-column",
|
| 148 |
+
",".join(extra_columns),
|
| 149 |
+
"--out-db",
|
| 150 |
+
str(out_db),
|
| 151 |
+
"--table",
|
| 152 |
+
args.table,
|
| 153 |
+
"--batch-size",
|
| 154 |
+
str(args.batch_size),
|
| 155 |
+
"--model-id",
|
| 156 |
+
args.model_id,
|
| 157 |
+
"--pooling",
|
| 158 |
+
args.pooling,
|
| 159 |
+
"--max-length",
|
| 160 |
+
str(args.max_length),
|
| 161 |
+
"--device",
|
| 162 |
+
resolved_device,
|
| 163 |
+
]
|
| 164 |
+
if args.trust_remote_code:
|
| 165 |
+
cli_args.append("--trust-remote-code")
|
| 166 |
+
|
| 167 |
+
print(f"Resolved build device: {resolved_device}")
|
| 168 |
+
print("Invoking: python -m thirawat_mapper.index.build " + " ".join(cli_args))
|
| 169 |
+
thirawat_index_build_main(cli_args)
|
| 170 |
+
return 0
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def main() -> int:
|
| 174 |
+
args = parse_args()
|
| 175 |
+
return run(args)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
if __name__ == "__main__":
|
| 179 |
+
raise SystemExit(main())
|
scripts/offline/publish_index_hf.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Offline: publish built index artifact to Hugging Face dataset repo."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from huggingface_hub import HfApi, create_repo, upload_folder
|
| 11 |
+
|
| 12 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 13 |
+
DEFAULT_SOURCE_DIR = REPO_ROOT / "data" / "lancedb"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def parse_args() -> argparse.Namespace:
|
| 17 |
+
parser = argparse.ArgumentParser(description="Upload index artifact to Hugging Face dataset repo.")
|
| 18 |
+
parser.add_argument("--repo-id", required=True, help="Target HF dataset repo (e.g. org/name).")
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"--source-dir",
|
| 21 |
+
default=str(DEFAULT_SOURCE_DIR),
|
| 22 |
+
help=f"Directory to upload (default: {DEFAULT_SOURCE_DIR}).",
|
| 23 |
+
)
|
| 24 |
+
parser.add_argument("--revision", default="main", help="Target branch/revision (default: main).")
|
| 25 |
+
parser.add_argument("--private", action="store_true", help="Create repo as private if it does not exist.")
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--token-env",
|
| 28 |
+
default="HF_TOKEN",
|
| 29 |
+
help="Environment variable name holding HF write token (default: HF_TOKEN).",
|
| 30 |
+
)
|
| 31 |
+
return parser.parse_args()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def run(args: argparse.Namespace) -> int:
|
| 35 |
+
source_dir = Path(args.source_dir).expanduser().resolve()
|
| 36 |
+
if not source_dir.exists():
|
| 37 |
+
raise FileNotFoundError(f"Source directory does not exist: {source_dir}")
|
| 38 |
+
|
| 39 |
+
token = os.getenv(args.token_env)
|
| 40 |
+
if not token:
|
| 41 |
+
raise ValueError(f"Missing Hugging Face token in environment variable: {args.token_env}")
|
| 42 |
+
|
| 43 |
+
create_repo(
|
| 44 |
+
repo_id=args.repo_id,
|
| 45 |
+
repo_type="dataset",
|
| 46 |
+
private=bool(args.private),
|
| 47 |
+
token=token,
|
| 48 |
+
exist_ok=True,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
api = HfApi(token=token)
|
| 52 |
+
if args.revision != "main":
|
| 53 |
+
try:
|
| 54 |
+
api.create_branch(repo_id=args.repo_id, repo_type="dataset", branch=args.revision, exist_ok=True)
|
| 55 |
+
except TypeError:
|
| 56 |
+
# Backward-compatible path for clients without exist_ok support.
|
| 57 |
+
try:
|
| 58 |
+
api.create_branch(repo_id=args.repo_id, repo_type="dataset", branch=args.revision)
|
| 59 |
+
except Exception:
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
commit_info = upload_folder(
|
| 63 |
+
repo_id=args.repo_id,
|
| 64 |
+
repo_type="dataset",
|
| 65 |
+
folder_path=str(source_dir),
|
| 66 |
+
path_in_repo=".",
|
| 67 |
+
revision=args.revision,
|
| 68 |
+
commit_message="Update THIRAWAT mapper demo index artifact",
|
| 69 |
+
token=token,
|
| 70 |
+
)
|
| 71 |
+
print(f"Uploaded index artifact to {args.repo_id}@{args.revision}")
|
| 72 |
+
print(f"Commit URL: {commit_info.commit_url}")
|
| 73 |
+
return 0
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def main() -> int:
|
| 77 |
+
args = parse_args()
|
| 78 |
+
return run(args)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
if __name__ == "__main__":
|
| 82 |
+
raise SystemExit(main())
|
| 83 |
+
|
spec/athena
ADDED
|
Binary file (948 Bytes). View file
|
|
|
spec/example.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Drug Concept Entity Linking - HuggingFace Space"""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import tempfile
|
| 5 |
+
import traceback
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import lancedb
|
| 9 |
+
from sentence_transformers import SentenceTransformer
|
| 10 |
+
import pandas as pd
|
| 11 |
+
|
| 12 |
+
# ===== CONFIG =====
|
| 13 |
+
# ดึงจาก Space Secrets (ตั้งค่าใน Settings > Secrets)
|
| 14 |
+
def _get_env(name: str, default: str | None = None) -> str | None:
|
| 15 |
+
value = os.environ.get(name)
|
| 16 |
+
if value is None:
|
| 17 |
+
return default
|
| 18 |
+
value = value.strip()
|
| 19 |
+
if not value or value.lower() == "none":
|
| 20 |
+
return default
|
| 21 |
+
return value
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
HF_TOKEN = _get_env("HF_TOKEN") # หรือไม่ใส่ก็ได้ถ้า public
|
| 25 |
+
INDEX_REPO = _get_env("INDEX_REPO", "amnnma/drug-concept-index") # เปลี่ยนชื่อ repo
|
| 26 |
+
LOCAL_INDEX_PATH = _get_env("LOCAL_INDEX_PATH", "data/lancedb")
|
| 27 |
+
DEBUG = _get_env("DEBUG", "0") == "1"
|
| 28 |
+
|
| 29 |
+
# Model
|
| 30 |
+
MODEL_ID = "cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR"
|
| 31 |
+
TOP_K = 10
|
| 32 |
+
|
| 33 |
+
class DrugConceptSearcher:
|
| 34 |
+
def __init__(self):
|
| 35 |
+
self.model = None
|
| 36 |
+
self.db = None
|
| 37 |
+
self.table = None
|
| 38 |
+
self._load()
|
| 39 |
+
|
| 40 |
+
def _load(self):
|
| 41 |
+
"""Load model and connect to LanceDB"""
|
| 42 |
+
print("Loading model...")
|
| 43 |
+
# Force slow tokenizer to avoid fast-tokenizer conversion issues on Space
|
| 44 |
+
self.model = SentenceTransformer(MODEL_ID, tokenizer_kwargs={"use_fast": False})
|
| 45 |
+
|
| 46 |
+
# Prefer local index when available (useful for local runs)
|
| 47 |
+
local_root = Path(LOCAL_INDEX_PATH) if LOCAL_INDEX_PATH else None
|
| 48 |
+
if local_root and local_root.exists() and (local_root / "db").exists():
|
| 49 |
+
index_root = local_root
|
| 50 |
+
print(f"Connecting to local LanceDB at {index_root}...")
|
| 51 |
+
else:
|
| 52 |
+
repo_id = INDEX_REPO or "amnnma/drug-concept-index"
|
| 53 |
+
if not isinstance(repo_id, str):
|
| 54 |
+
repo_id = str(repo_id)
|
| 55 |
+
repo_id = repo_id.strip()
|
| 56 |
+
if repo_id.startswith("http"):
|
| 57 |
+
# Accept full HF URLs and extract the repo id
|
| 58 |
+
parts = repo_id.split("/")
|
| 59 |
+
if "datasets" in parts:
|
| 60 |
+
repo_id = "/".join(parts[parts.index("datasets") + 1 :]).strip("/")
|
| 61 |
+
elif "spaces" in parts:
|
| 62 |
+
repo_id = "/".join(parts[parts.index("spaces") + 1 :]).strip("/")
|
| 63 |
+
else:
|
| 64 |
+
repo_id = "/".join(parts[-2:]).strip("/")
|
| 65 |
+
if repo_id.startswith("datasets/"):
|
| 66 |
+
repo_id = repo_id[len("datasets/") :]
|
| 67 |
+
print(f"Connecting to LanceDB from {repo_id}...")
|
| 68 |
+
# Download และ connect ไปยัง LanceDB ใน HF repo
|
| 69 |
+
from huggingface_hub import snapshot_download
|
| 70 |
+
|
| 71 |
+
# Download index (cache ไว้ใน /data)
|
| 72 |
+
download_root = Path(os.environ.get("HF_DATA_DIR", "/data")) / "lancedb"
|
| 73 |
+
try:
|
| 74 |
+
download_root.mkdir(parents=True, exist_ok=True)
|
| 75 |
+
except OSError:
|
| 76 |
+
download_root = Path("data/lancedb")
|
| 77 |
+
download_root.mkdir(parents=True, exist_ok=True)
|
| 78 |
+
|
| 79 |
+
# Avoid implicit token usage for public datasets
|
| 80 |
+
os.environ["HF_HUB_DISABLE_IMPLICIT_TOKEN"] = "1"
|
| 81 |
+
try:
|
| 82 |
+
index_root = Path(
|
| 83 |
+
snapshot_download(
|
| 84 |
+
repo_id=repo_id,
|
| 85 |
+
repo_type="dataset",
|
| 86 |
+
token=False,
|
| 87 |
+
revision=os.environ.get("HF_DATASET_REVISION", "main"),
|
| 88 |
+
local_dir=str(download_root),
|
| 89 |
+
)
|
| 90 |
+
)
|
| 91 |
+
except Exception as e:
|
| 92 |
+
if HF_TOKEN:
|
| 93 |
+
index_root = Path(
|
| 94 |
+
snapshot_download(
|
| 95 |
+
repo_id=repo_id,
|
| 96 |
+
repo_type="dataset",
|
| 97 |
+
token=HF_TOKEN,
|
| 98 |
+
local_dir=str(download_root),
|
| 99 |
+
)
|
| 100 |
+
)
|
| 101 |
+
else:
|
| 102 |
+
raise e
|
| 103 |
+
|
| 104 |
+
# Connect to LanceDB
|
| 105 |
+
self.db = lancedb.connect(str(index_root / "db"))
|
| 106 |
+
self.table = self.db.open_table("concepts_drug")
|
| 107 |
+
print("✅ Ready!")
|
| 108 |
+
|
| 109 |
+
def search(self, query: str, top_k: int = TOP_K):
|
| 110 |
+
"""Search drug concepts"""
|
| 111 |
+
if not query or not query.strip():
|
| 112 |
+
return pd.DataFrame()
|
| 113 |
+
|
| 114 |
+
# Encode query
|
| 115 |
+
query_emb = self.model.encode(query, normalize_embeddings=True)
|
| 116 |
+
|
| 117 |
+
# Search
|
| 118 |
+
results = self.table.search(query_emb).limit(top_k).to_pandas()
|
| 119 |
+
|
| 120 |
+
# Format output
|
| 121 |
+
if "_distance" in results.columns:
|
| 122 |
+
results["score"] = 1 - results["_distance"] # Convert distance to similarity
|
| 123 |
+
results = results.sort_values("score", ascending=False)
|
| 124 |
+
|
| 125 |
+
return results[["concept_id", "concept_name", "concept_code", "vocabulary_id", "score"]]
|
| 126 |
+
|
| 127 |
+
# Initialize
|
| 128 |
+
searcher = None
|
| 129 |
+
|
| 130 |
+
def get_searcher():
|
| 131 |
+
global searcher
|
| 132 |
+
if searcher is None:
|
| 133 |
+
searcher = DrugConceptSearcher()
|
| 134 |
+
return searcher
|
| 135 |
+
|
| 136 |
+
def _format_results(results: pd.DataFrame, query: str) -> tuple[str, pd.DataFrame]:
|
| 137 |
+
if results.empty:
|
| 138 |
+
return "No results found. Try a different search term.", results
|
| 139 |
+
|
| 140 |
+
output = f"## Results for: \"{query}\"\n\n"
|
| 141 |
+
best = results.iloc[0]
|
| 142 |
+
output += f"**Top match:** {best['concept_name']} (score {best['score']:.4f})\n\n"
|
| 143 |
+
return output, results
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def search_drugs(query: str, top_k: int):
|
| 147 |
+
"""Gradio search function (single query)"""
|
| 148 |
+
try:
|
| 149 |
+
s = get_searcher()
|
| 150 |
+
results = s.search(query, top_k)
|
| 151 |
+
|
| 152 |
+
output, table = _format_results(results, query)
|
| 153 |
+
return output, table
|
| 154 |
+
except Exception as e:
|
| 155 |
+
print("Search error:", e)
|
| 156 |
+
print(traceback.format_exc())
|
| 157 |
+
if DEBUG:
|
| 158 |
+
return f"❌ Error: {str(e)}\n\n```\n{traceback.format_exc()}\n```", pd.DataFrame()
|
| 159 |
+
return f"❌ Error: {str(e)}", pd.DataFrame()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def search_batch(queries_text: str, top_k: int):
|
| 163 |
+
"""Gradio search function (batch queries)"""
|
| 164 |
+
try:
|
| 165 |
+
if not queries_text or not queries_text.strip():
|
| 166 |
+
return "Please enter clinical terms to search.", gr.update(visible=False)
|
| 167 |
+
|
| 168 |
+
lines = [line.strip() for line in queries_text.splitlines() if line.strip()]
|
| 169 |
+
if not lines:
|
| 170 |
+
return "No valid queries found.", gr.update(visible=False)
|
| 171 |
+
|
| 172 |
+
s = get_searcher()
|
| 173 |
+
rows = []
|
| 174 |
+
for q in lines:
|
| 175 |
+
results = s.search(q, top_k)
|
| 176 |
+
for i, (_, row) in enumerate(results.iterrows(), start=1):
|
| 177 |
+
rows.append(
|
| 178 |
+
{
|
| 179 |
+
"query_text": q,
|
| 180 |
+
"rank": i,
|
| 181 |
+
"concept_id": row["concept_id"],
|
| 182 |
+
"concept_name": row["concept_name"],
|
| 183 |
+
"concept_code": row["concept_code"],
|
| 184 |
+
"vocabulary_id": row["vocabulary_id"],
|
| 185 |
+
"score": float(row["score"]),
|
| 186 |
+
}
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if not rows:
|
| 190 |
+
return "No results found.", gr.update(visible=False)
|
| 191 |
+
|
| 192 |
+
df = pd.DataFrame(rows)
|
| 193 |
+
tmp_dir = Path(tempfile.gettempdir()) / "thirawat_results"
|
| 194 |
+
tmp_dir.mkdir(parents=True, exist_ok=True)
|
| 195 |
+
out_path = tmp_dir / "batch_results.csv"
|
| 196 |
+
df.to_csv(out_path, index=False)
|
| 197 |
+
|
| 198 |
+
md = f"""## Batch Search Complete
|
| 199 |
+
|
| 200 |
+
- **Queries processed:** {len(lines)}
|
| 201 |
+
- **Rows returned:** {len(rows)}
|
| 202 |
+
- **Top-K per query:** {top_k}
|
| 203 |
+
"""
|
| 204 |
+
return md, gr.update(value=str(out_path), visible=True)
|
| 205 |
+
except Exception as e:
|
| 206 |
+
print("Batch search error:", e)
|
| 207 |
+
print(traceback.format_exc())
|
| 208 |
+
if DEBUG:
|
| 209 |
+
return f"❌ Error: {str(e)}\n\n```\n{traceback.format_exc()}\n```", gr.update(visible=False)
|
| 210 |
+
return f"❌ Error: {str(e)}", gr.update(visible=False)
|
| 211 |
+
|
| 212 |
+
# ===== GRADIO INTERFACE =====
|
| 213 |
+
with gr.Blocks(title="THIRAWAT - Drug Concept Search") as demo:
|
| 214 |
+
gr.HTML(
|
| 215 |
+
"""
|
| 216 |
+
<div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 10px; margin-bottom: 20px;">
|
| 217 |
+
<h1 style="color: white; margin: 0; font-size: 2em;">THIRAWAT</h1>
|
| 218 |
+
<p style="color: rgba(255,255,255,0.9); margin: 5px 0 0 0;">Drug Concept Entity Linking</p>
|
| 219 |
+
<p style="color: rgba(255,255,255,0.8); margin: 5px 0 0 0;">Map drug names to OMOP concepts using SapBERT + LanceDB.</p>
|
| 220 |
+
</div>
|
| 221 |
+
"""
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
with gr.Tabs():
|
| 225 |
+
with gr.Tab("Single Query"):
|
| 226 |
+
with gr.Row():
|
| 227 |
+
with gr.Column(scale=3):
|
| 228 |
+
query_input = gr.Textbox(
|
| 229 |
+
label="Drug name or query",
|
| 230 |
+
placeholder="e.g., aspirin, paracetamol, amoxicillin 500mg...",
|
| 231 |
+
lines=2,
|
| 232 |
+
)
|
| 233 |
+
with gr.Column(scale=1):
|
| 234 |
+
domain_hint = gr.Dropdown(
|
| 235 |
+
label="Domain",
|
| 236 |
+
choices=["Drug", "Condition", "Procedure", "Observation", "Device", "Unit"],
|
| 237 |
+
value="Drug",
|
| 238 |
+
interactive=False,
|
| 239 |
+
)
|
| 240 |
+
top_k = gr.Slider(
|
| 241 |
+
minimum=1,
|
| 242 |
+
maximum=50,
|
| 243 |
+
value=10,
|
| 244 |
+
step=1,
|
| 245 |
+
label="Number of results",
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
with gr.Row():
|
| 249 |
+
search_btn = gr.Button("Search", variant="primary")
|
| 250 |
+
clear_btn = gr.Button("Clear", variant="secondary")
|
| 251 |
+
|
| 252 |
+
output_md = gr.Markdown(label="Results")
|
| 253 |
+
output_table = gr.Dataframe(label="Results Table", interactive=False)
|
| 254 |
+
|
| 255 |
+
with gr.Tab("Batch Query"):
|
| 256 |
+
with gr.Row():
|
| 257 |
+
with gr.Column(scale=3):
|
| 258 |
+
batch_queries = gr.Textbox(
|
| 259 |
+
label="Drug names (one per line)",
|
| 260 |
+
placeholder="aspirin\nparacetamol\namoxicillin 500mg",
|
| 261 |
+
lines=10,
|
| 262 |
+
)
|
| 263 |
+
with gr.Column(scale=1):
|
| 264 |
+
batch_domain_hint = gr.Dropdown(
|
| 265 |
+
label="Domain",
|
| 266 |
+
choices=["Drug", "Condition", "Procedure", "Observation", "Device", "Unit"],
|
| 267 |
+
value="Drug",
|
| 268 |
+
interactive=False,
|
| 269 |
+
)
|
| 270 |
+
batch_topk = gr.Slider(
|
| 271 |
+
minimum=1,
|
| 272 |
+
maximum=50,
|
| 273 |
+
value=10,
|
| 274 |
+
step=1,
|
| 275 |
+
label="Top-K per query",
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
with gr.Row():
|
| 279 |
+
batch_btn = gr.Button("Process Batch", variant="primary")
|
| 280 |
+
batch_clear = gr.Button("Clear", variant="secondary")
|
| 281 |
+
|
| 282 |
+
batch_output = gr.Markdown(label="Summary")
|
| 283 |
+
batch_download = gr.DownloadButton(
|
| 284 |
+
label="Download Results (CSV)",
|
| 285 |
+
variant="secondary",
|
| 286 |
+
visible=False,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
def clear_single():
|
| 290 |
+
return "", 10, "", pd.DataFrame()
|
| 291 |
+
|
| 292 |
+
def clear_batch():
|
| 293 |
+
return "", 10, "", gr.update(visible=False)
|
| 294 |
+
|
| 295 |
+
search_btn.click(
|
| 296 |
+
fn=search_drugs,
|
| 297 |
+
inputs=[query_input, top_k],
|
| 298 |
+
outputs=[output_md, output_table],
|
| 299 |
+
api_name=False,
|
| 300 |
+
)
|
| 301 |
+
clear_btn.click(
|
| 302 |
+
fn=clear_single,
|
| 303 |
+
outputs=[query_input, top_k, output_md, output_table],
|
| 304 |
+
api_name=False,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
batch_btn.click(
|
| 308 |
+
fn=search_batch,
|
| 309 |
+
inputs=[batch_queries, batch_topk],
|
| 310 |
+
outputs=[batch_output, batch_download],
|
| 311 |
+
api_name=False,
|
| 312 |
+
)
|
| 313 |
+
batch_clear.click(
|
| 314 |
+
fn=clear_batch,
|
| 315 |
+
outputs=[batch_queries, batch_topk, batch_output, batch_download],
|
| 316 |
+
api_name=False,
|
| 317 |
+
)
|
| 318 |
+
gr.Markdown(
|
| 319 |
+
"""
|
| 320 |
+
---
|
| 321 |
+
|
| 322 |
+
**THIRAWAT** is a dense retrieval toolkit for mapping drug terminology to OMOP standard concepts.
|
| 323 |
+
"""
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
if __name__ == "__main__":
|
| 327 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
spec/spec.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# THIRAWAT-mapper-demo
|
| 2 |
+
|
| 3 |
+
End-to-end demo of THIRAWAT-mapper, a tool for mapping concepts from non-standard terminologies to standard terminologies in OHDSI/OMOP CDM.
|
| 4 |
+
|
| 5 |
+
## Key steps
|
| 6 |
+
|
| 7 |
+
1. Turn the [vocab set downloaded from Athena](spec/athena) into DuckDB format using `pip install athena2duckdb` [https://pypi.org/project/athena2duckdb/]
|
| 8 |
+
2. Follow instructions in [THIRAWAT-mapper](https://github.com/sidataplus/THIRAWAT-mapper)
|
| 9 |
+
3. Use [https://huggingface.co/cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR] for retrieval with CLS pooling and [https://huggingface.co/sidataplus/THIRAWAT-SapBERT] for ColBERT reranker
|
| 10 |
+
4. Build a complete gradio app, see prelim example [spec/example.py]
|
| 11 |
+
5. Package everything into a Hugging Face Space [https://huggingface.co/spaces/sidataplus/THIRAWAT-mapper-demo]
|
src/thirawat_demo/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""THIRAWAT mapper demo package."""
|
| 2 |
+
|
src/thirawat_demo/runtime/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Runtime modules for the Hugging Face Space app."""
|
| 2 |
+
|
| 3 |
+
from .config import RuntimeConfig
|
| 4 |
+
from .index_loader import resolve_lancedb_dir
|
| 5 |
+
from .peft_reranker import ThirawatPeftReranker
|
| 6 |
+
from .search_service import SearchService
|
| 7 |
+
|
| 8 |
+
__all__ = ["RuntimeConfig", "resolve_lancedb_dir", "ThirawatPeftReranker", "SearchService"]
|
src/thirawat_demo/runtime/config.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Runtime configuration for the Gradio app."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
DEFAULT_ENCODER_MODEL_ID = "cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR"
|
| 10 |
+
DEFAULT_RERANKER_ID = "sidataplus/THIRAWAT-SapBERT"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _clean_env(name: str, default: str | None = None) -> str | None:
|
| 14 |
+
value = os.getenv(name)
|
| 15 |
+
if value is None:
|
| 16 |
+
return default
|
| 17 |
+
cleaned = value.strip()
|
| 18 |
+
if cleaned == "" or cleaned.lower() == "none":
|
| 19 |
+
return default
|
| 20 |
+
return cleaned
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _env_int(name: str, default: int) -> int:
|
| 24 |
+
value = _clean_env(name)
|
| 25 |
+
if value is None:
|
| 26 |
+
return default
|
| 27 |
+
return int(value)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _env_float(name: str, default: float) -> float:
|
| 31 |
+
value = _clean_env(name)
|
| 32 |
+
if value is None:
|
| 33 |
+
return default
|
| 34 |
+
return float(value)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _env_bool(name: str, default: bool = False) -> bool:
|
| 38 |
+
value = _clean_env(name)
|
| 39 |
+
if value is None:
|
| 40 |
+
return default
|
| 41 |
+
return value.lower() in {"1", "true", "yes", "on"}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass(frozen=True)
|
| 45 |
+
class RuntimeConfig:
|
| 46 |
+
local_index_path: Path
|
| 47 |
+
index_repo: str | None
|
| 48 |
+
index_revision: str
|
| 49 |
+
hf_token: str | None
|
| 50 |
+
hf_data_dir: Path
|
| 51 |
+
lancedb_table: str
|
| 52 |
+
top_k_default: int
|
| 53 |
+
candidate_topk: int
|
| 54 |
+
retrieval_topk: int
|
| 55 |
+
device: str
|
| 56 |
+
encoder_model_id: str
|
| 57 |
+
reranker_id: str
|
| 58 |
+
post_mode: str
|
| 59 |
+
post_weight: float
|
| 60 |
+
tiebreak_eps: float
|
| 61 |
+
tiebreak_topn: int
|
| 62 |
+
post_strength_weight: float
|
| 63 |
+
post_jaccard_weight: float
|
| 64 |
+
post_brand_penalty: float
|
| 65 |
+
post_minmax: bool
|
| 66 |
+
brand_strict: bool
|
| 67 |
+
debug: bool
|
| 68 |
+
|
| 69 |
+
@classmethod
|
| 70 |
+
def from_env(cls) -> "RuntimeConfig":
|
| 71 |
+
local_index_path = Path(_clean_env("LOCAL_INDEX_PATH", "data/lancedb") or "data/lancedb")
|
| 72 |
+
index_repo = _clean_env("INDEX_REPO")
|
| 73 |
+
index_revision = _clean_env("INDEX_REVISION", "main") or "main"
|
| 74 |
+
hf_token = _clean_env("HF_TOKEN")
|
| 75 |
+
hf_data_dir = Path(_clean_env("HF_DATA_DIR", "/data") or "/data")
|
| 76 |
+
lancedb_table = _clean_env("LANCEDB_TABLE", "concepts_drug") or "concepts_drug"
|
| 77 |
+
top_k_default = _env_int("TOP_K_DEFAULT", 10)
|
| 78 |
+
candidate_topk = _env_int("CANDIDATE_TOPK", 100)
|
| 79 |
+
retrieval_topk = _env_int("RETRIEVAL_TOPK", 200)
|
| 80 |
+
device = (_clean_env("DEVICE", "auto") or "auto").lower()
|
| 81 |
+
encoder_model_id = _clean_env("ENCODER_MODEL_ID", DEFAULT_ENCODER_MODEL_ID) or DEFAULT_ENCODER_MODEL_ID
|
| 82 |
+
reranker_id = _clean_env("RERANKER_ID", DEFAULT_RERANKER_ID) or DEFAULT_RERANKER_ID
|
| 83 |
+
post_mode = (_clean_env("POST_MODE", "tiebreak") or "tiebreak").strip().lower()
|
| 84 |
+
post_weight = _env_float("POST_WEIGHT", 0.05)
|
| 85 |
+
tiebreak_eps = _env_float("TIEBREAK_EPS", 0.01)
|
| 86 |
+
tiebreak_topn = _env_int("TIEBREAK_TOPN", 50)
|
| 87 |
+
post_strength_weight = _env_float("POST_STRENGTH_WEIGHT", 0.6)
|
| 88 |
+
post_jaccard_weight = _env_float("POST_JACCARD_WEIGHT", 0.4)
|
| 89 |
+
post_brand_penalty = _env_float("POST_BRAND_PENALTY", 0.3)
|
| 90 |
+
post_minmax = _env_bool("POST_MINMAX", True)
|
| 91 |
+
brand_strict = _env_bool("BRAND_STRICT", False)
|
| 92 |
+
debug = _env_bool("DEBUG", False)
|
| 93 |
+
|
| 94 |
+
if candidate_topk <= 0:
|
| 95 |
+
raise ValueError("CANDIDATE_TOPK must be > 0.")
|
| 96 |
+
if retrieval_topk < candidate_topk:
|
| 97 |
+
retrieval_topk = candidate_topk
|
| 98 |
+
if top_k_default <= 0:
|
| 99 |
+
raise ValueError("TOP_K_DEFAULT must be > 0.")
|
| 100 |
+
if device not in {"cpu", "cuda", "mps", "auto"}:
|
| 101 |
+
raise ValueError("DEVICE must be one of: auto, cpu, cuda, mps.")
|
| 102 |
+
if post_mode not in {"blend", "tiebreak", "lex"}:
|
| 103 |
+
raise ValueError("POST_MODE must be one of: blend, tiebreak, lex.")
|
| 104 |
+
if tiebreak_topn <= 0:
|
| 105 |
+
raise ValueError("TIEBREAK_TOPN must be > 0.")
|
| 106 |
+
|
| 107 |
+
local_has_index = (local_index_path / "db").exists() or (
|
| 108 |
+
local_index_path.name == "db" and local_index_path.exists()
|
| 109 |
+
)
|
| 110 |
+
if not local_has_index and not index_repo:
|
| 111 |
+
raise ValueError(
|
| 112 |
+
"INDEX_REPO is required when LOCAL_INDEX_PATH does not point to an existing index."
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
return cls(
|
| 116 |
+
local_index_path=local_index_path,
|
| 117 |
+
index_repo=index_repo,
|
| 118 |
+
index_revision=index_revision,
|
| 119 |
+
hf_token=hf_token,
|
| 120 |
+
hf_data_dir=hf_data_dir,
|
| 121 |
+
lancedb_table=lancedb_table,
|
| 122 |
+
top_k_default=top_k_default,
|
| 123 |
+
candidate_topk=candidate_topk,
|
| 124 |
+
retrieval_topk=retrieval_topk,
|
| 125 |
+
device=device,
|
| 126 |
+
encoder_model_id=encoder_model_id,
|
| 127 |
+
reranker_id=reranker_id,
|
| 128 |
+
post_mode=post_mode,
|
| 129 |
+
post_weight=post_weight,
|
| 130 |
+
tiebreak_eps=tiebreak_eps,
|
| 131 |
+
tiebreak_topn=tiebreak_topn,
|
| 132 |
+
post_strength_weight=post_strength_weight,
|
| 133 |
+
post_jaccard_weight=post_jaccard_weight,
|
| 134 |
+
post_brand_penalty=post_brand_penalty,
|
| 135 |
+
post_minmax=post_minmax,
|
| 136 |
+
brand_strict=brand_strict,
|
| 137 |
+
debug=debug,
|
| 138 |
+
)
|
src/thirawat_demo/runtime/index_loader.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Index download/load helpers for runtime app."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from huggingface_hub import snapshot_download
|
| 8 |
+
|
| 9 |
+
from .config import RuntimeConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _can_connect_lancedb(path: Path) -> bool:
|
| 13 |
+
try:
|
| 14 |
+
import lancedb
|
| 15 |
+
|
| 16 |
+
lancedb.connect(str(path))
|
| 17 |
+
return True
|
| 18 |
+
except Exception:
|
| 19 |
+
return False
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _find_lancedb_dir(root: Path) -> Path:
|
| 23 |
+
candidates = []
|
| 24 |
+
if root.name == "db":
|
| 25 |
+
candidates.append(root)
|
| 26 |
+
candidates.append(root / "db")
|
| 27 |
+
candidates.append(root)
|
| 28 |
+
|
| 29 |
+
seen: set[Path] = set()
|
| 30 |
+
for candidate in candidates:
|
| 31 |
+
if candidate in seen:
|
| 32 |
+
continue
|
| 33 |
+
seen.add(candidate)
|
| 34 |
+
if candidate.exists() and candidate.is_dir() and _can_connect_lancedb(candidate):
|
| 35 |
+
return candidate
|
| 36 |
+
|
| 37 |
+
for candidate in root.rglob("db"):
|
| 38 |
+
if candidate.is_dir() and _can_connect_lancedb(candidate):
|
| 39 |
+
return candidate
|
| 40 |
+
|
| 41 |
+
raise FileNotFoundError(f"Could not find a valid LanceDB directory under: {root}")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def resolve_lancedb_dir(config: RuntimeConfig) -> Path:
|
| 45 |
+
local = config.local_index_path
|
| 46 |
+
if local.exists():
|
| 47 |
+
return _find_lancedb_dir(local)
|
| 48 |
+
|
| 49 |
+
if not config.index_repo:
|
| 50 |
+
raise ValueError("INDEX_REPO must be configured when local index is unavailable.")
|
| 51 |
+
|
| 52 |
+
download_root = config.hf_data_dir / "lancedb"
|
| 53 |
+
try:
|
| 54 |
+
download_root.mkdir(parents=True, exist_ok=True)
|
| 55 |
+
except OSError:
|
| 56 |
+
download_root = Path("data/lancedb")
|
| 57 |
+
download_root.mkdir(parents=True, exist_ok=True)
|
| 58 |
+
|
| 59 |
+
snapshot_path = Path(
|
| 60 |
+
snapshot_download(
|
| 61 |
+
repo_id=config.index_repo,
|
| 62 |
+
repo_type="dataset",
|
| 63 |
+
revision=config.index_revision,
|
| 64 |
+
token=config.hf_token or False,
|
| 65 |
+
local_dir=str(download_root),
|
| 66 |
+
)
|
| 67 |
+
)
|
| 68 |
+
return _find_lancedb_dir(snapshot_path)
|
| 69 |
+
|
src/thirawat_demo/runtime/peft_reranker.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PEFT-aware THIRAWAT reranker compatible with LanceDB rerank API."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from os import PathLike
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pyarrow as pa
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from thirawat_mapper.models.bms_pooling import bms_scores
|
| 14 |
+
|
| 15 |
+
DEFAULT_RERANKER_ID = "sidataplus/THIRAWAT-SapBERT"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def disable_mean_resizing() -> None:
|
| 19 |
+
"""Patch HF resize_token_embeddings default to mean_resizing=False."""
|
| 20 |
+
try:
|
| 21 |
+
from transformers import PreTrainedModel # type: ignore
|
| 22 |
+
except Exception:
|
| 23 |
+
return
|
| 24 |
+
if getattr(PreTrainedModel, "_mean_resizing_patched", False):
|
| 25 |
+
return
|
| 26 |
+
|
| 27 |
+
orig_resize = PreTrainedModel.resize_token_embeddings
|
| 28 |
+
|
| 29 |
+
def patched(self, new_num_tokens=None, pad_to_multiple_of=None, mean_resizing=False, **kwargs):
|
| 30 |
+
return orig_resize(
|
| 31 |
+
self,
|
| 32 |
+
new_num_tokens=new_num_tokens,
|
| 33 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 34 |
+
mean_resizing=mean_resizing,
|
| 35 |
+
**kwargs,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
PreTrainedModel.resize_token_embeddings = patched # type: ignore[assignment]
|
| 39 |
+
PreTrainedModel._mean_resizing_patched = True # type: ignore[attr-defined]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _resolve_model_id(model_id: str | PathLike[str]) -> str:
|
| 43 |
+
try:
|
| 44 |
+
path = Path(model_id).expanduser()
|
| 45 |
+
except TypeError:
|
| 46 |
+
return str(model_id)
|
| 47 |
+
if path.exists():
|
| 48 |
+
return str(path.resolve())
|
| 49 |
+
return str(model_id)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _load_colbert_with_peft(
|
| 53 |
+
model_id: str,
|
| 54 |
+
device: Optional[str],
|
| 55 |
+
max_query_len: int,
|
| 56 |
+
max_doc_len: int,
|
| 57 |
+
):
|
| 58 |
+
"""Load ColBERT and attach PEFT adapter when the checkpoint is adapter-only."""
|
| 59 |
+
|
| 60 |
+
disable_mean_resizing()
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
from peft import PeftConfig, PeftModel # type: ignore
|
| 64 |
+
except Exception:
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
peft_cfg = PeftConfig.from_pretrained(model_id)
|
| 69 |
+
except Exception:
|
| 70 |
+
return None
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
from pylate.models import ColBERT # type: ignore
|
| 74 |
+
from safetensors.torch import load_file # type: ignore
|
| 75 |
+
except Exception:
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
base_id = peft_cfg.base_model_name_or_path
|
| 79 |
+
tok_kwargs = {"tokenizer_name_or_path": model_id}
|
| 80 |
+
model = ColBERT(
|
| 81 |
+
model_name_or_path=base_id,
|
| 82 |
+
device=device or None,
|
| 83 |
+
query_length=int(max_query_len),
|
| 84 |
+
document_length=int(max_doc_len),
|
| 85 |
+
tokenizer_kwargs=tok_kwargs,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
encoder = model._first_module().auto_model # type: ignore[attr-defined]
|
| 89 |
+
try:
|
| 90 |
+
encoder.resize_token_embeddings(len(model.tokenizer), mean_resizing=False)
|
| 91 |
+
except Exception:
|
| 92 |
+
pass
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
encoder = PeftModel.from_pretrained(encoder, model_id, is_trainable=False)
|
| 96 |
+
encoder.eval()
|
| 97 |
+
model._first_module().auto_model = encoder # type: ignore[attr-defined]
|
| 98 |
+
except Exception:
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
dense_dir = Path(model_id).expanduser() / "1_Dense"
|
| 102 |
+
dense_weights = None
|
| 103 |
+
if dense_dir.exists():
|
| 104 |
+
safetensor_path = dense_dir / "model.safetensors"
|
| 105 |
+
if safetensor_path.exists():
|
| 106 |
+
try:
|
| 107 |
+
dense_weights = load_file(safetensor_path)
|
| 108 |
+
except Exception:
|
| 109 |
+
dense_weights = None
|
| 110 |
+
if dense_weights is None:
|
| 111 |
+
legacy_bin = dense_dir / "pytorch_model.bin"
|
| 112 |
+
if legacy_bin.exists():
|
| 113 |
+
try:
|
| 114 |
+
dense_weights = torch.load(legacy_bin, map_location="cpu")
|
| 115 |
+
except Exception:
|
| 116 |
+
dense_weights = None
|
| 117 |
+
if dense_weights and len(model) > 1:
|
| 118 |
+
try:
|
| 119 |
+
model[1].load_state_dict(dense_weights, strict=False)
|
| 120 |
+
except Exception:
|
| 121 |
+
pass
|
| 122 |
+
return model
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class _PylateColbert:
|
| 126 |
+
def __init__(self, model_id: str, device: Optional[str], max_query_len: int = 128, max_doc_len: int = 128):
|
| 127 |
+
from pylate.models import ColBERT # type: ignore
|
| 128 |
+
|
| 129 |
+
peft_model = _load_colbert_with_peft(model_id, device, max_query_len, max_doc_len)
|
| 130 |
+
if peft_model is not None:
|
| 131 |
+
self.model = peft_model
|
| 132 |
+
else:
|
| 133 |
+
self.model = ColBERT(
|
| 134 |
+
model_name_or_path=model_id,
|
| 135 |
+
device=device or None,
|
| 136 |
+
query_length=int(max_query_len),
|
| 137 |
+
document_length=int(max_doc_len),
|
| 138 |
+
)
|
| 139 |
+
try:
|
| 140 |
+
if device:
|
| 141 |
+
self.model.to(torch.device(device))
|
| 142 |
+
except Exception:
|
| 143 |
+
pass
|
| 144 |
+
|
| 145 |
+
def encode_query(self, text: str):
|
| 146 |
+
tok = self.model.tokenize([text], is_query=True, pad=True)
|
| 147 |
+
try:
|
| 148 |
+
device = next(self.model.parameters()).device
|
| 149 |
+
tok = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in tok.items()}
|
| 150 |
+
except Exception:
|
| 151 |
+
pass
|
| 152 |
+
out = self.model(tok)
|
| 153 |
+
return out["token_embeddings"], out.get("attention_mask")
|
| 154 |
+
|
| 155 |
+
def encode_docs(self, texts: List[str]):
|
| 156 |
+
tok = self.model.tokenize(texts, is_query=False, pad=False)
|
| 157 |
+
try:
|
| 158 |
+
device = next(self.model.parameters()).device
|
| 159 |
+
tok = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in tok.items()}
|
| 160 |
+
except Exception:
|
| 161 |
+
pass
|
| 162 |
+
out = self.model(tok)
|
| 163 |
+
return out["token_embeddings"], out.get("attention_mask")
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class ThirawatPeftReranker:
|
| 167 |
+
"""LanceDB-compatible THIRAWAT reranker with PEFT adapter support."""
|
| 168 |
+
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
model_id: str | PathLike[str] = DEFAULT_RERANKER_ID,
|
| 172 |
+
*,
|
| 173 |
+
device: str | None = None,
|
| 174 |
+
return_score: str = "all",
|
| 175 |
+
column: str = "profile_text",
|
| 176 |
+
) -> None:
|
| 177 |
+
self.model_id = _resolve_model_id(model_id)
|
| 178 |
+
self.device = device
|
| 179 |
+
self.return_score = return_score
|
| 180 |
+
self.score = return_score
|
| 181 |
+
self.column = column
|
| 182 |
+
self._scorer: Optional[_PylateColbert] = None
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
def scorer(self) -> _PylateColbert:
|
| 186 |
+
if self._scorer is None:
|
| 187 |
+
self._scorer = _PylateColbert(self.model_id, self.device)
|
| 188 |
+
return self._scorer
|
| 189 |
+
|
| 190 |
+
def rerank(self, query: str | None, results: pa.Table) -> pa.Table:
|
| 191 |
+
return self.rerank_vector(query, results)
|
| 192 |
+
|
| 193 |
+
def rerank_vector(self, query: str | None, vector_results: pa.Table) -> pa.Table:
|
| 194 |
+
col = self.column if self.column in vector_results.column_names else None
|
| 195 |
+
if col is None and "profile_text_norm" in vector_results.column_names:
|
| 196 |
+
col = "profile_text_norm"
|
| 197 |
+
if col is None:
|
| 198 |
+
raise ValueError(f"Candidate table missing '{self.column}' (or 'profile_text_norm') column.")
|
| 199 |
+
|
| 200 |
+
texts: List[str] = vector_results[col].to_pylist()
|
| 201 |
+
qtext = str(query or "").strip()
|
| 202 |
+
if not qtext:
|
| 203 |
+
raise ValueError("A text query is required for reranking.")
|
| 204 |
+
|
| 205 |
+
q_emb, q_mask = self.scorer.encode_query(qtext)
|
| 206 |
+
d_emb, d_mask = self.scorer.encode_docs(texts)
|
| 207 |
+
with torch.no_grad():
|
| 208 |
+
scores = bms_scores(q_emb, d_emb, q_mask, d_mask)
|
| 209 |
+
scores = scores.detach().float().cpu().numpy().reshape(-1)
|
| 210 |
+
df = vector_results.to_pandas()
|
| 211 |
+
df["_relevance_score"] = scores.astype(np.float32)
|
| 212 |
+
df = df.sort_values("_relevance_score", ascending=False, kind="mergesort").reset_index(drop=True)
|
| 213 |
+
return pa.Table.from_pandas(df, preserve_index=False)
|
| 214 |
+
|
| 215 |
+
def rerank_fts(self, query: str | None, fts_results: pa.Table) -> pa.Table:
|
| 216 |
+
return self.rerank_vector(query, fts_results)
|
| 217 |
+
|
| 218 |
+
def rerank_hybrid(self, query: str | None, vector_results: pa.Table, fts_results: pa.Table) -> pa.Table:
|
| 219 |
+
return self.rerank_vector(query, vector_results)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
__all__ = [
|
| 223 |
+
"DEFAULT_RERANKER_ID",
|
| 224 |
+
"ThirawatPeftReranker",
|
| 225 |
+
"_load_colbert_with_peft",
|
| 226 |
+
"disable_mean_resizing",
|
| 227 |
+
]
|
src/thirawat_demo/runtime/search_service.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Two-stage retrieval+rereanking runtime service."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Sequence
|
| 8 |
+
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from thirawat_mapper.infer.utils import enrich_with_post_scores, should_apply_post
|
| 11 |
+
|
| 12 |
+
from .config import RuntimeConfig
|
| 13 |
+
from .peft_reranker import ThirawatPeftReranker
|
| 14 |
+
|
| 15 |
+
DEFAULT_CONCEPT_CLASSES = [
|
| 16 |
+
"Clinical Drug",
|
| 17 |
+
"Quant Clinical Drug",
|
| 18 |
+
"Clinical Drug Comp",
|
| 19 |
+
"Clinical Drug Form",
|
| 20 |
+
"Branded Drug",
|
| 21 |
+
"Quant Branded Drug",
|
| 22 |
+
"Branded Drug Comp",
|
| 23 |
+
"Branded Drug Form",
|
| 24 |
+
"Ingredient",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SearchService:
|
| 29 |
+
"""Runtime search service with mandatory two-stage ranking."""
|
| 30 |
+
|
| 31 |
+
def __init__(self, config: RuntimeConfig, lancedb_dir: Path) -> None:
|
| 32 |
+
self.config = config
|
| 33 |
+
self.lancedb_dir = lancedb_dir
|
| 34 |
+
self._table = None
|
| 35 |
+
self._vector_column: str | None = None
|
| 36 |
+
self._embedder = None
|
| 37 |
+
self._reranker = None
|
| 38 |
+
self._normalize_text = None
|
| 39 |
+
self._available_concept_classes = self._load_manifest_concept_classes()
|
| 40 |
+
|
| 41 |
+
def startup(self) -> None:
|
| 42 |
+
try:
|
| 43 |
+
from thirawat_mapper.infer.utils import configure_torch_for_infer, resolve_device
|
| 44 |
+
from thirawat_mapper.models import SapBERTEmbedder
|
| 45 |
+
from thirawat_mapper.utils import connect_table, normalize_text_value
|
| 46 |
+
except Exception as exc: # pragma: no cover - import-time failure path
|
| 47 |
+
raise RuntimeError(f"Failed to import thirawat-mapper runtime dependencies: {exc}") from exc
|
| 48 |
+
|
| 49 |
+
device = self.config.device
|
| 50 |
+
if device == "auto":
|
| 51 |
+
device = resolve_device("auto")
|
| 52 |
+
|
| 53 |
+
configure_torch_for_infer(device)
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
table, vector_column = connect_table(self.lancedb_dir, self.config.lancedb_table)
|
| 57 |
+
except Exception as exc:
|
| 58 |
+
raise RuntimeError(
|
| 59 |
+
f"Failed to connect LanceDB table '{self.config.lancedb_table}' in '{self.lancedb_dir}': {exc}"
|
| 60 |
+
) from exc
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
embedder = SapBERTEmbedder(
|
| 64 |
+
model_id=self.config.encoder_model_id,
|
| 65 |
+
device=device,
|
| 66 |
+
batch_size=64,
|
| 67 |
+
max_length=128,
|
| 68 |
+
pooling="cls",
|
| 69 |
+
)
|
| 70 |
+
reranker = ThirawatPeftReranker(
|
| 71 |
+
model_id=self.config.reranker_id,
|
| 72 |
+
device=device,
|
| 73 |
+
return_score="all",
|
| 74 |
+
)
|
| 75 |
+
except Exception as exc:
|
| 76 |
+
raise RuntimeError(f"Failed to initialize embedding/reranker models: {exc}") from exc
|
| 77 |
+
|
| 78 |
+
# Force lazy components to load at startup so runtime errors fail fast.
|
| 79 |
+
try:
|
| 80 |
+
_ = embedder.encode(["startup warmup"])
|
| 81 |
+
_ = reranker.scorer.encode_query("startup warmup")
|
| 82 |
+
except Exception as exc:
|
| 83 |
+
raise RuntimeError(f"Failed to warm up models at startup: {exc}") from exc
|
| 84 |
+
|
| 85 |
+
self._table = table
|
| 86 |
+
self._vector_column = vector_column
|
| 87 |
+
self._embedder = embedder
|
| 88 |
+
self._reranker = reranker
|
| 89 |
+
self._normalize_text = normalize_text_value
|
| 90 |
+
|
| 91 |
+
def available_concept_classes(self) -> list[str]:
|
| 92 |
+
if self._available_concept_classes:
|
| 93 |
+
return list(self._available_concept_classes)
|
| 94 |
+
|
| 95 |
+
self._ensure_started()
|
| 96 |
+
assert self._table is not None
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
df = self._table.to_arrow().to_pandas()
|
| 100 |
+
except Exception:
|
| 101 |
+
return list(DEFAULT_CONCEPT_CLASSES)
|
| 102 |
+
|
| 103 |
+
if "concept_class_id" not in df.columns:
|
| 104 |
+
return list(DEFAULT_CONCEPT_CLASSES)
|
| 105 |
+
|
| 106 |
+
values = sorted(
|
| 107 |
+
{
|
| 108 |
+
str(value).strip()
|
| 109 |
+
for value in df["concept_class_id"].dropna().tolist()
|
| 110 |
+
if str(value).strip()
|
| 111 |
+
}
|
| 112 |
+
)
|
| 113 |
+
return values or list(DEFAULT_CONCEPT_CLASSES)
|
| 114 |
+
|
| 115 |
+
def _load_manifest_concept_classes(self) -> list[str]:
|
| 116 |
+
manifest_path = self.lancedb_dir / f"{self.config.lancedb_table}_manifest.json"
|
| 117 |
+
if not manifest_path.exists():
|
| 118 |
+
return list(DEFAULT_CONCEPT_CLASSES)
|
| 119 |
+
try:
|
| 120 |
+
payload = json.loads(manifest_path.read_text(encoding="utf-8"))
|
| 121 |
+
except Exception:
|
| 122 |
+
return list(DEFAULT_CONCEPT_CLASSES)
|
| 123 |
+
raw_values = payload.get("concept_class_id")
|
| 124 |
+
if not isinstance(raw_values, list):
|
| 125 |
+
return list(DEFAULT_CONCEPT_CLASSES)
|
| 126 |
+
values = [str(value).strip() for value in raw_values if str(value).strip()]
|
| 127 |
+
return values or list(DEFAULT_CONCEPT_CLASSES)
|
| 128 |
+
|
| 129 |
+
def _ensure_started(self) -> None:
|
| 130 |
+
if self._table is None or self._vector_column is None or self._embedder is None or self._reranker is None:
|
| 131 |
+
self.startup()
|
| 132 |
+
|
| 133 |
+
def search(self, query: str, top_k: int, concept_class_ids: Sequence[str] | None = None) -> pd.DataFrame:
|
| 134 |
+
where_clause = self._build_concept_class_where(concept_class_ids)
|
| 135 |
+
if concept_class_ids is not None and where_clause is None:
|
| 136 |
+
return pd.DataFrame(columns=self._ordered_output_columns())
|
| 137 |
+
|
| 138 |
+
self._ensure_started()
|
| 139 |
+
assert self._table is not None
|
| 140 |
+
assert self._vector_column is not None
|
| 141 |
+
assert self._embedder is not None
|
| 142 |
+
assert self._reranker is not None
|
| 143 |
+
assert self._normalize_text is not None
|
| 144 |
+
|
| 145 |
+
if not query or not query.strip():
|
| 146 |
+
return pd.DataFrame(columns=self._ordered_output_columns())
|
| 147 |
+
|
| 148 |
+
normalized_query = self._normalize_text(query)
|
| 149 |
+
query_emb = self._embedder.encode([normalized_query])[0]
|
| 150 |
+
|
| 151 |
+
builder = self._table.search(
|
| 152 |
+
query_emb.astype(float).tolist(),
|
| 153 |
+
vector_column_name=self._vector_column,
|
| 154 |
+
query_type="vector",
|
| 155 |
+
)
|
| 156 |
+
if where_clause:
|
| 157 |
+
schema_names = set(getattr(getattr(self._table, "schema", None), "names", []))
|
| 158 |
+
if "concept_class_id" not in schema_names:
|
| 159 |
+
raise RuntimeError("Requested concept class filtering but table has no 'concept_class_id' column.")
|
| 160 |
+
builder = builder.where(where_clause)
|
| 161 |
+
|
| 162 |
+
builder = builder.distance_type("cosine").limit(self.config.retrieval_topk)
|
| 163 |
+
|
| 164 |
+
# Mandatory two-stage retrieval: vector candidates then THIRAWAT reranking.
|
| 165 |
+
result = builder.rerank(reranker=self._reranker, query_string=normalized_query).limit(
|
| 166 |
+
self.config.candidate_topk
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
df = result.to_arrow().to_pandas()
|
| 170 |
+
if df.empty:
|
| 171 |
+
return pd.DataFrame(columns=self._ordered_output_columns())
|
| 172 |
+
|
| 173 |
+
df = self._apply_post_scoring(df, normalized_query)
|
| 174 |
+
df = self._finalize_scores(df)
|
| 175 |
+
df = self._ensure_columns(
|
| 176 |
+
df,
|
| 177 |
+
required=[
|
| 178 |
+
"concept_id",
|
| 179 |
+
"concept_name",
|
| 180 |
+
"concept_code",
|
| 181 |
+
"vocabulary_id",
|
| 182 |
+
"concept_class_id",
|
| 183 |
+
"score",
|
| 184 |
+
"retrieval_score",
|
| 185 |
+
],
|
| 186 |
+
)
|
| 187 |
+
df.insert(0, "rank", list(range(1, len(df) + 1)))
|
| 188 |
+
return df[self._ordered_output_columns()].head(top_k).reset_index(drop=True)
|
| 189 |
+
|
| 190 |
+
def search_batch(
|
| 191 |
+
self,
|
| 192 |
+
queries: list[str],
|
| 193 |
+
top_k: int,
|
| 194 |
+
concept_class_ids: Sequence[str] | None = None,
|
| 195 |
+
) -> pd.DataFrame:
|
| 196 |
+
rows: list[dict[str, Any]] = []
|
| 197 |
+
for query in queries:
|
| 198 |
+
result = self.search(query, top_k, concept_class_ids=concept_class_ids)
|
| 199 |
+
if result.empty:
|
| 200 |
+
continue
|
| 201 |
+
for _, row in result.iterrows():
|
| 202 |
+
rows.append(
|
| 203 |
+
{
|
| 204 |
+
"query_text": query,
|
| 205 |
+
"rank": int(row["rank"]),
|
| 206 |
+
"concept_id": row["concept_id"],
|
| 207 |
+
"concept_name": row["concept_name"],
|
| 208 |
+
"concept_code": row["concept_code"],
|
| 209 |
+
"vocabulary_id": row["vocabulary_id"],
|
| 210 |
+
"concept_class_id": row.get("concept_class_id"),
|
| 211 |
+
"score": float(row["score"]),
|
| 212 |
+
"retrieval_score": float(row["retrieval_score"]),
|
| 213 |
+
}
|
| 214 |
+
)
|
| 215 |
+
if not rows:
|
| 216 |
+
return pd.DataFrame(
|
| 217 |
+
columns=[
|
| 218 |
+
"query_text",
|
| 219 |
+
"rank",
|
| 220 |
+
"concept_id",
|
| 221 |
+
"concept_name",
|
| 222 |
+
"concept_code",
|
| 223 |
+
"vocabulary_id",
|
| 224 |
+
"concept_class_id",
|
| 225 |
+
"score",
|
| 226 |
+
"retrieval_score",
|
| 227 |
+
]
|
| 228 |
+
)
|
| 229 |
+
return pd.DataFrame(rows)
|
| 230 |
+
|
| 231 |
+
def _apply_post_scoring(self, df: pd.DataFrame, normalized_query: str) -> pd.DataFrame:
|
| 232 |
+
if should_apply_post(self.config.post_mode, self.config.post_weight):
|
| 233 |
+
return enrich_with_post_scores(
|
| 234 |
+
df,
|
| 235 |
+
normalized_query,
|
| 236 |
+
post_strength_weight=self.config.post_strength_weight,
|
| 237 |
+
post_jaccard_weight=self.config.post_jaccard_weight,
|
| 238 |
+
post_brand_penalty=self.config.post_brand_penalty,
|
| 239 |
+
post_minmax=self.config.post_minmax,
|
| 240 |
+
post_weight=self.config.post_weight,
|
| 241 |
+
prefer_brand=True,
|
| 242 |
+
post_mode=self.config.post_mode,
|
| 243 |
+
tiebreak_eps=self.config.tiebreak_eps,
|
| 244 |
+
tiebreak_topn=self.config.tiebreak_topn,
|
| 245 |
+
brand_strict=self.config.brand_strict,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
base_col = "_relevance_score" if "_relevance_score" in df.columns else "score"
|
| 249 |
+
if base_col in df.columns:
|
| 250 |
+
return df.sort_values(base_col, ascending=False, kind="mergesort").reset_index(drop=True)
|
| 251 |
+
return df.reset_index(drop=True)
|
| 252 |
+
|
| 253 |
+
@staticmethod
|
| 254 |
+
def _build_concept_class_where(concept_class_ids: Sequence[str] | None) -> str | None:
|
| 255 |
+
if concept_class_ids is None:
|
| 256 |
+
return None
|
| 257 |
+
values = sorted({str(value).strip() for value in concept_class_ids if str(value).strip()})
|
| 258 |
+
if not values:
|
| 259 |
+
return None
|
| 260 |
+
escaped = [value.replace("'", "''") for value in values]
|
| 261 |
+
if len(escaped) == 1:
|
| 262 |
+
return f"concept_class_id = '{escaped[0]}'"
|
| 263 |
+
return "concept_class_id IN (" + ",".join(f"'{value}'" for value in escaped) + ")"
|
| 264 |
+
|
| 265 |
+
@staticmethod
|
| 266 |
+
def _ensure_columns(df: pd.DataFrame, required: list[str]) -> pd.DataFrame:
|
| 267 |
+
for column in required:
|
| 268 |
+
if column not in df.columns:
|
| 269 |
+
df[column] = None
|
| 270 |
+
return df
|
| 271 |
+
|
| 272 |
+
@staticmethod
|
| 273 |
+
def _finalize_scores(df: pd.DataFrame) -> pd.DataFrame:
|
| 274 |
+
if "final_score" in df.columns:
|
| 275 |
+
df["score"] = pd.to_numeric(df["final_score"], errors="coerce").fillna(0.0)
|
| 276 |
+
elif "_relevance_score" in df.columns:
|
| 277 |
+
df["score"] = pd.to_numeric(df["_relevance_score"], errors="coerce").fillna(0.0)
|
| 278 |
+
elif "_distance" in df.columns:
|
| 279 |
+
distance = pd.to_numeric(df["_distance"], errors="coerce").fillna(1.0)
|
| 280 |
+
df["score"] = 1.0 - distance
|
| 281 |
+
elif "score" in df.columns:
|
| 282 |
+
df["score"] = pd.to_numeric(df["score"], errors="coerce").fillna(0.0)
|
| 283 |
+
else:
|
| 284 |
+
df["score"] = 0.0
|
| 285 |
+
|
| 286 |
+
if "retrieval_score" in df.columns:
|
| 287 |
+
df["retrieval_score"] = pd.to_numeric(df["retrieval_score"], errors="coerce").fillna(df["score"])
|
| 288 |
+
elif "_distance" in df.columns:
|
| 289 |
+
distance = pd.to_numeric(df["_distance"], errors="coerce").fillna(1.0)
|
| 290 |
+
df["retrieval_score"] = 1.0 - distance
|
| 291 |
+
else:
|
| 292 |
+
df["retrieval_score"] = df["score"]
|
| 293 |
+
|
| 294 |
+
return df.reset_index(drop=True)
|
| 295 |
+
|
| 296 |
+
@staticmethod
|
| 297 |
+
def _ordered_output_columns() -> list[str]:
|
| 298 |
+
return [
|
| 299 |
+
"rank",
|
| 300 |
+
"concept_id",
|
| 301 |
+
"concept_name",
|
| 302 |
+
"concept_code",
|
| 303 |
+
"vocabulary_id",
|
| 304 |
+
"concept_class_id",
|
| 305 |
+
"score",
|
| 306 |
+
"retrieval_score",
|
| 307 |
+
]
|
src/thirawat_demo/space_ui.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gradio UI for Hugging Face Space runtime."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
import traceback
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Sequence
|
| 10 |
+
|
| 11 |
+
import gradio as gr
|
| 12 |
+
import pandas as pd
|
| 13 |
+
|
| 14 |
+
from thirawat_demo.runtime import RuntimeConfig, SearchService, resolve_lancedb_dir
|
| 15 |
+
|
| 16 |
+
DEFAULT_CONCEPT_CLASSES = [
|
| 17 |
+
"Clinical Drug",
|
| 18 |
+
"Quant Clinical Drug",
|
| 19 |
+
"Clinical Drug Comp",
|
| 20 |
+
"Clinical Drug Form",
|
| 21 |
+
"Branded Drug",
|
| 22 |
+
"Quant Branded Drug",
|
| 23 |
+
"Branded Drug Comp",
|
| 24 |
+
"Branded Drug Form",
|
| 25 |
+
"Ingredient",
|
| 26 |
+
]
|
| 27 |
+
DEFAULT_SINGLE_QUERY = "Augmentin 875/125"
|
| 28 |
+
|
| 29 |
+
_SERVICE: SearchService | None = None
|
| 30 |
+
_CONFIG: RuntimeConfig | None = None
|
| 31 |
+
_RESULT_COLUMNS = [
|
| 32 |
+
"rank",
|
| 33 |
+
"concept_id",
|
| 34 |
+
"concept_name",
|
| 35 |
+
"concept_code",
|
| 36 |
+
"vocabulary_id",
|
| 37 |
+
"concept_class",
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _get_service(*, startup: bool = True) -> tuple[SearchService, RuntimeConfig]:
|
| 42 |
+
global _SERVICE, _CONFIG
|
| 43 |
+
if _SERVICE is None or _CONFIG is None:
|
| 44 |
+
_CONFIG = RuntimeConfig.from_env()
|
| 45 |
+
db_dir = resolve_lancedb_dir(_CONFIG)
|
| 46 |
+
_SERVICE = SearchService(config=_CONFIG, lancedb_dir=db_dir)
|
| 47 |
+
if startup:
|
| 48 |
+
_SERVICE.startup()
|
| 49 |
+
return _SERVICE, _CONFIG
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _resolve_concept_class_choices() -> list[str]:
|
| 53 |
+
try:
|
| 54 |
+
config = RuntimeConfig.from_env()
|
| 55 |
+
except Exception:
|
| 56 |
+
return list(DEFAULT_CONCEPT_CLASSES)
|
| 57 |
+
|
| 58 |
+
local = config.local_index_path
|
| 59 |
+
local_has_index = (local / "db").exists() or (local.name == "db" and local.exists())
|
| 60 |
+
if not local_has_index:
|
| 61 |
+
return list(DEFAULT_CONCEPT_CLASSES)
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
db_dir = resolve_lancedb_dir(config)
|
| 65 |
+
service = SearchService(config=config, lancedb_dir=db_dir)
|
| 66 |
+
return service.available_concept_classes()
|
| 67 |
+
except Exception:
|
| 68 |
+
return list(DEFAULT_CONCEPT_CLASSES)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _top_k_default() -> int:
|
| 72 |
+
raw = (os.getenv("TOP_K_DEFAULT") or "").strip()
|
| 73 |
+
if not raw:
|
| 74 |
+
return 10
|
| 75 |
+
try:
|
| 76 |
+
value = int(raw)
|
| 77 |
+
except ValueError:
|
| 78 |
+
return 10
|
| 79 |
+
return value if value > 0 else 10
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _empty_results_df() -> pd.DataFrame:
|
| 83 |
+
return pd.DataFrame(columns=_RESULT_COLUMNS)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _to_display_results(results: pd.DataFrame) -> pd.DataFrame:
|
| 87 |
+
if results is None or results.empty:
|
| 88 |
+
return _empty_results_df()
|
| 89 |
+
|
| 90 |
+
df = results.copy()
|
| 91 |
+
concept_class = (
|
| 92 |
+
df["concept_class_id"].astype(str)
|
| 93 |
+
if "concept_class_id" in df.columns
|
| 94 |
+
else pd.Series([""] * len(df), index=df.index, dtype=object)
|
| 95 |
+
)
|
| 96 |
+
concept_class = concept_class.where(~concept_class.isin(["None", "nan"]), "")
|
| 97 |
+
df["concept_class"] = concept_class
|
| 98 |
+
for column in _RESULT_COLUMNS:
|
| 99 |
+
if column not in df.columns:
|
| 100 |
+
df[column] = None
|
| 101 |
+
return df[_RESULT_COLUMNS].reset_index(drop=True)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _format_single_results(query: str, results: pd.DataFrame) -> tuple[str, pd.DataFrame]:
|
| 105 |
+
if results.empty:
|
| 106 |
+
return "No results found. Try a different query.", _empty_results_df()
|
| 107 |
+
|
| 108 |
+
top = results.iloc[0]
|
| 109 |
+
md = (
|
| 110 |
+
f"## Results for: \"{query}\"\n\n"
|
| 111 |
+
f"**Top match:** {top['concept_name']} (score {float(top['score']):.4f})\n\n"
|
| 112 |
+
f"Rows returned: **{len(results)}**"
|
| 113 |
+
)
|
| 114 |
+
return md, _to_display_results(results)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _validate_concept_classes(concept_class_ids: Sequence[str] | None) -> list[str] | None:
|
| 118 |
+
if concept_class_ids is None:
|
| 119 |
+
return None
|
| 120 |
+
cleaned = [str(value).strip() for value in concept_class_ids if str(value).strip()]
|
| 121 |
+
return cleaned
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def search_single(query: str, top_k: int, concept_class_ids: Sequence[str]) -> tuple[str, pd.DataFrame]:
|
| 125 |
+
try:
|
| 126 |
+
selected = _validate_concept_classes(concept_class_ids)
|
| 127 |
+
if selected is not None and not selected:
|
| 128 |
+
return "Please select at least one concept class.", _empty_results_df()
|
| 129 |
+
service, _ = _get_service()
|
| 130 |
+
results = service.search(query=query, top_k=top_k, concept_class_ids=selected)
|
| 131 |
+
return _format_single_results(query, results)
|
| 132 |
+
except Exception as exc: # pragma: no cover - runtime path
|
| 133 |
+
stack = traceback.format_exc()
|
| 134 |
+
print(stack)
|
| 135 |
+
return f"Error: {exc}", _empty_results_df()
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def search_batch(queries_text: str, top_k: int, concept_class_ids: Sequence[str]) -> tuple[str, gr.update]:
|
| 139 |
+
try:
|
| 140 |
+
lines = [line.strip() for line in (queries_text or "").splitlines() if line.strip()]
|
| 141 |
+
if not lines:
|
| 142 |
+
return "Please provide one query per line.", gr.update(visible=False)
|
| 143 |
+
|
| 144 |
+
selected = _validate_concept_classes(concept_class_ids)
|
| 145 |
+
if selected is not None and not selected:
|
| 146 |
+
return "Please select at least one concept class.", gr.update(visible=False)
|
| 147 |
+
|
| 148 |
+
service, _ = _get_service()
|
| 149 |
+
results = service.search_batch(lines, top_k=top_k, concept_class_ids=selected)
|
| 150 |
+
if results.empty:
|
| 151 |
+
return "No batch results found.", gr.update(visible=False)
|
| 152 |
+
|
| 153 |
+
out_dir = Path(tempfile.gettempdir()) / "thirawat_mapper_demo"
|
| 154 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 155 |
+
out_path = out_dir / "batch_results.csv"
|
| 156 |
+
results.to_csv(out_path, index=False)
|
| 157 |
+
|
| 158 |
+
md = (
|
| 159 |
+
"## Batch Search Complete\n\n"
|
| 160 |
+
f"- Queries processed: **{len(lines)}**\n"
|
| 161 |
+
f"- Rows returned: **{len(results)}**\n"
|
| 162 |
+
f"- Top-K per query: **{top_k}**"
|
| 163 |
+
)
|
| 164 |
+
return md, gr.update(value=str(out_path), visible=True)
|
| 165 |
+
except Exception as exc: # pragma: no cover - runtime path
|
| 166 |
+
stack = traceback.format_exc()
|
| 167 |
+
print(stack)
|
| 168 |
+
return f"Error: {exc}", gr.update(visible=False)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _clear_single(concept_classes: Sequence[str]) -> tuple[str, int, Sequence[str], str, pd.DataFrame]:
|
| 172 |
+
return DEFAULT_SINGLE_QUERY, _top_k_default(), concept_classes, "", _empty_results_df()
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _clear_batch(concept_classes: Sequence[str]) -> tuple[str, int, Sequence[str], str, gr.update]:
|
| 176 |
+
return "", _top_k_default(), concept_classes, "", gr.update(visible=False)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def build_demo() -> gr.Blocks:
|
| 180 |
+
top_k_default = _top_k_default()
|
| 181 |
+
concept_class_choices = _resolve_concept_class_choices()
|
| 182 |
+
concept_class_default = list(concept_class_choices)
|
| 183 |
+
|
| 184 |
+
with gr.Blocks(title="THIRAWAT Mapper Demo") as demo:
|
| 185 |
+
gr.Markdown(
|
| 186 |
+
"""
|
| 187 |
+
# THIRAWAT Mapper Demo
|
| 188 |
+
Map non-standard drug terms to OMOP standard concepts using SapBERT retrieval + THIRAWAT reranking.
|
| 189 |
+
"""
|
| 190 |
+
)
|
| 191 |
+
with gr.Tabs():
|
| 192 |
+
with gr.Tab("Single Query"):
|
| 193 |
+
with gr.Row():
|
| 194 |
+
with gr.Column(scale=3):
|
| 195 |
+
query_input = gr.Textbox(
|
| 196 |
+
label="Drug query",
|
| 197 |
+
placeholder="e.g., aspirin, amoxicillin 500 mg, paracetamol",
|
| 198 |
+
value=DEFAULT_SINGLE_QUERY,
|
| 199 |
+
lines=1,
|
| 200 |
+
)
|
| 201 |
+
with gr.Column(scale=1):
|
| 202 |
+
domain = gr.Dropdown(
|
| 203 |
+
label="Domain",
|
| 204 |
+
choices=["Drug"],
|
| 205 |
+
value="Drug",
|
| 206 |
+
interactive=False,
|
| 207 |
+
)
|
| 208 |
+
top_k = gr.Slider(
|
| 209 |
+
minimum=1,
|
| 210 |
+
maximum=50,
|
| 211 |
+
value=top_k_default,
|
| 212 |
+
step=1,
|
| 213 |
+
label="Number of results",
|
| 214 |
+
)
|
| 215 |
+
with gr.Row():
|
| 216 |
+
single_search = gr.Button("Search", variant="primary")
|
| 217 |
+
single_clear = gr.Button("Clear")
|
| 218 |
+
with gr.Row():
|
| 219 |
+
single_concept_classes = gr.Dropdown(
|
| 220 |
+
label="Concept classes",
|
| 221 |
+
choices=concept_class_choices,
|
| 222 |
+
value=concept_class_default,
|
| 223 |
+
multiselect=True,
|
| 224 |
+
filterable=True,
|
| 225 |
+
allow_custom_value=False,
|
| 226 |
+
interactive=True,
|
| 227 |
+
)
|
| 228 |
+
single_md = gr.Markdown(label="Summary")
|
| 229 |
+
single_table = gr.Dataframe(
|
| 230 |
+
label="Results",
|
| 231 |
+
headers=_RESULT_COLUMNS,
|
| 232 |
+
value=_empty_results_df(),
|
| 233 |
+
interactive=False,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
with gr.Tab("Batch Query"):
|
| 237 |
+
with gr.Row():
|
| 238 |
+
with gr.Column(scale=3):
|
| 239 |
+
batch_input = gr.Textbox(
|
| 240 |
+
label="Queries (one per line)",
|
| 241 |
+
placeholder="aspirin\namoxicillin 500 mg\nparacetamol",
|
| 242 |
+
lines=10,
|
| 243 |
+
)
|
| 244 |
+
with gr.Column(scale=1):
|
| 245 |
+
batch_domain = gr.Dropdown(
|
| 246 |
+
label="Domain",
|
| 247 |
+
choices=["Drug"],
|
| 248 |
+
value="Drug",
|
| 249 |
+
interactive=False,
|
| 250 |
+
)
|
| 251 |
+
batch_top_k = gr.Slider(
|
| 252 |
+
minimum=1,
|
| 253 |
+
maximum=50,
|
| 254 |
+
value=top_k_default,
|
| 255 |
+
step=1,
|
| 256 |
+
label="Top-K per query",
|
| 257 |
+
)
|
| 258 |
+
with gr.Row():
|
| 259 |
+
batch_search_btn = gr.Button("Process Batch", variant="primary")
|
| 260 |
+
batch_clear_btn = gr.Button("Clear")
|
| 261 |
+
with gr.Row():
|
| 262 |
+
batch_concept_classes = gr.Dropdown(
|
| 263 |
+
label="Concept classes",
|
| 264 |
+
choices=concept_class_choices,
|
| 265 |
+
value=concept_class_default,
|
| 266 |
+
multiselect=True,
|
| 267 |
+
filterable=True,
|
| 268 |
+
allow_custom_value=False,
|
| 269 |
+
interactive=True,
|
| 270 |
+
)
|
| 271 |
+
batch_md = gr.Markdown(label="Summary")
|
| 272 |
+
batch_download = gr.DownloadButton(
|
| 273 |
+
label="Download CSV",
|
| 274 |
+
variant="secondary",
|
| 275 |
+
visible=False,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
single_search.click(
|
| 279 |
+
fn=search_single,
|
| 280 |
+
inputs=[query_input, top_k, single_concept_classes],
|
| 281 |
+
outputs=[single_md, single_table],
|
| 282 |
+
api_name="search_single",
|
| 283 |
+
)
|
| 284 |
+
query_input.submit(
|
| 285 |
+
fn=search_single,
|
| 286 |
+
inputs=[query_input, top_k, single_concept_classes],
|
| 287 |
+
outputs=[single_md, single_table],
|
| 288 |
+
api_name=False,
|
| 289 |
+
)
|
| 290 |
+
single_clear.click(
|
| 291 |
+
fn=lambda: _clear_single(concept_class_default),
|
| 292 |
+
outputs=[query_input, top_k, single_concept_classes, single_md, single_table],
|
| 293 |
+
api_name=False,
|
| 294 |
+
)
|
| 295 |
+
batch_search_btn.click(
|
| 296 |
+
fn=search_batch,
|
| 297 |
+
inputs=[batch_input, batch_top_k, batch_concept_classes],
|
| 298 |
+
outputs=[batch_md, batch_download],
|
| 299 |
+
api_name="search_batch",
|
| 300 |
+
)
|
| 301 |
+
batch_clear_btn.click(
|
| 302 |
+
fn=lambda: _clear_batch(concept_class_default),
|
| 303 |
+
outputs=[batch_input, batch_top_k, batch_concept_classes, batch_md, batch_download],
|
| 304 |
+
api_name=False,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
domain.change(lambda: "Drug", outputs=domain, api_name=False)
|
| 308 |
+
batch_domain.change(lambda: "Drug", outputs=batch_domain, api_name=False)
|
| 309 |
+
|
| 310 |
+
return demo
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Test bootstrap for src-layout imports."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 9 |
+
SRC = ROOT / "src"
|
| 10 |
+
if str(SRC) not in sys.path:
|
| 11 |
+
sys.path.insert(0, str(SRC))
|
| 12 |
+
|
tests/test_build_lancedb_index.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import importlib.util
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import sys
|
| 6 |
+
import types
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _load_script_module():
|
| 10 |
+
root = Path(__file__).resolve().parents[1]
|
| 11 |
+
script_path = root / "scripts" / "offline" / "build_lancedb_index.py"
|
| 12 |
+
spec = importlib.util.spec_from_file_location("build_lancedb_index", script_path)
|
| 13 |
+
assert spec is not None
|
| 14 |
+
assert spec.loader is not None
|
| 15 |
+
module = importlib.util.module_from_spec(spec)
|
| 16 |
+
spec.loader.exec_module(module)
|
| 17 |
+
return module
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def test_default_concept_classes_exact():
|
| 21 |
+
module = _load_script_module()
|
| 22 |
+
expected = [
|
| 23 |
+
"Clinical Drug",
|
| 24 |
+
"Quant Clinical Drug",
|
| 25 |
+
"Clinical Drug Comp",
|
| 26 |
+
"Clinical Drug Form",
|
| 27 |
+
"Branded Drug",
|
| 28 |
+
"Quant Branded Drug",
|
| 29 |
+
"Branded Drug Comp",
|
| 30 |
+
"Branded Drug Form",
|
| 31 |
+
"Ingredient",
|
| 32 |
+
]
|
| 33 |
+
assert module.DEFAULT_CONCEPT_CLASSES == expected
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def test_resolve_device_prefers_mps_on_darwin(monkeypatch):
|
| 37 |
+
module = _load_script_module()
|
| 38 |
+
monkeypatch.setattr(module.platform, "system", lambda: "Darwin")
|
| 39 |
+
|
| 40 |
+
fake_torch = types.SimpleNamespace(
|
| 41 |
+
cuda=types.SimpleNamespace(is_available=lambda: True),
|
| 42 |
+
backends=types.SimpleNamespace(mps=types.SimpleNamespace(is_available=lambda: True)),
|
| 43 |
+
)
|
| 44 |
+
monkeypatch.setitem(sys.modules, "torch", fake_torch)
|
| 45 |
+
assert module.resolve_device("auto") == "mps"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def test_resolve_device_explicit_passthrough():
|
| 49 |
+
module = _load_script_module()
|
| 50 |
+
assert module.resolve_device("cpu") == "cpu"
|
| 51 |
+
assert module.resolve_device("cuda") == "cuda"
|
| 52 |
+
assert module.resolve_device("mps") == "mps"
|
| 53 |
+
|
tests/test_peft_reranker.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
import types
|
| 5 |
+
|
| 6 |
+
import thirawat_demo.runtime.peft_reranker as peft_mod
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_load_colbert_with_peft_resizes_before_adapter(monkeypatch):
|
| 10 |
+
events: list[tuple] = []
|
| 11 |
+
|
| 12 |
+
class FakeEncoder:
|
| 13 |
+
def resize_token_embeddings(self, size, mean_resizing=False):
|
| 14 |
+
events.append(("resize", size, mean_resizing))
|
| 15 |
+
|
| 16 |
+
class FakeFirstModule:
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.auto_model = FakeEncoder()
|
| 19 |
+
|
| 20 |
+
class FakeDense:
|
| 21 |
+
def load_state_dict(self, state_dict, strict=False):
|
| 22 |
+
events.append(("dense_load", strict))
|
| 23 |
+
|
| 24 |
+
class FakeColBERT:
|
| 25 |
+
def __init__(self, *args, **kwargs):
|
| 26 |
+
events.append(("colbert_init", kwargs.get("model_name_or_path"), kwargs.get("tokenizer_kwargs")))
|
| 27 |
+
self._first = FakeFirstModule()
|
| 28 |
+
self._dense = FakeDense()
|
| 29 |
+
self.tokenizer = [0, 1, 2, 3]
|
| 30 |
+
|
| 31 |
+
def _first_module(self):
|
| 32 |
+
return self._first
|
| 33 |
+
|
| 34 |
+
def __len__(self):
|
| 35 |
+
return 2
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, idx):
|
| 38 |
+
assert idx == 1
|
| 39 |
+
return self._dense
|
| 40 |
+
|
| 41 |
+
class FakePeftCfg:
|
| 42 |
+
base_model_name_or_path = "cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR"
|
| 43 |
+
|
| 44 |
+
class FakePeftConfig:
|
| 45 |
+
@staticmethod
|
| 46 |
+
def from_pretrained(model_id):
|
| 47 |
+
events.append(("peft_config", model_id))
|
| 48 |
+
return FakePeftCfg()
|
| 49 |
+
|
| 50 |
+
class FakePeftModel:
|
| 51 |
+
@staticmethod
|
| 52 |
+
def from_pretrained(encoder, model_id, is_trainable=False):
|
| 53 |
+
events.append(("peft_load", model_id, is_trainable))
|
| 54 |
+
|
| 55 |
+
class Wrapped:
|
| 56 |
+
def eval(self):
|
| 57 |
+
events.append(("peft_eval",))
|
| 58 |
+
|
| 59 |
+
return Wrapped()
|
| 60 |
+
|
| 61 |
+
pylate_module = types.ModuleType("pylate")
|
| 62 |
+
pylate_models = types.ModuleType("pylate.models")
|
| 63 |
+
pylate_models.ColBERT = FakeColBERT
|
| 64 |
+
pylate_module.models = pylate_models
|
| 65 |
+
|
| 66 |
+
safetensors_module = types.ModuleType("safetensors")
|
| 67 |
+
safetensors_torch = types.ModuleType("safetensors.torch")
|
| 68 |
+
safetensors_torch.load_file = lambda _: {}
|
| 69 |
+
safetensors_module.torch = safetensors_torch
|
| 70 |
+
|
| 71 |
+
monkeypatch.setattr(peft_mod, "disable_mean_resizing", lambda: events.append(("disable_mean_resizing",)))
|
| 72 |
+
monkeypatch.setitem(
|
| 73 |
+
sys.modules,
|
| 74 |
+
"peft",
|
| 75 |
+
types.SimpleNamespace(PeftConfig=FakePeftConfig, PeftModel=FakePeftModel),
|
| 76 |
+
)
|
| 77 |
+
monkeypatch.setitem(sys.modules, "pylate", pylate_module)
|
| 78 |
+
monkeypatch.setitem(sys.modules, "pylate.models", pylate_models)
|
| 79 |
+
monkeypatch.setitem(sys.modules, "safetensors", safetensors_module)
|
| 80 |
+
monkeypatch.setitem(sys.modules, "safetensors.torch", safetensors_torch)
|
| 81 |
+
|
| 82 |
+
model = peft_mod._load_colbert_with_peft("sidataplus/THIRAWAT-SapBERT", "cpu", 128, 128)
|
| 83 |
+
assert model is not None
|
| 84 |
+
|
| 85 |
+
resize_idx = next(idx for idx, entry in enumerate(events) if entry[0] == "resize")
|
| 86 |
+
peft_idx = next(idx for idx, entry in enumerate(events) if entry[0] == "peft_load")
|
| 87 |
+
assert resize_idx < peft_idx
|
tests/test_runtime_config.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from thirawat_demo.runtime.config import RuntimeConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_from_env_requires_index_repo_when_local_missing(monkeypatch):
|
| 11 |
+
monkeypatch.delenv("INDEX_REPO", raising=False)
|
| 12 |
+
monkeypatch.setenv("LOCAL_INDEX_PATH", "missing-local-index")
|
| 13 |
+
with pytest.raises(ValueError):
|
| 14 |
+
RuntimeConfig.from_env()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def test_from_env_accepts_local_db_without_repo(monkeypatch, tmp_path: Path):
|
| 18 |
+
local_index = tmp_path / "data" / "lancedb" / "db"
|
| 19 |
+
local_index.mkdir(parents=True)
|
| 20 |
+
|
| 21 |
+
monkeypatch.delenv("INDEX_REPO", raising=False)
|
| 22 |
+
monkeypatch.setenv("LOCAL_INDEX_PATH", str(local_index.parent))
|
| 23 |
+
config = RuntimeConfig.from_env()
|
| 24 |
+
assert config.local_index_path == local_index.parent
|
| 25 |
+
assert config.index_repo is None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def test_retrieval_topk_is_raised_to_candidate_topk(monkeypatch, tmp_path: Path):
|
| 29 |
+
local_index = tmp_path / "idx" / "db"
|
| 30 |
+
local_index.mkdir(parents=True)
|
| 31 |
+
|
| 32 |
+
monkeypatch.setenv("LOCAL_INDEX_PATH", str(local_index.parent))
|
| 33 |
+
monkeypatch.setenv("CANDIDATE_TOPK", "200")
|
| 34 |
+
monkeypatch.setenv("RETRIEVAL_TOPK", "10")
|
| 35 |
+
|
| 36 |
+
config = RuntimeConfig.from_env()
|
| 37 |
+
assert config.retrieval_topk == 200
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_device_defaults_to_auto(monkeypatch, tmp_path: Path):
|
| 41 |
+
local_index = tmp_path / "idx" / "db"
|
| 42 |
+
local_index.mkdir(parents=True)
|
| 43 |
+
monkeypatch.setenv("LOCAL_INDEX_PATH", str(local_index.parent))
|
| 44 |
+
monkeypatch.delenv("DEVICE", raising=False)
|
| 45 |
+
|
| 46 |
+
config = RuntimeConfig.from_env()
|
| 47 |
+
assert config.device == "auto"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def test_post_config_defaults(monkeypatch, tmp_path: Path):
|
| 51 |
+
local_index = tmp_path / "idx" / "db"
|
| 52 |
+
local_index.mkdir(parents=True)
|
| 53 |
+
monkeypatch.setenv("LOCAL_INDEX_PATH", str(local_index.parent))
|
| 54 |
+
monkeypatch.delenv("POST_MODE", raising=False)
|
| 55 |
+
monkeypatch.delenv("POST_WEIGHT", raising=False)
|
| 56 |
+
monkeypatch.delenv("TIEBREAK_EPS", raising=False)
|
| 57 |
+
monkeypatch.delenv("TIEBREAK_TOPN", raising=False)
|
| 58 |
+
monkeypatch.delenv("POST_STRENGTH_WEIGHT", raising=False)
|
| 59 |
+
monkeypatch.delenv("POST_JACCARD_WEIGHT", raising=False)
|
| 60 |
+
monkeypatch.delenv("POST_BRAND_PENALTY", raising=False)
|
| 61 |
+
monkeypatch.delenv("POST_MINMAX", raising=False)
|
| 62 |
+
monkeypatch.delenv("BRAND_STRICT", raising=False)
|
| 63 |
+
|
| 64 |
+
config = RuntimeConfig.from_env()
|
| 65 |
+
assert config.post_mode == "tiebreak"
|
| 66 |
+
assert config.post_weight == 0.05
|
| 67 |
+
assert config.tiebreak_eps == 0.01
|
| 68 |
+
assert config.tiebreak_topn == 50
|
| 69 |
+
assert config.post_strength_weight == 0.6
|
| 70 |
+
assert config.post_jaccard_weight == 0.4
|
| 71 |
+
assert config.post_brand_penalty == 0.3
|
| 72 |
+
assert config.post_minmax is True
|
| 73 |
+
assert config.brand_strict is False
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def test_invalid_post_mode_raises(monkeypatch, tmp_path: Path):
|
| 77 |
+
local_index = tmp_path / "idx" / "db"
|
| 78 |
+
local_index.mkdir(parents=True)
|
| 79 |
+
monkeypatch.setenv("LOCAL_INDEX_PATH", str(local_index.parent))
|
| 80 |
+
monkeypatch.setenv("POST_MODE", "unsupported")
|
| 81 |
+
with pytest.raises(ValueError):
|
| 82 |
+
RuntimeConfig.from_env()
|
tests/test_search_service.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import pyarrow as pa
|
| 8 |
+
|
| 9 |
+
import thirawat_demo.runtime.search_service as service_mod
|
| 10 |
+
from thirawat_demo.runtime.config import RuntimeConfig
|
| 11 |
+
from thirawat_demo.runtime.search_service import SearchService
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _make_config(tmp_path: Path) -> RuntimeConfig:
|
| 15 |
+
return RuntimeConfig(
|
| 16 |
+
local_index_path=tmp_path / "data" / "lancedb",
|
| 17 |
+
index_repo=None,
|
| 18 |
+
index_revision="main",
|
| 19 |
+
hf_token=None,
|
| 20 |
+
hf_data_dir=tmp_path / "data",
|
| 21 |
+
lancedb_table="concepts_drug",
|
| 22 |
+
top_k_default=10,
|
| 23 |
+
candidate_topk=100,
|
| 24 |
+
retrieval_topk=200,
|
| 25 |
+
device="cpu",
|
| 26 |
+
encoder_model_id="cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR",
|
| 27 |
+
reranker_id="sidataplus/THIRAWAT-SapBERT",
|
| 28 |
+
post_mode="tiebreak",
|
| 29 |
+
post_weight=0.05,
|
| 30 |
+
tiebreak_eps=0.01,
|
| 31 |
+
tiebreak_topn=50,
|
| 32 |
+
post_strength_weight=0.6,
|
| 33 |
+
post_jaccard_weight=0.4,
|
| 34 |
+
post_brand_penalty=0.3,
|
| 35 |
+
post_minmax=True,
|
| 36 |
+
brand_strict=False,
|
| 37 |
+
debug=False,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_build_concept_class_where_sorted_and_escaped():
|
| 42 |
+
expr = SearchService._build_concept_class_where(["Ingredient", "Clinical Drug", "O'Reilly", "Ingredient"])
|
| 43 |
+
assert expr == "concept_class_id IN ('Clinical Drug','Ingredient','O''Reilly')"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_apply_post_scoring_uses_mapper_tiebreak(monkeypatch, tmp_path: Path):
|
| 47 |
+
config = _make_config(tmp_path)
|
| 48 |
+
service = SearchService(config=config, lancedb_dir=tmp_path)
|
| 49 |
+
source = pd.DataFrame(
|
| 50 |
+
{
|
| 51 |
+
"concept_name": ["a", "b"],
|
| 52 |
+
"profile_text": ["", ""],
|
| 53 |
+
"_relevance_score": [0.9, 0.89],
|
| 54 |
+
}
|
| 55 |
+
)
|
| 56 |
+
called: dict[str, object] = {}
|
| 57 |
+
|
| 58 |
+
def fake_should_apply_post(mode: str, weight: float) -> bool:
|
| 59 |
+
called["mode"] = mode
|
| 60 |
+
called["weight"] = weight
|
| 61 |
+
return True
|
| 62 |
+
|
| 63 |
+
def fake_enrich(df: pd.DataFrame, query_text_norm: str, **kwargs):
|
| 64 |
+
called["query"] = query_text_norm
|
| 65 |
+
called["kwargs"] = kwargs
|
| 66 |
+
out = df.copy()
|
| 67 |
+
out["final_score"] = [0.1, 0.2]
|
| 68 |
+
return out.iloc[[1, 0]].reset_index(drop=True)
|
| 69 |
+
|
| 70 |
+
monkeypatch.setattr(service_mod, "should_apply_post", fake_should_apply_post)
|
| 71 |
+
monkeypatch.setattr(service_mod, "enrich_with_post_scores", fake_enrich)
|
| 72 |
+
|
| 73 |
+
out = service._apply_post_scoring(source, "aspirin 81 mg")
|
| 74 |
+
assert out["concept_name"].tolist() == ["b", "a"]
|
| 75 |
+
assert called["mode"] == "tiebreak"
|
| 76 |
+
assert called["weight"] == 0.05
|
| 77 |
+
assert called["query"] == "aspirin 81 mg"
|
| 78 |
+
assert called["kwargs"]["tiebreak_eps"] == 0.01
|
| 79 |
+
assert called["kwargs"]["tiebreak_topn"] == 50
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def test_search_applies_concept_class_filter(monkeypatch, tmp_path: Path):
|
| 83 |
+
config = _make_config(tmp_path)
|
| 84 |
+
service = SearchService(config=config, lancedb_dir=tmp_path)
|
| 85 |
+
|
| 86 |
+
class FakeBuilder:
|
| 87 |
+
def __init__(self):
|
| 88 |
+
self.where_clause = None
|
| 89 |
+
|
| 90 |
+
def where(self, clause):
|
| 91 |
+
self.where_clause = clause
|
| 92 |
+
return self
|
| 93 |
+
|
| 94 |
+
def distance_type(self, _):
|
| 95 |
+
return self
|
| 96 |
+
|
| 97 |
+
def limit(self, _):
|
| 98 |
+
return self
|
| 99 |
+
|
| 100 |
+
def rerank(self, reranker=None, query_string=None):
|
| 101 |
+
return self
|
| 102 |
+
|
| 103 |
+
def to_arrow(self):
|
| 104 |
+
frame = pd.DataFrame(
|
| 105 |
+
{
|
| 106 |
+
"concept_id": [1191],
|
| 107 |
+
"concept_name": ["aspirin"],
|
| 108 |
+
"concept_code": ["1191"],
|
| 109 |
+
"vocabulary_id": ["RxNorm"],
|
| 110 |
+
"_relevance_score": [0.998],
|
| 111 |
+
"_distance": [0.2],
|
| 112 |
+
"concept_class_id": ["Ingredient"],
|
| 113 |
+
}
|
| 114 |
+
)
|
| 115 |
+
return pa.Table.from_pandas(frame, preserve_index=False)
|
| 116 |
+
|
| 117 |
+
class FakeSchema:
|
| 118 |
+
names = ["concept_id", "concept_name", "concept_code", "vocabulary_id", "concept_class_id", "vector"]
|
| 119 |
+
|
| 120 |
+
class FakeTable:
|
| 121 |
+
schema = FakeSchema()
|
| 122 |
+
|
| 123 |
+
def __init__(self, builder: FakeBuilder):
|
| 124 |
+
self._builder = builder
|
| 125 |
+
|
| 126 |
+
def search(self, *args, **kwargs):
|
| 127 |
+
return self._builder
|
| 128 |
+
|
| 129 |
+
class FakeEmbedder:
|
| 130 |
+
def encode(self, texts):
|
| 131 |
+
return np.array([[0.0, 1.0]], dtype=float)
|
| 132 |
+
|
| 133 |
+
builder = FakeBuilder()
|
| 134 |
+
service._table = FakeTable(builder)
|
| 135 |
+
service._vector_column = "vector"
|
| 136 |
+
service._embedder = FakeEmbedder()
|
| 137 |
+
service._reranker = object()
|
| 138 |
+
service._normalize_text = lambda value: value.strip()
|
| 139 |
+
monkeypatch.setattr(SearchService, "_apply_post_scoring", lambda self, df, q: df)
|
| 140 |
+
|
| 141 |
+
result = service.search("aspirin", top_k=5, concept_class_ids=["Ingredient", "Clinical Drug"])
|
| 142 |
+
assert builder.where_clause == "concept_class_id IN ('Clinical Drug','Ingredient')"
|
| 143 |
+
assert result["concept_name"].tolist() == ["aspirin"]
|
| 144 |
+
assert result["concept_class_id"].tolist() == ["Ingredient"]
|
| 145 |
+
assert result["retrieval_score"].tolist() == [0.8]
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def test_search_returns_empty_when_no_concept_class_selected(tmp_path: Path):
|
| 149 |
+
config = _make_config(tmp_path)
|
| 150 |
+
service = SearchService(config=config, lancedb_dir=tmp_path)
|
| 151 |
+
out = service.search("aspirin", top_k=10, concept_class_ids=[])
|
| 152 |
+
assert out.empty
|
tests/test_space_ui.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
|
| 5 |
+
from thirawat_demo import space_ui
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def test_to_display_results_replaces_scores_with_concept_class():
|
| 9 |
+
raw = pd.DataFrame(
|
| 10 |
+
{
|
| 11 |
+
"rank": [1],
|
| 12 |
+
"concept_id": [123],
|
| 13 |
+
"concept_name": ["aspirin"],
|
| 14 |
+
"concept_code": ["1191"],
|
| 15 |
+
"vocabulary_id": ["RxNorm"],
|
| 16 |
+
"concept_class_id": ["Ingredient"],
|
| 17 |
+
"score": [0.9],
|
| 18 |
+
"retrieval_score": [0.8],
|
| 19 |
+
}
|
| 20 |
+
)
|
| 21 |
+
out = space_ui._to_display_results(raw)
|
| 22 |
+
assert out.columns.tolist() == [
|
| 23 |
+
"rank",
|
| 24 |
+
"concept_id",
|
| 25 |
+
"concept_name",
|
| 26 |
+
"concept_code",
|
| 27 |
+
"vocabulary_id",
|
| 28 |
+
"concept_class",
|
| 29 |
+
]
|
| 30 |
+
assert out.iloc[0]["concept_class"] == "Ingredient"
|
| 31 |
+
assert "score" not in out.columns
|
| 32 |
+
assert "retrieval_score" not in out.columns
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_to_display_results_handles_missing_concept_class_id():
|
| 36 |
+
raw = pd.DataFrame(
|
| 37 |
+
{
|
| 38 |
+
"rank": [1],
|
| 39 |
+
"concept_id": [123],
|
| 40 |
+
"concept_name": ["aspirin"],
|
| 41 |
+
"concept_code": ["1191"],
|
| 42 |
+
"vocabulary_id": ["RxNorm"],
|
| 43 |
+
}
|
| 44 |
+
)
|
| 45 |
+
out = space_ui._to_display_results(raw)
|
| 46 |
+
assert out.iloc[0]["concept_class"] == ""
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def test_format_single_results_keeps_score_text_but_returns_projected_table():
|
| 50 |
+
raw = pd.DataFrame(
|
| 51 |
+
{
|
| 52 |
+
"rank": [1],
|
| 53 |
+
"concept_id": [123],
|
| 54 |
+
"concept_name": ["aspirin"],
|
| 55 |
+
"concept_code": ["1191"],
|
| 56 |
+
"vocabulary_id": ["RxNorm"],
|
| 57 |
+
"concept_class_id": ["Ingredient"],
|
| 58 |
+
"score": [0.91],
|
| 59 |
+
"retrieval_score": [0.81],
|
| 60 |
+
}
|
| 61 |
+
)
|
| 62 |
+
md, table = space_ui._format_single_results("aspirin", raw)
|
| 63 |
+
assert "score 0.9100" in md
|
| 64 |
+
assert "concept_class" in table.columns
|
| 65 |
+
assert "score" not in table.columns
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def test_search_single_requires_at_least_one_concept_class():
|
| 69 |
+
md, table = space_ui.search_single("aspirin", 10, [])
|
| 70 |
+
assert md == "Please select at least one concept class."
|
| 71 |
+
assert table.empty
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def test_search_batch_requires_at_least_one_concept_class():
|
| 75 |
+
md, download = space_ui.search_batch("aspirin", 10, [])
|
| 76 |
+
assert md == "Please select at least one concept class."
|
| 77 |
+
assert download["visible"] is False
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def test_clear_single_resets_default_query_and_concept_classes():
|
| 81 |
+
concept_classes = ["Ingredient", "Clinical Drug"]
|
| 82 |
+
query, top_k, returned_classes, md, table = space_ui._clear_single(concept_classes)
|
| 83 |
+
assert query == space_ui.DEFAULT_SINGLE_QUERY
|
| 84 |
+
assert top_k == 10
|
| 85 |
+
assert returned_classes == concept_classes
|
| 86 |
+
assert md == ""
|
| 87 |
+
assert table.empty
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def test_clear_batch_resets_concept_classes_and_hides_download():
|
| 91 |
+
concept_classes = ["Ingredient", "Clinical Drug"]
|
| 92 |
+
queries, top_k, returned_classes, md, download = space_ui._clear_batch(concept_classes)
|
| 93 |
+
assert queries == ""
|
| 94 |
+
assert top_k == 10
|
| 95 |
+
assert returned_classes == concept_classes
|
| 96 |
+
assert md == ""
|
| 97 |
+
assert download["visible"] is False
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|