import torch def predict_fn(data, model_and_tokenizer): # destruct model and tokenizer model, tokenizer = model_and_tokenizer # Tokenize sentences sentences = data.pop("inputs", data) encoded_input = tokenizer(sentences, add_special_tokens=False,return_tensors='pt') input_id_chunks = list(encoded_input['input_ids'][0].split(510)) mask_chunks = list(encoded_input['attention_mask'][0].split(510)) for i in range(len(input_id_chunks)): input_id_chunks[i]=torch.cat([torch.Tensor([101]),input_id_chunks[i],torch.Tensor([102])]) mask_chunks[i] = torch.cat([ torch.Tensor([1]), mask_chunks[i], torch.Tensor([1]) ]) pad_len = 512 - input_id_chunks[i].shape[0] if pad_len > 0: input_id_chunks[i] = torch.cat([input_id_chunks[i],torch.Tensor([0]*pad_len)]) mask_chunks[i] = torch.cat([mask_chunks[i],torch.Tensor([0]*pad_len)]) input_ids = torch.stack(input_id_chunks) attention_masks = torch.stack(mask_chunks) input_dict = { 'input_ids': input_ids.long(), 'attention_mask': attention_masks.int() } output = model(**input_dict) print("inference.py") return output