T2FPipeline / testing.py
dhanesh
commit files to HF hub
cf6c975
raw
history blame
1.57 kB
from transformers import Pipeline
import torch
class DeclarativePipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "answer" in kwargs:
preprocess_kwargs["answer"] = kwargs["answer"]
if "question" in kwargs:
preprocess_kwargs["question"] = kwargs["question"]
return preprocess_kwargs, {}, {}
def preprocess(self, inputs, question, answer):
return self.tokenizer(
question,
answer,
add_special_tokens=True,
max_length=512,
padding='max_length',
truncation='only_first',
return_attention_mask=True,
return_tensors=self.framework
)
def _forward(self, model_inputs):
self.model.eval()
with torch.no_grad():
generate_ids = self.model.generate(
input_ids=model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
max_length=128,
num_beams=4,
num_return_sequences=1,
no_repeat_ngram_size=2,
early_stopping=True,
)
return generate_ids
def postprocess(self, model_outputs):
preds = [
self.tokenizer.decode(gen_id,
skip_special_tokens=True,
clean_up_tokenization_spaces=True)
for gen_id in model_outputs
]
return "".join(preds)