gpt-wiki_85_pct_2 / handler.py
Christian2903's picture
Create handler.py
feeb591
raw
history blame contribute delete
987 Bytes
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.nn.functional import softmax
import torch
from typing import Any, Dict, List
class EndpointHandler:
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForSequenceClassification.from_pretrained(path)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model.config.pad_token_id = self.tokenizer.pad_token_id
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
batch_of_strings = data["inputs"]
tokens = self.tokenizer(
batch_of_strings, padding=True, truncation=True, return_tensors="pt"
)
# Calculate the loss
with torch.no_grad():
outputs = self.model(**tokens)
probabilities = softmax(outputs.logits, dim=1)
return {
"predictions": [pred[0] for pred in probabilities.tolist()],
}