| from transformers import Pipeline | |
| import torch | |
| class DeclarativePipeline(Pipeline): | |
| def _sanitize_parameters(self, *args): | |
| return args, {}, {} | |
| def preprocess(self, inputs, question, answer): | |
| return self.tokenizer( | |
| question, | |
| answer, | |
| add_special_tokens=True, | |
| max_length=512, | |
| padding='max_length', | |
| truncation='only_first', | |
| return_attention_mask=True, | |
| return_tensors=self.framework | |
| ) | |
| def _forward(self, model_inputs): | |
| self.model.eval() | |
| with torch.no_grad(): | |
| generate_ids = self.model.generate( | |
| input_ids=model_inputs["input_ids"], | |
| attention_mask=model_inputs["attention_mask"], | |
| max_length=128, | |
| num_beams=4, | |
| num_return_sequences=1, | |
| no_repeat_ngram_size=2, | |
| early_stopping=True, | |
| ) | |
| return generate_ids | |
| def postprocess(self, model_outputs): | |
| preds = [ | |
| self.tokenizer.decode(gen_id, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True) | |
| for gen_id in model_outputs | |
| ] | |
| return "".join(preds) | |