File size: 1,395 Bytes
e9989e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37

from sentence_transformers import SentenceTransformer
from typing import Dict, List, Any, Union

class EndpointHandler:
    def __init__(self, model_path="bge-large-en/"):
        # Preload all the elements you are going to need at inference
        self.model = SentenceTransformer(model_path)

    def __call__(self, data: Dict[str, Any]) -> Union[List[List[float]], List[float]]:
        """
        data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
        Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        # Extracting the inputs and kwargs
        inputs = data["inputs"]
        kwargs = data.get("kwargs", {})
        normalize_embeddings = kwargs.get('normalize_embeddings', True)

        # Determine if the input is a query or a passage
        is_query = kwargs.get("is_query", False)
        if is_query:
            instruction = kwargs.get("query_instruction", "")
            if isinstance(inputs, list):
                inputs = [instruction + q for q in inputs]
            else:
                inputs = instruction + inputs

        # Encoding the inputs using the model
        embeddings = self.model.encode(inputs, normalize_embeddings=normalize_embeddings)

        # Return the serialized embeddings
        return embeddings.tolist() if isinstance(embeddings, list) else embeddings