Adapt tokenization_interns1.py to transformers>=5.0.0

#10
by Zhangyc02 - opened
Files changed (1) hide show
  1. tokenization_interns1.py +151 -10
tokenization_interns1.py CHANGED
@@ -14,9 +14,10 @@
14
  # limitations under the License.
15
  """Tokenization classes for InternS1."""
16
 
17
- from typing import Union, Dict, List, Optional, Tuple
18
  import json
19
  import os
 
20
  from functools import lru_cache
21
  from abc import ABC, abstractmethod
22
  import regex as re
@@ -25,22 +26,26 @@ import sentencepiece as spm
25
  from collections import OrderedDict
26
 
27
  from transformers.tokenization_utils_base import AddedToken, TextInput
28
- from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
29
  from transformers.utils import logging
 
 
 
 
 
 
30
 
31
 
32
  logger = logging.get_logger(__name__)
33
 
34
  try:
35
- from rdkit import Chem
36
- from rdkit import RDLogger
37
 
38
  RDLogger.DisableLog("rdApp.error")
39
  RDLogger.DisableLog("rdApp.*")
40
  RDKIT_AVAILABLE = True
41
  except ImportError:
42
  logger.warning_once(
43
- f"If tokenization with SMILES formula is of necessity, please 'pip install RDKit' for better tokenization quality."
44
  )
45
  RDKIT_AVAILABLE = False
46
 
@@ -341,7 +346,48 @@ class SmilesCheckModule(InternS1CheckModuleMixin):
341
  return self.check_brackets(text)
342
 
343
 
344
- class InternS1Tokenizer(Qwen2Tokenizer):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  """
346
  Construct an InternS1 tokenizer. Based on byte-level Byte-Pair-Encoding.
347
 
@@ -404,8 +450,57 @@ class InternS1Tokenizer(Qwen2Tokenizer):
404
  pad_token="<|endoftext|>",
405
  clean_up_tokenization_spaces=False,
406
  split_special_tokens=False,
 
407
  **kwargs,
408
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  self.extra_tokenizer_start_mapping = {}
410
  self.extra_tokenizer_end_mapping = {}
411
  self._extra_special_tokens = []
@@ -458,6 +553,7 @@ class InternS1Tokenizer(Qwen2Tokenizer):
458
  pad_token=pad_token,
459
  clean_up_tokenization_spaces=clean_up_tokenization_spaces,
460
  split_special_tokens=split_special_tokens,
 
461
  **kwargs,
462
  )
463
 
@@ -495,6 +591,10 @@ class InternS1Tokenizer(Qwen2Tokenizer):
495
  """Overload method"""
496
  return self.vocab_size
497
 
 
 
 
 
498
  @property
499
  def logical_auto_tokens(self):
500
  """Tokens that won't be decoded and only for switching tokenizer"""
@@ -631,9 +731,6 @@ class InternS1Tokenizer(Qwen2Tokenizer):
631
 
632
  text, kwargs = self.prepare_for_tokenization(text, **kwargs)
633
 
634
- if kwargs:
635
- logger.warning(f"Keyword arguments {kwargs} not recognized.")
636
-
637
  if hasattr(self, "do_lower_case") and self.do_lower_case:
638
  # convert non-special tokens to lowercase. Might be super slow as well?
639
  escaped_special_toks = [re.escape(s_tok) for s_tok in (self.all_special_tokens)]
@@ -783,6 +880,7 @@ class InternS1Tokenizer(Qwen2Tokenizer):
783
  self._added_tokens_encoder[token.content] = token_index
784
  if self.verbose:
785
  logger.info(f"Adding {token} to the vocabulary")
 
786
  self._update_trie()
787
  self._update_total_vocab_size()
788
 
@@ -812,6 +910,49 @@ class InternS1Tokenizer(Qwen2Tokenizer):
812
  else:
813
  return self._bpe_tokenize(text)
814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
815
  def _bpe_tokenize(self, text, **kwargs):
816
  text = text.replace(
817
  "▁", " "
@@ -894,7 +1035,7 @@ class InternS1Tokenizer(Qwen2Tokenizer):
894
  def _convert_id_to_token(self, index):
895
  """Converts an index (integer) in a token (str) using the vocab."""
896
  return self.decoder.get(index, "")
897
-
898
  def convert_tokens_to_string(self, tokens):
899
  """Converts a sequence of tokens (string) in a single string."""
900
  text = "".join(tokens)
 
14
  # limitations under the License.
15
  """Tokenization classes for InternS1."""
16
 
17
+ from typing import List, Union, Dict, List, Optional, Tuple
18
  import json
19
  import os
20
+ import unicodedata
21
  from functools import lru_cache
22
  from abc import ABC, abstractmethod
23
  import regex as re
 
26
  from collections import OrderedDict
27
 
28
  from transformers.tokenization_utils_base import AddedToken, TextInput
 
29
  from transformers.utils import logging
30
+ import transformers
31
+ from packaging import version
32
+ if version.parse(transformers.__version__) >= version.parse("5.0.0"):
33
+ from transformers.tokenization_python import PreTrainedTokenizer
34
+ else:
35
+ from transformers.tokenization_utils import PreTrainedTokenizer
36
 
37
 
38
  logger = logging.get_logger(__name__)
39
 
40
  try:
41
+ from rdkit import Chem, RDLogger
 
42
 
43
  RDLogger.DisableLog("rdApp.error")
44
  RDLogger.DisableLog("rdApp.*")
45
  RDKIT_AVAILABLE = True
46
  except ImportError:
47
  logger.warning_once(
48
+ "If tokenization with SMILES formula is of necessity, please 'pip install RDKit' for better tokenization quality."
49
  )
50
  RDKIT_AVAILABLE = False
51
 
 
346
  return self.check_brackets(text)
347
 
348
 
349
+ @lru_cache
350
+ # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
351
+ def bytes_to_unicode():
352
+ """
353
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
354
+ characters the bpe code barfs on.
355
+
356
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
357
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
358
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
359
+ tables between utf-8 bytes and unicode strings.
360
+ """
361
+ bs = (
362
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
363
+ )
364
+ cs = bs[:]
365
+ n = 0
366
+ for b in range(2**8):
367
+ if b not in bs:
368
+ bs.append(b)
369
+ cs.append(2**8 + n)
370
+ n += 1
371
+ cs = [chr(n) for n in cs]
372
+ return dict(zip(bs, cs))
373
+
374
+
375
+ # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
376
+ def get_pairs(word):
377
+ """
378
+ Return set of symbol pairs in a word.
379
+
380
+ Word is represented as tuple of symbols (symbols being variable-length strings).
381
+ """
382
+ pairs = set()
383
+ prev_char = word[0]
384
+ for char in word[1:]:
385
+ pairs.add((prev_char, char))
386
+ prev_char = char
387
+ return pairs
388
+
389
+
390
+ class InternS1Tokenizer(PreTrainedTokenizer):
391
  """
392
  Construct an InternS1 tokenizer. Based on byte-level Byte-Pair-Encoding.
393
 
 
450
  pad_token="<|endoftext|>",
451
  clean_up_tokenization_spaces=False,
452
  split_special_tokens=False,
453
+ special_tokens_pattern="none",
454
  **kwargs,
455
  ):
456
+ bos_token = (
457
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
458
+ if isinstance(bos_token, str)
459
+ else bos_token
460
+ )
461
+ eos_token = (
462
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
463
+ if isinstance(eos_token, str)
464
+ else eos_token
465
+ )
466
+ unk_token = (
467
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
468
+ if isinstance(unk_token, str)
469
+ else unk_token
470
+ )
471
+ pad_token = (
472
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
473
+ if isinstance(pad_token, str)
474
+ else pad_token
475
+ )
476
+
477
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
478
+ self.encoder = json.load(vocab_handle)
479
+ self.decoder = {v: k for k, v in self.encoder.items()}
480
+ self.errors = errors # how to handle errors in decoding
481
+ self.byte_encoder = bytes_to_unicode()
482
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
483
+ bpe_merges = []
484
+ with open(merges_file, encoding="utf-8") as merges_handle:
485
+ for i, line in enumerate(merges_handle):
486
+ line = line.strip()
487
+ if (i == 0 and line.startswith("#version:")) or not line:
488
+ continue
489
+ bpe_merges.append(tuple(line.split()))
490
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
491
+ # NOTE: the cache can grow without bound and will get really large for long running processes
492
+ # (esp. for texts of language that do not use space between word, e.g. Chinese); technically
493
+ # not a memory leak but appears as one.
494
+ # GPT2Tokenizer has the same problem, so let's be consistent.
495
+ self.cache = {}
496
+
497
+ self.pat = re.compile(PRETOKENIZE_REGEX)
498
+
499
+ if kwargs.get("add_prefix_space", False):
500
+ logger.warning_once(
501
+ f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
502
+ )
503
+
504
  self.extra_tokenizer_start_mapping = {}
505
  self.extra_tokenizer_end_mapping = {}
506
  self._extra_special_tokens = []
 
553
  pad_token=pad_token,
554
  clean_up_tokenization_spaces=clean_up_tokenization_spaces,
555
  split_special_tokens=split_special_tokens,
556
+ special_tokens_pattern="none",
557
  **kwargs,
558
  )
559
 
 
591
  """Overload method"""
592
  return self.vocab_size
593
 
594
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
595
+ def get_vocab(self):
596
+ return dict(self.encoder, **self.added_tokens_encoder)
597
+
598
  @property
599
  def logical_auto_tokens(self):
600
  """Tokens that won't be decoded and only for switching tokenizer"""
 
731
 
732
  text, kwargs = self.prepare_for_tokenization(text, **kwargs)
733
 
 
 
 
734
  if hasattr(self, "do_lower_case") and self.do_lower_case:
735
  # convert non-special tokens to lowercase. Might be super slow as well?
736
  escaped_special_toks = [re.escape(s_tok) for s_tok in (self.all_special_tokens)]
 
880
  self._added_tokens_encoder[token.content] = token_index
881
  if self.verbose:
882
  logger.info(f"Adding {token} to the vocabulary")
883
+
884
  self._update_trie()
885
  self._update_total_vocab_size()
886
 
 
910
  else:
911
  return self._bpe_tokenize(text)
912
 
913
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
914
+ def bpe(self, token):
915
+ if token in self.cache:
916
+ return self.cache[token]
917
+ word = tuple(token)
918
+ pairs = get_pairs(word)
919
+
920
+ if not pairs:
921
+ return token
922
+
923
+ while True:
924
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
925
+ if bigram not in self.bpe_ranks:
926
+ break
927
+ first, second = bigram
928
+ new_word = []
929
+ i = 0
930
+ while i < len(word):
931
+ try:
932
+ j = word.index(first, i)
933
+ except ValueError:
934
+ new_word.extend(word[i:])
935
+ break
936
+ else:
937
+ new_word.extend(word[i:j])
938
+ i = j
939
+
940
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
941
+ new_word.append(first + second)
942
+ i += 2
943
+ else:
944
+ new_word.append(word[i])
945
+ i += 1
946
+ new_word = tuple(new_word)
947
+ word = new_word
948
+ if len(word) == 1:
949
+ break
950
+ else:
951
+ pairs = get_pairs(word)
952
+ word = " ".join(word)
953
+ self.cache[token] = word
954
+ return word
955
+
956
  def _bpe_tokenize(self, text, **kwargs):
957
  text = text.replace(
958
  "▁", " "
 
1035
  def _convert_id_to_token(self, index):
1036
  """Converts an index (integer) in a token (str) using the vocab."""
1037
  return self.decoder.get(index, "")
1038
+
1039
  def convert_tokens_to_string(self, tokens):
1040
  """Converts a sequence of tokens (string) in a single string."""
1041
  text = "".join(tokens)