| """ |
| Medical embedding models for semantic search. |
| """ |
| import torch |
| import numpy as np |
| from typing import List, Union, Optional |
| from sentence_transformers import SentenceTransformer |
| from pathlib import Path |
| import os |
|
|
| class MedicalEmbedder: |
| """ |
| Medical domain embedding model wrapper. |
| Supports: MedCPT, PubMedBERT, BioBERT, BGE-M3 |
| """ |
| |
| SUPPORTED_MODELS = { |
| "medcpt-query": "ncbi/MedCPT-Query-Encoder", |
| "medcpt-article": "ncbi/MedCPT-Article-Encoder", |
| "pubmedbert": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", |
| "biobert": "dmis-lab/biobert-v1.1", |
| "bge-small": "BAAI/bge-small-en-v1.5", |
| "all-minilm": "sentence-transformers/all-MiniLM-L6-v2" |
| } |
| |
| def __init__( |
| self, |
| model_name: str = "all-minilm", |
| device: Optional[str] = None, |
| cache_dir: Optional[str] = None |
| ): |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| |
| |
| if model_name in self.SUPPORTED_MODELS: |
| model_path = self.SUPPORTED_MODELS[model_name] |
| else: |
| model_path = model_name |
| |
| self.model_name = model_name |
| print(f"🔄 Loading embedding model: {model_path} on {self.device}") |
| |
| try: |
| self.model = SentenceTransformer( |
| model_path, |
| device=self.device, |
| cache_folder=cache_dir |
| ) |
| print(f"✅ Model loaded. Dimension: {self.embedding_dimension}") |
| except Exception as e: |
| print(f"⚠️ Failed to load {model_path}, falling back to all-MiniLM") |
| self.model = SentenceTransformer( |
| self.SUPPORTED_MODELS["all-minilm"], |
| device=self.device |
| ) |
| |
| def embed( |
| self, |
| texts: Union[str, List[str]], |
| batch_size: int = 32, |
| show_progress: bool = True, |
| normalize: bool = True |
| ) -> np.ndarray: |
| """Generate embeddings for texts.""" |
| if isinstance(texts, str): |
| texts = [texts] |
| |
| embeddings = self.model.encode( |
| texts, |
| batch_size=batch_size, |
| show_progress_bar=show_progress, |
| convert_to_numpy=True, |
| normalize_embeddings=normalize |
| ) |
| |
| return embeddings |
| |
| def embed_query(self, query: str) -> np.ndarray: |
| """Embed a single query.""" |
| return self.embed(query, show_progress=False)[0] |
| |
| def embed_documents(self, documents: List[str], batch_size: int = 32) -> np.ndarray: |
| """Embed multiple documents.""" |
| return self.embed(documents, batch_size=batch_size, show_progress=True) |
| |
| @property |
| def embedding_dimension(self) -> int: |
| """Get embedding dimension.""" |
| return self.model.get_sentence_embedding_dimension() |
| |
| def similarity(self, query: str, documents: List[str]) -> np.ndarray: |
| """Calculate similarity between query and documents.""" |
| query_emb = self.embed_query(query) |
| doc_embs = self.embed_documents(documents, batch_size=32) |
| |
| |
| similarities = np.dot(doc_embs, query_emb) |
| return similarities |
|
|