Spaces:
Sleeping
Sleeping
| """Drug Concept Entity Linking - HuggingFace Space""" | |
| import os | |
| import tempfile | |
| import traceback | |
| from pathlib import Path | |
| import gradio as gr | |
| import lancedb | |
| from sentence_transformers import SentenceTransformer | |
| import pandas as pd | |
| # ===== CONFIG ===== | |
| # ดึงจาก Space Secrets (ตั้งค่าใน Settings > Secrets) | |
| def _get_env(name: str, default: str | None = None) -> str | None: | |
| value = os.environ.get(name) | |
| if value is None: | |
| return default | |
| value = value.strip() | |
| if not value or value.lower() == "none": | |
| return default | |
| return value | |
| HF_TOKEN = _get_env("HF_TOKEN") # หรือไม่ใส่ก็ได้ถ้า public | |
| INDEX_REPO = _get_env("INDEX_REPO", "amnnma/drug-concept-index") # เปลี่ยนชื่อ repo | |
| LOCAL_INDEX_PATH = _get_env("LOCAL_INDEX_PATH", "data/lancedb") | |
| DEBUG = _get_env("DEBUG", "0") == "1" | |
| # Model | |
| MODEL_ID = "cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR" | |
| TOP_K = 10 | |
| class DrugConceptSearcher: | |
| def __init__(self): | |
| self.model = None | |
| self.db = None | |
| self.table = None | |
| self._load() | |
| def _load(self): | |
| """Load model and connect to LanceDB""" | |
| print("Loading model...") | |
| # Force slow tokenizer to avoid fast-tokenizer conversion issues on Space | |
| self.model = SentenceTransformer(MODEL_ID, tokenizer_kwargs={"use_fast": False}) | |
| # Prefer local index when available (useful for local runs) | |
| local_root = Path(LOCAL_INDEX_PATH) if LOCAL_INDEX_PATH else None | |
| if local_root and local_root.exists() and (local_root / "db").exists(): | |
| index_root = local_root | |
| print(f"Connecting to local LanceDB at {index_root}...") | |
| else: | |
| repo_id = INDEX_REPO or "amnnma/drug-concept-index" | |
| if not isinstance(repo_id, str): | |
| repo_id = str(repo_id) | |
| repo_id = repo_id.strip() | |
| if repo_id.startswith("http"): | |
| # Accept full HF URLs and extract the repo id | |
| parts = repo_id.split("/") | |
| if "datasets" in parts: | |
| repo_id = "/".join(parts[parts.index("datasets") + 1 :]).strip("/") | |
| elif "spaces" in parts: | |
| repo_id = "/".join(parts[parts.index("spaces") + 1 :]).strip("/") | |
| else: | |
| repo_id = "/".join(parts[-2:]).strip("/") | |
| if repo_id.startswith("datasets/"): | |
| repo_id = repo_id[len("datasets/") :] | |
| print(f"Connecting to LanceDB from {repo_id}...") | |
| # Download และ connect ไปยัง LanceDB ใน HF repo | |
| from huggingface_hub import snapshot_download | |
| # Download index (cache ไว้ใน /data) | |
| download_root = Path(os.environ.get("HF_DATA_DIR", "/data")) / "lancedb" | |
| try: | |
| download_root.mkdir(parents=True, exist_ok=True) | |
| except OSError: | |
| download_root = Path("data/lancedb") | |
| download_root.mkdir(parents=True, exist_ok=True) | |
| # Avoid implicit token usage for public datasets | |
| os.environ["HF_HUB_DISABLE_IMPLICIT_TOKEN"] = "1" | |
| try: | |
| index_root = Path( | |
| snapshot_download( | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| token=False, | |
| revision=os.environ.get("HF_DATASET_REVISION", "main"), | |
| local_dir=str(download_root), | |
| ) | |
| ) | |
| except Exception as e: | |
| if HF_TOKEN: | |
| index_root = Path( | |
| snapshot_download( | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| local_dir=str(download_root), | |
| ) | |
| ) | |
| else: | |
| raise e | |
| # Connect to LanceDB | |
| self.db = lancedb.connect(str(index_root / "db")) | |
| self.table = self.db.open_table("concepts_drug") | |
| print("✅ Ready!") | |
| def search(self, query: str, top_k: int = TOP_K): | |
| """Search drug concepts""" | |
| if not query or not query.strip(): | |
| return pd.DataFrame() | |
| # Encode query | |
| query_emb = self.model.encode(query, normalize_embeddings=True) | |
| # Search | |
| results = self.table.search(query_emb).limit(top_k).to_pandas() | |
| # Format output | |
| if "_distance" in results.columns: | |
| results["score"] = 1 - results["_distance"] # Convert distance to similarity | |
| results = results.sort_values("score", ascending=False) | |
| return results[["concept_id", "concept_name", "concept_code", "vocabulary_id", "score"]] | |
| # Initialize | |
| searcher = None | |
| def get_searcher(): | |
| global searcher | |
| if searcher is None: | |
| searcher = DrugConceptSearcher() | |
| return searcher | |
| def _format_results(results: pd.DataFrame, query: str) -> tuple[str, pd.DataFrame]: | |
| if results.empty: | |
| return "No results found. Try a different search term.", results | |
| output = f"## Results for: \"{query}\"\n\n" | |
| best = results.iloc[0] | |
| output += f"**Top match:** {best['concept_name']} (score {best['score']:.4f})\n\n" | |
| return output, results | |
| def search_drugs(query: str, top_k: int): | |
| """Gradio search function (single query)""" | |
| try: | |
| s = get_searcher() | |
| results = s.search(query, top_k) | |
| output, table = _format_results(results, query) | |
| return output, table | |
| except Exception as e: | |
| print("Search error:", e) | |
| print(traceback.format_exc()) | |
| if DEBUG: | |
| return f"❌ Error: {str(e)}\n\n```\n{traceback.format_exc()}\n```", pd.DataFrame() | |
| return f"❌ Error: {str(e)}", pd.DataFrame() | |
| def search_batch(queries_text: str, top_k: int): | |
| """Gradio search function (batch queries)""" | |
| try: | |
| if not queries_text or not queries_text.strip(): | |
| return "Please enter clinical terms to search.", gr.update(visible=False) | |
| lines = [line.strip() for line in queries_text.splitlines() if line.strip()] | |
| if not lines: | |
| return "No valid queries found.", gr.update(visible=False) | |
| s = get_searcher() | |
| rows = [] | |
| for q in lines: | |
| results = s.search(q, top_k) | |
| for i, (_, row) in enumerate(results.iterrows(), start=1): | |
| rows.append( | |
| { | |
| "query_text": q, | |
| "rank": i, | |
| "concept_id": row["concept_id"], | |
| "concept_name": row["concept_name"], | |
| "concept_code": row["concept_code"], | |
| "vocabulary_id": row["vocabulary_id"], | |
| "score": float(row["score"]), | |
| } | |
| ) | |
| if not rows: | |
| return "No results found.", gr.update(visible=False) | |
| df = pd.DataFrame(rows) | |
| tmp_dir = Path(tempfile.gettempdir()) / "thirawat_results" | |
| tmp_dir.mkdir(parents=True, exist_ok=True) | |
| out_path = tmp_dir / "batch_results.csv" | |
| df.to_csv(out_path, index=False) | |
| md = f"""## Batch Search Complete | |
| - **Queries processed:** {len(lines)} | |
| - **Rows returned:** {len(rows)} | |
| - **Top-K per query:** {top_k} | |
| """ | |
| return md, gr.update(value=str(out_path), visible=True) | |
| except Exception as e: | |
| print("Batch search error:", e) | |
| print(traceback.format_exc()) | |
| if DEBUG: | |
| return f"❌ Error: {str(e)}\n\n```\n{traceback.format_exc()}\n```", gr.update(visible=False) | |
| return f"❌ Error: {str(e)}", gr.update(visible=False) | |
| # ===== GRADIO INTERFACE ===== | |
| with gr.Blocks(title="THIRAWAT - Drug Concept Search") as demo: | |
| gr.HTML( | |
| """ | |
| <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 10px; margin-bottom: 20px;"> | |
| <h1 style="color: white; margin: 0; font-size: 2em;">THIRAWAT</h1> | |
| <p style="color: rgba(255,255,255,0.9); margin: 5px 0 0 0;">Drug Concept Entity Linking</p> | |
| <p style="color: rgba(255,255,255,0.8); margin: 5px 0 0 0;">Map drug names to OMOP concepts using SapBERT + LanceDB.</p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Single Query"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| query_input = gr.Textbox( | |
| label="Drug name or query", | |
| placeholder="e.g., aspirin, paracetamol, amoxicillin 500mg...", | |
| lines=2, | |
| ) | |
| with gr.Column(scale=1): | |
| domain_hint = gr.Dropdown( | |
| label="Domain", | |
| choices=["Drug", "Condition", "Procedure", "Observation", "Device", "Unit"], | |
| value="Drug", | |
| interactive=False, | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=50, | |
| value=10, | |
| step=1, | |
| label="Number of results", | |
| ) | |
| with gr.Row(): | |
| search_btn = gr.Button("Search", variant="primary") | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| output_md = gr.Markdown(label="Results") | |
| output_table = gr.Dataframe(label="Results Table", interactive=False) | |
| with gr.Tab("Batch Query"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| batch_queries = gr.Textbox( | |
| label="Drug names (one per line)", | |
| placeholder="aspirin\nparacetamol\namoxicillin 500mg", | |
| lines=10, | |
| ) | |
| with gr.Column(scale=1): | |
| batch_domain_hint = gr.Dropdown( | |
| label="Domain", | |
| choices=["Drug", "Condition", "Procedure", "Observation", "Device", "Unit"], | |
| value="Drug", | |
| interactive=False, | |
| ) | |
| batch_topk = gr.Slider( | |
| minimum=1, | |
| maximum=50, | |
| value=10, | |
| step=1, | |
| label="Top-K per query", | |
| ) | |
| with gr.Row(): | |
| batch_btn = gr.Button("Process Batch", variant="primary") | |
| batch_clear = gr.Button("Clear", variant="secondary") | |
| batch_output = gr.Markdown(label="Summary") | |
| batch_download = gr.DownloadButton( | |
| label="Download Results (CSV)", | |
| variant="secondary", | |
| visible=False, | |
| ) | |
| def clear_single(): | |
| return "", 10, "", pd.DataFrame() | |
| def clear_batch(): | |
| return "", 10, "", gr.update(visible=False) | |
| search_btn.click( | |
| fn=search_drugs, | |
| inputs=[query_input, top_k], | |
| outputs=[output_md, output_table], | |
| api_name=False, | |
| ) | |
| clear_btn.click( | |
| fn=clear_single, | |
| outputs=[query_input, top_k, output_md, output_table], | |
| api_name=False, | |
| ) | |
| batch_btn.click( | |
| fn=search_batch, | |
| inputs=[batch_queries, batch_topk], | |
| outputs=[batch_output, batch_download], | |
| api_name=False, | |
| ) | |
| batch_clear.click( | |
| fn=clear_batch, | |
| outputs=[batch_queries, batch_topk, batch_output, batch_download], | |
| api_name=False, | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| **THIRAWAT** is a dense retrieval toolkit for mapping drug terminology to OMOP standard concepts. | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) |