LittleMouse
commited on
Commit
·
d054f6c
1
Parent(s):
6912cd9
Upload file
Browse files- .gitattributes +2 -0
- CosyVoice-BlankEN/merges.txt +0 -0
- CosyVoice-BlankEN/tokenizer_config.json +40 -0
- CosyVoice-BlankEN/vocab.json +0 -0
- asset/en_man1.mp3 +3 -0
- asset/en_man1.txt +1 -0
- asset/en_woman1.mp3 +3 -0
- asset/en_woman1.txt +1 -0
- asset/zh_man1.txt +1 -0
- asset/zh_man1.wav +3 -0
- asset/zh_man2.mp3 +3 -0
- asset/zh_man2.txt +1 -0
- asset/zh_woman1.txt +1 -0
- asset/zh_woman1.wav +3 -0
- frontend-onnx/campplus.onnx +3 -0
- frontend-onnx/speech_tokenizer_v2.onnx +3 -0
- pengzhendong/wetext +1 -0
- requirements.txt +10 -0
- scripts/audio.py +83 -0
- scripts/frontend.py +251 -0
- scripts/process_prompt.py +70 -0
- scripts/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +0 -0
- scripts/tokenizer/tokenizer.py +151 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
CosyVoice-BlankEN/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
CosyVoice-BlankEN/tokenizer_config.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"added_tokens_decoder": {
|
| 4 |
+
"151643": {
|
| 5 |
+
"content": "<|endoftext|>",
|
| 6 |
+
"lstrip": false,
|
| 7 |
+
"normalized": false,
|
| 8 |
+
"rstrip": false,
|
| 9 |
+
"single_word": false,
|
| 10 |
+
"special": true
|
| 11 |
+
},
|
| 12 |
+
"151644": {
|
| 13 |
+
"content": "<|im_start|>",
|
| 14 |
+
"lstrip": false,
|
| 15 |
+
"normalized": false,
|
| 16 |
+
"rstrip": false,
|
| 17 |
+
"single_word": false,
|
| 18 |
+
"special": true
|
| 19 |
+
},
|
| 20 |
+
"151645": {
|
| 21 |
+
"content": "<|im_end|>",
|
| 22 |
+
"lstrip": false,
|
| 23 |
+
"normalized": false,
|
| 24 |
+
"rstrip": false,
|
| 25 |
+
"single_word": false,
|
| 26 |
+
"special": true
|
| 27 |
+
}
|
| 28 |
+
},
|
| 29 |
+
"additional_special_tokens": ["<|im_start|>", "<|im_end|>"],
|
| 30 |
+
"bos_token": null,
|
| 31 |
+
"chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
| 32 |
+
"clean_up_tokenization_spaces": false,
|
| 33 |
+
"eos_token": "<|im_end|>",
|
| 34 |
+
"errors": "replace",
|
| 35 |
+
"model_max_length": 32768,
|
| 36 |
+
"pad_token": "<|endoftext|>",
|
| 37 |
+
"split_special_tokens": false,
|
| 38 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 39 |
+
"unk_token": null
|
| 40 |
+
}
|
CosyVoice-BlankEN/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
asset/en_man1.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:461dd4cc9cf5bf6b774a9978cc9b7ca96033b214714b12413ecfe9eb1bf03ab9
|
| 3 |
+
size 15309
|
asset/en_man1.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Because he has zero capacity to respond to the two and a half hour
|
asset/en_woman1.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:872ff69b74b37763cfc4a49bdd39d8a2acf51f428e42e1ab9fa3dfc0c4a2e3d4
|
| 3 |
+
size 16941
|
asset/en_woman1.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
But many of these southern girls have the same trouble, said Holly.
|
asset/zh_man1.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
南方高温卷土重来,全国秋老虎地图出炉。
|
asset/zh_man1.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:da1153fca1303cd20470317a4ba93027cc5e172214b777747215add36f41109e
|
| 3 |
+
size 1536044
|
asset/zh_man2.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cd012ac30fe1ffb5bc3e356a84f4f668a25a62c72f810ffae218f83cbcfdf53e
|
| 3 |
+
size 31761
|
asset/zh_man2.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
所以呢目标是非常有威力的,它是创造原则的全部。
|
asset/zh_woman1.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
希望你以后能够做的比我还好呦。
|
asset/zh_woman1.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bd199eb7109fd6ce9943cb297e3cf350c1073af014063dfadbdc100230526243
|
| 3 |
+
size 111496
|
frontend-onnx/campplus.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a6ac6a63997761ae2997373e2ee1c47040854b4b759ea41ec48e4e42df0f4d73
|
| 3 |
+
size 28303423
|
frontend-onnx/speech_tokenizer_v2.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d43342aa12163a80bf07bffb94c9de2e120a8df2f9917cd2f642e7f4219c6f71
|
| 3 |
+
size 496082973
|
pengzhendong/wetext
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit 8e93692beb2e7f7d0aab4807819abfff0c3dbe6d
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchaudio
|
| 3 |
+
soundfile
|
| 4 |
+
numpy
|
| 5 |
+
onnxruntime
|
| 6 |
+
openai-whisper
|
| 7 |
+
inflect
|
| 8 |
+
transformers
|
| 9 |
+
librosa
|
| 10 |
+
wetext==0.0.4
|
scripts/audio.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.utils.data
|
| 4 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 5 |
+
from scipy.io.wavfile import read
|
| 6 |
+
|
| 7 |
+
MAX_WAV_VALUE = 32768.0
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def load_wav(full_path):
|
| 11 |
+
sampling_rate, data = read(full_path)
|
| 12 |
+
return data, sampling_rate
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
| 16 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def dynamic_range_decompression(x, C=1):
|
| 20 |
+
return np.exp(x) / C
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 24 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def dynamic_range_decompression_torch(x, C=1):
|
| 28 |
+
return torch.exp(x) / C
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def spectral_normalize_torch(magnitudes):
|
| 32 |
+
output = dynamic_range_compression_torch(magnitudes)
|
| 33 |
+
return output
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def spectral_de_normalize_torch(magnitudes):
|
| 37 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
| 38 |
+
return output
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
mel_basis = {}
|
| 42 |
+
hann_window = {}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
| 46 |
+
if torch.min(y) < -1.0:
|
| 47 |
+
print("min value is ", torch.min(y))
|
| 48 |
+
if torch.max(y) > 1.0:
|
| 49 |
+
print("max value is ", torch.max(y))
|
| 50 |
+
|
| 51 |
+
global mel_basis, hann_window # pylint: disable=global-statement
|
| 52 |
+
print("fmax",fmax)
|
| 53 |
+
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
| 54 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
| 55 |
+
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
| 56 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
| 57 |
+
|
| 58 |
+
y = torch.nn.functional.pad(
|
| 59 |
+
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
| 60 |
+
)
|
| 61 |
+
y = y.squeeze(1)
|
| 62 |
+
|
| 63 |
+
spec = torch.view_as_real(
|
| 64 |
+
torch.stft(
|
| 65 |
+
y,
|
| 66 |
+
n_fft,
|
| 67 |
+
hop_length=hop_size,
|
| 68 |
+
win_length=win_size,
|
| 69 |
+
window=hann_window[str(y.device)],
|
| 70 |
+
center=center,
|
| 71 |
+
pad_mode="reflect",
|
| 72 |
+
normalized=False,
|
| 73 |
+
onesided=True,
|
| 74 |
+
return_complex=True,
|
| 75 |
+
)
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
| 79 |
+
|
| 80 |
+
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
|
| 81 |
+
spec = spectral_normalize_torch(spec)
|
| 82 |
+
|
| 83 |
+
return spec
|
scripts/frontend.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from functools import partial
|
| 15 |
+
from functools import lru_cache
|
| 16 |
+
from typing import Generator
|
| 17 |
+
import json
|
| 18 |
+
import onnxruntime
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
import whisper
|
| 22 |
+
from typing import Callable
|
| 23 |
+
import torchaudio.compliance.kaldi as kaldi
|
| 24 |
+
import torchaudio
|
| 25 |
+
import os
|
| 26 |
+
import re
|
| 27 |
+
import inflect
|
| 28 |
+
from tokenizer.tokenizer import get_qwen_tokenizer
|
| 29 |
+
from audio import mel_spectrogram
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
import ttsfrd
|
| 33 |
+
use_ttsfrd = True
|
| 34 |
+
except ImportError:
|
| 35 |
+
|
| 36 |
+
from wetext import Normalizer as ZhNormalizer
|
| 37 |
+
from wetext import Normalizer as EnNormalizer
|
| 38 |
+
use_ttsfrd = False
|
| 39 |
+
|
| 40 |
+
import logging
|
| 41 |
+
logging.getLogger('frontend').setLevel(logging.WARNING)
|
| 42 |
+
logging.basicConfig(level=logging.DEBUG,
|
| 43 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
| 44 |
+
|
| 45 |
+
class CosyVoiceFrontEnd:
|
| 46 |
+
|
| 47 |
+
def __init__(self,
|
| 48 |
+
pretrained_path: str,
|
| 49 |
+
wetext_dir: str,
|
| 50 |
+
campplus_model: str,
|
| 51 |
+
speech_tokenizer_model: str,
|
| 52 |
+
spk2info: str = '',
|
| 53 |
+
allowed_special: str = 'all'):
|
| 54 |
+
self.tokenizer = get_qwen_tokenizer(pretrained_path, True)
|
| 55 |
+
self.feat_extractor = partial(
|
| 56 |
+
mel_spectrogram,
|
| 57 |
+
n_fft=1920,
|
| 58 |
+
num_mels=80,
|
| 59 |
+
sampling_rate=24000,
|
| 60 |
+
hop_size=480,
|
| 61 |
+
win_size=1920,
|
| 62 |
+
fmin=0,
|
| 63 |
+
fmax=8000,
|
| 64 |
+
center=False)
|
| 65 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 66 |
+
option = onnxruntime.SessionOptions()
|
| 67 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 68 |
+
option.intra_op_num_threads = 1
|
| 69 |
+
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
| 70 |
+
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
| 71 |
+
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
| 72 |
+
"CPUExecutionProvider"])
|
| 73 |
+
if os.path.exists(spk2info):
|
| 74 |
+
self.spk2info = torch.load(spk2info, map_location=self.device)
|
| 75 |
+
else:
|
| 76 |
+
self.spk2info = {}
|
| 77 |
+
self.allowed_special = allowed_special
|
| 78 |
+
self.use_ttsfrd = use_ttsfrd
|
| 79 |
+
if self.use_ttsfrd:
|
| 80 |
+
self.frd = ttsfrd.TtsFrontendEngine()
|
| 81 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 82 |
+
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
|
| 83 |
+
'failed to initialize ttsfrd resource'
|
| 84 |
+
self.frd.set_lang_type('pinyinvg')
|
| 85 |
+
else:
|
| 86 |
+
self.zh_tn_model = ZhNormalizer(remove_erhua=False, lang="zh", tagger_path=f"{wetext_dir}/zh/tn/tagger.fst", verbalizer_path=f"{wetext_dir}/zh/tn/tagger.fst")
|
| 87 |
+
self.en_tn_model = EnNormalizer(lang="zh", tagger_path=f"{wetext_dir}/zh/tn/tagger.fst", verbalizer_path=f"{wetext_dir}/zh/tn/tagger.fst")
|
| 88 |
+
self.inflect_parser = inflect.engine()
|
| 89 |
+
|
| 90 |
+
def _extract_text_token(self, text):
|
| 91 |
+
if isinstance(text, Generator):
|
| 92 |
+
logging.info('get tts_text generator, will return _extract_text_token_generator!')
|
| 93 |
+
# NOTE add a dummy text_token_len for compatibility
|
| 94 |
+
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
|
| 95 |
+
else:
|
| 96 |
+
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
| 97 |
+
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
| 98 |
+
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
| 99 |
+
return text_token, text_token_len
|
| 100 |
+
|
| 101 |
+
def _extract_text_token_generator(self, text_generator):
|
| 102 |
+
for text in text_generator:
|
| 103 |
+
text_token, _ = self._extract_text_token(text)
|
| 104 |
+
for i in range(text_token.shape[1]):
|
| 105 |
+
yield text_token[:, i: i + 1]
|
| 106 |
+
|
| 107 |
+
def _extract_speech_token(self, speech):
|
| 108 |
+
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
| 109 |
+
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
| 110 |
+
speech_token = self.speech_tokenizer_session.run(None,
|
| 111 |
+
{self.speech_tokenizer_session.get_inputs()[0].name:
|
| 112 |
+
feat.detach().cpu().numpy(),
|
| 113 |
+
self.speech_tokenizer_session.get_inputs()[1].name:
|
| 114 |
+
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
| 115 |
+
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
| 116 |
+
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
| 117 |
+
return speech_token, speech_token_len
|
| 118 |
+
|
| 119 |
+
def _extract_spk_embedding(self, speech):
|
| 120 |
+
feat = kaldi.fbank(speech,
|
| 121 |
+
num_mel_bins=80,
|
| 122 |
+
dither=0,
|
| 123 |
+
sample_frequency=16000)
|
| 124 |
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
| 125 |
+
embedding = self.campplus_session.run(None,
|
| 126 |
+
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
| 127 |
+
embedding = torch.tensor([embedding]).to(self.device)
|
| 128 |
+
return embedding
|
| 129 |
+
|
| 130 |
+
def _extract_speech_feat(self, speech):
|
| 131 |
+
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
| 132 |
+
speech_feat = speech_feat.unsqueeze(dim=0)
|
| 133 |
+
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
| 134 |
+
return speech_feat, speech_feat_len
|
| 135 |
+
|
| 136 |
+
def text_normalize(self, text, split=True, text_frontend=True):
|
| 137 |
+
if isinstance(text, Generator):
|
| 138 |
+
logging.info('get tts_text generator, will skip text_normalize!')
|
| 139 |
+
return [text]
|
| 140 |
+
if text_frontend is False or text == '':
|
| 141 |
+
return [text] if split is True else text
|
| 142 |
+
text = text.strip()
|
| 143 |
+
if self.use_ttsfrd:
|
| 144 |
+
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
| 145 |
+
text = ''.join(texts)
|
| 146 |
+
else:
|
| 147 |
+
if contains_chinese(text):
|
| 148 |
+
text = self.zh_tn_model.normalize(text)
|
| 149 |
+
text = text.replace("\n", "")
|
| 150 |
+
text = replace_blank(text)
|
| 151 |
+
text = replace_corner_mark(text)
|
| 152 |
+
text = text.replace(".", "。")
|
| 153 |
+
text = text.replace(" - ", ",")
|
| 154 |
+
text = remove_bracket(text)
|
| 155 |
+
text = re.sub(r'[,,、]+$', '。', text)
|
| 156 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
| 157 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
| 158 |
+
else:
|
| 159 |
+
text = self.en_tn_model.normalize(text)
|
| 160 |
+
text = spell_out_number(text, self.inflect_parser)
|
| 161 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
| 162 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
| 163 |
+
texts = [i for i in texts if not is_only_punctuation(i)]
|
| 164 |
+
return texts if split is True else text
|
| 165 |
+
|
| 166 |
+
def frontend_sft(self, tts_text, spk_id):
|
| 167 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
| 168 |
+
embedding = self.spk2info[spk_id]['embedding']
|
| 169 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
| 170 |
+
return model_input
|
| 171 |
+
|
| 172 |
+
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
|
| 173 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
| 174 |
+
if zero_shot_spk_id == '':
|
| 175 |
+
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
| 176 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
| 177 |
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
| 178 |
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
| 179 |
+
if resample_rate == 24000:
|
| 180 |
+
# cosyvoice2, force speech_feat % speech_token = 2
|
| 181 |
+
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
| 182 |
+
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
|
| 183 |
+
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
| 184 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
| 185 |
+
model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
| 186 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
| 187 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
| 188 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
| 189 |
+
'llm_embedding': embedding, 'flow_embedding': embedding}
|
| 190 |
+
else:
|
| 191 |
+
model_input = self.spk2info[zero_shot_spk_id]
|
| 192 |
+
model_input['text'] = tts_text_token
|
| 193 |
+
model_input['text_len'] = tts_text_token_len
|
| 194 |
+
return model_input
|
| 195 |
+
|
| 196 |
+
def process_prompt(self, prompt_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
|
| 197 |
+
if zero_shot_spk_id == '':
|
| 198 |
+
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
| 199 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
| 200 |
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
| 201 |
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
| 202 |
+
if resample_rate == 24000:
|
| 203 |
+
# cosyvoice2, force speech_feat % speech_token = 2
|
| 204 |
+
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
| 205 |
+
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
|
| 206 |
+
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
| 207 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
| 208 |
+
model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
| 209 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
| 210 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
| 211 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
| 212 |
+
'llm_embedding': embedding, 'flow_embedding': embedding}
|
| 213 |
+
else:
|
| 214 |
+
model_input = self.spk2info[zero_shot_spk_id]
|
| 215 |
+
return model_input
|
| 216 |
+
|
| 217 |
+
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
|
| 218 |
+
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate, zero_shot_spk_id)
|
| 219 |
+
# in cross lingual mode, we remove prompt in llm
|
| 220 |
+
del model_input['prompt_text']
|
| 221 |
+
del model_input['prompt_text_len']
|
| 222 |
+
del model_input['llm_prompt_speech_token']
|
| 223 |
+
del model_input['llm_prompt_speech_token_len']
|
| 224 |
+
return model_input
|
| 225 |
+
|
| 226 |
+
def frontend_instruct(self, tts_text, spk_id, instruct_text):
|
| 227 |
+
model_input = self.frontend_sft(tts_text, spk_id)
|
| 228 |
+
# in instruct mode, we remove spk_embedding in llm due to information leakage
|
| 229 |
+
del model_input['llm_embedding']
|
| 230 |
+
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
|
| 231 |
+
model_input['prompt_text'] = instruct_text_token
|
| 232 |
+
model_input['prompt_text_len'] = instruct_text_token_len
|
| 233 |
+
return model_input
|
| 234 |
+
|
| 235 |
+
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
|
| 236 |
+
model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate, zero_shot_spk_id)
|
| 237 |
+
del model_input['llm_prompt_speech_token']
|
| 238 |
+
del model_input['llm_prompt_speech_token_len']
|
| 239 |
+
return model_input
|
| 240 |
+
|
| 241 |
+
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
|
| 242 |
+
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
| 243 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
| 244 |
+
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
| 245 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
| 246 |
+
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
| 247 |
+
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
|
| 248 |
+
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
|
| 249 |
+
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
|
| 250 |
+
'flow_embedding': embedding}
|
| 251 |
+
return model_input
|
scripts/process_prompt.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
import numpy as np
|
| 6 |
+
from frontend import CosyVoiceFrontEnd
|
| 7 |
+
import soundfile as sf
|
| 8 |
+
|
| 9 |
+
def load_wav(wav, target_sr):
|
| 10 |
+
speech, sample_rate = sf.read(wav, dtype='float32')
|
| 11 |
+
|
| 12 |
+
if speech.ndim == 1:
|
| 13 |
+
speech = torch.from_numpy(speech).unsqueeze(0) # (1, T)
|
| 14 |
+
else:
|
| 15 |
+
speech = torch.from_numpy(speech).transpose(0, 1) # (C, T)
|
| 16 |
+
|
| 17 |
+
speech = speech.mean(dim=0, keepdim=True) # (1, T)
|
| 18 |
+
|
| 19 |
+
if sample_rate != target_sr:
|
| 20 |
+
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
|
| 21 |
+
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
|
| 22 |
+
return speech
|
| 23 |
+
|
| 24 |
+
if __name__ == "__main__":
|
| 25 |
+
|
| 26 |
+
args = argparse.ArgumentParser()
|
| 27 |
+
args.add_argument('--model_dir', type=str, default="CosyVoice-BlankEN", help="tokenizer configuration directionary")
|
| 28 |
+
args.add_argument('--wetext_dir', type=str, default="pengzhendong/wetext", help="path to wetext")
|
| 29 |
+
args.add_argument('--sample_rate', type=int, default=24000, help="Sampling rate for prompt audio")
|
| 30 |
+
args.add_argument('--prompt_text', type=str, default="希望你以后能够做的比我还好呦。", help="The text content of the prompt(reference) audio. Text or file path.")
|
| 31 |
+
args.add_argument('--prompt_speech', type=str, default="asset/zero_shot_prompt.wav", help="The path to prompt(reference) audio.")
|
| 32 |
+
args.add_argument('--output', type=str, default="prompt_files", help="Output data storage directory")
|
| 33 |
+
args = args.parse_args()
|
| 34 |
+
|
| 35 |
+
os.makedirs(args.output, exist_ok=True)
|
| 36 |
+
|
| 37 |
+
frontend = CosyVoiceFrontEnd(f"{args.model_dir}",
|
| 38 |
+
args.wetext_dir,
|
| 39 |
+
"frontend-onnx/campplus.onnx",
|
| 40 |
+
"frontend-onnx/speech_tokenizer_v2.onnx",
|
| 41 |
+
f"{args.model_dir}/spk2info.pt",
|
| 42 |
+
"all")
|
| 43 |
+
|
| 44 |
+
prompt_speech_16k = load_wav(args.prompt_speech, 16000)
|
| 45 |
+
zero_shot_spk_id = ""
|
| 46 |
+
|
| 47 |
+
if os.path.isfile(args.prompt_text):
|
| 48 |
+
with open(args.prompt_text, "r") as f:
|
| 49 |
+
prompt_text = f.read()
|
| 50 |
+
else:
|
| 51 |
+
prompt_text = args.prompt_text
|
| 52 |
+
print("prompt_text",prompt_text)
|
| 53 |
+
model_input = frontend.process_prompt( prompt_text, prompt_speech_16k, args.sample_rate, zero_shot_spk_id)
|
| 54 |
+
|
| 55 |
+
# model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
| 56 |
+
# 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
| 57 |
+
# 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
| 58 |
+
# 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
| 59 |
+
# 'llm_embedding': embedding, 'flow_embedding': embedding}
|
| 60 |
+
print("prompt speech token size:", model_input["flow_prompt_speech_token"].shape)
|
| 61 |
+
assert model_input["flow_prompt_speech_token"].shape[1] >=75, f"speech_token length should >= 75, bug get {model_input['flow_prompt_speech_token'].shape[1]}"
|
| 62 |
+
for k, v in model_input.items():
|
| 63 |
+
if "_len" in k:
|
| 64 |
+
continue
|
| 65 |
+
shapes = [str(s) for s in v.shape]
|
| 66 |
+
shape_str = "_".join(shapes)
|
| 67 |
+
if v.dtype in (torch.int32, torch.int64):
|
| 68 |
+
np.savetxt(f"{args.output}/{k}.txt", v.detach().cpu().numpy().reshape(-1), fmt="%d", delimiter=",")
|
| 69 |
+
else:
|
| 70 |
+
np.savetxt(f"{args.output}/{k}.txt", v.detach().cpu().numpy().reshape(-1), delimiter=",")
|
scripts/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scripts/tokenizer/tokenizer.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import os
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
from typing import Optional
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
import tiktoken
|
| 8 |
+
|
| 9 |
+
LANGUAGES = {
|
| 10 |
+
"en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian",
|
| 11 |
+
"ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish",
|
| 12 |
+
"pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian",
|
| 13 |
+
"id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese", "he": "hebrew",
|
| 14 |
+
"uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish",
|
| 15 |
+
"hu": "hungarian", "ta": "tamil", "no": "norwegian", "th": "thai", "ur": "urdu", "hr": "croatian",
|
| 16 |
+
"bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh",
|
| 17 |
+
"sk": "slovak", "te": "telugu", "fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian",
|
| 18 |
+
"az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian",
|
| 19 |
+
"br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian",
|
| 20 |
+
"bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili", "gl": "galician", "mr": "marathi",
|
| 21 |
+
"pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali",
|
| 22 |
+
"af": "afrikaans", "oc": "occitan", "ka": "georgian", "be": "belarusian", "tg": "tajik",
|
| 23 |
+
"sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek",
|
| 24 |
+
"fo": "faroese", "ht": "haitian creole", "ps": "pashto", "tk": "turkmen", "nn": "nynorsk",
|
| 25 |
+
"mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan",
|
| 26 |
+
"tl": "tagalog", "mg": "malagasy", "as": "assamese", "tt": "tatar", "haw": "hawaiian",
|
| 27 |
+
"ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese",
|
| 28 |
+
"yue": "cantonese", "minnan": "minnan", "wuyu": "wuyu", "dialect": "dialect", "zh/en": "zh/en", "en/zh": "en/zh"
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
TO_LANGUAGE_CODE = {
|
| 32 |
+
**{language: code for code, language in LANGUAGES.items()},
|
| 33 |
+
"burmese": "my", "valencian": "ca", "flemish": "nl", "haitian": "ht", "letzeburgesch": "lb",
|
| 34 |
+
"pushto": "ps", "panjabi": "pa", "moldavian": "ro", "moldovan": "ro", "sinhalese": "si",
|
| 35 |
+
"castilian": "es", "mandarin": "zh",
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
AUDIO_EVENT = {
|
| 39 |
+
"ASR": "ASR", "AED": "AED", "SER": "SER", "Speech": "Speech", "/Speech": "/Speech",
|
| 40 |
+
"BGM": "BGM", "/BGM": "/BGM", "Laughter": "Laughter", "/Laughter": "/Laughter",
|
| 41 |
+
"Applause": "Applause", "/Applause": "/Applause",
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
EMOTION = {
|
| 45 |
+
"HAPPY": "HAPPY", "SAD": "SAD", "ANGRY": "ANGRY", "NEUTRAL": "NEUTRAL",
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
TTS_Vocal_Token = {
|
| 49 |
+
"TTS/B": "TTS/B", "TTS/O": "TTS/O", "TTS/Q": "TTS/Q", "TTS/A": "TTS/A", "TTS/CO": "TTS/CO",
|
| 50 |
+
"TTS/CL": "TTS/CL", "TTS/H": "TTS/H", **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
# ===== 构造 Encoding =====
|
| 54 |
+
@lru_cache(maxsize=None)
|
| 55 |
+
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
| 56 |
+
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
| 57 |
+
ranks = {
|
| 58 |
+
base64.b64decode(token): int(rank)
|
| 59 |
+
for token, rank in (line.split() for line in open(vocab_path) if line)
|
| 60 |
+
}
|
| 61 |
+
n_vocab = len(ranks)
|
| 62 |
+
special_tokens = {}
|
| 63 |
+
specials = [
|
| 64 |
+
"<|endoftext|>", "<|startoftranscript|>",
|
| 65 |
+
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
| 66 |
+
*[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
|
| 67 |
+
*[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
|
| 68 |
+
"<|translate|>", "<|transcribe|>", "<|startoflm|>", "<|startofprev|>",
|
| 69 |
+
"<|nospeech|>", "<|notimestamps|>",
|
| 70 |
+
*[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)],
|
| 71 |
+
*[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())],
|
| 72 |
+
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
| 73 |
+
]
|
| 74 |
+
for token in specials:
|
| 75 |
+
special_tokens[token] = n_vocab
|
| 76 |
+
n_vocab += 1
|
| 77 |
+
return tiktoken.Encoding(
|
| 78 |
+
name=os.path.basename(vocab_path),
|
| 79 |
+
explicit_n_vocab=n_vocab,
|
| 80 |
+
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
| 81 |
+
mergeable_ranks=ranks,
|
| 82 |
+
special_tokens=special_tokens,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
class SimpleTokenizer:
|
| 86 |
+
def __init__(self, encoding, num_languages: int = 99, language: Optional[str] = None, task: Optional[str] = None):
|
| 87 |
+
self.encoding = encoding
|
| 88 |
+
self.num_languages = num_languages
|
| 89 |
+
self.language = language
|
| 90 |
+
self.task = task
|
| 91 |
+
def encode(self, text: str):
|
| 92 |
+
return self.encoding.encode(text)
|
| 93 |
+
def decode(self, tokens: list):
|
| 94 |
+
return self.encoding.decode(tokens)
|
| 95 |
+
|
| 96 |
+
@lru_cache(maxsize=None)
|
| 97 |
+
def get_tokenizer(
|
| 98 |
+
multilingual: bool,
|
| 99 |
+
*,
|
| 100 |
+
num_languages: int = 99,
|
| 101 |
+
language: Optional[str] = None,
|
| 102 |
+
task: Optional[str] = None,
|
| 103 |
+
) -> SimpleTokenizer:
|
| 104 |
+
if language is not None:
|
| 105 |
+
language = language.lower()
|
| 106 |
+
if language not in LANGUAGES:
|
| 107 |
+
if language in TO_LANGUAGE_CODE:
|
| 108 |
+
language = TO_LANGUAGE_CODE[language]
|
| 109 |
+
else:
|
| 110 |
+
raise ValueError(f"Unsupported language: {language}")
|
| 111 |
+
if multilingual:
|
| 112 |
+
encoding_name = "multilingual_zh_ja_yue_char_del"
|
| 113 |
+
language = language or "en"
|
| 114 |
+
task = task or "transcribe"
|
| 115 |
+
else:
|
| 116 |
+
encoding_name = "gpt2"
|
| 117 |
+
language = None
|
| 118 |
+
task = None
|
| 119 |
+
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
| 120 |
+
return SimpleTokenizer(encoding=encoding, num_languages=num_languages, language=language, task=task)
|
| 121 |
+
|
| 122 |
+
class QwenTokenizer():
|
| 123 |
+
def __init__(self, token_path, skip_special_tokens=True):
|
| 124 |
+
super().__init__()
|
| 125 |
+
special_tokens = {
|
| 126 |
+
'eos_token': '<|endoftext|>',
|
| 127 |
+
'pad_token': '<|endoftext|>',
|
| 128 |
+
'additional_special_tokens': [
|
| 129 |
+
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
| 130 |
+
'[breath]', '<strong>', '</strong>', '[noise]',
|
| 131 |
+
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
| 132 |
+
'[quick_breath]',
|
| 133 |
+
"<laughter>", "</laughter>",
|
| 134 |
+
"[hissing]", "[sigh]", "[vocalized-noise]",
|
| 135 |
+
"[lipsmack]", "[mn]"
|
| 136 |
+
]
|
| 137 |
+
}
|
| 138 |
+
self.special_tokens = special_tokens
|
| 139 |
+
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
|
| 140 |
+
self.tokenizer.add_special_tokens(special_tokens)
|
| 141 |
+
self.skip_special_tokens = skip_special_tokens
|
| 142 |
+
def encode(self, text, **kwargs):
|
| 143 |
+
tokens = self.tokenizer([text], return_tensors="pt")
|
| 144 |
+
return tokens["input_ids"][0].cpu().tolist()
|
| 145 |
+
def decode(self, tokens):
|
| 146 |
+
tokens = torch.tensor(tokens, dtype=torch.int64)
|
| 147 |
+
return self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
|
| 148 |
+
|
| 149 |
+
@lru_cache(maxsize=None)
|
| 150 |
+
def get_qwen_tokenizer(token_path: str, skip_special_tokens: bool) -> QwenTokenizer:
|
| 151 |
+
return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
|