na399's picture
Deploy THIRAWAT mapper app
25c66a0 verified
"""Drug Concept Entity Linking - HuggingFace Space"""
import os
import tempfile
import traceback
from pathlib import Path
import gradio as gr
import lancedb
from sentence_transformers import SentenceTransformer
import pandas as pd
# ===== CONFIG =====
# ดึงจาก Space Secrets (ตั้งค่าใน Settings > Secrets)
def _get_env(name: str, default: str | None = None) -> str | None:
value = os.environ.get(name)
if value is None:
return default
value = value.strip()
if not value or value.lower() == "none":
return default
return value
HF_TOKEN = _get_env("HF_TOKEN") # หรือไม่ใส่ก็ได้ถ้า public
INDEX_REPO = _get_env("INDEX_REPO", "amnnma/drug-concept-index") # เปลี่ยนชื่อ repo
LOCAL_INDEX_PATH = _get_env("LOCAL_INDEX_PATH", "data/lancedb")
DEBUG = _get_env("DEBUG", "0") == "1"
# Model
MODEL_ID = "cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR"
TOP_K = 10
class DrugConceptSearcher:
def __init__(self):
self.model = None
self.db = None
self.table = None
self._load()
def _load(self):
"""Load model and connect to LanceDB"""
print("Loading model...")
# Force slow tokenizer to avoid fast-tokenizer conversion issues on Space
self.model = SentenceTransformer(MODEL_ID, tokenizer_kwargs={"use_fast": False})
# Prefer local index when available (useful for local runs)
local_root = Path(LOCAL_INDEX_PATH) if LOCAL_INDEX_PATH else None
if local_root and local_root.exists() and (local_root / "db").exists():
index_root = local_root
print(f"Connecting to local LanceDB at {index_root}...")
else:
repo_id = INDEX_REPO or "amnnma/drug-concept-index"
if not isinstance(repo_id, str):
repo_id = str(repo_id)
repo_id = repo_id.strip()
if repo_id.startswith("http"):
# Accept full HF URLs and extract the repo id
parts = repo_id.split("/")
if "datasets" in parts:
repo_id = "/".join(parts[parts.index("datasets") + 1 :]).strip("/")
elif "spaces" in parts:
repo_id = "/".join(parts[parts.index("spaces") + 1 :]).strip("/")
else:
repo_id = "/".join(parts[-2:]).strip("/")
if repo_id.startswith("datasets/"):
repo_id = repo_id[len("datasets/") :]
print(f"Connecting to LanceDB from {repo_id}...")
# Download และ connect ไปยัง LanceDB ใน HF repo
from huggingface_hub import snapshot_download
# Download index (cache ไว้ใน /data)
download_root = Path(os.environ.get("HF_DATA_DIR", "/data")) / "lancedb"
try:
download_root.mkdir(parents=True, exist_ok=True)
except OSError:
download_root = Path("data/lancedb")
download_root.mkdir(parents=True, exist_ok=True)
# Avoid implicit token usage for public datasets
os.environ["HF_HUB_DISABLE_IMPLICIT_TOKEN"] = "1"
try:
index_root = Path(
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
token=False,
revision=os.environ.get("HF_DATASET_REVISION", "main"),
local_dir=str(download_root),
)
)
except Exception as e:
if HF_TOKEN:
index_root = Path(
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
token=HF_TOKEN,
local_dir=str(download_root),
)
)
else:
raise e
# Connect to LanceDB
self.db = lancedb.connect(str(index_root / "db"))
self.table = self.db.open_table("concepts_drug")
print("✅ Ready!")
def search(self, query: str, top_k: int = TOP_K):
"""Search drug concepts"""
if not query or not query.strip():
return pd.DataFrame()
# Encode query
query_emb = self.model.encode(query, normalize_embeddings=True)
# Search
results = self.table.search(query_emb).limit(top_k).to_pandas()
# Format output
if "_distance" in results.columns:
results["score"] = 1 - results["_distance"] # Convert distance to similarity
results = results.sort_values("score", ascending=False)
return results[["concept_id", "concept_name", "concept_code", "vocabulary_id", "score"]]
# Initialize
searcher = None
def get_searcher():
global searcher
if searcher is None:
searcher = DrugConceptSearcher()
return searcher
def _format_results(results: pd.DataFrame, query: str) -> tuple[str, pd.DataFrame]:
if results.empty:
return "No results found. Try a different search term.", results
output = f"## Results for: \"{query}\"\n\n"
best = results.iloc[0]
output += f"**Top match:** {best['concept_name']} (score {best['score']:.4f})\n\n"
return output, results
def search_drugs(query: str, top_k: int):
"""Gradio search function (single query)"""
try:
s = get_searcher()
results = s.search(query, top_k)
output, table = _format_results(results, query)
return output, table
except Exception as e:
print("Search error:", e)
print(traceback.format_exc())
if DEBUG:
return f"❌ Error: {str(e)}\n\n```\n{traceback.format_exc()}\n```", pd.DataFrame()
return f"❌ Error: {str(e)}", pd.DataFrame()
def search_batch(queries_text: str, top_k: int):
"""Gradio search function (batch queries)"""
try:
if not queries_text or not queries_text.strip():
return "Please enter clinical terms to search.", gr.update(visible=False)
lines = [line.strip() for line in queries_text.splitlines() if line.strip()]
if not lines:
return "No valid queries found.", gr.update(visible=False)
s = get_searcher()
rows = []
for q in lines:
results = s.search(q, top_k)
for i, (_, row) in enumerate(results.iterrows(), start=1):
rows.append(
{
"query_text": q,
"rank": i,
"concept_id": row["concept_id"],
"concept_name": row["concept_name"],
"concept_code": row["concept_code"],
"vocabulary_id": row["vocabulary_id"],
"score": float(row["score"]),
}
)
if not rows:
return "No results found.", gr.update(visible=False)
df = pd.DataFrame(rows)
tmp_dir = Path(tempfile.gettempdir()) / "thirawat_results"
tmp_dir.mkdir(parents=True, exist_ok=True)
out_path = tmp_dir / "batch_results.csv"
df.to_csv(out_path, index=False)
md = f"""## Batch Search Complete
- **Queries processed:** {len(lines)}
- **Rows returned:** {len(rows)}
- **Top-K per query:** {top_k}
"""
return md, gr.update(value=str(out_path), visible=True)
except Exception as e:
print("Batch search error:", e)
print(traceback.format_exc())
if DEBUG:
return f"❌ Error: {str(e)}\n\n```\n{traceback.format_exc()}\n```", gr.update(visible=False)
return f"❌ Error: {str(e)}", gr.update(visible=False)
# ===== GRADIO INTERFACE =====
with gr.Blocks(title="THIRAWAT - Drug Concept Search") as demo:
gr.HTML(
"""
<div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 10px; margin-bottom: 20px;">
<h1 style="color: white; margin: 0; font-size: 2em;">THIRAWAT</h1>
<p style="color: rgba(255,255,255,0.9); margin: 5px 0 0 0;">Drug Concept Entity Linking</p>
<p style="color: rgba(255,255,255,0.8); margin: 5px 0 0 0;">Map drug names to OMOP concepts using SapBERT + LanceDB.</p>
</div>
"""
)
with gr.Tabs():
with gr.Tab("Single Query"):
with gr.Row():
with gr.Column(scale=3):
query_input = gr.Textbox(
label="Drug name or query",
placeholder="e.g., aspirin, paracetamol, amoxicillin 500mg...",
lines=2,
)
with gr.Column(scale=1):
domain_hint = gr.Dropdown(
label="Domain",
choices=["Drug", "Condition", "Procedure", "Observation", "Device", "Unit"],
value="Drug",
interactive=False,
)
top_k = gr.Slider(
minimum=1,
maximum=50,
value=10,
step=1,
label="Number of results",
)
with gr.Row():
search_btn = gr.Button("Search", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
output_md = gr.Markdown(label="Results")
output_table = gr.Dataframe(label="Results Table", interactive=False)
with gr.Tab("Batch Query"):
with gr.Row():
with gr.Column(scale=3):
batch_queries = gr.Textbox(
label="Drug names (one per line)",
placeholder="aspirin\nparacetamol\namoxicillin 500mg",
lines=10,
)
with gr.Column(scale=1):
batch_domain_hint = gr.Dropdown(
label="Domain",
choices=["Drug", "Condition", "Procedure", "Observation", "Device", "Unit"],
value="Drug",
interactive=False,
)
batch_topk = gr.Slider(
minimum=1,
maximum=50,
value=10,
step=1,
label="Top-K per query",
)
with gr.Row():
batch_btn = gr.Button("Process Batch", variant="primary")
batch_clear = gr.Button("Clear", variant="secondary")
batch_output = gr.Markdown(label="Summary")
batch_download = gr.DownloadButton(
label="Download Results (CSV)",
variant="secondary",
visible=False,
)
def clear_single():
return "", 10, "", pd.DataFrame()
def clear_batch():
return "", 10, "", gr.update(visible=False)
search_btn.click(
fn=search_drugs,
inputs=[query_input, top_k],
outputs=[output_md, output_table],
api_name=False,
)
clear_btn.click(
fn=clear_single,
outputs=[query_input, top_k, output_md, output_table],
api_name=False,
)
batch_btn.click(
fn=search_batch,
inputs=[batch_queries, batch_topk],
outputs=[batch_output, batch_download],
api_name=False,
)
batch_clear.click(
fn=clear_batch,
outputs=[batch_queries, batch_topk, batch_output, batch_download],
api_name=False,
)
gr.Markdown(
"""
---
**THIRAWAT** is a dense retrieval toolkit for mapping drug terminology to OMOP standard concepts.
"""
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)