# coding=utf-8 # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tokenization classes.""" import collections import logging import os import math import unicodedata from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast logger = logging.getLogger(__name__) VOCAB_FILES_NAMES = {"vocab_file": os.getenv("VOCAB_NAME")} PRETRAINED_VOCAB_FILES_MAP = {"vocab_file": { 'dna' : os.getenv("VOCAB_PATH") } } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {'dna': os.getenv("POSITIONAL_EMBEDDINGS_SIZE")} PRETRAINED_INIT_CONFIGURATION = {'dna': {"do_lower_case": False}} def load_vocab(vocab_file): """Loads a vocabulary file into a dictionary.""" vocab = collections.OrderedDict() with open(vocab_file, "r", encoding="utf-8") as reader: tokens = reader.readlines() for index, token in enumerate(tokens): token = token.rstrip("\n") vocab[token] = index return vocab def whitespace_tokenize(text): """Runs basic whitespace cleaning and splitting on a piece of text.""" text = text.strip() if not text: return [] tokens = text.split() return tokens class DNATokenizer(PreTrainedTokenizer): r""" Constructs a BertTokenizer. :class:`~transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece Args: vocab_file: Path to a one-wordpiece-per-line vocabulary file do_lower_case: Whether to lower case the input. Only has an effect when do_basic_tokenize=True do_basic_tokenize: Whether to do basic tokenization before wordpiece. max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the minimum of this value (if specified) and the underlying BERT model's sequence length. never_split: List of tokens which will never be split during tokenization. Only has an effect when do_basic_tokenize=True """ vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES def __init__( self, vocab_file, do_lower_case=False, never_split=None, unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", mask_token="[MASK]", tokenize_chinese_chars=True, **kwargs ): """Constructs a BertTokenizer. Args: **vocab_file**: Path to a one-wordpiece-per-line vocabulary file **do_lower_case**: (`optional`) boolean (default True) Whether to lower case the input Only has an effect when do_basic_tokenize=True **do_basic_tokenize**: (`optional`) boolean (default True) Whether to do basic tokenization before wordpiece. **never_split**: (`optional`) list of string List of tokens which will never be split during tokenization. Only has an effect when do_basic_tokenize=True **tokenize_chinese_chars**: (`optional`) boolean (default True) Whether to tokenize Chinese characters. This should likely be deactivated for Japanese: see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 """ super().__init__( unk_token=unk_token, sep_token=sep_token, pad_token=pad_token, cls_token=cls_token, mask_token=mask_token, **kwargs, ) self.max_len_single_sentence = self.max_len - 2 # take into account special tokens self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens if not os.path.isfile(vocab_file): raise ValueError( "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file) ) self.vocab = load_vocab(vocab_file) self.kmer = VOCAB_KMER[str(len(self.vocab))] self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) self.basic_tokenizer = BasicTokenizer( do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars ) @property def vocab_size(self): return len(self.vocab) def _tokenize(self, text): split_tokens = [] for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): split_tokens.append(token) # print(split_tokens) return split_tokens def _convert_token_to_id(self, token): """ Converts a token (str) in an id using the vocab. """ return self.vocab.get(token, self.vocab.get(self.unk_token)) def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" return self.ids_to_tokens.get(index, self.unk_token) def convert_tokens_to_string(self, tokens): """ Converts a sequence of tokens (string) in a single string. """ out_string = " ".join(tokens).replace(" ##", "").strip() return out_string def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. A BERT sequence has the following format: single sequence: [CLS] X [SEP] pair of sequences: [CLS] A [SEP] B [SEP] """ cls = [self.cls_token_id] sep = [self.sep_token_id] if token_ids_1 is None: if len(token_ids_0) < 510: return cls + token_ids_0 + sep else: output = [] num_pieces = int(len(token_ids_0)//510) + 1 for i in range(num_pieces): output.extend(cls + token_ids_0[510*i:min(len(token_ids_0), 510*(i+1))] + sep) return output return cls + token_ids_0 + sep + token_ids_1 + sep def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. Args: token_ids_0: list of ids (must not contain special tokens) token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids for sequence pairs already_has_special_tokens: (default False) Set to True if the token list is already formated with special tokens for the model Returns: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ if already_has_special_tokens: if token_ids_1 is not None: raise ValueError( "You should not supply a second sequence if the provided sequence of " "ids is already formated with special tokens for the model." ) return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) if token_ids_1 is not None: return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] if len(token_ids_0) < 510: return [1] + ([0] * len(token_ids_0)) + [1] else: output = [] num_pieces = int(len(token_ids_0)//510) + 1 for i in range(num_pieces): output.extend([1] + ([0] * (min(len(token_ids_0), 510*(i+1))-510*i)) + [1]) return output return [1] + ([0] * len(token_ids_0)) + [1] def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): """ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence pair mask has the following format: 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 | first sequence | second sequence if token_ids_1 is None, only returns the first portion of the mask (0's). """ sep = [self.sep_token_id] cls = [self.cls_token_id] if token_ids_1 is None: if len(token_ids_0) < 510: return len(cls + token_ids_0 + sep) * [0] else: num_pieces = int(len(token_ids_0)//510) + 1 return (len(cls + token_ids_0 + sep) + 2*(num_pieces-1)) * [0] return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] def save_vocabulary(self, vocab_path): """Save the tokenizer vocabulary to a directory or file.""" index = 0 if os.path.isdir(vocab_path): vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"]) else: vocab_file = vocab_path with open(vocab_file, "w", encoding="utf-8") as writer: for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): if index != token_index: logger.warning( "Saving vocabulary to {}: vocabulary indices are not consecutive." " Please check that the vocabulary is not corrupted!".format(vocab_file) ) index = token_index writer.write(token + "\n") index += 1 return (vocab_file,) class BasicTokenizer(object): """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" def __init__(self, do_lower_case=False, never_split=None, tokenize_chinese_chars=True): """ Constructs a BasicTokenizer. Args: **do_lower_case**: Whether to lower case the input. **never_split**: (`optional`) list of str Kept for backward compatibility purposes. Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) List of token not to split. **tokenize_chinese_chars**: (`optional`) boolean (default True) Whether to tokenize Chinese characters. This should likely be deactivated for Japanese: see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 """ if never_split is None: never_split = [] self.do_lower_case = do_lower_case self.never_split = never_split self.tokenize_chinese_chars = tokenize_chinese_chars def tokenize(self, text, never_split=None): """ Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer. Args: **never_split**: (`optional`) list of str Kept for backward compatibility purposes. Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) List of token not to split. """ never_split = self.never_split + (never_split if never_split is not None else []) text = self._clean_text(text) # This was added on November 1st, 2018 for the multilingual and Chinese # models. This is also applied to the English models now, but it doesn't # matter since the English models were not trained on any Chinese data # and generally don't have any Chinese data in them (there are Chinese # characters in the vocabulary because Wikipedia does have some Chinese # words in the English Wikipedia.). orig_tokens = whitespace_tokenize(text) split_tokens = [] for token in orig_tokens: if token not in never_split: token = self._run_strip_accents(token) split_tokens.extend(self._run_split_on_punc(token, never_split)) output_tokens = whitespace_tokenize(" ".join(split_tokens)) return output_tokens def _run_strip_accents(self, text): """Strips accents from a piece of text.""" text = unicodedata.normalize("NFD", text) output = [] for char in text: cat = unicodedata.category(char) if cat == "Mn": continue output.append(char) return "".join(output) def _run_split_on_punc(self, text, never_split=None): """Splits punctuation on a piece of text.""" if never_split is not None and text in never_split: return [text] chars = list(text) i = 0 start_new_word = True output = [] while i < len(chars): char = chars[i] if _is_punctuation(char): output.append([char]) start_new_word = True else: if start_new_word: output.append([]) start_new_word = False output[-1].append(char) i += 1 return ["".join(x) for x in output] def _clean_text(self, text): """Performs invalid character removal and whitespace cleanup on text.""" output = [] for char in text: cp = ord(char) if cp == 0 or cp == 0xFFFD or _is_control(char): continue if _is_whitespace(char): output.append(" ") else: output.append(char) return "".join(output) def _is_whitespace(char): """Checks whether `chars` is a whitespace character.""" # \t, \n, and \r are technically contorl characters but we treat them # as whitespace since they are generally considered as such. if char == " " or char == "\t" or char == "\n" or char == "\r": return True cat = unicodedata.category(char) if cat == "Zs": return True return False def _is_control(char): """Checks whether `chars` is a control character.""" # These are technically control characters but we count them as whitespace # characters. if char == "\t" or char == "\n" or char == "\r": return False cat = unicodedata.category(char) if cat.startswith("C"): return True return False def _is_punctuation(char): """Checks whether `chars` is a punctuation character.""" cp = ord(char) # We treat all non-letter/number ASCII as punctuation. # Characters such as "^", "$", and "`" are not in the Unicode # Punctuation class but we treat them as punctuation anyways, for # consistency. if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): return True cat = unicodedata.category(char) if cat.startswith("P"): return True return False