Biorrith's picture
Initial app
82892a4
import logging
import json
import re
import torch
from pathlib import Path
from unicodedata import category
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
# Special tokens
SOT = "[START]"
EOT = "[STOP]"
UNK = "[UNK]"
SPACE = "[SPACE]"
SPECIAL_TOKENS = [SOT, EOT, UNK, SPACE, "[PAD]", "[SEP]", "[CLS]", "[MASK]"]
logger = logging.getLogger(__name__)
class EnTokenizer:
def __init__(self, vocab_file_path):
self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
self.check_vocabset_sot_eot()
def check_vocabset_sot_eot(self):
voc = self.tokenizer.get_vocab()
assert SOT in voc
assert EOT in voc
def text_to_tokens(self, text: str):
text_tokens = self.encode(text)
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
return text_tokens
def encode( self, txt: str, verbose=False):
"""
clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer
"""
txt = txt.replace(' ', SPACE)
code = self.tokenizer.encode(txt)
ids = code.ids
return ids
def decode(self, seq):
if isinstance(seq, torch.Tensor):
seq = seq.cpu().numpy()
txt: str = self.tokenizer.decode(seq,
skip_special_tokens=False)
txt = txt.replace(' ', '')
txt = txt.replace(SPACE, ' ')
txt = txt.replace(EOT, '')
txt = txt.replace(UNK, '')
return txt
# Model repository
REPO_ID = "ResembleAI/chatterbox"
class DaEnTokenizer:
def __init__(self, vocab_file_path):
self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
self.check_vocabset_sot_eot()
def check_vocabset_sot_eot(self):
voc = self.tokenizer.get_vocab()
assert SOT in voc
assert EOT in voc
def text_to_tokens(self, text: str, language_id: str = None):
text_tokens = self.encode(text, language_id=language_id)
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
return text_tokens
def encode(self, txt: str, language_id: str = None):
# Prepend language token
if language_id:
txt = f"[{language_id.lower()}]{txt}"
txt = txt.replace(' ', SPACE)
return self.tokenizer.encode(txt).ids
def decode(self, seq):
if isinstance(seq, torch.Tensor):
seq = seq.cpu().numpy()
txt = self.tokenizer.decode(seq, skip_special_tokens=False)
txt = txt.replace(' ', '').replace(SPACE, ' ').replace(EOT, '').replace(UNK, '')
return txt