Spaces:
Sleeping
Sleeping
| from sentence_transformers import SentenceTransformer | |
| from typing import List, Union | |
| import numpy as np | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class Embedder: | |
| def __init__(self, model_name: str = "BAAI/bge-large-en-v1.5", device: str = "cuda"): | |
| self.model_name = model_name | |
| self.device = device | |
| self.model = SentenceTransformer(model_name, device=device) | |
| logger.info(f"Loaded embedding model: {model_name}") | |
| def encode(self, texts: Union[str, List[str]], batch_size: int = 16) -> np.ndarray: | |
| """Encode texts to embeddings""" | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| embeddings = self.model.encode( | |
| texts, | |
| batch_size=batch_size, | |
| convert_to_numpy=True, | |
| show_progress_bar=len(texts) > 100 | |
| ) | |
| return embeddings | |
| def encode_queries(self, queries: List[str], batch_size: int = 16) -> np.ndarray: | |
| """Encode queries with query prefix""" | |
| if not queries: | |
| return np.array([]) | |
| # Add query prefix for BGE models | |
| prefixed_queries = [f"Represent this sentence for searching relevant passages: {q}" for q in queries] | |
| return self.encode(prefixed_queries, batch_size) | |
| def encode_passages(self, passages: List[str], batch_size: int = 16) -> np.ndarray: | |
| """Encode passages with passage prefix""" | |
| if not passages: | |
| return np.array([]) | |
| # Add passage prefix for BGE models | |
| prefixed_passages = [f"Represent this sentence for searching relevant passages: {p}" for p in passages] | |
| return self.encode(prefixed_passages, batch_size) | |
| def get_dimension(self) -> int: | |
| """Get embedding dimension""" | |
| return self.model.get_sentence_embedding_dimension() |