File size: 1,790 Bytes
fb433ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b5f761
fb433ce
0b5f761
 
 
fb433ce
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from typing import Dict, List, Any
from sentence_transformers import SentenceTransformer
import torch


class EndpointHandler:
    def __init__(self, path: str = ""):
        """
        Initialise the SentenceTransformer model.
        `path` is the local directory where model artifacts are stored.
        """
        self.model = SentenceTransformer(path, device="cuda" if torch.cuda.is_available() else "cpu")
        self.model.eval()

    def __call__(self, data: Dict[str, Any]) -> Dict[str, List]:
        """
        Main inference entry point.

        Expected request body (single or batch):
            { "inputs": "A single sentence" }
            { "inputs": ["Sentence one", "Sentence two"] }

        Optional parameters:
            { "inputs": [...], "parameters": { "normalize_embeddings": true, "batch_size": 32 } }

        Returns:
            { "embeddings": [[float, ...], ...] }
        """
        inputs = data.get("inputs")
        if inputs is None:
            raise ValueError("`inputs` field is required in the request body.")

        # Normalise to a list
        if isinstance(inputs, str):
            inputs = [inputs]

        parameters = data.get("parameters", {})
        normalize_embeddings = parameters.get("normalize_embeddings", True)
        batch_size = parameters.get("batch_size", 32)
        instruction = parameters.get("instruction", "")

        if instruction:
            inputs = [instruction + input for input in inputs]
        
        with torch.no_grad():
            embeddings = self.model.encode(
                inputs,
                batch_size=batch_size,
                normalize_embeddings=normalize_embeddings,
                convert_to_numpy=True,
            )

        return {"embeddings": embeddings.tolist()}