flash_gen_bert_para / handler.py
cguynup's picture
Upload 2 files
ed8a7fb
raw
history blame
790 Bytes
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer
import torch
class EndpointHandler():
def __init__(self, path=""):
# load the optimized model
self.model = ORTModelForSequenceClassification.from_pretrained(path)
self.tokenizer = AutoTokenizer.from_pretrained(path)
def __call__(self, data):
answers = data.pop("answers")
paraphrases = data.pop("paraphrases")
inputs = self.tokenizer(answers, paraphrases, max_length=253, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1).numpy()
return list(predictions)