|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
from torch.nn.functional import softmax |
|
|
|
|
|
import torch |
|
|
from typing import Any, Dict, List |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(path) |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
self.model.config.pad_token_id = self.tokenizer.pad_token_id |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
batch_of_strings = data["inputs"] |
|
|
|
|
|
tokens = self.tokenizer( |
|
|
batch_of_strings, padding=True, truncation=True, return_tensors="pt" |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**tokens) |
|
|
|
|
|
probabilities = softmax(outputs.logits, dim=1) |
|
|
|
|
|
return { |
|
|
"predictions": [pred[0] for pred in probabilities.tolist()], |
|
|
} |