MedSpace / src /embeddings /embedding_models.py
kbsss's picture
Upload folder using huggingface_hub
f373e2b verified
Raw
History Blame Contribute Delete
3.37 kB
"""
Medical embedding models for semantic search.
"""
import torch
import numpy as np
from typing import List, Union, Optional
from sentence_transformers import SentenceTransformer
from pathlib import Path
import os
class MedicalEmbedder:
"""
Medical domain embedding model wrapper.
Supports: MedCPT, PubMedBERT, BioBERT, BGE-M3
"""
SUPPORTED_MODELS = {
"medcpt-query": "ncbi/MedCPT-Query-Encoder",
"medcpt-article": "ncbi/MedCPT-Article-Encoder",
"pubmedbert": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
"biobert": "dmis-lab/biobert-v1.1",
"bge-small": "BAAI/bge-small-en-v1.5",
"all-minilm": "sentence-transformers/all-MiniLM-L6-v2" # Fallback, fast
}
def __init__(
self,
model_name: str = "all-minilm", # Default to fast model for testing
device: Optional[str] = None,
cache_dir: Optional[str] = None
):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
# Get model path
if model_name in self.SUPPORTED_MODELS:
model_path = self.SUPPORTED_MODELS[model_name]
else:
model_path = model_name
self.model_name = model_name
print(f"🔄 Loading embedding model: {model_path} on {self.device}")
try:
self.model = SentenceTransformer(
model_path,
device=self.device,
cache_folder=cache_dir
)
print(f"✅ Model loaded. Dimension: {self.embedding_dimension}")
except Exception as e:
print(f"⚠️ Failed to load {model_path}, falling back to all-MiniLM")
self.model = SentenceTransformer(
self.SUPPORTED_MODELS["all-minilm"],
device=self.device
)
def embed(
self,
texts: Union[str, List[str]],
batch_size: int = 32,
show_progress: bool = True,
normalize: bool = True
) -> np.ndarray:
"""Generate embeddings for texts."""
if isinstance(texts, str):
texts = [texts]
embeddings = self.model.encode(
texts,
batch_size=batch_size,
show_progress_bar=show_progress,
convert_to_numpy=True,
normalize_embeddings=normalize
)
return embeddings
def embed_query(self, query: str) -> np.ndarray:
"""Embed a single query."""
return self.embed(query, show_progress=False)[0]
def embed_documents(self, documents: List[str], batch_size: int = 32) -> np.ndarray:
"""Embed multiple documents."""
return self.embed(documents, batch_size=batch_size, show_progress=True)
@property
def embedding_dimension(self) -> int:
"""Get embedding dimension."""
return self.model.get_sentence_embedding_dimension()
def similarity(self, query: str, documents: List[str]) -> np.ndarray:
"""Calculate similarity between query and documents."""
query_emb = self.embed_query(query)
doc_embs = self.embed_documents(documents, batch_size=32)
# Cosine similarity (embeddings are normalized)
similarities = np.dot(doc_embs, query_emb)
return similarities