from transformers import Pipeline import torch class DeclarativePipeline(Pipeline): def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} return preprocess_kwargs, {}, {} def preprocess(self, inputs): question = inputs.split("#")[0] answer = inputs.split("#")[1] return self.tokenizer( question, answer, add_special_tokens=True, max_length=512, padding='max_length', truncation='only_first', return_attention_mask=True, return_tensors=self.framework ) def _forward(self, model_inputs): self.model.eval() with torch.no_grad(): generate_ids = self.model.generate( input_ids=model_inputs["input_ids"], attention_mask=model_inputs["attention_mask"], max_length=128, num_beams=4, num_return_sequences=1, no_repeat_ngram_size=2, early_stopping=True, ) return generate_ids def postprocess(self, model_outputs): preds = [ self.tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) for gen_id in model_outputs ] return "".join(preds)