Spaces:
Sleeping
Sleeping
File size: 1,861 Bytes
db06013 5d0d255 db06013 5d0d255 db06013 5d0d255 db06013 8a3396b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | from sentence_transformers import SentenceTransformer
from typing import List, Union
import numpy as np
import logging
logger = logging.getLogger(__name__)
class Embedder:
def __init__(self, model_name: str = "BAAI/bge-large-en-v1.5", device: str = "cuda"):
self.model_name = model_name
self.device = device
self.model = SentenceTransformer(model_name, device=device)
logger.info(f"Loaded embedding model: {model_name}")
def encode(self, texts: Union[str, List[str]], batch_size: int = 16) -> np.ndarray:
"""Encode texts to embeddings"""
if isinstance(texts, str):
texts = [texts]
embeddings = self.model.encode(
texts,
batch_size=batch_size,
convert_to_numpy=True,
show_progress_bar=len(texts) > 100
)
return embeddings
def encode_queries(self, queries: List[str], batch_size: int = 16) -> np.ndarray:
"""Encode queries with query prefix"""
if not queries:
return np.array([])
# Add query prefix for BGE models
prefixed_queries = [f"Represent this sentence for searching relevant passages: {q}" for q in queries]
return self.encode(prefixed_queries, batch_size)
def encode_passages(self, passages: List[str], batch_size: int = 16) -> np.ndarray:
"""Encode passages with passage prefix"""
if not passages:
return np.array([])
# Add passage prefix for BGE models
prefixed_passages = [f"Represent this sentence for searching relevant passages: {p}" for p in passages]
return self.encode(prefixed_passages, batch_size)
def get_dimension(self) -> int:
"""Get embedding dimension"""
return self.model.get_sentence_embedding_dimension() |