File size: 1,443 Bytes
cf6c975 51592ee bb39afc cf6c975 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from transformers import Pipeline
import torch
class DeclarativePipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
return preprocess_kwargs, {}, {}
def preprocess(self, inputs):
question = inputs.split("#")[0]
answer = inputs.split("#")[1]
return self.tokenizer(
question,
answer,
add_special_tokens=True,
max_length=512,
padding='max_length',
truncation='only_first',
return_attention_mask=True,
return_tensors=self.framework
)
def _forward(self, model_inputs):
self.model.eval()
with torch.no_grad():
generate_ids = self.model.generate(
input_ids=model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
max_length=128,
num_beams=4,
num_return_sequences=1,
no_repeat_ngram_size=2,
early_stopping=True,
)
return generate_ids
def postprocess(self, model_outputs):
preds = [
self.tokenizer.decode(gen_id,
skip_special_tokens=True,
clean_up_tokenization_spaces=True)
for gen_id in model_outputs
]
return "".join(preds)
|