File size: 307 Bytes
7c7c54d
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
class CustomProvider:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def __call__(self, text):
        inputs = self.tokenizer(text, return_tensors="pt")
        outputs = self.model(**inputs)
        return outputs.logits.argmax(dim=-1).item()