| import logging | |
| from datetime import datetime | |
| from typing import Dict, List, AnyStr | |
| from sentence_transformers import CrossEncoder | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| class EndpointHandler(): | |
| def __init__(self, path=""): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.cross_encoder = CrossEncoder(path, device=device) | |
| def __call__(self, data: Dict[str, AnyStr]) -> Dict[str, List[float]]: | |
| """ | |
| Args: | |
| data (Dict[str, AnyStr]): A dictionary containing the input data and parameters for inference. | |
| The input data should include a "query" and a list of "passages". | |
| Return: | |
| Dict[str, List[float]]: A dictionary with a single key "scores", containing a list of floating point numbers. | |
| Each number represents the score of a passage for the given query. The order of the scores matches the order of the passages. | |
| """ | |
| inputs = data.get("inputs") | |
| query = inputs.get("query") | |
| passages = inputs.get("passages") | |
| logger.info(f"Query: {query}") | |
| logger.info(f"N. of passages: {len(passages)}") | |
| start_time = datetime.now() | |
| scores = self.cross_encoder.predict([(query, passage) for passage in passages], activation_fct=torch.nn.Sigmoid()) | |
| logger.info(f"Time to run cross-encoder for query '{query}' with {len(passages)} passages: {datetime.now() - start_time}") | |
| logger.info(f"Scores: {scores}") | |
| return { | |
| "scores": scores | |
| } | |