from typing import Dict, List, Any from sentence_transformers import SentenceTransformer import torch class EndpointHandler: def __init__(self, path: str = ""): """ Initialise the SentenceTransformer model. `path` is the local directory where model artifacts are stored. """ self.model = SentenceTransformer(path, device="cuda" if torch.cuda.is_available() else "cpu") self.model.eval() def __call__(self, data: Dict[str, Any]) -> Dict[str, List]: """ Main inference entry point. Expected request body (single or batch): { "inputs": "A single sentence" } { "inputs": ["Sentence one", "Sentence two"] } Optional parameters: { "inputs": [...], "parameters": { "normalize_embeddings": true, "batch_size": 32 } } Returns: { "embeddings": [[float, ...], ...] } """ inputs = data.get("inputs") if inputs is None: raise ValueError("`inputs` field is required in the request body.") # Normalise to a list if isinstance(inputs, str): inputs = [inputs] parameters = data.get("parameters", {}) normalize_embeddings = parameters.get("normalize_embeddings", True) batch_size = parameters.get("batch_size", 32) instruction = parameters.get("instruction", "") if instruction: inputs = [instruction + input for input in inputs] with torch.no_grad(): embeddings = self.model.encode( inputs, batch_size=batch_size, normalize_embeddings=normalize_embeddings, convert_to_numpy=True, ) return {"embeddings": embeddings.tolist()}