File size: 2,454 Bytes
fa16bad
fea62df
fa16bad
fea62df
0231daa
fea62df
 
 
 
 
0231daa
fea62df
 
 
 
 
0231daa
fea62df
 
 
 
0231daa
fea62df
 
 
 
0231daa
fea62df
 
0231daa
 
 
fea62df
 
 
 
 
 
fa16bad
fea62df
fa16bad
0231daa
fea62df
fa16bad
fea62df
0231daa
fa16bad
 
fea62df
 
 
0231daa
fea62df
af36df4
 
fea62df
 
 
 
0231daa
 
 
fea62df
fa16bad
0231daa
fea62df
 
 
0231daa
 
fa16bad
fea62df
 
 
0231daa
fea62df
fa16bad
 
fea62df
fa16bad
fea62df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from loguru import logger

from ..src.core.config import ModelConfig


class EmbeddingModel:
    """
    Embedding model wrapper for dense embeddings.

    attributes:
        config: ModelConfig instance
        model: SentenceTransformer instance
        _loaded: Flag indicating if the model is loaded
    """

    def __init__(self, config: ModelConfig):
        self.config = config
        self.model: Optional[SentenceTransformer] = None
        self._loaded = False

    def load(self) -> None:
        """Load the embedding model."""
        if self._loaded:
            return

        logger.info(f"Loading embedding model: {self.config.name}")
        try:
            self.model = SentenceTransformer(
                self.config.name, device="cpu", trust_remote_code=True
            )
            self._loaded = True
            logger.success(f"Loaded embedding model: {self.config.id}")
        except Exception as e:
            logger.error(f"Failed to load embedding model {self.config.id}: {e}")
            raise

    def query_embed(self, text: List[str], prompt: Optional[str] = None) -> List[float]:
        """
        method to generate embedding for a single text.

        Args:
            text: Input text
            prompt: Optional prompt for instruction-based models

        Returns:
            Embedding vector
        """
        if not self._loaded:
            self.load()

        try:
            embeddings = self.model.encode_query(text, prompt=prompt)
            return [embedding.tolist() for embedding in embeddings]
        except Exception as e:
            logger.error(f"Embedding generation failed: {e}")
            raise

    def embed_documents(
        self, texts: List[str], prompt: Optional[str] = None
    ) -> List[List[float]]:
        """
        method to generate embeddings for a list of texts.

        Args:
            texts: List of input texts
            prompt: Optional prompt for instruction-based models

        Returns:
        List of embedding vectors
        """
        if not self._loaded:
            self.load()

        try:
            embeddings = self.model.encode_document(texts, prompt=prompt)
            return [embedding.tolist() for embedding in embeddings]
        except Exception as e:
            logger.error(f"Embedding generation failed: {e}")
            raise