LittleMouse
Upload file
d054f6c
raw
history blame
7.06 kB
import base64
import os
from functools import lru_cache
from typing import Optional
import torch
from transformers import AutoTokenizer
import tiktoken
LANGUAGES = {
"en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian",
"ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish",
"pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian",
"id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese", "he": "hebrew",
"uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish",
"hu": "hungarian", "ta": "tamil", "no": "norwegian", "th": "thai", "ur": "urdu", "hr": "croatian",
"bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh",
"sk": "slovak", "te": "telugu", "fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian",
"az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian",
"br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian",
"bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili", "gl": "galician", "mr": "marathi",
"pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali",
"af": "afrikaans", "oc": "occitan", "ka": "georgian", "be": "belarusian", "tg": "tajik",
"sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek",
"fo": "faroese", "ht": "haitian creole", "ps": "pashto", "tk": "turkmen", "nn": "nynorsk",
"mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan",
"tl": "tagalog", "mg": "malagasy", "as": "assamese", "tt": "tatar", "haw": "hawaiian",
"ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese",
"yue": "cantonese", "minnan": "minnan", "wuyu": "wuyu", "dialect": "dialect", "zh/en": "zh/en", "en/zh": "en/zh"
}
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my", "valencian": "ca", "flemish": "nl", "haitian": "ht", "letzeburgesch": "lb",
"pushto": "ps", "panjabi": "pa", "moldavian": "ro", "moldovan": "ro", "sinhalese": "si",
"castilian": "es", "mandarin": "zh",
}
AUDIO_EVENT = {
"ASR": "ASR", "AED": "AED", "SER": "SER", "Speech": "Speech", "/Speech": "/Speech",
"BGM": "BGM", "/BGM": "/BGM", "Laughter": "Laughter", "/Laughter": "/Laughter",
"Applause": "Applause", "/Applause": "/Applause",
}
EMOTION = {
"HAPPY": "HAPPY", "SAD": "SAD", "ANGRY": "ANGRY", "NEUTRAL": "NEUTRAL",
}
TTS_Vocal_Token = {
"TTS/B": "TTS/B", "TTS/O": "TTS/O", "TTS/Q": "TTS/Q", "TTS/A": "TTS/A", "TTS/CO": "TTS/CO",
"TTS/CL": "TTS/CL", "TTS/H": "TTS/H", **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
}
# ===== 构造 Encoding =====
@lru_cache(maxsize=None)
def get_encoding(name: str = "gpt2", num_languages: int = 99):
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
ranks = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in open(vocab_path) if line)
}
n_vocab = len(ranks)
special_tokens = {}
specials = [
"<|endoftext|>", "<|startoftranscript|>",
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
*[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
*[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
"<|translate|>", "<|transcribe|>", "<|startoflm|>", "<|startofprev|>",
"<|nospeech|>", "<|notimestamps|>",
*[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)],
*[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())],
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
]
for token in specials:
special_tokens[token] = n_vocab
n_vocab += 1
return tiktoken.Encoding(
name=os.path.basename(vocab_path),
explicit_n_vocab=n_vocab,
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
mergeable_ranks=ranks,
special_tokens=special_tokens,
)
class SimpleTokenizer:
def __init__(self, encoding, num_languages: int = 99, language: Optional[str] = None, task: Optional[str] = None):
self.encoding = encoding
self.num_languages = num_languages
self.language = language
self.task = task
def encode(self, text: str):
return self.encoding.encode(text)
def decode(self, tokens: list):
return self.encoding.decode(tokens)
@lru_cache(maxsize=None)
def get_tokenizer(
multilingual: bool,
*,
num_languages: int = 99,
language: Optional[str] = None,
task: Optional[str] = None,
) -> SimpleTokenizer:
if language is not None:
language = language.lower()
if language not in LANGUAGES:
if language in TO_LANGUAGE_CODE:
language = TO_LANGUAGE_CODE[language]
else:
raise ValueError(f"Unsupported language: {language}")
if multilingual:
encoding_name = "multilingual_zh_ja_yue_char_del"
language = language or "en"
task = task or "transcribe"
else:
encoding_name = "gpt2"
language = None
task = None
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
return SimpleTokenizer(encoding=encoding, num_languages=num_languages, language=language, task=task)
class QwenTokenizer():
def __init__(self, token_path, skip_special_tokens=True):
super().__init__()
special_tokens = {
'eos_token': '<|endoftext|>',
'pad_token': '<|endoftext|>',
'additional_special_tokens': [
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
'[breath]', '<strong>', '</strong>', '[noise]',
'[laughter]', '[cough]', '[clucking]', '[accent]',
'[quick_breath]',
"<laughter>", "</laughter>",
"[hissing]", "[sigh]", "[vocalized-noise]",
"[lipsmack]", "[mn]"
]
}
self.special_tokens = special_tokens
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
self.tokenizer.add_special_tokens(special_tokens)
self.skip_special_tokens = skip_special_tokens
def encode(self, text, **kwargs):
tokens = self.tokenizer([text], return_tensors="pt")
return tokens["input_ids"][0].cpu().tolist()
def decode(self, tokens):
tokens = torch.tensor(tokens, dtype=torch.int64)
return self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
@lru_cache(maxsize=None)
def get_qwen_tokenizer(token_path: str, skip_special_tokens: bool) -> QwenTokenizer:
return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)