Spaces:
Build error
Build error
Add file CosyVoice
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- model/.ipynb_checkpoints/README-checkpoint.md +24 -0
- model/.ipynb_checkpoints/cosyvoice_notebook-checkpoint.ipynb +113 -0
- model/.ipynb_checkpoints/gitignore-checkpoint.txt +7 -0
- model/.ipynb_checkpoints/requirements-checkpoint.txt +6 -0
- model/README.md +24 -0
- model/cosyvoice/cli/.ipynb_checkpoints/cosyvoice-checkpoint.py +179 -0
- model/cosyvoice/cli/.ipynb_checkpoints/frontend-checkpoint.py +211 -0
- model/cosyvoice/cli/__init__.py +0 -0
- model/cosyvoice/cli/__pycache__/__init__.cpython-310.pyc +0 -0
- model/cosyvoice/cli/__pycache__/__init__.cpython-311.pyc +0 -0
- model/cosyvoice/cli/__pycache__/__init__.cpython-312.pyc +0 -0
- model/cosyvoice/cli/__pycache__/cosyvoice.cpython-310.pyc +0 -0
- model/cosyvoice/cli/__pycache__/cosyvoice.cpython-311.pyc +0 -0
- model/cosyvoice/cli/__pycache__/cosyvoice.cpython-312.pyc +0 -0
- model/cosyvoice/cli/__pycache__/frontend.cpython-310.pyc +0 -0
- model/cosyvoice/cli/__pycache__/frontend.cpython-311.pyc +0 -0
- model/cosyvoice/cli/__pycache__/frontend.cpython-312.pyc +0 -0
- model/cosyvoice/cli/__pycache__/model.cpython-310.pyc +0 -0
- model/cosyvoice/cli/__pycache__/model.cpython-311.pyc +0 -0
- model/cosyvoice/cli/cosyvoice.py +179 -0
- model/cosyvoice/cli/frontend.py +211 -0
- model/cosyvoice/cli/model.py +461 -0
- model/cosyvoice/dataset/__init__.py +0 -0
- model/cosyvoice/dataset/__pycache__/__init__.cpython-310.pyc +0 -0
- model/cosyvoice/dataset/__pycache__/__init__.cpython-311.pyc +0 -0
- model/cosyvoice/dataset/__pycache__/processor.cpython-310.pyc +0 -0
- model/cosyvoice/dataset/__pycache__/processor.cpython-311.pyc +0 -0
- model/cosyvoice/dataset/dataset.py +164 -0
- model/cosyvoice/dataset/processor.py +435 -0
- model/cosyvoice/flow/__pycache__/decoder.cpython-310.pyc +0 -0
- model/cosyvoice/flow/__pycache__/decoder.cpython-311.pyc +0 -0
- model/cosyvoice/flow/__pycache__/flow.cpython-310.pyc +0 -0
- model/cosyvoice/flow/__pycache__/flow.cpython-311.pyc +0 -0
- model/cosyvoice/flow/__pycache__/flow_matching.cpython-310.pyc +0 -0
- model/cosyvoice/flow/__pycache__/flow_matching.cpython-311.pyc +0 -0
- model/cosyvoice/flow/decoder.py +902 -0
- model/cosyvoice/flow/flow.py +289 -0
- model/cosyvoice/flow/flow_matching.py +344 -0
- model/cosyvoice/flow/length_regulator.py +70 -0
- model/cosyvoice/hifigan/__pycache__/discriminator.cpython-310.pyc +0 -0
- model/cosyvoice/hifigan/__pycache__/discriminator.cpython-311.pyc +0 -0
- model/cosyvoice/hifigan/__pycache__/f0_predictor.cpython-310.pyc +0 -0
- model/cosyvoice/hifigan/__pycache__/f0_predictor.cpython-311.pyc +0 -0
- model/cosyvoice/hifigan/__pycache__/generator.cpython-310.pyc +0 -0
- model/cosyvoice/hifigan/__pycache__/generator.cpython-311.pyc +0 -0
- model/cosyvoice/hifigan/__pycache__/hifigan.cpython-310.pyc +0 -0
- model/cosyvoice/hifigan/__pycache__/hifigan.cpython-311.pyc +0 -0
- model/cosyvoice/hifigan/discriminator.py +230 -0
- model/cosyvoice/hifigan/f0_predictor.py +58 -0
- model/cosyvoice/hifigan/generator.py +414 -0
model/.ipynb_checkpoints/README-checkpoint.md
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
```bash
|
| 2 |
+
# 1. 리포지토리 클론 및 이동
|
| 3 |
+
git clone https://github.com/yourusername/cosyvoice.git
|
| 4 |
+
cd cosyvoice
|
| 5 |
+
|
| 6 |
+
# 2. 의존성 설치
|
| 7 |
+
pip install -r requirements.txt
|
| 8 |
+
|
| 9 |
+
# 3. 사전학습 모델 다운로드 (Python 코드 실행)
|
| 10 |
+
from modelscope import snapshot_download
|
| 11 |
+
|
| 12 |
+
# CosyVoice2 TTS 모델
|
| 13 |
+
snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
|
| 14 |
+
|
| 15 |
+
# CosyVoice 전처리기 (frontend)
|
| 16 |
+
snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
|
| 17 |
+
|
| 18 |
+
# 4. frontend 리소스 압축 해제
|
| 19 |
+
cd pretrained_models/CosyVoice-ttsfrd/
|
| 20 |
+
unzip resource.zip -d .
|
| 21 |
+
|
| 22 |
+
# 5. ttsfrd 의존성 설치
|
| 23 |
+
pip install ttsfrd_dependency-0.1-py3-none-any.whl
|
| 24 |
+
pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
|
model/.ipynb_checkpoints/cosyvoice_notebook-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "8ae68495-7e0f-4308-b63a-f1d09f2850c7",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stderr",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"/home/pasong/miniconda3/envs/cosyvoice/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 14 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
| 15 |
+
]
|
| 16 |
+
}
|
| 17 |
+
],
|
| 18 |
+
"source": [
|
| 19 |
+
"import sys\n",
|
| 20 |
+
"sys.path.append('third_party/Matcha-TTS')\n",
|
| 21 |
+
"from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2\n",
|
| 22 |
+
"from cosyvoice.utils.file_utils import load_wav\n",
|
| 23 |
+
"import torchaudio"
|
| 24 |
+
]
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"cell_type": "code",
|
| 28 |
+
"execution_count": 2,
|
| 29 |
+
"id": "6a38d09e-c3a2-43cf-a684-0fd3b47e4b09",
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"outputs": [
|
| 32 |
+
{
|
| 33 |
+
"name": "stderr",
|
| 34 |
+
"output_type": "stream",
|
| 35 |
+
"text": [
|
| 36 |
+
"/home/pasong/miniconda3/envs/cosyvoice/lib/python3.10/site-packages/diffusers/models/lora.py:393: FutureWarning: `LoRACompatibleLinear` is deprecated and will be removed in version 1.0.0. Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`.\n",
|
| 37 |
+
" deprecate(\"LoRACompatibleLinear\", \"1.0.0\", deprecation_message)\n",
|
| 38 |
+
"2025-04-19 19:46:33,344 INFO input frame rate=25\n",
|
| 39 |
+
"/home/pasong/miniconda3/envs/cosyvoice/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\n",
|
| 40 |
+
" warnings.warn(\"torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\")\n",
|
| 41 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
|
| 42 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
|
| 43 |
+
"\u001b[0;93m2025-04-19 19:46:35.693712543 [W:onnxruntime:, transformer_memcpy.cc:74 ApplyImpl] 8 Memcpy nodes are added to the graph main_graph for CUDAExecutionProvider. It might have negative impact on performance (including unable to run CUDA graph). Set session_options.log_severity_level=1 to see the detail logs before this message.\u001b[m\n",
|
| 44 |
+
"\u001b[0;93m2025-04-19 19:46:35.697676173 [W:onnxruntime:, session_state.cc:1166 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.\u001b[m\n",
|
| 45 |
+
"\u001b[0;93m2025-04-19 19:46:35.697686892 [W:onnxruntime:, session_state.cc:1168 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.\u001b[m\n",
|
| 46 |
+
" 0%| | 0/1 [00:00<?, ?it/s]2025-04-19 19:46:41,474 INFO synthesis text 공룡이 밤양갱을 몰래 먹고 도망쳤어요。\n"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"name": "stdout",
|
| 51 |
+
"output_type": "stream",
|
| 52 |
+
"text": [
|
| 53 |
+
"text.cc: festival_Text_init\n",
|
| 54 |
+
"open voice lang map failed\n"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"name": "stderr",
|
| 59 |
+
"output_type": "stream",
|
| 60 |
+
"text": [
|
| 61 |
+
"2025-04-19 19:46:46,661 INFO yield speech len 5.92, rtf 0.8762226314158054\n",
|
| 62 |
+
"100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:06<00:00, 6.36s/it]\n"
|
| 63 |
+
]
|
| 64 |
+
}
|
| 65 |
+
],
|
| 66 |
+
"source": [
|
| 67 |
+
"cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False, use_flow_cache=False)\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"# 프롬프트로 사용할 음성 파일 (화자의 목소리 담긴 wav, 16kHz로)\n",
|
| 70 |
+
"prompt_speech_16k = load_wav('./asset/tts_test.wav', 16000)\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"for i, j in enumerate(\n",
|
| 73 |
+
" cosyvoice.inference_zero_shot(\n",
|
| 74 |
+
" '공룡이 밤양갱을 몰래 먹고 도망쳤어요.', #tts할 문장\n",
|
| 75 |
+
" prompt_text='오느른 커피 안 마실 꺼야', #tts_test의 발음문장\n",
|
| 76 |
+
" prompt_speech_16k=prompt_speech_16k,\n",
|
| 77 |
+
" text_frontend=True, \n",
|
| 78 |
+
" )\n",
|
| 79 |
+
"):\n",
|
| 80 |
+
" torchaudio.save(f'korean_tts_{i}.wav', j['tts_speech'], cosyvoice.sample_rate)"
|
| 81 |
+
]
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"cell_type": "code",
|
| 85 |
+
"execution_count": null,
|
| 86 |
+
"id": "899101b0-3556-4ab5-8477-d182da48357d",
|
| 87 |
+
"metadata": {},
|
| 88 |
+
"outputs": [],
|
| 89 |
+
"source": []
|
| 90 |
+
}
|
| 91 |
+
],
|
| 92 |
+
"metadata": {
|
| 93 |
+
"kernelspec": {
|
| 94 |
+
"display_name": "Python (cosyvoice)",
|
| 95 |
+
"language": "python",
|
| 96 |
+
"name": "cosyvoice"
|
| 97 |
+
},
|
| 98 |
+
"language_info": {
|
| 99 |
+
"codemirror_mode": {
|
| 100 |
+
"name": "ipython",
|
| 101 |
+
"version": 3
|
| 102 |
+
},
|
| 103 |
+
"file_extension": ".py",
|
| 104 |
+
"mimetype": "text/x-python",
|
| 105 |
+
"name": "python",
|
| 106 |
+
"nbconvert_exporter": "python",
|
| 107 |
+
"pygments_lexer": "ipython3",
|
| 108 |
+
"version": "3.10.16"
|
| 109 |
+
}
|
| 110 |
+
},
|
| 111 |
+
"nbformat": 4,
|
| 112 |
+
"nbformat_minor": 5
|
| 113 |
+
}
|
model/.ipynb_checkpoints/gitignore-checkpoint.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
.ipynb_checkpoints/
|
| 3 |
+
*.pt
|
| 4 |
+
*.pth
|
| 5 |
+
*.onnx
|
| 6 |
+
pretrained_models/CosyVoice2-0.5B/
|
| 7 |
+
pretrained_models/CosyVoice-ttsfrd/resource/
|
model/.ipynb_checkpoints/requirements-checkpoint.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
HyperPyYAML==1.2.2
|
| 2 |
+
modelscope==1.20.0
|
| 3 |
+
tensorrt_cu12==10.0.1
|
| 4 |
+
torch==2.3.1+cu121
|
| 5 |
+
torchaudio==2.3.1+cu121
|
| 6 |
+
tqdm==4.67.1
|
model/README.md
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
```bash
|
| 2 |
+
# 1. 리포지토리 클론 및 이동
|
| 3 |
+
git clone https://github.com/yourusername/cosyvoice.git
|
| 4 |
+
cd cosyvoice
|
| 5 |
+
|
| 6 |
+
# 2. 의존성 설치
|
| 7 |
+
pip install -r requirements.txt
|
| 8 |
+
|
| 9 |
+
# 3. 사전학습 모델 다운로드 (Python 코드 실행)
|
| 10 |
+
from modelscope import snapshot_download
|
| 11 |
+
|
| 12 |
+
# CosyVoice2 TTS 모델
|
| 13 |
+
snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
|
| 14 |
+
|
| 15 |
+
# CosyVoice 전처리기 (frontend)
|
| 16 |
+
snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
|
| 17 |
+
|
| 18 |
+
# 4. frontend 리소스 압축 해제
|
| 19 |
+
cd pretrained_models/CosyVoice-ttsfrd/
|
| 20 |
+
unzip resource.zip -d .
|
| 21 |
+
|
| 22 |
+
# 5. ttsfrd 의존성 설치
|
| 23 |
+
pip install ttsfrd_dependency-0.1-py3-none-any.whl
|
| 24 |
+
pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
|
model/cosyvoice/cli/.ipynb_checkpoints/cosyvoice-checkpoint.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
import time
|
| 16 |
+
from typing import Generator
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from hyperpyyaml import load_hyperpyyaml
|
| 19 |
+
from modelscope import snapshot_download
|
| 20 |
+
import torch
|
| 21 |
+
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
| 22 |
+
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
| 23 |
+
from cosyvoice.utils.file_utils import logging
|
| 24 |
+
from cosyvoice.utils.class_utils import get_model_type
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class CosyVoice:
|
| 28 |
+
|
| 29 |
+
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
|
| 30 |
+
self.instruct = True if '-Instruct' in model_dir else False
|
| 31 |
+
self.model_dir = model_dir
|
| 32 |
+
self.fp16 = fp16
|
| 33 |
+
if not os.path.exists(model_dir):
|
| 34 |
+
model_dir = snapshot_download(model_dir)
|
| 35 |
+
hyper_yaml_path = '{}/cosyvoice.yaml'.format(model_dir)
|
| 36 |
+
if not os.path.exists(hyper_yaml_path):
|
| 37 |
+
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
| 38 |
+
with open(hyper_yaml_path, 'r') as f:
|
| 39 |
+
configs = load_hyperpyyaml(f)
|
| 40 |
+
assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
|
| 41 |
+
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
| 42 |
+
configs['feat_extractor'],
|
| 43 |
+
'{}/campplus.onnx'.format(model_dir),
|
| 44 |
+
'{}/speech_tokenizer_v1.onnx'.format(model_dir),
|
| 45 |
+
'{}/spk2info.pt'.format(model_dir),
|
| 46 |
+
configs['allowed_special'])
|
| 47 |
+
self.sample_rate = configs['sample_rate']
|
| 48 |
+
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
| 49 |
+
load_jit, load_trt, fp16 = False, False, False
|
| 50 |
+
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
| 51 |
+
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
|
| 52 |
+
self.model.load('{}/llm.pt'.format(model_dir),
|
| 53 |
+
'{}/flow.pt'.format(model_dir),
|
| 54 |
+
'{}/hift.pt'.format(model_dir))
|
| 55 |
+
if load_jit:
|
| 56 |
+
self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
| 57 |
+
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
| 58 |
+
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
| 59 |
+
if load_trt:
|
| 60 |
+
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
| 61 |
+
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
| 62 |
+
self.fp16)
|
| 63 |
+
del configs
|
| 64 |
+
|
| 65 |
+
def list_available_spks(self):
|
| 66 |
+
spks = list(self.frontend.spk2info.keys())
|
| 67 |
+
return spks
|
| 68 |
+
|
| 69 |
+
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
| 70 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 71 |
+
model_input = self.frontend.frontend_sft(i, spk_id)
|
| 72 |
+
start_time = time.time()
|
| 73 |
+
logging.info('synthesis text {}'.format(i))
|
| 74 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 75 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 76 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 77 |
+
yield model_output
|
| 78 |
+
start_time = time.time()
|
| 79 |
+
|
| 80 |
+
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
| 81 |
+
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
|
| 82 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 83 |
+
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
|
| 84 |
+
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
|
| 85 |
+
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
|
| 86 |
+
start_time = time.time()
|
| 87 |
+
logging.info('synthesis text {}'.format(i))
|
| 88 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 89 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 90 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 91 |
+
yield model_output
|
| 92 |
+
start_time = time.time()
|
| 93 |
+
|
| 94 |
+
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
| 95 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 96 |
+
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
|
| 97 |
+
start_time = time.time()
|
| 98 |
+
logging.info('synthesis text {}'.format(i))
|
| 99 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 100 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 101 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 102 |
+
yield model_output
|
| 103 |
+
start_time = time.time()
|
| 104 |
+
|
| 105 |
+
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
|
| 106 |
+
assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
|
| 107 |
+
if self.instruct is False:
|
| 108 |
+
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
| 109 |
+
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
|
| 110 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 111 |
+
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
| 112 |
+
start_time = time.time()
|
| 113 |
+
logging.info('synthesis text {}'.format(i))
|
| 114 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 115 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 116 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 117 |
+
yield model_output
|
| 118 |
+
start_time = time.time()
|
| 119 |
+
|
| 120 |
+
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
|
| 121 |
+
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
|
| 122 |
+
start_time = time.time()
|
| 123 |
+
for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
|
| 124 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 125 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 126 |
+
yield model_output
|
| 127 |
+
start_time = time.time()
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class CosyVoice2(CosyVoice):
|
| 131 |
+
|
| 132 |
+
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False):
|
| 133 |
+
self.instruct = True if '-Instruct' in model_dir else False
|
| 134 |
+
self.model_dir = model_dir
|
| 135 |
+
self.fp16 = fp16
|
| 136 |
+
if not os.path.exists(model_dir):
|
| 137 |
+
model_dir = snapshot_download(model_dir)
|
| 138 |
+
hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
|
| 139 |
+
if not os.path.exists(hyper_yaml_path):
|
| 140 |
+
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
| 141 |
+
with open(hyper_yaml_path, 'r') as f:
|
| 142 |
+
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
| 143 |
+
assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
|
| 144 |
+
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
| 145 |
+
configs['feat_extractor'],
|
| 146 |
+
'{}/campplus.onnx'.format(model_dir),
|
| 147 |
+
'{}/speech_tokenizer_v2.onnx'.format(model_dir),
|
| 148 |
+
'{}/spk2info.pt'.format(model_dir),
|
| 149 |
+
configs['allowed_special'])
|
| 150 |
+
self.sample_rate = configs['sample_rate']
|
| 151 |
+
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
| 152 |
+
load_jit, load_trt, fp16 = False, False, False
|
| 153 |
+
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
| 154 |
+
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache)
|
| 155 |
+
self.model.load('{}/llm.pt'.format(model_dir),
|
| 156 |
+
'{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir),
|
| 157 |
+
'{}/hift.pt'.format(model_dir))
|
| 158 |
+
if load_jit:
|
| 159 |
+
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
| 160 |
+
if load_trt:
|
| 161 |
+
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
| 162 |
+
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
| 163 |
+
self.fp16)
|
| 164 |
+
del configs
|
| 165 |
+
|
| 166 |
+
def inference_instruct(self, *args, **kwargs):
|
| 167 |
+
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
|
| 168 |
+
|
| 169 |
+
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
| 170 |
+
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
|
| 171 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 172 |
+
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
|
| 173 |
+
start_time = time.time()
|
| 174 |
+
logging.info('synthesis text {}'.format(i))
|
| 175 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 176 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 177 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 178 |
+
yield model_output
|
| 179 |
+
start_time = time.time()
|
model/cosyvoice/cli/.ipynb_checkpoints/frontend-checkpoint.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from functools import partial
|
| 15 |
+
from typing import Generator
|
| 16 |
+
import json
|
| 17 |
+
import onnxruntime
|
| 18 |
+
import torch
|
| 19 |
+
import numpy as np
|
| 20 |
+
import whisper
|
| 21 |
+
from typing import Callable
|
| 22 |
+
import torchaudio.compliance.kaldi as kaldi
|
| 23 |
+
import torchaudio
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
import inflect
|
| 27 |
+
try:
|
| 28 |
+
import ttsfrd
|
| 29 |
+
use_ttsfrd = True
|
| 30 |
+
except ImportError:
|
| 31 |
+
print("failed to import ttsfrd, use WeTextProcessing instead")
|
| 32 |
+
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
| 33 |
+
from tn.english.normalizer import Normalizer as EnNormalizer
|
| 34 |
+
use_ttsfrd = False
|
| 35 |
+
from cosyvoice.utils.file_utils import logging
|
| 36 |
+
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class CosyVoiceFrontEnd:
|
| 40 |
+
|
| 41 |
+
def __init__(self,
|
| 42 |
+
get_tokenizer: Callable,
|
| 43 |
+
feat_extractor: Callable,
|
| 44 |
+
campplus_model: str,
|
| 45 |
+
speech_tokenizer_model: str,
|
| 46 |
+
spk2info: str = '',
|
| 47 |
+
allowed_special: str = 'all'):
|
| 48 |
+
self.tokenizer = get_tokenizer()
|
| 49 |
+
self.feat_extractor = feat_extractor
|
| 50 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 51 |
+
option = onnxruntime.SessionOptions()
|
| 52 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 53 |
+
option.intra_op_num_threads = 1
|
| 54 |
+
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
| 55 |
+
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
| 56 |
+
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
| 57 |
+
"CPUExecutionProvider"])
|
| 58 |
+
if os.path.exists(spk2info):
|
| 59 |
+
self.spk2info = torch.load(spk2info, map_location=self.device)
|
| 60 |
+
else:
|
| 61 |
+
self.spk2info = {}
|
| 62 |
+
self.allowed_special = allowed_special
|
| 63 |
+
self.use_ttsfrd = use_ttsfrd
|
| 64 |
+
if self.use_ttsfrd:
|
| 65 |
+
self.frd = ttsfrd.TtsFrontendEngine()
|
| 66 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 67 |
+
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
|
| 68 |
+
'failed to initialize ttsfrd resource'
|
| 69 |
+
self.frd.set_lang_type('pinyinvg')
|
| 70 |
+
else:
|
| 71 |
+
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
|
| 72 |
+
self.en_tn_model = EnNormalizer()
|
| 73 |
+
self.inflect_parser = inflect.engine()
|
| 74 |
+
|
| 75 |
+
def _extract_text_token(self, text):
|
| 76 |
+
if isinstance(text, Generator):
|
| 77 |
+
logging.info('get tts_text generator, will return _extract_text_token_generator!')
|
| 78 |
+
# NOTE add a dummy text_token_len for compatibility
|
| 79 |
+
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
|
| 80 |
+
else:
|
| 81 |
+
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
| 82 |
+
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
| 83 |
+
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
| 84 |
+
return text_token, text_token_len
|
| 85 |
+
|
| 86 |
+
def _extract_text_token_generator(self, text_generator):
|
| 87 |
+
for text in text_generator:
|
| 88 |
+
text_token, _ = self._extract_text_token(text)
|
| 89 |
+
for i in range(text_token.shape[1]):
|
| 90 |
+
yield text_token[:, i: i + 1]
|
| 91 |
+
|
| 92 |
+
def _extract_speech_token(self, speech):
|
| 93 |
+
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
| 94 |
+
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
| 95 |
+
speech_token = self.speech_tokenizer_session.run(None,
|
| 96 |
+
{self.speech_tokenizer_session.get_inputs()[0].name:
|
| 97 |
+
feat.detach().cpu().numpy(),
|
| 98 |
+
self.speech_tokenizer_session.get_inputs()[1].name:
|
| 99 |
+
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
| 100 |
+
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
| 101 |
+
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
| 102 |
+
return speech_token, speech_token_len
|
| 103 |
+
|
| 104 |
+
def _extract_spk_embedding(self, speech):
|
| 105 |
+
feat = kaldi.fbank(speech,
|
| 106 |
+
num_mel_bins=80,
|
| 107 |
+
dither=0,
|
| 108 |
+
sample_frequency=16000)
|
| 109 |
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
| 110 |
+
embedding = self.campplus_session.run(None,
|
| 111 |
+
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
| 112 |
+
embedding = torch.tensor([embedding]).to(self.device)
|
| 113 |
+
return embedding
|
| 114 |
+
|
| 115 |
+
def _extract_speech_feat(self, speech):
|
| 116 |
+
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
| 117 |
+
speech_feat = speech_feat.unsqueeze(dim=0)
|
| 118 |
+
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
| 119 |
+
return speech_feat, speech_feat_len
|
| 120 |
+
|
| 121 |
+
def text_normalize(self, text, split=True, text_frontend=True):
|
| 122 |
+
if isinstance(text, Generator):
|
| 123 |
+
logging.info('get tts_text generator, will skip text_normalize!')
|
| 124 |
+
return [text]
|
| 125 |
+
if text_frontend is False:
|
| 126 |
+
return [text] if split is True else text
|
| 127 |
+
text = text.strip()
|
| 128 |
+
if self.use_ttsfrd:
|
| 129 |
+
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
| 130 |
+
text = ''.join(texts)
|
| 131 |
+
else:
|
| 132 |
+
if contains_chinese(text):
|
| 133 |
+
text = self.zh_tn_model.normalize(text)
|
| 134 |
+
text = text.replace("\n", "")
|
| 135 |
+
text = replace_blank(text)
|
| 136 |
+
text = replace_corner_mark(text)
|
| 137 |
+
text = text.replace(".", "。")
|
| 138 |
+
text = text.replace(" - ", ",")
|
| 139 |
+
text = remove_bracket(text)
|
| 140 |
+
text = re.sub(r'[,,、]+$', '。', text)
|
| 141 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
| 142 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
| 143 |
+
else:
|
| 144 |
+
text = self.en_tn_model.normalize(text)
|
| 145 |
+
text = spell_out_number(text, self.inflect_parser)
|
| 146 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
| 147 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
| 148 |
+
texts = [i for i in texts if not is_only_punctuation(i)]
|
| 149 |
+
return texts if split is True else text
|
| 150 |
+
|
| 151 |
+
def frontend_sft(self, tts_text, spk_id):
|
| 152 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
| 153 |
+
embedding = self.spk2info[spk_id]['embedding']
|
| 154 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
| 155 |
+
return model_input
|
| 156 |
+
|
| 157 |
+
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
|
| 158 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
| 159 |
+
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
| 160 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
| 161 |
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
| 162 |
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
| 163 |
+
if resample_rate == 24000:
|
| 164 |
+
# cosyvoice2, force speech_feat % speech_token = 2
|
| 165 |
+
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
| 166 |
+
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
|
| 167 |
+
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
| 168 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
| 169 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
| 170 |
+
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
| 171 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
| 172 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
| 173 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
| 174 |
+
'llm_embedding': embedding, 'flow_embedding': embedding}
|
| 175 |
+
return model_input
|
| 176 |
+
|
| 177 |
+
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
|
| 178 |
+
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
|
| 179 |
+
# in cross lingual mode, we remove prompt in llm
|
| 180 |
+
del model_input['prompt_text']
|
| 181 |
+
del model_input['prompt_text_len']
|
| 182 |
+
del model_input['llm_prompt_speech_token']
|
| 183 |
+
del model_input['llm_prompt_speech_token_len']
|
| 184 |
+
return model_input
|
| 185 |
+
|
| 186 |
+
def frontend_instruct(self, tts_text, spk_id, instruct_text):
|
| 187 |
+
model_input = self.frontend_sft(tts_text, spk_id)
|
| 188 |
+
# in instruct mode, we remove spk_embedding in llm due to information leakage
|
| 189 |
+
del model_input['llm_embedding']
|
| 190 |
+
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
|
| 191 |
+
model_input['prompt_text'] = instruct_text_token
|
| 192 |
+
model_input['prompt_text_len'] = instruct_text_token_len
|
| 193 |
+
return model_input
|
| 194 |
+
|
| 195 |
+
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
|
| 196 |
+
model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate)
|
| 197 |
+
del model_input['llm_prompt_speech_token']
|
| 198 |
+
del model_input['llm_prompt_speech_token_len']
|
| 199 |
+
return model_input
|
| 200 |
+
|
| 201 |
+
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
|
| 202 |
+
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
| 203 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
| 204 |
+
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
| 205 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
| 206 |
+
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
| 207 |
+
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
|
| 208 |
+
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
|
| 209 |
+
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
|
| 210 |
+
'flow_embedding': embedding}
|
| 211 |
+
return model_input
|
model/cosyvoice/cli/__init__.py
ADDED
|
File without changes
|
model/cosyvoice/cli/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (145 Bytes). View file
|
|
|
model/cosyvoice/cli/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (161 Bytes). View file
|
|
|
model/cosyvoice/cli/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (149 Bytes). View file
|
|
|
model/cosyvoice/cli/__pycache__/cosyvoice.cpython-310.pyc
ADDED
|
Binary file (7.38 kB). View file
|
|
|
model/cosyvoice/cli/__pycache__/cosyvoice.cpython-311.pyc
ADDED
|
Binary file (16.2 kB). View file
|
|
|
model/cosyvoice/cli/__pycache__/cosyvoice.cpython-312.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
model/cosyvoice/cli/__pycache__/frontend.cpython-310.pyc
ADDED
|
Binary file (8.58 kB). View file
|
|
|
model/cosyvoice/cli/__pycache__/frontend.cpython-311.pyc
ADDED
|
Binary file (16.4 kB). View file
|
|
|
model/cosyvoice/cli/__pycache__/frontend.cpython-312.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
model/cosyvoice/cli/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
model/cosyvoice/cli/__pycache__/model.cpython-311.pyc
ADDED
|
Binary file (37 kB). View file
|
|
|
model/cosyvoice/cli/cosyvoice.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
import time
|
| 16 |
+
from typing import Generator
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from hyperpyyaml import load_hyperpyyaml
|
| 19 |
+
from modelscope import snapshot_download
|
| 20 |
+
import torch
|
| 21 |
+
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
| 22 |
+
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
| 23 |
+
from cosyvoice.utils.file_utils import logging
|
| 24 |
+
from cosyvoice.utils.class_utils import get_model_type
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class CosyVoice:
|
| 28 |
+
|
| 29 |
+
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
|
| 30 |
+
self.instruct = True if '-Instruct' in model_dir else False
|
| 31 |
+
self.model_dir = model_dir
|
| 32 |
+
self.fp16 = fp16
|
| 33 |
+
if not os.path.exists(model_dir):
|
| 34 |
+
model_dir = snapshot_download(model_dir)
|
| 35 |
+
hyper_yaml_path = '{}/cosyvoice.yaml'.format(model_dir)
|
| 36 |
+
if not os.path.exists(hyper_yaml_path):
|
| 37 |
+
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
| 38 |
+
with open(hyper_yaml_path, 'r') as f:
|
| 39 |
+
configs = load_hyperpyyaml(f)
|
| 40 |
+
assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
|
| 41 |
+
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
| 42 |
+
configs['feat_extractor'],
|
| 43 |
+
'{}/campplus.onnx'.format(model_dir),
|
| 44 |
+
'{}/speech_tokenizer_v1.onnx'.format(model_dir),
|
| 45 |
+
'{}/spk2info.pt'.format(model_dir),
|
| 46 |
+
configs['allowed_special'])
|
| 47 |
+
self.sample_rate = configs['sample_rate']
|
| 48 |
+
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
| 49 |
+
load_jit, load_trt, fp16 = False, False, False
|
| 50 |
+
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
| 51 |
+
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
|
| 52 |
+
self.model.load('{}/llm.pt'.format(model_dir),
|
| 53 |
+
'{}/flow.pt'.format(model_dir),
|
| 54 |
+
'{}/hift.pt'.format(model_dir))
|
| 55 |
+
if load_jit:
|
| 56 |
+
self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
| 57 |
+
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
| 58 |
+
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
| 59 |
+
if load_trt:
|
| 60 |
+
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
| 61 |
+
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
| 62 |
+
self.fp16)
|
| 63 |
+
del configs
|
| 64 |
+
|
| 65 |
+
def list_available_spks(self):
|
| 66 |
+
spks = list(self.frontend.spk2info.keys())
|
| 67 |
+
return spks
|
| 68 |
+
|
| 69 |
+
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
| 70 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 71 |
+
model_input = self.frontend.frontend_sft(i, spk_id)
|
| 72 |
+
start_time = time.time()
|
| 73 |
+
logging.info('synthesis text {}'.format(i))
|
| 74 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 75 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 76 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 77 |
+
yield model_output
|
| 78 |
+
start_time = time.time()
|
| 79 |
+
|
| 80 |
+
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
| 81 |
+
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
|
| 82 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 83 |
+
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
|
| 84 |
+
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
|
| 85 |
+
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
|
| 86 |
+
start_time = time.time()
|
| 87 |
+
logging.info('synthesis text {}'.format(i))
|
| 88 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 89 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 90 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 91 |
+
yield model_output
|
| 92 |
+
start_time = time.time()
|
| 93 |
+
|
| 94 |
+
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
| 95 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 96 |
+
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
|
| 97 |
+
start_time = time.time()
|
| 98 |
+
logging.info('synthesis text {}'.format(i))
|
| 99 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 100 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 101 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 102 |
+
yield model_output
|
| 103 |
+
start_time = time.time()
|
| 104 |
+
|
| 105 |
+
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
|
| 106 |
+
assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
|
| 107 |
+
if self.instruct is False:
|
| 108 |
+
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
| 109 |
+
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
|
| 110 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 111 |
+
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
| 112 |
+
start_time = time.time()
|
| 113 |
+
logging.info('synthesis text {}'.format(i))
|
| 114 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 115 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 116 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 117 |
+
yield model_output
|
| 118 |
+
start_time = time.time()
|
| 119 |
+
|
| 120 |
+
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
|
| 121 |
+
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
|
| 122 |
+
start_time = time.time()
|
| 123 |
+
for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
|
| 124 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 125 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 126 |
+
yield model_output
|
| 127 |
+
start_time = time.time()
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class CosyVoice2(CosyVoice):
|
| 131 |
+
|
| 132 |
+
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False):
|
| 133 |
+
self.instruct = True if '-Instruct' in model_dir else False
|
| 134 |
+
self.model_dir = model_dir
|
| 135 |
+
self.fp16 = fp16
|
| 136 |
+
if not os.path.exists(model_dir):
|
| 137 |
+
model_dir = snapshot_download(model_dir)
|
| 138 |
+
hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
|
| 139 |
+
if not os.path.exists(hyper_yaml_path):
|
| 140 |
+
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
| 141 |
+
with open(hyper_yaml_path, 'r') as f:
|
| 142 |
+
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
| 143 |
+
assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
|
| 144 |
+
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
| 145 |
+
configs['feat_extractor'],
|
| 146 |
+
'{}/campplus.onnx'.format(model_dir),
|
| 147 |
+
'{}/speech_tokenizer_v2.onnx'.format(model_dir),
|
| 148 |
+
'{}/spk2info.pt'.format(model_dir),
|
| 149 |
+
configs['allowed_special'])
|
| 150 |
+
self.sample_rate = configs['sample_rate']
|
| 151 |
+
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
| 152 |
+
load_jit, load_trt, fp16 = False, False, False
|
| 153 |
+
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
| 154 |
+
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache)
|
| 155 |
+
self.model.load('{}/llm.pt'.format(model_dir),
|
| 156 |
+
'{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir),
|
| 157 |
+
'{}/hift.pt'.format(model_dir))
|
| 158 |
+
if load_jit:
|
| 159 |
+
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
| 160 |
+
if load_trt:
|
| 161 |
+
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
| 162 |
+
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
| 163 |
+
self.fp16)
|
| 164 |
+
del configs
|
| 165 |
+
|
| 166 |
+
def inference_instruct(self, *args, **kwargs):
|
| 167 |
+
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
|
| 168 |
+
|
| 169 |
+
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
| 170 |
+
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
|
| 171 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 172 |
+
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
|
| 173 |
+
start_time = time.time()
|
| 174 |
+
logging.info('synthesis text {}'.format(i))
|
| 175 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 176 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 177 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 178 |
+
yield model_output
|
| 179 |
+
start_time = time.time()
|
model/cosyvoice/cli/frontend.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from functools import partial
|
| 15 |
+
from typing import Generator
|
| 16 |
+
import json
|
| 17 |
+
import onnxruntime
|
| 18 |
+
import torch
|
| 19 |
+
import numpy as np
|
| 20 |
+
import whisper
|
| 21 |
+
from typing import Callable
|
| 22 |
+
import torchaudio.compliance.kaldi as kaldi
|
| 23 |
+
import torchaudio
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
import inflect
|
| 27 |
+
try:
|
| 28 |
+
import ttsfrd
|
| 29 |
+
use_ttsfrd = True
|
| 30 |
+
except ImportError:
|
| 31 |
+
print("failed to import ttsfrd, use WeTextProcessing instead")
|
| 32 |
+
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
| 33 |
+
from tn.english.normalizer import Normalizer as EnNormalizer
|
| 34 |
+
use_ttsfrd = False
|
| 35 |
+
from cosyvoice.utils.file_utils import logging
|
| 36 |
+
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class CosyVoiceFrontEnd:
|
| 40 |
+
|
| 41 |
+
def __init__(self,
|
| 42 |
+
get_tokenizer: Callable,
|
| 43 |
+
feat_extractor: Callable,
|
| 44 |
+
campplus_model: str,
|
| 45 |
+
speech_tokenizer_model: str,
|
| 46 |
+
spk2info: str = '',
|
| 47 |
+
allowed_special: str = 'all'):
|
| 48 |
+
self.tokenizer = get_tokenizer()
|
| 49 |
+
self.feat_extractor = feat_extractor
|
| 50 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 51 |
+
option = onnxruntime.SessionOptions()
|
| 52 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 53 |
+
option.intra_op_num_threads = 1
|
| 54 |
+
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
| 55 |
+
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
| 56 |
+
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
| 57 |
+
"CPUExecutionProvider"])
|
| 58 |
+
if os.path.exists(spk2info):
|
| 59 |
+
self.spk2info = torch.load(spk2info, map_location=self.device)
|
| 60 |
+
else:
|
| 61 |
+
self.spk2info = {}
|
| 62 |
+
self.allowed_special = allowed_special
|
| 63 |
+
self.use_ttsfrd = use_ttsfrd
|
| 64 |
+
if self.use_ttsfrd:
|
| 65 |
+
self.frd = ttsfrd.TtsFrontendEngine()
|
| 66 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 67 |
+
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
|
| 68 |
+
'failed to initialize ttsfrd resource'
|
| 69 |
+
self.frd.set_lang_type('pinyinvg')
|
| 70 |
+
else:
|
| 71 |
+
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
|
| 72 |
+
self.en_tn_model = EnNormalizer()
|
| 73 |
+
self.inflect_parser = inflect.engine()
|
| 74 |
+
|
| 75 |
+
def _extract_text_token(self, text):
|
| 76 |
+
if isinstance(text, Generator):
|
| 77 |
+
logging.info('get tts_text generator, will return _extract_text_token_generator!')
|
| 78 |
+
# NOTE add a dummy text_token_len for compatibility
|
| 79 |
+
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
|
| 80 |
+
else:
|
| 81 |
+
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
| 82 |
+
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
| 83 |
+
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
| 84 |
+
return text_token, text_token_len
|
| 85 |
+
|
| 86 |
+
def _extract_text_token_generator(self, text_generator):
|
| 87 |
+
for text in text_generator:
|
| 88 |
+
text_token, _ = self._extract_text_token(text)
|
| 89 |
+
for i in range(text_token.shape[1]):
|
| 90 |
+
yield text_token[:, i: i + 1]
|
| 91 |
+
|
| 92 |
+
def _extract_speech_token(self, speech):
|
| 93 |
+
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
| 94 |
+
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
| 95 |
+
speech_token = self.speech_tokenizer_session.run(None,
|
| 96 |
+
{self.speech_tokenizer_session.get_inputs()[0].name:
|
| 97 |
+
feat.detach().cpu().numpy(),
|
| 98 |
+
self.speech_tokenizer_session.get_inputs()[1].name:
|
| 99 |
+
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
| 100 |
+
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
| 101 |
+
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
| 102 |
+
return speech_token, speech_token_len
|
| 103 |
+
|
| 104 |
+
def _extract_spk_embedding(self, speech):
|
| 105 |
+
feat = kaldi.fbank(speech,
|
| 106 |
+
num_mel_bins=80,
|
| 107 |
+
dither=0,
|
| 108 |
+
sample_frequency=16000)
|
| 109 |
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
| 110 |
+
embedding = self.campplus_session.run(None,
|
| 111 |
+
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
| 112 |
+
embedding = torch.tensor([embedding]).to(self.device)
|
| 113 |
+
return embedding
|
| 114 |
+
|
| 115 |
+
def _extract_speech_feat(self, speech):
|
| 116 |
+
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
| 117 |
+
speech_feat = speech_feat.unsqueeze(dim=0)
|
| 118 |
+
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
| 119 |
+
return speech_feat, speech_feat_len
|
| 120 |
+
|
| 121 |
+
def text_normalize(self, text, split=True, text_frontend=True):
|
| 122 |
+
if isinstance(text, Generator):
|
| 123 |
+
logging.info('get tts_text generator, will skip text_normalize!')
|
| 124 |
+
return [text]
|
| 125 |
+
if text_frontend is False:
|
| 126 |
+
return [text] if split is True else text
|
| 127 |
+
text = text.strip()
|
| 128 |
+
if self.use_ttsfrd:
|
| 129 |
+
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
| 130 |
+
text = ''.join(texts)
|
| 131 |
+
else:
|
| 132 |
+
if contains_chinese(text):
|
| 133 |
+
text = self.zh_tn_model.normalize(text)
|
| 134 |
+
text = text.replace("\n", "")
|
| 135 |
+
text = replace_blank(text)
|
| 136 |
+
text = replace_corner_mark(text)
|
| 137 |
+
text = text.replace(".", "。")
|
| 138 |
+
text = text.replace(" - ", ",")
|
| 139 |
+
text = remove_bracket(text)
|
| 140 |
+
text = re.sub(r'[,,、]+$', '。', text)
|
| 141 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
| 142 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
| 143 |
+
else:
|
| 144 |
+
text = self.en_tn_model.normalize(text)
|
| 145 |
+
text = spell_out_number(text, self.inflect_parser)
|
| 146 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
| 147 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
| 148 |
+
texts = [i for i in texts if not is_only_punctuation(i)]
|
| 149 |
+
return texts if split is True else text
|
| 150 |
+
|
| 151 |
+
def frontend_sft(self, tts_text, spk_id):
|
| 152 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
| 153 |
+
embedding = self.spk2info[spk_id]['embedding']
|
| 154 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
| 155 |
+
return model_input
|
| 156 |
+
|
| 157 |
+
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
|
| 158 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
| 159 |
+
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
| 160 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
| 161 |
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
| 162 |
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
| 163 |
+
if resample_rate == 24000:
|
| 164 |
+
# cosyvoice2, force speech_feat % speech_token = 2
|
| 165 |
+
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
| 166 |
+
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
|
| 167 |
+
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
| 168 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
| 169 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
| 170 |
+
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
| 171 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
| 172 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
| 173 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
| 174 |
+
'llm_embedding': embedding, 'flow_embedding': embedding}
|
| 175 |
+
return model_input
|
| 176 |
+
|
| 177 |
+
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
|
| 178 |
+
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
|
| 179 |
+
# in cross lingual mode, we remove prompt in llm
|
| 180 |
+
del model_input['prompt_text']
|
| 181 |
+
del model_input['prompt_text_len']
|
| 182 |
+
del model_input['llm_prompt_speech_token']
|
| 183 |
+
del model_input['llm_prompt_speech_token_len']
|
| 184 |
+
return model_input
|
| 185 |
+
|
| 186 |
+
def frontend_instruct(self, tts_text, spk_id, instruct_text):
|
| 187 |
+
model_input = self.frontend_sft(tts_text, spk_id)
|
| 188 |
+
# in instruct mode, we remove spk_embedding in llm due to information leakage
|
| 189 |
+
del model_input['llm_embedding']
|
| 190 |
+
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
|
| 191 |
+
model_input['prompt_text'] = instruct_text_token
|
| 192 |
+
model_input['prompt_text_len'] = instruct_text_token_len
|
| 193 |
+
return model_input
|
| 194 |
+
|
| 195 |
+
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
|
| 196 |
+
model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate)
|
| 197 |
+
del model_input['llm_prompt_speech_token']
|
| 198 |
+
del model_input['llm_prompt_speech_token_len']
|
| 199 |
+
return model_input
|
| 200 |
+
|
| 201 |
+
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
|
| 202 |
+
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
| 203 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
| 204 |
+
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
| 205 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
| 206 |
+
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
| 207 |
+
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
|
| 208 |
+
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
|
| 209 |
+
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
|
| 210 |
+
'flow_embedding': embedding}
|
| 211 |
+
return model_input
|
model/cosyvoice/cli/model.py
ADDED
|
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
from typing import Generator
|
| 16 |
+
import torch
|
| 17 |
+
import numpy as np
|
| 18 |
+
import threading
|
| 19 |
+
import time
|
| 20 |
+
from torch.nn import functional as F
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
import uuid
|
| 23 |
+
from cosyvoice.utils.common import fade_in_out
|
| 24 |
+
from cosyvoice.utils.file_utils import convert_onnx_to_trt
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class CosyVoiceModel:
|
| 28 |
+
|
| 29 |
+
def __init__(self,
|
| 30 |
+
llm: torch.nn.Module,
|
| 31 |
+
flow: torch.nn.Module,
|
| 32 |
+
hift: torch.nn.Module,
|
| 33 |
+
fp16: bool):
|
| 34 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 35 |
+
self.llm = llm
|
| 36 |
+
self.flow = flow
|
| 37 |
+
self.hift = hift
|
| 38 |
+
self.fp16 = fp16
|
| 39 |
+
if self.fp16 is True:
|
| 40 |
+
self.llm.half()
|
| 41 |
+
self.flow.half()
|
| 42 |
+
self.token_min_hop_len = 2 * self.flow.input_frame_rate
|
| 43 |
+
self.token_max_hop_len = 4 * self.flow.input_frame_rate
|
| 44 |
+
self.token_overlap_len = 20
|
| 45 |
+
# mel fade in out
|
| 46 |
+
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
|
| 47 |
+
self.mel_window = np.hamming(2 * self.mel_overlap_len)
|
| 48 |
+
# hift cache
|
| 49 |
+
self.mel_cache_len = 20
|
| 50 |
+
self.source_cache_len = int(self.mel_cache_len * 256)
|
| 51 |
+
# speech fade in out
|
| 52 |
+
self.speech_window = np.hamming(2 * self.source_cache_len)
|
| 53 |
+
# rtf and decoding related
|
| 54 |
+
self.stream_scale_factor = 1
|
| 55 |
+
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
| 56 |
+
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
| 57 |
+
self.lock = threading.Lock()
|
| 58 |
+
# dict used to store session related variable
|
| 59 |
+
self.tts_speech_token_dict = {}
|
| 60 |
+
self.llm_end_dict = {}
|
| 61 |
+
self.mel_overlap_dict = {}
|
| 62 |
+
self.flow_cache_dict = {}
|
| 63 |
+
self.hift_cache_dict = {}
|
| 64 |
+
|
| 65 |
+
def load(self, llm_model, flow_model, hift_model):
|
| 66 |
+
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
| 67 |
+
self.llm.to(self.device).eval()
|
| 68 |
+
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
| 69 |
+
self.flow.to(self.device).eval()
|
| 70 |
+
# in case hift_model is a hifigan model
|
| 71 |
+
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
| 72 |
+
self.hift.load_state_dict(hift_state_dict, strict=True)
|
| 73 |
+
self.hift.to(self.device).eval()
|
| 74 |
+
|
| 75 |
+
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
| 76 |
+
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
|
| 77 |
+
self.llm.text_encoder = llm_text_encoder
|
| 78 |
+
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
|
| 79 |
+
self.llm.llm = llm_llm
|
| 80 |
+
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
| 81 |
+
self.flow.encoder = flow_encoder
|
| 82 |
+
|
| 83 |
+
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
|
| 84 |
+
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
| 85 |
+
if not os.path.exists(flow_decoder_estimator_model):
|
| 86 |
+
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
|
| 87 |
+
if os.path.getsize(flow_decoder_estimator_model) == 0:
|
| 88 |
+
raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
|
| 89 |
+
del self.flow.decoder.estimator
|
| 90 |
+
import tensorrt as trt
|
| 91 |
+
with open(flow_decoder_estimator_model, 'rb') as f:
|
| 92 |
+
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
| 93 |
+
assert self.flow.decoder.estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
| 94 |
+
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
|
| 95 |
+
|
| 96 |
+
def get_trt_kwargs(self):
|
| 97 |
+
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
|
| 98 |
+
opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200)]
|
| 99 |
+
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
|
| 100 |
+
input_names = ["x", "mask", "mu", "cond"]
|
| 101 |
+
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
| 102 |
+
|
| 103 |
+
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
| 104 |
+
with self.llm_context, torch.cuda.amp.autocast(self.fp16):
|
| 105 |
+
if isinstance(text, Generator):
|
| 106 |
+
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
|
| 107 |
+
for i in self.llm.inference_bistream(text=text,
|
| 108 |
+
prompt_text=prompt_text.to(self.device),
|
| 109 |
+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
| 110 |
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
| 111 |
+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
| 112 |
+
embedding=llm_embedding.to(self.device)):
|
| 113 |
+
self.tts_speech_token_dict[uuid].append(i)
|
| 114 |
+
else:
|
| 115 |
+
for i in self.llm.inference(text=text.to(self.device),
|
| 116 |
+
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
| 117 |
+
prompt_text=prompt_text.to(self.device),
|
| 118 |
+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
| 119 |
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
| 120 |
+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
| 121 |
+
embedding=llm_embedding.to(self.device)):
|
| 122 |
+
self.tts_speech_token_dict[uuid].append(i)
|
| 123 |
+
self.llm_end_dict[uuid] = True
|
| 124 |
+
|
| 125 |
+
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
| 126 |
+
with torch.cuda.amp.autocast(self.fp16):
|
| 127 |
+
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
|
| 128 |
+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
| 129 |
+
prompt_token=prompt_token.to(self.device),
|
| 130 |
+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
| 131 |
+
prompt_feat=prompt_feat.to(self.device),
|
| 132 |
+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
| 133 |
+
embedding=embedding.to(self.device),
|
| 134 |
+
flow_cache=self.flow_cache_dict[uuid])
|
| 135 |
+
|
| 136 |
+
# mel overlap fade in out
|
| 137 |
+
if self.mel_overlap_dict[uuid].shape[2] != 0:
|
| 138 |
+
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
|
| 139 |
+
# append hift cache
|
| 140 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 141 |
+
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
| 142 |
+
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
| 143 |
+
else:
|
| 144 |
+
hift_cache_source = torch.zeros(1, 1, 0)
|
| 145 |
+
# keep overlap mel and hift cache
|
| 146 |
+
if finalize is False:
|
| 147 |
+
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
|
| 148 |
+
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
|
| 149 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
| 150 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 151 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
| 152 |
+
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
| 153 |
+
'source': tts_source[:, :, -self.source_cache_len:],
|
| 154 |
+
'speech': tts_speech[:, -self.source_cache_len:]}
|
| 155 |
+
tts_speech = tts_speech[:, :-self.source_cache_len]
|
| 156 |
+
else:
|
| 157 |
+
if speed != 1.0:
|
| 158 |
+
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
| 159 |
+
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
| 160 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
| 161 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 162 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
| 163 |
+
return tts_speech
|
| 164 |
+
|
| 165 |
+
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
| 166 |
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
| 167 |
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
| 168 |
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
| 169 |
+
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
| 170 |
+
# this_uuid is used to track variables related to this inference thread
|
| 171 |
+
this_uuid = str(uuid.uuid1())
|
| 172 |
+
with self.lock:
|
| 173 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
| 174 |
+
self.hift_cache_dict[this_uuid] = None
|
| 175 |
+
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
| 176 |
+
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
| 177 |
+
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
| 178 |
+
p.start()
|
| 179 |
+
if stream is True:
|
| 180 |
+
token_hop_len = self.token_min_hop_len
|
| 181 |
+
while True:
|
| 182 |
+
time.sleep(0.1)
|
| 183 |
+
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
| 184 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
| 185 |
+
.unsqueeze(dim=0)
|
| 186 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 187 |
+
prompt_token=flow_prompt_speech_token,
|
| 188 |
+
prompt_feat=prompt_speech_feat,
|
| 189 |
+
embedding=flow_embedding,
|
| 190 |
+
uuid=this_uuid,
|
| 191 |
+
finalize=False)
|
| 192 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 193 |
+
with self.lock:
|
| 194 |
+
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
| 195 |
+
# increase token_hop_len for better speech quality
|
| 196 |
+
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
| 197 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
| 198 |
+
break
|
| 199 |
+
p.join()
|
| 200 |
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
| 201 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 202 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 203 |
+
prompt_token=flow_prompt_speech_token,
|
| 204 |
+
prompt_feat=prompt_speech_feat,
|
| 205 |
+
embedding=flow_embedding,
|
| 206 |
+
uuid=this_uuid,
|
| 207 |
+
finalize=True)
|
| 208 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 209 |
+
else:
|
| 210 |
+
# deal with all tokens
|
| 211 |
+
p.join()
|
| 212 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 213 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 214 |
+
prompt_token=flow_prompt_speech_token,
|
| 215 |
+
prompt_feat=prompt_speech_feat,
|
| 216 |
+
embedding=flow_embedding,
|
| 217 |
+
uuid=this_uuid,
|
| 218 |
+
finalize=True,
|
| 219 |
+
speed=speed)
|
| 220 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 221 |
+
with self.lock:
|
| 222 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
| 223 |
+
self.llm_end_dict.pop(this_uuid)
|
| 224 |
+
self.mel_overlap_dict.pop(this_uuid)
|
| 225 |
+
self.hift_cache_dict.pop(this_uuid)
|
| 226 |
+
self.flow_cache_dict.pop(this_uuid)
|
| 227 |
+
torch.cuda.empty_cache()
|
| 228 |
+
|
| 229 |
+
def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
|
| 230 |
+
# this_uuid is used to track variables related to this inference thread
|
| 231 |
+
this_uuid = str(uuid.uuid1())
|
| 232 |
+
with self.lock:
|
| 233 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
|
| 234 |
+
self.hift_cache_dict[this_uuid] = None
|
| 235 |
+
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
| 236 |
+
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
| 237 |
+
if stream is True:
|
| 238 |
+
token_hop_len = self.token_min_hop_len
|
| 239 |
+
while True:
|
| 240 |
+
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
| 241 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
| 242 |
+
.unsqueeze(dim=0)
|
| 243 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 244 |
+
prompt_token=flow_prompt_speech_token,
|
| 245 |
+
prompt_feat=prompt_speech_feat,
|
| 246 |
+
embedding=flow_embedding,
|
| 247 |
+
uuid=this_uuid,
|
| 248 |
+
finalize=False)
|
| 249 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 250 |
+
with self.lock:
|
| 251 |
+
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
| 252 |
+
# increase token_hop_len for better speech quality
|
| 253 |
+
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
| 254 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
| 255 |
+
break
|
| 256 |
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
| 257 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 258 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 259 |
+
prompt_token=flow_prompt_speech_token,
|
| 260 |
+
prompt_feat=prompt_speech_feat,
|
| 261 |
+
embedding=flow_embedding,
|
| 262 |
+
uuid=this_uuid,
|
| 263 |
+
finalize=True)
|
| 264 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 265 |
+
else:
|
| 266 |
+
# deal with all tokens
|
| 267 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 268 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 269 |
+
prompt_token=flow_prompt_speech_token,
|
| 270 |
+
prompt_feat=prompt_speech_feat,
|
| 271 |
+
embedding=flow_embedding,
|
| 272 |
+
uuid=this_uuid,
|
| 273 |
+
finalize=True,
|
| 274 |
+
speed=speed)
|
| 275 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 276 |
+
with self.lock:
|
| 277 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
| 278 |
+
self.llm_end_dict.pop(this_uuid)
|
| 279 |
+
self.mel_overlap_dict.pop(this_uuid)
|
| 280 |
+
self.hift_cache_dict.pop(this_uuid)
|
| 281 |
+
self.flow_cache_dict.pop(this_uuid)
|
| 282 |
+
torch.cuda.empty_cache()
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class CosyVoice2Model(CosyVoiceModel):
|
| 286 |
+
|
| 287 |
+
def __init__(self,
|
| 288 |
+
llm: torch.nn.Module,
|
| 289 |
+
flow: torch.nn.Module,
|
| 290 |
+
hift: torch.nn.Module,
|
| 291 |
+
fp16: bool,
|
| 292 |
+
use_flow_cache: bool):
|
| 293 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 294 |
+
self.llm = llm
|
| 295 |
+
self.flow = flow
|
| 296 |
+
self.hift = hift
|
| 297 |
+
self.fp16 = fp16
|
| 298 |
+
self.use_flow_cache = use_flow_cache
|
| 299 |
+
if self.fp16 is True:
|
| 300 |
+
self.llm.half()
|
| 301 |
+
self.flow.half()
|
| 302 |
+
# stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml
|
| 303 |
+
self.token_hop_len = 25
|
| 304 |
+
self.flow_decoder_required_cache_size = -1 if use_flow_cache is False else 1 * self.token_hop_len
|
| 305 |
+
# hift cache
|
| 306 |
+
self.mel_cache_len = 8
|
| 307 |
+
self.source_cache_len = int(self.mel_cache_len * 480)
|
| 308 |
+
# speech fade in out
|
| 309 |
+
self.speech_window = np.hamming(2 * self.source_cache_len)
|
| 310 |
+
# rtf and decoding related
|
| 311 |
+
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
| 312 |
+
self.lock = threading.Lock()
|
| 313 |
+
# dict used to store session related variable
|
| 314 |
+
self.tts_speech_token_dict = {}
|
| 315 |
+
self.llm_end_dict = {}
|
| 316 |
+
self.flow_cache_dict = {}
|
| 317 |
+
self.hift_cache_dict = {}
|
| 318 |
+
|
| 319 |
+
def init_flow_cache(self):
|
| 320 |
+
encoder_cache = {'offset': 0,
|
| 321 |
+
'pre_lookahead_layer_conv2_cache': torch.zeros(1, 512, 2).to(self.device),
|
| 322 |
+
'encoders_kv_cache': torch.zeros(6, 1, 8, 0, 64 * 2).to(self.device),
|
| 323 |
+
'upsample_offset': 0,
|
| 324 |
+
'upsample_conv_cache': torch.zeros(1, 512, 4).to(self.device),
|
| 325 |
+
'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)}
|
| 326 |
+
decoder_cache = {'offset': 0,
|
| 327 |
+
'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device),
|
| 328 |
+
'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, 0, 512, 2).to(self.device),
|
| 329 |
+
'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device),
|
| 330 |
+
'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, 0, 512, 2).to(self.device),
|
| 331 |
+
'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device),
|
| 332 |
+
'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, 0, 512, 2).to(self.device),
|
| 333 |
+
'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)}
|
| 334 |
+
if self.fp16 is True:
|
| 335 |
+
for cache in [encoder_cache, decoder_cache]:
|
| 336 |
+
for k, v in cache.items():
|
| 337 |
+
if isinstance(v, torch.Tensor):
|
| 338 |
+
cache[k] = v.half()
|
| 339 |
+
cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache}
|
| 340 |
+
return cache
|
| 341 |
+
|
| 342 |
+
def trim_flow_cache(self, cache):
|
| 343 |
+
if self.flow_decoder_required_cache_size > 0:
|
| 344 |
+
cache['decoder_cache']['down_blocks_kv_cache'] = cache['decoder_cache']['down_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
|
| 345 |
+
cache['decoder_cache']['mid_blocks_kv_cache'] = cache['decoder_cache']['mid_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
|
| 346 |
+
cache['decoder_cache']['up_blocks_kv_cache'] = cache['decoder_cache']['up_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
|
| 347 |
+
return cache
|
| 348 |
+
|
| 349 |
+
def load_jit(self, flow_encoder_model):
|
| 350 |
+
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
| 351 |
+
self.flow.encoder = flow_encoder
|
| 352 |
+
|
| 353 |
+
def get_trt_kwargs(self):
|
| 354 |
+
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (1, 4, 2, 0, 512, 2), (12, 4, 2, 0, 512, 2), (1, 4, 2, 0, 512, 2)]
|
| 355 |
+
opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200), (1, 4, 2, 100, 512, 2), (12, 4, 2, 100, 512, 2), (1, 4, 2, 100, 512, 2)]
|
| 356 |
+
max_shape = [(2, 80, 1500), (2, 1, 1500), (2, 80, 1500), (2, 80, 1500), (1, 4, 2, 200, 512, 2), (12, 4, 2, 200, 512, 2), (1, 4, 2, 200, 512, 2)]
|
| 357 |
+
input_names = ["x", "mask", "mu", "cond", 'down_blocks_kv_cache', 'mid_blocks_kv_cache', 'up_blocks_kv_cache']
|
| 358 |
+
assert self.use_flow_cache is True, "get_trt_kwargs is set for flow cache mode. If you want to use trt with use_flow_cache=False, please set higher max_shape"
|
| 359 |
+
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
| 360 |
+
|
| 361 |
+
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
| 362 |
+
with torch.cuda.amp.autocast(self.fp16):
|
| 363 |
+
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
|
| 364 |
+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
| 365 |
+
prompt_token=prompt_token.to(self.device),
|
| 366 |
+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
| 367 |
+
prompt_feat=prompt_feat.to(self.device),
|
| 368 |
+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
| 369 |
+
embedding=embedding.to(self.device),
|
| 370 |
+
cache=self.flow_cache_dict[uuid],
|
| 371 |
+
finalize=finalize)
|
| 372 |
+
self.flow_cache_dict[uuid] = self.trim_flow_cache(self.flow_cache_dict[uuid])
|
| 373 |
+
# append hift cache
|
| 374 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 375 |
+
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
| 376 |
+
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
| 377 |
+
else:
|
| 378 |
+
hift_cache_source = torch.zeros(1, 1, 0)
|
| 379 |
+
# keep overlap mel and hift cache
|
| 380 |
+
if finalize is False:
|
| 381 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
| 382 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 383 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
| 384 |
+
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
| 385 |
+
'source': tts_source[:, :, -self.source_cache_len:],
|
| 386 |
+
'speech': tts_speech[:, -self.source_cache_len:]}
|
| 387 |
+
tts_speech = tts_speech[:, :-self.source_cache_len]
|
| 388 |
+
else:
|
| 389 |
+
if speed != 1.0:
|
| 390 |
+
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
| 391 |
+
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
| 392 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
| 393 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 394 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
| 395 |
+
return tts_speech
|
| 396 |
+
|
| 397 |
+
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
| 398 |
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
| 399 |
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
| 400 |
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
| 401 |
+
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
| 402 |
+
# this_uuid is used to track variables related to this inference thread
|
| 403 |
+
this_uuid = str(uuid.uuid1())
|
| 404 |
+
with self.lock:
|
| 405 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
| 406 |
+
self.hift_cache_dict[this_uuid] = None
|
| 407 |
+
self.flow_cache_dict[this_uuid] = self.init_flow_cache()
|
| 408 |
+
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
| 409 |
+
p.start()
|
| 410 |
+
if stream is True:
|
| 411 |
+
assert self.use_flow_cache is True, "set use_flow_cache=True if you want to use stream inference to avoid OOM"
|
| 412 |
+
# NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size
|
| 413 |
+
flow_prompt_speech_token = flow_prompt_speech_token[:, -self.flow_decoder_required_cache_size:]
|
| 414 |
+
prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size * 2:]
|
| 415 |
+
while True:
|
| 416 |
+
time.sleep(0.1)
|
| 417 |
+
if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len:
|
| 418 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
|
| 419 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 420 |
+
prompt_token=flow_prompt_speech_token,
|
| 421 |
+
prompt_feat=prompt_speech_feat,
|
| 422 |
+
embedding=flow_embedding,
|
| 423 |
+
uuid=this_uuid,
|
| 424 |
+
finalize=False)
|
| 425 |
+
# NOTE in cache inference mode, we only use flow_prompt_speech_token/prompt_speech_feat in first chunk
|
| 426 |
+
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device)
|
| 427 |
+
prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device)
|
| 428 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 429 |
+
with self.lock:
|
| 430 |
+
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][self.token_hop_len:]
|
| 431 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < self.token_hop_len + self.flow.pre_lookahead_len:
|
| 432 |
+
break
|
| 433 |
+
p.join()
|
| 434 |
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
| 435 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 436 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 437 |
+
prompt_token=flow_prompt_speech_token,
|
| 438 |
+
prompt_feat=prompt_speech_feat,
|
| 439 |
+
embedding=flow_embedding,
|
| 440 |
+
uuid=this_uuid,
|
| 441 |
+
finalize=True)
|
| 442 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 443 |
+
else:
|
| 444 |
+
# deal with all tokens
|
| 445 |
+
assert self.use_flow_cache is False, "set use_flow_cache=False for nonstream inference"
|
| 446 |
+
p.join()
|
| 447 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 448 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 449 |
+
prompt_token=flow_prompt_speech_token,
|
| 450 |
+
prompt_feat=prompt_speech_feat,
|
| 451 |
+
embedding=flow_embedding,
|
| 452 |
+
uuid=this_uuid,
|
| 453 |
+
finalize=True,
|
| 454 |
+
speed=speed)
|
| 455 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 456 |
+
with self.lock:
|
| 457 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
| 458 |
+
self.llm_end_dict.pop(this_uuid)
|
| 459 |
+
self.hift_cache_dict.pop(this_uuid)
|
| 460 |
+
self.flow_cache_dict.pop(this_uuid)
|
| 461 |
+
torch.cuda.empty_cache()
|
model/cosyvoice/dataset/__init__.py
ADDED
|
File without changes
|
model/cosyvoice/dataset/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (149 Bytes). View file
|
|
|
model/cosyvoice/dataset/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (165 Bytes). View file
|
|
|
model/cosyvoice/dataset/__pycache__/processor.cpython-310.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
model/cosyvoice/dataset/__pycache__/processor.cpython-311.pyc
ADDED
|
Binary file (22.8 kB). View file
|
|
|
model/cosyvoice/dataset/dataset.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
| 2 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import random
|
| 17 |
+
import json
|
| 18 |
+
import math
|
| 19 |
+
from functools import partial
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.distributed as dist
|
| 23 |
+
from torch.utils.data import IterableDataset
|
| 24 |
+
from cosyvoice.utils.file_utils import read_lists, read_json_lists
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Processor(IterableDataset):
|
| 28 |
+
|
| 29 |
+
def __init__(self, source, f, *args, **kw):
|
| 30 |
+
assert callable(f)
|
| 31 |
+
self.source = source
|
| 32 |
+
self.f = f
|
| 33 |
+
self.args = args
|
| 34 |
+
self.kw = kw
|
| 35 |
+
|
| 36 |
+
def set_epoch(self, epoch):
|
| 37 |
+
self.source.set_epoch(epoch)
|
| 38 |
+
|
| 39 |
+
def __iter__(self):
|
| 40 |
+
""" Return an iterator over the source dataset processed by the
|
| 41 |
+
given processor.
|
| 42 |
+
"""
|
| 43 |
+
assert self.source is not None
|
| 44 |
+
assert callable(self.f)
|
| 45 |
+
return self.f(iter(self.source), *self.args, **self.kw)
|
| 46 |
+
|
| 47 |
+
def apply(self, f):
|
| 48 |
+
assert callable(f)
|
| 49 |
+
return Processor(self, f, *self.args, **self.kw)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class DistributedSampler:
|
| 53 |
+
|
| 54 |
+
def __init__(self, shuffle=True, partition=True):
|
| 55 |
+
self.epoch = -1
|
| 56 |
+
self.update()
|
| 57 |
+
self.shuffle = shuffle
|
| 58 |
+
self.partition = partition
|
| 59 |
+
|
| 60 |
+
def update(self):
|
| 61 |
+
assert dist.is_available()
|
| 62 |
+
if dist.is_initialized():
|
| 63 |
+
self.rank = dist.get_rank()
|
| 64 |
+
self.world_size = dist.get_world_size()
|
| 65 |
+
else:
|
| 66 |
+
self.rank = 0
|
| 67 |
+
self.world_size = 1
|
| 68 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 69 |
+
if worker_info is None:
|
| 70 |
+
self.worker_id = 0
|
| 71 |
+
self.num_workers = 1
|
| 72 |
+
else:
|
| 73 |
+
self.worker_id = worker_info.id
|
| 74 |
+
self.num_workers = worker_info.num_workers
|
| 75 |
+
return dict(rank=self.rank,
|
| 76 |
+
world_size=self.world_size,
|
| 77 |
+
worker_id=self.worker_id,
|
| 78 |
+
num_workers=self.num_workers)
|
| 79 |
+
|
| 80 |
+
def set_epoch(self, epoch):
|
| 81 |
+
self.epoch = epoch
|
| 82 |
+
|
| 83 |
+
def sample(self, data):
|
| 84 |
+
""" Sample data according to rank/world_size/num_workers
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
data(List): input data list
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
List: data list after sample
|
| 91 |
+
"""
|
| 92 |
+
data = list(range(len(data)))
|
| 93 |
+
# force datalist even
|
| 94 |
+
if self.partition:
|
| 95 |
+
if self.shuffle:
|
| 96 |
+
random.Random(self.epoch).shuffle(data)
|
| 97 |
+
if len(data) < self.world_size:
|
| 98 |
+
data = data * math.ceil(self.world_size / len(data))
|
| 99 |
+
data = data[:self.world_size]
|
| 100 |
+
data = data[self.rank::self.world_size]
|
| 101 |
+
if len(data) < self.num_workers:
|
| 102 |
+
data = data * math.ceil(self.num_workers / len(data))
|
| 103 |
+
data = data[:self.num_workers]
|
| 104 |
+
data = data[self.worker_id::self.num_workers]
|
| 105 |
+
return data
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class DataList(IterableDataset):
|
| 109 |
+
|
| 110 |
+
def __init__(self, lists, shuffle=True, partition=True):
|
| 111 |
+
self.lists = lists
|
| 112 |
+
self.sampler = DistributedSampler(shuffle, partition)
|
| 113 |
+
|
| 114 |
+
def set_epoch(self, epoch):
|
| 115 |
+
self.sampler.set_epoch(epoch)
|
| 116 |
+
|
| 117 |
+
def __iter__(self):
|
| 118 |
+
sampler_info = self.sampler.update()
|
| 119 |
+
indexes = self.sampler.sample(self.lists)
|
| 120 |
+
for index in indexes:
|
| 121 |
+
data = dict(src=self.lists[index])
|
| 122 |
+
data.update(sampler_info)
|
| 123 |
+
yield data
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def Dataset(data_list_file,
|
| 127 |
+
data_pipeline,
|
| 128 |
+
mode='train',
|
| 129 |
+
gan=False,
|
| 130 |
+
shuffle=True,
|
| 131 |
+
partition=True,
|
| 132 |
+
tts_file='',
|
| 133 |
+
prompt_utt2data=''):
|
| 134 |
+
""" Construct dataset from arguments
|
| 135 |
+
|
| 136 |
+
We have two shuffle stage in the Dataset. The first is global
|
| 137 |
+
shuffle at shards tar/raw file level. The second is global shuffle
|
| 138 |
+
at training samples level.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
data_type(str): raw/shard
|
| 142 |
+
tokenizer (BaseTokenizer): tokenizer to tokenize
|
| 143 |
+
partition(bool): whether to do data partition in terms of rank
|
| 144 |
+
"""
|
| 145 |
+
assert mode in ['train', 'inference']
|
| 146 |
+
lists = read_lists(data_list_file)
|
| 147 |
+
if mode == 'inference':
|
| 148 |
+
with open(tts_file) as f:
|
| 149 |
+
tts_data = json.load(f)
|
| 150 |
+
utt2lists = read_json_lists(prompt_utt2data)
|
| 151 |
+
# filter unnecessary file in inference mode
|
| 152 |
+
lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
|
| 153 |
+
dataset = DataList(lists,
|
| 154 |
+
shuffle=shuffle,
|
| 155 |
+
partition=partition)
|
| 156 |
+
if mode == 'inference':
|
| 157 |
+
# map partial arg to parquet_opener func in inference mode
|
| 158 |
+
data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
|
| 159 |
+
if gan is True:
|
| 160 |
+
# map partial arg to padding func in gan mode
|
| 161 |
+
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
|
| 162 |
+
for func in data_pipeline:
|
| 163 |
+
dataset = Processor(dataset, func, mode=mode)
|
| 164 |
+
return dataset
|
model/cosyvoice/dataset/processor.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import logging
|
| 15 |
+
import random
|
| 16 |
+
|
| 17 |
+
import pyarrow.parquet as pq
|
| 18 |
+
from io import BytesIO
|
| 19 |
+
import torch
|
| 20 |
+
import torchaudio
|
| 21 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
import pyworld as pw
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def parquet_opener(data, mode='train', tts_data={}):
|
| 30 |
+
""" Give url or local file, return file descriptor
|
| 31 |
+
Inplace operation.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
data(Iterable[str]): url or local file list
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Iterable[{src, stream}]
|
| 38 |
+
"""
|
| 39 |
+
for sample in data:
|
| 40 |
+
assert 'src' in sample
|
| 41 |
+
url = sample['src']
|
| 42 |
+
try:
|
| 43 |
+
for df in pq.ParquetFile(url).iter_batches(batch_size=64):
|
| 44 |
+
df = df.to_pandas()
|
| 45 |
+
for i in range(len(df)):
|
| 46 |
+
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
|
| 47 |
+
continue
|
| 48 |
+
sample.update(dict(df.loc[i]))
|
| 49 |
+
if mode == 'train':
|
| 50 |
+
# NOTE do not return sample directly, must initialize a new dict
|
| 51 |
+
yield {**sample}
|
| 52 |
+
else:
|
| 53 |
+
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
|
| 54 |
+
yield {**sample, 'tts_index': index, 'tts_text': text}
|
| 55 |
+
except Exception as ex:
|
| 56 |
+
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def filter(data,
|
| 60 |
+
max_length=10240,
|
| 61 |
+
min_length=10,
|
| 62 |
+
token_max_length=200,
|
| 63 |
+
token_min_length=1,
|
| 64 |
+
min_output_input_ratio=0.0005,
|
| 65 |
+
max_output_input_ratio=1,
|
| 66 |
+
mode='train'):
|
| 67 |
+
""" Filter sample according to feature and label length
|
| 68 |
+
Inplace operation.
|
| 69 |
+
|
| 70 |
+
Args::
|
| 71 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
| 72 |
+
max_length: drop utterance which is greater than max_length(10ms)
|
| 73 |
+
min_length: drop utterance which is less than min_length(10ms)
|
| 74 |
+
token_max_length: drop utterance which is greater than
|
| 75 |
+
token_max_length, especially when use char unit for
|
| 76 |
+
english modeling
|
| 77 |
+
token_min_length: drop utterance which is
|
| 78 |
+
less than token_max_length
|
| 79 |
+
min_output_input_ratio: minimal ration of
|
| 80 |
+
token_length / feats_length(10ms)
|
| 81 |
+
max_output_input_ratio: maximum ration of
|
| 82 |
+
token_length / feats_length(10ms)
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Iterable[{key, wav, label, sample_rate}]
|
| 86 |
+
"""
|
| 87 |
+
for sample in data:
|
| 88 |
+
sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
|
| 89 |
+
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
|
| 90 |
+
del sample['audio_data']
|
| 91 |
+
# sample['wav'] is torch.Tensor, we have 100 frames every second
|
| 92 |
+
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
|
| 93 |
+
if num_frames < min_length:
|
| 94 |
+
continue
|
| 95 |
+
if num_frames > max_length:
|
| 96 |
+
continue
|
| 97 |
+
if len(sample['text_token']) < token_min_length:
|
| 98 |
+
continue
|
| 99 |
+
if len(sample['text_token']) > token_max_length:
|
| 100 |
+
continue
|
| 101 |
+
if len(sample['speech_token']) == 0:
|
| 102 |
+
continue
|
| 103 |
+
if num_frames != 0:
|
| 104 |
+
if len(sample['text_token']) / num_frames < min_output_input_ratio:
|
| 105 |
+
continue
|
| 106 |
+
if len(sample['text_token']) / num_frames > max_output_input_ratio:
|
| 107 |
+
continue
|
| 108 |
+
yield sample
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
|
| 112 |
+
""" Resample data.
|
| 113 |
+
Inplace operation.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
| 117 |
+
resample_rate: target resample rate
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Iterable[{key, wav, label, sample_rate}]
|
| 121 |
+
"""
|
| 122 |
+
for sample in data:
|
| 123 |
+
assert 'sample_rate' in sample
|
| 124 |
+
assert 'speech' in sample
|
| 125 |
+
sample_rate = sample['sample_rate']
|
| 126 |
+
waveform = sample['speech']
|
| 127 |
+
if sample_rate != resample_rate:
|
| 128 |
+
if sample_rate < min_sample_rate:
|
| 129 |
+
continue
|
| 130 |
+
sample['sample_rate'] = resample_rate
|
| 131 |
+
sample['speech'] = torchaudio.transforms.Resample(
|
| 132 |
+
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
| 133 |
+
max_val = sample['speech'].abs().max()
|
| 134 |
+
if max_val > 1:
|
| 135 |
+
sample['speech'] /= max_val
|
| 136 |
+
yield sample
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def truncate(data, truncate_length=24576, mode='train'):
|
| 140 |
+
""" Truncate data.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
| 144 |
+
truncate_length: truncate length
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Iterable[{key, wav, label, sample_rate}]
|
| 148 |
+
"""
|
| 149 |
+
for sample in data:
|
| 150 |
+
waveform = sample['speech']
|
| 151 |
+
if waveform.shape[1] > truncate_length:
|
| 152 |
+
start = random.randint(0, waveform.shape[1] - truncate_length)
|
| 153 |
+
waveform = waveform[:, start: start + truncate_length]
|
| 154 |
+
else:
|
| 155 |
+
waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
|
| 156 |
+
sample['speech'] = waveform
|
| 157 |
+
yield sample
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def compute_fbank(data,
|
| 161 |
+
feat_extractor,
|
| 162 |
+
mode='train'):
|
| 163 |
+
""" Extract fbank
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Iterable[{key, feat, label}]
|
| 170 |
+
"""
|
| 171 |
+
for sample in data:
|
| 172 |
+
assert 'sample_rate' in sample
|
| 173 |
+
assert 'speech' in sample
|
| 174 |
+
assert 'utt' in sample
|
| 175 |
+
assert 'text_token' in sample
|
| 176 |
+
waveform = sample['speech']
|
| 177 |
+
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
| 178 |
+
sample['speech_feat'] = mat
|
| 179 |
+
yield sample
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def compute_f0(data, sample_rate, hop_size, mode='train'):
|
| 183 |
+
""" Extract f0
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Iterable[{key, feat, label}]
|
| 190 |
+
"""
|
| 191 |
+
frame_period = hop_size * 1000 / sample_rate
|
| 192 |
+
for sample in data:
|
| 193 |
+
assert 'sample_rate' in sample
|
| 194 |
+
assert 'speech' in sample
|
| 195 |
+
assert 'utt' in sample
|
| 196 |
+
assert 'text_token' in sample
|
| 197 |
+
waveform = sample['speech']
|
| 198 |
+
_f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
|
| 199 |
+
if sum(_f0 != 0) < 5: # this happens when the algorithm fails
|
| 200 |
+
_f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
|
| 201 |
+
f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
|
| 202 |
+
f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
|
| 203 |
+
sample['pitch_feat'] = f0
|
| 204 |
+
yield sample
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def parse_embedding(data, normalize, mode='train'):
|
| 208 |
+
""" Parse utt_embedding/spk_embedding
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
Iterable[{key, feat, label}]
|
| 215 |
+
"""
|
| 216 |
+
for sample in data:
|
| 217 |
+
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
|
| 218 |
+
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
|
| 219 |
+
if normalize:
|
| 220 |
+
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
|
| 221 |
+
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
|
| 222 |
+
yield sample
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def tokenize(data, get_tokenizer, allowed_special, mode='train'):
|
| 226 |
+
""" Decode text to chars or BPE
|
| 227 |
+
Inplace operation
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
data: Iterable[{key, wav, txt, sample_rate}]
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
Iterable[{key, wav, txt, tokens, label, sample_rate}]
|
| 234 |
+
"""
|
| 235 |
+
tokenizer = get_tokenizer()
|
| 236 |
+
for sample in data:
|
| 237 |
+
assert 'text' in sample
|
| 238 |
+
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
|
| 239 |
+
if mode == 'inference':
|
| 240 |
+
sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
|
| 241 |
+
yield sample
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def shuffle(data, shuffle_size=10000, mode='train'):
|
| 245 |
+
""" Local shuffle the data
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
data: Iterable[{key, feat, label}]
|
| 249 |
+
shuffle_size: buffer size for shuffle
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
Iterable[{key, feat, label}]
|
| 253 |
+
"""
|
| 254 |
+
buf = []
|
| 255 |
+
for sample in data:
|
| 256 |
+
buf.append(sample)
|
| 257 |
+
if len(buf) >= shuffle_size:
|
| 258 |
+
random.shuffle(buf)
|
| 259 |
+
for x in buf:
|
| 260 |
+
yield x
|
| 261 |
+
buf = []
|
| 262 |
+
# The sample left over
|
| 263 |
+
random.shuffle(buf)
|
| 264 |
+
for x in buf:
|
| 265 |
+
yield x
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def sort(data, sort_size=500, mode='train'):
|
| 269 |
+
""" Sort the data by feature length.
|
| 270 |
+
Sort is used after shuffle and before batch, so we can group
|
| 271 |
+
utts with similar lengths into a batch, and `sort_size` should
|
| 272 |
+
be less than `shuffle_size`
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
data: Iterable[{key, feat, label}]
|
| 276 |
+
sort_size: buffer size for sort
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
Iterable[{key, feat, label}]
|
| 280 |
+
"""
|
| 281 |
+
|
| 282 |
+
buf = []
|
| 283 |
+
for sample in data:
|
| 284 |
+
buf.append(sample)
|
| 285 |
+
if len(buf) >= sort_size:
|
| 286 |
+
buf.sort(key=lambda x: x['speech_feat'].size(0))
|
| 287 |
+
for x in buf:
|
| 288 |
+
yield x
|
| 289 |
+
buf = []
|
| 290 |
+
# The sample left over
|
| 291 |
+
buf.sort(key=lambda x: x['speech_feat'].size(0))
|
| 292 |
+
for x in buf:
|
| 293 |
+
yield x
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def static_batch(data, batch_size=16):
|
| 297 |
+
""" Static batch the data by `batch_size`
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
data: Iterable[{key, feat, label}]
|
| 301 |
+
batch_size: batch size
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
Iterable[List[{key, feat, label}]]
|
| 305 |
+
"""
|
| 306 |
+
buf = []
|
| 307 |
+
for sample in data:
|
| 308 |
+
buf.append(sample)
|
| 309 |
+
if len(buf) >= batch_size:
|
| 310 |
+
yield buf
|
| 311 |
+
buf = []
|
| 312 |
+
if len(buf) > 0:
|
| 313 |
+
yield buf
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
|
| 317 |
+
""" Dynamic batch the data until the total frames in batch
|
| 318 |
+
reach `max_frames_in_batch`
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
data: Iterable[{key, feat, label}]
|
| 322 |
+
max_frames_in_batch: max_frames in one batch
|
| 323 |
+
|
| 324 |
+
Returns:
|
| 325 |
+
Iterable[List[{key, feat, label}]]
|
| 326 |
+
"""
|
| 327 |
+
buf = []
|
| 328 |
+
longest_frames = 0
|
| 329 |
+
for sample in data:
|
| 330 |
+
assert 'speech_feat' in sample
|
| 331 |
+
assert isinstance(sample['speech_feat'], torch.Tensor)
|
| 332 |
+
new_sample_frames = sample['speech_feat'].size(0)
|
| 333 |
+
longest_frames = max(longest_frames, new_sample_frames)
|
| 334 |
+
frames_after_padding = longest_frames * (len(buf) + 1)
|
| 335 |
+
if frames_after_padding > max_frames_in_batch:
|
| 336 |
+
yield buf
|
| 337 |
+
buf = [sample]
|
| 338 |
+
longest_frames = new_sample_frames
|
| 339 |
+
else:
|
| 340 |
+
buf.append(sample)
|
| 341 |
+
if len(buf) > 0:
|
| 342 |
+
yield buf
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
|
| 346 |
+
""" Wrapper for static/dynamic batch
|
| 347 |
+
"""
|
| 348 |
+
if mode == 'inference':
|
| 349 |
+
return static_batch(data, 1)
|
| 350 |
+
else:
|
| 351 |
+
if batch_type == 'static':
|
| 352 |
+
return static_batch(data, batch_size)
|
| 353 |
+
elif batch_type == 'dynamic':
|
| 354 |
+
return dynamic_batch(data, max_frames_in_batch)
|
| 355 |
+
else:
|
| 356 |
+
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def padding(data, use_spk_embedding, mode='train', gan=False):
|
| 360 |
+
""" Padding the data into training data
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
data: Iterable[List[{key, feat, label}]]
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
|
| 367 |
+
"""
|
| 368 |
+
for sample in data:
|
| 369 |
+
assert isinstance(sample, list)
|
| 370 |
+
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
|
| 371 |
+
dtype=torch.int32)
|
| 372 |
+
order = torch.argsort(speech_feat_len, descending=True)
|
| 373 |
+
|
| 374 |
+
utts = [sample[i]['utt'] for i in order]
|
| 375 |
+
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
|
| 376 |
+
speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
|
| 377 |
+
speech = pad_sequence(speech, batch_first=True, padding_value=0)
|
| 378 |
+
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
|
| 379 |
+
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
|
| 380 |
+
speech_token = pad_sequence(speech_token,
|
| 381 |
+
batch_first=True,
|
| 382 |
+
padding_value=0)
|
| 383 |
+
speech_feat = [sample[i]['speech_feat'] for i in order]
|
| 384 |
+
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
|
| 385 |
+
speech_feat = pad_sequence(speech_feat,
|
| 386 |
+
batch_first=True,
|
| 387 |
+
padding_value=0)
|
| 388 |
+
text = [sample[i]['text'] for i in order]
|
| 389 |
+
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
|
| 390 |
+
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
|
| 391 |
+
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
|
| 392 |
+
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
|
| 393 |
+
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
| 394 |
+
batch = {
|
| 395 |
+
"utts": utts,
|
| 396 |
+
"speech": speech,
|
| 397 |
+
"speech_len": speech_len,
|
| 398 |
+
"speech_token": speech_token,
|
| 399 |
+
"speech_token_len": speech_token_len,
|
| 400 |
+
"speech_feat": speech_feat,
|
| 401 |
+
"speech_feat_len": speech_feat_len,
|
| 402 |
+
"text": text,
|
| 403 |
+
"text_token": text_token,
|
| 404 |
+
"text_token_len": text_token_len,
|
| 405 |
+
"utt_embedding": utt_embedding,
|
| 406 |
+
"spk_embedding": spk_embedding,
|
| 407 |
+
}
|
| 408 |
+
if gan is True:
|
| 409 |
+
# in gan train, we need pitch_feat
|
| 410 |
+
pitch_feat = [sample[i]['pitch_feat'] for i in order]
|
| 411 |
+
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
|
| 412 |
+
pitch_feat = pad_sequence(pitch_feat,
|
| 413 |
+
batch_first=True,
|
| 414 |
+
padding_value=0)
|
| 415 |
+
batch["pitch_feat"] = pitch_feat
|
| 416 |
+
batch["pitch_feat_len"] = pitch_feat_len
|
| 417 |
+
else:
|
| 418 |
+
# only gan train needs speech, delete it to save memory
|
| 419 |
+
del batch["speech"]
|
| 420 |
+
del batch["speech_len"]
|
| 421 |
+
if mode == 'inference':
|
| 422 |
+
tts_text = [sample[i]['tts_text'] for i in order]
|
| 423 |
+
tts_index = [sample[i]['tts_index'] for i in order]
|
| 424 |
+
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
|
| 425 |
+
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
|
| 426 |
+
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
|
| 427 |
+
batch.update({'tts_text': tts_text,
|
| 428 |
+
'tts_index': tts_index,
|
| 429 |
+
'tts_text_token': tts_text_token,
|
| 430 |
+
'tts_text_token_len': tts_text_token_len})
|
| 431 |
+
if use_spk_embedding is True:
|
| 432 |
+
batch["embedding"] = batch["spk_embedding"]
|
| 433 |
+
else:
|
| 434 |
+
batch["embedding"] = batch["utt_embedding"]
|
| 435 |
+
yield batch
|
model/cosyvoice/flow/__pycache__/decoder.cpython-310.pyc
ADDED
|
Binary file (22.6 kB). View file
|
|
|
model/cosyvoice/flow/__pycache__/decoder.cpython-311.pyc
ADDED
|
Binary file (48.9 kB). View file
|
|
|
model/cosyvoice/flow/__pycache__/flow.cpython-310.pyc
ADDED
|
Binary file (7.68 kB). View file
|
|
|
model/cosyvoice/flow/__pycache__/flow.cpython-311.pyc
ADDED
|
Binary file (17.1 kB). View file
|
|
|
model/cosyvoice/flow/__pycache__/flow_matching.cpython-310.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
model/cosyvoice/flow/__pycache__/flow_matching.cpython-311.pyc
ADDED
|
Binary file (22.8 kB). View file
|
|
|
model/cosyvoice/flow/decoder.py
ADDED
|
@@ -0,0 +1,902 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import Tuple, Optional, Dict, Any
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from einops import pack, rearrange, repeat
|
| 19 |
+
from diffusers.models.attention_processor import Attention, AttnProcessor2_0, inspect, logger, deprecate
|
| 20 |
+
from cosyvoice.utils.common import mask_to_bias
|
| 21 |
+
from cosyvoice.utils.mask import add_optional_chunk_mask
|
| 22 |
+
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
| 23 |
+
from matcha.models.components.transformer import BasicTransformerBlock, maybe_allow_in_graph
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Transpose(torch.nn.Module):
|
| 27 |
+
def __init__(self, dim0: int, dim1: int):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.dim0 = dim0
|
| 30 |
+
self.dim1 = dim1
|
| 31 |
+
|
| 32 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
|
| 33 |
+
x = torch.transpose(x, self.dim0, self.dim1)
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class CausalConv1d(torch.nn.Conv1d):
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
in_channels: int,
|
| 41 |
+
out_channels: int,
|
| 42 |
+
kernel_size: int,
|
| 43 |
+
stride: int = 1,
|
| 44 |
+
dilation: int = 1,
|
| 45 |
+
groups: int = 1,
|
| 46 |
+
bias: bool = True,
|
| 47 |
+
padding_mode: str = 'zeros',
|
| 48 |
+
device=None,
|
| 49 |
+
dtype=None
|
| 50 |
+
) -> None:
|
| 51 |
+
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
| 52 |
+
kernel_size, stride,
|
| 53 |
+
padding=0, dilation=dilation,
|
| 54 |
+
groups=groups, bias=bias,
|
| 55 |
+
padding_mode=padding_mode,
|
| 56 |
+
device=device, dtype=dtype)
|
| 57 |
+
assert stride == 1
|
| 58 |
+
self.causal_padding = kernel_size - 1
|
| 59 |
+
|
| 60 |
+
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 61 |
+
if cache.size(2) == 0:
|
| 62 |
+
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
| 63 |
+
else:
|
| 64 |
+
assert cache.size(2) == self.causal_padding
|
| 65 |
+
x = torch.concat([cache, x], dim=2)
|
| 66 |
+
cache = x[:, :, -self.causal_padding:]
|
| 67 |
+
x = super(CausalConv1d, self).forward(x)
|
| 68 |
+
return x, cache
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class CausalBlock1D(Block1D):
|
| 72 |
+
def __init__(self, dim: int, dim_out: int):
|
| 73 |
+
super(CausalBlock1D, self).__init__(dim, dim_out)
|
| 74 |
+
self.block = torch.nn.Sequential(
|
| 75 |
+
CausalConv1d(dim, dim_out, 3),
|
| 76 |
+
Transpose(1, 2),
|
| 77 |
+
nn.LayerNorm(dim_out),
|
| 78 |
+
Transpose(1, 2),
|
| 79 |
+
nn.Mish(),
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 83 |
+
output, cache = self.block[0](x * mask, cache)
|
| 84 |
+
for i in range(1, len(self.block)):
|
| 85 |
+
output = self.block[i](output)
|
| 86 |
+
return output * mask, cache
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class CausalResnetBlock1D(ResnetBlock1D):
|
| 90 |
+
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
| 91 |
+
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
| 92 |
+
self.block1 = CausalBlock1D(dim, dim_out)
|
| 93 |
+
self.block2 = CausalBlock1D(dim_out, dim_out)
|
| 94 |
+
|
| 95 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor,
|
| 96 |
+
block1_cache: torch.Tensor = torch.zeros(0, 0, 0), block2_cache: torch.Tensor = torch.zeros(0, 0, 0)
|
| 97 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 98 |
+
h, block1_cache = self.block1(x, mask, block1_cache)
|
| 99 |
+
h += self.mlp(time_emb).unsqueeze(-1)
|
| 100 |
+
h, block2_cache = self.block2(h, mask, block2_cache)
|
| 101 |
+
output = h + self.res_conv(x * mask)
|
| 102 |
+
return output, block1_cache, block2_cache
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class CausalAttnProcessor2_0(AttnProcessor2_0):
|
| 106 |
+
r"""
|
| 107 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
def __init__(self):
|
| 111 |
+
super(CausalAttnProcessor2_0, self).__init__()
|
| 112 |
+
|
| 113 |
+
def __call__(
|
| 114 |
+
self,
|
| 115 |
+
attn: Attention,
|
| 116 |
+
hidden_states: torch.FloatTensor,
|
| 117 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 118 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 119 |
+
temb: Optional[torch.FloatTensor] = None,
|
| 120 |
+
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
| 121 |
+
*args,
|
| 122 |
+
**kwargs,
|
| 123 |
+
) -> Tuple[torch.FloatTensor, torch.Tensor]:
|
| 124 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
| 125 |
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. \
|
| 126 |
+
`scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
| 127 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
| 128 |
+
|
| 129 |
+
residual = hidden_states
|
| 130 |
+
if attn.spatial_norm is not None:
|
| 131 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 132 |
+
|
| 133 |
+
input_ndim = hidden_states.ndim
|
| 134 |
+
|
| 135 |
+
if input_ndim == 4:
|
| 136 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 137 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 138 |
+
|
| 139 |
+
batch_size, sequence_length, _ = (
|
| 140 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if attention_mask is not None:
|
| 144 |
+
# NOTE do not use attn.prepare_attention_mask as we have already provided the correct attention_mask
|
| 145 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 146 |
+
# (batch, heads, source_length, target_length)
|
| 147 |
+
attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, attn.heads, 1, 1)
|
| 148 |
+
|
| 149 |
+
if attn.group_norm is not None:
|
| 150 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 151 |
+
|
| 152 |
+
query = attn.to_q(hidden_states)
|
| 153 |
+
|
| 154 |
+
if encoder_hidden_states is None:
|
| 155 |
+
encoder_hidden_states = hidden_states
|
| 156 |
+
elif attn.norm_cross:
|
| 157 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 158 |
+
|
| 159 |
+
key_cache = attn.to_k(encoder_hidden_states)
|
| 160 |
+
value_cache = attn.to_v(encoder_hidden_states)
|
| 161 |
+
# NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2)
|
| 162 |
+
if cache.size(0) != 0:
|
| 163 |
+
key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
|
| 164 |
+
value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
|
| 165 |
+
else:
|
| 166 |
+
key, value = key_cache, value_cache
|
| 167 |
+
cache = torch.stack([key_cache, value_cache], dim=3)
|
| 168 |
+
|
| 169 |
+
inner_dim = key.shape[-1]
|
| 170 |
+
head_dim = inner_dim // attn.heads
|
| 171 |
+
|
| 172 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 173 |
+
|
| 174 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 175 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 176 |
+
|
| 177 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 178 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 179 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 180 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 184 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 185 |
+
|
| 186 |
+
# linear proj
|
| 187 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 188 |
+
# dropout
|
| 189 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 190 |
+
|
| 191 |
+
if input_ndim == 4:
|
| 192 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 193 |
+
|
| 194 |
+
if attn.residual_connection:
|
| 195 |
+
hidden_states = hidden_states + residual
|
| 196 |
+
|
| 197 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 198 |
+
|
| 199 |
+
return hidden_states, cache
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@maybe_allow_in_graph
|
| 203 |
+
class CausalAttention(Attention):
|
| 204 |
+
def __init__(
|
| 205 |
+
self,
|
| 206 |
+
query_dim: int,
|
| 207 |
+
cross_attention_dim: Optional[int] = None,
|
| 208 |
+
heads: int = 8,
|
| 209 |
+
dim_head: int = 64,
|
| 210 |
+
dropout: float = 0.0,
|
| 211 |
+
bias: bool = False,
|
| 212 |
+
upcast_attention: bool = False,
|
| 213 |
+
upcast_softmax: bool = False,
|
| 214 |
+
cross_attention_norm: Optional[str] = None,
|
| 215 |
+
cross_attention_norm_num_groups: int = 32,
|
| 216 |
+
qk_norm: Optional[str] = None,
|
| 217 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 218 |
+
norm_num_groups: Optional[int] = None,
|
| 219 |
+
spatial_norm_dim: Optional[int] = None,
|
| 220 |
+
out_bias: bool = True,
|
| 221 |
+
scale_qk: bool = True,
|
| 222 |
+
only_cross_attention: bool = False,
|
| 223 |
+
eps: float = 1e-5,
|
| 224 |
+
rescale_output_factor: float = 1.0,
|
| 225 |
+
residual_connection: bool = False,
|
| 226 |
+
_from_deprecated_attn_block: bool = False,
|
| 227 |
+
processor: Optional["AttnProcessor2_0"] = None,
|
| 228 |
+
out_dim: int = None,
|
| 229 |
+
):
|
| 230 |
+
super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax,
|
| 231 |
+
cross_attention_norm, cross_attention_norm_num_groups, qk_norm, added_kv_proj_dim, norm_num_groups,
|
| 232 |
+
spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection,
|
| 233 |
+
_from_deprecated_attn_block, processor, out_dim)
|
| 234 |
+
processor = CausalAttnProcessor2_0()
|
| 235 |
+
self.set_processor(processor)
|
| 236 |
+
|
| 237 |
+
def forward(
|
| 238 |
+
self,
|
| 239 |
+
hidden_states: torch.FloatTensor,
|
| 240 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 241 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 242 |
+
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
| 243 |
+
**cross_attention_kwargs,
|
| 244 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 245 |
+
r"""
|
| 246 |
+
The forward method of the `Attention` class.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
hidden_states (`torch.Tensor`):
|
| 250 |
+
The hidden states of the query.
|
| 251 |
+
encoder_hidden_states (`torch.Tensor`, *optional*):
|
| 252 |
+
The hidden states of the encoder.
|
| 253 |
+
attention_mask (`torch.Tensor`, *optional*):
|
| 254 |
+
The attention mask to use. If `None`, no mask is applied.
|
| 255 |
+
**cross_attention_kwargs:
|
| 256 |
+
Additional keyword arguments to pass along to the cross attention.
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
`torch.Tensor`: The output of the attention layer.
|
| 260 |
+
"""
|
| 261 |
+
# The `Attention` class can call different attention processors / attention functions
|
| 262 |
+
# here we simply pass along all tensors to the selected processor class
|
| 263 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
| 264 |
+
|
| 265 |
+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
| 266 |
+
unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
|
| 267 |
+
if len(unused_kwargs) > 0:
|
| 268 |
+
logger.warning(
|
| 269 |
+
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
| 270 |
+
)
|
| 271 |
+
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
|
| 272 |
+
|
| 273 |
+
return self.processor(
|
| 274 |
+
self,
|
| 275 |
+
hidden_states,
|
| 276 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 277 |
+
attention_mask=attention_mask,
|
| 278 |
+
cache=cache,
|
| 279 |
+
**cross_attention_kwargs,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
@maybe_allow_in_graph
|
| 284 |
+
class CausalBasicTransformerBlock(BasicTransformerBlock):
|
| 285 |
+
def __init__(
|
| 286 |
+
self,
|
| 287 |
+
dim: int,
|
| 288 |
+
num_attention_heads: int,
|
| 289 |
+
attention_head_dim: int,
|
| 290 |
+
dropout=0.0,
|
| 291 |
+
cross_attention_dim: Optional[int] = None,
|
| 292 |
+
activation_fn: str = "geglu",
|
| 293 |
+
num_embeds_ada_norm: Optional[int] = None,
|
| 294 |
+
attention_bias: bool = False,
|
| 295 |
+
only_cross_attention: bool = False,
|
| 296 |
+
double_self_attention: bool = False,
|
| 297 |
+
upcast_attention: bool = False,
|
| 298 |
+
norm_elementwise_affine: bool = True,
|
| 299 |
+
norm_type: str = "layer_norm",
|
| 300 |
+
final_dropout: bool = False,
|
| 301 |
+
):
|
| 302 |
+
super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout,
|
| 303 |
+
cross_attention_dim, activation_fn, num_embeds_ada_norm,
|
| 304 |
+
attention_bias, only_cross_attention, double_self_attention,
|
| 305 |
+
upcast_attention, norm_elementwise_affine, norm_type, final_dropout)
|
| 306 |
+
self.attn1 = CausalAttention(
|
| 307 |
+
query_dim=dim,
|
| 308 |
+
heads=num_attention_heads,
|
| 309 |
+
dim_head=attention_head_dim,
|
| 310 |
+
dropout=dropout,
|
| 311 |
+
bias=attention_bias,
|
| 312 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
| 313 |
+
upcast_attention=upcast_attention,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
def forward(
|
| 317 |
+
self,
|
| 318 |
+
hidden_states: torch.FloatTensor,
|
| 319 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 320 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 321 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 322 |
+
timestep: Optional[torch.LongTensor] = None,
|
| 323 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
| 324 |
+
class_labels: Optional[torch.LongTensor] = None,
|
| 325 |
+
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
| 326 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 327 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
| 328 |
+
# 1. Self-Attention
|
| 329 |
+
if self.use_ada_layer_norm:
|
| 330 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
| 331 |
+
elif self.use_ada_layer_norm_zero:
|
| 332 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
| 333 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
| 334 |
+
)
|
| 335 |
+
else:
|
| 336 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 337 |
+
|
| 338 |
+
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
| 339 |
+
|
| 340 |
+
attn_output, cache = self.attn1(
|
| 341 |
+
norm_hidden_states,
|
| 342 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
| 343 |
+
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
|
| 344 |
+
cache=cache,
|
| 345 |
+
**cross_attention_kwargs,
|
| 346 |
+
)
|
| 347 |
+
if self.use_ada_layer_norm_zero:
|
| 348 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
| 349 |
+
hidden_states = attn_output + hidden_states
|
| 350 |
+
|
| 351 |
+
# 2. Cross-Attention
|
| 352 |
+
if self.attn2 is not None:
|
| 353 |
+
norm_hidden_states = (
|
| 354 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
attn_output = self.attn2(
|
| 358 |
+
norm_hidden_states,
|
| 359 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 360 |
+
attention_mask=encoder_attention_mask,
|
| 361 |
+
**cross_attention_kwargs,
|
| 362 |
+
)
|
| 363 |
+
hidden_states = attn_output + hidden_states
|
| 364 |
+
|
| 365 |
+
# 3. Feed-forward
|
| 366 |
+
norm_hidden_states = self.norm3(hidden_states)
|
| 367 |
+
|
| 368 |
+
if self.use_ada_layer_norm_zero:
|
| 369 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 370 |
+
|
| 371 |
+
if self._chunk_size is not None:
|
| 372 |
+
# "feed_forward_chunk_size" can be used to save memory
|
| 373 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
| 374 |
+
raise ValueError(f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: \
|
| 375 |
+
{self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.")
|
| 376 |
+
|
| 377 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
| 378 |
+
ff_output = torch.cat(
|
| 379 |
+
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
| 380 |
+
dim=self._chunk_dim,
|
| 381 |
+
)
|
| 382 |
+
else:
|
| 383 |
+
ff_output = self.ff(norm_hidden_states)
|
| 384 |
+
|
| 385 |
+
if self.use_ada_layer_norm_zero:
|
| 386 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 387 |
+
|
| 388 |
+
hidden_states = ff_output + hidden_states
|
| 389 |
+
|
| 390 |
+
return hidden_states, cache
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
class ConditionalDecoder(nn.Module):
|
| 394 |
+
def __init__(
|
| 395 |
+
self,
|
| 396 |
+
in_channels,
|
| 397 |
+
out_channels,
|
| 398 |
+
channels=(256, 256),
|
| 399 |
+
dropout=0.05,
|
| 400 |
+
attention_head_dim=64,
|
| 401 |
+
n_blocks=1,
|
| 402 |
+
num_mid_blocks=2,
|
| 403 |
+
num_heads=4,
|
| 404 |
+
act_fn="snake",
|
| 405 |
+
):
|
| 406 |
+
"""
|
| 407 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
| 408 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
| 409 |
+
"""
|
| 410 |
+
super().__init__()
|
| 411 |
+
channels = tuple(channels)
|
| 412 |
+
self.in_channels = in_channels
|
| 413 |
+
self.out_channels = out_channels
|
| 414 |
+
|
| 415 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
| 416 |
+
time_embed_dim = channels[0] * 4
|
| 417 |
+
self.time_mlp = TimestepEmbedding(
|
| 418 |
+
in_channels=in_channels,
|
| 419 |
+
time_embed_dim=time_embed_dim,
|
| 420 |
+
act_fn="silu",
|
| 421 |
+
)
|
| 422 |
+
self.down_blocks = nn.ModuleList([])
|
| 423 |
+
self.mid_blocks = nn.ModuleList([])
|
| 424 |
+
self.up_blocks = nn.ModuleList([])
|
| 425 |
+
|
| 426 |
+
output_channel = in_channels
|
| 427 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
| 428 |
+
input_channel = output_channel
|
| 429 |
+
output_channel = channels[i]
|
| 430 |
+
is_last = i == len(channels) - 1
|
| 431 |
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
| 432 |
+
transformer_blocks = nn.ModuleList(
|
| 433 |
+
[
|
| 434 |
+
BasicTransformerBlock(
|
| 435 |
+
dim=output_channel,
|
| 436 |
+
num_attention_heads=num_heads,
|
| 437 |
+
attention_head_dim=attention_head_dim,
|
| 438 |
+
dropout=dropout,
|
| 439 |
+
activation_fn=act_fn,
|
| 440 |
+
)
|
| 441 |
+
for _ in range(n_blocks)
|
| 442 |
+
]
|
| 443 |
+
)
|
| 444 |
+
downsample = (
|
| 445 |
+
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
| 446 |
+
)
|
| 447 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
| 448 |
+
|
| 449 |
+
for _ in range(num_mid_blocks):
|
| 450 |
+
input_channel = channels[-1]
|
| 451 |
+
out_channels = channels[-1]
|
| 452 |
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
| 453 |
+
|
| 454 |
+
transformer_blocks = nn.ModuleList(
|
| 455 |
+
[
|
| 456 |
+
BasicTransformerBlock(
|
| 457 |
+
dim=output_channel,
|
| 458 |
+
num_attention_heads=num_heads,
|
| 459 |
+
attention_head_dim=attention_head_dim,
|
| 460 |
+
dropout=dropout,
|
| 461 |
+
activation_fn=act_fn,
|
| 462 |
+
)
|
| 463 |
+
for _ in range(n_blocks)
|
| 464 |
+
]
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
| 468 |
+
|
| 469 |
+
channels = channels[::-1] + (channels[0],)
|
| 470 |
+
for i in range(len(channels) - 1):
|
| 471 |
+
input_channel = channels[i] * 2
|
| 472 |
+
output_channel = channels[i + 1]
|
| 473 |
+
is_last = i == len(channels) - 2
|
| 474 |
+
resnet = ResnetBlock1D(
|
| 475 |
+
dim=input_channel,
|
| 476 |
+
dim_out=output_channel,
|
| 477 |
+
time_emb_dim=time_embed_dim,
|
| 478 |
+
)
|
| 479 |
+
transformer_blocks = nn.ModuleList(
|
| 480 |
+
[
|
| 481 |
+
BasicTransformerBlock(
|
| 482 |
+
dim=output_channel,
|
| 483 |
+
num_attention_heads=num_heads,
|
| 484 |
+
attention_head_dim=attention_head_dim,
|
| 485 |
+
dropout=dropout,
|
| 486 |
+
activation_fn=act_fn,
|
| 487 |
+
)
|
| 488 |
+
for _ in range(n_blocks)
|
| 489 |
+
]
|
| 490 |
+
)
|
| 491 |
+
upsample = (
|
| 492 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
| 493 |
+
if not is_last
|
| 494 |
+
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
| 495 |
+
)
|
| 496 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
| 497 |
+
self.final_block = Block1D(channels[-1], channels[-1])
|
| 498 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
| 499 |
+
self.initialize_weights()
|
| 500 |
+
|
| 501 |
+
def initialize_weights(self):
|
| 502 |
+
for m in self.modules():
|
| 503 |
+
if isinstance(m, nn.Conv1d):
|
| 504 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
| 505 |
+
if m.bias is not None:
|
| 506 |
+
nn.init.constant_(m.bias, 0)
|
| 507 |
+
elif isinstance(m, nn.GroupNorm):
|
| 508 |
+
nn.init.constant_(m.weight, 1)
|
| 509 |
+
nn.init.constant_(m.bias, 0)
|
| 510 |
+
elif isinstance(m, nn.Linear):
|
| 511 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
| 512 |
+
if m.bias is not None:
|
| 513 |
+
nn.init.constant_(m.bias, 0)
|
| 514 |
+
|
| 515 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
| 516 |
+
"""Forward pass of the UNet1DConditional model.
|
| 517 |
+
|
| 518 |
+
Args:
|
| 519 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
| 520 |
+
mask (_type_): shape (batch_size, 1, time)
|
| 521 |
+
t (_type_): shape (batch_size)
|
| 522 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
| 523 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
| 524 |
+
|
| 525 |
+
Raises:
|
| 526 |
+
ValueError: _description_
|
| 527 |
+
ValueError: _description_
|
| 528 |
+
|
| 529 |
+
Returns:
|
| 530 |
+
_type_: _description_
|
| 531 |
+
"""
|
| 532 |
+
|
| 533 |
+
t = self.time_embeddings(t).to(t.dtype)
|
| 534 |
+
t = self.time_mlp(t)
|
| 535 |
+
|
| 536 |
+
x = pack([x, mu], "b * t")[0]
|
| 537 |
+
|
| 538 |
+
if spks is not None:
|
| 539 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
| 540 |
+
x = pack([x, spks], "b * t")[0]
|
| 541 |
+
if cond is not None:
|
| 542 |
+
x = pack([x, cond], "b * t")[0]
|
| 543 |
+
|
| 544 |
+
hiddens = []
|
| 545 |
+
masks = [mask]
|
| 546 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
| 547 |
+
mask_down = masks[-1]
|
| 548 |
+
x = resnet(x, mask_down, t)
|
| 549 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 550 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
| 551 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
| 552 |
+
for transformer_block in transformer_blocks:
|
| 553 |
+
x = transformer_block(
|
| 554 |
+
hidden_states=x,
|
| 555 |
+
attention_mask=attn_mask,
|
| 556 |
+
timestep=t,
|
| 557 |
+
)
|
| 558 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 559 |
+
hiddens.append(x) # Save hidden states for skip connections
|
| 560 |
+
x = downsample(x * mask_down)
|
| 561 |
+
masks.append(mask_down[:, :, ::2])
|
| 562 |
+
masks = masks[:-1]
|
| 563 |
+
mask_mid = masks[-1]
|
| 564 |
+
|
| 565 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
| 566 |
+
x = resnet(x, mask_mid, t)
|
| 567 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 568 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
| 569 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
| 570 |
+
for transformer_block in transformer_blocks:
|
| 571 |
+
x = transformer_block(
|
| 572 |
+
hidden_states=x,
|
| 573 |
+
attention_mask=attn_mask,
|
| 574 |
+
timestep=t,
|
| 575 |
+
)
|
| 576 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 577 |
+
|
| 578 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
| 579 |
+
mask_up = masks.pop()
|
| 580 |
+
skip = hiddens.pop()
|
| 581 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
| 582 |
+
x = resnet(x, mask_up, t)
|
| 583 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 584 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
| 585 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
| 586 |
+
for transformer_block in transformer_blocks:
|
| 587 |
+
x = transformer_block(
|
| 588 |
+
hidden_states=x,
|
| 589 |
+
attention_mask=attn_mask,
|
| 590 |
+
timestep=t,
|
| 591 |
+
)
|
| 592 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 593 |
+
x = upsample(x * mask_up)
|
| 594 |
+
x = self.final_block(x, mask_up)
|
| 595 |
+
output = self.final_proj(x * mask_up)
|
| 596 |
+
return output * mask
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
class CausalConditionalDecoder(ConditionalDecoder):
|
| 600 |
+
def __init__(
|
| 601 |
+
self,
|
| 602 |
+
in_channels,
|
| 603 |
+
out_channels,
|
| 604 |
+
channels=(256, 256),
|
| 605 |
+
dropout=0.05,
|
| 606 |
+
attention_head_dim=64,
|
| 607 |
+
n_blocks=1,
|
| 608 |
+
num_mid_blocks=2,
|
| 609 |
+
num_heads=4,
|
| 610 |
+
act_fn="snake",
|
| 611 |
+
static_chunk_size=50,
|
| 612 |
+
num_decoding_left_chunks=2,
|
| 613 |
+
):
|
| 614 |
+
"""
|
| 615 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
| 616 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
| 617 |
+
"""
|
| 618 |
+
torch.nn.Module.__init__(self)
|
| 619 |
+
channels = tuple(channels)
|
| 620 |
+
self.in_channels = in_channels
|
| 621 |
+
self.out_channels = out_channels
|
| 622 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
| 623 |
+
time_embed_dim = channels[0] * 4
|
| 624 |
+
self.time_mlp = TimestepEmbedding(
|
| 625 |
+
in_channels=in_channels,
|
| 626 |
+
time_embed_dim=time_embed_dim,
|
| 627 |
+
act_fn="silu",
|
| 628 |
+
)
|
| 629 |
+
self.static_chunk_size = static_chunk_size
|
| 630 |
+
self.num_decoding_left_chunks = num_decoding_left_chunks
|
| 631 |
+
self.down_blocks = nn.ModuleList([])
|
| 632 |
+
self.mid_blocks = nn.ModuleList([])
|
| 633 |
+
self.up_blocks = nn.ModuleList([])
|
| 634 |
+
|
| 635 |
+
output_channel = in_channels
|
| 636 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
| 637 |
+
input_channel = output_channel
|
| 638 |
+
output_channel = channels[i]
|
| 639 |
+
is_last = i == len(channels) - 1
|
| 640 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
| 641 |
+
transformer_blocks = nn.ModuleList(
|
| 642 |
+
[
|
| 643 |
+
CausalBasicTransformerBlock(
|
| 644 |
+
dim=output_channel,
|
| 645 |
+
num_attention_heads=num_heads,
|
| 646 |
+
attention_head_dim=attention_head_dim,
|
| 647 |
+
dropout=dropout,
|
| 648 |
+
activation_fn=act_fn,
|
| 649 |
+
)
|
| 650 |
+
for _ in range(n_blocks)
|
| 651 |
+
]
|
| 652 |
+
)
|
| 653 |
+
downsample = (
|
| 654 |
+
Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
|
| 655 |
+
)
|
| 656 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
| 657 |
+
|
| 658 |
+
for _ in range(num_mid_blocks):
|
| 659 |
+
input_channel = channels[-1]
|
| 660 |
+
out_channels = channels[-1]
|
| 661 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
| 662 |
+
|
| 663 |
+
transformer_blocks = nn.ModuleList(
|
| 664 |
+
[
|
| 665 |
+
CausalBasicTransformerBlock(
|
| 666 |
+
dim=output_channel,
|
| 667 |
+
num_attention_heads=num_heads,
|
| 668 |
+
attention_head_dim=attention_head_dim,
|
| 669 |
+
dropout=dropout,
|
| 670 |
+
activation_fn=act_fn,
|
| 671 |
+
)
|
| 672 |
+
for _ in range(n_blocks)
|
| 673 |
+
]
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
| 677 |
+
|
| 678 |
+
channels = channels[::-1] + (channels[0],)
|
| 679 |
+
for i in range(len(channels) - 1):
|
| 680 |
+
input_channel = channels[i] * 2
|
| 681 |
+
output_channel = channels[i + 1]
|
| 682 |
+
is_last = i == len(channels) - 2
|
| 683 |
+
resnet = CausalResnetBlock1D(
|
| 684 |
+
dim=input_channel,
|
| 685 |
+
dim_out=output_channel,
|
| 686 |
+
time_emb_dim=time_embed_dim,
|
| 687 |
+
)
|
| 688 |
+
transformer_blocks = nn.ModuleList(
|
| 689 |
+
[
|
| 690 |
+
CausalBasicTransformerBlock(
|
| 691 |
+
dim=output_channel,
|
| 692 |
+
num_attention_heads=num_heads,
|
| 693 |
+
attention_head_dim=attention_head_dim,
|
| 694 |
+
dropout=dropout,
|
| 695 |
+
activation_fn=act_fn,
|
| 696 |
+
)
|
| 697 |
+
for _ in range(n_blocks)
|
| 698 |
+
]
|
| 699 |
+
)
|
| 700 |
+
upsample = (
|
| 701 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
| 702 |
+
if not is_last
|
| 703 |
+
else CausalConv1d(output_channel, output_channel, 3)
|
| 704 |
+
)
|
| 705 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
| 706 |
+
self.final_block = CausalBlock1D(channels[-1], channels[-1])
|
| 707 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
| 708 |
+
self.initialize_weights()
|
| 709 |
+
|
| 710 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
| 711 |
+
"""Forward pass of the UNet1DConditional model.
|
| 712 |
+
|
| 713 |
+
Args:
|
| 714 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
| 715 |
+
mask (_type_): shape (batch_size, 1, time)
|
| 716 |
+
t (_type_): shape (batch_size)
|
| 717 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
| 718 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
| 719 |
+
|
| 720 |
+
Raises:
|
| 721 |
+
ValueError: _description_
|
| 722 |
+
ValueError: _description_
|
| 723 |
+
|
| 724 |
+
Returns:
|
| 725 |
+
_type_: _description_
|
| 726 |
+
"""
|
| 727 |
+
|
| 728 |
+
t = self.time_embeddings(t).to(t.dtype)
|
| 729 |
+
t = self.time_mlp(t)
|
| 730 |
+
|
| 731 |
+
x = pack([x, mu], "b * t")[0]
|
| 732 |
+
|
| 733 |
+
if spks is not None:
|
| 734 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
| 735 |
+
x = pack([x, spks], "b * t")[0]
|
| 736 |
+
if cond is not None:
|
| 737 |
+
x = pack([x, cond], "b * t")[0]
|
| 738 |
+
|
| 739 |
+
hiddens = []
|
| 740 |
+
masks = [mask]
|
| 741 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
| 742 |
+
mask_down = masks[-1]
|
| 743 |
+
x, _, _ = resnet(x, mask_down, t)
|
| 744 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 745 |
+
if streaming is True:
|
| 746 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
| 747 |
+
else:
|
| 748 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
| 749 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
| 750 |
+
for transformer_block in transformer_blocks:
|
| 751 |
+
x, _ = transformer_block(
|
| 752 |
+
hidden_states=x,
|
| 753 |
+
attention_mask=attn_mask,
|
| 754 |
+
timestep=t,
|
| 755 |
+
)
|
| 756 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 757 |
+
hiddens.append(x) # Save hidden states for skip connections
|
| 758 |
+
x, _ = downsample(x * mask_down)
|
| 759 |
+
masks.append(mask_down[:, :, ::2])
|
| 760 |
+
masks = masks[:-1]
|
| 761 |
+
mask_mid = masks[-1]
|
| 762 |
+
|
| 763 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
| 764 |
+
x, _, _ = resnet(x, mask_mid, t)
|
| 765 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 766 |
+
if streaming is True:
|
| 767 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
| 768 |
+
else:
|
| 769 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
| 770 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
| 771 |
+
for transformer_block in transformer_blocks:
|
| 772 |
+
x, _ = transformer_block(
|
| 773 |
+
hidden_states=x,
|
| 774 |
+
attention_mask=attn_mask,
|
| 775 |
+
timestep=t,
|
| 776 |
+
)
|
| 777 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 778 |
+
|
| 779 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
| 780 |
+
mask_up = masks.pop()
|
| 781 |
+
skip = hiddens.pop()
|
| 782 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
| 783 |
+
x, _, _ = resnet(x, mask_up, t)
|
| 784 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 785 |
+
if streaming is True:
|
| 786 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
| 787 |
+
else:
|
| 788 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
| 789 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
| 790 |
+
for transformer_block in transformer_blocks:
|
| 791 |
+
x, _ = transformer_block(
|
| 792 |
+
hidden_states=x,
|
| 793 |
+
attention_mask=attn_mask,
|
| 794 |
+
timestep=t,
|
| 795 |
+
)
|
| 796 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 797 |
+
x, _ = upsample(x * mask_up)
|
| 798 |
+
x, _ = self.final_block(x, mask_up)
|
| 799 |
+
output = self.final_proj(x * mask_up)
|
| 800 |
+
return output * mask
|
| 801 |
+
|
| 802 |
+
def forward_chunk(self, x, mask, mu, t, spks=None, cond=None,
|
| 803 |
+
down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
| 804 |
+
down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
|
| 805 |
+
mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
| 806 |
+
mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
|
| 807 |
+
up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
| 808 |
+
up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
|
| 809 |
+
final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0)
|
| 810 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 811 |
+
"""Forward pass of the UNet1DConditional model.
|
| 812 |
+
|
| 813 |
+
Args:
|
| 814 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
| 815 |
+
mask (_type_): shape (batch_size, 1, time)
|
| 816 |
+
t (_type_): shape (batch_size)
|
| 817 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
| 818 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
| 819 |
+
|
| 820 |
+
Raises:
|
| 821 |
+
ValueError: _description_
|
| 822 |
+
ValueError: _description_
|
| 823 |
+
|
| 824 |
+
Returns:
|
| 825 |
+
_type_: _description_
|
| 826 |
+
"""
|
| 827 |
+
|
| 828 |
+
t = self.time_embeddings(t).to(t.dtype)
|
| 829 |
+
t = self.time_mlp(t)
|
| 830 |
+
|
| 831 |
+
x = pack([x, mu], "b * t")[0]
|
| 832 |
+
|
| 833 |
+
if spks is not None:
|
| 834 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
| 835 |
+
x = pack([x, spks], "b * t")[0]
|
| 836 |
+
if cond is not None:
|
| 837 |
+
x = pack([x, cond], "b * t")[0]
|
| 838 |
+
|
| 839 |
+
hiddens = []
|
| 840 |
+
masks = [mask]
|
| 841 |
+
|
| 842 |
+
down_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
|
| 843 |
+
mid_blocks_kv_cache_new = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x.device)
|
| 844 |
+
up_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
|
| 845 |
+
for index, (resnet, transformer_blocks, downsample) in enumerate(self.down_blocks):
|
| 846 |
+
mask_down = masks[-1]
|
| 847 |
+
x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = \
|
| 848 |
+
resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576])
|
| 849 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 850 |
+
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + down_blocks_kv_cache.size(3), device=x.device).bool()
|
| 851 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
| 852 |
+
for i, transformer_block in enumerate(transformer_blocks):
|
| 853 |
+
x, down_blocks_kv_cache_new[index, i] = transformer_block(
|
| 854 |
+
hidden_states=x,
|
| 855 |
+
attention_mask=attn_mask,
|
| 856 |
+
timestep=t,
|
| 857 |
+
cache=down_blocks_kv_cache[index, i],
|
| 858 |
+
)
|
| 859 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 860 |
+
hiddens.append(x) # Save hidden states for skip connections
|
| 861 |
+
x, down_blocks_conv_cache[index][:, 576:] = downsample(x * mask_down, down_blocks_conv_cache[index][:, 576:])
|
| 862 |
+
masks.append(mask_down[:, :, ::2])
|
| 863 |
+
masks = masks[:-1]
|
| 864 |
+
mask_mid = masks[-1]
|
| 865 |
+
|
| 866 |
+
for index, (resnet, transformer_blocks) in enumerate(self.mid_blocks):
|
| 867 |
+
x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = \
|
| 868 |
+
resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:])
|
| 869 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 870 |
+
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + mid_blocks_kv_cache.size(3), device=x.device).bool()
|
| 871 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
| 872 |
+
for i, transformer_block in enumerate(transformer_blocks):
|
| 873 |
+
x, mid_blocks_kv_cache_new[index, i] = transformer_block(
|
| 874 |
+
hidden_states=x,
|
| 875 |
+
attention_mask=attn_mask,
|
| 876 |
+
timestep=t,
|
| 877 |
+
cache=mid_blocks_kv_cache[index, i]
|
| 878 |
+
)
|
| 879 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 880 |
+
|
| 881 |
+
for index, (resnet, transformer_blocks, upsample) in enumerate(self.up_blocks):
|
| 882 |
+
mask_up = masks.pop()
|
| 883 |
+
skip = hiddens.pop()
|
| 884 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
| 885 |
+
x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = \
|
| 886 |
+
resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768])
|
| 887 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 888 |
+
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + up_blocks_kv_cache.size(3), device=x.device).bool()
|
| 889 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
| 890 |
+
for i, transformer_block in enumerate(transformer_blocks):
|
| 891 |
+
x, up_blocks_kv_cache_new[index, i] = transformer_block(
|
| 892 |
+
hidden_states=x,
|
| 893 |
+
attention_mask=attn_mask,
|
| 894 |
+
timestep=t,
|
| 895 |
+
cache=up_blocks_kv_cache[index, i]
|
| 896 |
+
)
|
| 897 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 898 |
+
x, up_blocks_conv_cache[index][:, 768:] = upsample(x * mask_up, up_blocks_conv_cache[index][:, 768:])
|
| 899 |
+
x, final_blocks_conv_cache = self.final_block(x, mask_up, final_blocks_conv_cache)
|
| 900 |
+
output = self.final_proj(x * mask_up)
|
| 901 |
+
return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, \
|
| 902 |
+
up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache
|
model/cosyvoice/flow/flow.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import logging
|
| 15 |
+
import random
|
| 16 |
+
from typing import Dict, Optional
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from torch.nn import functional as F
|
| 20 |
+
from omegaconf import DictConfig
|
| 21 |
+
from cosyvoice.utils.mask import make_pad_mask
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MaskedDiffWithXvec(torch.nn.Module):
|
| 25 |
+
def __init__(self,
|
| 26 |
+
input_size: int = 512,
|
| 27 |
+
output_size: int = 80,
|
| 28 |
+
spk_embed_dim: int = 192,
|
| 29 |
+
output_type: str = "mel",
|
| 30 |
+
vocab_size: int = 4096,
|
| 31 |
+
input_frame_rate: int = 50,
|
| 32 |
+
only_mask_loss: bool = True,
|
| 33 |
+
encoder: torch.nn.Module = None,
|
| 34 |
+
length_regulator: torch.nn.Module = None,
|
| 35 |
+
decoder: torch.nn.Module = None,
|
| 36 |
+
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
| 37 |
+
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
| 38 |
+
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
| 39 |
+
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
| 40 |
+
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
| 41 |
+
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
| 42 |
+
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.input_size = input_size
|
| 45 |
+
self.output_size = output_size
|
| 46 |
+
self.decoder_conf = decoder_conf
|
| 47 |
+
self.mel_feat_conf = mel_feat_conf
|
| 48 |
+
self.vocab_size = vocab_size
|
| 49 |
+
self.output_type = output_type
|
| 50 |
+
self.input_frame_rate = input_frame_rate
|
| 51 |
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
| 52 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
| 53 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
| 54 |
+
self.encoder = encoder
|
| 55 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
| 56 |
+
self.decoder = decoder
|
| 57 |
+
self.length_regulator = length_regulator
|
| 58 |
+
self.only_mask_loss = only_mask_loss
|
| 59 |
+
|
| 60 |
+
def forward(
|
| 61 |
+
self,
|
| 62 |
+
batch: dict,
|
| 63 |
+
device: torch.device,
|
| 64 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
| 65 |
+
token = batch['speech_token'].to(device)
|
| 66 |
+
token_len = batch['speech_token_len'].to(device)
|
| 67 |
+
feat = batch['speech_feat'].to(device)
|
| 68 |
+
feat_len = batch['speech_feat_len'].to(device)
|
| 69 |
+
embedding = batch['embedding'].to(device)
|
| 70 |
+
|
| 71 |
+
# xvec projection
|
| 72 |
+
embedding = F.normalize(embedding, dim=1)
|
| 73 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
| 74 |
+
|
| 75 |
+
# concat text and prompt_text
|
| 76 |
+
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
| 77 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 78 |
+
|
| 79 |
+
# text encode
|
| 80 |
+
h, h_lengths = self.encoder(token, token_len)
|
| 81 |
+
h = self.encoder_proj(h)
|
| 82 |
+
h, h_lengths = self.length_regulator(h, feat_len)
|
| 83 |
+
|
| 84 |
+
# get conditions
|
| 85 |
+
conds = torch.zeros(feat.shape, device=token.device)
|
| 86 |
+
for i, j in enumerate(feat_len):
|
| 87 |
+
if random.random() < 0.5:
|
| 88 |
+
continue
|
| 89 |
+
index = random.randint(0, int(0.3 * j))
|
| 90 |
+
conds[i, :index] = feat[i, :index]
|
| 91 |
+
conds = conds.transpose(1, 2)
|
| 92 |
+
|
| 93 |
+
mask = (~make_pad_mask(feat_len)).to(h)
|
| 94 |
+
# NOTE this is unnecessary, feat/h already same shape
|
| 95 |
+
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
| 96 |
+
loss, _ = self.decoder.compute_loss(
|
| 97 |
+
feat.transpose(1, 2).contiguous(),
|
| 98 |
+
mask.unsqueeze(1),
|
| 99 |
+
h.transpose(1, 2).contiguous(),
|
| 100 |
+
embedding,
|
| 101 |
+
cond=conds
|
| 102 |
+
)
|
| 103 |
+
return {'loss': loss}
|
| 104 |
+
|
| 105 |
+
@torch.inference_mode()
|
| 106 |
+
def inference(self,
|
| 107 |
+
token,
|
| 108 |
+
token_len,
|
| 109 |
+
prompt_token,
|
| 110 |
+
prompt_token_len,
|
| 111 |
+
prompt_feat,
|
| 112 |
+
prompt_feat_len,
|
| 113 |
+
embedding,
|
| 114 |
+
flow_cache):
|
| 115 |
+
assert token.shape[0] == 1
|
| 116 |
+
# xvec projection
|
| 117 |
+
embedding = F.normalize(embedding, dim=1)
|
| 118 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
| 119 |
+
|
| 120 |
+
# concat speech token and prompt speech token
|
| 121 |
+
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
| 122 |
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
| 123 |
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
| 124 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 125 |
+
|
| 126 |
+
# text encode
|
| 127 |
+
h, h_lengths = self.encoder(token, token_len)
|
| 128 |
+
h = self.encoder_proj(h)
|
| 129 |
+
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
|
| 130 |
+
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
| 131 |
+
|
| 132 |
+
# get conditions
|
| 133 |
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
| 134 |
+
conds[:, :mel_len1] = prompt_feat
|
| 135 |
+
conds = conds.transpose(1, 2)
|
| 136 |
+
|
| 137 |
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
| 138 |
+
feat, flow_cache = self.decoder(
|
| 139 |
+
mu=h.transpose(1, 2).contiguous(),
|
| 140 |
+
mask=mask.unsqueeze(1),
|
| 141 |
+
spks=embedding,
|
| 142 |
+
cond=conds,
|
| 143 |
+
n_timesteps=10,
|
| 144 |
+
prompt_len=mel_len1,
|
| 145 |
+
cache=flow_cache
|
| 146 |
+
)
|
| 147 |
+
feat = feat[:, :, mel_len1:]
|
| 148 |
+
assert feat.shape[2] == mel_len2
|
| 149 |
+
return feat.float(), flow_cache
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
| 153 |
+
def __init__(self,
|
| 154 |
+
input_size: int = 512,
|
| 155 |
+
output_size: int = 80,
|
| 156 |
+
spk_embed_dim: int = 192,
|
| 157 |
+
output_type: str = "mel",
|
| 158 |
+
vocab_size: int = 4096,
|
| 159 |
+
input_frame_rate: int = 50,
|
| 160 |
+
only_mask_loss: bool = True,
|
| 161 |
+
token_mel_ratio: int = 2,
|
| 162 |
+
pre_lookahead_len: int = 3,
|
| 163 |
+
encoder: torch.nn.Module = None,
|
| 164 |
+
decoder: torch.nn.Module = None,
|
| 165 |
+
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
| 166 |
+
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
| 167 |
+
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
| 168 |
+
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
| 169 |
+
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
| 170 |
+
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
| 171 |
+
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
| 172 |
+
super().__init__()
|
| 173 |
+
self.input_size = input_size
|
| 174 |
+
self.output_size = output_size
|
| 175 |
+
self.decoder_conf = decoder_conf
|
| 176 |
+
self.mel_feat_conf = mel_feat_conf
|
| 177 |
+
self.vocab_size = vocab_size
|
| 178 |
+
self.output_type = output_type
|
| 179 |
+
self.input_frame_rate = input_frame_rate
|
| 180 |
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
| 181 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
| 182 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
| 183 |
+
self.encoder = encoder
|
| 184 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
| 185 |
+
self.decoder = decoder
|
| 186 |
+
self.only_mask_loss = only_mask_loss
|
| 187 |
+
self.token_mel_ratio = token_mel_ratio
|
| 188 |
+
self.pre_lookahead_len = pre_lookahead_len
|
| 189 |
+
|
| 190 |
+
def forward(
|
| 191 |
+
self,
|
| 192 |
+
batch: dict,
|
| 193 |
+
device: torch.device,
|
| 194 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
| 195 |
+
token = batch['speech_token'].to(device)
|
| 196 |
+
token_len = batch['speech_token_len'].to(device)
|
| 197 |
+
feat = batch['speech_feat'].to(device)
|
| 198 |
+
feat_len = batch['speech_feat_len'].to(device)
|
| 199 |
+
embedding = batch['embedding'].to(device)
|
| 200 |
+
|
| 201 |
+
# NOTE unified training, static_chunk_size > 0 or = 0
|
| 202 |
+
streaming = True if random.random() < 0.5 else False
|
| 203 |
+
|
| 204 |
+
# xvec projection
|
| 205 |
+
embedding = F.normalize(embedding, dim=1)
|
| 206 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
| 207 |
+
|
| 208 |
+
# concat text and prompt_text
|
| 209 |
+
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
| 210 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 211 |
+
|
| 212 |
+
# text encode
|
| 213 |
+
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
| 214 |
+
h = self.encoder_proj(h)
|
| 215 |
+
|
| 216 |
+
# get conditions
|
| 217 |
+
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
| 218 |
+
conds = torch.zeros(feat.shape, device=token.device)
|
| 219 |
+
for i, j in enumerate(feat_len):
|
| 220 |
+
if random.random() < 0.5:
|
| 221 |
+
continue
|
| 222 |
+
index = random.randint(0, int(0.3 * j))
|
| 223 |
+
conds[i, :index] = feat[i, :index]
|
| 224 |
+
conds = conds.transpose(1, 2)
|
| 225 |
+
|
| 226 |
+
mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
|
| 227 |
+
loss, _ = self.decoder.compute_loss(
|
| 228 |
+
feat.transpose(1, 2).contiguous(),
|
| 229 |
+
mask.unsqueeze(1),
|
| 230 |
+
h.transpose(1, 2).contiguous(),
|
| 231 |
+
embedding,
|
| 232 |
+
cond=conds,
|
| 233 |
+
streaming=streaming,
|
| 234 |
+
)
|
| 235 |
+
return {'loss': loss}
|
| 236 |
+
|
| 237 |
+
@torch.inference_mode()
|
| 238 |
+
def inference(self,
|
| 239 |
+
token,
|
| 240 |
+
token_len,
|
| 241 |
+
prompt_token,
|
| 242 |
+
prompt_token_len,
|
| 243 |
+
prompt_feat,
|
| 244 |
+
prompt_feat_len,
|
| 245 |
+
embedding,
|
| 246 |
+
cache,
|
| 247 |
+
finalize):
|
| 248 |
+
assert token.shape[0] == 1
|
| 249 |
+
# xvec projection
|
| 250 |
+
embedding = F.normalize(embedding, dim=1)
|
| 251 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
| 252 |
+
|
| 253 |
+
# concat text and prompt_text
|
| 254 |
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
| 255 |
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
| 256 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 257 |
+
|
| 258 |
+
# text encode
|
| 259 |
+
if finalize is True:
|
| 260 |
+
h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, **cache['encoder_cache'])
|
| 261 |
+
else:
|
| 262 |
+
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
|
| 263 |
+
h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, context=context, **cache['encoder_cache'])
|
| 264 |
+
cache['encoder_cache']['offset'] = encoder_cache[0]
|
| 265 |
+
cache['encoder_cache']['pre_lookahead_layer_conv2_cache'] = encoder_cache[1]
|
| 266 |
+
cache['encoder_cache']['encoders_kv_cache'] = encoder_cache[2]
|
| 267 |
+
cache['encoder_cache']['upsample_offset'] = encoder_cache[3]
|
| 268 |
+
cache['encoder_cache']['upsample_conv_cache'] = encoder_cache[4]
|
| 269 |
+
cache['encoder_cache']['upsample_kv_cache'] = encoder_cache[5]
|
| 270 |
+
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
| 271 |
+
h = self.encoder_proj(h)
|
| 272 |
+
|
| 273 |
+
# get conditions
|
| 274 |
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
| 275 |
+
conds[:, :mel_len1] = prompt_feat
|
| 276 |
+
conds = conds.transpose(1, 2)
|
| 277 |
+
|
| 278 |
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
| 279 |
+
feat, cache['decoder_cache'] = self.decoder(
|
| 280 |
+
mu=h.transpose(1, 2).contiguous(),
|
| 281 |
+
mask=mask.unsqueeze(1),
|
| 282 |
+
spks=embedding,
|
| 283 |
+
cond=conds,
|
| 284 |
+
n_timesteps=10,
|
| 285 |
+
cache=cache['decoder_cache']
|
| 286 |
+
)
|
| 287 |
+
feat = feat[:, :, mel_len1:]
|
| 288 |
+
assert feat.shape[2] == mel_len2
|
| 289 |
+
return feat.float(), cache
|
model/cosyvoice/flow/flow_matching.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import threading
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from matcha.models.components.flow_matching import BASECFM
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ConditionalCFM(BASECFM):
|
| 21 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
| 22 |
+
super().__init__(
|
| 23 |
+
n_feats=in_channels,
|
| 24 |
+
cfm_params=cfm_params,
|
| 25 |
+
n_spks=n_spks,
|
| 26 |
+
spk_emb_dim=spk_emb_dim,
|
| 27 |
+
)
|
| 28 |
+
self.t_scheduler = cfm_params.t_scheduler
|
| 29 |
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
| 30 |
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
| 31 |
+
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
| 32 |
+
# Just change the architecture of the estimator here
|
| 33 |
+
self.estimator = estimator
|
| 34 |
+
self.lock = threading.Lock()
|
| 35 |
+
|
| 36 |
+
@torch.inference_mode()
|
| 37 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
|
| 38 |
+
"""Forward diffusion
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
mu (torch.Tensor): output of encoder
|
| 42 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 43 |
+
mask (torch.Tensor): output_mask
|
| 44 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 45 |
+
n_timesteps (int): number of diffusion steps
|
| 46 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| 47 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 48 |
+
shape: (batch_size, spk_emb_dim)
|
| 49 |
+
cond: Not used but kept for future purposes
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
sample: generated mel-spectrogram
|
| 53 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
|
| 57 |
+
cache_size = cache.shape[2]
|
| 58 |
+
# fix prompt and overlap part mu and z
|
| 59 |
+
if cache_size != 0:
|
| 60 |
+
z[:, :, :cache_size] = cache[:, :, :, 0]
|
| 61 |
+
mu[:, :, :cache_size] = cache[:, :, :, 1]
|
| 62 |
+
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
|
| 63 |
+
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
|
| 64 |
+
cache = torch.stack([z_cache, mu_cache], dim=-1)
|
| 65 |
+
|
| 66 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
| 67 |
+
if self.t_scheduler == 'cosine':
|
| 68 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
| 69 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
|
| 70 |
+
|
| 71 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
| 72 |
+
"""
|
| 73 |
+
Fixed euler solver for ODEs.
|
| 74 |
+
Args:
|
| 75 |
+
x (torch.Tensor): random noise
|
| 76 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
| 77 |
+
shape: (n_timesteps + 1,)
|
| 78 |
+
mu (torch.Tensor): output of encoder
|
| 79 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 80 |
+
mask (torch.Tensor): output_mask
|
| 81 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 82 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 83 |
+
shape: (batch_size, spk_emb_dim)
|
| 84 |
+
cond: Not used but kept for future purposes
|
| 85 |
+
"""
|
| 86 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| 87 |
+
t = t.unsqueeze(dim=0)
|
| 88 |
+
|
| 89 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
| 90 |
+
# Or in future might add like a return_all_steps flag
|
| 91 |
+
sol = []
|
| 92 |
+
|
| 93 |
+
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
| 94 |
+
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
| 95 |
+
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
| 96 |
+
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
| 97 |
+
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
| 98 |
+
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
| 99 |
+
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
| 100 |
+
for step in range(1, len(t_span)):
|
| 101 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
| 102 |
+
x_in[:] = x
|
| 103 |
+
mask_in[:] = mask
|
| 104 |
+
mu_in[0] = mu
|
| 105 |
+
t_in[:] = t.unsqueeze(0)
|
| 106 |
+
spks_in[0] = spks
|
| 107 |
+
cond_in[0] = cond
|
| 108 |
+
dphi_dt = self.forward_estimator(
|
| 109 |
+
x_in, mask_in,
|
| 110 |
+
mu_in, t_in,
|
| 111 |
+
spks_in,
|
| 112 |
+
cond_in
|
| 113 |
+
)
|
| 114 |
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
| 115 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
| 116 |
+
x = x + dt * dphi_dt
|
| 117 |
+
t = t + dt
|
| 118 |
+
sol.append(x)
|
| 119 |
+
if step < len(t_span) - 1:
|
| 120 |
+
dt = t_span[step + 1] - t
|
| 121 |
+
|
| 122 |
+
return sol[-1].float()
|
| 123 |
+
|
| 124 |
+
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
| 125 |
+
if isinstance(self.estimator, torch.nn.Module):
|
| 126 |
+
return self.estimator(x, mask, mu, t, spks, cond)
|
| 127 |
+
else:
|
| 128 |
+
with self.lock:
|
| 129 |
+
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
| 130 |
+
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
| 131 |
+
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
| 132 |
+
self.estimator.set_input_shape('t', (2,))
|
| 133 |
+
self.estimator.set_input_shape('spks', (2, 80))
|
| 134 |
+
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
| 135 |
+
# run trt engine
|
| 136 |
+
assert self.estimator.execute_v2([x.contiguous().data_ptr(),
|
| 137 |
+
mask.contiguous().data_ptr(),
|
| 138 |
+
mu.contiguous().data_ptr(),
|
| 139 |
+
t.contiguous().data_ptr(),
|
| 140 |
+
spks.contiguous().data_ptr(),
|
| 141 |
+
cond.contiguous().data_ptr(),
|
| 142 |
+
x.data_ptr()]) is True
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
|
| 146 |
+
"""Computes diffusion loss
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
x1 (torch.Tensor): Target
|
| 150 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 151 |
+
mask (torch.Tensor): target mask
|
| 152 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 153 |
+
mu (torch.Tensor): output of encoder
|
| 154 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 155 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
| 156 |
+
shape: (batch_size, spk_emb_dim)
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
loss: conditional flow matching loss
|
| 160 |
+
y: conditional flow
|
| 161 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 162 |
+
"""
|
| 163 |
+
b, _, t = mu.shape
|
| 164 |
+
|
| 165 |
+
# random timestep
|
| 166 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
| 167 |
+
if self.t_scheduler == 'cosine':
|
| 168 |
+
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
| 169 |
+
# sample noise p(x_0)
|
| 170 |
+
z = torch.randn_like(x1)
|
| 171 |
+
|
| 172 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
| 173 |
+
u = x1 - (1 - self.sigma_min) * z
|
| 174 |
+
|
| 175 |
+
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
| 176 |
+
if self.training_cfg_rate > 0:
|
| 177 |
+
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
| 178 |
+
mu = mu * cfg_mask.view(-1, 1, 1)
|
| 179 |
+
spks = spks * cfg_mask.view(-1, 1)
|
| 180 |
+
cond = cond * cfg_mask.view(-1, 1, 1)
|
| 181 |
+
|
| 182 |
+
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
|
| 183 |
+
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
| 184 |
+
return loss, y
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class CausalConditionalCFM(ConditionalCFM):
|
| 188 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
| 189 |
+
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
| 190 |
+
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
| 191 |
+
|
| 192 |
+
@torch.inference_mode()
|
| 193 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, cache={}):
|
| 194 |
+
"""Forward diffusion
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
mu (torch.Tensor): output of encoder
|
| 198 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 199 |
+
mask (torch.Tensor): output_mask
|
| 200 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 201 |
+
n_timesteps (int): number of diffusion steps
|
| 202 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| 203 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 204 |
+
shape: (batch_size, spk_emb_dim)
|
| 205 |
+
cond: Not used but kept for future purposes
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
sample: generated mel-spectrogram
|
| 209 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
offset = cache.pop('offset')
|
| 213 |
+
z = self.rand_noise[:, :, :mu.size(2) + offset].to(mu.device).to(mu.dtype) * temperature
|
| 214 |
+
z = z[:, :, offset:]
|
| 215 |
+
offset += mu.size(2)
|
| 216 |
+
# fix prompt and overlap part mu and z
|
| 217 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
| 218 |
+
if self.t_scheduler == 'cosine':
|
| 219 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
| 220 |
+
mel, cache = self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, cache=cache)
|
| 221 |
+
cache['offset'] = offset
|
| 222 |
+
return mel, cache
|
| 223 |
+
|
| 224 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond, cache):
|
| 225 |
+
"""
|
| 226 |
+
Fixed euler solver for ODEs.
|
| 227 |
+
Args:
|
| 228 |
+
x (torch.Tensor): random noise
|
| 229 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
| 230 |
+
shape: (n_timesteps + 1,)
|
| 231 |
+
mu (torch.Tensor): output of encoder
|
| 232 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 233 |
+
mask (torch.Tensor): output_mask
|
| 234 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 235 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 236 |
+
shape: (batch_size, spk_emb_dim)
|
| 237 |
+
cond: Not used but kept for future purposes
|
| 238 |
+
"""
|
| 239 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| 240 |
+
t = t.unsqueeze(dim=0)
|
| 241 |
+
|
| 242 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
| 243 |
+
# Or in future might add like a return_all_steps flag
|
| 244 |
+
sol = []
|
| 245 |
+
|
| 246 |
+
# estimator cache for each step
|
| 247 |
+
down_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x)
|
| 248 |
+
mid_blocks_kv_cache_new = torch.zeros(10, 12, 4, 2, x.size(2), 512, 2).to(x)
|
| 249 |
+
up_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x)
|
| 250 |
+
|
| 251 |
+
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
| 252 |
+
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
| 253 |
+
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
| 254 |
+
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
| 255 |
+
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
| 256 |
+
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
| 257 |
+
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
| 258 |
+
for step in range(1, len(t_span)):
|
| 259 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
| 260 |
+
x_in[:] = x
|
| 261 |
+
mask_in[:] = mask
|
| 262 |
+
mu_in[0] = mu
|
| 263 |
+
t_in[:] = t.unsqueeze(0)
|
| 264 |
+
spks_in[0] = spks
|
| 265 |
+
cond_in[0] = cond
|
| 266 |
+
cache_step = {k: v[step - 1] for k, v in cache.items()}
|
| 267 |
+
dphi_dt, cache_step = self.forward_estimator(
|
| 268 |
+
x_in, mask_in,
|
| 269 |
+
mu_in, t_in,
|
| 270 |
+
spks_in,
|
| 271 |
+
cond_in,
|
| 272 |
+
cache_step
|
| 273 |
+
)
|
| 274 |
+
cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
|
| 275 |
+
down_blocks_kv_cache_new[step - 1] = cache_step[1]
|
| 276 |
+
cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
|
| 277 |
+
mid_blocks_kv_cache_new[step - 1] = cache_step[3]
|
| 278 |
+
cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
|
| 279 |
+
up_blocks_kv_cache_new[step - 1] = cache_step[5]
|
| 280 |
+
cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
|
| 281 |
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
| 282 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
| 283 |
+
x = x + dt * dphi_dt
|
| 284 |
+
t = t + dt
|
| 285 |
+
sol.append(x)
|
| 286 |
+
if step < len(t_span) - 1:
|
| 287 |
+
dt = t_span[step + 1] - t
|
| 288 |
+
cache['down_blocks_kv_cache'] = torch.concat([cache['down_blocks_kv_cache'], down_blocks_kv_cache_new], dim=4)
|
| 289 |
+
cache['mid_blocks_kv_cache'] = torch.concat([cache['mid_blocks_kv_cache'], mid_blocks_kv_cache_new], dim=4)
|
| 290 |
+
cache['up_blocks_kv_cache'] = torch.concat([cache['up_blocks_kv_cache'], up_blocks_kv_cache_new], dim=4)
|
| 291 |
+
return sol[-1].float(), cache
|
| 292 |
+
|
| 293 |
+
def forward_estimator(self, x, mask, mu, t, spks, cond, cache):
|
| 294 |
+
if isinstance(self.estimator, torch.nn.Module):
|
| 295 |
+
x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache)
|
| 296 |
+
cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7)
|
| 297 |
+
else:
|
| 298 |
+
with self.lock:
|
| 299 |
+
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
| 300 |
+
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
| 301 |
+
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
| 302 |
+
self.estimator.set_input_shape('t', (2,))
|
| 303 |
+
self.estimator.set_input_shape('spks', (2, 80))
|
| 304 |
+
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
| 305 |
+
self.estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape)
|
| 306 |
+
self.estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape)
|
| 307 |
+
self.estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape)
|
| 308 |
+
self.estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape)
|
| 309 |
+
self.estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape)
|
| 310 |
+
self.estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape)
|
| 311 |
+
self.estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape)
|
| 312 |
+
# run trt engine
|
| 313 |
+
down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
|
| 314 |
+
mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x)
|
| 315 |
+
up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
|
| 316 |
+
assert self.estimator.execute_v2([x.contiguous().data_ptr(),
|
| 317 |
+
mask.contiguous().data_ptr(),
|
| 318 |
+
mu.contiguous().data_ptr(),
|
| 319 |
+
t.contiguous().data_ptr(),
|
| 320 |
+
spks.contiguous().data_ptr(),
|
| 321 |
+
cond.contiguous().data_ptr(),
|
| 322 |
+
cache['down_blocks_conv_cache'].contiguous().data_ptr(),
|
| 323 |
+
cache['down_blocks_kv_cache'].contiguous().data_ptr(),
|
| 324 |
+
cache['mid_blocks_conv_cache'].contiguous().data_ptr(),
|
| 325 |
+
cache['mid_blocks_kv_cache'].contiguous().data_ptr(),
|
| 326 |
+
cache['up_blocks_conv_cache'].contiguous().data_ptr(),
|
| 327 |
+
cache['up_blocks_kv_cache'].contiguous().data_ptr(),
|
| 328 |
+
cache['final_blocks_conv_cache'].contiguous().data_ptr(),
|
| 329 |
+
x.data_ptr(),
|
| 330 |
+
cache['down_blocks_conv_cache'].data_ptr(),
|
| 331 |
+
down_blocks_kv_cache_out.data_ptr(),
|
| 332 |
+
cache['mid_blocks_conv_cache'].data_ptr(),
|
| 333 |
+
mid_blocks_kv_cache_out.data_ptr(),
|
| 334 |
+
cache['up_blocks_conv_cache'].data_ptr(),
|
| 335 |
+
up_blocks_kv_cache_out.data_ptr(),
|
| 336 |
+
cache['final_blocks_conv_cache'].data_ptr()]) is True
|
| 337 |
+
cache = (cache['down_blocks_conv_cache'],
|
| 338 |
+
down_blocks_kv_cache_out,
|
| 339 |
+
cache['mid_blocks_conv_cache'],
|
| 340 |
+
mid_blocks_kv_cache_out,
|
| 341 |
+
cache['up_blocks_conv_cache'],
|
| 342 |
+
up_blocks_kv_cache_out,
|
| 343 |
+
cache['final_blocks_conv_cache'])
|
| 344 |
+
return x, cache
|
model/cosyvoice/flow/length_regulator.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import Tuple
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch
|
| 17 |
+
from torch.nn import functional as F
|
| 18 |
+
from cosyvoice.utils.mask import make_pad_mask
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class InterpolateRegulator(nn.Module):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
channels: int,
|
| 25 |
+
sampling_ratios: Tuple,
|
| 26 |
+
out_channels: int = None,
|
| 27 |
+
groups: int = 1,
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.sampling_ratios = sampling_ratios
|
| 31 |
+
out_channels = out_channels or channels
|
| 32 |
+
model = nn.ModuleList([])
|
| 33 |
+
if len(sampling_ratios) > 0:
|
| 34 |
+
for _ in sampling_ratios:
|
| 35 |
+
module = nn.Conv1d(channels, channels, 3, 1, 1)
|
| 36 |
+
norm = nn.GroupNorm(groups, channels)
|
| 37 |
+
act = nn.Mish()
|
| 38 |
+
model.extend([module, norm, act])
|
| 39 |
+
model.append(
|
| 40 |
+
nn.Conv1d(channels, out_channels, 1, 1)
|
| 41 |
+
)
|
| 42 |
+
self.model = nn.Sequential(*model)
|
| 43 |
+
|
| 44 |
+
def forward(self, x, ylens=None):
|
| 45 |
+
# x in (B, T, D)
|
| 46 |
+
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
|
| 47 |
+
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
|
| 48 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
| 49 |
+
olens = ylens
|
| 50 |
+
return out * mask, olens
|
| 51 |
+
|
| 52 |
+
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
|
| 53 |
+
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
|
| 54 |
+
# NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py
|
| 55 |
+
# x in (B, T, D)
|
| 56 |
+
if x2.shape[1] > 40:
|
| 57 |
+
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
| 58 |
+
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
|
| 59 |
+
mode='linear')
|
| 60 |
+
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
| 61 |
+
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
|
| 62 |
+
else:
|
| 63 |
+
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
|
| 64 |
+
if x1.shape[1] != 0:
|
| 65 |
+
x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
|
| 66 |
+
x = torch.concat([x1, x2], dim=2)
|
| 67 |
+
else:
|
| 68 |
+
x = x2
|
| 69 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
| 70 |
+
return out, mel_len1 + mel_len2
|
model/cosyvoice/hifigan/__pycache__/discriminator.cpython-310.pyc
ADDED
|
Binary file (8.74 kB). View file
|
|
|
model/cosyvoice/hifigan/__pycache__/discriminator.cpython-311.pyc
ADDED
|
Binary file (16.1 kB). View file
|
|
|
model/cosyvoice/hifigan/__pycache__/f0_predictor.cpython-310.pyc
ADDED
|
Binary file (1.44 kB). View file
|
|
|
model/cosyvoice/hifigan/__pycache__/f0_predictor.cpython-311.pyc
ADDED
|
Binary file (2.82 kB). View file
|
|
|
model/cosyvoice/hifigan/__pycache__/generator.cpython-310.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
model/cosyvoice/hifigan/__pycache__/generator.cpython-311.pyc
ADDED
|
Binary file (23.4 kB). View file
|
|
|
model/cosyvoice/hifigan/__pycache__/hifigan.cpython-310.pyc
ADDED
|
Binary file (2.58 kB). View file
|
|
|
model/cosyvoice/hifigan/__pycache__/hifigan.cpython-311.pyc
ADDED
|
Binary file (4.69 kB). View file
|
|
|
model/cosyvoice/hifigan/discriminator.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
try:
|
| 5 |
+
from torch.nn.utils.parametrizations import weight_norm, spectral_norm
|
| 6 |
+
except ImportError:
|
| 7 |
+
from torch.nn.utils import weight_norm, spectral_norm
|
| 8 |
+
from typing import List, Optional, Tuple
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
from torchaudio.transforms import Spectrogram
|
| 11 |
+
|
| 12 |
+
LRELU_SLOPE = 0.1
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MultipleDiscriminator(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self, mpd: nn.Module, mrd: nn.Module
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.mpd = mpd
|
| 21 |
+
self.mrd = mrd
|
| 22 |
+
|
| 23 |
+
def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
|
| 24 |
+
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
|
| 25 |
+
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
|
| 26 |
+
y_d_rs += this_y_d_rs
|
| 27 |
+
y_d_gs += this_y_d_gs
|
| 28 |
+
fmap_rs += this_fmap_rs
|
| 29 |
+
fmap_gs += this_fmap_gs
|
| 30 |
+
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
|
| 31 |
+
y_d_rs += this_y_d_rs
|
| 32 |
+
y_d_gs += this_y_d_gs
|
| 33 |
+
fmap_rs += this_fmap_rs
|
| 34 |
+
fmap_gs += this_fmap_gs
|
| 35 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class MultiResolutionDiscriminator(nn.Module):
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
|
| 42 |
+
num_embeddings: Optional[int] = None,
|
| 43 |
+
):
|
| 44 |
+
"""
|
| 45 |
+
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
|
| 46 |
+
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
|
| 50 |
+
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
| 51 |
+
Defaults to None.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.discriminators = nn.ModuleList(
|
| 56 |
+
[DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def forward(
|
| 60 |
+
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
|
| 61 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
|
| 62 |
+
y_d_rs = []
|
| 63 |
+
y_d_gs = []
|
| 64 |
+
fmap_rs = []
|
| 65 |
+
fmap_gs = []
|
| 66 |
+
|
| 67 |
+
for d in self.discriminators:
|
| 68 |
+
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
| 69 |
+
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
| 70 |
+
y_d_rs.append(y_d_r)
|
| 71 |
+
fmap_rs.append(fmap_r)
|
| 72 |
+
y_d_gs.append(y_d_g)
|
| 73 |
+
fmap_gs.append(fmap_g)
|
| 74 |
+
|
| 75 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class DiscriminatorR(nn.Module):
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
window_length: int,
|
| 82 |
+
num_embeddings: Optional[int] = None,
|
| 83 |
+
channels: int = 32,
|
| 84 |
+
hop_factor: float = 0.25,
|
| 85 |
+
bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
|
| 86 |
+
):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.window_length = window_length
|
| 89 |
+
self.hop_factor = hop_factor
|
| 90 |
+
self.spec_fn = Spectrogram(
|
| 91 |
+
n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
|
| 92 |
+
)
|
| 93 |
+
n_fft = window_length // 2 + 1
|
| 94 |
+
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
| 95 |
+
self.bands = bands
|
| 96 |
+
convs = lambda: nn.ModuleList(
|
| 97 |
+
[
|
| 98 |
+
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
| 99 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
| 100 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
| 101 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
| 102 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
|
| 103 |
+
]
|
| 104 |
+
)
|
| 105 |
+
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
| 106 |
+
|
| 107 |
+
if num_embeddings is not None:
|
| 108 |
+
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
|
| 109 |
+
torch.nn.init.zeros_(self.emb.weight)
|
| 110 |
+
|
| 111 |
+
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
|
| 112 |
+
|
| 113 |
+
def spectrogram(self, x):
|
| 114 |
+
# Remove DC offset
|
| 115 |
+
x = x - x.mean(dim=-1, keepdims=True)
|
| 116 |
+
# Peak normalize the volume of input audio
|
| 117 |
+
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
| 118 |
+
x = self.spec_fn(x)
|
| 119 |
+
x = torch.view_as_real(x)
|
| 120 |
+
x = rearrange(x, "b f t c -> b c t f")
|
| 121 |
+
# Split into bands
|
| 122 |
+
x_bands = [x[..., b[0]: b[1]] for b in self.bands]
|
| 123 |
+
return x_bands
|
| 124 |
+
|
| 125 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
|
| 126 |
+
x_bands = self.spectrogram(x)
|
| 127 |
+
fmap = []
|
| 128 |
+
x = []
|
| 129 |
+
for band, stack in zip(x_bands, self.band_convs):
|
| 130 |
+
for i, layer in enumerate(stack):
|
| 131 |
+
band = layer(band)
|
| 132 |
+
band = torch.nn.functional.leaky_relu(band, 0.1)
|
| 133 |
+
if i > 0:
|
| 134 |
+
fmap.append(band)
|
| 135 |
+
x.append(band)
|
| 136 |
+
x = torch.cat(x, dim=-1)
|
| 137 |
+
if cond_embedding_id is not None:
|
| 138 |
+
emb = self.emb(cond_embedding_id)
|
| 139 |
+
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
| 140 |
+
else:
|
| 141 |
+
h = 0
|
| 142 |
+
x = self.conv_post(x)
|
| 143 |
+
fmap.append(x)
|
| 144 |
+
x += h
|
| 145 |
+
|
| 146 |
+
return x, fmap
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class MultiResSpecDiscriminator(torch.nn.Module):
|
| 150 |
+
|
| 151 |
+
def __init__(self,
|
| 152 |
+
fft_sizes=[1024, 2048, 512],
|
| 153 |
+
hop_sizes=[120, 240, 50],
|
| 154 |
+
win_lengths=[600, 1200, 240],
|
| 155 |
+
window="hann_window"):
|
| 156 |
+
|
| 157 |
+
super(MultiResSpecDiscriminator, self).__init__()
|
| 158 |
+
self.discriminators = nn.ModuleList([
|
| 159 |
+
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
|
| 160 |
+
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
|
| 161 |
+
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)])
|
| 162 |
+
|
| 163 |
+
def forward(self, y, y_hat):
|
| 164 |
+
y_d_rs = []
|
| 165 |
+
y_d_gs = []
|
| 166 |
+
fmap_rs = []
|
| 167 |
+
fmap_gs = []
|
| 168 |
+
for _, d in enumerate(self.discriminators):
|
| 169 |
+
y_d_r, fmap_r = d(y)
|
| 170 |
+
y_d_g, fmap_g = d(y_hat)
|
| 171 |
+
y_d_rs.append(y_d_r)
|
| 172 |
+
fmap_rs.append(fmap_r)
|
| 173 |
+
y_d_gs.append(y_d_g)
|
| 174 |
+
fmap_gs.append(fmap_g)
|
| 175 |
+
|
| 176 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def stft(x, fft_size, hop_size, win_length, window):
|
| 180 |
+
"""Perform STFT and convert to magnitude spectrogram.
|
| 181 |
+
Args:
|
| 182 |
+
x (Tensor): Input signal tensor (B, T).
|
| 183 |
+
fft_size (int): FFT size.
|
| 184 |
+
hop_size (int): Hop size.
|
| 185 |
+
win_length (int): Window length.
|
| 186 |
+
window (str): Window function type.
|
| 187 |
+
Returns:
|
| 188 |
+
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
| 189 |
+
"""
|
| 190 |
+
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
|
| 191 |
+
|
| 192 |
+
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
| 193 |
+
return torch.abs(x_stft).transpose(2, 1)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class SpecDiscriminator(nn.Module):
|
| 197 |
+
"""docstring for Discriminator."""
|
| 198 |
+
|
| 199 |
+
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
|
| 200 |
+
super(SpecDiscriminator, self).__init__()
|
| 201 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
| 202 |
+
self.fft_size = fft_size
|
| 203 |
+
self.shift_size = shift_size
|
| 204 |
+
self.win_length = win_length
|
| 205 |
+
self.window = getattr(torch, window)(win_length)
|
| 206 |
+
self.discriminators = nn.ModuleList([
|
| 207 |
+
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
|
| 208 |
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
| 209 |
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
| 210 |
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
| 211 |
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
|
| 212 |
+
])
|
| 213 |
+
|
| 214 |
+
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
|
| 215 |
+
|
| 216 |
+
def forward(self, y):
|
| 217 |
+
|
| 218 |
+
fmap = []
|
| 219 |
+
y = y.squeeze(1)
|
| 220 |
+
y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.device))
|
| 221 |
+
y = y.unsqueeze(1)
|
| 222 |
+
for _, d in enumerate(self.discriminators):
|
| 223 |
+
y = d(y)
|
| 224 |
+
y = F.leaky_relu(y, LRELU_SLOPE)
|
| 225 |
+
fmap.append(y)
|
| 226 |
+
|
| 227 |
+
y = self.out(y)
|
| 228 |
+
fmap.append(y)
|
| 229 |
+
|
| 230 |
+
return torch.flatten(y, 1, -1), fmap
|
model/cosyvoice/hifigan/f0_predictor.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
try:
|
| 17 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 18 |
+
except ImportError:
|
| 19 |
+
from torch.nn.utils import weight_norm
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ConvRNNF0Predictor(nn.Module):
|
| 23 |
+
def __init__(self,
|
| 24 |
+
num_class: int = 1,
|
| 25 |
+
in_channels: int = 80,
|
| 26 |
+
cond_channels: int = 512
|
| 27 |
+
):
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
self.num_class = num_class
|
| 31 |
+
self.condnet = nn.Sequential(
|
| 32 |
+
weight_norm(
|
| 33 |
+
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
|
| 34 |
+
),
|
| 35 |
+
nn.ELU(),
|
| 36 |
+
weight_norm(
|
| 37 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 38 |
+
),
|
| 39 |
+
nn.ELU(),
|
| 40 |
+
weight_norm(
|
| 41 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 42 |
+
),
|
| 43 |
+
nn.ELU(),
|
| 44 |
+
weight_norm(
|
| 45 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 46 |
+
),
|
| 47 |
+
nn.ELU(),
|
| 48 |
+
weight_norm(
|
| 49 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 50 |
+
),
|
| 51 |
+
nn.ELU(),
|
| 52 |
+
)
|
| 53 |
+
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
| 54 |
+
|
| 55 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 56 |
+
x = self.condnet(x)
|
| 57 |
+
x = x.transpose(1, 2)
|
| 58 |
+
return torch.abs(self.classifier(x).squeeze(-1))
|
model/cosyvoice/hifigan/generator.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""HIFI-GAN"""
|
| 16 |
+
|
| 17 |
+
from typing import Dict, Optional, List
|
| 18 |
+
import numpy as np
|
| 19 |
+
from scipy.signal import get_window
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from torch.nn import Conv1d
|
| 24 |
+
from torch.nn import ConvTranspose1d
|
| 25 |
+
from torch.nn.utils import remove_weight_norm
|
| 26 |
+
try:
|
| 27 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 28 |
+
except ImportError:
|
| 29 |
+
from torch.nn.utils import weight_norm
|
| 30 |
+
from torch.distributions.uniform import Uniform
|
| 31 |
+
|
| 32 |
+
from cosyvoice.transformer.activation import Snake
|
| 33 |
+
from cosyvoice.utils.common import get_padding
|
| 34 |
+
from cosyvoice.utils.common import init_weights
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
"""hifigan based generator implementation.
|
| 38 |
+
|
| 39 |
+
This code is modified from https://github.com/jik876/hifi-gan
|
| 40 |
+
,https://github.com/kan-bayashi/ParallelWaveGAN and
|
| 41 |
+
https://github.com/NVIDIA/BigVGAN
|
| 42 |
+
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ResBlock(torch.nn.Module):
|
| 47 |
+
"""Residual block module in HiFiGAN/BigVGAN."""
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
channels: int = 512,
|
| 51 |
+
kernel_size: int = 3,
|
| 52 |
+
dilations: List[int] = [1, 3, 5],
|
| 53 |
+
):
|
| 54 |
+
super(ResBlock, self).__init__()
|
| 55 |
+
self.convs1 = nn.ModuleList()
|
| 56 |
+
self.convs2 = nn.ModuleList()
|
| 57 |
+
|
| 58 |
+
for dilation in dilations:
|
| 59 |
+
self.convs1.append(
|
| 60 |
+
weight_norm(
|
| 61 |
+
Conv1d(
|
| 62 |
+
channels,
|
| 63 |
+
channels,
|
| 64 |
+
kernel_size,
|
| 65 |
+
1,
|
| 66 |
+
dilation=dilation,
|
| 67 |
+
padding=get_padding(kernel_size, dilation)
|
| 68 |
+
)
|
| 69 |
+
)
|
| 70 |
+
)
|
| 71 |
+
self.convs2.append(
|
| 72 |
+
weight_norm(
|
| 73 |
+
Conv1d(
|
| 74 |
+
channels,
|
| 75 |
+
channels,
|
| 76 |
+
kernel_size,
|
| 77 |
+
1,
|
| 78 |
+
dilation=1,
|
| 79 |
+
padding=get_padding(kernel_size, 1)
|
| 80 |
+
)
|
| 81 |
+
)
|
| 82 |
+
)
|
| 83 |
+
self.convs1.apply(init_weights)
|
| 84 |
+
self.convs2.apply(init_weights)
|
| 85 |
+
self.activations1 = nn.ModuleList([
|
| 86 |
+
Snake(channels, alpha_logscale=False)
|
| 87 |
+
for _ in range(len(self.convs1))
|
| 88 |
+
])
|
| 89 |
+
self.activations2 = nn.ModuleList([
|
| 90 |
+
Snake(channels, alpha_logscale=False)
|
| 91 |
+
for _ in range(len(self.convs2))
|
| 92 |
+
])
|
| 93 |
+
|
| 94 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 95 |
+
for idx in range(len(self.convs1)):
|
| 96 |
+
xt = self.activations1[idx](x)
|
| 97 |
+
xt = self.convs1[idx](xt)
|
| 98 |
+
xt = self.activations2[idx](xt)
|
| 99 |
+
xt = self.convs2[idx](xt)
|
| 100 |
+
x = xt + x
|
| 101 |
+
return x
|
| 102 |
+
|
| 103 |
+
def remove_weight_norm(self):
|
| 104 |
+
for idx in range(len(self.convs1)):
|
| 105 |
+
remove_weight_norm(self.convs1[idx])
|
| 106 |
+
remove_weight_norm(self.convs2[idx])
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class SineGen(torch.nn.Module):
|
| 110 |
+
""" Definition of sine generator
|
| 111 |
+
SineGen(samp_rate, harmonic_num = 0,
|
| 112 |
+
sine_amp = 0.1, noise_std = 0.003,
|
| 113 |
+
voiced_threshold = 0,
|
| 114 |
+
flag_for_pulse=False)
|
| 115 |
+
samp_rate: sampling rate in Hz
|
| 116 |
+
harmonic_num: number of harmonic overtones (default 0)
|
| 117 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
| 118 |
+
noise_std: std of Gaussian noise (default 0.003)
|
| 119 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
| 120 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
| 121 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
| 122 |
+
segment is always sin(np.pi) or cos(0)
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def __init__(self, samp_rate, harmonic_num=0,
|
| 126 |
+
sine_amp=0.1, noise_std=0.003,
|
| 127 |
+
voiced_threshold=0):
|
| 128 |
+
super(SineGen, self).__init__()
|
| 129 |
+
self.sine_amp = sine_amp
|
| 130 |
+
self.noise_std = noise_std
|
| 131 |
+
self.harmonic_num = harmonic_num
|
| 132 |
+
self.sampling_rate = samp_rate
|
| 133 |
+
self.voiced_threshold = voiced_threshold
|
| 134 |
+
|
| 135 |
+
def _f02uv(self, f0):
|
| 136 |
+
# generate uv signal
|
| 137 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
| 138 |
+
return uv
|
| 139 |
+
|
| 140 |
+
@torch.no_grad()
|
| 141 |
+
def forward(self, f0):
|
| 142 |
+
"""
|
| 143 |
+
:param f0: [B, 1, sample_len], Hz
|
| 144 |
+
:return: [B, 1, sample_len]
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
| 148 |
+
for i in range(self.harmonic_num + 1):
|
| 149 |
+
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
|
| 150 |
+
|
| 151 |
+
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
|
| 152 |
+
u_dist = Uniform(low=-np.pi, high=np.pi)
|
| 153 |
+
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
|
| 154 |
+
phase_vec[:, 0, :] = 0
|
| 155 |
+
|
| 156 |
+
# generate sine waveforms
|
| 157 |
+
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
|
| 158 |
+
|
| 159 |
+
# generate uv signal
|
| 160 |
+
uv = self._f02uv(f0)
|
| 161 |
+
|
| 162 |
+
# noise: for unvoiced should be similar to sine_amp
|
| 163 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
| 164 |
+
# . for voiced regions is self.noise_std
|
| 165 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
| 166 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
| 167 |
+
|
| 168 |
+
# first: set the unvoiced part to 0 by uv
|
| 169 |
+
# then: additive noise
|
| 170 |
+
sine_waves = sine_waves * uv + noise
|
| 171 |
+
return sine_waves, uv, noise
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
| 175 |
+
""" SourceModule for hn-nsf
|
| 176 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
| 177 |
+
add_noise_std=0.003, voiced_threshod=0)
|
| 178 |
+
sampling_rate: sampling_rate in Hz
|
| 179 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
| 180 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
| 181 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
| 182 |
+
note that amplitude of noise in unvoiced is decided
|
| 183 |
+
by sine_amp
|
| 184 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
| 185 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
| 186 |
+
F0_sampled (batchsize, length, 1)
|
| 187 |
+
Sine_source (batchsize, length, 1)
|
| 188 |
+
noise_source (batchsize, length 1)
|
| 189 |
+
uv (batchsize, length, 1)
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
| 193 |
+
add_noise_std=0.003, voiced_threshod=0):
|
| 194 |
+
super(SourceModuleHnNSF, self).__init__()
|
| 195 |
+
|
| 196 |
+
self.sine_amp = sine_amp
|
| 197 |
+
self.noise_std = add_noise_std
|
| 198 |
+
|
| 199 |
+
# to produce sine waveforms
|
| 200 |
+
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
| 201 |
+
sine_amp, add_noise_std, voiced_threshod)
|
| 202 |
+
|
| 203 |
+
# to merge source harmonics into a single excitation
|
| 204 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
| 205 |
+
self.l_tanh = torch.nn.Tanh()
|
| 206 |
+
|
| 207 |
+
def forward(self, x):
|
| 208 |
+
"""
|
| 209 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
| 210 |
+
F0_sampled (batchsize, length, 1)
|
| 211 |
+
Sine_source (batchsize, length, 1)
|
| 212 |
+
noise_source (batchsize, length 1)
|
| 213 |
+
"""
|
| 214 |
+
# source for harmonic branch
|
| 215 |
+
with torch.no_grad():
|
| 216 |
+
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
|
| 217 |
+
sine_wavs = sine_wavs.transpose(1, 2)
|
| 218 |
+
uv = uv.transpose(1, 2)
|
| 219 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
| 220 |
+
|
| 221 |
+
# source for noise branch, in the same shape as uv
|
| 222 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
| 223 |
+
return sine_merge, noise, uv
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class HiFTGenerator(nn.Module):
|
| 227 |
+
"""
|
| 228 |
+
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
| 229 |
+
https://arxiv.org/abs/2309.09493
|
| 230 |
+
"""
|
| 231 |
+
def __init__(
|
| 232 |
+
self,
|
| 233 |
+
in_channels: int = 80,
|
| 234 |
+
base_channels: int = 512,
|
| 235 |
+
nb_harmonics: int = 8,
|
| 236 |
+
sampling_rate: int = 22050,
|
| 237 |
+
nsf_alpha: float = 0.1,
|
| 238 |
+
nsf_sigma: float = 0.003,
|
| 239 |
+
nsf_voiced_threshold: float = 10,
|
| 240 |
+
upsample_rates: List[int] = [8, 8],
|
| 241 |
+
upsample_kernel_sizes: List[int] = [16, 16],
|
| 242 |
+
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
| 243 |
+
resblock_kernel_sizes: List[int] = [3, 7, 11],
|
| 244 |
+
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 245 |
+
source_resblock_kernel_sizes: List[int] = [7, 11],
|
| 246 |
+
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
|
| 247 |
+
lrelu_slope: float = 0.1,
|
| 248 |
+
audio_limit: float = 0.99,
|
| 249 |
+
f0_predictor: torch.nn.Module = None,
|
| 250 |
+
):
|
| 251 |
+
super(HiFTGenerator, self).__init__()
|
| 252 |
+
|
| 253 |
+
self.out_channels = 1
|
| 254 |
+
self.nb_harmonics = nb_harmonics
|
| 255 |
+
self.sampling_rate = sampling_rate
|
| 256 |
+
self.istft_params = istft_params
|
| 257 |
+
self.lrelu_slope = lrelu_slope
|
| 258 |
+
self.audio_limit = audio_limit
|
| 259 |
+
|
| 260 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 261 |
+
self.num_upsamples = len(upsample_rates)
|
| 262 |
+
self.m_source = SourceModuleHnNSF(
|
| 263 |
+
sampling_rate=sampling_rate,
|
| 264 |
+
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
| 265 |
+
harmonic_num=nb_harmonics,
|
| 266 |
+
sine_amp=nsf_alpha,
|
| 267 |
+
add_noise_std=nsf_sigma,
|
| 268 |
+
voiced_threshod=nsf_voiced_threshold)
|
| 269 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
| 270 |
+
|
| 271 |
+
self.conv_pre = weight_norm(
|
| 272 |
+
Conv1d(in_channels, base_channels, 7, 1, padding=3)
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# Up
|
| 276 |
+
self.ups = nn.ModuleList()
|
| 277 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 278 |
+
self.ups.append(
|
| 279 |
+
weight_norm(
|
| 280 |
+
ConvTranspose1d(
|
| 281 |
+
base_channels // (2**i),
|
| 282 |
+
base_channels // (2**(i + 1)),
|
| 283 |
+
k,
|
| 284 |
+
u,
|
| 285 |
+
padding=(k - u) // 2,
|
| 286 |
+
)
|
| 287 |
+
)
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# Down
|
| 291 |
+
self.source_downs = nn.ModuleList()
|
| 292 |
+
self.source_resblocks = nn.ModuleList()
|
| 293 |
+
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
| 294 |
+
downsample_cum_rates = np.cumprod(downsample_rates)
|
| 295 |
+
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
|
| 296 |
+
if u == 1:
|
| 297 |
+
self.source_downs.append(
|
| 298 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
| 299 |
+
)
|
| 300 |
+
else:
|
| 301 |
+
self.source_downs.append(
|
| 302 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
self.source_resblocks.append(
|
| 306 |
+
ResBlock(base_channels // (2 ** (i + 1)), k, d)
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
self.resblocks = nn.ModuleList()
|
| 310 |
+
for i in range(len(self.ups)):
|
| 311 |
+
ch = base_channels // (2**(i + 1))
|
| 312 |
+
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
| 313 |
+
self.resblocks.append(ResBlock(ch, k, d))
|
| 314 |
+
|
| 315 |
+
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
|
| 316 |
+
self.ups.apply(init_weights)
|
| 317 |
+
self.conv_post.apply(init_weights)
|
| 318 |
+
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
| 319 |
+
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
| 320 |
+
self.f0_predictor = f0_predictor
|
| 321 |
+
|
| 322 |
+
def remove_weight_norm(self):
|
| 323 |
+
print('Removing weight norm...')
|
| 324 |
+
for l in self.ups:
|
| 325 |
+
remove_weight_norm(l)
|
| 326 |
+
for l in self.resblocks:
|
| 327 |
+
l.remove_weight_norm()
|
| 328 |
+
remove_weight_norm(self.conv_pre)
|
| 329 |
+
remove_weight_norm(self.conv_post)
|
| 330 |
+
self.m_source.remove_weight_norm()
|
| 331 |
+
for l in self.source_downs:
|
| 332 |
+
remove_weight_norm(l)
|
| 333 |
+
for l in self.source_resblocks:
|
| 334 |
+
l.remove_weight_norm()
|
| 335 |
+
|
| 336 |
+
def _stft(self, x):
|
| 337 |
+
spec = torch.stft(
|
| 338 |
+
x,
|
| 339 |
+
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
|
| 340 |
+
return_complex=True)
|
| 341 |
+
spec = torch.view_as_real(spec) # [B, F, TT, 2]
|
| 342 |
+
return spec[..., 0], spec[..., 1]
|
| 343 |
+
|
| 344 |
+
def _istft(self, magnitude, phase):
|
| 345 |
+
magnitude = torch.clip(magnitude, max=1e2)
|
| 346 |
+
real = magnitude * torch.cos(phase)
|
| 347 |
+
img = magnitude * torch.sin(phase)
|
| 348 |
+
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
|
| 349 |
+
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
| 350 |
+
return inverse_transform
|
| 351 |
+
|
| 352 |
+
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
| 353 |
+
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
| 354 |
+
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
| 355 |
+
|
| 356 |
+
x = self.conv_pre(x)
|
| 357 |
+
for i in range(self.num_upsamples):
|
| 358 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
| 359 |
+
x = self.ups[i](x)
|
| 360 |
+
|
| 361 |
+
if i == self.num_upsamples - 1:
|
| 362 |
+
x = self.reflection_pad(x)
|
| 363 |
+
|
| 364 |
+
# fusion
|
| 365 |
+
si = self.source_downs[i](s_stft)
|
| 366 |
+
si = self.source_resblocks[i](si)
|
| 367 |
+
x = x + si
|
| 368 |
+
|
| 369 |
+
xs = None
|
| 370 |
+
for j in range(self.num_kernels):
|
| 371 |
+
if xs is None:
|
| 372 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 373 |
+
else:
|
| 374 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 375 |
+
x = xs / self.num_kernels
|
| 376 |
+
|
| 377 |
+
x = F.leaky_relu(x)
|
| 378 |
+
x = self.conv_post(x)
|
| 379 |
+
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
|
| 380 |
+
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
|
| 381 |
+
|
| 382 |
+
x = self._istft(magnitude, phase)
|
| 383 |
+
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
| 384 |
+
return x
|
| 385 |
+
|
| 386 |
+
def forward(
|
| 387 |
+
self,
|
| 388 |
+
batch: dict,
|
| 389 |
+
device: torch.device,
|
| 390 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
| 391 |
+
speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
|
| 392 |
+
# mel->f0
|
| 393 |
+
f0 = self.f0_predictor(speech_feat)
|
| 394 |
+
# f0->source
|
| 395 |
+
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
| 396 |
+
s, _, _ = self.m_source(s)
|
| 397 |
+
s = s.transpose(1, 2)
|
| 398 |
+
# mel+source->speech
|
| 399 |
+
generated_speech = self.decode(x=speech_feat, s=s)
|
| 400 |
+
return generated_speech, f0
|
| 401 |
+
|
| 402 |
+
@torch.inference_mode()
|
| 403 |
+
def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
| 404 |
+
# mel->f0
|
| 405 |
+
f0 = self.f0_predictor(speech_feat)
|
| 406 |
+
# f0->source
|
| 407 |
+
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
| 408 |
+
s, _, _ = self.m_source(s)
|
| 409 |
+
s = s.transpose(1, 2)
|
| 410 |
+
# use cache_source to avoid glitch
|
| 411 |
+
if cache_source.shape[2] != 0:
|
| 412 |
+
s[:, :, :cache_source.shape[2]] = cache_source
|
| 413 |
+
generated_speech = self.decode(x=speech_feat, s=s)
|
| 414 |
+
return generated_speech, s
|