File size: 1,223 Bytes
a4c9d33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch

def predict_fn(data, model_and_tokenizer):
    # destruct model and tokenizer
    model, tokenizer = model_and_tokenizer
    
    # Tokenize sentences
    sentences = data.pop("inputs", data)
    encoded_input = tokenizer(sentences, add_special_tokens=False,return_tensors='pt')
    input_id_chunks = list(encoded_input['input_ids'][0].split(510))
    mask_chunks = list(encoded_input['attention_mask'][0].split(510))
    for i in range(len(input_id_chunks)):
      input_id_chunks[i]=torch.cat([torch.Tensor([101]),input_id_chunks[i],torch.Tensor([102])])
      mask_chunks[i] = torch.cat([
          torch.Tensor([1]), mask_chunks[i], torch.Tensor([1])
      ])
      pad_len = 512 - input_id_chunks[i].shape[0]
      if pad_len > 0:
          input_id_chunks[i] = torch.cat([input_id_chunks[i],torch.Tensor([0]*pad_len)])
          mask_chunks[i] = torch.cat([mask_chunks[i],torch.Tensor([0]*pad_len)])

      input_ids = torch.stack(input_id_chunks)
      attention_masks = torch.stack(mask_chunks)

      input_dict = {
          'input_ids': input_ids.long(),
          'attention_mask': attention_masks.int()
      }
      output = model(**input_dict)
      print("inference.py")
      return output