safe_rag / retriever /embedder.py
goodmodeler's picture
MOD: batch size
5d0d255
from sentence_transformers import SentenceTransformer
from typing import List, Union
import numpy as np
import logging
logger = logging.getLogger(__name__)
class Embedder:
def __init__(self, model_name: str = "BAAI/bge-large-en-v1.5", device: str = "cuda"):
self.model_name = model_name
self.device = device
self.model = SentenceTransformer(model_name, device=device)
logger.info(f"Loaded embedding model: {model_name}")
def encode(self, texts: Union[str, List[str]], batch_size: int = 16) -> np.ndarray:
"""Encode texts to embeddings"""
if isinstance(texts, str):
texts = [texts]
embeddings = self.model.encode(
texts,
batch_size=batch_size,
convert_to_numpy=True,
show_progress_bar=len(texts) > 100
)
return embeddings
def encode_queries(self, queries: List[str], batch_size: int = 16) -> np.ndarray:
"""Encode queries with query prefix"""
if not queries:
return np.array([])
# Add query prefix for BGE models
prefixed_queries = [f"Represent this sentence for searching relevant passages: {q}" for q in queries]
return self.encode(prefixed_queries, batch_size)
def encode_passages(self, passages: List[str], batch_size: int = 16) -> np.ndarray:
"""Encode passages with passage prefix"""
if not passages:
return np.array([])
# Add passage prefix for BGE models
prefixed_passages = [f"Represent this sentence for searching relevant passages: {p}" for p in passages]
return self.encode(prefixed_passages, batch_size)
def get_dimension(self) -> int:
"""Get embedding dimension"""
return self.model.get_sentence_embedding_dimension()