File size: 1,861 Bytes
db06013
 
 
 
 
 
 
 
 
 
 
 
 
 
5d0d255
db06013
 
 
 
 
 
 
 
 
 
 
 
 
5d0d255
db06013
 
 
 
 
 
 
 
5d0d255
db06013
 
 
 
 
 
 
 
 
 
8a3396b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from sentence_transformers import SentenceTransformer
from typing import List, Union
import numpy as np
import logging

logger = logging.getLogger(__name__)

class Embedder:
    def __init__(self, model_name: str = "BAAI/bge-large-en-v1.5", device: str = "cuda"):
        self.model_name = model_name
        self.device = device
        self.model = SentenceTransformer(model_name, device=device)
        logger.info(f"Loaded embedding model: {model_name}")
    
    def encode(self, texts: Union[str, List[str]], batch_size: int = 16) -> np.ndarray:
        """Encode texts to embeddings"""
        if isinstance(texts, str):
            texts = [texts]
        
        embeddings = self.model.encode(
            texts, 
            batch_size=batch_size,
            convert_to_numpy=True,
            show_progress_bar=len(texts) > 100
        )
        
        return embeddings
    
    def encode_queries(self, queries: List[str], batch_size: int = 16) -> np.ndarray:
        """Encode queries with query prefix"""
        if not queries:
            return np.array([])
        
        # Add query prefix for BGE models
        prefixed_queries = [f"Represent this sentence for searching relevant passages: {q}" for q in queries]
        return self.encode(prefixed_queries, batch_size)
    
    def encode_passages(self, passages: List[str], batch_size: int = 16) -> np.ndarray:
        """Encode passages with passage prefix"""
        if not passages:
            return np.array([])
        
        # Add passage prefix for BGE models
        prefixed_passages = [f"Represent this sentence for searching relevant passages: {p}" for p in passages]
        return self.encode(prefixed_passages, batch_size)
    
    def get_dimension(self) -> int:
        """Get embedding dimension"""
        return self.model.get_sentence_embedding_dimension()