File size: 1,334 Bytes
cf6c975 211926f cf6c975 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
from transformers import Pipeline
import torch
class DeclarativePipeline(Pipeline):
def _sanitize_parameters(self, *args):
return args, {}, {}
def preprocess(self, inputs, question, answer):
return self.tokenizer(
question,
answer,
add_special_tokens=True,
max_length=512,
padding='max_length',
truncation='only_first',
return_attention_mask=True,
return_tensors=self.framework
)
def _forward(self, model_inputs):
self.model.eval()
with torch.no_grad():
generate_ids = self.model.generate(
input_ids=model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
max_length=128,
num_beams=4,
num_return_sequences=1,
no_repeat_ngram_size=2,
early_stopping=True,
)
return generate_ids
def postprocess(self, model_outputs):
preds = [
self.tokenizer.decode(gen_id,
skip_special_tokens=True,
clean_up_tokenization_spaces=True)
for gen_id in model_outputs
]
return "".join(preds)
|