NeMo / nemo /collections /common /tokenizers /bytelevel_tokenizers.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Union
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
__all__ = ['ByteLevelProcessor', 'ByteLevelTokenizer']
class ByteLevelProcessor:
"""
A very basic tokenization and detokenization class for use with byte-level
tokenization.
"""
def detokenize(self, tokens: List[str]) -> str:
return ' '.join(tokens)
def tokenize(self, text) -> str:
return text
def normalize(self, text) -> str:
return text
class ByteLevelTokenizer(TokenizerSpec):
def __init__(self, special_tokens: Optional[Union[Dict[str, str], List[str]]] = None):
self.vocab_size = 259
self.special_start = 256
self.special_token_to_id = {
self.pad_id: self.pad_id,
self.bos_id: self.bos_id,
self.eos_id: self.eos_id,
}
special_tokens = {} if special_tokens is None else special_tokens
for tok in special_tokens:
self.special_start -= 1
self.special_token_to_id[tok] = self.special_start
self.id_to_special_token = {v: k for k, v in self.special_token_to_id.items()}
# no distinction between tokens and ids.
def text_to_tokens(self, text):
return self.text_to_ids(text)
def tokens_to_text(self, tokens):
return self.ids_to_text(tokens)
def text_to_ids(self, text):
return list(text.encode('utf-8'))
def ids_to_text(self, ids):
# remove special tokens.
ids = [x for x in ids if x < self.special_start]
return bytes(ids).decode('utf-8', errors='ignore').rstrip()
def tokens_to_ids(self, tokens):
if isinstance(tokens, str):
tokens = [tokens]
ids = []
for token in tokens:
ids.append(self.token_to_id(token))
return ids
def ids_to_tokens(self, ids):
if isinstance(ids, int):
ids = [ids]
tokens = []
for id in ids:
tokens.append(self.id_to_token(id))
return tokens
def token_to_id(self, token):
if token in self.special_token_to_id:
return self.special_token_to_id[token]
else:
return token
def id_to_token(self, id):
if id < self.special_start:
return id
else:
return self.id_to_special_token[id]
@property
def pad_id(self):
return 256
@property
def bos_id(self):
return 257
@property
def eos_id(self):
return 258
@property
def unk_id(self):
return 259 # unused