kashif HF Staff commited on
Commit
6411d45
·
verified ·
1 Parent(s): 6300b57

tokenizer: fix decode() to handle torch tensor input via .tolist()

Browse files
Files changed (1) hide show
  1. tokenizer.py +4 -2
tokenizer.py CHANGED
@@ -323,7 +323,7 @@ class HybridDNATokenizer(PreTrainedTokenizer):
323
  else:
324
  base_ids = self._base_tokenizer.encode(
325
  segment_content,
326
- add_special_tokens=False
327
  )
328
  token_ids.extend(base_ids)
329
  if return_token_mask:
@@ -345,6 +345,8 @@ class HybridDNATokenizer(PreTrainedTokenizer):
345
  skip_special_tokens: bool = False,
346
  **kwargs
347
  ) -> str:
 
 
348
  if isinstance(token_ids, int):
349
  token_ids = [token_ids]
350
 
@@ -437,7 +439,7 @@ class HybridDNATokenizer(PreTrainedTokenizer):
437
  UserWarning
438
  )
439
  add_special_tokens = False
440
-
441
  is_batch = isinstance(text, list)
442
  texts = text if is_batch else [text]
443
 
 
323
  else:
324
  base_ids = self._base_tokenizer.encode(
325
  segment_content,
326
+ add_special_tokens=add_special_tokens
327
  )
328
  token_ids.extend(base_ids)
329
  if return_token_mask:
 
345
  skip_special_tokens: bool = False,
346
  **kwargs
347
  ) -> str:
348
+ if hasattr(token_ids, 'tolist'):
349
+ token_ids = token_ids.tolist()
350
  if isinstance(token_ids, int):
351
  token_ids = [token_ids]
352
 
 
439
  UserWarning
440
  )
441
  add_special_tokens = False
442
+
443
  is_batch = isinstance(text, list)
444
  texts = text if is_batch else [text]
445