Taykhoom commited on
Commit
d0628b6
·
verified ·
1 Parent(s): 864eca9

Upload tokenization_rnamsm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tokenization_rnamsm.py +5 -4
tokenization_rnamsm.py CHANGED
@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Union
4
 
5
  import torch
6
  from transformers import PreTrainedTokenizer
 
7
 
8
 
9
  _VOCAB = {
@@ -154,9 +155,9 @@ class RNAMSMTokenizer(PreTrainedTokenizer):
154
  if return_tensors == "pt":
155
  input_ids = torch.tensor(input_ids, dtype=torch.long)
156
  attention_mask = torch.tensor(attention_mask, dtype=torch.long)
157
- return {"input_ids": input_ids, "attention_mask": attention_mask}
158
 
159
- return {"input_ids": input_ids, "attention_mask": attention_mask}
160
 
161
  def _tokenize_single(self, sequence, add_special_tokens=True):
162
  tokens = list(sequence)
@@ -223,9 +224,9 @@ class RNAMSMTokenizer(PreTrainedTokenizer):
223
  if return_tensors == "pt":
224
  batch_ids = torch.tensor(batch_ids, dtype=torch.long)
225
  batch_mask = torch.tensor(batch_mask, dtype=torch.long)
226
- return {"input_ids": batch_ids, "attention_mask": batch_mask}
227
 
228
- return {"input_ids": batch_ids, "attention_mask": batch_mask}
229
 
230
  def decode(self, token_ids, skip_special_tokens=False, **kwargs):
231
  if isinstance(token_ids, torch.Tensor):
 
4
 
5
  import torch
6
  from transformers import PreTrainedTokenizer
7
+ from transformers.tokenization_utils_base import BatchEncoding
8
 
9
 
10
  _VOCAB = {
 
155
  if return_tensors == "pt":
156
  input_ids = torch.tensor(input_ids, dtype=torch.long)
157
  attention_mask = torch.tensor(attention_mask, dtype=torch.long)
158
+ return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask})
159
 
160
+ return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask})
161
 
162
  def _tokenize_single(self, sequence, add_special_tokens=True):
163
  tokens = list(sequence)
 
224
  if return_tensors == "pt":
225
  batch_ids = torch.tensor(batch_ids, dtype=torch.long)
226
  batch_mask = torch.tensor(batch_mask, dtype=torch.long)
227
+ return BatchEncoding({"input_ids": batch_ids, "attention_mask": batch_mask})
228
 
229
+ return BatchEncoding({"input_ids": batch_ids, "attention_mask": batch_mask})
230
 
231
  def decode(self, token_ids, skip_special_tokens=False, **kwargs):
232
  if isinstance(token_ids, torch.Tensor):