File size: 503 Bytes
cf9604f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from transformers import pipeline

class EndpointHandler:
    def __init__(self, path=""):
        self.pipeline = pipeline(
            "text-classification",
            model=path,
            tokenizer=path,
            top_k=None,
            truncation=True,
            max_length=512,
        )

    def __call__(self, data):
        inputs = data.get("inputs", [])
        if isinstance(inputs, str):
            inputs = [inputs]
        results = self.pipeline(inputs)
        return results