"""Drug Concept Entity Linking - HuggingFace Space""" import os import tempfile import traceback from pathlib import Path import gradio as gr import lancedb from sentence_transformers import SentenceTransformer import pandas as pd # ===== CONFIG ===== # ดึงจาก Space Secrets (ตั้งค่าใน Settings > Secrets) def _get_env(name: str, default: str | None = None) -> str | None: value = os.environ.get(name) if value is None: return default value = value.strip() if not value or value.lower() == "none": return default return value HF_TOKEN = _get_env("HF_TOKEN") # หรือไม่ใส่ก็ได้ถ้า public INDEX_REPO = _get_env("INDEX_REPO", "amnnma/drug-concept-index") # เปลี่ยนชื่อ repo LOCAL_INDEX_PATH = _get_env("LOCAL_INDEX_PATH", "data/lancedb") DEBUG = _get_env("DEBUG", "0") == "1" # Model MODEL_ID = "cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR" TOP_K = 10 class DrugConceptSearcher: def __init__(self): self.model = None self.db = None self.table = None self._load() def _load(self): """Load model and connect to LanceDB""" print("Loading model...") # Force slow tokenizer to avoid fast-tokenizer conversion issues on Space self.model = SentenceTransformer(MODEL_ID, tokenizer_kwargs={"use_fast": False}) # Prefer local index when available (useful for local runs) local_root = Path(LOCAL_INDEX_PATH) if LOCAL_INDEX_PATH else None if local_root and local_root.exists() and (local_root / "db").exists(): index_root = local_root print(f"Connecting to local LanceDB at {index_root}...") else: repo_id = INDEX_REPO or "amnnma/drug-concept-index" if not isinstance(repo_id, str): repo_id = str(repo_id) repo_id = repo_id.strip() if repo_id.startswith("http"): # Accept full HF URLs and extract the repo id parts = repo_id.split("/") if "datasets" in parts: repo_id = "/".join(parts[parts.index("datasets") + 1 :]).strip("/") elif "spaces" in parts: repo_id = "/".join(parts[parts.index("spaces") + 1 :]).strip("/") else: repo_id = "/".join(parts[-2:]).strip("/") if repo_id.startswith("datasets/"): repo_id = repo_id[len("datasets/") :] print(f"Connecting to LanceDB from {repo_id}...") # Download และ connect ไปยัง LanceDB ใน HF repo from huggingface_hub import snapshot_download # Download index (cache ไว้ใน /data) download_root = Path(os.environ.get("HF_DATA_DIR", "/data")) / "lancedb" try: download_root.mkdir(parents=True, exist_ok=True) except OSError: download_root = Path("data/lancedb") download_root.mkdir(parents=True, exist_ok=True) # Avoid implicit token usage for public datasets os.environ["HF_HUB_DISABLE_IMPLICIT_TOKEN"] = "1" try: index_root = Path( snapshot_download( repo_id=repo_id, repo_type="dataset", token=False, revision=os.environ.get("HF_DATASET_REVISION", "main"), local_dir=str(download_root), ) ) except Exception as e: if HF_TOKEN: index_root = Path( snapshot_download( repo_id=repo_id, repo_type="dataset", token=HF_TOKEN, local_dir=str(download_root), ) ) else: raise e # Connect to LanceDB self.db = lancedb.connect(str(index_root / "db")) self.table = self.db.open_table("concepts_drug") print("✅ Ready!") def search(self, query: str, top_k: int = TOP_K): """Search drug concepts""" if not query or not query.strip(): return pd.DataFrame() # Encode query query_emb = self.model.encode(query, normalize_embeddings=True) # Search results = self.table.search(query_emb).limit(top_k).to_pandas() # Format output if "_distance" in results.columns: results["score"] = 1 - results["_distance"] # Convert distance to similarity results = results.sort_values("score", ascending=False) return results[["concept_id", "concept_name", "concept_code", "vocabulary_id", "score"]] # Initialize searcher = None def get_searcher(): global searcher if searcher is None: searcher = DrugConceptSearcher() return searcher def _format_results(results: pd.DataFrame, query: str) -> tuple[str, pd.DataFrame]: if results.empty: return "No results found. Try a different search term.", results output = f"## Results for: \"{query}\"\n\n" best = results.iloc[0] output += f"**Top match:** {best['concept_name']} (score {best['score']:.4f})\n\n" return output, results def search_drugs(query: str, top_k: int): """Gradio search function (single query)""" try: s = get_searcher() results = s.search(query, top_k) output, table = _format_results(results, query) return output, table except Exception as e: print("Search error:", e) print(traceback.format_exc()) if DEBUG: return f"❌ Error: {str(e)}\n\n```\n{traceback.format_exc()}\n```", pd.DataFrame() return f"❌ Error: {str(e)}", pd.DataFrame() def search_batch(queries_text: str, top_k: int): """Gradio search function (batch queries)""" try: if not queries_text or not queries_text.strip(): return "Please enter clinical terms to search.", gr.update(visible=False) lines = [line.strip() for line in queries_text.splitlines() if line.strip()] if not lines: return "No valid queries found.", gr.update(visible=False) s = get_searcher() rows = [] for q in lines: results = s.search(q, top_k) for i, (_, row) in enumerate(results.iterrows(), start=1): rows.append( { "query_text": q, "rank": i, "concept_id": row["concept_id"], "concept_name": row["concept_name"], "concept_code": row["concept_code"], "vocabulary_id": row["vocabulary_id"], "score": float(row["score"]), } ) if not rows: return "No results found.", gr.update(visible=False) df = pd.DataFrame(rows) tmp_dir = Path(tempfile.gettempdir()) / "thirawat_results" tmp_dir.mkdir(parents=True, exist_ok=True) out_path = tmp_dir / "batch_results.csv" df.to_csv(out_path, index=False) md = f"""## Batch Search Complete - **Queries processed:** {len(lines)} - **Rows returned:** {len(rows)} - **Top-K per query:** {top_k} """ return md, gr.update(value=str(out_path), visible=True) except Exception as e: print("Batch search error:", e) print(traceback.format_exc()) if DEBUG: return f"❌ Error: {str(e)}\n\n```\n{traceback.format_exc()}\n```", gr.update(visible=False) return f"❌ Error: {str(e)}", gr.update(visible=False) # ===== GRADIO INTERFACE ===== with gr.Blocks(title="THIRAWAT - Drug Concept Search") as demo: gr.HTML( """
Drug Concept Entity Linking
Map drug names to OMOP concepts using SapBERT + LanceDB.