File size: 7,055 Bytes
d054f6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import base64
import os
from functools import lru_cache
from typing import Optional
import torch
from transformers import AutoTokenizer
import tiktoken
LANGUAGES = {
"en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian",
"ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish",
"pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian",
"id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese", "he": "hebrew",
"uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish",
"hu": "hungarian", "ta": "tamil", "no": "norwegian", "th": "thai", "ur": "urdu", "hr": "croatian",
"bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh",
"sk": "slovak", "te": "telugu", "fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian",
"az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian",
"br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian",
"bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili", "gl": "galician", "mr": "marathi",
"pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali",
"af": "afrikaans", "oc": "occitan", "ka": "georgian", "be": "belarusian", "tg": "tajik",
"sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek",
"fo": "faroese", "ht": "haitian creole", "ps": "pashto", "tk": "turkmen", "nn": "nynorsk",
"mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan",
"tl": "tagalog", "mg": "malagasy", "as": "assamese", "tt": "tatar", "haw": "hawaiian",
"ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese",
"yue": "cantonese", "minnan": "minnan", "wuyu": "wuyu", "dialect": "dialect", "zh/en": "zh/en", "en/zh": "en/zh"
}
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my", "valencian": "ca", "flemish": "nl", "haitian": "ht", "letzeburgesch": "lb",
"pushto": "ps", "panjabi": "pa", "moldavian": "ro", "moldovan": "ro", "sinhalese": "si",
"castilian": "es", "mandarin": "zh",
}
AUDIO_EVENT = {
"ASR": "ASR", "AED": "AED", "SER": "SER", "Speech": "Speech", "/Speech": "/Speech",
"BGM": "BGM", "/BGM": "/BGM", "Laughter": "Laughter", "/Laughter": "/Laughter",
"Applause": "Applause", "/Applause": "/Applause",
}
EMOTION = {
"HAPPY": "HAPPY", "SAD": "SAD", "ANGRY": "ANGRY", "NEUTRAL": "NEUTRAL",
}
TTS_Vocal_Token = {
"TTS/B": "TTS/B", "TTS/O": "TTS/O", "TTS/Q": "TTS/Q", "TTS/A": "TTS/A", "TTS/CO": "TTS/CO",
"TTS/CL": "TTS/CL", "TTS/H": "TTS/H", **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
}
# ===== 构造 Encoding =====
@lru_cache(maxsize=None)
def get_encoding(name: str = "gpt2", num_languages: int = 99):
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
ranks = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in open(vocab_path) if line)
}
n_vocab = len(ranks)
special_tokens = {}
specials = [
"<|endoftext|>", "<|startoftranscript|>",
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
*[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
*[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
"<|translate|>", "<|transcribe|>", "<|startoflm|>", "<|startofprev|>",
"<|nospeech|>", "<|notimestamps|>",
*[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)],
*[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())],
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
]
for token in specials:
special_tokens[token] = n_vocab
n_vocab += 1
return tiktoken.Encoding(
name=os.path.basename(vocab_path),
explicit_n_vocab=n_vocab,
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
mergeable_ranks=ranks,
special_tokens=special_tokens,
)
class SimpleTokenizer:
def __init__(self, encoding, num_languages: int = 99, language: Optional[str] = None, task: Optional[str] = None):
self.encoding = encoding
self.num_languages = num_languages
self.language = language
self.task = task
def encode(self, text: str):
return self.encoding.encode(text)
def decode(self, tokens: list):
return self.encoding.decode(tokens)
@lru_cache(maxsize=None)
def get_tokenizer(
multilingual: bool,
*,
num_languages: int = 99,
language: Optional[str] = None,
task: Optional[str] = None,
) -> SimpleTokenizer:
if language is not None:
language = language.lower()
if language not in LANGUAGES:
if language in TO_LANGUAGE_CODE:
language = TO_LANGUAGE_CODE[language]
else:
raise ValueError(f"Unsupported language: {language}")
if multilingual:
encoding_name = "multilingual_zh_ja_yue_char_del"
language = language or "en"
task = task or "transcribe"
else:
encoding_name = "gpt2"
language = None
task = None
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
return SimpleTokenizer(encoding=encoding, num_languages=num_languages, language=language, task=task)
class QwenTokenizer():
def __init__(self, token_path, skip_special_tokens=True):
super().__init__()
special_tokens = {
'eos_token': '<|endoftext|>',
'pad_token': '<|endoftext|>',
'additional_special_tokens': [
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
'[breath]', '<strong>', '</strong>', '[noise]',
'[laughter]', '[cough]', '[clucking]', '[accent]',
'[quick_breath]',
"<laughter>", "</laughter>",
"[hissing]", "[sigh]", "[vocalized-noise]",
"[lipsmack]", "[mn]"
]
}
self.special_tokens = special_tokens
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
self.tokenizer.add_special_tokens(special_tokens)
self.skip_special_tokens = skip_special_tokens
def encode(self, text, **kwargs):
tokens = self.tokenizer([text], return_tensors="pt")
return tokens["input_ids"][0].cpu().tolist()
def decode(self, tokens):
tokens = torch.tensor(tokens, dtype=torch.int64)
return self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
@lru_cache(maxsize=None)
def get_qwen_tokenizer(token_path: str, skip_special_tokens: bool) -> QwenTokenizer:
return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens) |