File size: 1,443 Bytes
cf6c975
 
 
 
 
51592ee
 
 
 
 
bb39afc
 
 
cf6c975
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from transformers import Pipeline
import torch

class DeclarativePipeline(Pipeline):

    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        return preprocess_kwargs, {}, {}


    def preprocess(self, inputs):
        question = inputs.split("#")[0]
        answer = inputs.split("#")[1]
        return self.tokenizer(
            question,
            answer,
            add_special_tokens=True,
            max_length=512,
            padding='max_length',
            truncation='only_first',
            return_attention_mask=True,
            return_tensors=self.framework
        )

    def _forward(self, model_inputs):
        self.model.eval()
        with torch.no_grad():
            generate_ids = self.model.generate(
                    input_ids=model_inputs["input_ids"],
                    attention_mask=model_inputs["attention_mask"],
                    max_length=128,
                    num_beams=4,
                    num_return_sequences=1,
                    no_repeat_ngram_size=2,
                    early_stopping=True,
                )
        return generate_ids

    def postprocess(self, model_outputs):
        preds = [
            self.tokenizer.decode(gen_id,
                                  skip_special_tokens=True,
                                  clean_up_tokenization_spaces=True)
            for gen_id in model_outputs
        ]
        return "".join(preds)