Adapt tokenization_interns1.py to transformers>=5.0.0
#10
by Zhangyc02 - opened
- tokenization_interns1.py +151 -10
tokenization_interns1.py
CHANGED
|
@@ -14,9 +14,10 @@
|
|
| 14 |
# limitations under the License.
|
| 15 |
"""Tokenization classes for InternS1."""
|
| 16 |
|
| 17 |
-
from typing import Union, Dict, List, Optional, Tuple
|
| 18 |
import json
|
| 19 |
import os
|
|
|
|
| 20 |
from functools import lru_cache
|
| 21 |
from abc import ABC, abstractmethod
|
| 22 |
import regex as re
|
|
@@ -25,22 +26,26 @@ import sentencepiece as spm
|
|
| 25 |
from collections import OrderedDict
|
| 26 |
|
| 27 |
from transformers.tokenization_utils_base import AddedToken, TextInput
|
| 28 |
-
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
|
| 29 |
from transformers.utils import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
logger = logging.get_logger(__name__)
|
| 33 |
|
| 34 |
try:
|
| 35 |
-
from rdkit import Chem
|
| 36 |
-
from rdkit import RDLogger
|
| 37 |
|
| 38 |
RDLogger.DisableLog("rdApp.error")
|
| 39 |
RDLogger.DisableLog("rdApp.*")
|
| 40 |
RDKIT_AVAILABLE = True
|
| 41 |
except ImportError:
|
| 42 |
logger.warning_once(
|
| 43 |
-
|
| 44 |
)
|
| 45 |
RDKIT_AVAILABLE = False
|
| 46 |
|
|
@@ -341,7 +346,48 @@ class SmilesCheckModule(InternS1CheckModuleMixin):
|
|
| 341 |
return self.check_brackets(text)
|
| 342 |
|
| 343 |
|
| 344 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
"""
|
| 346 |
Construct an InternS1 tokenizer. Based on byte-level Byte-Pair-Encoding.
|
| 347 |
|
|
@@ -404,8 +450,57 @@ class InternS1Tokenizer(Qwen2Tokenizer):
|
|
| 404 |
pad_token="<|endoftext|>",
|
| 405 |
clean_up_tokenization_spaces=False,
|
| 406 |
split_special_tokens=False,
|
|
|
|
| 407 |
**kwargs,
|
| 408 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
self.extra_tokenizer_start_mapping = {}
|
| 410 |
self.extra_tokenizer_end_mapping = {}
|
| 411 |
self._extra_special_tokens = []
|
|
@@ -458,6 +553,7 @@ class InternS1Tokenizer(Qwen2Tokenizer):
|
|
| 458 |
pad_token=pad_token,
|
| 459 |
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 460 |
split_special_tokens=split_special_tokens,
|
|
|
|
| 461 |
**kwargs,
|
| 462 |
)
|
| 463 |
|
|
@@ -495,6 +591,10 @@ class InternS1Tokenizer(Qwen2Tokenizer):
|
|
| 495 |
"""Overload method"""
|
| 496 |
return self.vocab_size
|
| 497 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
@property
|
| 499 |
def logical_auto_tokens(self):
|
| 500 |
"""Tokens that won't be decoded and only for switching tokenizer"""
|
|
@@ -631,9 +731,6 @@ class InternS1Tokenizer(Qwen2Tokenizer):
|
|
| 631 |
|
| 632 |
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
|
| 633 |
|
| 634 |
-
if kwargs:
|
| 635 |
-
logger.warning(f"Keyword arguments {kwargs} not recognized.")
|
| 636 |
-
|
| 637 |
if hasattr(self, "do_lower_case") and self.do_lower_case:
|
| 638 |
# convert non-special tokens to lowercase. Might be super slow as well?
|
| 639 |
escaped_special_toks = [re.escape(s_tok) for s_tok in (self.all_special_tokens)]
|
|
@@ -783,6 +880,7 @@ class InternS1Tokenizer(Qwen2Tokenizer):
|
|
| 783 |
self._added_tokens_encoder[token.content] = token_index
|
| 784 |
if self.verbose:
|
| 785 |
logger.info(f"Adding {token} to the vocabulary")
|
|
|
|
| 786 |
self._update_trie()
|
| 787 |
self._update_total_vocab_size()
|
| 788 |
|
|
@@ -812,6 +910,49 @@ class InternS1Tokenizer(Qwen2Tokenizer):
|
|
| 812 |
else:
|
| 813 |
return self._bpe_tokenize(text)
|
| 814 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 815 |
def _bpe_tokenize(self, text, **kwargs):
|
| 816 |
text = text.replace(
|
| 817 |
"▁", " "
|
|
@@ -894,7 +1035,7 @@ class InternS1Tokenizer(Qwen2Tokenizer):
|
|
| 894 |
def _convert_id_to_token(self, index):
|
| 895 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 896 |
return self.decoder.get(index, "")
|
| 897 |
-
|
| 898 |
def convert_tokens_to_string(self, tokens):
|
| 899 |
"""Converts a sequence of tokens (string) in a single string."""
|
| 900 |
text = "".join(tokens)
|
|
|
|
| 14 |
# limitations under the License.
|
| 15 |
"""Tokenization classes for InternS1."""
|
| 16 |
|
| 17 |
+
from typing import List, Union, Dict, List, Optional, Tuple
|
| 18 |
import json
|
| 19 |
import os
|
| 20 |
+
import unicodedata
|
| 21 |
from functools import lru_cache
|
| 22 |
from abc import ABC, abstractmethod
|
| 23 |
import regex as re
|
|
|
|
| 26 |
from collections import OrderedDict
|
| 27 |
|
| 28 |
from transformers.tokenization_utils_base import AddedToken, TextInput
|
|
|
|
| 29 |
from transformers.utils import logging
|
| 30 |
+
import transformers
|
| 31 |
+
from packaging import version
|
| 32 |
+
if version.parse(transformers.__version__) >= version.parse("5.0.0"):
|
| 33 |
+
from transformers.tokenization_python import PreTrainedTokenizer
|
| 34 |
+
else:
|
| 35 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 36 |
|
| 37 |
|
| 38 |
logger = logging.get_logger(__name__)
|
| 39 |
|
| 40 |
try:
|
| 41 |
+
from rdkit import Chem, RDLogger
|
|
|
|
| 42 |
|
| 43 |
RDLogger.DisableLog("rdApp.error")
|
| 44 |
RDLogger.DisableLog("rdApp.*")
|
| 45 |
RDKIT_AVAILABLE = True
|
| 46 |
except ImportError:
|
| 47 |
logger.warning_once(
|
| 48 |
+
"If tokenization with SMILES formula is of necessity, please 'pip install RDKit' for better tokenization quality."
|
| 49 |
)
|
| 50 |
RDKIT_AVAILABLE = False
|
| 51 |
|
|
|
|
| 346 |
return self.check_brackets(text)
|
| 347 |
|
| 348 |
|
| 349 |
+
@lru_cache
|
| 350 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
|
| 351 |
+
def bytes_to_unicode():
|
| 352 |
+
"""
|
| 353 |
+
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
|
| 354 |
+
characters the bpe code barfs on.
|
| 355 |
+
|
| 356 |
+
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
|
| 357 |
+
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
|
| 358 |
+
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
|
| 359 |
+
tables between utf-8 bytes and unicode strings.
|
| 360 |
+
"""
|
| 361 |
+
bs = (
|
| 362 |
+
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
| 363 |
+
)
|
| 364 |
+
cs = bs[:]
|
| 365 |
+
n = 0
|
| 366 |
+
for b in range(2**8):
|
| 367 |
+
if b not in bs:
|
| 368 |
+
bs.append(b)
|
| 369 |
+
cs.append(2**8 + n)
|
| 370 |
+
n += 1
|
| 371 |
+
cs = [chr(n) for n in cs]
|
| 372 |
+
return dict(zip(bs, cs))
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
|
| 376 |
+
def get_pairs(word):
|
| 377 |
+
"""
|
| 378 |
+
Return set of symbol pairs in a word.
|
| 379 |
+
|
| 380 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 381 |
+
"""
|
| 382 |
+
pairs = set()
|
| 383 |
+
prev_char = word[0]
|
| 384 |
+
for char in word[1:]:
|
| 385 |
+
pairs.add((prev_char, char))
|
| 386 |
+
prev_char = char
|
| 387 |
+
return pairs
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class InternS1Tokenizer(PreTrainedTokenizer):
|
| 391 |
"""
|
| 392 |
Construct an InternS1 tokenizer. Based on byte-level Byte-Pair-Encoding.
|
| 393 |
|
|
|
|
| 450 |
pad_token="<|endoftext|>",
|
| 451 |
clean_up_tokenization_spaces=False,
|
| 452 |
split_special_tokens=False,
|
| 453 |
+
special_tokens_pattern="none",
|
| 454 |
**kwargs,
|
| 455 |
):
|
| 456 |
+
bos_token = (
|
| 457 |
+
AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
| 458 |
+
if isinstance(bos_token, str)
|
| 459 |
+
else bos_token
|
| 460 |
+
)
|
| 461 |
+
eos_token = (
|
| 462 |
+
AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
| 463 |
+
if isinstance(eos_token, str)
|
| 464 |
+
else eos_token
|
| 465 |
+
)
|
| 466 |
+
unk_token = (
|
| 467 |
+
AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
| 468 |
+
if isinstance(unk_token, str)
|
| 469 |
+
else unk_token
|
| 470 |
+
)
|
| 471 |
+
pad_token = (
|
| 472 |
+
AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
| 473 |
+
if isinstance(pad_token, str)
|
| 474 |
+
else pad_token
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
| 478 |
+
self.encoder = json.load(vocab_handle)
|
| 479 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 480 |
+
self.errors = errors # how to handle errors in decoding
|
| 481 |
+
self.byte_encoder = bytes_to_unicode()
|
| 482 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 483 |
+
bpe_merges = []
|
| 484 |
+
with open(merges_file, encoding="utf-8") as merges_handle:
|
| 485 |
+
for i, line in enumerate(merges_handle):
|
| 486 |
+
line = line.strip()
|
| 487 |
+
if (i == 0 and line.startswith("#version:")) or not line:
|
| 488 |
+
continue
|
| 489 |
+
bpe_merges.append(tuple(line.split()))
|
| 490 |
+
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
| 491 |
+
# NOTE: the cache can grow without bound and will get really large for long running processes
|
| 492 |
+
# (esp. for texts of language that do not use space between word, e.g. Chinese); technically
|
| 493 |
+
# not a memory leak but appears as one.
|
| 494 |
+
# GPT2Tokenizer has the same problem, so let's be consistent.
|
| 495 |
+
self.cache = {}
|
| 496 |
+
|
| 497 |
+
self.pat = re.compile(PRETOKENIZE_REGEX)
|
| 498 |
+
|
| 499 |
+
if kwargs.get("add_prefix_space", False):
|
| 500 |
+
logger.warning_once(
|
| 501 |
+
f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
self.extra_tokenizer_start_mapping = {}
|
| 505 |
self.extra_tokenizer_end_mapping = {}
|
| 506 |
self._extra_special_tokens = []
|
|
|
|
| 553 |
pad_token=pad_token,
|
| 554 |
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 555 |
split_special_tokens=split_special_tokens,
|
| 556 |
+
special_tokens_pattern="none",
|
| 557 |
**kwargs,
|
| 558 |
)
|
| 559 |
|
|
|
|
| 591 |
"""Overload method"""
|
| 592 |
return self.vocab_size
|
| 593 |
|
| 594 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
|
| 595 |
+
def get_vocab(self):
|
| 596 |
+
return dict(self.encoder, **self.added_tokens_encoder)
|
| 597 |
+
|
| 598 |
@property
|
| 599 |
def logical_auto_tokens(self):
|
| 600 |
"""Tokens that won't be decoded and only for switching tokenizer"""
|
|
|
|
| 731 |
|
| 732 |
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
|
| 733 |
|
|
|
|
|
|
|
|
|
|
| 734 |
if hasattr(self, "do_lower_case") and self.do_lower_case:
|
| 735 |
# convert non-special tokens to lowercase. Might be super slow as well?
|
| 736 |
escaped_special_toks = [re.escape(s_tok) for s_tok in (self.all_special_tokens)]
|
|
|
|
| 880 |
self._added_tokens_encoder[token.content] = token_index
|
| 881 |
if self.verbose:
|
| 882 |
logger.info(f"Adding {token} to the vocabulary")
|
| 883 |
+
|
| 884 |
self._update_trie()
|
| 885 |
self._update_total_vocab_size()
|
| 886 |
|
|
|
|
| 910 |
else:
|
| 911 |
return self._bpe_tokenize(text)
|
| 912 |
|
| 913 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
|
| 914 |
+
def bpe(self, token):
|
| 915 |
+
if token in self.cache:
|
| 916 |
+
return self.cache[token]
|
| 917 |
+
word = tuple(token)
|
| 918 |
+
pairs = get_pairs(word)
|
| 919 |
+
|
| 920 |
+
if not pairs:
|
| 921 |
+
return token
|
| 922 |
+
|
| 923 |
+
while True:
|
| 924 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
| 925 |
+
if bigram not in self.bpe_ranks:
|
| 926 |
+
break
|
| 927 |
+
first, second = bigram
|
| 928 |
+
new_word = []
|
| 929 |
+
i = 0
|
| 930 |
+
while i < len(word):
|
| 931 |
+
try:
|
| 932 |
+
j = word.index(first, i)
|
| 933 |
+
except ValueError:
|
| 934 |
+
new_word.extend(word[i:])
|
| 935 |
+
break
|
| 936 |
+
else:
|
| 937 |
+
new_word.extend(word[i:j])
|
| 938 |
+
i = j
|
| 939 |
+
|
| 940 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
| 941 |
+
new_word.append(first + second)
|
| 942 |
+
i += 2
|
| 943 |
+
else:
|
| 944 |
+
new_word.append(word[i])
|
| 945 |
+
i += 1
|
| 946 |
+
new_word = tuple(new_word)
|
| 947 |
+
word = new_word
|
| 948 |
+
if len(word) == 1:
|
| 949 |
+
break
|
| 950 |
+
else:
|
| 951 |
+
pairs = get_pairs(word)
|
| 952 |
+
word = " ".join(word)
|
| 953 |
+
self.cache[token] = word
|
| 954 |
+
return word
|
| 955 |
+
|
| 956 |
def _bpe_tokenize(self, text, **kwargs):
|
| 957 |
text = text.replace(
|
| 958 |
"▁", " "
|
|
|
|
| 1035 |
def _convert_id_to_token(self, index):
|
| 1036 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 1037 |
return self.decoder.get(index, "")
|
| 1038 |
+
|
| 1039 |
def convert_tokens_to_string(self, tokens):
|
| 1040 |
"""Converts a sequence of tokens (string) in a single string."""
|
| 1041 |
text = "".join(tokens)
|