na399 commited on
Commit
25c66a0
·
verified ·
1 Parent(s): ca78da0

Deploy THIRAWAT mapper app

Browse files
.gitattributes CHANGED
@@ -1,35 +1,2 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ # Auto detect text files and perform LF normalization
2
+ * text=auto
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .venv/
2
+ __pycache__/
3
+ .pytest_cache/
4
+ .mypy_cache/
5
+ .ruff_cache/
6
+ .DS_Store
7
+
8
+ data/
9
+ !data/.gitkeep
10
+
11
+ temp/*.tmp
AGENTS.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Repository Guidelines
2
+
3
+ ## Project Structure & Module Organization
4
+ - `src/thirawat_demo/runtime/`: runtime configuration, index loading, reranking, and search orchestration.
5
+ - `src/thirawat_demo/space_ui.py`: Gradio interface logic; `app.py` is the local/Space entrypoint.
6
+ - `scripts/offline/`: offline pipeline scripts (`build_duckdb.py`, `build_lancedb_index.py`, `publish_index_hf.py`).
7
+ - `tests/`: pytest suite for runtime behavior and indexing workflow.
8
+ - `data/`: local/generated artifacts (DuckDB + LanceDB); avoid treating as source code.
9
+ - `docs/` and `spec/`: reference notes and specification examples.
10
+
11
+ ## Build, Test, and Development Commands
12
+ - `uv sync --python 3.11`: install dependencies from `uv.lock`.
13
+ - `uv run pytest`: run all tests in `tests/`.
14
+ - `uv run pytest tests/test_search_service.py -q`: run a focused test file.
15
+ - `uv run python app.py`: start the Gradio runtime app locally.
16
+ - `uv run python scripts/offline/build_duckdb.py --help`: inspect offline build options before running data jobs.
17
+
18
+ ## Coding Style & Naming Conventions
19
+ - Target Python 3.11+, 4-space indentation, and PEP 8-compatible formatting.
20
+ - Prefer explicit type hints and small, single-purpose functions.
21
+ - Naming: `snake_case` for modules/functions/variables, `PascalCase` for classes, `UPPER_SNAKE_CASE` for constants.
22
+ - Keep the architecture boundary strict: heavy indexing stays in `scripts/offline/`; runtime code in `src/thirawat_demo/runtime/`.
23
+
24
+ ## Testing Guidelines
25
+ - Framework: `pytest` (configured in `pyproject.toml` with `testpaths = ["tests"]`).
26
+ - File naming: `tests/test_*.py`; test function naming: `test_*`.
27
+ - Use `monkeypatch` for environment-variable behavior and `tmp_path` for filesystem-dependent cases.
28
+ - Add or update tests with every behavioral change, especially around config parsing, retrieval/reranking flow, and UI integration points.
29
+
30
+ ## Commit & Pull Request Guidelines
31
+ - Current history is minimal and uses short subject lines (example: `Initial commit`).
32
+ - Use concise, imperative commit subjects (prefer <= 72 chars), e.g., `Add post-score validation for runtime config`.
33
+ - PRs should include: purpose, key changes, verification commands run, and any required env var/data changes.
34
+ - Include screenshots for UI-impacting changes (`space_ui.py`) and note any model/index compatibility implications.
35
+
36
+ ## Security & Configuration Tips
37
+ - Never commit secrets (for example `HF_TOKEN`).
38
+ - Keep runtime config in environment variables (`INDEX_REPO`, `DEVICE`, `LANCEDB_TABLE`, etc.).
39
+ - Publish large index artifacts through the HF dataset workflow instead of Git.
README.md CHANGED
@@ -1,14 +1,122 @@
1
- ---
2
- title: THIRAWAT Mapper Demo
3
- emoji: 😻
4
- colorFrom: red
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 6.3.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: Demo of THIRAWAT Mapper
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIRAWAT Mapper Demo
2
+
3
+ This repository is intentionally split into:
4
+
5
+ 1. **Offline indexing pipeline** (Athena -> DuckDB -> LanceDB -> HF dataset upload)
6
+ 2. **Hugging Face Space runtime app** (download prebuilt index + serve Gradio UI)
7
+
8
+ ## Separation Rule
9
+
10
+ The Space runtime must **not** build indexes.
11
+ All heavy data preparation happens offline via scripts in `/scripts/offline`.
12
+
13
+ ## Project Layout
14
+
15
+ - `/Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/scripts/offline/build_duckdb.py`
16
+ - `/Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/scripts/offline/build_lancedb_index.py`
17
+ - `/Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/scripts/offline/publish_index_hf.py`
18
+ - `/Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/src/thirawat_demo/runtime/config.py`
19
+ - `/Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/src/thirawat_demo/runtime/index_loader.py`
20
+ - `/Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/src/thirawat_demo/runtime/search_service.py`
21
+ - `/Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/src/thirawat_demo/space_ui.py`
22
+ - `/Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/app.py`
23
+
24
+ ## Setup
25
+
26
+ ```bash
27
+ uv sync --python 3.11
28
+ ```
29
+
30
+ ## Offline Build Flow
31
+
32
+ ### 1) Athena -> DuckDB
33
+
34
+ ```bash
35
+ uv run python scripts/offline/build_duckdb.py \
36
+ --athena-dir /path/to/athena-export \
37
+ --out /Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/data/derived/concepts.duckdb \
38
+ --overwrite
39
+ ```
40
+
41
+ ### 2) DuckDB -> LanceDB index
42
+
43
+ `--device auto` behavior:
44
+ - macOS + MPS available: `mps`
45
+ - otherwise if CUDA available: `cuda`
46
+ - otherwise: `cpu`
47
+
48
+ ```bash
49
+ uv run python scripts/offline/build_lancedb_index.py \
50
+ --duckdb /Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/data/derived/concepts.duckdb \
51
+ --out-db /Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/data/lancedb/db \
52
+ --table concepts_drug \
53
+ --device auto
54
+ ```
55
+
56
+ Default concept classes used by the builder:
57
+
58
+ - Clinical Drug
59
+ - Quant Clinical Drug
60
+ - Clinical Drug Comp
61
+ - Clinical Drug Form
62
+ - Branded Drug
63
+ - Quant Branded Drug
64
+ - Branded Drug Comp
65
+ - Branded Drug Form
66
+ - Ingredient
67
+
68
+ ### 3) Publish index artifact to HF dataset repo
69
+
70
+ ```bash
71
+ export HF_TOKEN=hf_xxx
72
+ uv run python scripts/offline/publish_index_hf.py \
73
+ --repo-id your-org/thirawat-mapper-demo-index \
74
+ --source-dir /Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/data/lancedb \
75
+ --revision main
76
+ ```
77
+
78
+ ## HF Space Deployment Flow
79
+
80
+ In Space settings, configure env vars as needed:
81
+
82
+ - `INDEX_REPO` (required unless local index is baked into image)
83
+ - `INDEX_REVISION` (default `main`)
84
+ - `LANCEDB_TABLE` (default `concepts_drug`)
85
+ - `DEVICE` (default `auto`)
86
+ - `TOP_K_DEFAULT`, `CANDIDATE_TOPK`, `RETRIEVAL_TOPK` (optional)
87
+ - `POST_MODE` (default `tiebreak`; supported: `blend|tiebreak|lex`)
88
+ - `POST_WEIGHT` (default `0.05`; used for `blend`)
89
+ - `TIEBREAK_EPS` (default `0.01`)
90
+ - `TIEBREAK_TOPN` (default `50`)
91
+ - `POST_STRENGTH_WEIGHT` (default `0.6`)
92
+ - `POST_JACCARD_WEIGHT` (default `0.4`)
93
+ - `POST_BRAND_PENALTY` (default `0.3`)
94
+ - `POST_MINMAX` (default `true`)
95
+ - `BRAND_STRICT` (default `false`)
96
+ - `HF_TOKEN` (optional for private repos)
97
+
98
+ Space entrypoint:
99
+
100
+ ```bash
101
+ python app.py
102
+ ```
103
+
104
+ ## Local Runtime Test
105
+
106
+ ```bash
107
+ export INDEX_REPO=your-org/thirawat-mapper-demo-index
108
+ export DEVICE=auto
109
+ uv run python app.py
110
+ ```
111
+
112
+ ## Notes
113
+
114
+ - Runtime search enforces two-stage retrieval: SapBERT vector retrieval + THIRAWAT reranking.
115
+ - Runtime reranker loads PEFT adapters directly (trainer-style) and does not merge checkpoints to disk at startup.
116
+ - Runtime ranking applies THIRAWAT deterministic post/tie-break rules.
117
+ - Warning note: `No sentence-transformers model found with name cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR.` can appear during reranker startup. This is informational in the current setup: PyLate ColBERT first checks for a Sentence-Transformers-packaged model and then falls back to building from the base Hugging Face encoder.
118
+ - Retrieval embeddings use Hugging Face `transformers` directly, while reranking uses PyLate ColBERT (which relies on Sentence-Transformers internals). So `sentence-transformers` is still required for reranking even though retrieval itself does not depend on it.
119
+ - Gradio UI includes concept-class multi-select filters (include-only). Defaults come from index manifest when present.
120
+ - Domain is fixed to `Drug` in this v1 app.
121
+ - Upstream doc issues are tracked in:
122
+ - `/Users/na399/GitHub/sidataplus/THIRAWAT-mapper-demo/docs/upstream-instruction-issues.md`
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hugging Face Space entrypoint."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from pathlib import Path
7
+ import socket
8
+ import sys
9
+
10
+ ROOT = Path(__file__).resolve().parent
11
+ SRC = ROOT / "src"
12
+ if str(SRC) not in sys.path:
13
+ sys.path.insert(0, str(SRC))
14
+
15
+ from thirawat_demo.space_ui import build_demo
16
+
17
+
18
+ def _is_port_free(port: int) -> bool:
19
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
20
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
21
+ return sock.connect_ex(("127.0.0.1", port)) != 0
22
+
23
+
24
+ def _resolve_server_port() -> int:
25
+ """Use fixed PORT in managed environments; choose first free local port."""
26
+ port_env = os.getenv("PORT")
27
+ if port_env:
28
+ return int(port_env)
29
+ gradio_port_env = os.getenv("GRADIO_SERVER_PORT")
30
+ if gradio_port_env:
31
+ return int(gradio_port_env)
32
+ for candidate in range(7860, 7871):
33
+ if _is_port_free(candidate):
34
+ return candidate
35
+ return 7860
36
+
37
+
38
+ demo = build_demo()
39
+
40
+ if __name__ == "__main__":
41
+ port = _resolve_server_port()
42
+ print(f"Starting THIRAWAT Mapper Demo on http://127.0.0.1:{port}", flush=True)
43
+ demo.launch(server_name="0.0.0.0", server_port=port, share=False)
docs/upstream-instruction-issues.md ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Upstream Instruction Issues
2
+
3
+ This note captures discrepancies found while implementing the demo so they can be fixed upstream.
4
+
5
+ ## 1) `athena2duckdb` README default output mismatch
6
+ - Repo: <https://github.com/sidataplus/athena2duckdb>
7
+ - README says `--out` default is `vocab.duckdb`.
8
+ - CLI source (`src/athena2duckdb/cli.py`) uses `omop_vocab.duckdb`.
9
+ - Suggested fix:
10
+ - Align README default value with code, or
11
+ - Change code default to match README.
12
+
13
+ ## 2) `THIRAWAT-mapper` README says reranker is gated
14
+ - Repo: <https://github.com/sidataplus/THIRAWAT-mapper>
15
+ - README notes `sidataplus/THIRAWAT-SapBERT` is gated.
16
+ - Verified on **February 10, 2026** via HF model API that `gated=false`, `private=false`.
17
+ - Suggested fix:
18
+ - Update README access notes and include verification date.
19
+
20
+ ## 3) `THIRAWAT-mapper` Cloudflare markdown fence issue
21
+ - Repo: <https://github.com/sidataplus/THIRAWAT-mapper>
22
+ - Cloudflare section has a duplicated fenced code opener in README, causing broken rendering.
23
+ - Suggested fix:
24
+ - Remove the extra code fence and validate markdown rendering.
25
+
26
+ ## 4) `THIRAWAT-mapper` index build docs under-document MPS usage
27
+ - Repo: <https://github.com/sidataplus/THIRAWAT-mapper>
28
+ - README examples mostly show `--device cuda`.
29
+ - Apple Silicon users can run with `--device mps`; CPU fallback is also valid.
30
+ - Suggested fix:
31
+ - Add a short device matrix and one MPS example command.
32
+
33
+ ## 5) `THIRAWAT-mapper` docs can confuse on `--profiles-table`
34
+ - Repo: <https://github.com/sidataplus/THIRAWAT-mapper>
35
+ - README examples require `--profiles-table concept_profiles`.
36
+ - Athena-to-DuckDB outputs typically do not contain `concept_profiles`.
37
+ - Current builder code does fallback to inline profile creation if `profiles_table` is absent.
38
+ - Suggested fix:
39
+ - Document fallback behavior clearly and show an Athena-only example path.
40
+
41
+ ## 6) `athena2duckdb==0.1.0` import-time syntax error in `loader.py`
42
+ - Repo: <https://github.com/sidataplus/athena2duckdb>
43
+ - In our local run on **February 10, 2026**, importing `athena2duckdb` fails:
44
+ - `SyntaxError: f-string expression part cannot include a backslash`
45
+ - Affected code pattern in `loader.py`:
46
+ - `f" {',\\n '.join(columns_sql)}\\n"` inside an f-string expression.
47
+ - Impact:
48
+ - `build_duckdb.py` cannot execute with published `athena2duckdb==0.1.0`.
49
+ - Suggested fix:
50
+ - Refactor string assembly to avoid backslash-containing expression inside the f-string and release `0.1.1`.
51
+
52
+ ## 7) `THIRAWAT-mapper` docs mix `transformers` defaults with `--backend st` examples
53
+ - Repo: <https://github.com/sidataplus/thirawat-mapper>
54
+ - README states embedding backend default is `transformers` and says `--backend transformers` can be omitted.
55
+ - But `docs/retrieval_reranker.md` examples repeatedly use `--backend st` (vector build/eval/RAG examples).
56
+ - Suggested fix:
57
+ - Standardize examples to default `transformers`, or
58
+ - Add a clear "when to use `st` vs `transformers`" note and keep examples consistent with that guidance.
59
+
60
+ ## 8) `THIRAWAT-mapper` lacks troubleshooting note for SapBERT sentence-transformers warning
61
+ - Repo: <https://github.com/sidataplus/thirawat-mapper>
62
+ - During ColBERT/PyLate startup with SapBERT, users can see:
63
+ - `No sentence-transformers model found with name cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR.`
64
+ - This is often benign fallback behavior, but docs do not explain it, so users may treat it as a hard failure.
65
+ - Suggested fix:
66
+ - Add a short troubleshooting note clarifying expected behavior and when it is truly an error (actual load failure).
67
+
68
+ ## 9) `THIRAWAT-mapper` docs under-document deterministic post/tie-break inference rules
69
+ - Repo: <https://github.com/sidataplus/thirawat-mapper>
70
+ - Installed package behavior (`thirawat_mapper==0.1.4`) includes deterministic tie-break/post ranking (`tiebreak_rerank`, `enrich_with_post_scores`, `eps`, `topn`, and ordered tie-break keys).
71
+ - Public markdown docs focus on blend-style post scoring but do not document tie-break mode semantics and knobs.
72
+ - Suggested fix:
73
+ - Add an inference-ranking section documenting `blend|lex|tiebreak` modes, default parameters, and deterministic sorting behavior.
pyproject.toml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "thirawat-mapper-demo"
3
+ version = "0.1.0"
4
+ description = "THIRAWAT mapper demo split into offline index pipeline and HF Space runtime app."
5
+ readme = "README.md"
6
+ requires-python = ">=3.11,<3.13"
7
+ dependencies = [
8
+ "athena2duckdb==0.1.0",
9
+ "gradio==6.5.1",
10
+ "huggingface_hub>=0.27.0",
11
+ "lancedb==0.29.2",
12
+ "pandas>=2.2.0",
13
+ "peft>=0.17.0",
14
+ "thirawat-mapper==0.1.4",
15
+ ]
16
+
17
+ [dependency-groups]
18
+ dev = [
19
+ "pytest>=8.3.0",
20
+ ]
21
+
22
+ [tool.pytest.ini_options]
23
+ testpaths = ["tests"]
24
+
25
+ [build-system]
26
+ requires = ["hatchling>=1.24.0"]
27
+ build-backend = "hatchling.build"
28
+
29
+ [tool.hatch.build.targets.wheel]
30
+ packages = ["src/thirawat_demo"]
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==6.5.1
2
+ huggingface_hub>=0.27.0
3
+ lancedb==0.29.2
4
+ pandas>=2.2.0
5
+ peft>=0.17.0
6
+ thirawat-mapper==0.1.4
scripts/offline/build_duckdb.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Offline: convert Athena vocabulary export to DuckDB."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ from pathlib import Path
8
+
9
+ REPO_ROOT = Path(__file__).resolve().parents[2]
10
+ DEFAULT_DUCKDB_PATH = REPO_ROOT / "data" / "derived" / "concepts.duckdb"
11
+
12
+
13
+ def parse_args() -> argparse.Namespace:
14
+ parser = argparse.ArgumentParser(description="Build DuckDB from Athena vocabulary export.")
15
+ parser.add_argument("--athena-dir", required=True, help="Directory containing Athena CSV/TSV files.")
16
+ parser.add_argument(
17
+ "--out",
18
+ default=str(DEFAULT_DUCKDB_PATH),
19
+ help=f"Output DuckDB path (default: {DEFAULT_DUCKDB_PATH}).",
20
+ )
21
+ parser.add_argument("--overwrite", action="store_true", help="Overwrite existing DuckDB file.")
22
+ parser.add_argument("--threads", type=int, default=None, help="DuckDB threads (default: auto).")
23
+ parser.add_argument("--schema", default="main", help="Target schema name (default: main).")
24
+ parser.add_argument("--sep", default="\t", help="Input delimiter (default: tab).")
25
+ parser.add_argument("--encoding", default="UTF-8", help="Input encoding (default: UTF-8).")
26
+ return parser.parse_args()
27
+
28
+
29
+ def run(args: argparse.Namespace) -> int:
30
+ try:
31
+ from athena2duckdb import CSVOptions, load_vocab_dir, verify_row_counts
32
+ except SyntaxError as exc:
33
+ raise RuntimeError(
34
+ "Failed to import athena2duckdb due upstream syntax error. "
35
+ "Current workaround: use an already-built DuckDB file (for example a previous vocab.duckdb) "
36
+ "and continue with scripts/offline/build_lancedb_index.py."
37
+ ) from exc
38
+
39
+ athena_dir = Path(args.athena_dir).expanduser().resolve()
40
+ out_path = Path(args.out).expanduser().resolve()
41
+ out_path.parent.mkdir(parents=True, exist_ok=True)
42
+
43
+ if not athena_dir.exists():
44
+ raise FileNotFoundError(f"Athena directory does not exist: {athena_dir}")
45
+
46
+ csv_options = CSVOptions(sep=args.sep, encoding=args.encoding)
47
+
48
+ summary = load_vocab_dir(
49
+ input_dir=athena_dir,
50
+ out_path=out_path,
51
+ csv_options=csv_options,
52
+ overwrite=bool(args.overwrite),
53
+ threads=args.threads,
54
+ schema=args.schema,
55
+ )
56
+ print(f"Loaded {len(summary.vocab_files)} tables into {summary.db_path} (schema {summary.schema}).")
57
+
58
+ results = verify_row_counts(
59
+ db_path=summary.db_path,
60
+ vocab_files=summary.vocab_files,
61
+ csv_options=csv_options,
62
+ threads=args.threads,
63
+ schema=summary.schema,
64
+ )
65
+ mismatches = [result for result in results if not result.matches]
66
+ for result in results:
67
+ status = "OK" if result.matches else "MISMATCH"
68
+ print(
69
+ f"{status:9s} table={result.table_name:<25s} "
70
+ f"csv_rows={result.csv_rows:,} table_rows={result.table_rows:,}"
71
+ )
72
+
73
+ if mismatches:
74
+ print(f"Found {len(mismatches)} row-count mismatches.")
75
+ return 2
76
+ return 0
77
+
78
+
79
+ def main() -> int:
80
+ args = parse_args()
81
+ return run(args)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ raise SystemExit(main())
scripts/offline/build_lancedb_index.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Offline: build LanceDB index for THIRAWAT mapper."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ from pathlib import Path
8
+ import platform
9
+
10
+ REPO_ROOT = Path(__file__).resolve().parents[2]
11
+ DEFAULT_DUCKDB_PATH = REPO_ROOT / "data" / "derived" / "concepts.duckdb"
12
+ DEFAULT_LANCEDB_DIR = REPO_ROOT / "data" / "lancedb" / "db"
13
+
14
+ DEFAULT_DOMAIN_IDS = ["Drug"]
15
+ DEFAULT_CONCEPT_CLASSES = [
16
+ "Clinical Drug",
17
+ "Quant Clinical Drug",
18
+ "Clinical Drug Comp",
19
+ "Clinical Drug Form",
20
+ "Branded Drug",
21
+ "Quant Branded Drug",
22
+ "Branded Drug Comp",
23
+ "Branded Drug Form",
24
+ "Ingredient",
25
+ ]
26
+ DEFAULT_EXCLUDED_CONCEPT_CLASSES = [
27
+ "Clinical Drug Box",
28
+ "Branded Drug Box",
29
+ "Branded Pack Box",
30
+ "Clinical Pack Box",
31
+ "Marketed Product",
32
+ "Quant Branded Box",
33
+ "Quant Clinical Box",
34
+ ]
35
+ DEFAULT_EXTRA_COLUMNS = ["concept_name", "concept_code", "domain_id", "vocabulary_id", "concept_class_id"]
36
+ DEFAULT_MODEL_ID = "cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR"
37
+
38
+
39
+ def _split_multi(values: list[str] | None) -> list[str]:
40
+ if not values:
41
+ return []
42
+ items: list[str] = []
43
+ for value in values:
44
+ parts = [part.strip() for part in str(value).split(",")]
45
+ items.extend([part for part in parts if part])
46
+ return items
47
+
48
+
49
+ def resolve_device(device: str) -> str:
50
+ requested = device.strip().lower()
51
+ if requested != "auto":
52
+ return requested
53
+ try:
54
+ import torch # type: ignore
55
+ except Exception:
56
+ return "cpu"
57
+
58
+ is_darwin = platform.system().lower() == "darwin"
59
+ has_mps = bool(getattr(torch.backends, "mps", None)) and torch.backends.mps.is_available()
60
+ has_cuda = bool(torch.cuda.is_available())
61
+
62
+ # Policy requested for this repo:
63
+ # auto => macOS MPS first, then CUDA, then CPU.
64
+ if is_darwin and has_mps:
65
+ return "mps"
66
+ if has_cuda:
67
+ return "cuda"
68
+ return "cpu"
69
+
70
+
71
+ def parse_args() -> argparse.Namespace:
72
+ parser = argparse.ArgumentParser(description="Build LanceDB index using thirawat-mapper.")
73
+ parser.add_argument("--duckdb", default=str(DEFAULT_DUCKDB_PATH), help="Path to DuckDB vocabulary file.")
74
+ parser.add_argument("--profiles-table", default="concept_profiles", help="Profiles table name.")
75
+ parser.add_argument("--concepts-table", default="concept", help="Concepts table name.")
76
+ parser.add_argument("--out-db", default=str(DEFAULT_LANCEDB_DIR), help="Output LanceDB directory.")
77
+ parser.add_argument("--table", default="concepts_drug", help="Output LanceDB table name.")
78
+
79
+ parser.add_argument(
80
+ "--domain-id",
81
+ action="append",
82
+ default=None,
83
+ help="Domain filters, comma-separated or repeated (default: Drug).",
84
+ )
85
+ parser.add_argument(
86
+ "--concept-class-id",
87
+ action="append",
88
+ default=None,
89
+ help="Concept class filters, comma-separated or repeated.",
90
+ )
91
+ parser.add_argument(
92
+ "--exclude-concept-class-id",
93
+ action="append",
94
+ default=None,
95
+ help="Concept class exclusions, comma-separated or repeated.",
96
+ )
97
+ parser.add_argument(
98
+ "--extra-column",
99
+ action="append",
100
+ default=None,
101
+ help="Extra columns to carry into index table (comma-separated or repeated).",
102
+ )
103
+
104
+ parser.add_argument("--batch-size", type=int, default=256, help="Embedding batch size.")
105
+ parser.add_argument("--model-id", default=DEFAULT_MODEL_ID, help="Encoder model id.")
106
+ parser.add_argument("--pooling", choices=["cls", "mean"], default="cls", help="Pooling type.")
107
+ parser.add_argument("--max-length", type=int, default=128, help="Encoder max token length.")
108
+ parser.add_argument(
109
+ "--device",
110
+ choices=["auto", "mps", "cuda", "cpu"],
111
+ default="auto",
112
+ help="Index build device; auto resolves to mps (Darwin), then cuda, then cpu.",
113
+ )
114
+ parser.add_argument("--trust-remote-code", action="store_true", help="Pass trust_remote_code to encoder.")
115
+ return parser.parse_args()
116
+
117
+
118
+ def run(args: argparse.Namespace) -> int:
119
+ from thirawat_mapper.index.build import main as thirawat_index_build_main
120
+
121
+ duckdb_path = Path(args.duckdb).expanduser().resolve()
122
+ out_db = Path(args.out_db).expanduser().resolve()
123
+ out_db.mkdir(parents=True, exist_ok=True)
124
+
125
+ if not duckdb_path.exists():
126
+ raise FileNotFoundError(f"DuckDB file does not exist: {duckdb_path}")
127
+
128
+ resolved_device = resolve_device(args.device)
129
+ domain_ids = _split_multi(args.domain_id) or DEFAULT_DOMAIN_IDS
130
+ concept_classes = _split_multi(args.concept_class_id) or DEFAULT_CONCEPT_CLASSES
131
+ excluded_classes = _split_multi(args.exclude_concept_class_id) or DEFAULT_EXCLUDED_CONCEPT_CLASSES
132
+ extra_columns = _split_multi(args.extra_column) or DEFAULT_EXTRA_COLUMNS
133
+
134
+ cli_args = [
135
+ "--duckdb",
136
+ str(duckdb_path),
137
+ "--profiles-table",
138
+ args.profiles_table,
139
+ "--concepts-table",
140
+ args.concepts_table,
141
+ "--domain-id",
142
+ ",".join(domain_ids),
143
+ "--concept-class-id",
144
+ ",".join(concept_classes),
145
+ "--exclude-concept-class-id",
146
+ ",".join(excluded_classes),
147
+ "--extra-column",
148
+ ",".join(extra_columns),
149
+ "--out-db",
150
+ str(out_db),
151
+ "--table",
152
+ args.table,
153
+ "--batch-size",
154
+ str(args.batch_size),
155
+ "--model-id",
156
+ args.model_id,
157
+ "--pooling",
158
+ args.pooling,
159
+ "--max-length",
160
+ str(args.max_length),
161
+ "--device",
162
+ resolved_device,
163
+ ]
164
+ if args.trust_remote_code:
165
+ cli_args.append("--trust-remote-code")
166
+
167
+ print(f"Resolved build device: {resolved_device}")
168
+ print("Invoking: python -m thirawat_mapper.index.build " + " ".join(cli_args))
169
+ thirawat_index_build_main(cli_args)
170
+ return 0
171
+
172
+
173
+ def main() -> int:
174
+ args = parse_args()
175
+ return run(args)
176
+
177
+
178
+ if __name__ == "__main__":
179
+ raise SystemExit(main())
scripts/offline/publish_index_hf.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Offline: publish built index artifact to Hugging Face dataset repo."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ from pathlib import Path
8
+ import os
9
+
10
+ from huggingface_hub import HfApi, create_repo, upload_folder
11
+
12
+ REPO_ROOT = Path(__file__).resolve().parents[2]
13
+ DEFAULT_SOURCE_DIR = REPO_ROOT / "data" / "lancedb"
14
+
15
+
16
+ def parse_args() -> argparse.Namespace:
17
+ parser = argparse.ArgumentParser(description="Upload index artifact to Hugging Face dataset repo.")
18
+ parser.add_argument("--repo-id", required=True, help="Target HF dataset repo (e.g. org/name).")
19
+ parser.add_argument(
20
+ "--source-dir",
21
+ default=str(DEFAULT_SOURCE_DIR),
22
+ help=f"Directory to upload (default: {DEFAULT_SOURCE_DIR}).",
23
+ )
24
+ parser.add_argument("--revision", default="main", help="Target branch/revision (default: main).")
25
+ parser.add_argument("--private", action="store_true", help="Create repo as private if it does not exist.")
26
+ parser.add_argument(
27
+ "--token-env",
28
+ default="HF_TOKEN",
29
+ help="Environment variable name holding HF write token (default: HF_TOKEN).",
30
+ )
31
+ return parser.parse_args()
32
+
33
+
34
+ def run(args: argparse.Namespace) -> int:
35
+ source_dir = Path(args.source_dir).expanduser().resolve()
36
+ if not source_dir.exists():
37
+ raise FileNotFoundError(f"Source directory does not exist: {source_dir}")
38
+
39
+ token = os.getenv(args.token_env)
40
+ if not token:
41
+ raise ValueError(f"Missing Hugging Face token in environment variable: {args.token_env}")
42
+
43
+ create_repo(
44
+ repo_id=args.repo_id,
45
+ repo_type="dataset",
46
+ private=bool(args.private),
47
+ token=token,
48
+ exist_ok=True,
49
+ )
50
+
51
+ api = HfApi(token=token)
52
+ if args.revision != "main":
53
+ try:
54
+ api.create_branch(repo_id=args.repo_id, repo_type="dataset", branch=args.revision, exist_ok=True)
55
+ except TypeError:
56
+ # Backward-compatible path for clients without exist_ok support.
57
+ try:
58
+ api.create_branch(repo_id=args.repo_id, repo_type="dataset", branch=args.revision)
59
+ except Exception:
60
+ pass
61
+
62
+ commit_info = upload_folder(
63
+ repo_id=args.repo_id,
64
+ repo_type="dataset",
65
+ folder_path=str(source_dir),
66
+ path_in_repo=".",
67
+ revision=args.revision,
68
+ commit_message="Update THIRAWAT mapper demo index artifact",
69
+ token=token,
70
+ )
71
+ print(f"Uploaded index artifact to {args.repo_id}@{args.revision}")
72
+ print(f"Commit URL: {commit_info.commit_url}")
73
+ return 0
74
+
75
+
76
+ def main() -> int:
77
+ args = parse_args()
78
+ return run(args)
79
+
80
+
81
+ if __name__ == "__main__":
82
+ raise SystemExit(main())
83
+
spec/athena ADDED
Binary file (948 Bytes). View file
 
spec/example.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Drug Concept Entity Linking - HuggingFace Space"""
2
+
3
+ import os
4
+ import tempfile
5
+ import traceback
6
+ from pathlib import Path
7
+ import gradio as gr
8
+ import lancedb
9
+ from sentence_transformers import SentenceTransformer
10
+ import pandas as pd
11
+
12
+ # ===== CONFIG =====
13
+ # ดึงจาก Space Secrets (ตั้งค่าใน Settings > Secrets)
14
+ def _get_env(name: str, default: str | None = None) -> str | None:
15
+ value = os.environ.get(name)
16
+ if value is None:
17
+ return default
18
+ value = value.strip()
19
+ if not value or value.lower() == "none":
20
+ return default
21
+ return value
22
+
23
+
24
+ HF_TOKEN = _get_env("HF_TOKEN") # หรือไม่ใส่ก็ได้ถ้า public
25
+ INDEX_REPO = _get_env("INDEX_REPO", "amnnma/drug-concept-index") # เปลี่ยนชื่อ repo
26
+ LOCAL_INDEX_PATH = _get_env("LOCAL_INDEX_PATH", "data/lancedb")
27
+ DEBUG = _get_env("DEBUG", "0") == "1"
28
+
29
+ # Model
30
+ MODEL_ID = "cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR"
31
+ TOP_K = 10
32
+
33
+ class DrugConceptSearcher:
34
+ def __init__(self):
35
+ self.model = None
36
+ self.db = None
37
+ self.table = None
38
+ self._load()
39
+
40
+ def _load(self):
41
+ """Load model and connect to LanceDB"""
42
+ print("Loading model...")
43
+ # Force slow tokenizer to avoid fast-tokenizer conversion issues on Space
44
+ self.model = SentenceTransformer(MODEL_ID, tokenizer_kwargs={"use_fast": False})
45
+
46
+ # Prefer local index when available (useful for local runs)
47
+ local_root = Path(LOCAL_INDEX_PATH) if LOCAL_INDEX_PATH else None
48
+ if local_root and local_root.exists() and (local_root / "db").exists():
49
+ index_root = local_root
50
+ print(f"Connecting to local LanceDB at {index_root}...")
51
+ else:
52
+ repo_id = INDEX_REPO or "amnnma/drug-concept-index"
53
+ if not isinstance(repo_id, str):
54
+ repo_id = str(repo_id)
55
+ repo_id = repo_id.strip()
56
+ if repo_id.startswith("http"):
57
+ # Accept full HF URLs and extract the repo id
58
+ parts = repo_id.split("/")
59
+ if "datasets" in parts:
60
+ repo_id = "/".join(parts[parts.index("datasets") + 1 :]).strip("/")
61
+ elif "spaces" in parts:
62
+ repo_id = "/".join(parts[parts.index("spaces") + 1 :]).strip("/")
63
+ else:
64
+ repo_id = "/".join(parts[-2:]).strip("/")
65
+ if repo_id.startswith("datasets/"):
66
+ repo_id = repo_id[len("datasets/") :]
67
+ print(f"Connecting to LanceDB from {repo_id}...")
68
+ # Download และ connect ไปยัง LanceDB ใน HF repo
69
+ from huggingface_hub import snapshot_download
70
+
71
+ # Download index (cache ไว้ใน /data)
72
+ download_root = Path(os.environ.get("HF_DATA_DIR", "/data")) / "lancedb"
73
+ try:
74
+ download_root.mkdir(parents=True, exist_ok=True)
75
+ except OSError:
76
+ download_root = Path("data/lancedb")
77
+ download_root.mkdir(parents=True, exist_ok=True)
78
+
79
+ # Avoid implicit token usage for public datasets
80
+ os.environ["HF_HUB_DISABLE_IMPLICIT_TOKEN"] = "1"
81
+ try:
82
+ index_root = Path(
83
+ snapshot_download(
84
+ repo_id=repo_id,
85
+ repo_type="dataset",
86
+ token=False,
87
+ revision=os.environ.get("HF_DATASET_REVISION", "main"),
88
+ local_dir=str(download_root),
89
+ )
90
+ )
91
+ except Exception as e:
92
+ if HF_TOKEN:
93
+ index_root = Path(
94
+ snapshot_download(
95
+ repo_id=repo_id,
96
+ repo_type="dataset",
97
+ token=HF_TOKEN,
98
+ local_dir=str(download_root),
99
+ )
100
+ )
101
+ else:
102
+ raise e
103
+
104
+ # Connect to LanceDB
105
+ self.db = lancedb.connect(str(index_root / "db"))
106
+ self.table = self.db.open_table("concepts_drug")
107
+ print("✅ Ready!")
108
+
109
+ def search(self, query: str, top_k: int = TOP_K):
110
+ """Search drug concepts"""
111
+ if not query or not query.strip():
112
+ return pd.DataFrame()
113
+
114
+ # Encode query
115
+ query_emb = self.model.encode(query, normalize_embeddings=True)
116
+
117
+ # Search
118
+ results = self.table.search(query_emb).limit(top_k).to_pandas()
119
+
120
+ # Format output
121
+ if "_distance" in results.columns:
122
+ results["score"] = 1 - results["_distance"] # Convert distance to similarity
123
+ results = results.sort_values("score", ascending=False)
124
+
125
+ return results[["concept_id", "concept_name", "concept_code", "vocabulary_id", "score"]]
126
+
127
+ # Initialize
128
+ searcher = None
129
+
130
+ def get_searcher():
131
+ global searcher
132
+ if searcher is None:
133
+ searcher = DrugConceptSearcher()
134
+ return searcher
135
+
136
+ def _format_results(results: pd.DataFrame, query: str) -> tuple[str, pd.DataFrame]:
137
+ if results.empty:
138
+ return "No results found. Try a different search term.", results
139
+
140
+ output = f"## Results for: \"{query}\"\n\n"
141
+ best = results.iloc[0]
142
+ output += f"**Top match:** {best['concept_name']} (score {best['score']:.4f})\n\n"
143
+ return output, results
144
+
145
+
146
+ def search_drugs(query: str, top_k: int):
147
+ """Gradio search function (single query)"""
148
+ try:
149
+ s = get_searcher()
150
+ results = s.search(query, top_k)
151
+
152
+ output, table = _format_results(results, query)
153
+ return output, table
154
+ except Exception as e:
155
+ print("Search error:", e)
156
+ print(traceback.format_exc())
157
+ if DEBUG:
158
+ return f"❌ Error: {str(e)}\n\n```\n{traceback.format_exc()}\n```", pd.DataFrame()
159
+ return f"❌ Error: {str(e)}", pd.DataFrame()
160
+
161
+
162
+ def search_batch(queries_text: str, top_k: int):
163
+ """Gradio search function (batch queries)"""
164
+ try:
165
+ if not queries_text or not queries_text.strip():
166
+ return "Please enter clinical terms to search.", gr.update(visible=False)
167
+
168
+ lines = [line.strip() for line in queries_text.splitlines() if line.strip()]
169
+ if not lines:
170
+ return "No valid queries found.", gr.update(visible=False)
171
+
172
+ s = get_searcher()
173
+ rows = []
174
+ for q in lines:
175
+ results = s.search(q, top_k)
176
+ for i, (_, row) in enumerate(results.iterrows(), start=1):
177
+ rows.append(
178
+ {
179
+ "query_text": q,
180
+ "rank": i,
181
+ "concept_id": row["concept_id"],
182
+ "concept_name": row["concept_name"],
183
+ "concept_code": row["concept_code"],
184
+ "vocabulary_id": row["vocabulary_id"],
185
+ "score": float(row["score"]),
186
+ }
187
+ )
188
+
189
+ if not rows:
190
+ return "No results found.", gr.update(visible=False)
191
+
192
+ df = pd.DataFrame(rows)
193
+ tmp_dir = Path(tempfile.gettempdir()) / "thirawat_results"
194
+ tmp_dir.mkdir(parents=True, exist_ok=True)
195
+ out_path = tmp_dir / "batch_results.csv"
196
+ df.to_csv(out_path, index=False)
197
+
198
+ md = f"""## Batch Search Complete
199
+
200
+ - **Queries processed:** {len(lines)}
201
+ - **Rows returned:** {len(rows)}
202
+ - **Top-K per query:** {top_k}
203
+ """
204
+ return md, gr.update(value=str(out_path), visible=True)
205
+ except Exception as e:
206
+ print("Batch search error:", e)
207
+ print(traceback.format_exc())
208
+ if DEBUG:
209
+ return f"❌ Error: {str(e)}\n\n```\n{traceback.format_exc()}\n```", gr.update(visible=False)
210
+ return f"❌ Error: {str(e)}", gr.update(visible=False)
211
+
212
+ # ===== GRADIO INTERFACE =====
213
+ with gr.Blocks(title="THIRAWAT - Drug Concept Search") as demo:
214
+ gr.HTML(
215
+ """
216
+ <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 10px; margin-bottom: 20px;">
217
+ <h1 style="color: white; margin: 0; font-size: 2em;">THIRAWAT</h1>
218
+ <p style="color: rgba(255,255,255,0.9); margin: 5px 0 0 0;">Drug Concept Entity Linking</p>
219
+ <p style="color: rgba(255,255,255,0.8); margin: 5px 0 0 0;">Map drug names to OMOP concepts using SapBERT + LanceDB.</p>
220
+ </div>
221
+ """
222
+ )
223
+
224
+ with gr.Tabs():
225
+ with gr.Tab("Single Query"):
226
+ with gr.Row():
227
+ with gr.Column(scale=3):
228
+ query_input = gr.Textbox(
229
+ label="Drug name or query",
230
+ placeholder="e.g., aspirin, paracetamol, amoxicillin 500mg...",
231
+ lines=2,
232
+ )
233
+ with gr.Column(scale=1):
234
+ domain_hint = gr.Dropdown(
235
+ label="Domain",
236
+ choices=["Drug", "Condition", "Procedure", "Observation", "Device", "Unit"],
237
+ value="Drug",
238
+ interactive=False,
239
+ )
240
+ top_k = gr.Slider(
241
+ minimum=1,
242
+ maximum=50,
243
+ value=10,
244
+ step=1,
245
+ label="Number of results",
246
+ )
247
+
248
+ with gr.Row():
249
+ search_btn = gr.Button("Search", variant="primary")
250
+ clear_btn = gr.Button("Clear", variant="secondary")
251
+
252
+ output_md = gr.Markdown(label="Results")
253
+ output_table = gr.Dataframe(label="Results Table", interactive=False)
254
+
255
+ with gr.Tab("Batch Query"):
256
+ with gr.Row():
257
+ with gr.Column(scale=3):
258
+ batch_queries = gr.Textbox(
259
+ label="Drug names (one per line)",
260
+ placeholder="aspirin\nparacetamol\namoxicillin 500mg",
261
+ lines=10,
262
+ )
263
+ with gr.Column(scale=1):
264
+ batch_domain_hint = gr.Dropdown(
265
+ label="Domain",
266
+ choices=["Drug", "Condition", "Procedure", "Observation", "Device", "Unit"],
267
+ value="Drug",
268
+ interactive=False,
269
+ )
270
+ batch_topk = gr.Slider(
271
+ minimum=1,
272
+ maximum=50,
273
+ value=10,
274
+ step=1,
275
+ label="Top-K per query",
276
+ )
277
+
278
+ with gr.Row():
279
+ batch_btn = gr.Button("Process Batch", variant="primary")
280
+ batch_clear = gr.Button("Clear", variant="secondary")
281
+
282
+ batch_output = gr.Markdown(label="Summary")
283
+ batch_download = gr.DownloadButton(
284
+ label="Download Results (CSV)",
285
+ variant="secondary",
286
+ visible=False,
287
+ )
288
+
289
+ def clear_single():
290
+ return "", 10, "", pd.DataFrame()
291
+
292
+ def clear_batch():
293
+ return "", 10, "", gr.update(visible=False)
294
+
295
+ search_btn.click(
296
+ fn=search_drugs,
297
+ inputs=[query_input, top_k],
298
+ outputs=[output_md, output_table],
299
+ api_name=False,
300
+ )
301
+ clear_btn.click(
302
+ fn=clear_single,
303
+ outputs=[query_input, top_k, output_md, output_table],
304
+ api_name=False,
305
+ )
306
+
307
+ batch_btn.click(
308
+ fn=search_batch,
309
+ inputs=[batch_queries, batch_topk],
310
+ outputs=[batch_output, batch_download],
311
+ api_name=False,
312
+ )
313
+ batch_clear.click(
314
+ fn=clear_batch,
315
+ outputs=[batch_queries, batch_topk, batch_output, batch_download],
316
+ api_name=False,
317
+ )
318
+ gr.Markdown(
319
+ """
320
+ ---
321
+
322
+ **THIRAWAT** is a dense retrieval toolkit for mapping drug terminology to OMOP standard concepts.
323
+ """
324
+ )
325
+
326
+ if __name__ == "__main__":
327
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
spec/spec.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIRAWAT-mapper-demo
2
+
3
+ End-to-end demo of THIRAWAT-mapper, a tool for mapping concepts from non-standard terminologies to standard terminologies in OHDSI/OMOP CDM.
4
+
5
+ ## Key steps
6
+
7
+ 1. Turn the [vocab set downloaded from Athena](spec/athena) into DuckDB format using `pip install athena2duckdb` [https://pypi.org/project/athena2duckdb/]
8
+ 2. Follow instructions in [THIRAWAT-mapper](https://github.com/sidataplus/THIRAWAT-mapper)
9
+ 3. Use [https://huggingface.co/cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR] for retrieval with CLS pooling and [https://huggingface.co/sidataplus/THIRAWAT-SapBERT] for ColBERT reranker
10
+ 4. Build a complete gradio app, see prelim example [spec/example.py]
11
+ 5. Package everything into a Hugging Face Space [https://huggingface.co/spaces/sidataplus/THIRAWAT-mapper-demo]
src/thirawat_demo/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """THIRAWAT mapper demo package."""
2
+
src/thirawat_demo/runtime/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """Runtime modules for the Hugging Face Space app."""
2
+
3
+ from .config import RuntimeConfig
4
+ from .index_loader import resolve_lancedb_dir
5
+ from .peft_reranker import ThirawatPeftReranker
6
+ from .search_service import SearchService
7
+
8
+ __all__ = ["RuntimeConfig", "resolve_lancedb_dir", "ThirawatPeftReranker", "SearchService"]
src/thirawat_demo/runtime/config.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Runtime configuration for the Gradio app."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ import os
8
+
9
+ DEFAULT_ENCODER_MODEL_ID = "cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR"
10
+ DEFAULT_RERANKER_ID = "sidataplus/THIRAWAT-SapBERT"
11
+
12
+
13
+ def _clean_env(name: str, default: str | None = None) -> str | None:
14
+ value = os.getenv(name)
15
+ if value is None:
16
+ return default
17
+ cleaned = value.strip()
18
+ if cleaned == "" or cleaned.lower() == "none":
19
+ return default
20
+ return cleaned
21
+
22
+
23
+ def _env_int(name: str, default: int) -> int:
24
+ value = _clean_env(name)
25
+ if value is None:
26
+ return default
27
+ return int(value)
28
+
29
+
30
+ def _env_float(name: str, default: float) -> float:
31
+ value = _clean_env(name)
32
+ if value is None:
33
+ return default
34
+ return float(value)
35
+
36
+
37
+ def _env_bool(name: str, default: bool = False) -> bool:
38
+ value = _clean_env(name)
39
+ if value is None:
40
+ return default
41
+ return value.lower() in {"1", "true", "yes", "on"}
42
+
43
+
44
+ @dataclass(frozen=True)
45
+ class RuntimeConfig:
46
+ local_index_path: Path
47
+ index_repo: str | None
48
+ index_revision: str
49
+ hf_token: str | None
50
+ hf_data_dir: Path
51
+ lancedb_table: str
52
+ top_k_default: int
53
+ candidate_topk: int
54
+ retrieval_topk: int
55
+ device: str
56
+ encoder_model_id: str
57
+ reranker_id: str
58
+ post_mode: str
59
+ post_weight: float
60
+ tiebreak_eps: float
61
+ tiebreak_topn: int
62
+ post_strength_weight: float
63
+ post_jaccard_weight: float
64
+ post_brand_penalty: float
65
+ post_minmax: bool
66
+ brand_strict: bool
67
+ debug: bool
68
+
69
+ @classmethod
70
+ def from_env(cls) -> "RuntimeConfig":
71
+ local_index_path = Path(_clean_env("LOCAL_INDEX_PATH", "data/lancedb") or "data/lancedb")
72
+ index_repo = _clean_env("INDEX_REPO")
73
+ index_revision = _clean_env("INDEX_REVISION", "main") or "main"
74
+ hf_token = _clean_env("HF_TOKEN")
75
+ hf_data_dir = Path(_clean_env("HF_DATA_DIR", "/data") or "/data")
76
+ lancedb_table = _clean_env("LANCEDB_TABLE", "concepts_drug") or "concepts_drug"
77
+ top_k_default = _env_int("TOP_K_DEFAULT", 10)
78
+ candidate_topk = _env_int("CANDIDATE_TOPK", 100)
79
+ retrieval_topk = _env_int("RETRIEVAL_TOPK", 200)
80
+ device = (_clean_env("DEVICE", "auto") or "auto").lower()
81
+ encoder_model_id = _clean_env("ENCODER_MODEL_ID", DEFAULT_ENCODER_MODEL_ID) or DEFAULT_ENCODER_MODEL_ID
82
+ reranker_id = _clean_env("RERANKER_ID", DEFAULT_RERANKER_ID) or DEFAULT_RERANKER_ID
83
+ post_mode = (_clean_env("POST_MODE", "tiebreak") or "tiebreak").strip().lower()
84
+ post_weight = _env_float("POST_WEIGHT", 0.05)
85
+ tiebreak_eps = _env_float("TIEBREAK_EPS", 0.01)
86
+ tiebreak_topn = _env_int("TIEBREAK_TOPN", 50)
87
+ post_strength_weight = _env_float("POST_STRENGTH_WEIGHT", 0.6)
88
+ post_jaccard_weight = _env_float("POST_JACCARD_WEIGHT", 0.4)
89
+ post_brand_penalty = _env_float("POST_BRAND_PENALTY", 0.3)
90
+ post_minmax = _env_bool("POST_MINMAX", True)
91
+ brand_strict = _env_bool("BRAND_STRICT", False)
92
+ debug = _env_bool("DEBUG", False)
93
+
94
+ if candidate_topk <= 0:
95
+ raise ValueError("CANDIDATE_TOPK must be > 0.")
96
+ if retrieval_topk < candidate_topk:
97
+ retrieval_topk = candidate_topk
98
+ if top_k_default <= 0:
99
+ raise ValueError("TOP_K_DEFAULT must be > 0.")
100
+ if device not in {"cpu", "cuda", "mps", "auto"}:
101
+ raise ValueError("DEVICE must be one of: auto, cpu, cuda, mps.")
102
+ if post_mode not in {"blend", "tiebreak", "lex"}:
103
+ raise ValueError("POST_MODE must be one of: blend, tiebreak, lex.")
104
+ if tiebreak_topn <= 0:
105
+ raise ValueError("TIEBREAK_TOPN must be > 0.")
106
+
107
+ local_has_index = (local_index_path / "db").exists() or (
108
+ local_index_path.name == "db" and local_index_path.exists()
109
+ )
110
+ if not local_has_index and not index_repo:
111
+ raise ValueError(
112
+ "INDEX_REPO is required when LOCAL_INDEX_PATH does not point to an existing index."
113
+ )
114
+
115
+ return cls(
116
+ local_index_path=local_index_path,
117
+ index_repo=index_repo,
118
+ index_revision=index_revision,
119
+ hf_token=hf_token,
120
+ hf_data_dir=hf_data_dir,
121
+ lancedb_table=lancedb_table,
122
+ top_k_default=top_k_default,
123
+ candidate_topk=candidate_topk,
124
+ retrieval_topk=retrieval_topk,
125
+ device=device,
126
+ encoder_model_id=encoder_model_id,
127
+ reranker_id=reranker_id,
128
+ post_mode=post_mode,
129
+ post_weight=post_weight,
130
+ tiebreak_eps=tiebreak_eps,
131
+ tiebreak_topn=tiebreak_topn,
132
+ post_strength_weight=post_strength_weight,
133
+ post_jaccard_weight=post_jaccard_weight,
134
+ post_brand_penalty=post_brand_penalty,
135
+ post_minmax=post_minmax,
136
+ brand_strict=brand_strict,
137
+ debug=debug,
138
+ )
src/thirawat_demo/runtime/index_loader.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Index download/load helpers for runtime app."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ from huggingface_hub import snapshot_download
8
+
9
+ from .config import RuntimeConfig
10
+
11
+
12
+ def _can_connect_lancedb(path: Path) -> bool:
13
+ try:
14
+ import lancedb
15
+
16
+ lancedb.connect(str(path))
17
+ return True
18
+ except Exception:
19
+ return False
20
+
21
+
22
+ def _find_lancedb_dir(root: Path) -> Path:
23
+ candidates = []
24
+ if root.name == "db":
25
+ candidates.append(root)
26
+ candidates.append(root / "db")
27
+ candidates.append(root)
28
+
29
+ seen: set[Path] = set()
30
+ for candidate in candidates:
31
+ if candidate in seen:
32
+ continue
33
+ seen.add(candidate)
34
+ if candidate.exists() and candidate.is_dir() and _can_connect_lancedb(candidate):
35
+ return candidate
36
+
37
+ for candidate in root.rglob("db"):
38
+ if candidate.is_dir() and _can_connect_lancedb(candidate):
39
+ return candidate
40
+
41
+ raise FileNotFoundError(f"Could not find a valid LanceDB directory under: {root}")
42
+
43
+
44
+ def resolve_lancedb_dir(config: RuntimeConfig) -> Path:
45
+ local = config.local_index_path
46
+ if local.exists():
47
+ return _find_lancedb_dir(local)
48
+
49
+ if not config.index_repo:
50
+ raise ValueError("INDEX_REPO must be configured when local index is unavailable.")
51
+
52
+ download_root = config.hf_data_dir / "lancedb"
53
+ try:
54
+ download_root.mkdir(parents=True, exist_ok=True)
55
+ except OSError:
56
+ download_root = Path("data/lancedb")
57
+ download_root.mkdir(parents=True, exist_ok=True)
58
+
59
+ snapshot_path = Path(
60
+ snapshot_download(
61
+ repo_id=config.index_repo,
62
+ repo_type="dataset",
63
+ revision=config.index_revision,
64
+ token=config.hf_token or False,
65
+ local_dir=str(download_root),
66
+ )
67
+ )
68
+ return _find_lancedb_dir(snapshot_path)
69
+
src/thirawat_demo/runtime/peft_reranker.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PEFT-aware THIRAWAT reranker compatible with LanceDB rerank API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from os import PathLike
6
+ from pathlib import Path
7
+ from typing import List, Optional
8
+
9
+ import numpy as np
10
+ import pyarrow as pa
11
+ import torch
12
+
13
+ from thirawat_mapper.models.bms_pooling import bms_scores
14
+
15
+ DEFAULT_RERANKER_ID = "sidataplus/THIRAWAT-SapBERT"
16
+
17
+
18
+ def disable_mean_resizing() -> None:
19
+ """Patch HF resize_token_embeddings default to mean_resizing=False."""
20
+ try:
21
+ from transformers import PreTrainedModel # type: ignore
22
+ except Exception:
23
+ return
24
+ if getattr(PreTrainedModel, "_mean_resizing_patched", False):
25
+ return
26
+
27
+ orig_resize = PreTrainedModel.resize_token_embeddings
28
+
29
+ def patched(self, new_num_tokens=None, pad_to_multiple_of=None, mean_resizing=False, **kwargs):
30
+ return orig_resize(
31
+ self,
32
+ new_num_tokens=new_num_tokens,
33
+ pad_to_multiple_of=pad_to_multiple_of,
34
+ mean_resizing=mean_resizing,
35
+ **kwargs,
36
+ )
37
+
38
+ PreTrainedModel.resize_token_embeddings = patched # type: ignore[assignment]
39
+ PreTrainedModel._mean_resizing_patched = True # type: ignore[attr-defined]
40
+
41
+
42
+ def _resolve_model_id(model_id: str | PathLike[str]) -> str:
43
+ try:
44
+ path = Path(model_id).expanduser()
45
+ except TypeError:
46
+ return str(model_id)
47
+ if path.exists():
48
+ return str(path.resolve())
49
+ return str(model_id)
50
+
51
+
52
+ def _load_colbert_with_peft(
53
+ model_id: str,
54
+ device: Optional[str],
55
+ max_query_len: int,
56
+ max_doc_len: int,
57
+ ):
58
+ """Load ColBERT and attach PEFT adapter when the checkpoint is adapter-only."""
59
+
60
+ disable_mean_resizing()
61
+
62
+ try:
63
+ from peft import PeftConfig, PeftModel # type: ignore
64
+ except Exception:
65
+ return None
66
+
67
+ try:
68
+ peft_cfg = PeftConfig.from_pretrained(model_id)
69
+ except Exception:
70
+ return None
71
+
72
+ try:
73
+ from pylate.models import ColBERT # type: ignore
74
+ from safetensors.torch import load_file # type: ignore
75
+ except Exception:
76
+ return None
77
+
78
+ base_id = peft_cfg.base_model_name_or_path
79
+ tok_kwargs = {"tokenizer_name_or_path": model_id}
80
+ model = ColBERT(
81
+ model_name_or_path=base_id,
82
+ device=device or None,
83
+ query_length=int(max_query_len),
84
+ document_length=int(max_doc_len),
85
+ tokenizer_kwargs=tok_kwargs,
86
+ )
87
+
88
+ encoder = model._first_module().auto_model # type: ignore[attr-defined]
89
+ try:
90
+ encoder.resize_token_embeddings(len(model.tokenizer), mean_resizing=False)
91
+ except Exception:
92
+ pass
93
+
94
+ try:
95
+ encoder = PeftModel.from_pretrained(encoder, model_id, is_trainable=False)
96
+ encoder.eval()
97
+ model._first_module().auto_model = encoder # type: ignore[attr-defined]
98
+ except Exception:
99
+ return None
100
+
101
+ dense_dir = Path(model_id).expanduser() / "1_Dense"
102
+ dense_weights = None
103
+ if dense_dir.exists():
104
+ safetensor_path = dense_dir / "model.safetensors"
105
+ if safetensor_path.exists():
106
+ try:
107
+ dense_weights = load_file(safetensor_path)
108
+ except Exception:
109
+ dense_weights = None
110
+ if dense_weights is None:
111
+ legacy_bin = dense_dir / "pytorch_model.bin"
112
+ if legacy_bin.exists():
113
+ try:
114
+ dense_weights = torch.load(legacy_bin, map_location="cpu")
115
+ except Exception:
116
+ dense_weights = None
117
+ if dense_weights and len(model) > 1:
118
+ try:
119
+ model[1].load_state_dict(dense_weights, strict=False)
120
+ except Exception:
121
+ pass
122
+ return model
123
+
124
+
125
+ class _PylateColbert:
126
+ def __init__(self, model_id: str, device: Optional[str], max_query_len: int = 128, max_doc_len: int = 128):
127
+ from pylate.models import ColBERT # type: ignore
128
+
129
+ peft_model = _load_colbert_with_peft(model_id, device, max_query_len, max_doc_len)
130
+ if peft_model is not None:
131
+ self.model = peft_model
132
+ else:
133
+ self.model = ColBERT(
134
+ model_name_or_path=model_id,
135
+ device=device or None,
136
+ query_length=int(max_query_len),
137
+ document_length=int(max_doc_len),
138
+ )
139
+ try:
140
+ if device:
141
+ self.model.to(torch.device(device))
142
+ except Exception:
143
+ pass
144
+
145
+ def encode_query(self, text: str):
146
+ tok = self.model.tokenize([text], is_query=True, pad=True)
147
+ try:
148
+ device = next(self.model.parameters()).device
149
+ tok = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in tok.items()}
150
+ except Exception:
151
+ pass
152
+ out = self.model(tok)
153
+ return out["token_embeddings"], out.get("attention_mask")
154
+
155
+ def encode_docs(self, texts: List[str]):
156
+ tok = self.model.tokenize(texts, is_query=False, pad=False)
157
+ try:
158
+ device = next(self.model.parameters()).device
159
+ tok = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in tok.items()}
160
+ except Exception:
161
+ pass
162
+ out = self.model(tok)
163
+ return out["token_embeddings"], out.get("attention_mask")
164
+
165
+
166
+ class ThirawatPeftReranker:
167
+ """LanceDB-compatible THIRAWAT reranker with PEFT adapter support."""
168
+
169
+ def __init__(
170
+ self,
171
+ model_id: str | PathLike[str] = DEFAULT_RERANKER_ID,
172
+ *,
173
+ device: str | None = None,
174
+ return_score: str = "all",
175
+ column: str = "profile_text",
176
+ ) -> None:
177
+ self.model_id = _resolve_model_id(model_id)
178
+ self.device = device
179
+ self.return_score = return_score
180
+ self.score = return_score
181
+ self.column = column
182
+ self._scorer: Optional[_PylateColbert] = None
183
+
184
+ @property
185
+ def scorer(self) -> _PylateColbert:
186
+ if self._scorer is None:
187
+ self._scorer = _PylateColbert(self.model_id, self.device)
188
+ return self._scorer
189
+
190
+ def rerank(self, query: str | None, results: pa.Table) -> pa.Table:
191
+ return self.rerank_vector(query, results)
192
+
193
+ def rerank_vector(self, query: str | None, vector_results: pa.Table) -> pa.Table:
194
+ col = self.column if self.column in vector_results.column_names else None
195
+ if col is None and "profile_text_norm" in vector_results.column_names:
196
+ col = "profile_text_norm"
197
+ if col is None:
198
+ raise ValueError(f"Candidate table missing '{self.column}' (or 'profile_text_norm') column.")
199
+
200
+ texts: List[str] = vector_results[col].to_pylist()
201
+ qtext = str(query or "").strip()
202
+ if not qtext:
203
+ raise ValueError("A text query is required for reranking.")
204
+
205
+ q_emb, q_mask = self.scorer.encode_query(qtext)
206
+ d_emb, d_mask = self.scorer.encode_docs(texts)
207
+ with torch.no_grad():
208
+ scores = bms_scores(q_emb, d_emb, q_mask, d_mask)
209
+ scores = scores.detach().float().cpu().numpy().reshape(-1)
210
+ df = vector_results.to_pandas()
211
+ df["_relevance_score"] = scores.astype(np.float32)
212
+ df = df.sort_values("_relevance_score", ascending=False, kind="mergesort").reset_index(drop=True)
213
+ return pa.Table.from_pandas(df, preserve_index=False)
214
+
215
+ def rerank_fts(self, query: str | None, fts_results: pa.Table) -> pa.Table:
216
+ return self.rerank_vector(query, fts_results)
217
+
218
+ def rerank_hybrid(self, query: str | None, vector_results: pa.Table, fts_results: pa.Table) -> pa.Table:
219
+ return self.rerank_vector(query, vector_results)
220
+
221
+
222
+ __all__ = [
223
+ "DEFAULT_RERANKER_ID",
224
+ "ThirawatPeftReranker",
225
+ "_load_colbert_with_peft",
226
+ "disable_mean_resizing",
227
+ ]
src/thirawat_demo/runtime/search_service.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Two-stage retrieval+rereanking runtime service."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Any, Sequence
8
+
9
+ import pandas as pd
10
+ from thirawat_mapper.infer.utils import enrich_with_post_scores, should_apply_post
11
+
12
+ from .config import RuntimeConfig
13
+ from .peft_reranker import ThirawatPeftReranker
14
+
15
+ DEFAULT_CONCEPT_CLASSES = [
16
+ "Clinical Drug",
17
+ "Quant Clinical Drug",
18
+ "Clinical Drug Comp",
19
+ "Clinical Drug Form",
20
+ "Branded Drug",
21
+ "Quant Branded Drug",
22
+ "Branded Drug Comp",
23
+ "Branded Drug Form",
24
+ "Ingredient",
25
+ ]
26
+
27
+
28
+ class SearchService:
29
+ """Runtime search service with mandatory two-stage ranking."""
30
+
31
+ def __init__(self, config: RuntimeConfig, lancedb_dir: Path) -> None:
32
+ self.config = config
33
+ self.lancedb_dir = lancedb_dir
34
+ self._table = None
35
+ self._vector_column: str | None = None
36
+ self._embedder = None
37
+ self._reranker = None
38
+ self._normalize_text = None
39
+ self._available_concept_classes = self._load_manifest_concept_classes()
40
+
41
+ def startup(self) -> None:
42
+ try:
43
+ from thirawat_mapper.infer.utils import configure_torch_for_infer, resolve_device
44
+ from thirawat_mapper.models import SapBERTEmbedder
45
+ from thirawat_mapper.utils import connect_table, normalize_text_value
46
+ except Exception as exc: # pragma: no cover - import-time failure path
47
+ raise RuntimeError(f"Failed to import thirawat-mapper runtime dependencies: {exc}") from exc
48
+
49
+ device = self.config.device
50
+ if device == "auto":
51
+ device = resolve_device("auto")
52
+
53
+ configure_torch_for_infer(device)
54
+
55
+ try:
56
+ table, vector_column = connect_table(self.lancedb_dir, self.config.lancedb_table)
57
+ except Exception as exc:
58
+ raise RuntimeError(
59
+ f"Failed to connect LanceDB table '{self.config.lancedb_table}' in '{self.lancedb_dir}': {exc}"
60
+ ) from exc
61
+
62
+ try:
63
+ embedder = SapBERTEmbedder(
64
+ model_id=self.config.encoder_model_id,
65
+ device=device,
66
+ batch_size=64,
67
+ max_length=128,
68
+ pooling="cls",
69
+ )
70
+ reranker = ThirawatPeftReranker(
71
+ model_id=self.config.reranker_id,
72
+ device=device,
73
+ return_score="all",
74
+ )
75
+ except Exception as exc:
76
+ raise RuntimeError(f"Failed to initialize embedding/reranker models: {exc}") from exc
77
+
78
+ # Force lazy components to load at startup so runtime errors fail fast.
79
+ try:
80
+ _ = embedder.encode(["startup warmup"])
81
+ _ = reranker.scorer.encode_query("startup warmup")
82
+ except Exception as exc:
83
+ raise RuntimeError(f"Failed to warm up models at startup: {exc}") from exc
84
+
85
+ self._table = table
86
+ self._vector_column = vector_column
87
+ self._embedder = embedder
88
+ self._reranker = reranker
89
+ self._normalize_text = normalize_text_value
90
+
91
+ def available_concept_classes(self) -> list[str]:
92
+ if self._available_concept_classes:
93
+ return list(self._available_concept_classes)
94
+
95
+ self._ensure_started()
96
+ assert self._table is not None
97
+
98
+ try:
99
+ df = self._table.to_arrow().to_pandas()
100
+ except Exception:
101
+ return list(DEFAULT_CONCEPT_CLASSES)
102
+
103
+ if "concept_class_id" not in df.columns:
104
+ return list(DEFAULT_CONCEPT_CLASSES)
105
+
106
+ values = sorted(
107
+ {
108
+ str(value).strip()
109
+ for value in df["concept_class_id"].dropna().tolist()
110
+ if str(value).strip()
111
+ }
112
+ )
113
+ return values or list(DEFAULT_CONCEPT_CLASSES)
114
+
115
+ def _load_manifest_concept_classes(self) -> list[str]:
116
+ manifest_path = self.lancedb_dir / f"{self.config.lancedb_table}_manifest.json"
117
+ if not manifest_path.exists():
118
+ return list(DEFAULT_CONCEPT_CLASSES)
119
+ try:
120
+ payload = json.loads(manifest_path.read_text(encoding="utf-8"))
121
+ except Exception:
122
+ return list(DEFAULT_CONCEPT_CLASSES)
123
+ raw_values = payload.get("concept_class_id")
124
+ if not isinstance(raw_values, list):
125
+ return list(DEFAULT_CONCEPT_CLASSES)
126
+ values = [str(value).strip() for value in raw_values if str(value).strip()]
127
+ return values or list(DEFAULT_CONCEPT_CLASSES)
128
+
129
+ def _ensure_started(self) -> None:
130
+ if self._table is None or self._vector_column is None or self._embedder is None or self._reranker is None:
131
+ self.startup()
132
+
133
+ def search(self, query: str, top_k: int, concept_class_ids: Sequence[str] | None = None) -> pd.DataFrame:
134
+ where_clause = self._build_concept_class_where(concept_class_ids)
135
+ if concept_class_ids is not None and where_clause is None:
136
+ return pd.DataFrame(columns=self._ordered_output_columns())
137
+
138
+ self._ensure_started()
139
+ assert self._table is not None
140
+ assert self._vector_column is not None
141
+ assert self._embedder is not None
142
+ assert self._reranker is not None
143
+ assert self._normalize_text is not None
144
+
145
+ if not query or not query.strip():
146
+ return pd.DataFrame(columns=self._ordered_output_columns())
147
+
148
+ normalized_query = self._normalize_text(query)
149
+ query_emb = self._embedder.encode([normalized_query])[0]
150
+
151
+ builder = self._table.search(
152
+ query_emb.astype(float).tolist(),
153
+ vector_column_name=self._vector_column,
154
+ query_type="vector",
155
+ )
156
+ if where_clause:
157
+ schema_names = set(getattr(getattr(self._table, "schema", None), "names", []))
158
+ if "concept_class_id" not in schema_names:
159
+ raise RuntimeError("Requested concept class filtering but table has no 'concept_class_id' column.")
160
+ builder = builder.where(where_clause)
161
+
162
+ builder = builder.distance_type("cosine").limit(self.config.retrieval_topk)
163
+
164
+ # Mandatory two-stage retrieval: vector candidates then THIRAWAT reranking.
165
+ result = builder.rerank(reranker=self._reranker, query_string=normalized_query).limit(
166
+ self.config.candidate_topk
167
+ )
168
+
169
+ df = result.to_arrow().to_pandas()
170
+ if df.empty:
171
+ return pd.DataFrame(columns=self._ordered_output_columns())
172
+
173
+ df = self._apply_post_scoring(df, normalized_query)
174
+ df = self._finalize_scores(df)
175
+ df = self._ensure_columns(
176
+ df,
177
+ required=[
178
+ "concept_id",
179
+ "concept_name",
180
+ "concept_code",
181
+ "vocabulary_id",
182
+ "concept_class_id",
183
+ "score",
184
+ "retrieval_score",
185
+ ],
186
+ )
187
+ df.insert(0, "rank", list(range(1, len(df) + 1)))
188
+ return df[self._ordered_output_columns()].head(top_k).reset_index(drop=True)
189
+
190
+ def search_batch(
191
+ self,
192
+ queries: list[str],
193
+ top_k: int,
194
+ concept_class_ids: Sequence[str] | None = None,
195
+ ) -> pd.DataFrame:
196
+ rows: list[dict[str, Any]] = []
197
+ for query in queries:
198
+ result = self.search(query, top_k, concept_class_ids=concept_class_ids)
199
+ if result.empty:
200
+ continue
201
+ for _, row in result.iterrows():
202
+ rows.append(
203
+ {
204
+ "query_text": query,
205
+ "rank": int(row["rank"]),
206
+ "concept_id": row["concept_id"],
207
+ "concept_name": row["concept_name"],
208
+ "concept_code": row["concept_code"],
209
+ "vocabulary_id": row["vocabulary_id"],
210
+ "concept_class_id": row.get("concept_class_id"),
211
+ "score": float(row["score"]),
212
+ "retrieval_score": float(row["retrieval_score"]),
213
+ }
214
+ )
215
+ if not rows:
216
+ return pd.DataFrame(
217
+ columns=[
218
+ "query_text",
219
+ "rank",
220
+ "concept_id",
221
+ "concept_name",
222
+ "concept_code",
223
+ "vocabulary_id",
224
+ "concept_class_id",
225
+ "score",
226
+ "retrieval_score",
227
+ ]
228
+ )
229
+ return pd.DataFrame(rows)
230
+
231
+ def _apply_post_scoring(self, df: pd.DataFrame, normalized_query: str) -> pd.DataFrame:
232
+ if should_apply_post(self.config.post_mode, self.config.post_weight):
233
+ return enrich_with_post_scores(
234
+ df,
235
+ normalized_query,
236
+ post_strength_weight=self.config.post_strength_weight,
237
+ post_jaccard_weight=self.config.post_jaccard_weight,
238
+ post_brand_penalty=self.config.post_brand_penalty,
239
+ post_minmax=self.config.post_minmax,
240
+ post_weight=self.config.post_weight,
241
+ prefer_brand=True,
242
+ post_mode=self.config.post_mode,
243
+ tiebreak_eps=self.config.tiebreak_eps,
244
+ tiebreak_topn=self.config.tiebreak_topn,
245
+ brand_strict=self.config.brand_strict,
246
+ )
247
+
248
+ base_col = "_relevance_score" if "_relevance_score" in df.columns else "score"
249
+ if base_col in df.columns:
250
+ return df.sort_values(base_col, ascending=False, kind="mergesort").reset_index(drop=True)
251
+ return df.reset_index(drop=True)
252
+
253
+ @staticmethod
254
+ def _build_concept_class_where(concept_class_ids: Sequence[str] | None) -> str | None:
255
+ if concept_class_ids is None:
256
+ return None
257
+ values = sorted({str(value).strip() for value in concept_class_ids if str(value).strip()})
258
+ if not values:
259
+ return None
260
+ escaped = [value.replace("'", "''") for value in values]
261
+ if len(escaped) == 1:
262
+ return f"concept_class_id = '{escaped[0]}'"
263
+ return "concept_class_id IN (" + ",".join(f"'{value}'" for value in escaped) + ")"
264
+
265
+ @staticmethod
266
+ def _ensure_columns(df: pd.DataFrame, required: list[str]) -> pd.DataFrame:
267
+ for column in required:
268
+ if column not in df.columns:
269
+ df[column] = None
270
+ return df
271
+
272
+ @staticmethod
273
+ def _finalize_scores(df: pd.DataFrame) -> pd.DataFrame:
274
+ if "final_score" in df.columns:
275
+ df["score"] = pd.to_numeric(df["final_score"], errors="coerce").fillna(0.0)
276
+ elif "_relevance_score" in df.columns:
277
+ df["score"] = pd.to_numeric(df["_relevance_score"], errors="coerce").fillna(0.0)
278
+ elif "_distance" in df.columns:
279
+ distance = pd.to_numeric(df["_distance"], errors="coerce").fillna(1.0)
280
+ df["score"] = 1.0 - distance
281
+ elif "score" in df.columns:
282
+ df["score"] = pd.to_numeric(df["score"], errors="coerce").fillna(0.0)
283
+ else:
284
+ df["score"] = 0.0
285
+
286
+ if "retrieval_score" in df.columns:
287
+ df["retrieval_score"] = pd.to_numeric(df["retrieval_score"], errors="coerce").fillna(df["score"])
288
+ elif "_distance" in df.columns:
289
+ distance = pd.to_numeric(df["_distance"], errors="coerce").fillna(1.0)
290
+ df["retrieval_score"] = 1.0 - distance
291
+ else:
292
+ df["retrieval_score"] = df["score"]
293
+
294
+ return df.reset_index(drop=True)
295
+
296
+ @staticmethod
297
+ def _ordered_output_columns() -> list[str]:
298
+ return [
299
+ "rank",
300
+ "concept_id",
301
+ "concept_name",
302
+ "concept_code",
303
+ "vocabulary_id",
304
+ "concept_class_id",
305
+ "score",
306
+ "retrieval_score",
307
+ ]
src/thirawat_demo/space_ui.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio UI for Hugging Face Space runtime."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import tempfile
7
+ import traceback
8
+ from pathlib import Path
9
+ from typing import Sequence
10
+
11
+ import gradio as gr
12
+ import pandas as pd
13
+
14
+ from thirawat_demo.runtime import RuntimeConfig, SearchService, resolve_lancedb_dir
15
+
16
+ DEFAULT_CONCEPT_CLASSES = [
17
+ "Clinical Drug",
18
+ "Quant Clinical Drug",
19
+ "Clinical Drug Comp",
20
+ "Clinical Drug Form",
21
+ "Branded Drug",
22
+ "Quant Branded Drug",
23
+ "Branded Drug Comp",
24
+ "Branded Drug Form",
25
+ "Ingredient",
26
+ ]
27
+ DEFAULT_SINGLE_QUERY = "Augmentin 875/125"
28
+
29
+ _SERVICE: SearchService | None = None
30
+ _CONFIG: RuntimeConfig | None = None
31
+ _RESULT_COLUMNS = [
32
+ "rank",
33
+ "concept_id",
34
+ "concept_name",
35
+ "concept_code",
36
+ "vocabulary_id",
37
+ "concept_class",
38
+ ]
39
+
40
+
41
+ def _get_service(*, startup: bool = True) -> tuple[SearchService, RuntimeConfig]:
42
+ global _SERVICE, _CONFIG
43
+ if _SERVICE is None or _CONFIG is None:
44
+ _CONFIG = RuntimeConfig.from_env()
45
+ db_dir = resolve_lancedb_dir(_CONFIG)
46
+ _SERVICE = SearchService(config=_CONFIG, lancedb_dir=db_dir)
47
+ if startup:
48
+ _SERVICE.startup()
49
+ return _SERVICE, _CONFIG
50
+
51
+
52
+ def _resolve_concept_class_choices() -> list[str]:
53
+ try:
54
+ config = RuntimeConfig.from_env()
55
+ except Exception:
56
+ return list(DEFAULT_CONCEPT_CLASSES)
57
+
58
+ local = config.local_index_path
59
+ local_has_index = (local / "db").exists() or (local.name == "db" and local.exists())
60
+ if not local_has_index:
61
+ return list(DEFAULT_CONCEPT_CLASSES)
62
+
63
+ try:
64
+ db_dir = resolve_lancedb_dir(config)
65
+ service = SearchService(config=config, lancedb_dir=db_dir)
66
+ return service.available_concept_classes()
67
+ except Exception:
68
+ return list(DEFAULT_CONCEPT_CLASSES)
69
+
70
+
71
+ def _top_k_default() -> int:
72
+ raw = (os.getenv("TOP_K_DEFAULT") or "").strip()
73
+ if not raw:
74
+ return 10
75
+ try:
76
+ value = int(raw)
77
+ except ValueError:
78
+ return 10
79
+ return value if value > 0 else 10
80
+
81
+
82
+ def _empty_results_df() -> pd.DataFrame:
83
+ return pd.DataFrame(columns=_RESULT_COLUMNS)
84
+
85
+
86
+ def _to_display_results(results: pd.DataFrame) -> pd.DataFrame:
87
+ if results is None or results.empty:
88
+ return _empty_results_df()
89
+
90
+ df = results.copy()
91
+ concept_class = (
92
+ df["concept_class_id"].astype(str)
93
+ if "concept_class_id" in df.columns
94
+ else pd.Series([""] * len(df), index=df.index, dtype=object)
95
+ )
96
+ concept_class = concept_class.where(~concept_class.isin(["None", "nan"]), "")
97
+ df["concept_class"] = concept_class
98
+ for column in _RESULT_COLUMNS:
99
+ if column not in df.columns:
100
+ df[column] = None
101
+ return df[_RESULT_COLUMNS].reset_index(drop=True)
102
+
103
+
104
+ def _format_single_results(query: str, results: pd.DataFrame) -> tuple[str, pd.DataFrame]:
105
+ if results.empty:
106
+ return "No results found. Try a different query.", _empty_results_df()
107
+
108
+ top = results.iloc[0]
109
+ md = (
110
+ f"## Results for: \"{query}\"\n\n"
111
+ f"**Top match:** {top['concept_name']} (score {float(top['score']):.4f})\n\n"
112
+ f"Rows returned: **{len(results)}**"
113
+ )
114
+ return md, _to_display_results(results)
115
+
116
+
117
+ def _validate_concept_classes(concept_class_ids: Sequence[str] | None) -> list[str] | None:
118
+ if concept_class_ids is None:
119
+ return None
120
+ cleaned = [str(value).strip() for value in concept_class_ids if str(value).strip()]
121
+ return cleaned
122
+
123
+
124
+ def search_single(query: str, top_k: int, concept_class_ids: Sequence[str]) -> tuple[str, pd.DataFrame]:
125
+ try:
126
+ selected = _validate_concept_classes(concept_class_ids)
127
+ if selected is not None and not selected:
128
+ return "Please select at least one concept class.", _empty_results_df()
129
+ service, _ = _get_service()
130
+ results = service.search(query=query, top_k=top_k, concept_class_ids=selected)
131
+ return _format_single_results(query, results)
132
+ except Exception as exc: # pragma: no cover - runtime path
133
+ stack = traceback.format_exc()
134
+ print(stack)
135
+ return f"Error: {exc}", _empty_results_df()
136
+
137
+
138
+ def search_batch(queries_text: str, top_k: int, concept_class_ids: Sequence[str]) -> tuple[str, gr.update]:
139
+ try:
140
+ lines = [line.strip() for line in (queries_text or "").splitlines() if line.strip()]
141
+ if not lines:
142
+ return "Please provide one query per line.", gr.update(visible=False)
143
+
144
+ selected = _validate_concept_classes(concept_class_ids)
145
+ if selected is not None and not selected:
146
+ return "Please select at least one concept class.", gr.update(visible=False)
147
+
148
+ service, _ = _get_service()
149
+ results = service.search_batch(lines, top_k=top_k, concept_class_ids=selected)
150
+ if results.empty:
151
+ return "No batch results found.", gr.update(visible=False)
152
+
153
+ out_dir = Path(tempfile.gettempdir()) / "thirawat_mapper_demo"
154
+ out_dir.mkdir(parents=True, exist_ok=True)
155
+ out_path = out_dir / "batch_results.csv"
156
+ results.to_csv(out_path, index=False)
157
+
158
+ md = (
159
+ "## Batch Search Complete\n\n"
160
+ f"- Queries processed: **{len(lines)}**\n"
161
+ f"- Rows returned: **{len(results)}**\n"
162
+ f"- Top-K per query: **{top_k}**"
163
+ )
164
+ return md, gr.update(value=str(out_path), visible=True)
165
+ except Exception as exc: # pragma: no cover - runtime path
166
+ stack = traceback.format_exc()
167
+ print(stack)
168
+ return f"Error: {exc}", gr.update(visible=False)
169
+
170
+
171
+ def _clear_single(concept_classes: Sequence[str]) -> tuple[str, int, Sequence[str], str, pd.DataFrame]:
172
+ return DEFAULT_SINGLE_QUERY, _top_k_default(), concept_classes, "", _empty_results_df()
173
+
174
+
175
+ def _clear_batch(concept_classes: Sequence[str]) -> tuple[str, int, Sequence[str], str, gr.update]:
176
+ return "", _top_k_default(), concept_classes, "", gr.update(visible=False)
177
+
178
+
179
+ def build_demo() -> gr.Blocks:
180
+ top_k_default = _top_k_default()
181
+ concept_class_choices = _resolve_concept_class_choices()
182
+ concept_class_default = list(concept_class_choices)
183
+
184
+ with gr.Blocks(title="THIRAWAT Mapper Demo") as demo:
185
+ gr.Markdown(
186
+ """
187
+ # THIRAWAT Mapper Demo
188
+ Map non-standard drug terms to OMOP standard concepts using SapBERT retrieval + THIRAWAT reranking.
189
+ """
190
+ )
191
+ with gr.Tabs():
192
+ with gr.Tab("Single Query"):
193
+ with gr.Row():
194
+ with gr.Column(scale=3):
195
+ query_input = gr.Textbox(
196
+ label="Drug query",
197
+ placeholder="e.g., aspirin, amoxicillin 500 mg, paracetamol",
198
+ value=DEFAULT_SINGLE_QUERY,
199
+ lines=1,
200
+ )
201
+ with gr.Column(scale=1):
202
+ domain = gr.Dropdown(
203
+ label="Domain",
204
+ choices=["Drug"],
205
+ value="Drug",
206
+ interactive=False,
207
+ )
208
+ top_k = gr.Slider(
209
+ minimum=1,
210
+ maximum=50,
211
+ value=top_k_default,
212
+ step=1,
213
+ label="Number of results",
214
+ )
215
+ with gr.Row():
216
+ single_search = gr.Button("Search", variant="primary")
217
+ single_clear = gr.Button("Clear")
218
+ with gr.Row():
219
+ single_concept_classes = gr.Dropdown(
220
+ label="Concept classes",
221
+ choices=concept_class_choices,
222
+ value=concept_class_default,
223
+ multiselect=True,
224
+ filterable=True,
225
+ allow_custom_value=False,
226
+ interactive=True,
227
+ )
228
+ single_md = gr.Markdown(label="Summary")
229
+ single_table = gr.Dataframe(
230
+ label="Results",
231
+ headers=_RESULT_COLUMNS,
232
+ value=_empty_results_df(),
233
+ interactive=False,
234
+ )
235
+
236
+ with gr.Tab("Batch Query"):
237
+ with gr.Row():
238
+ with gr.Column(scale=3):
239
+ batch_input = gr.Textbox(
240
+ label="Queries (one per line)",
241
+ placeholder="aspirin\namoxicillin 500 mg\nparacetamol",
242
+ lines=10,
243
+ )
244
+ with gr.Column(scale=1):
245
+ batch_domain = gr.Dropdown(
246
+ label="Domain",
247
+ choices=["Drug"],
248
+ value="Drug",
249
+ interactive=False,
250
+ )
251
+ batch_top_k = gr.Slider(
252
+ minimum=1,
253
+ maximum=50,
254
+ value=top_k_default,
255
+ step=1,
256
+ label="Top-K per query",
257
+ )
258
+ with gr.Row():
259
+ batch_search_btn = gr.Button("Process Batch", variant="primary")
260
+ batch_clear_btn = gr.Button("Clear")
261
+ with gr.Row():
262
+ batch_concept_classes = gr.Dropdown(
263
+ label="Concept classes",
264
+ choices=concept_class_choices,
265
+ value=concept_class_default,
266
+ multiselect=True,
267
+ filterable=True,
268
+ allow_custom_value=False,
269
+ interactive=True,
270
+ )
271
+ batch_md = gr.Markdown(label="Summary")
272
+ batch_download = gr.DownloadButton(
273
+ label="Download CSV",
274
+ variant="secondary",
275
+ visible=False,
276
+ )
277
+
278
+ single_search.click(
279
+ fn=search_single,
280
+ inputs=[query_input, top_k, single_concept_classes],
281
+ outputs=[single_md, single_table],
282
+ api_name="search_single",
283
+ )
284
+ query_input.submit(
285
+ fn=search_single,
286
+ inputs=[query_input, top_k, single_concept_classes],
287
+ outputs=[single_md, single_table],
288
+ api_name=False,
289
+ )
290
+ single_clear.click(
291
+ fn=lambda: _clear_single(concept_class_default),
292
+ outputs=[query_input, top_k, single_concept_classes, single_md, single_table],
293
+ api_name=False,
294
+ )
295
+ batch_search_btn.click(
296
+ fn=search_batch,
297
+ inputs=[batch_input, batch_top_k, batch_concept_classes],
298
+ outputs=[batch_md, batch_download],
299
+ api_name="search_batch",
300
+ )
301
+ batch_clear_btn.click(
302
+ fn=lambda: _clear_batch(concept_class_default),
303
+ outputs=[batch_input, batch_top_k, batch_concept_classes, batch_md, batch_download],
304
+ api_name=False,
305
+ )
306
+
307
+ domain.change(lambda: "Drug", outputs=domain, api_name=False)
308
+ batch_domain.change(lambda: "Drug", outputs=batch_domain, api_name=False)
309
+
310
+ return demo
tests/conftest.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test bootstrap for src-layout imports."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ import sys
7
+
8
+ ROOT = Path(__file__).resolve().parents[1]
9
+ SRC = ROOT / "src"
10
+ if str(SRC) not in sys.path:
11
+ sys.path.insert(0, str(SRC))
12
+
tests/test_build_lancedb_index.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import importlib.util
4
+ from pathlib import Path
5
+ import sys
6
+ import types
7
+
8
+
9
+ def _load_script_module():
10
+ root = Path(__file__).resolve().parents[1]
11
+ script_path = root / "scripts" / "offline" / "build_lancedb_index.py"
12
+ spec = importlib.util.spec_from_file_location("build_lancedb_index", script_path)
13
+ assert spec is not None
14
+ assert spec.loader is not None
15
+ module = importlib.util.module_from_spec(spec)
16
+ spec.loader.exec_module(module)
17
+ return module
18
+
19
+
20
+ def test_default_concept_classes_exact():
21
+ module = _load_script_module()
22
+ expected = [
23
+ "Clinical Drug",
24
+ "Quant Clinical Drug",
25
+ "Clinical Drug Comp",
26
+ "Clinical Drug Form",
27
+ "Branded Drug",
28
+ "Quant Branded Drug",
29
+ "Branded Drug Comp",
30
+ "Branded Drug Form",
31
+ "Ingredient",
32
+ ]
33
+ assert module.DEFAULT_CONCEPT_CLASSES == expected
34
+
35
+
36
+ def test_resolve_device_prefers_mps_on_darwin(monkeypatch):
37
+ module = _load_script_module()
38
+ monkeypatch.setattr(module.platform, "system", lambda: "Darwin")
39
+
40
+ fake_torch = types.SimpleNamespace(
41
+ cuda=types.SimpleNamespace(is_available=lambda: True),
42
+ backends=types.SimpleNamespace(mps=types.SimpleNamespace(is_available=lambda: True)),
43
+ )
44
+ monkeypatch.setitem(sys.modules, "torch", fake_torch)
45
+ assert module.resolve_device("auto") == "mps"
46
+
47
+
48
+ def test_resolve_device_explicit_passthrough():
49
+ module = _load_script_module()
50
+ assert module.resolve_device("cpu") == "cpu"
51
+ assert module.resolve_device("cuda") == "cuda"
52
+ assert module.resolve_device("mps") == "mps"
53
+
tests/test_peft_reranker.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ import types
5
+
6
+ import thirawat_demo.runtime.peft_reranker as peft_mod
7
+
8
+
9
+ def test_load_colbert_with_peft_resizes_before_adapter(monkeypatch):
10
+ events: list[tuple] = []
11
+
12
+ class FakeEncoder:
13
+ def resize_token_embeddings(self, size, mean_resizing=False):
14
+ events.append(("resize", size, mean_resizing))
15
+
16
+ class FakeFirstModule:
17
+ def __init__(self):
18
+ self.auto_model = FakeEncoder()
19
+
20
+ class FakeDense:
21
+ def load_state_dict(self, state_dict, strict=False):
22
+ events.append(("dense_load", strict))
23
+
24
+ class FakeColBERT:
25
+ def __init__(self, *args, **kwargs):
26
+ events.append(("colbert_init", kwargs.get("model_name_or_path"), kwargs.get("tokenizer_kwargs")))
27
+ self._first = FakeFirstModule()
28
+ self._dense = FakeDense()
29
+ self.tokenizer = [0, 1, 2, 3]
30
+
31
+ def _first_module(self):
32
+ return self._first
33
+
34
+ def __len__(self):
35
+ return 2
36
+
37
+ def __getitem__(self, idx):
38
+ assert idx == 1
39
+ return self._dense
40
+
41
+ class FakePeftCfg:
42
+ base_model_name_or_path = "cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR"
43
+
44
+ class FakePeftConfig:
45
+ @staticmethod
46
+ def from_pretrained(model_id):
47
+ events.append(("peft_config", model_id))
48
+ return FakePeftCfg()
49
+
50
+ class FakePeftModel:
51
+ @staticmethod
52
+ def from_pretrained(encoder, model_id, is_trainable=False):
53
+ events.append(("peft_load", model_id, is_trainable))
54
+
55
+ class Wrapped:
56
+ def eval(self):
57
+ events.append(("peft_eval",))
58
+
59
+ return Wrapped()
60
+
61
+ pylate_module = types.ModuleType("pylate")
62
+ pylate_models = types.ModuleType("pylate.models")
63
+ pylate_models.ColBERT = FakeColBERT
64
+ pylate_module.models = pylate_models
65
+
66
+ safetensors_module = types.ModuleType("safetensors")
67
+ safetensors_torch = types.ModuleType("safetensors.torch")
68
+ safetensors_torch.load_file = lambda _: {}
69
+ safetensors_module.torch = safetensors_torch
70
+
71
+ monkeypatch.setattr(peft_mod, "disable_mean_resizing", lambda: events.append(("disable_mean_resizing",)))
72
+ monkeypatch.setitem(
73
+ sys.modules,
74
+ "peft",
75
+ types.SimpleNamespace(PeftConfig=FakePeftConfig, PeftModel=FakePeftModel),
76
+ )
77
+ monkeypatch.setitem(sys.modules, "pylate", pylate_module)
78
+ monkeypatch.setitem(sys.modules, "pylate.models", pylate_models)
79
+ monkeypatch.setitem(sys.modules, "safetensors", safetensors_module)
80
+ monkeypatch.setitem(sys.modules, "safetensors.torch", safetensors_torch)
81
+
82
+ model = peft_mod._load_colbert_with_peft("sidataplus/THIRAWAT-SapBERT", "cpu", 128, 128)
83
+ assert model is not None
84
+
85
+ resize_idx = next(idx for idx, entry in enumerate(events) if entry[0] == "resize")
86
+ peft_idx = next(idx for idx, entry in enumerate(events) if entry[0] == "peft_load")
87
+ assert resize_idx < peft_idx
tests/test_runtime_config.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ import pytest
6
+
7
+ from thirawat_demo.runtime.config import RuntimeConfig
8
+
9
+
10
+ def test_from_env_requires_index_repo_when_local_missing(monkeypatch):
11
+ monkeypatch.delenv("INDEX_REPO", raising=False)
12
+ monkeypatch.setenv("LOCAL_INDEX_PATH", "missing-local-index")
13
+ with pytest.raises(ValueError):
14
+ RuntimeConfig.from_env()
15
+
16
+
17
+ def test_from_env_accepts_local_db_without_repo(monkeypatch, tmp_path: Path):
18
+ local_index = tmp_path / "data" / "lancedb" / "db"
19
+ local_index.mkdir(parents=True)
20
+
21
+ monkeypatch.delenv("INDEX_REPO", raising=False)
22
+ monkeypatch.setenv("LOCAL_INDEX_PATH", str(local_index.parent))
23
+ config = RuntimeConfig.from_env()
24
+ assert config.local_index_path == local_index.parent
25
+ assert config.index_repo is None
26
+
27
+
28
+ def test_retrieval_topk_is_raised_to_candidate_topk(monkeypatch, tmp_path: Path):
29
+ local_index = tmp_path / "idx" / "db"
30
+ local_index.mkdir(parents=True)
31
+
32
+ monkeypatch.setenv("LOCAL_INDEX_PATH", str(local_index.parent))
33
+ monkeypatch.setenv("CANDIDATE_TOPK", "200")
34
+ monkeypatch.setenv("RETRIEVAL_TOPK", "10")
35
+
36
+ config = RuntimeConfig.from_env()
37
+ assert config.retrieval_topk == 200
38
+
39
+
40
+ def test_device_defaults_to_auto(monkeypatch, tmp_path: Path):
41
+ local_index = tmp_path / "idx" / "db"
42
+ local_index.mkdir(parents=True)
43
+ monkeypatch.setenv("LOCAL_INDEX_PATH", str(local_index.parent))
44
+ monkeypatch.delenv("DEVICE", raising=False)
45
+
46
+ config = RuntimeConfig.from_env()
47
+ assert config.device == "auto"
48
+
49
+
50
+ def test_post_config_defaults(monkeypatch, tmp_path: Path):
51
+ local_index = tmp_path / "idx" / "db"
52
+ local_index.mkdir(parents=True)
53
+ monkeypatch.setenv("LOCAL_INDEX_PATH", str(local_index.parent))
54
+ monkeypatch.delenv("POST_MODE", raising=False)
55
+ monkeypatch.delenv("POST_WEIGHT", raising=False)
56
+ monkeypatch.delenv("TIEBREAK_EPS", raising=False)
57
+ monkeypatch.delenv("TIEBREAK_TOPN", raising=False)
58
+ monkeypatch.delenv("POST_STRENGTH_WEIGHT", raising=False)
59
+ monkeypatch.delenv("POST_JACCARD_WEIGHT", raising=False)
60
+ monkeypatch.delenv("POST_BRAND_PENALTY", raising=False)
61
+ monkeypatch.delenv("POST_MINMAX", raising=False)
62
+ monkeypatch.delenv("BRAND_STRICT", raising=False)
63
+
64
+ config = RuntimeConfig.from_env()
65
+ assert config.post_mode == "tiebreak"
66
+ assert config.post_weight == 0.05
67
+ assert config.tiebreak_eps == 0.01
68
+ assert config.tiebreak_topn == 50
69
+ assert config.post_strength_weight == 0.6
70
+ assert config.post_jaccard_weight == 0.4
71
+ assert config.post_brand_penalty == 0.3
72
+ assert config.post_minmax is True
73
+ assert config.brand_strict is False
74
+
75
+
76
+ def test_invalid_post_mode_raises(monkeypatch, tmp_path: Path):
77
+ local_index = tmp_path / "idx" / "db"
78
+ local_index.mkdir(parents=True)
79
+ monkeypatch.setenv("LOCAL_INDEX_PATH", str(local_index.parent))
80
+ monkeypatch.setenv("POST_MODE", "unsupported")
81
+ with pytest.raises(ValueError):
82
+ RuntimeConfig.from_env()
tests/test_search_service.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import pyarrow as pa
8
+
9
+ import thirawat_demo.runtime.search_service as service_mod
10
+ from thirawat_demo.runtime.config import RuntimeConfig
11
+ from thirawat_demo.runtime.search_service import SearchService
12
+
13
+
14
+ def _make_config(tmp_path: Path) -> RuntimeConfig:
15
+ return RuntimeConfig(
16
+ local_index_path=tmp_path / "data" / "lancedb",
17
+ index_repo=None,
18
+ index_revision="main",
19
+ hf_token=None,
20
+ hf_data_dir=tmp_path / "data",
21
+ lancedb_table="concepts_drug",
22
+ top_k_default=10,
23
+ candidate_topk=100,
24
+ retrieval_topk=200,
25
+ device="cpu",
26
+ encoder_model_id="cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR",
27
+ reranker_id="sidataplus/THIRAWAT-SapBERT",
28
+ post_mode="tiebreak",
29
+ post_weight=0.05,
30
+ tiebreak_eps=0.01,
31
+ tiebreak_topn=50,
32
+ post_strength_weight=0.6,
33
+ post_jaccard_weight=0.4,
34
+ post_brand_penalty=0.3,
35
+ post_minmax=True,
36
+ brand_strict=False,
37
+ debug=False,
38
+ )
39
+
40
+
41
+ def test_build_concept_class_where_sorted_and_escaped():
42
+ expr = SearchService._build_concept_class_where(["Ingredient", "Clinical Drug", "O'Reilly", "Ingredient"])
43
+ assert expr == "concept_class_id IN ('Clinical Drug','Ingredient','O''Reilly')"
44
+
45
+
46
+ def test_apply_post_scoring_uses_mapper_tiebreak(monkeypatch, tmp_path: Path):
47
+ config = _make_config(tmp_path)
48
+ service = SearchService(config=config, lancedb_dir=tmp_path)
49
+ source = pd.DataFrame(
50
+ {
51
+ "concept_name": ["a", "b"],
52
+ "profile_text": ["", ""],
53
+ "_relevance_score": [0.9, 0.89],
54
+ }
55
+ )
56
+ called: dict[str, object] = {}
57
+
58
+ def fake_should_apply_post(mode: str, weight: float) -> bool:
59
+ called["mode"] = mode
60
+ called["weight"] = weight
61
+ return True
62
+
63
+ def fake_enrich(df: pd.DataFrame, query_text_norm: str, **kwargs):
64
+ called["query"] = query_text_norm
65
+ called["kwargs"] = kwargs
66
+ out = df.copy()
67
+ out["final_score"] = [0.1, 0.2]
68
+ return out.iloc[[1, 0]].reset_index(drop=True)
69
+
70
+ monkeypatch.setattr(service_mod, "should_apply_post", fake_should_apply_post)
71
+ monkeypatch.setattr(service_mod, "enrich_with_post_scores", fake_enrich)
72
+
73
+ out = service._apply_post_scoring(source, "aspirin 81 mg")
74
+ assert out["concept_name"].tolist() == ["b", "a"]
75
+ assert called["mode"] == "tiebreak"
76
+ assert called["weight"] == 0.05
77
+ assert called["query"] == "aspirin 81 mg"
78
+ assert called["kwargs"]["tiebreak_eps"] == 0.01
79
+ assert called["kwargs"]["tiebreak_topn"] == 50
80
+
81
+
82
+ def test_search_applies_concept_class_filter(monkeypatch, tmp_path: Path):
83
+ config = _make_config(tmp_path)
84
+ service = SearchService(config=config, lancedb_dir=tmp_path)
85
+
86
+ class FakeBuilder:
87
+ def __init__(self):
88
+ self.where_clause = None
89
+
90
+ def where(self, clause):
91
+ self.where_clause = clause
92
+ return self
93
+
94
+ def distance_type(self, _):
95
+ return self
96
+
97
+ def limit(self, _):
98
+ return self
99
+
100
+ def rerank(self, reranker=None, query_string=None):
101
+ return self
102
+
103
+ def to_arrow(self):
104
+ frame = pd.DataFrame(
105
+ {
106
+ "concept_id": [1191],
107
+ "concept_name": ["aspirin"],
108
+ "concept_code": ["1191"],
109
+ "vocabulary_id": ["RxNorm"],
110
+ "_relevance_score": [0.998],
111
+ "_distance": [0.2],
112
+ "concept_class_id": ["Ingredient"],
113
+ }
114
+ )
115
+ return pa.Table.from_pandas(frame, preserve_index=False)
116
+
117
+ class FakeSchema:
118
+ names = ["concept_id", "concept_name", "concept_code", "vocabulary_id", "concept_class_id", "vector"]
119
+
120
+ class FakeTable:
121
+ schema = FakeSchema()
122
+
123
+ def __init__(self, builder: FakeBuilder):
124
+ self._builder = builder
125
+
126
+ def search(self, *args, **kwargs):
127
+ return self._builder
128
+
129
+ class FakeEmbedder:
130
+ def encode(self, texts):
131
+ return np.array([[0.0, 1.0]], dtype=float)
132
+
133
+ builder = FakeBuilder()
134
+ service._table = FakeTable(builder)
135
+ service._vector_column = "vector"
136
+ service._embedder = FakeEmbedder()
137
+ service._reranker = object()
138
+ service._normalize_text = lambda value: value.strip()
139
+ monkeypatch.setattr(SearchService, "_apply_post_scoring", lambda self, df, q: df)
140
+
141
+ result = service.search("aspirin", top_k=5, concept_class_ids=["Ingredient", "Clinical Drug"])
142
+ assert builder.where_clause == "concept_class_id IN ('Clinical Drug','Ingredient')"
143
+ assert result["concept_name"].tolist() == ["aspirin"]
144
+ assert result["concept_class_id"].tolist() == ["Ingredient"]
145
+ assert result["retrieval_score"].tolist() == [0.8]
146
+
147
+
148
+ def test_search_returns_empty_when_no_concept_class_selected(tmp_path: Path):
149
+ config = _make_config(tmp_path)
150
+ service = SearchService(config=config, lancedb_dir=tmp_path)
151
+ out = service.search("aspirin", top_k=10, concept_class_ids=[])
152
+ assert out.empty
tests/test_space_ui.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import pandas as pd
4
+
5
+ from thirawat_demo import space_ui
6
+
7
+
8
+ def test_to_display_results_replaces_scores_with_concept_class():
9
+ raw = pd.DataFrame(
10
+ {
11
+ "rank": [1],
12
+ "concept_id": [123],
13
+ "concept_name": ["aspirin"],
14
+ "concept_code": ["1191"],
15
+ "vocabulary_id": ["RxNorm"],
16
+ "concept_class_id": ["Ingredient"],
17
+ "score": [0.9],
18
+ "retrieval_score": [0.8],
19
+ }
20
+ )
21
+ out = space_ui._to_display_results(raw)
22
+ assert out.columns.tolist() == [
23
+ "rank",
24
+ "concept_id",
25
+ "concept_name",
26
+ "concept_code",
27
+ "vocabulary_id",
28
+ "concept_class",
29
+ ]
30
+ assert out.iloc[0]["concept_class"] == "Ingredient"
31
+ assert "score" not in out.columns
32
+ assert "retrieval_score" not in out.columns
33
+
34
+
35
+ def test_to_display_results_handles_missing_concept_class_id():
36
+ raw = pd.DataFrame(
37
+ {
38
+ "rank": [1],
39
+ "concept_id": [123],
40
+ "concept_name": ["aspirin"],
41
+ "concept_code": ["1191"],
42
+ "vocabulary_id": ["RxNorm"],
43
+ }
44
+ )
45
+ out = space_ui._to_display_results(raw)
46
+ assert out.iloc[0]["concept_class"] == ""
47
+
48
+
49
+ def test_format_single_results_keeps_score_text_but_returns_projected_table():
50
+ raw = pd.DataFrame(
51
+ {
52
+ "rank": [1],
53
+ "concept_id": [123],
54
+ "concept_name": ["aspirin"],
55
+ "concept_code": ["1191"],
56
+ "vocabulary_id": ["RxNorm"],
57
+ "concept_class_id": ["Ingredient"],
58
+ "score": [0.91],
59
+ "retrieval_score": [0.81],
60
+ }
61
+ )
62
+ md, table = space_ui._format_single_results("aspirin", raw)
63
+ assert "score 0.9100" in md
64
+ assert "concept_class" in table.columns
65
+ assert "score" not in table.columns
66
+
67
+
68
+ def test_search_single_requires_at_least_one_concept_class():
69
+ md, table = space_ui.search_single("aspirin", 10, [])
70
+ assert md == "Please select at least one concept class."
71
+ assert table.empty
72
+
73
+
74
+ def test_search_batch_requires_at_least_one_concept_class():
75
+ md, download = space_ui.search_batch("aspirin", 10, [])
76
+ assert md == "Please select at least one concept class."
77
+ assert download["visible"] is False
78
+
79
+
80
+ def test_clear_single_resets_default_query_and_concept_classes():
81
+ concept_classes = ["Ingredient", "Clinical Drug"]
82
+ query, top_k, returned_classes, md, table = space_ui._clear_single(concept_classes)
83
+ assert query == space_ui.DEFAULT_SINGLE_QUERY
84
+ assert top_k == 10
85
+ assert returned_classes == concept_classes
86
+ assert md == ""
87
+ assert table.empty
88
+
89
+
90
+ def test_clear_batch_resets_concept_classes_and_hides_download():
91
+ concept_classes = ["Ingredient", "Clinical Drug"]
92
+ queries, top_k, returned_classes, md, download = space_ui._clear_batch(concept_classes)
93
+ assert queries == ""
94
+ assert top_k == 10
95
+ assert returned_classes == concept_classes
96
+ assert md == ""
97
+ assert download["visible"] is False
uv.lock ADDED
The diff for this file is too large to render. See raw diff