File size: 1,362 Bytes
9c37331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from typing import Union, List
from langchain.embeddings import HuggingFaceEmbeddings

class EmbeddingModel:
    """
    A flexible embedding model wrapper supporting multiple backend models.
    """

    def __init__(self, model_type: str = "huggingface", model_name: str = "all-MiniLM-L6-v2"):
        self.model_type = model_type
        self.model_name = model_name
        self.model = self._load_model()

    def _load_model(self):
        if self.model_type == "huggingface":
            return HuggingFaceEmbeddings(model_name=self.model_name)

        # Implementation for other model types can be added here

        else:
            raise ValueError(f"Unsupported model type: {self.model_type}")

    def embed(self, text: Union[str, List[str]]):
        """
        Generate embeddings for the given text.

        :param text: A string or list of strings.
        :return: A list of embeddings.
        """
        if self.model_type == "huggingface":

            if isinstance(text, list):
                return [self.model.embed_query(t) for t in text]
            return self.model.embed_query(text)

        elif self.model_type == "sentence_transformers":
            return self.model.encode(text, convert_to_tensor=True).tolist()

        else:
            raise NotImplementedError(f"Embedding for {self.model_type} is not implemented.")