Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from typing import List | |
| from sentence_transformers import SentenceTransformer | |
| from tqdm import tqdm | |
| class EmbeddingManager: | |
| def __init__(self, model_name: str = "pritamdeka/S-BioBERT-snli-multinli-stsb"): | |
| self.model_name = model_name | |
| self.model = None | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| self.load_model() | |
| def load_model(self): | |
| print("Loading embedding model:", self.model_name) | |
| print('Using device', self.device) | |
| self.model = SentenceTransformer(model_name_or_path=self.model_name, device=self.device) | |
| print("Model loaded.") | |
| def get_model(self): | |
| return self.model | |
| def embed_texts(self, texts: List[str], batch_size: int = 16) -> np.ndarray: | |
| if self.model is None: | |
| raise RuntimeError("Model not loaded. Call load_model() first.") | |
| embeddings = [] | |
| for i in tqdm(range(0, len(texts), batch_size), desc="Embedding texts"): | |
| batch = texts[i:i + batch_size] | |
| emb = self.model.encode(batch, batch_size=batch_size, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True) | |
| embeddings.extend(emb) | |
| return np.vstack(embeddings) | |
| def embed_query(self, text: str) -> np.ndarray: | |
| return self.model.encode(text, convert_to_numpy=True, normalize_embeddings=True).flatten() |