Spaces:
Sleeping
Sleeping
| import logging | |
| import json | |
| import re | |
| import torch | |
| from pathlib import Path | |
| from unicodedata import category | |
| from tokenizers import Tokenizer | |
| from huggingface_hub import hf_hub_download | |
| # Special tokens | |
| SOT = "[START]" | |
| EOT = "[STOP]" | |
| UNK = "[UNK]" | |
| SPACE = "[SPACE]" | |
| SPECIAL_TOKENS = [SOT, EOT, UNK, SPACE, "[PAD]", "[SEP]", "[CLS]", "[MASK]"] | |
| logger = logging.getLogger(__name__) | |
| class EnTokenizer: | |
| def __init__(self, vocab_file_path): | |
| self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path) | |
| self.check_vocabset_sot_eot() | |
| def check_vocabset_sot_eot(self): | |
| voc = self.tokenizer.get_vocab() | |
| assert SOT in voc | |
| assert EOT in voc | |
| def text_to_tokens(self, text: str): | |
| text_tokens = self.encode(text) | |
| text_tokens = torch.IntTensor(text_tokens).unsqueeze(0) | |
| return text_tokens | |
| def encode( self, txt: str, verbose=False): | |
| """ | |
| clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer | |
| """ | |
| txt = txt.replace(' ', SPACE) | |
| code = self.tokenizer.encode(txt) | |
| ids = code.ids | |
| return ids | |
| def decode(self, seq): | |
| if isinstance(seq, torch.Tensor): | |
| seq = seq.cpu().numpy() | |
| txt: str = self.tokenizer.decode(seq, | |
| skip_special_tokens=False) | |
| txt = txt.replace(' ', '') | |
| txt = txt.replace(SPACE, ' ') | |
| txt = txt.replace(EOT, '') | |
| txt = txt.replace(UNK, '') | |
| return txt | |
| # Model repository | |
| REPO_ID = "ResembleAI/chatterbox" | |
| class DaEnTokenizer: | |
| def __init__(self, vocab_file_path): | |
| self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path) | |
| self.check_vocabset_sot_eot() | |
| def check_vocabset_sot_eot(self): | |
| voc = self.tokenizer.get_vocab() | |
| assert SOT in voc | |
| assert EOT in voc | |
| def text_to_tokens(self, text: str, language_id: str = None): | |
| text_tokens = self.encode(text, language_id=language_id) | |
| text_tokens = torch.IntTensor(text_tokens).unsqueeze(0) | |
| return text_tokens | |
| def encode(self, txt: str, language_id: str = None): | |
| # Prepend language token | |
| if language_id: | |
| txt = f"[{language_id.lower()}]{txt}" | |
| txt = txt.replace(' ', SPACE) | |
| return self.tokenizer.encode(txt).ids | |
| def decode(self, seq): | |
| if isinstance(seq, torch.Tensor): | |
| seq = seq.cpu().numpy() | |
| txt = self.tokenizer.decode(seq, skip_special_tokens=False) | |
| txt = txt.replace(' ', '').replace(SPACE, ' ').replace(EOT, '').replace(UNK, '') | |
| return txt | |