BERT-Sentiment / handler.py
Christian2903's picture
Create handler.py
399efb5
raw
history blame contribute delete
907 Bytes
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import DataCollatorWithPadding
from torch.nn.functional import softmax
import torch
from typing import Any, Dict, List
class EndpointHandler:
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForSequenceClassification.from_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
batch_of_strings = data["inputs"]
tokens = self.tokenizer(
batch_of_strings, padding=True, truncation=True, return_tensors="pt"
)
# Calculate the loss
with torch.no_grad():
outputs = self.model(**tokens)
probabilities = softmax(outputs.logits, dim=1)
return {
"predictions": [pred[0] for pred in probabilities.tolist()],
}