Adapt tokenization_interns1.py to transformers>=5.0.0

#15
Files changed (1) hide show
  1. tokenization_interns1.py +6 -6
tokenization_interns1.py CHANGED
@@ -25,10 +25,12 @@ import regex as re
25
  import sentencepiece as spm
26
 
27
  from transformers.tokenization_utils_base import AddedToken, TextInput
28
- from transformers.tokenization_utils import PreTrainedTokenizer
29
  from transformers.utils import logging
30
- # from transformers.utils.import_utils import requires
31
-
 
 
 
32
 
33
  logger = logging.get_logger(__name__)
34
 
@@ -566,6 +568,7 @@ class InternS1Tokenizer(PreTrainedTokenizer):
566
  pad_token=pad_token,
567
  clean_up_tokenization_spaces=clean_up_tokenization_spaces,
568
  split_special_tokens=split_special_tokens,
 
569
  **kwargs,
570
  )
571
 
@@ -715,9 +718,6 @@ class InternS1Tokenizer(PreTrainedTokenizer):
715
 
716
  text, kwargs = self.prepare_for_tokenization(text, **kwargs)
717
 
718
- if kwargs:
719
- logger.warning(f"Keyword arguments {kwargs} not recognized.")
720
-
721
  if hasattr(self, "do_lower_case") and self.do_lower_case:
722
  # convert non-special tokens to lowercase. Might be super slow as well?
723
  escaped_special_toks = [re.escape(s_tok) for s_tok in (self.all_special_tokens)]
 
25
  import sentencepiece as spm
26
 
27
  from transformers.tokenization_utils_base import AddedToken, TextInput
 
28
  from transformers.utils import logging
29
+ from packaging import version
30
+ if version.parse(transformers.__version__) >= version.parse("5.0.0"):
31
+ from transformers.tokenization_python import PreTrainedTokenizer
32
+ else:
33
+ from transformers.tokenization_utils import PreTrainedTokenizer
34
 
35
  logger = logging.get_logger(__name__)
36
 
 
568
  pad_token=pad_token,
569
  clean_up_tokenization_spaces=clean_up_tokenization_spaces,
570
  split_special_tokens=split_special_tokens,
571
+ special_tokens_pattern="bos_eos",
572
  **kwargs,
573
  )
574
 
 
718
 
719
  text, kwargs = self.prepare_for_tokenization(text, **kwargs)
720
 
 
 
 
721
  if hasattr(self, "do_lower_case") and self.do_lower_case:
722
  # convert non-special tokens to lowercase. Might be super slow as well?
723
  escaped_special_toks = [re.escape(s_tok) for s_tok in (self.all_special_tokens)]