T2FPipeline / testing.py
DhaneshV's picture
Update testing.py
211926f
raw
history blame
1.33 kB
from transformers import Pipeline
import torch
class DeclarativePipeline(Pipeline):
def _sanitize_parameters(self, *args):
return args, {}, {}
def preprocess(self, inputs, question, answer):
return self.tokenizer(
question,
answer,
add_special_tokens=True,
max_length=512,
padding='max_length',
truncation='only_first',
return_attention_mask=True,
return_tensors=self.framework
)
def _forward(self, model_inputs):
self.model.eval()
with torch.no_grad():
generate_ids = self.model.generate(
input_ids=model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
max_length=128,
num_beams=4,
num_return_sequences=1,
no_repeat_ngram_size=2,
early_stopping=True,
)
return generate_ids
def postprocess(self, model_outputs):
preds = [
self.tokenizer.decode(gen_id,
skip_special_tokens=True,
clean_up_tokenization_spaces=True)
for gen_id in model_outputs
]
return "".join(preds)