File size: 1,453 Bytes
7591fca
 
bdd4a60
7591fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from typing import Literal

from loguru import logger
from sentence_transformers import SentenceTransformer

from scientific_rag.settings import settings


class EmbeddingEncoder:
    def __init__(self) -> None:
        self.model_name = settings.embedding_model_name
        self.device = settings.embedding_device
        self.batch_size = settings.embedding_batch_size

        logger.info(f"Loading embedding model: {self.model_name} on {self.device}")
        self._model = SentenceTransformer(self.model_name, device=self.device)
        logger.info(f"Embedding model loaded. Dimension: {self.embedding_dim}")

    @property
    def embedding_dim(self) -> int:
        return self._model.get_sentence_embedding_dimension()

    def encode(
        self,
        texts: list[str],
        mode: Literal["query", "passage"],
        batch_size: int | None = None,
        show_progress: bool = False,
    ) -> list[list[float]]:
        batch_size = batch_size or self.batch_size

        # specific prefixing for E5/BGE models
        prefix = "query: " if mode == "query" else "passage: "
        prefixed_texts = [f"{prefix}{t}" for t in texts]

        embeddings = self._model.encode(
            prefixed_texts,
            batch_size=batch_size,
            show_progress_bar=show_progress,
            normalize_embeddings=True,
            convert_to_numpy=True,
        )

        return embeddings.tolist()


encoder = EmbeddingEncoder()