File size: 720 Bytes
22d1f99
 
 
 
 
 
 
 
 
 
9c92905
22d1f99
 
 
 
9ea6a04
22d1f99
c6afee5
8ebf3e6
9c92905
9ea6a04
c6afee5
22d1f99
c6afee5
22d1f99
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from transformers.utils import logging
from FlagEmbedding import FlagReranker

logging.set_verbosity_info()
logger = logging.get_logger("transformers")
logger.info("INFO")


class EndpointHandler:
    def __init__(self, path=""):
        self.reranker = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) 

    def __call__(self, inputs):
        data = inputs['inputs']
        logger.info("Inference started")
        # logger.info(data)
        logger.info(type(data))
        scores = []
        for t in data['texts']:
            score = self.reranker.compute_score([data['query'], t])
            logger.info(score)
            scores.append(score)

        output = {"scores": scores}
        return output