Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .ipynb_checkpoints/Untitled-checkpoint.ipynb +6 -0
- .ipynb_checkpoints/requirements_vllm-checkpoint.txt +40 -0
- .ipynb_checkpoints/speed_test-checkpoint.ipynb +486 -0
- CODE_OF_CONDUCT.md +76 -0
- FAQ.md +16 -0
- LICENSE +201 -0
- README.md +237 -3
- Untitled.ipynb +60 -0
- asset/dingding.png +0 -0
- cosyvoice/__init__.py +0 -0
- cosyvoice/bin/average_model.py +92 -0
- cosyvoice/bin/export_jit.py +91 -0
- cosyvoice/bin/export_onnx.py +116 -0
- cosyvoice/bin/export_trt.sh +10 -0
- cosyvoice/bin/inference.py +115 -0
- cosyvoice/bin/train.py +170 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +214 -0
- cosyvoice/cli/frontend.py +289 -0
- cosyvoice/cli/model.py +466 -0
- cosyvoice/dataset/dataset.py +164 -0
- cosyvoice/hifigan/__pycache__/generator.cpython-310.pyc +0 -0
- cosyvoice/llm/llm.py +434 -0
- cosyvoice/llm/llm_vllm.py +212 -0
- cosyvoice/llm/vllm_use_cosyvoice2_model.py +263 -0
- cosyvoice/tokenizer/__pycache__/tokenizer.cpython-310.pyc +0 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +0 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/__pycache__/__init__.cpython-310.pyc +0 -0
- cosyvoice/transformer/__pycache__/activation.cpython-310.pyc +0 -0
- cosyvoice/transformer/__pycache__/attention.cpython-310.pyc +0 -0
- cosyvoice/transformer/__pycache__/convolution.cpython-310.pyc +0 -0
- cosyvoice/transformer/__pycache__/embedding.cpython-310.pyc +0 -0
- cosyvoice/transformer/__pycache__/encoder_layer.cpython-310.pyc +0 -0
- cosyvoice/transformer/__pycache__/label_smoothing_loss.cpython-310.pyc +0 -0
- cosyvoice/transformer/__pycache__/positionwise_feed_forward.cpython-310.pyc +0 -0
- cosyvoice/transformer/__pycache__/subsampling.cpython-310.pyc +0 -0
- cosyvoice/transformer/__pycache__/upsample_encoder.cpython-310.pyc +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +294 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
.ipynb_checkpoints/Untitled-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [],
|
| 3 |
+
"metadata": {},
|
| 4 |
+
"nbformat": 4,
|
| 5 |
+
"nbformat_minor": 5
|
| 6 |
+
}
|
.ipynb_checkpoints/requirements_vllm-checkpoint.txt
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
vllm==0.7.3
|
| 2 |
+
pydantic==2.10.6
|
| 3 |
+
torch==2.5.1
|
| 4 |
+
torchaudio==2.5.1
|
| 5 |
+
|
| 6 |
+
conformer==0.3.2
|
| 7 |
+
|
| 8 |
+
diffusers==0.32.2
|
| 9 |
+
gdown==5.1.0
|
| 10 |
+
grpcio==1.57.0
|
| 11 |
+
grpcio-tools==1.57.0
|
| 12 |
+
hydra-core==1.3.2
|
| 13 |
+
HyperPyYAML==1.2.2
|
| 14 |
+
inflect==7.3.1
|
| 15 |
+
librosa==0.10.2
|
| 16 |
+
|
| 17 |
+
lightning==2.5.0.post0
|
| 18 |
+
matplotlib==3.7.5
|
| 19 |
+
modelscope==1.15.0
|
| 20 |
+
|
| 21 |
+
networkx==3.4.2
|
| 22 |
+
omegaconf==2.3.0
|
| 23 |
+
onnx==1.17.0
|
| 24 |
+
|
| 25 |
+
onnxruntime-gpu==1.19.0; sys_platform == 'linux'
|
| 26 |
+
|
| 27 |
+
#openai-whisper==20231117
|
| 28 |
+
openai-whisper==20240930
|
| 29 |
+
protobuf==4.25
|
| 30 |
+
pyworld==0.3.4
|
| 31 |
+
rich==13.7.1
|
| 32 |
+
soundfile==0.12.1
|
| 33 |
+
tensorboard==2.14.0
|
| 34 |
+
wget==3.2
|
| 35 |
+
WeTextProcessing==1.0.3
|
| 36 |
+
|
| 37 |
+
# trt use
|
| 38 |
+
tensorrt-cu12==10.0.1
|
| 39 |
+
tensorrt-cu12-bindings==10.0.1
|
| 40 |
+
tensorrt-cu12-libs==10.0.1
|
.ipynb_checkpoints/speed_test-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"## 测试效果\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"- 测试代码: [speed_test.ipynb](speed_test.ipynb)\n",
|
| 10 |
+
"- 测试环境: Intel i5-12400 CPU, 48GB RAM, 1x NVIDIA GeForce RTX 4070\n",
|
| 11 |
+
"- 运行环境: Ubuntu 24.04.1 LTS, cuda 12.4, python 3.10.16\n",
|
| 12 |
+
"- 测试说明: 单任务执行的数据(非并发测试)\n"
|
| 13 |
+
]
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"cell_type": "markdown",
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"source": [
|
| 19 |
+
"## 默认情况下使用"
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "code",
|
| 24 |
+
"execution_count": null,
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"outputs": [],
|
| 27 |
+
"source": [
|
| 28 |
+
"import time\n",
|
| 29 |
+
"import asyncio\n",
|
| 30 |
+
"import torchaudio\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"import sys\n",
|
| 33 |
+
"sys.path.append('third_party/Matcha-TTS')\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"from cosyvoice.cli.cosyvoice import CosyVoice2\n",
|
| 36 |
+
"from cosyvoice.utils.file_utils import load_wav\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"prompt_text = '希望你以后能够做得比我还好哟'\n",
|
| 39 |
+
"prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"# cosyvoice = CosyVoice2('./pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=True)\n",
|
| 42 |
+
"cosyvoice = CosyVoice2('./pretrained_models/CosyVoice2-0.5B', load_jit=True, load_trt=True, fp16=True)"
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "markdown",
|
| 47 |
+
"metadata": {},
|
| 48 |
+
"source": [
|
| 49 |
+
"## 使用vllm加速llm推理\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"#### 1. **安装依赖**\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"(该依赖环境下可以运行原本cosyvoice2代码)\n",
|
| 54 |
+
"```bash\n",
|
| 55 |
+
"pip install -r requirements_vllm.txt\n",
|
| 56 |
+
"```\n",
|
| 57 |
+
"\n",
|
| 58 |
+
"#### 2. **文件复制**\n",
|
| 59 |
+
"将 pretrained_models/CosyVoice2-0.5B/CosyVoice-BlankEN 文件夹下的部分文件复制到下载的CosyVoice2-0.5B模型文件夹下,并替换 config.json 文件中的 Qwen2ForCausalLM 为 CosyVoice2Model。\n",
|
| 60 |
+
"```bash\n",
|
| 61 |
+
"cp pretrained_models/CosyVoice2-0.5B/CosyVoice-BlankEN/{config.json,tokenizer_config.json,vocab.json,merges.txt} pretrained_models/CosyVoice2-0.5B/\n",
|
| 62 |
+
"sed -i 's/Qwen2ForCausalLM/CosyVoice2Model/' pretrained_models/CosyVoice2-0.5B/config.json\n",
|
| 63 |
+
"```\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"#### **注意:**\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"- 使用 load_trt 后,需要进行 **预热** 10次推理以上,使用流式推理预热效果较好\n",
|
| 68 |
+
"- 在 jupyter notebook 中,如果要使用 **vllm** 运行下列代码,需要将vllm_use_cosyvoice2_model.py正确复制到 vllm 包中,并注册到 _VLLM_MODELS 字典中。运行下面的 code 完成"
|
| 69 |
+
]
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
"cell_type": "code",
|
| 73 |
+
"execution_count": null,
|
| 74 |
+
"metadata": {},
|
| 75 |
+
"outputs": [],
|
| 76 |
+
"source": [
|
| 77 |
+
"import os\n",
|
| 78 |
+
"import shutil\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"# 获取vllm包的安装路径\n",
|
| 81 |
+
"try:\n",
|
| 82 |
+
" import vllm\n",
|
| 83 |
+
"except ImportError:\n",
|
| 84 |
+
" raise ImportError(\"vllm package not installed\")\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"vllm_path = os.path.dirname(vllm.__file__)\n",
|
| 88 |
+
"print(f\"vllm package path: {vllm_path}\")\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"# 定义目标路径\n",
|
| 91 |
+
"target_dir = os.path.join(vllm_path, \"model_executor\", \"models\")\n",
|
| 92 |
+
"target_file = os.path.join(target_dir, \"cosyvoice2.py\")\n",
|
| 93 |
+
"\n",
|
| 94 |
+
"# 复制模型文件\n",
|
| 95 |
+
"source_file = \"./cosyvoice/llm/vllm_use_cosyvoice2_model.py\"\n",
|
| 96 |
+
"if not os.path.exists(source_file):\n",
|
| 97 |
+
" raise FileNotFoundError(f\"Source file {source_file} not found\")\n",
|
| 98 |
+
"\n",
|
| 99 |
+
"shutil.copy(source_file, target_file)\n",
|
| 100 |
+
"print(f\"Copied {source_file} to {target_file}\")\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"# 修改registry.py文件\n",
|
| 103 |
+
"registry_path = os.path.join(target_dir, \"registry.py\")\n",
|
| 104 |
+
"new_entry = ' \"CosyVoice2Model\": (\"cosyvoice2\", \"CosyVoice2Model\"), # noqa: E501\\n'\n",
|
| 105 |
+
"\n",
|
| 106 |
+
"# 读取并修改文件内容\n",
|
| 107 |
+
"with open(registry_path, \"r\") as f:\n",
|
| 108 |
+
" lines = f.readlines()\n",
|
| 109 |
+
"\n",
|
| 110 |
+
"# 检查是否已存在条目\n",
|
| 111 |
+
"entry_exists = any(\"CosyVoice2Model\" in line for line in lines)\n",
|
| 112 |
+
"\n",
|
| 113 |
+
"if not entry_exists:\n",
|
| 114 |
+
" # 寻找插入位置\n",
|
| 115 |
+
" insert_pos = None\n",
|
| 116 |
+
" for i, line in enumerate(lines):\n",
|
| 117 |
+
" if line.strip().startswith(\"**_FALLBACK_MODEL\"):\n",
|
| 118 |
+
" insert_pos = i + 1\n",
|
| 119 |
+
" break\n",
|
| 120 |
+
" \n",
|
| 121 |
+
" if insert_pos is None:\n",
|
| 122 |
+
" raise ValueError(\"Could not find insertion point in registry.py\")\n",
|
| 123 |
+
" \n",
|
| 124 |
+
" # 插入新条目\n",
|
| 125 |
+
" lines.insert(insert_pos, new_entry)\n",
|
| 126 |
+
" \n",
|
| 127 |
+
" # 写回文件\n",
|
| 128 |
+
" with open(registry_path, \"w\") as f:\n",
|
| 129 |
+
" f.writelines(lines)\n",
|
| 130 |
+
" print(\"Successfully updated registry.py\")\n",
|
| 131 |
+
"else:\n",
|
| 132 |
+
" print(\"Entry already exists in registry.py, skipping modification\")\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"print(\"All operations completed successfully!\")"
|
| 135 |
+
]
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"cell_type": "code",
|
| 139 |
+
"execution_count": 1,
|
| 140 |
+
"metadata": {},
|
| 141 |
+
"outputs": [
|
| 142 |
+
{
|
| 143 |
+
"name": "stdout",
|
| 144 |
+
"output_type": "stream",
|
| 145 |
+
"text": [
|
| 146 |
+
"failed to import ttsfrd, use WeTextProcessing instead\n"
|
| 147 |
+
]
|
| 148 |
+
},
|
| 149 |
+
{
|
| 150 |
+
"name": "stderr",
|
| 151 |
+
"output_type": "stream",
|
| 152 |
+
"text": [
|
| 153 |
+
"Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.\n",
|
| 154 |
+
"/opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/diffusers/models/lora.py:393: FutureWarning: `LoRACompatibleLinear` is deprecated and will be removed in version 1.0.0. Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`.\n",
|
| 155 |
+
" deprecate(\"LoRACompatibleLinear\", \"1.0.0\", deprecation_message)\n",
|
| 156 |
+
"2025-03-08 00:37:04,867 INFO input frame rate=25\n",
|
| 157 |
+
"/opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:115: UserWarning: Specified provider 'CUDAExecutionProvider' is not in available provider names.Available providers: 'AzureExecutionProvider, CPUExecutionProvider'\n",
|
| 158 |
+
" warnings.warn(\n",
|
| 159 |
+
"2025-03-08 00:37:06,103 WETEXT INFO found existing fst: /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/zh_tn_tagger.fst\n",
|
| 160 |
+
"2025-03-08 00:37:06,103 INFO found existing fst: /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/zh_tn_tagger.fst\n",
|
| 161 |
+
"2025-03-08 00:37:06,104 WETEXT INFO /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/zh_tn_verbalizer.fst\n",
|
| 162 |
+
"2025-03-08 00:37:06,104 INFO /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/zh_tn_verbalizer.fst\n",
|
| 163 |
+
"2025-03-08 00:37:06,104 WETEXT INFO skip building fst for zh_normalizer ...\n",
|
| 164 |
+
"2025-03-08 00:37:06,104 INFO skip building fst for zh_normalizer ...\n",
|
| 165 |
+
"2025-03-08 00:37:06,313 WETEXT INFO found existing fst: /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/en_tn_tagger.fst\n",
|
| 166 |
+
"2025-03-08 00:37:06,313 INFO found existing fst: /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/en_tn_tagger.fst\n",
|
| 167 |
+
"2025-03-08 00:37:06,314 WETEXT INFO /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/en_tn_verbalizer.fst\n",
|
| 168 |
+
"2025-03-08 00:37:06,314 INFO /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/en_tn_verbalizer.fst\n",
|
| 169 |
+
"2025-03-08 00:37:06,314 WETEXT INFO skip building fst for en_normalizer ...\n",
|
| 170 |
+
"2025-03-08 00:37:06,314 INFO skip building fst for en_normalizer ...\n"
|
| 171 |
+
]
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"name": "stdout",
|
| 175 |
+
"output_type": "stream",
|
| 176 |
+
"text": [
|
| 177 |
+
"INFO 03-08 00:37:07 __init__.py:207] Automatically detected platform cuda.\n",
|
| 178 |
+
"WARNING 03-08 00:37:07 registry.py:352] Model architecture CosyVoice2Model is already registered, and will be overwritten by the new model class <class 'cosyvoice.llm.vllm_use_cosyvoice2_model.CosyVoice2Model'>.\n",
|
| 179 |
+
"WARNING 03-08 00:37:07 config.py:2517] Casting torch.bfloat16 to torch.float16.\n",
|
| 180 |
+
"INFO 03-08 00:37:07 config.py:560] This model supports multiple tasks: {'embed', 'classify', 'reward', 'generate', 'score'}. Defaulting to 'generate'.\n",
|
| 181 |
+
"INFO 03-08 00:37:07 config.py:1624] Chunked prefill is enabled with max_num_batched_tokens=1024.\n",
|
| 182 |
+
"WARNING 03-08 00:37:08 utils.py:2164] CUDA was previously initialized. We must use the `spawn` multiprocessing start method. Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See https://docs.vllm.ai/en/latest/getting_started/troubleshooting.html#python-multiprocessing for more information.\n",
|
| 183 |
+
"INFO 03-08 00:37:10 __init__.py:207] Automatically detected platform cuda.\n",
|
| 184 |
+
"INFO 03-08 00:37:11 core.py:50] Initializing a V1 LLM engine (v0.7.3.dev213+gede41bc7.d20250219) with config: model='./pretrained_models/CosyVoice2-0.5B', speculative_config=None, tokenizer='./pretrained_models/CosyVoice2-0.5B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=1024, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=./pretrained_models/CosyVoice2-0.5B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={\"level\":3,\"custom_ops\":[\"none\"],\"splitting_ops\":[\"vllm.unified_attention\",\"vllm.unified_attention_with_output\"],\"use_inductor\":true,\"compile_sizes\":[],\"use_cudagraph\":true,\"cudagraph_num_of_warmups\":1,\"cudagraph_capture_sizes\":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],\"max_capture_size\":512}\n",
|
| 185 |
+
"WARNING 03-08 00:37:11 utils.py:2298] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,list_loras,load_config,pin_lora,remove_lora,scheduler_config not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x771e56fb9a50>\n",
|
| 186 |
+
"INFO 03-08 00:37:11 parallel_state.py:948] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0\n",
|
| 187 |
+
"INFO 03-08 00:37:11 gpu_model_runner.py:1055] Starting to load model ./pretrained_models/CosyVoice2-0.5B...\n",
|
| 188 |
+
"INFO 03-08 00:37:11 cuda.py:157] Using Flash Attention backend on V1 engine.\n",
|
| 189 |
+
"WARNING 03-08 00:37:11 topk_topp_sampler.py:46] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.\n",
|
| 190 |
+
"WARNING 03-08 00:37:11 rejection_sampler.py:47] FlashInfer is not available. Falling back to the PyTorch-native implementation of rejection sampling. For the best performance, please install FlashInfer.\n"
|
| 191 |
+
]
|
| 192 |
+
},
|
| 193 |
+
{
|
| 194 |
+
"name": "stderr",
|
| 195 |
+
"output_type": "stream",
|
| 196 |
+
"text": [
|
| 197 |
+
"/opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/torch/utils/_device.py:106: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
| 198 |
+
" return func(*args, **kwargs)\n",
|
| 199 |
+
"Loading pt checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]\n",
|
| 200 |
+
"Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 1.12it/s]\n",
|
| 201 |
+
"Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 1.12it/s]\n",
|
| 202 |
+
"\n"
|
| 203 |
+
]
|
| 204 |
+
},
|
| 205 |
+
{
|
| 206 |
+
"name": "stdout",
|
| 207 |
+
"output_type": "stream",
|
| 208 |
+
"text": [
|
| 209 |
+
"INFO 03-08 00:37:12 gpu_model_runner.py:1068] Loading model weights took 0.9532 GB and 1.023026 seconds\n",
|
| 210 |
+
"INFO 03-08 00:37:16 backends.py:408] Using cache directory: /home/qihua/.cache/vllm/torch_compile_cache/29f70599cb/rank_0 for vLLM's torch.compile\n",
|
| 211 |
+
"INFO 03-08 00:37:16 backends.py:418] Dynamo bytecode transform time: 3.62 s\n",
|
| 212 |
+
"INFO 03-08 00:37:16 backends.py:115] Directly load the compiled graph for shape None from the cache\n",
|
| 213 |
+
"INFO 03-08 00:37:19 monitor.py:33] torch.compile takes 3.62 s in total\n",
|
| 214 |
+
"INFO 03-08 00:37:20 kv_cache_utils.py:524] GPU KV cache size: 216,560 tokens\n",
|
| 215 |
+
"INFO 03-08 00:37:20 kv_cache_utils.py:527] Maximum concurrency for 1,024 tokens per request: 211.48x\n"
|
| 216 |
+
]
|
| 217 |
+
},
|
| 218 |
+
{
|
| 219 |
+
"name": "stderr",
|
| 220 |
+
"output_type": "stream",
|
| 221 |
+
"text": [
|
| 222 |
+
"2025-03-08 00:37:30,767 DEBUG Using selector: EpollSelector\n"
|
| 223 |
+
]
|
| 224 |
+
},
|
| 225 |
+
{
|
| 226 |
+
"name": "stdout",
|
| 227 |
+
"output_type": "stream",
|
| 228 |
+
"text": [
|
| 229 |
+
"INFO 03-08 00:37:30 gpu_model_runner.py:1375] Graph capturing finished in 11 secs, took 0.37 GiB\n",
|
| 230 |
+
"INFO 03-08 00:37:30 core.py:116] init engine (profile, create kv cache, warmup model) took 17.82 seconds\n",
|
| 231 |
+
"inference_processor\n",
|
| 232 |
+
"[03/08/2025-00:37:31] [TRT] [I] Loaded engine size: 158 MiB\n",
|
| 233 |
+
"[03/08/2025-00:37:31] [TRT] [I] [MS] Running engine with multi stream info\n",
|
| 234 |
+
"[03/08/2025-00:37:31] [TRT] [I] [MS] Number of aux streams is 1\n",
|
| 235 |
+
"[03/08/2025-00:37:31] [TRT] [I] [MS] Number of total worker streams is 2\n",
|
| 236 |
+
"[03/08/2025-00:37:31] [TRT] [I] [MS] The main stream provided by execute/enqueue calls is the first worker stream\n",
|
| 237 |
+
"[03/08/2025-00:37:32] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +4545, now: CPU 0, GPU 4681 (MiB)\n"
|
| 238 |
+
]
|
| 239 |
+
},
|
| 240 |
+
{
|
| 241 |
+
"name": "stdout",
|
| 242 |
+
"output_type": "stream",
|
| 243 |
+
"text": [
|
| 244 |
+
"inference_processor\n",
|
| 245 |
+
"inference_processor\n",
|
| 246 |
+
"inference_processor\n",
|
| 247 |
+
"inference_processor\n",
|
| 248 |
+
"inference_processor\n",
|
| 249 |
+
"inference_processor\n",
|
| 250 |
+
"inference_processor\n",
|
| 251 |
+
"inference_processor\n",
|
| 252 |
+
"inference_processor\n",
|
| 253 |
+
"inference_processor\n",
|
| 254 |
+
"inference_processor\n",
|
| 255 |
+
"inference_processor\n",
|
| 256 |
+
"inference_processor\n",
|
| 257 |
+
"inference_processor\n",
|
| 258 |
+
"inference_processor\n",
|
| 259 |
+
"inference_processor\n",
|
| 260 |
+
"inference_processor\n",
|
| 261 |
+
"inference_processor\n",
|
| 262 |
+
"inference_processor\n",
|
| 263 |
+
"inference_processor\n",
|
| 264 |
+
"inference_processor\n",
|
| 265 |
+
"inference_processor\n",
|
| 266 |
+
"inference_processor\n",
|
| 267 |
+
"inference_processor\n",
|
| 268 |
+
"inference_processor\n",
|
| 269 |
+
"inference_processor\n",
|
| 270 |
+
"inference_processor\n",
|
| 271 |
+
"inference_processor\n",
|
| 272 |
+
"inference_processor\n",
|
| 273 |
+
"inference_processor\n",
|
| 274 |
+
"inference_processor\n",
|
| 275 |
+
"inference_processor\n",
|
| 276 |
+
"inference_processor\n",
|
| 277 |
+
"inference_processor\n",
|
| 278 |
+
"inference_processor\n",
|
| 279 |
+
"inference_processor\n",
|
| 280 |
+
"inference_processor\n",
|
| 281 |
+
"inference_processor\n",
|
| 282 |
+
"inference_processor\n",
|
| 283 |
+
"inference_processor\n"
|
| 284 |
+
]
|
| 285 |
+
}
|
| 286 |
+
],
|
| 287 |
+
"source": [
|
| 288 |
+
"import time\n",
|
| 289 |
+
"import asyncio\n",
|
| 290 |
+
"import torchaudio\n",
|
| 291 |
+
"\n",
|
| 292 |
+
"import sys\n",
|
| 293 |
+
"sys.path.append('third_party/Matcha-TTS')\n",
|
| 294 |
+
"\n",
|
| 295 |
+
"from cosyvoice.cli.cosyvoice import CosyVoice2\n",
|
| 296 |
+
"from cosyvoice.utils.file_utils import load_wav\n",
|
| 297 |
+
"\n",
|
| 298 |
+
"prompt_text = '希望你以后能够做得比我还好哟'\n",
|
| 299 |
+
"prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)\n",
|
| 300 |
+
"\n",
|
| 301 |
+
"# cosyvoice = CosyVoice2(\n",
|
| 302 |
+
"# './pretrained_models/CosyVoice2-0.5B', \n",
|
| 303 |
+
"# load_jit=False, \n",
|
| 304 |
+
"# load_trt=False, \n",
|
| 305 |
+
"# fp16=True, \n",
|
| 306 |
+
"# use_vllm=True,\n",
|
| 307 |
+
"# )\n",
|
| 308 |
+
"cosyvoice = CosyVoice2(\n",
|
| 309 |
+
" './pretrained_models/CosyVoice2-0.5B', \n",
|
| 310 |
+
" load_jit=True, \n",
|
| 311 |
+
" load_trt=True, \n",
|
| 312 |
+
" fp16=True, \n",
|
| 313 |
+
" use_vllm=True,\n",
|
| 314 |
+
")"
|
| 315 |
+
]
|
| 316 |
+
},
|
| 317 |
+
{
|
| 318 |
+
"cell_type": "code",
|
| 319 |
+
"execution_count": 16,
|
| 320 |
+
"metadata": {},
|
| 321 |
+
"outputs": [
|
| 322 |
+
{
|
| 323 |
+
"name": "stderr",
|
| 324 |
+
"output_type": "stream",
|
| 325 |
+
"text": [
|
| 326 |
+
" 0%| | 0/1 [00:00<?, ?it/s]2025-03-08 00:38:59,777 INFO synthesis text 收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。\n",
|
| 327 |
+
"2025-03-08 00:39:00,917 INFO yield speech len 11.68, rtf 0.09757431402598342\n",
|
| 328 |
+
"100%|██████████| 1/1 [00:01<00:00, 1.47s/it]\n"
|
| 329 |
+
]
|
| 330 |
+
}
|
| 331 |
+
],
|
| 332 |
+
"source": [
|
| 333 |
+
"for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', prompt_text, prompt_speech_16k, stream=False)):\n",
|
| 334 |
+
" torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)"
|
| 335 |
+
]
|
| 336 |
+
},
|
| 337 |
+
{
|
| 338 |
+
"cell_type": "code",
|
| 339 |
+
"execution_count": 17,
|
| 340 |
+
"metadata": {},
|
| 341 |
+
"outputs": [
|
| 342 |
+
{
|
| 343 |
+
"name": "stderr",
|
| 344 |
+
"output_type": "stream",
|
| 345 |
+
"text": [
|
| 346 |
+
" 0%| | 0/1 [00:00<?, ?it/s]2025-03-08 00:39:01,208 INFO synthesis text 收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。\n",
|
| 347 |
+
"2025-03-08 00:39:01,587 INFO yield speech len 1.84, rtf 0.20591642545617145\n",
|
| 348 |
+
"2025-03-08 00:39:01,790 INFO yield speech len 2.0, rtf 0.10057318210601807\n",
|
| 349 |
+
"2025-03-08 00:39:02,116 INFO yield speech len 2.0, rtf 0.16271138191223145\n",
|
| 350 |
+
"2025-03-08 00:39:02,367 INFO yield speech len 2.0, rtf 0.1247786283493042\n",
|
| 351 |
+
"2025-03-08 00:39:02,640 INFO yield speech len 2.0, rtf 0.13561689853668213\n",
|
| 352 |
+
"2025-03-08 00:39:02,980 INFO yield speech len 1.88, rtf 0.1803158445561186\n",
|
| 353 |
+
"100%|██████████| 1/1 [00:02<00:00, 2.05s/it]\n"
|
| 354 |
+
]
|
| 355 |
+
}
|
| 356 |
+
],
|
| 357 |
+
"source": [
|
| 358 |
+
"for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', prompt_text, prompt_speech_16k, stream=True)):\n",
|
| 359 |
+
" torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)"
|
| 360 |
+
]
|
| 361 |
+
},
|
| 362 |
+
{
|
| 363 |
+
"cell_type": "code",
|
| 364 |
+
"execution_count": 18,
|
| 365 |
+
"metadata": {},
|
| 366 |
+
"outputs": [
|
| 367 |
+
{
|
| 368 |
+
"name": "stderr",
|
| 369 |
+
"output_type": "stream",
|
| 370 |
+
"text": [
|
| 371 |
+
"2025-03-08 00:39:02,990 INFO get tts_text generator, will skip text_normalize!\n",
|
| 372 |
+
" 0%| | 0/1 [00:00<?, ?it/s]2025-03-08 00:39:02,991 INFO get tts_text generator, will return _extract_text_token_generator!\n",
|
| 373 |
+
"2025-03-08 00:39:03,236 INFO synthesis text <generator object text_generator at 0x79c694dae340>\n",
|
| 374 |
+
"2025-03-08 00:39:03,237 INFO not enough text token to decode, wait for more\n",
|
| 375 |
+
"2025-03-08 00:39:03,252 INFO get fill token, need to append more text token\n",
|
| 376 |
+
"2025-03-08 00:39:03,253 INFO append 5 text token\n",
|
| 377 |
+
"2025-03-08 00:39:03,311 INFO get fill token, need to append more text token\n",
|
| 378 |
+
"2025-03-08 00:39:03,312 INFO append 5 text token\n",
|
| 379 |
+
"2025-03-08 00:39:03,456 INFO no more text token, decode until met eos\n",
|
| 380 |
+
"2025-03-08 00:39:04,861 INFO yield speech len 15.16, rtf 0.1072180145334128\n",
|
| 381 |
+
"100%|██████████| 1/1 [00:01<00:00, 1.88s/it]\n"
|
| 382 |
+
]
|
| 383 |
+
}
|
| 384 |
+
],
|
| 385 |
+
"source": [
|
| 386 |
+
"def text_generator():\n",
|
| 387 |
+
" yield '收到好友从远方寄来的生日礼物,'\n",
|
| 388 |
+
" yield '那份意外的惊喜与深深的祝福'\n",
|
| 389 |
+
" yield '让我心中充满了甜蜜的快乐,'\n",
|
| 390 |
+
" yield '��容如花儿般绽放。'\n",
|
| 391 |
+
"\n",
|
| 392 |
+
" \n",
|
| 393 |
+
"for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), prompt_text, prompt_speech_16k, stream=False)):\n",
|
| 394 |
+
" torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)"
|
| 395 |
+
]
|
| 396 |
+
},
|
| 397 |
+
{
|
| 398 |
+
"cell_type": "code",
|
| 399 |
+
"execution_count": 19,
|
| 400 |
+
"metadata": {},
|
| 401 |
+
"outputs": [
|
| 402 |
+
{
|
| 403 |
+
"name": "stderr",
|
| 404 |
+
"output_type": "stream",
|
| 405 |
+
"text": [
|
| 406 |
+
"2025-03-08 00:39:04,878 INFO get tts_text generator, will skip text_normalize!\n",
|
| 407 |
+
" 0%| | 0/1 [00:00<?, ?it/s]2025-03-08 00:39:04,880 INFO get tts_text generator, will return _extract_text_token_generator!\n",
|
| 408 |
+
"2025-03-08 00:39:05,151 INFO synthesis text <generator object text_generator at 0x79c694dad690>\n",
|
| 409 |
+
"2025-03-08 00:39:05,152 INFO not enough text token to decode, wait for more\n",
|
| 410 |
+
"2025-03-08 00:39:05,169 INFO get fill token, need to append more text token\n",
|
| 411 |
+
"2025-03-08 00:39:05,169 INFO append 5 text token\n",
|
| 412 |
+
"2025-03-08 00:39:05,292 INFO get fill token, need to append more text token\n",
|
| 413 |
+
"2025-03-08 00:39:05,293 INFO append 5 text token\n",
|
| 414 |
+
"2025-03-08 00:39:05,438 INFO no more text token, decode until met eos\n",
|
| 415 |
+
"2025-03-08 00:39:05,638 INFO yield speech len 1.84, rtf 0.26492670826289966\n",
|
| 416 |
+
"2025-03-08 00:39:05,841 INFO yield speech len 2.0, rtf 0.10065567493438721\n",
|
| 417 |
+
"2025-03-08 00:39:06,164 INFO yield speech len 2.0, rtf 0.16065263748168945\n",
|
| 418 |
+
"2025-03-08 00:39:06,422 INFO yield speech len 2.0, rtf 0.12791669368743896\n",
|
| 419 |
+
"2025-03-08 00:39:06,697 INFO yield speech len 2.0, rtf 0.13690149784088135\n",
|
| 420 |
+
"2025-03-08 00:39:06,998 INFO yield speech len 2.0, rtf 0.14957869052886963\n",
|
| 421 |
+
"2025-03-08 00:39:07,335 INFO yield speech len 1.0, rtf 0.3356931209564209\n",
|
| 422 |
+
"100%|██████████| 1/1 [00:02<00:00, 2.46s/it]\n"
|
| 423 |
+
]
|
| 424 |
+
}
|
| 425 |
+
],
|
| 426 |
+
"source": [
|
| 427 |
+
"def text_generator():\n",
|
| 428 |
+
" yield '收到好友从远方寄来的生日礼物,'\n",
|
| 429 |
+
" yield '那份意外的惊喜与深深的祝福'\n",
|
| 430 |
+
" yield '让我心中充满了甜蜜的快乐,'\n",
|
| 431 |
+
" yield '笑容如花儿般绽放。'\n",
|
| 432 |
+
"for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), prompt_text, prompt_speech_16k, stream=True)):\n",
|
| 433 |
+
" torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)"
|
| 434 |
+
]
|
| 435 |
+
},
|
| 436 |
+
{
|
| 437 |
+
"cell_type": "code",
|
| 438 |
+
"execution_count": 20,
|
| 439 |
+
"metadata": {},
|
| 440 |
+
"outputs": [
|
| 441 |
+
{
|
| 442 |
+
"name": "stderr",
|
| 443 |
+
"output_type": "stream",
|
| 444 |
+
"text": [
|
| 445 |
+
" 0%| | 0/1 [00:00<?, ?it/s]2025-03-08 00:39:07,592 INFO synthesis text 收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。\n",
|
| 446 |
+
"2025-03-08 00:39:08,925 INFO yield speech len 11.24, rtf 0.11861237342671567\n",
|
| 447 |
+
"100%|██████████| 1/1 [00:01<00:00, 1.58s/it]\n"
|
| 448 |
+
]
|
| 449 |
+
}
|
| 450 |
+
],
|
| 451 |
+
"source": [
|
| 452 |
+
"# instruct usage\n",
|
| 453 |
+
"for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):\n",
|
| 454 |
+
" torchaudio.save('instruct2_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)\n"
|
| 455 |
+
]
|
| 456 |
+
},
|
| 457 |
+
{
|
| 458 |
+
"cell_type": "code",
|
| 459 |
+
"execution_count": null,
|
| 460 |
+
"metadata": {},
|
| 461 |
+
"outputs": [],
|
| 462 |
+
"source": []
|
| 463 |
+
}
|
| 464 |
+
],
|
| 465 |
+
"metadata": {
|
| 466 |
+
"kernelspec": {
|
| 467 |
+
"display_name": "cosyvoice",
|
| 468 |
+
"language": "python",
|
| 469 |
+
"name": "python3"
|
| 470 |
+
},
|
| 471 |
+
"language_info": {
|
| 472 |
+
"codemirror_mode": {
|
| 473 |
+
"name": "ipython",
|
| 474 |
+
"version": 3
|
| 475 |
+
},
|
| 476 |
+
"file_extension": ".py",
|
| 477 |
+
"mimetype": "text/x-python",
|
| 478 |
+
"name": "python",
|
| 479 |
+
"nbconvert_exporter": "python",
|
| 480 |
+
"pygments_lexer": "ipython3",
|
| 481 |
+
"version": "3.10.16"
|
| 482 |
+
}
|
| 483 |
+
},
|
| 484 |
+
"nbformat": 4,
|
| 485 |
+
"nbformat_minor": 2
|
| 486 |
+
}
|
CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributor Covenant Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
In the interest of fostering an open and welcoming environment, we as
|
| 6 |
+
contributors and maintainers pledge to making participation in our project and
|
| 7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
| 8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
| 9 |
+
level of experience, education, socio-economic status, nationality, personal
|
| 10 |
+
appearance, race, religion, or sexual identity and orientation.
|
| 11 |
+
|
| 12 |
+
## Our Standards
|
| 13 |
+
|
| 14 |
+
Examples of behavior that contributes to creating a positive environment
|
| 15 |
+
include:
|
| 16 |
+
|
| 17 |
+
* Using welcoming and inclusive language
|
| 18 |
+
* Being respectful of differing viewpoints and experiences
|
| 19 |
+
* Gracefully accepting constructive criticism
|
| 20 |
+
* Focusing on what is best for the community
|
| 21 |
+
* Showing empathy towards other community members
|
| 22 |
+
|
| 23 |
+
Examples of unacceptable behavior by participants include:
|
| 24 |
+
|
| 25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
| 26 |
+
advances
|
| 27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
| 28 |
+
* Public or private harassment
|
| 29 |
+
* Publishing others' private information, such as a physical or electronic
|
| 30 |
+
address, without explicit permission
|
| 31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 32 |
+
professional setting
|
| 33 |
+
|
| 34 |
+
## Our Responsibilities
|
| 35 |
+
|
| 36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
| 37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
| 38 |
+
response to any instances of unacceptable behavior.
|
| 39 |
+
|
| 40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
| 41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
| 42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
| 43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
| 44 |
+
threatening, offensive, or harmful.
|
| 45 |
+
|
| 46 |
+
## Scope
|
| 47 |
+
|
| 48 |
+
This Code of Conduct applies both within project spaces and in public spaces
|
| 49 |
+
when an individual is representing the project or its community. Examples of
|
| 50 |
+
representing a project or community include using an official project e-mail
|
| 51 |
+
address, posting via an official social media account, or acting as an appointed
|
| 52 |
+
representative at an online or offline event. Representation of a project may be
|
| 53 |
+
further defined and clarified by project maintainers.
|
| 54 |
+
|
| 55 |
+
## Enforcement
|
| 56 |
+
|
| 57 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 58 |
+
reported by contacting the project team at mikelei@mobvoi.com. All
|
| 59 |
+
complaints will be reviewed and investigated and will result in a response that
|
| 60 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
| 61 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
| 62 |
+
Further details of specific enforcement policies may be posted separately.
|
| 63 |
+
|
| 64 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
| 65 |
+
faith may face temporary or permanent repercussions as determined by other
|
| 66 |
+
members of the project's leadership.
|
| 67 |
+
|
| 68 |
+
## Attribution
|
| 69 |
+
|
| 70 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
| 71 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
| 72 |
+
|
| 73 |
+
[homepage]: https://www.contributor-covenant.org
|
| 74 |
+
|
| 75 |
+
For answers to common questions about this code of conduct, see
|
| 76 |
+
https://www.contributor-covenant.org/faq
|
FAQ.md
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## ModuleNotFoundError: No module named 'matcha'
|
| 2 |
+
|
| 3 |
+
Matcha-TTS is a third_party module. Please check `third_party` directory. If there is no `Matcha-TTS`, execute `git submodule update --init --recursive`.
|
| 4 |
+
|
| 5 |
+
run `export PYTHONPATH=third_party/Matcha-TTS` if you want to use `from cosyvoice.cli.cosyvoice import CosyVoice` in python script.
|
| 6 |
+
|
| 7 |
+
## cannot find resource.zip or cannot unzip resource.zip
|
| 8 |
+
|
| 9 |
+
Please make sure you have git-lfs installed. Execute
|
| 10 |
+
|
| 11 |
+
```sh
|
| 12 |
+
git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
|
| 13 |
+
cd pretrained_models/CosyVoice-ttsfrd/
|
| 14 |
+
unzip resource.zip -d .
|
| 15 |
+
pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl
|
| 16 |
+
```
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,3 +1,237 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[](https://github.com/Akshay090/svg-banners)
|
| 2 |
+
|
| 3 |
+
## 👉🏻 CosyVoice 👈🏻
|
| 4 |
+
**CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/abs/2412.10117); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/spaces/FunAudioLLM/CosyVoice2-0.5B)
|
| 5 |
+
|
| 6 |
+
**CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice-300M)
|
| 7 |
+
|
| 8 |
+
## Highlight🔥
|
| 9 |
+
|
| 10 |
+
**CosyVoice 2.0** has been released! Compared to version 1.0, the new version offers more accurate, more stable, faster, and better speech generation capabilities.
|
| 11 |
+
### Multilingual
|
| 12 |
+
- **Supported Language**: Chinese, English, Japanese, Korean, Chinese dialects (Cantonese, Sichuanese, Shanghainese, Tianjinese, Wuhanese, etc.)
|
| 13 |
+
- **Crosslingual & Mixlingual**:Support zero-shot voice cloning for cross-lingual and code-switching scenarios.
|
| 14 |
+
### Ultra-Low Latency
|
| 15 |
+
- **Bidirectional Streaming Support**: CosyVoice 2.0 integrates offline and streaming modeling technologies.
|
| 16 |
+
- **Rapid First Packet Synthesis**: Achieves latency as low as 150ms while maintaining high-quality audio output.
|
| 17 |
+
### High Accuracy
|
| 18 |
+
- **Improved Pronunciation**: Reduces pronunciation errors by 30% to 50% compared to CosyVoice 1.0.
|
| 19 |
+
- **Benchmark Achievements**: Attains the lowest character error rate on the hard test set of the Seed-TTS evaluation set.
|
| 20 |
+
### Strong Stability
|
| 21 |
+
- **Consistency in Timbre**: Ensures reliable voice consistency for zero-shot and cross-language speech synthesis.
|
| 22 |
+
- **Cross-language Synthesis**: Marked improvements compared to version 1.0.
|
| 23 |
+
### Natural Experience
|
| 24 |
+
- **Enhanced Prosody and Sound Quality**: Improved alignment of synthesized audio, raising MOS evaluation scores from 5.4 to 5.53.
|
| 25 |
+
- **Emotional and Dialectal Flexibility**: Now supports more granular emotional controls and accent adjustments.
|
| 26 |
+
|
| 27 |
+
## Roadmap
|
| 28 |
+
|
| 29 |
+
- [x] 2024/12
|
| 30 |
+
|
| 31 |
+
- [x] 25hz cosyvoice 2.0 released
|
| 32 |
+
|
| 33 |
+
- [x] 2024/09
|
| 34 |
+
|
| 35 |
+
- [x] 25hz cosyvoice base model
|
| 36 |
+
- [x] 25hz cosyvoice voice conversion model
|
| 37 |
+
|
| 38 |
+
- [x] 2024/08
|
| 39 |
+
|
| 40 |
+
- [x] Repetition Aware Sampling(RAS) inference for llm stability
|
| 41 |
+
- [x] Streaming inference mode support, including kv cache and sdpa for rtf optimization
|
| 42 |
+
|
| 43 |
+
- [x] 2024/07
|
| 44 |
+
|
| 45 |
+
- [x] Flow matching training support
|
| 46 |
+
- [x] WeTextProcessing support when ttsfrd is not available
|
| 47 |
+
- [x] Fastapi server and client
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
## Install
|
| 51 |
+
|
| 52 |
+
**Clone and install**
|
| 53 |
+
|
| 54 |
+
- Clone the repo
|
| 55 |
+
``` sh
|
| 56 |
+
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
| 57 |
+
# If you failed to clone submodule due to network failures, please run following command until success
|
| 58 |
+
cd CosyVoice
|
| 59 |
+
git submodule update --init --recursive
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
- Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
|
| 63 |
+
- Create Conda env:
|
| 64 |
+
|
| 65 |
+
``` sh
|
| 66 |
+
conda create -n cosyvoice -y python=3.10
|
| 67 |
+
conda activate cosyvoice
|
| 68 |
+
# pynini is required by WeTextProcessing, use conda to install it as it can be executed on all platform.
|
| 69 |
+
conda install -y -c conda-forge pynini==2.1.5
|
| 70 |
+
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
| 71 |
+
|
| 72 |
+
# If you encounter sox compatibility issues
|
| 73 |
+
# ubuntu
|
| 74 |
+
sudo apt-get install sox libsox-dev
|
| 75 |
+
# centos
|
| 76 |
+
sudo yum install sox sox-devel
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
**Model download**
|
| 80 |
+
|
| 81 |
+
We strongly recommend that you download our pretrained `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
|
| 82 |
+
|
| 83 |
+
``` python
|
| 84 |
+
# SDK模型下载
|
| 85 |
+
from modelscope import snapshot_download
|
| 86 |
+
snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
|
| 87 |
+
snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
|
| 88 |
+
snapshot_download('iic/CosyVoice-300M-25Hz', local_dir='pretrained_models/CosyVoice-300M-25Hz')
|
| 89 |
+
snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
|
| 90 |
+
snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
|
| 91 |
+
snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
``` sh
|
| 95 |
+
# git模型下载,请确保已安装git lfs
|
| 96 |
+
mkdir -p pretrained_models
|
| 97 |
+
git clone https://www.modelscope.cn/iic/CosyVoice2-0.5B.git pretrained_models/CosyVoice2-0.5B
|
| 98 |
+
git clone https://www.modelscope.cn/iic/CosyVoice-300M.git pretrained_models/CosyVoice-300M
|
| 99 |
+
git clone https://www.modelscope.cn/iic/CosyVoice-300M-25Hz.git pretrained_models/CosyVoice-300M-25Hz
|
| 100 |
+
git clone https://www.modelscope.cn/iic/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT
|
| 101 |
+
git clone https://www.modelscope.cn/iic/CosyVoice-300M-Instruct.git pretrained_models/CosyVoice-300M-Instruct
|
| 102 |
+
git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
Optionally, you can unzip `ttsfrd` resouce and install `ttsfrd` package for better text normalization performance.
|
| 106 |
+
|
| 107 |
+
Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use WeTextProcessing by default.
|
| 108 |
+
|
| 109 |
+
``` sh
|
| 110 |
+
cd pretrained_models/CosyVoice-ttsfrd/
|
| 111 |
+
unzip resource.zip -d .
|
| 112 |
+
pip install ttsfrd_dependency-0.1-py3-none-any.whl
|
| 113 |
+
pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
**Basic Usage**
|
| 117 |
+
|
| 118 |
+
We strongly recommend using `CosyVoice2-0.5B` for better performance.
|
| 119 |
+
Follow code below for detailed usage of each model.
|
| 120 |
+
|
| 121 |
+
``` python
|
| 122 |
+
import sys
|
| 123 |
+
sys.path.append('third_party/Matcha-TTS')
|
| 124 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
| 125 |
+
from cosyvoice.utils.file_utils import load_wav
|
| 126 |
+
import torchaudio
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
**CosyVoice2 Usage**
|
| 130 |
+
```python
|
| 131 |
+
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False)
|
| 132 |
+
|
| 133 |
+
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
|
| 134 |
+
# zero_shot usage
|
| 135 |
+
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
|
| 136 |
+
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
| 137 |
+
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
| 138 |
+
|
| 139 |
+
# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
|
| 140 |
+
for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)):
|
| 141 |
+
torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
| 142 |
+
|
| 143 |
+
# instruct usage
|
| 144 |
+
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
|
| 145 |
+
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
| 146 |
+
|
| 147 |
+
# bistream usage, you can use generator as input, this is useful when using text llm model as input
|
| 148 |
+
# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
|
| 149 |
+
def text_generator():
|
| 150 |
+
yield '收到好友从远方寄来的生日礼物,'
|
| 151 |
+
yield '那份意外的惊喜与深深的祝福'
|
| 152 |
+
yield '让我心中充满了甜蜜的快乐,'
|
| 153 |
+
yield '笑容如花儿般绽放。'
|
| 154 |
+
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
| 155 |
+
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
**CosyVoice Usage**
|
| 159 |
+
```python
|
| 160 |
+
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_trt=False, fp16=False)
|
| 161 |
+
# sft usage
|
| 162 |
+
print(cosyvoice.list_available_spks())
|
| 163 |
+
# change stream=True for chunk stream inference
|
| 164 |
+
for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
|
| 165 |
+
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
| 166 |
+
|
| 167 |
+
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M') # or change to pretrained_models/CosyVoice-300M-25Hz for 25Hz inference
|
| 168 |
+
# zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
|
| 169 |
+
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
|
| 170 |
+
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
| 171 |
+
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
| 172 |
+
# cross_lingual usage
|
| 173 |
+
prompt_speech_16k = load_wav('./asset/cross_lingual_prompt.wav', 16000)
|
| 174 |
+
for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
|
| 175 |
+
torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
| 176 |
+
# vc usage
|
| 177 |
+
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
|
| 178 |
+
source_speech_16k = load_wav('./asset/cross_lingual_prompt.wav', 16000)
|
| 179 |
+
for i, j in enumerate(cosyvoice.inference_vc(source_speech_16k, prompt_speech_16k, stream=False)):
|
| 180 |
+
torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
| 181 |
+
|
| 182 |
+
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
|
| 183 |
+
# instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
|
| 184 |
+
for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
|
| 185 |
+
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
**Start web demo**
|
| 189 |
+
|
| 190 |
+
You can use our web demo page to get familiar with CosyVoice quickly.
|
| 191 |
+
|
| 192 |
+
Please see the demo website for details.
|
| 193 |
+
|
| 194 |
+
``` python
|
| 195 |
+
# change iic/CosyVoice-300M-SFT for sft inference, or iic/CosyVoice-300M-Instruct for instruct inference
|
| 196 |
+
python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice-300M
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
**Advanced Usage**
|
| 200 |
+
|
| 201 |
+
For advanced user, we have provided train and inference scripts in `examples/libritts/cosyvoice/run.sh`.
|
| 202 |
+
|
| 203 |
+
**Build for deployment**
|
| 204 |
+
|
| 205 |
+
Optionally, if you want service deployment,
|
| 206 |
+
you can run following steps.
|
| 207 |
+
|
| 208 |
+
``` sh
|
| 209 |
+
cd runtime/python
|
| 210 |
+
docker build -t cosyvoice:v1.0 .
|
| 211 |
+
# change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
|
| 212 |
+
# for grpc usage
|
| 213 |
+
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
|
| 214 |
+
cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
| 215 |
+
# for fastapi usage
|
| 216 |
+
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && python3 server.py --port 50000 --model_dir iic/CosyVoice-300M && sleep infinity"
|
| 217 |
+
cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
+
## Discussion & Communication
|
| 221 |
+
|
| 222 |
+
You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
|
| 223 |
+
|
| 224 |
+
You can also scan the QR code to join our official Dingding chat group.
|
| 225 |
+
|
| 226 |
+
<img src="./asset/dingding.png" width="250px">
|
| 227 |
+
|
| 228 |
+
## Acknowledge
|
| 229 |
+
|
| 230 |
+
1. We borrowed a lot of code from [FunASR](https://github.com/modelscope/FunASR).
|
| 231 |
+
2. We borrowed a lot of code from [FunCodec](https://github.com/modelscope/FunCodec).
|
| 232 |
+
3. We borrowed a lot of code from [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS).
|
| 233 |
+
4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
|
| 234 |
+
5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
|
| 235 |
+
|
| 236 |
+
## Disclaimer
|
| 237 |
+
The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
|
Untitled.ipynb
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "7a91a72d-234a-4fe7-ba1a-247b598af957",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": []
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "code",
|
| 13 |
+
"execution_count": 7,
|
| 14 |
+
"id": "8fceda4d-f896-45c2-b7c5-df31c4225972",
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": [
|
| 18 |
+
"from huggingface_hub import HfApi\n",
|
| 19 |
+
"api = HfApi()\n"
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "code",
|
| 24 |
+
"execution_count": null,
|
| 25 |
+
"id": "534e4d62-9cc2-4d04-bf9d-c073a047752a",
|
| 26 |
+
"metadata": {},
|
| 27 |
+
"outputs": [],
|
| 28 |
+
"source": []
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"cell_type": "code",
|
| 32 |
+
"execution_count": null,
|
| 33 |
+
"id": "20bcb46a-9c4e-4a36-aac2-a1ac1b3c1a90",
|
| 34 |
+
"metadata": {},
|
| 35 |
+
"outputs": [],
|
| 36 |
+
"source": []
|
| 37 |
+
}
|
| 38 |
+
],
|
| 39 |
+
"metadata": {
|
| 40 |
+
"kernelspec": {
|
| 41 |
+
"display_name": "Python3 (ipykernel)",
|
| 42 |
+
"language": "python",
|
| 43 |
+
"name": "python3"
|
| 44 |
+
},
|
| 45 |
+
"language_info": {
|
| 46 |
+
"codemirror_mode": {
|
| 47 |
+
"name": "ipython",
|
| 48 |
+
"version": 3
|
| 49 |
+
},
|
| 50 |
+
"file_extension": ".py",
|
| 51 |
+
"mimetype": "text/x-python",
|
| 52 |
+
"name": "python",
|
| 53 |
+
"nbconvert_exporter": "python",
|
| 54 |
+
"pygments_lexer": "ipython3",
|
| 55 |
+
"version": "3.12.3"
|
| 56 |
+
}
|
| 57 |
+
},
|
| 58 |
+
"nbformat": 4,
|
| 59 |
+
"nbformat_minor": 5
|
| 60 |
+
}
|
asset/dingding.png
ADDED
|
cosyvoice/__init__.py
ADDED
|
File without changes
|
cosyvoice/bin/average_model.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
|
| 2 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import argparse
|
| 18 |
+
import glob
|
| 19 |
+
|
| 20 |
+
import yaml
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_args():
|
| 25 |
+
parser = argparse.ArgumentParser(description='average model')
|
| 26 |
+
parser.add_argument('--dst_model', required=True, help='averaged model')
|
| 27 |
+
parser.add_argument('--src_path',
|
| 28 |
+
required=True,
|
| 29 |
+
help='src model path for average')
|
| 30 |
+
parser.add_argument('--val_best',
|
| 31 |
+
action="store_true",
|
| 32 |
+
help='averaged model')
|
| 33 |
+
parser.add_argument('--num',
|
| 34 |
+
default=5,
|
| 35 |
+
type=int,
|
| 36 |
+
help='nums for averaged model')
|
| 37 |
+
|
| 38 |
+
args = parser.parse_args()
|
| 39 |
+
print(args)
|
| 40 |
+
return args
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def main():
|
| 44 |
+
args = get_args()
|
| 45 |
+
val_scores = []
|
| 46 |
+
if args.val_best:
|
| 47 |
+
yamls = glob.glob('{}/*.yaml'.format(args.src_path))
|
| 48 |
+
yamls = [
|
| 49 |
+
f for f in yamls
|
| 50 |
+
if not (os.path.basename(f).startswith('train')
|
| 51 |
+
or os.path.basename(f).startswith('init'))
|
| 52 |
+
]
|
| 53 |
+
for y in yamls:
|
| 54 |
+
with open(y, 'r') as f:
|
| 55 |
+
dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
|
| 56 |
+
loss = float(dic_yaml['loss_dict']['loss'])
|
| 57 |
+
epoch = int(dic_yaml['epoch'])
|
| 58 |
+
step = int(dic_yaml['step'])
|
| 59 |
+
tag = dic_yaml['tag']
|
| 60 |
+
val_scores += [[epoch, step, loss, tag]]
|
| 61 |
+
sorted_val_scores = sorted(val_scores,
|
| 62 |
+
key=lambda x: x[2],
|
| 63 |
+
reverse=False)
|
| 64 |
+
print("best val (epoch, step, loss, tag) = " +
|
| 65 |
+
str(sorted_val_scores[:args.num]))
|
| 66 |
+
path_list = [
|
| 67 |
+
args.src_path + '/epoch_{}_whole.pt'.format(score[0])
|
| 68 |
+
for score in sorted_val_scores[:args.num]
|
| 69 |
+
]
|
| 70 |
+
print(path_list)
|
| 71 |
+
avg = {}
|
| 72 |
+
num = args.num
|
| 73 |
+
assert num == len(path_list)
|
| 74 |
+
for path in path_list:
|
| 75 |
+
print('Processing {}'.format(path))
|
| 76 |
+
states = torch.load(path, map_location=torch.device('cpu'))
|
| 77 |
+
for k in states.keys():
|
| 78 |
+
if k not in avg.keys():
|
| 79 |
+
avg[k] = states[k].clone()
|
| 80 |
+
else:
|
| 81 |
+
avg[k] += states[k]
|
| 82 |
+
# average
|
| 83 |
+
for k in avg.keys():
|
| 84 |
+
if avg[k] is not None:
|
| 85 |
+
# pytorch 1.6 use true_divide instead of /=
|
| 86 |
+
avg[k] = torch.true_divide(avg[k], num)
|
| 87 |
+
print('Saving to {}'.format(args.dst_model))
|
| 88 |
+
torch.save(avg, args.dst_model)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == '__main__':
|
| 92 |
+
main()
|
cosyvoice/bin/export_jit.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import print_function
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import logging
|
| 19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
import torch
|
| 23 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 24 |
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
| 25 |
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
| 26 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_args():
|
| 30 |
+
parser = argparse.ArgumentParser(description='export your model for deployment')
|
| 31 |
+
parser.add_argument('--model_dir',
|
| 32 |
+
type=str,
|
| 33 |
+
default='pretrained_models/CosyVoice-300M',
|
| 34 |
+
help='local path')
|
| 35 |
+
args = parser.parse_args()
|
| 36 |
+
print(args)
|
| 37 |
+
return args
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_optimized_script(model, preserved_attrs=[]):
|
| 41 |
+
script = torch.jit.script(model)
|
| 42 |
+
if preserved_attrs != []:
|
| 43 |
+
script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
|
| 44 |
+
else:
|
| 45 |
+
script = torch.jit.freeze(script)
|
| 46 |
+
script = torch.jit.optimize_for_inference(script)
|
| 47 |
+
return script
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def main():
|
| 51 |
+
args = get_args()
|
| 52 |
+
logging.basicConfig(level=logging.DEBUG,
|
| 53 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
| 54 |
+
|
| 55 |
+
torch._C._jit_set_fusion_strategy([('STATIC', 1)])
|
| 56 |
+
torch._C._jit_set_profiling_mode(False)
|
| 57 |
+
torch._C._jit_set_profiling_executor(False)
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
model = CosyVoice(args.model_dir)
|
| 61 |
+
except Exception:
|
| 62 |
+
try:
|
| 63 |
+
model = CosyVoice2(args.model_dir)
|
| 64 |
+
except Exception:
|
| 65 |
+
raise TypeError('no valid model_type!')
|
| 66 |
+
|
| 67 |
+
if not isinstance(model, CosyVoice2):
|
| 68 |
+
# 1. export llm text_encoder
|
| 69 |
+
llm_text_encoder = model.model.llm.text_encoder
|
| 70 |
+
script = get_optimized_script(llm_text_encoder)
|
| 71 |
+
script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
|
| 72 |
+
script = get_optimized_script(llm_text_encoder.half())
|
| 73 |
+
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
|
| 74 |
+
|
| 75 |
+
# 2. export llm llm
|
| 76 |
+
llm_llm = model.model.llm.llm
|
| 77 |
+
script = get_optimized_script(llm_llm, ['forward_chunk'])
|
| 78 |
+
script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
|
| 79 |
+
script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
|
| 80 |
+
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
| 81 |
+
|
| 82 |
+
# 3. export flow encoder
|
| 83 |
+
flow_encoder = model.model.flow.encoder
|
| 84 |
+
script = get_optimized_script(flow_encoder)
|
| 85 |
+
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
| 86 |
+
script = get_optimized_script(flow_encoder.half())
|
| 87 |
+
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
if __name__ == '__main__':
|
| 91 |
+
main()
|
cosyvoice/bin/export_onnx.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
|
| 2 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from __future__ import print_function
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import logging
|
| 20 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
import onnxruntime
|
| 24 |
+
import random
|
| 25 |
+
import torch
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 28 |
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
| 29 |
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
| 30 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_dummy_input(batch_size, seq_len, out_channels, device):
|
| 34 |
+
x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
| 35 |
+
mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
|
| 36 |
+
mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
| 37 |
+
t = torch.rand((batch_size), dtype=torch.float32, device=device)
|
| 38 |
+
spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
|
| 39 |
+
cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
| 40 |
+
return x, mask, mu, t, spks, cond
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_args():
|
| 44 |
+
parser = argparse.ArgumentParser(description='export your model for deployment')
|
| 45 |
+
parser.add_argument('--model_dir',
|
| 46 |
+
type=str,
|
| 47 |
+
default='pretrained_models/CosyVoice-300M',
|
| 48 |
+
help='local path')
|
| 49 |
+
args = parser.parse_args()
|
| 50 |
+
print(args)
|
| 51 |
+
return args
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def main():
|
| 55 |
+
args = get_args()
|
| 56 |
+
logging.basicConfig(level=logging.DEBUG,
|
| 57 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
model = CosyVoice(args.model_dir)
|
| 61 |
+
except Exception:
|
| 62 |
+
try:
|
| 63 |
+
model = CosyVoice2(args.model_dir)
|
| 64 |
+
except Exception:
|
| 65 |
+
raise TypeError('no valid model_type!')
|
| 66 |
+
|
| 67 |
+
# 1. export flow decoder estimator
|
| 68 |
+
estimator = model.model.flow.decoder.estimator
|
| 69 |
+
|
| 70 |
+
device = model.model.device
|
| 71 |
+
batch_size, seq_len = 2, 256
|
| 72 |
+
out_channels = model.model.flow.decoder.estimator.out_channels
|
| 73 |
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
| 74 |
+
torch.onnx.export(
|
| 75 |
+
estimator,
|
| 76 |
+
(x, mask, mu, t, spks, cond),
|
| 77 |
+
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
| 78 |
+
export_params=True,
|
| 79 |
+
opset_version=18,
|
| 80 |
+
do_constant_folding=True,
|
| 81 |
+
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
| 82 |
+
output_names=['estimator_out'],
|
| 83 |
+
dynamic_axes={
|
| 84 |
+
'x': {2: 'seq_len'},
|
| 85 |
+
'mask': {2: 'seq_len'},
|
| 86 |
+
'mu': {2: 'seq_len'},
|
| 87 |
+
'cond': {2: 'seq_len'},
|
| 88 |
+
'estimator_out': {2: 'seq_len'},
|
| 89 |
+
}
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# 2. test computation consistency
|
| 93 |
+
option = onnxruntime.SessionOptions()
|
| 94 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 95 |
+
option.intra_op_num_threads = 1
|
| 96 |
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
| 97 |
+
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
| 98 |
+
sess_options=option, providers=providers)
|
| 99 |
+
|
| 100 |
+
for _ in tqdm(range(10)):
|
| 101 |
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
|
| 102 |
+
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
| 103 |
+
ort_inputs = {
|
| 104 |
+
'x': x.cpu().numpy(),
|
| 105 |
+
'mask': mask.cpu().numpy(),
|
| 106 |
+
'mu': mu.cpu().numpy(),
|
| 107 |
+
't': t.cpu().numpy(),
|
| 108 |
+
'spks': spks.cpu().numpy(),
|
| 109 |
+
'cond': cond.cpu().numpy()
|
| 110 |
+
}
|
| 111 |
+
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
| 112 |
+
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
main()
|
cosyvoice/bin/export_trt.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Copyright 2024 Alibaba Inc. All Rights Reserved.
|
| 3 |
+
# download tensorrt from https://developer.nvidia.com/tensorrt/download/10x, check your system and cuda for compatibability
|
| 4 |
+
# for example for linux + cuda12.4, you can download https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.Linux.x86_64-gnu.cuda-12.4.tar.gz
|
| 5 |
+
TRT_DIR=<YOUR_TRT_DIR>
|
| 6 |
+
MODEL_DIR=<COSYVOICE2_MODEL_DIR>
|
| 7 |
+
|
| 8 |
+
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_DIR/lib:/usr/local/cuda/lib64
|
| 9 |
+
$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp32.mygpu.plan --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw --outputIOFormats=fp32:chw
|
| 10 |
+
$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw
|
cosyvoice/bin/inference.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import print_function
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import logging
|
| 19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
| 20 |
+
import os
|
| 21 |
+
import torch
|
| 22 |
+
from torch.utils.data import DataLoader
|
| 23 |
+
import torchaudio
|
| 24 |
+
from hyperpyyaml import load_hyperpyyaml
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
from cosyvoice.cli.model import CosyVoiceModel
|
| 27 |
+
from cosyvoice.dataset.dataset import Dataset
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_args():
|
| 31 |
+
parser = argparse.ArgumentParser(description='inference with your model')
|
| 32 |
+
parser.add_argument('--config', required=True, help='config file')
|
| 33 |
+
parser.add_argument('--prompt_data', required=True, help='prompt data file')
|
| 34 |
+
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
|
| 35 |
+
parser.add_argument('--tts_text', required=True, help='tts input file')
|
| 36 |
+
parser.add_argument('--llm_model', required=True, help='llm model file')
|
| 37 |
+
parser.add_argument('--flow_model', required=True, help='flow model file')
|
| 38 |
+
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
|
| 39 |
+
parser.add_argument('--gpu',
|
| 40 |
+
type=int,
|
| 41 |
+
default=-1,
|
| 42 |
+
help='gpu id for this rank, -1 for cpu')
|
| 43 |
+
parser.add_argument('--mode',
|
| 44 |
+
default='sft',
|
| 45 |
+
choices=['sft', 'zero_shot'],
|
| 46 |
+
help='inference mode')
|
| 47 |
+
parser.add_argument('--result_dir', required=True, help='asr result file')
|
| 48 |
+
args = parser.parse_args()
|
| 49 |
+
print(args)
|
| 50 |
+
return args
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def main():
|
| 54 |
+
args = get_args()
|
| 55 |
+
logging.basicConfig(level=logging.DEBUG,
|
| 56 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
| 57 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
| 58 |
+
|
| 59 |
+
# Init cosyvoice models from configs
|
| 60 |
+
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
| 61 |
+
device = torch.device('cuda' if use_cuda else 'cpu')
|
| 62 |
+
with open(args.config, 'r') as f:
|
| 63 |
+
configs = load_hyperpyyaml(f)
|
| 64 |
+
|
| 65 |
+
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
| 66 |
+
model.load(args.llm_model, args.flow_model, args.hifigan_model)
|
| 67 |
+
|
| 68 |
+
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
|
| 69 |
+
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
|
| 70 |
+
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
|
| 71 |
+
|
| 72 |
+
del configs
|
| 73 |
+
os.makedirs(args.result_dir, exist_ok=True)
|
| 74 |
+
fn = os.path.join(args.result_dir, 'wav.scp')
|
| 75 |
+
f = open(fn, 'w')
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
for _, batch in tqdm(enumerate(test_data_loader)):
|
| 78 |
+
utts = batch["utts"]
|
| 79 |
+
assert len(utts) == 1, "inference mode only support batchsize 1"
|
| 80 |
+
text_token = batch["text_token"].to(device)
|
| 81 |
+
text_token_len = batch["text_token_len"].to(device)
|
| 82 |
+
tts_index = batch["tts_index"]
|
| 83 |
+
tts_text_token = batch["tts_text_token"].to(device)
|
| 84 |
+
tts_text_token_len = batch["tts_text_token_len"].to(device)
|
| 85 |
+
speech_token = batch["speech_token"].to(device)
|
| 86 |
+
speech_token_len = batch["speech_token_len"].to(device)
|
| 87 |
+
speech_feat = batch["speech_feat"].to(device)
|
| 88 |
+
speech_feat_len = batch["speech_feat_len"].to(device)
|
| 89 |
+
utt_embedding = batch["utt_embedding"].to(device)
|
| 90 |
+
spk_embedding = batch["spk_embedding"].to(device)
|
| 91 |
+
if args.mode == 'sft':
|
| 92 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
| 93 |
+
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
|
| 94 |
+
else:
|
| 95 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
| 96 |
+
'prompt_text': text_token, 'prompt_text_len': text_token_len,
|
| 97 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
| 98 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
| 99 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
| 100 |
+
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
|
| 101 |
+
tts_speeches = []
|
| 102 |
+
for model_output in model.tts(**model_input):
|
| 103 |
+
tts_speeches.append(model_output['tts_speech'])
|
| 104 |
+
tts_speeches = torch.concat(tts_speeches, dim=1)
|
| 105 |
+
tts_key = '{}_{}'.format(utts[0], tts_index[0])
|
| 106 |
+
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
|
| 107 |
+
torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
|
| 108 |
+
f.write('{} {}\n'.format(tts_key, tts_fn))
|
| 109 |
+
f.flush()
|
| 110 |
+
f.close()
|
| 111 |
+
logging.info('Result wav.scp saved in {}'.format(fn))
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
if __name__ == '__main__':
|
| 115 |
+
main()
|
cosyvoice/bin/train.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import print_function
|
| 16 |
+
import argparse
|
| 17 |
+
import datetime
|
| 18 |
+
import logging
|
| 19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
| 20 |
+
from copy import deepcopy
|
| 21 |
+
import os
|
| 22 |
+
import torch
|
| 23 |
+
import torch.distributed as dist
|
| 24 |
+
import deepspeed
|
| 25 |
+
|
| 26 |
+
from hyperpyyaml import load_hyperpyyaml
|
| 27 |
+
|
| 28 |
+
from torch.distributed.elastic.multiprocessing.errors import record
|
| 29 |
+
|
| 30 |
+
from cosyvoice.utils.executor import Executor
|
| 31 |
+
from cosyvoice.utils.train_utils import (
|
| 32 |
+
init_distributed,
|
| 33 |
+
init_dataset_and_dataloader,
|
| 34 |
+
init_optimizer_and_scheduler,
|
| 35 |
+
init_summarywriter, save_model,
|
| 36 |
+
wrap_cuda_model, check_modify_and_save_config)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_args():
|
| 40 |
+
parser = argparse.ArgumentParser(description='training your network')
|
| 41 |
+
parser.add_argument('--train_engine',
|
| 42 |
+
default='torch_ddp',
|
| 43 |
+
choices=['torch_ddp', 'deepspeed'],
|
| 44 |
+
help='Engine for paralleled training')
|
| 45 |
+
parser.add_argument('--model', required=True, help='model which will be trained')
|
| 46 |
+
parser.add_argument('--config', required=True, help='config file')
|
| 47 |
+
parser.add_argument('--train_data', required=True, help='train data file')
|
| 48 |
+
parser.add_argument('--cv_data', required=True, help='cv data file')
|
| 49 |
+
parser.add_argument('--checkpoint', help='checkpoint model')
|
| 50 |
+
parser.add_argument('--model_dir', required=True, help='save model dir')
|
| 51 |
+
parser.add_argument('--tensorboard_dir',
|
| 52 |
+
default='tensorboard',
|
| 53 |
+
help='tensorboard log dir')
|
| 54 |
+
parser.add_argument('--ddp.dist_backend',
|
| 55 |
+
dest='dist_backend',
|
| 56 |
+
default='nccl',
|
| 57 |
+
choices=['nccl', 'gloo'],
|
| 58 |
+
help='distributed backend')
|
| 59 |
+
parser.add_argument('--num_workers',
|
| 60 |
+
default=0,
|
| 61 |
+
type=int,
|
| 62 |
+
help='num of subprocess workers for reading')
|
| 63 |
+
parser.add_argument('--prefetch',
|
| 64 |
+
default=100,
|
| 65 |
+
type=int,
|
| 66 |
+
help='prefetch number')
|
| 67 |
+
parser.add_argument('--pin_memory',
|
| 68 |
+
action='store_true',
|
| 69 |
+
default=False,
|
| 70 |
+
help='Use pinned memory buffers used for reading')
|
| 71 |
+
parser.add_argument('--use_amp',
|
| 72 |
+
action='store_true',
|
| 73 |
+
default=False,
|
| 74 |
+
help='Use automatic mixed precision training')
|
| 75 |
+
parser.add_argument('--deepspeed.save_states',
|
| 76 |
+
dest='save_states',
|
| 77 |
+
default='model_only',
|
| 78 |
+
choices=['model_only', 'model+optimizer'],
|
| 79 |
+
help='save model/optimizer states')
|
| 80 |
+
parser.add_argument('--timeout',
|
| 81 |
+
default=60,
|
| 82 |
+
type=int,
|
| 83 |
+
help='timeout (in seconds) of cosyvoice_join.')
|
| 84 |
+
parser = deepspeed.add_config_arguments(parser)
|
| 85 |
+
args = parser.parse_args()
|
| 86 |
+
return args
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@record
|
| 90 |
+
def main():
|
| 91 |
+
args = get_args()
|
| 92 |
+
logging.basicConfig(level=logging.DEBUG,
|
| 93 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
| 94 |
+
# gan train has some special initialization logic
|
| 95 |
+
gan = True if args.model == 'hifigan' else False
|
| 96 |
+
|
| 97 |
+
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
|
| 98 |
+
if gan is True:
|
| 99 |
+
override_dict.pop('hift')
|
| 100 |
+
with open(args.config, 'r') as f:
|
| 101 |
+
configs = load_hyperpyyaml(f, overrides=override_dict)
|
| 102 |
+
if gan is True:
|
| 103 |
+
configs['train_conf'] = configs['train_conf_gan']
|
| 104 |
+
configs['train_conf'].update(vars(args))
|
| 105 |
+
|
| 106 |
+
# Init env for ddp
|
| 107 |
+
init_distributed(args)
|
| 108 |
+
|
| 109 |
+
# Get dataset & dataloader
|
| 110 |
+
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
|
| 111 |
+
init_dataset_and_dataloader(args, configs, gan)
|
| 112 |
+
|
| 113 |
+
# Do some sanity checks and save config to arsg.model_dir
|
| 114 |
+
configs = check_modify_and_save_config(args, configs)
|
| 115 |
+
|
| 116 |
+
# Tensorboard summary
|
| 117 |
+
writer = init_summarywriter(args)
|
| 118 |
+
|
| 119 |
+
# load checkpoint
|
| 120 |
+
model = configs[args.model]
|
| 121 |
+
start_step, start_epoch = 0, -1
|
| 122 |
+
if args.checkpoint is not None:
|
| 123 |
+
if os.path.exists(args.checkpoint):
|
| 124 |
+
state_dict = torch.load(args.checkpoint, map_location='cpu')
|
| 125 |
+
model.load_state_dict(state_dict, strict=False)
|
| 126 |
+
if 'step' in state_dict:
|
| 127 |
+
start_step = state_dict['step']
|
| 128 |
+
if 'epoch' in state_dict:
|
| 129 |
+
start_epoch = state_dict['epoch']
|
| 130 |
+
else:
|
| 131 |
+
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
|
| 132 |
+
|
| 133 |
+
# Dispatch model from cpu to gpu
|
| 134 |
+
model = wrap_cuda_model(args, model)
|
| 135 |
+
|
| 136 |
+
# Get optimizer & scheduler
|
| 137 |
+
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
|
| 138 |
+
scheduler.set_step(start_step)
|
| 139 |
+
if scheduler_d is not None:
|
| 140 |
+
scheduler_d.set_step(start_step)
|
| 141 |
+
|
| 142 |
+
# Save init checkpoints
|
| 143 |
+
info_dict = deepcopy(configs['train_conf'])
|
| 144 |
+
info_dict['step'] = start_step
|
| 145 |
+
info_dict['epoch'] = start_epoch
|
| 146 |
+
save_model(model, 'init', info_dict)
|
| 147 |
+
|
| 148 |
+
# Get executor
|
| 149 |
+
executor = Executor(gan=gan)
|
| 150 |
+
executor.step = start_step
|
| 151 |
+
|
| 152 |
+
# Init scaler, used for pytorch amp mixed precision training
|
| 153 |
+
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
|
| 154 |
+
print('start step {} start epoch {}'.format(start_step, start_epoch))
|
| 155 |
+
# Start training loop
|
| 156 |
+
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
|
| 157 |
+
executor.epoch = epoch
|
| 158 |
+
train_dataset.set_epoch(epoch)
|
| 159 |
+
dist.barrier()
|
| 160 |
+
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
|
| 161 |
+
if gan is True:
|
| 162 |
+
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
| 163 |
+
writer, info_dict, scaler, group_join)
|
| 164 |
+
else:
|
| 165 |
+
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
|
| 166 |
+
dist.destroy_process_group(group_join)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
if __name__ == '__main__':
|
| 170 |
+
main()
|
cosyvoice/cli/__init__.py
ADDED
|
File without changes
|
cosyvoice/cli/cosyvoice.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
import time
|
| 16 |
+
from typing import Generator
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from hyperpyyaml import load_hyperpyyaml
|
| 19 |
+
from modelscope import snapshot_download
|
| 20 |
+
import torch
|
| 21 |
+
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
| 22 |
+
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, VllmCosyVoice2Model
|
| 23 |
+
from cosyvoice.utils.file_utils import logging
|
| 24 |
+
from cosyvoice.utils.class_utils import get_model_type
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class CosyVoice:
|
| 28 |
+
|
| 29 |
+
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
|
| 30 |
+
self.instruct = True if '-Instruct' in model_dir else False
|
| 31 |
+
self.model_dir = model_dir
|
| 32 |
+
self.fp16 = fp16
|
| 33 |
+
if not os.path.exists(model_dir):
|
| 34 |
+
model_dir = snapshot_download(model_dir)
|
| 35 |
+
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
| 36 |
+
configs = load_hyperpyyaml(f)
|
| 37 |
+
assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
|
| 38 |
+
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
| 39 |
+
configs['feat_extractor'],
|
| 40 |
+
'{}/campplus.onnx'.format(model_dir),
|
| 41 |
+
'{}/speech_tokenizer_v1.onnx'.format(model_dir),
|
| 42 |
+
'{}/spk2info.pt'.format(model_dir),
|
| 43 |
+
configs['allowed_special'])
|
| 44 |
+
self.sample_rate = configs['sample_rate']
|
| 45 |
+
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
| 46 |
+
load_jit, load_trt, fp16 = False, False, False
|
| 47 |
+
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
| 48 |
+
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
|
| 49 |
+
self.model.load('{}/llm.pt'.format(model_dir),
|
| 50 |
+
'{}/flow.pt'.format(model_dir),
|
| 51 |
+
'{}/hift.pt'.format(model_dir))
|
| 52 |
+
if load_jit:
|
| 53 |
+
self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
| 54 |
+
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
| 55 |
+
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
| 56 |
+
if load_trt:
|
| 57 |
+
self.estimator_count = configs.get('estimator_count', 1)
|
| 58 |
+
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
| 59 |
+
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
| 60 |
+
self.fp16, self.estimator_count)
|
| 61 |
+
del configs
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def list_available_spks(self):
|
| 65 |
+
spks = list(self.frontend.spk2info.keys())
|
| 66 |
+
return spks
|
| 67 |
+
|
| 68 |
+
def add_spk_info(self, spk_id, spk_info):
|
| 69 |
+
self.frontend.add_spk_info(spk_id, spk_info)
|
| 70 |
+
|
| 71 |
+
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
| 72 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 73 |
+
model_input = self.frontend.frontend_sft(i, spk_id)
|
| 74 |
+
start_time = time.time()
|
| 75 |
+
logging.info('synthesis text {}'.format(i))
|
| 76 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 77 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 78 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 79 |
+
yield model_output
|
| 80 |
+
start_time = time.time()
|
| 81 |
+
|
| 82 |
+
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
| 83 |
+
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
|
| 84 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 85 |
+
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
|
| 86 |
+
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
|
| 87 |
+
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
|
| 88 |
+
start_time = time.time()
|
| 89 |
+
logging.info('synthesis text {}'.format(i))
|
| 90 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 91 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 92 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 93 |
+
yield model_output
|
| 94 |
+
start_time = time.time()
|
| 95 |
+
|
| 96 |
+
def inference_zero_shot_by_spk_id(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
| 97 |
+
"""使用预定义的说话人执行 zero_shot 推理"""
|
| 98 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 99 |
+
model_input = self.frontend.frontend_zero_shot_by_spk_id(i, spk_id)
|
| 100 |
+
start_time = time.time()
|
| 101 |
+
last_time = start_time
|
| 102 |
+
chunk_index = 0
|
| 103 |
+
logging.info('synthesis text {}'.format(i))
|
| 104 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 105 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 106 |
+
logging.info('yield speech index:{}, len {:.2f}, rtf {:.3f}, cost {:.3f}s, all cost time {:.3f}s'.format(
|
| 107 |
+
chunk_index, speech_len, (time.time()-last_time)/speech_len, time.time()-last_time, time.time()-start_time))
|
| 108 |
+
yield model_output
|
| 109 |
+
last_time = time.time()
|
| 110 |
+
chunk_index += 1
|
| 111 |
+
|
| 112 |
+
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
| 113 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 114 |
+
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
|
| 115 |
+
start_time = time.time()
|
| 116 |
+
logging.info('synthesis text {}'.format(i))
|
| 117 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 118 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 119 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 120 |
+
yield model_output
|
| 121 |
+
start_time = time.time()
|
| 122 |
+
|
| 123 |
+
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
|
| 124 |
+
assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
|
| 125 |
+
if self.instruct is False:
|
| 126 |
+
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
| 127 |
+
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
|
| 128 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 129 |
+
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
| 130 |
+
start_time = time.time()
|
| 131 |
+
logging.info('synthesis text {}'.format(i))
|
| 132 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 133 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 134 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 135 |
+
yield model_output
|
| 136 |
+
start_time = time.time()
|
| 137 |
+
|
| 138 |
+
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
|
| 139 |
+
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
|
| 140 |
+
start_time = time.time()
|
| 141 |
+
for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
|
| 142 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 143 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 144 |
+
yield model_output
|
| 145 |
+
start_time = time.time()
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class CosyVoice2(CosyVoice):
|
| 149 |
+
|
| 150 |
+
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_vllm=False):
|
| 151 |
+
self.instruct = True if '-Instruct' in model_dir else False
|
| 152 |
+
self.model_dir = model_dir
|
| 153 |
+
self.fp16 = fp16
|
| 154 |
+
if not os.path.exists(model_dir):
|
| 155 |
+
model_dir = snapshot_download(model_dir)
|
| 156 |
+
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
| 157 |
+
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
| 158 |
+
assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
|
| 159 |
+
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
| 160 |
+
configs['feat_extractor'],
|
| 161 |
+
'{}/campplus.onnx'.format(model_dir),
|
| 162 |
+
'{}/speech_tokenizer_v2.onnx'.format(model_dir),
|
| 163 |
+
'{}/spk2info.pt'.format(model_dir),
|
| 164 |
+
configs['allowed_special'])
|
| 165 |
+
self.sample_rate = configs['sample_rate']
|
| 166 |
+
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
| 167 |
+
load_jit, load_trt, fp16 = False, False, False
|
| 168 |
+
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
| 169 |
+
if use_vllm:
|
| 170 |
+
try:
|
| 171 |
+
self.model = VllmCosyVoice2Model(model_dir, configs['flow'], configs['hift'], fp16)
|
| 172 |
+
except Exception as e:
|
| 173 |
+
logging.warning(f'use vllm inference failed. \n{e}')
|
| 174 |
+
raise e
|
| 175 |
+
else:
|
| 176 |
+
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
| 177 |
+
self.model.load('{}/llm.pt'.format(model_dir),
|
| 178 |
+
'{}/flow.pt'.format(model_dir),
|
| 179 |
+
'{}/hift.pt'.format(model_dir))
|
| 180 |
+
if load_jit:
|
| 181 |
+
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
| 182 |
+
if load_trt:
|
| 183 |
+
self.estimator_count = configs.get('estimator_count', 1)
|
| 184 |
+
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
| 185 |
+
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
| 186 |
+
self.fp16, self.estimator_count)
|
| 187 |
+
del configs
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def inference_instruct(self, *args, **kwargs):
|
| 191 |
+
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
|
| 192 |
+
|
| 193 |
+
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
| 194 |
+
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
|
| 195 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 196 |
+
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
|
| 197 |
+
start_time = time.time()
|
| 198 |
+
logging.info('synthesis text {}'.format(i))
|
| 199 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 200 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 201 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 202 |
+
yield model_output
|
| 203 |
+
start_time = time.time()
|
| 204 |
+
|
| 205 |
+
def inference_instruct2_by_spk_id(self, tts_text, instruct_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
| 206 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 207 |
+
model_input = self.frontend.frontend_instruct2_by_spk_id(i, instruct_text, spk_id)
|
| 208 |
+
start_time = time.time()
|
| 209 |
+
logging.info('synthesis text {}'.format(i))
|
| 210 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 211 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 212 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 213 |
+
yield model_output
|
| 214 |
+
start_time = time.time()
|
cosyvoice/cli/frontend.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from functools import partial
|
| 15 |
+
from typing import Generator, Optional
|
| 16 |
+
import json
|
| 17 |
+
import onnxruntime
|
| 18 |
+
import torch
|
| 19 |
+
import numpy as np
|
| 20 |
+
import whisper
|
| 21 |
+
from typing import Callable
|
| 22 |
+
import torchaudio.compliance.kaldi as kaldi
|
| 23 |
+
import torchaudio
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
import inflect
|
| 27 |
+
from pydantic import BaseModel, ConfigDict
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
import ttsfrd
|
| 31 |
+
use_ttsfrd = True
|
| 32 |
+
except ImportError:
|
| 33 |
+
print("failed to import ttsfrd, use WeTextProcessing instead")
|
| 34 |
+
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
| 35 |
+
from tn.english.normalizer import Normalizer as EnNormalizer
|
| 36 |
+
use_ttsfrd = False
|
| 37 |
+
from cosyvoice.utils.file_utils import logging
|
| 38 |
+
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class SpeakerInfo(BaseModel):
|
| 42 |
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
| 43 |
+
|
| 44 |
+
name: Optional[str] = None
|
| 45 |
+
spk_id: str
|
| 46 |
+
prompt_text: str
|
| 47 |
+
prompt_text_token: torch.Tensor
|
| 48 |
+
speech_feat: torch.Tensor
|
| 49 |
+
speech_token: torch.Tensor
|
| 50 |
+
embedding: torch.Tensor
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class CosyVoiceFrontEnd:
|
| 54 |
+
|
| 55 |
+
def __init__(self,
|
| 56 |
+
get_tokenizer: Callable,
|
| 57 |
+
feat_extractor: Callable,
|
| 58 |
+
campplus_model: str,
|
| 59 |
+
speech_tokenizer_model: str,
|
| 60 |
+
spk2info: str = '',
|
| 61 |
+
allowed_special: str = 'all'):
|
| 62 |
+
self.tokenizer = get_tokenizer()
|
| 63 |
+
self.feat_extractor = feat_extractor
|
| 64 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 65 |
+
option = onnxruntime.SessionOptions()
|
| 66 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 67 |
+
option.intra_op_num_threads = 1
|
| 68 |
+
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
| 69 |
+
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
| 70 |
+
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
| 71 |
+
"CPUExecutionProvider"])
|
| 72 |
+
self.spk2info_path = spk2info
|
| 73 |
+
if os.path.exists(spk2info):
|
| 74 |
+
self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=False)
|
| 75 |
+
else:
|
| 76 |
+
self.spk2info = {}
|
| 77 |
+
self.allowed_special = allowed_special
|
| 78 |
+
self.use_ttsfrd = use_ttsfrd
|
| 79 |
+
if self.use_ttsfrd:
|
| 80 |
+
self.frd = ttsfrd.TtsFrontendEngine()
|
| 81 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 82 |
+
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
|
| 83 |
+
'failed to initialize ttsfrd resource'
|
| 84 |
+
self.frd.set_lang_type('pinyinvg')
|
| 85 |
+
else:
|
| 86 |
+
# self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
|
| 87 |
+
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=False)
|
| 88 |
+
self.en_tn_model = EnNormalizer()
|
| 89 |
+
self.inflect_parser = inflect.engine()
|
| 90 |
+
|
| 91 |
+
def _extract_text_token(self, text):
|
| 92 |
+
if isinstance(text, Generator):
|
| 93 |
+
logging.info('get tts_text generator, will return _extract_text_token_generator!')
|
| 94 |
+
# NOTE add a dummy text_token_len for compatibility
|
| 95 |
+
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
|
| 96 |
+
else:
|
| 97 |
+
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
| 98 |
+
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
| 99 |
+
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
| 100 |
+
return text_token, text_token_len
|
| 101 |
+
|
| 102 |
+
def _extract_text_token_generator(self, text_generator):
|
| 103 |
+
for text in text_generator:
|
| 104 |
+
text_token, _ = self._extract_text_token(text)
|
| 105 |
+
for i in range(text_token.shape[1]):
|
| 106 |
+
yield text_token[:, i: i + 1]
|
| 107 |
+
|
| 108 |
+
def _extract_speech_token(self, speech):
|
| 109 |
+
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
| 110 |
+
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
| 111 |
+
speech_token = self.speech_tokenizer_session.run(None,
|
| 112 |
+
{self.speech_tokenizer_session.get_inputs()[0].name:
|
| 113 |
+
feat.detach().cpu().numpy(),
|
| 114 |
+
self.speech_tokenizer_session.get_inputs()[1].name:
|
| 115 |
+
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
| 116 |
+
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
| 117 |
+
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
| 118 |
+
return speech_token, speech_token_len
|
| 119 |
+
|
| 120 |
+
def _extract_spk_embedding(self, speech):
|
| 121 |
+
feat = kaldi.fbank(speech,
|
| 122 |
+
num_mel_bins=80,
|
| 123 |
+
dither=0,
|
| 124 |
+
sample_frequency=16000)
|
| 125 |
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
| 126 |
+
embedding = self.campplus_session.run(None,
|
| 127 |
+
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
| 128 |
+
embedding = torch.tensor([embedding]).to(self.device)
|
| 129 |
+
return embedding
|
| 130 |
+
|
| 131 |
+
def _extract_speech_feat(self, speech):
|
| 132 |
+
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
| 133 |
+
speech_feat = speech_feat.unsqueeze(dim=0)
|
| 134 |
+
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
| 135 |
+
return speech_feat, speech_feat_len
|
| 136 |
+
|
| 137 |
+
def text_normalize(self, text, split=True, text_frontend=True):
|
| 138 |
+
if isinstance(text, Generator):
|
| 139 |
+
logging.info('get tts_text generator, will skip text_normalize!')
|
| 140 |
+
return [text]
|
| 141 |
+
if text_frontend is False:
|
| 142 |
+
return [text] if split is True else text
|
| 143 |
+
text = text.strip()
|
| 144 |
+
if self.use_ttsfrd:
|
| 145 |
+
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
| 146 |
+
text = ''.join(texts)
|
| 147 |
+
else:
|
| 148 |
+
if contains_chinese(text):
|
| 149 |
+
text = self.zh_tn_model.normalize(text)
|
| 150 |
+
text = text.replace("\n", "")
|
| 151 |
+
text = replace_blank(text)
|
| 152 |
+
text = replace_corner_mark(text)
|
| 153 |
+
text = text.replace(".", "。")
|
| 154 |
+
text = text.replace(" - ", ",")
|
| 155 |
+
text = remove_bracket(text)
|
| 156 |
+
text = re.sub(r'[,,、]+$', '。', text)
|
| 157 |
+
if not split:
|
| 158 |
+
return text
|
| 159 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
| 160 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
| 161 |
+
else:
|
| 162 |
+
text = self.en_tn_model.normalize(text)
|
| 163 |
+
text = spell_out_number(text, self.inflect_parser)
|
| 164 |
+
if not split:
|
| 165 |
+
return text
|
| 166 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
| 167 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
| 168 |
+
texts = [i for i in texts if not is_only_punctuation(i)]
|
| 169 |
+
return texts if split is True else text
|
| 170 |
+
|
| 171 |
+
def frontend_sft(self, tts_text, spk_id):
|
| 172 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
| 173 |
+
embedding = self.spk2info[spk_id]['embedding']
|
| 174 |
+
assert embedding is not None
|
| 175 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
| 176 |
+
return model_input
|
| 177 |
+
|
| 178 |
+
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
|
| 179 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
| 180 |
+
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
| 181 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
| 182 |
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
| 183 |
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
| 184 |
+
if resample_rate == 24000:
|
| 185 |
+
# cosyvoice2, force speech_feat % speech_token = 2
|
| 186 |
+
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
| 187 |
+
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
|
| 188 |
+
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
| 189 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
| 190 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
| 191 |
+
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
| 192 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
| 193 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
| 194 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
| 195 |
+
'llm_embedding': embedding, 'flow_embedding': embedding}
|
| 196 |
+
return model_input
|
| 197 |
+
|
| 198 |
+
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
|
| 199 |
+
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
|
| 200 |
+
# in cross lingual mode, we remove prompt in llm
|
| 201 |
+
del model_input['prompt_text']
|
| 202 |
+
del model_input['prompt_text_len']
|
| 203 |
+
del model_input['llm_prompt_speech_token']
|
| 204 |
+
del model_input['llm_prompt_speech_token_len']
|
| 205 |
+
return model_input
|
| 206 |
+
|
| 207 |
+
def frontend_instruct(self, tts_text, spk_id, instruct_text):
|
| 208 |
+
model_input = self.frontend_sft(tts_text, spk_id)
|
| 209 |
+
# in instruct mode, we remove spk_embedding in llm due to information leakage
|
| 210 |
+
del model_input['llm_embedding']
|
| 211 |
+
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
|
| 212 |
+
model_input['prompt_text'] = instruct_text_token
|
| 213 |
+
model_input['prompt_text_len'] = instruct_text_token_len
|
| 214 |
+
return model_input
|
| 215 |
+
|
| 216 |
+
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
|
| 217 |
+
model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate)
|
| 218 |
+
del model_input['llm_prompt_speech_token']
|
| 219 |
+
del model_input['llm_prompt_speech_token_len']
|
| 220 |
+
return model_input
|
| 221 |
+
|
| 222 |
+
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
|
| 223 |
+
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
| 224 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
| 225 |
+
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
| 226 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
| 227 |
+
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
| 228 |
+
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
|
| 229 |
+
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
|
| 230 |
+
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
|
| 231 |
+
'flow_embedding': embedding}
|
| 232 |
+
return model_input
|
| 233 |
+
|
| 234 |
+
def generate_spk_info(self, spk_id: str, prompt_text: str, prompt_speech_16k: torch.Tensor, resample_rate:int=24000, name: str=None):
|
| 235 |
+
assert isinstance(spk_id, str)
|
| 236 |
+
assert spk_id not in self.spk2info, "spk_id already exists"
|
| 237 |
+
prompt_text_token, _ = self._extract_text_token(prompt_text)
|
| 238 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
| 239 |
+
speech_feat, _ = self._extract_speech_feat(prompt_speech_resample)
|
| 240 |
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
| 241 |
+
if resample_rate == 24000:
|
| 242 |
+
# cosyvoice2, force speech_feat % speech_token = 2
|
| 243 |
+
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
| 244 |
+
speech_feat = speech_feat[:, :2 * token_len]
|
| 245 |
+
speech_token = speech_token[:, :token_len]
|
| 246 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
| 247 |
+
spk_info = SpeakerInfo(
|
| 248 |
+
name=name,
|
| 249 |
+
spk_id=spk_id,
|
| 250 |
+
prompt_text=prompt_text,
|
| 251 |
+
prompt_text_token=prompt_text_token,
|
| 252 |
+
speech_feat=speech_feat,
|
| 253 |
+
speech_token=speech_token,
|
| 254 |
+
embedding=embedding,
|
| 255 |
+
)
|
| 256 |
+
self.add_spk_info(spk_id, spk_info)
|
| 257 |
+
|
| 258 |
+
def add_spk_info(self, spk_id: str, spk_info: dict|SpeakerInfo):
|
| 259 |
+
if isinstance(spk_info, BaseModel):
|
| 260 |
+
spk_info = spk_info.model_dump()
|
| 261 |
+
self.spk2info[spk_id] = spk_info
|
| 262 |
+
if self.spk2info_path:
|
| 263 |
+
torch.save(self.spk2info, self.spk2info_path)
|
| 264 |
+
|
| 265 |
+
def frontend_instruct2_by_spk_id(self, tts_text, instruct_text, spk_id):
|
| 266 |
+
assert spk_id in self.spk2info
|
| 267 |
+
tts_text_token, _ = self._extract_text_token(tts_text)
|
| 268 |
+
prompt_text_token, _ = self._extract_text_token(instruct_text + '<|endofprompt|>')
|
| 269 |
+
model_input = {'text': tts_text_token,
|
| 270 |
+
'prompt_text': prompt_text_token,
|
| 271 |
+
'flow_prompt_speech_token': self.spk2info[spk_id]['speech_token'],
|
| 272 |
+
'prompt_speech_feat': self.spk2info[spk_id]['speech_feat'],
|
| 273 |
+
'llm_embedding': self.spk2info[spk_id]['embedding'],
|
| 274 |
+
'flow_embedding': self.spk2info[spk_id]['embedding'],
|
| 275 |
+
}
|
| 276 |
+
return model_input
|
| 277 |
+
|
| 278 |
+
def frontend_zero_shot_by_spk_id(self, tts_text, spk_id):
|
| 279 |
+
assert spk_id in self.spk2info
|
| 280 |
+
tts_text_token, _ = self._extract_text_token(tts_text)
|
| 281 |
+
model_input = {'text': tts_text_token,
|
| 282 |
+
'prompt_text': self.spk2info[spk_id]['prompt_text_token'],
|
| 283 |
+
'llm_prompt_speech_token': self.spk2info[spk_id]['speech_token'],
|
| 284 |
+
'flow_prompt_speech_token': self.spk2info[spk_id]['speech_token'],
|
| 285 |
+
'prompt_speech_feat': self.spk2info[spk_id]['speech_feat'],
|
| 286 |
+
'llm_embedding': self.spk2info[spk_id]['embedding'],
|
| 287 |
+
'flow_embedding': self.spk2info[spk_id]['embedding']
|
| 288 |
+
}
|
| 289 |
+
return model_input
|
cosyvoice/cli/model.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
from typing import Generator
|
| 16 |
+
import torch
|
| 17 |
+
import numpy as np
|
| 18 |
+
import threading
|
| 19 |
+
import time
|
| 20 |
+
from torch.nn import functional as F
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
import uuid
|
| 23 |
+
from cosyvoice.utils.common import fade_in_out
|
| 24 |
+
from cosyvoice.utils.file_utils import convert_onnx_to_trt
|
| 25 |
+
from cosyvoice.flow.flow_matching import EstimatorWrapper
|
| 26 |
+
import queue
|
| 27 |
+
|
| 28 |
+
class CosyVoiceModel:
|
| 29 |
+
|
| 30 |
+
def __init__(self,
|
| 31 |
+
llm: torch.nn.Module,
|
| 32 |
+
flow: torch.nn.Module,
|
| 33 |
+
hift: torch.nn.Module,
|
| 34 |
+
fp16: bool):
|
| 35 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 36 |
+
self.llm = llm
|
| 37 |
+
self.flow = flow
|
| 38 |
+
self.hift = hift
|
| 39 |
+
self.fp16 = fp16
|
| 40 |
+
self.llm.fp16 = fp16
|
| 41 |
+
self.flow.fp16 = fp16
|
| 42 |
+
if self.fp16 is True:
|
| 43 |
+
self.llm.half()
|
| 44 |
+
self.flow.half()
|
| 45 |
+
self.token_min_hop_len = 2 * self.flow.input_frame_rate
|
| 46 |
+
self.token_max_hop_len = 4 * self.flow.input_frame_rate
|
| 47 |
+
self.token_overlap_len = 20
|
| 48 |
+
# here we fix set flow.decoder.estimator.static_chunk_size = 0 for compatibability
|
| 49 |
+
self.flow.decoder.estimator.static_chunk_size = 0
|
| 50 |
+
# mel fade in out
|
| 51 |
+
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
|
| 52 |
+
self.mel_window = np.hamming(2 * self.mel_overlap_len)
|
| 53 |
+
# hift cache
|
| 54 |
+
self.mel_cache_len = 20
|
| 55 |
+
self.source_cache_len = int(self.mel_cache_len * 256)
|
| 56 |
+
# speech fade in out
|
| 57 |
+
self.speech_window = np.hamming(2 * self.source_cache_len)
|
| 58 |
+
# rtf and decoding related
|
| 59 |
+
self.stream_scale_factor = 1
|
| 60 |
+
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
| 61 |
+
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
| 62 |
+
self.lock = threading.Lock()
|
| 63 |
+
# dict used to store session related variable
|
| 64 |
+
self.tts_speech_token_dict = {}
|
| 65 |
+
self.llm_end_dict = {}
|
| 66 |
+
self.mel_overlap_dict = {}
|
| 67 |
+
self.flow_cache_dict = {}
|
| 68 |
+
self.hift_cache_dict = {}
|
| 69 |
+
|
| 70 |
+
self.stream_context_pool = queue.Queue()
|
| 71 |
+
for _ in range(10):
|
| 72 |
+
self.stream_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
|
| 73 |
+
|
| 74 |
+
self.is_cuda_available = torch.cuda.is_available()
|
| 75 |
+
|
| 76 |
+
def load(self, llm_model, flow_model, hift_model):
|
| 77 |
+
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
| 78 |
+
self.llm.to(self.device).eval()
|
| 79 |
+
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
| 80 |
+
self.flow.to(self.device).eval()
|
| 81 |
+
# in case hift_model is a hifigan model
|
| 82 |
+
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
| 83 |
+
self.hift.load_state_dict(hift_state_dict, strict=True)
|
| 84 |
+
self.hift.to(self.device).eval()
|
| 85 |
+
|
| 86 |
+
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
| 87 |
+
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
|
| 88 |
+
self.llm.text_encoder = llm_text_encoder
|
| 89 |
+
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
|
| 90 |
+
self.llm.llm = llm_llm
|
| 91 |
+
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
| 92 |
+
self.flow.encoder = flow_encoder
|
| 93 |
+
|
| 94 |
+
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16, estimator_count=8): # use 8 estimators
|
| 95 |
+
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
| 96 |
+
if not os.path.exists(flow_decoder_estimator_model):
|
| 97 |
+
convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
|
| 98 |
+
if os.path.getsize(flow_decoder_estimator_model) == 0:
|
| 99 |
+
raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
|
| 100 |
+
del self.flow.decoder.estimator
|
| 101 |
+
import tensorrt as trt
|
| 102 |
+
with open(flow_decoder_estimator_model, 'rb') as f:
|
| 103 |
+
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
| 104 |
+
if self.flow.decoder.estimator_engine is None:
|
| 105 |
+
raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
|
| 106 |
+
self.flow.decoder.estimator = EstimatorWrapper(self.flow.decoder.estimator_engine, estimator_count=estimator_count)
|
| 107 |
+
|
| 108 |
+
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
| 109 |
+
with self.llm_context:
|
| 110 |
+
if isinstance(text, Generator):
|
| 111 |
+
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
|
| 112 |
+
for i in self.llm.inference_bistream(text=text,
|
| 113 |
+
prompt_text=prompt_text.to(self.device),
|
| 114 |
+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
| 115 |
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
| 116 |
+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
| 117 |
+
embedding=llm_embedding.to(self.device)):
|
| 118 |
+
self.tts_speech_token_dict[uuid].append(i)
|
| 119 |
+
else:
|
| 120 |
+
for i in self.llm.inference(text=text.to(self.device),
|
| 121 |
+
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
| 122 |
+
prompt_text=prompt_text.to(self.device),
|
| 123 |
+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
| 124 |
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
| 125 |
+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
| 126 |
+
embedding=llm_embedding.to(self.device)):
|
| 127 |
+
self.tts_speech_token_dict[uuid].append(i)
|
| 128 |
+
self.llm_end_dict[uuid] = True
|
| 129 |
+
|
| 130 |
+
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
| 131 |
+
tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
|
| 132 |
+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
| 133 |
+
prompt_token=prompt_token.to(self.device),
|
| 134 |
+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
| 135 |
+
prompt_feat=prompt_feat.to(self.device),
|
| 136 |
+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
| 137 |
+
embedding=embedding.to(self.device),
|
| 138 |
+
flow_cache=self.flow_cache_dict[uuid])
|
| 139 |
+
self.flow_cache_dict[uuid] = flow_cache
|
| 140 |
+
|
| 141 |
+
# mel overlap fade in out
|
| 142 |
+
if self.mel_overlap_dict[uuid].shape[2] != 0:
|
| 143 |
+
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
|
| 144 |
+
# append hift cache
|
| 145 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 146 |
+
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
| 147 |
+
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
| 148 |
+
else:
|
| 149 |
+
hift_cache_source = torch.zeros(1, 1, 0)
|
| 150 |
+
# keep overlap mel and hift cache
|
| 151 |
+
if finalize is False:
|
| 152 |
+
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
|
| 153 |
+
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
|
| 154 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
| 155 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 156 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
| 157 |
+
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
| 158 |
+
'source': tts_source[:, :, -self.source_cache_len:],
|
| 159 |
+
'speech': tts_speech[:, -self.source_cache_len:]}
|
| 160 |
+
tts_speech = tts_speech[:, :-self.source_cache_len]
|
| 161 |
+
else:
|
| 162 |
+
if speed != 1.0:
|
| 163 |
+
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
| 164 |
+
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
| 165 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
| 166 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 167 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
| 168 |
+
return tts_speech
|
| 169 |
+
|
| 170 |
+
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
| 171 |
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
| 172 |
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
| 173 |
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
| 174 |
+
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
| 175 |
+
# this_uuid is used to track variables related to this inference thread
|
| 176 |
+
|
| 177 |
+
stream_context = self.stream_context_pool.get()
|
| 178 |
+
with stream_context:
|
| 179 |
+
|
| 180 |
+
this_uuid = str(uuid.uuid1())
|
| 181 |
+
with self.lock:
|
| 182 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
| 183 |
+
self.hift_cache_dict[this_uuid] = None
|
| 184 |
+
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
| 185 |
+
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
| 186 |
+
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
| 187 |
+
p.start()
|
| 188 |
+
if stream is True:
|
| 189 |
+
token_hop_len = self.token_min_hop_len
|
| 190 |
+
while True:
|
| 191 |
+
time.sleep(0.1)
|
| 192 |
+
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
| 193 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
| 194 |
+
.unsqueeze(dim=0)
|
| 195 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 196 |
+
prompt_token=flow_prompt_speech_token,
|
| 197 |
+
prompt_feat=prompt_speech_feat,
|
| 198 |
+
embedding=flow_embedding,
|
| 199 |
+
uuid=this_uuid,
|
| 200 |
+
finalize=False)
|
| 201 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 202 |
+
with self.lock:
|
| 203 |
+
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
| 204 |
+
# increase token_hop_len for better speech quality
|
| 205 |
+
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
| 206 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
| 207 |
+
break
|
| 208 |
+
p.join()
|
| 209 |
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
| 210 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 211 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 212 |
+
prompt_token=flow_prompt_speech_token,
|
| 213 |
+
prompt_feat=prompt_speech_feat,
|
| 214 |
+
embedding=flow_embedding,
|
| 215 |
+
uuid=this_uuid,
|
| 216 |
+
finalize=True)
|
| 217 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 218 |
+
else:
|
| 219 |
+
# deal with all tokens
|
| 220 |
+
p.join()
|
| 221 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 222 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 223 |
+
prompt_token=flow_prompt_speech_token,
|
| 224 |
+
prompt_feat=prompt_speech_feat,
|
| 225 |
+
embedding=flow_embedding,
|
| 226 |
+
uuid=this_uuid,
|
| 227 |
+
finalize=True,
|
| 228 |
+
speed=speed)
|
| 229 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 230 |
+
with self.lock:
|
| 231 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
| 232 |
+
self.llm_end_dict.pop(this_uuid)
|
| 233 |
+
self.mel_overlap_dict.pop(this_uuid)
|
| 234 |
+
self.hift_cache_dict.pop(this_uuid)
|
| 235 |
+
self.flow_cache_dict.pop(this_uuid)
|
| 236 |
+
|
| 237 |
+
self.synchronize_stream()
|
| 238 |
+
self.stream_context_pool.put(stream_context)
|
| 239 |
+
torch.cuda.empty_cache()
|
| 240 |
+
|
| 241 |
+
def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
|
| 242 |
+
# this_uuid is used to track variables related to this inference thread
|
| 243 |
+
this_uuid = str(uuid.uuid1())
|
| 244 |
+
with self.lock:
|
| 245 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
|
| 246 |
+
self.hift_cache_dict[this_uuid] = None
|
| 247 |
+
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
| 248 |
+
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
| 249 |
+
if stream is True:
|
| 250 |
+
token_hop_len = self.token_min_hop_len
|
| 251 |
+
while True:
|
| 252 |
+
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
| 253 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
| 254 |
+
.unsqueeze(dim=0)
|
| 255 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 256 |
+
prompt_token=flow_prompt_speech_token,
|
| 257 |
+
prompt_feat=prompt_speech_feat,
|
| 258 |
+
embedding=flow_embedding,
|
| 259 |
+
uuid=this_uuid,
|
| 260 |
+
finalize=False)
|
| 261 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 262 |
+
with self.lock:
|
| 263 |
+
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
| 264 |
+
# increase token_hop_len for better speech quality
|
| 265 |
+
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
| 266 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
| 267 |
+
break
|
| 268 |
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
| 269 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 270 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 271 |
+
prompt_token=flow_prompt_speech_token,
|
| 272 |
+
prompt_feat=prompt_speech_feat,
|
| 273 |
+
embedding=flow_embedding,
|
| 274 |
+
uuid=this_uuid,
|
| 275 |
+
finalize=True)
|
| 276 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 277 |
+
else:
|
| 278 |
+
# deal with all tokens
|
| 279 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 280 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 281 |
+
prompt_token=flow_prompt_speech_token,
|
| 282 |
+
prompt_feat=prompt_speech_feat,
|
| 283 |
+
embedding=flow_embedding,
|
| 284 |
+
uuid=this_uuid,
|
| 285 |
+
finalize=True,
|
| 286 |
+
speed=speed)
|
| 287 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 288 |
+
with self.lock:
|
| 289 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
| 290 |
+
self.llm_end_dict.pop(this_uuid)
|
| 291 |
+
self.mel_overlap_dict.pop(this_uuid)
|
| 292 |
+
self.hift_cache_dict.pop(this_uuid)
|
| 293 |
+
torch.cuda.empty_cache()
|
| 294 |
+
|
| 295 |
+
def synchronize_stream(self):
|
| 296 |
+
if self.is_cuda_available:
|
| 297 |
+
torch.cuda.current_stream().synchronize()
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class CosyVoice2Model(CosyVoiceModel):
|
| 301 |
+
|
| 302 |
+
def __init__(self,
|
| 303 |
+
llm: torch.nn.Module,
|
| 304 |
+
flow: torch.nn.Module,
|
| 305 |
+
hift: torch.nn.Module,
|
| 306 |
+
fp16: bool):
|
| 307 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 308 |
+
self.llm = llm
|
| 309 |
+
self.flow = flow
|
| 310 |
+
self.hift = hift
|
| 311 |
+
self.fp16 = fp16
|
| 312 |
+
self.llm.fp16 = fp16
|
| 313 |
+
self.flow.fp16 = fp16
|
| 314 |
+
if self.fp16 is True:
|
| 315 |
+
self.llm.half()
|
| 316 |
+
self.flow.half()
|
| 317 |
+
self.token_hop_len = 2 * self.flow.input_frame_rate
|
| 318 |
+
# here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
|
| 319 |
+
self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
|
| 320 |
+
self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
|
| 321 |
+
# hift cache
|
| 322 |
+
self.mel_cache_len = 8
|
| 323 |
+
self.source_cache_len = int(self.mel_cache_len * 480)
|
| 324 |
+
# speech fade in out
|
| 325 |
+
self.speech_window = np.hamming(2 * self.source_cache_len)
|
| 326 |
+
# rtf and decoding related
|
| 327 |
+
self.stream_scale_factor = 1
|
| 328 |
+
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
| 329 |
+
self.lock = threading.Lock()
|
| 330 |
+
# dict used to store session related variable
|
| 331 |
+
self.tts_speech_token_dict = {}
|
| 332 |
+
self.llm_end_dict = {}
|
| 333 |
+
self.hift_cache_dict = {}
|
| 334 |
+
|
| 335 |
+
self.stream_context_pool = queue.Queue()
|
| 336 |
+
for _ in range(10):
|
| 337 |
+
self.stream_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
|
| 338 |
+
|
| 339 |
+
self.is_cuda_available = torch.cuda.is_available()
|
| 340 |
+
|
| 341 |
+
def load_jit(self, flow_encoder_model):
|
| 342 |
+
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
| 343 |
+
self.flow.encoder = flow_encoder
|
| 344 |
+
|
| 345 |
+
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
| 346 |
+
|
| 347 |
+
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
| 348 |
+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
| 349 |
+
prompt_token=prompt_token.to(self.device),
|
| 350 |
+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
| 351 |
+
prompt_feat=prompt_feat.to(self.device),
|
| 352 |
+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
| 353 |
+
embedding=embedding.to(self.device),
|
| 354 |
+
finalize=finalize)
|
| 355 |
+
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
| 356 |
+
# append hift cache
|
| 357 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 358 |
+
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
| 359 |
+
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
| 360 |
+
else:
|
| 361 |
+
hift_cache_source = torch.zeros(1, 1, 0)
|
| 362 |
+
# keep overlap mel and hift cache
|
| 363 |
+
if finalize is False:
|
| 364 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
| 365 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 366 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
| 367 |
+
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
| 368 |
+
'source': tts_source[:, :, -self.source_cache_len:],
|
| 369 |
+
'speech': tts_speech[:, -self.source_cache_len:]}
|
| 370 |
+
tts_speech = tts_speech[:, :-self.source_cache_len]
|
| 371 |
+
else:
|
| 372 |
+
if speed != 1.0:
|
| 373 |
+
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
| 374 |
+
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
| 375 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
| 376 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 377 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
| 378 |
+
return tts_speech
|
| 379 |
+
|
| 380 |
+
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
| 381 |
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
| 382 |
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
| 383 |
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
| 384 |
+
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
| 385 |
+
# this_uuid is used to track variables related to this inference thread
|
| 386 |
+
self.synchronize_stream()
|
| 387 |
+
stream_context = self.stream_context_pool.get()
|
| 388 |
+
with stream_context:
|
| 389 |
+
|
| 390 |
+
this_uuid = str(uuid.uuid1())
|
| 391 |
+
with self.lock:
|
| 392 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
| 393 |
+
self.hift_cache_dict[this_uuid] = None
|
| 394 |
+
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
| 395 |
+
p.start()
|
| 396 |
+
if stream is True:
|
| 397 |
+
token_offset = 0
|
| 398 |
+
while True:
|
| 399 |
+
time.sleep(0.1)
|
| 400 |
+
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
|
| 401 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
|
| 402 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 403 |
+
prompt_token=flow_prompt_speech_token,
|
| 404 |
+
prompt_feat=prompt_speech_feat,
|
| 405 |
+
embedding=flow_embedding,
|
| 406 |
+
uuid=this_uuid,
|
| 407 |
+
token_offset=token_offset,
|
| 408 |
+
finalize=False)
|
| 409 |
+
token_offset += self.token_hop_len
|
| 410 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 411 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
|
| 412 |
+
break
|
| 413 |
+
p.join()
|
| 414 |
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
| 415 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 416 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 417 |
+
prompt_token=flow_prompt_speech_token,
|
| 418 |
+
prompt_feat=prompt_speech_feat,
|
| 419 |
+
embedding=flow_embedding,
|
| 420 |
+
uuid=this_uuid,
|
| 421 |
+
token_offset=token_offset,
|
| 422 |
+
finalize=True)
|
| 423 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 424 |
+
else:
|
| 425 |
+
# deal with all tokens
|
| 426 |
+
p.join()
|
| 427 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 428 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 429 |
+
prompt_token=flow_prompt_speech_token,
|
| 430 |
+
prompt_feat=prompt_speech_feat,
|
| 431 |
+
embedding=flow_embedding,
|
| 432 |
+
uuid=this_uuid,
|
| 433 |
+
token_offset=0,
|
| 434 |
+
finalize=True,
|
| 435 |
+
speed=speed)
|
| 436 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 437 |
+
with self.lock:
|
| 438 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
| 439 |
+
self.llm_end_dict.pop(this_uuid)
|
| 440 |
+
|
| 441 |
+
self.synchronize_stream()
|
| 442 |
+
self.stream_context_pool.put(stream_context)
|
| 443 |
+
torch.cuda.empty_cache()
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class VllmCosyVoice2Model(CosyVoice2Model):
|
| 447 |
+
def __init__(self,
|
| 448 |
+
model_dir: str,
|
| 449 |
+
flow: torch.nn.Module,
|
| 450 |
+
hift: torch.nn.Module,
|
| 451 |
+
fp16: bool):
|
| 452 |
+
try:
|
| 453 |
+
from cosyvoice.llm.llm_vllm import VllmQwen2LM
|
| 454 |
+
except Exception as e:
|
| 455 |
+
raise e
|
| 456 |
+
llm = VllmQwen2LM(model_dir)
|
| 457 |
+
super().__init__(llm,flow,hift,fp16)
|
| 458 |
+
|
| 459 |
+
def load(self, llm_model, flow_model, hift_model):
|
| 460 |
+
self.flow.load_state_dict(torch.load(flow_model, weights_only=True, map_location=self.device), strict=True)
|
| 461 |
+
self.flow.to(self.device).eval()
|
| 462 |
+
# in case hift_model is a hifigan model
|
| 463 |
+
hift_state_dict = {k.replace('generator.', ''): v for k, v in
|
| 464 |
+
torch.load(hift_model, weights_only=True, map_location=self.device).items()}
|
| 465 |
+
self.hift.load_state_dict(hift_state_dict, strict=True)
|
| 466 |
+
self.hift.to(self.device).eval()
|
cosyvoice/dataset/dataset.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
| 2 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import random
|
| 17 |
+
import json
|
| 18 |
+
import math
|
| 19 |
+
from functools import partial
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.distributed as dist
|
| 23 |
+
from torch.utils.data import IterableDataset
|
| 24 |
+
from cosyvoice.utils.file_utils import read_lists, read_json_lists
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Processor(IterableDataset):
|
| 28 |
+
|
| 29 |
+
def __init__(self, source, f, *args, **kw):
|
| 30 |
+
assert callable(f)
|
| 31 |
+
self.source = source
|
| 32 |
+
self.f = f
|
| 33 |
+
self.args = args
|
| 34 |
+
self.kw = kw
|
| 35 |
+
|
| 36 |
+
def set_epoch(self, epoch):
|
| 37 |
+
self.source.set_epoch(epoch)
|
| 38 |
+
|
| 39 |
+
def __iter__(self):
|
| 40 |
+
""" Return an iterator over the source dataset processed by the
|
| 41 |
+
given processor.
|
| 42 |
+
"""
|
| 43 |
+
assert self.source is not None
|
| 44 |
+
assert callable(self.f)
|
| 45 |
+
return self.f(iter(self.source), *self.args, **self.kw)
|
| 46 |
+
|
| 47 |
+
def apply(self, f):
|
| 48 |
+
assert callable(f)
|
| 49 |
+
return Processor(self, f, *self.args, **self.kw)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class DistributedSampler:
|
| 53 |
+
|
| 54 |
+
def __init__(self, shuffle=True, partition=True):
|
| 55 |
+
self.epoch = -1
|
| 56 |
+
self.update()
|
| 57 |
+
self.shuffle = shuffle
|
| 58 |
+
self.partition = partition
|
| 59 |
+
|
| 60 |
+
def update(self):
|
| 61 |
+
assert dist.is_available()
|
| 62 |
+
if dist.is_initialized():
|
| 63 |
+
self.rank = dist.get_rank()
|
| 64 |
+
self.world_size = dist.get_world_size()
|
| 65 |
+
else:
|
| 66 |
+
self.rank = 0
|
| 67 |
+
self.world_size = 1
|
| 68 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 69 |
+
if worker_info is None:
|
| 70 |
+
self.worker_id = 0
|
| 71 |
+
self.num_workers = 1
|
| 72 |
+
else:
|
| 73 |
+
self.worker_id = worker_info.id
|
| 74 |
+
self.num_workers = worker_info.num_workers
|
| 75 |
+
return dict(rank=self.rank,
|
| 76 |
+
world_size=self.world_size,
|
| 77 |
+
worker_id=self.worker_id,
|
| 78 |
+
num_workers=self.num_workers)
|
| 79 |
+
|
| 80 |
+
def set_epoch(self, epoch):
|
| 81 |
+
self.epoch = epoch
|
| 82 |
+
|
| 83 |
+
def sample(self, data):
|
| 84 |
+
""" Sample data according to rank/world_size/num_workers
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
data(List): input data list
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
List: data list after sample
|
| 91 |
+
"""
|
| 92 |
+
data = list(range(len(data)))
|
| 93 |
+
# force datalist even
|
| 94 |
+
if self.partition:
|
| 95 |
+
if self.shuffle:
|
| 96 |
+
random.Random(self.epoch).shuffle(data)
|
| 97 |
+
if len(data) < self.world_size:
|
| 98 |
+
data = data * math.ceil(self.world_size / len(data))
|
| 99 |
+
data = data[:self.world_size]
|
| 100 |
+
data = data[self.rank::self.world_size]
|
| 101 |
+
if len(data) < self.num_workers:
|
| 102 |
+
data = data * math.ceil(self.num_workers / len(data))
|
| 103 |
+
data = data[:self.num_workers]
|
| 104 |
+
data = data[self.worker_id::self.num_workers]
|
| 105 |
+
return data
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class DataList(IterableDataset):
|
| 109 |
+
|
| 110 |
+
def __init__(self, lists, shuffle=True, partition=True):
|
| 111 |
+
self.lists = lists
|
| 112 |
+
self.sampler = DistributedSampler(shuffle, partition)
|
| 113 |
+
|
| 114 |
+
def set_epoch(self, epoch):
|
| 115 |
+
self.sampler.set_epoch(epoch)
|
| 116 |
+
|
| 117 |
+
def __iter__(self):
|
| 118 |
+
sampler_info = self.sampler.update()
|
| 119 |
+
indexes = self.sampler.sample(self.lists)
|
| 120 |
+
for index in indexes:
|
| 121 |
+
data = dict(src=self.lists[index])
|
| 122 |
+
data.update(sampler_info)
|
| 123 |
+
yield data
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def Dataset(data_list_file,
|
| 127 |
+
data_pipeline,
|
| 128 |
+
mode='train',
|
| 129 |
+
gan=False,
|
| 130 |
+
shuffle=True,
|
| 131 |
+
partition=True,
|
| 132 |
+
tts_file='',
|
| 133 |
+
prompt_utt2data=''):
|
| 134 |
+
""" Construct dataset from arguments
|
| 135 |
+
|
| 136 |
+
We have two shuffle stage in the Dataset. The first is global
|
| 137 |
+
shuffle at shards tar/raw file level. The second is global shuffle
|
| 138 |
+
at training samples level.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
data_type(str): raw/shard
|
| 142 |
+
tokenizer (BaseTokenizer): tokenizer to tokenize
|
| 143 |
+
partition(bool): whether to do data partition in terms of rank
|
| 144 |
+
"""
|
| 145 |
+
assert mode in ['train', 'inference']
|
| 146 |
+
lists = read_lists(data_list_file)
|
| 147 |
+
if mode == 'inference':
|
| 148 |
+
with open(tts_file) as f:
|
| 149 |
+
tts_data = json.load(f)
|
| 150 |
+
utt2lists = read_json_lists(prompt_utt2data)
|
| 151 |
+
# filter unnecessary file in inference mode
|
| 152 |
+
lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
|
| 153 |
+
dataset = DataList(lists,
|
| 154 |
+
shuffle=shuffle,
|
| 155 |
+
partition=partition)
|
| 156 |
+
if mode == 'inference':
|
| 157 |
+
# map partial arg to parquet_opener func in inference mode
|
| 158 |
+
data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
|
| 159 |
+
if gan is True:
|
| 160 |
+
# map partial arg to padding func in gan mode
|
| 161 |
+
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
|
| 162 |
+
for func in data_pipeline:
|
| 163 |
+
dataset = Processor(dataset, func, mode=mode)
|
| 164 |
+
return dataset
|
cosyvoice/hifigan/__pycache__/generator.cpython-310.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
cosyvoice/llm/llm.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import Dict, Optional, Callable, List, Generator
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from transformers import Qwen2ForCausalLM
|
| 19 |
+
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
|
| 20 |
+
from cosyvoice.utils.common import IGNORE_ID
|
| 21 |
+
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
|
| 22 |
+
from cosyvoice.utils.common import th_accuracy
|
| 23 |
+
from cosyvoice.utils.file_utils import logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TransformerLM(torch.nn.Module):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
text_encoder_input_size: int,
|
| 30 |
+
llm_input_size: int,
|
| 31 |
+
llm_output_size: int,
|
| 32 |
+
text_token_size: int,
|
| 33 |
+
speech_token_size: int,
|
| 34 |
+
text_encoder: torch.nn.Module,
|
| 35 |
+
llm: torch.nn.Module,
|
| 36 |
+
sampling: Callable,
|
| 37 |
+
length_normalized_loss: bool = True,
|
| 38 |
+
lsm_weight: float = 0.0,
|
| 39 |
+
spk_embed_dim: int = 192,
|
| 40 |
+
):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.llm_input_size = llm_input_size
|
| 43 |
+
self.speech_token_size = speech_token_size
|
| 44 |
+
# 1. build text token inputs related modules
|
| 45 |
+
self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
|
| 46 |
+
self.text_encoder = text_encoder
|
| 47 |
+
self.text_encoder_affine_layer = nn.Linear(
|
| 48 |
+
self.text_encoder.output_size(),
|
| 49 |
+
llm_input_size
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# 2. build speech token language model related modules
|
| 53 |
+
self.sos_eos = 0
|
| 54 |
+
self.task_id = 1
|
| 55 |
+
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
|
| 56 |
+
self.llm = llm
|
| 57 |
+
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
|
| 58 |
+
self.criterion_ce = LabelSmoothingLoss(
|
| 59 |
+
size=speech_token_size + 1,
|
| 60 |
+
padding_idx=IGNORE_ID,
|
| 61 |
+
smoothing=lsm_weight,
|
| 62 |
+
normalize_length=length_normalized_loss,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# 3. [Optional] build speech token related modules
|
| 66 |
+
self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
|
| 67 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
|
| 68 |
+
|
| 69 |
+
# 4. sampling method
|
| 70 |
+
self.sampling = sampling
|
| 71 |
+
|
| 72 |
+
def encode(
|
| 73 |
+
self,
|
| 74 |
+
text: torch.Tensor,
|
| 75 |
+
text_lengths: torch.Tensor,
|
| 76 |
+
):
|
| 77 |
+
encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
|
| 78 |
+
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
|
| 79 |
+
encoder_out = self.text_encoder_affine_layer(encoder_out)
|
| 80 |
+
return encoder_out, encoder_out_lens
|
| 81 |
+
|
| 82 |
+
def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
|
| 83 |
+
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
| 84 |
+
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
| 85 |
+
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
|
| 86 |
+
for i in range(len(text_token))]
|
| 87 |
+
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
| 88 |
+
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
|
| 89 |
+
return lm_input, lm_input_len
|
| 90 |
+
|
| 91 |
+
def forward(
|
| 92 |
+
self,
|
| 93 |
+
batch: dict,
|
| 94 |
+
device: torch.device,
|
| 95 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
| 96 |
+
"""
|
| 97 |
+
Args:
|
| 98 |
+
text: (B, L, D)
|
| 99 |
+
text_lengths: (B,)
|
| 100 |
+
audio: (B, T, N) or (B, T)
|
| 101 |
+
audio_lengths: (B,)
|
| 102 |
+
"""
|
| 103 |
+
text_token = batch['text_token'].to(device)
|
| 104 |
+
text_token_len = batch['text_token_len'].to(device)
|
| 105 |
+
speech_token = batch['speech_token'].to(device)
|
| 106 |
+
speech_token_len = batch['speech_token_len'].to(device)
|
| 107 |
+
embedding = batch['embedding'].to(device)
|
| 108 |
+
|
| 109 |
+
# 1. prepare llm_target
|
| 110 |
+
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
|
| 111 |
+
[self.speech_token_size]) for i in range(text_token.size(0))]
|
| 112 |
+
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
|
| 113 |
+
|
| 114 |
+
# 1. encode text_token
|
| 115 |
+
text_token = self.text_embedding(text_token)
|
| 116 |
+
text_token, text_token_len = self.encode(text_token, text_token_len)
|
| 117 |
+
|
| 118 |
+
# 2. embedding projection
|
| 119 |
+
embedding = F.normalize(embedding, dim=1)
|
| 120 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
| 121 |
+
embedding = embedding.unsqueeze(1)
|
| 122 |
+
|
| 123 |
+
# 3. eos and task_id
|
| 124 |
+
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
| 125 |
+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
| 126 |
+
|
| 127 |
+
# 4. encode speech_token
|
| 128 |
+
speech_token = self.speech_embedding(speech_token)
|
| 129 |
+
|
| 130 |
+
# 5. unpad and pad
|
| 131 |
+
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
|
| 132 |
+
task_id_emb, speech_token, speech_token_len)
|
| 133 |
+
|
| 134 |
+
# 6. run lm forward
|
| 135 |
+
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
| 136 |
+
logits = self.llm_decoder(lm_output)
|
| 137 |
+
loss = self.criterion_ce(logits, lm_target)
|
| 138 |
+
acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
|
| 139 |
+
return {'loss': loss, 'acc': acc}
|
| 140 |
+
|
| 141 |
+
def sampling_ids(
|
| 142 |
+
self,
|
| 143 |
+
weighted_scores: torch.Tensor,
|
| 144 |
+
decoded_tokens: List,
|
| 145 |
+
sampling: int,
|
| 146 |
+
ignore_eos: bool = True,
|
| 147 |
+
):
|
| 148 |
+
num_trials, max_trials = 0, 100
|
| 149 |
+
while True:
|
| 150 |
+
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
| 151 |
+
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
| 152 |
+
break
|
| 153 |
+
num_trials += 1
|
| 154 |
+
if num_trials > max_trials:
|
| 155 |
+
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
|
| 156 |
+
return top_ids
|
| 157 |
+
|
| 158 |
+
@torch.inference_mode()
|
| 159 |
+
def inference(
|
| 160 |
+
self,
|
| 161 |
+
text: torch.Tensor,
|
| 162 |
+
text_len: torch.Tensor,
|
| 163 |
+
prompt_text: torch.Tensor,
|
| 164 |
+
prompt_text_len: torch.Tensor,
|
| 165 |
+
prompt_speech_token: torch.Tensor,
|
| 166 |
+
prompt_speech_token_len: torch.Tensor,
|
| 167 |
+
embedding: torch.Tensor,
|
| 168 |
+
sampling: int = 25,
|
| 169 |
+
max_token_text_ratio: float = 20,
|
| 170 |
+
min_token_text_ratio: float = 2,
|
| 171 |
+
) -> Generator[torch.Tensor, None, None]:
|
| 172 |
+
if self.fp16 is True:
|
| 173 |
+
embedding = embedding.half()
|
| 174 |
+
|
| 175 |
+
device = text.device
|
| 176 |
+
text = torch.concat([prompt_text, text], dim=1)
|
| 177 |
+
text_len += prompt_text_len
|
| 178 |
+
text = self.text_embedding(text)
|
| 179 |
+
|
| 180 |
+
# 1. encode text
|
| 181 |
+
text, text_len = self.encode(text, text_len)
|
| 182 |
+
|
| 183 |
+
# 2. encode embedding
|
| 184 |
+
if embedding.shape[0] != 0:
|
| 185 |
+
embedding = F.normalize(embedding, dim=1)
|
| 186 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
| 187 |
+
embedding = embedding.unsqueeze(dim=1)
|
| 188 |
+
else:
|
| 189 |
+
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
|
| 190 |
+
|
| 191 |
+
# 3. concat llm_input
|
| 192 |
+
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
| 193 |
+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
| 194 |
+
if prompt_speech_token_len != 0:
|
| 195 |
+
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
| 196 |
+
else:
|
| 197 |
+
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
| 198 |
+
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
| 199 |
+
|
| 200 |
+
# 4. cal min/max_length
|
| 201 |
+
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
| 202 |
+
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
| 203 |
+
|
| 204 |
+
# 5. step by step decode
|
| 205 |
+
out_tokens = []
|
| 206 |
+
offset = 0
|
| 207 |
+
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
|
| 208 |
+
for i in range(max_len):
|
| 209 |
+
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
|
| 210 |
+
att_cache=att_cache, cnn_cache=cnn_cache,
|
| 211 |
+
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
|
| 212 |
+
device=lm_input.device)).to(torch.bool))
|
| 213 |
+
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
| 214 |
+
# force continue decode first token
|
| 215 |
+
if i == 0:
|
| 216 |
+
logp[:, self.speech_token_size] = -float('inf')
|
| 217 |
+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
| 218 |
+
if top_ids == self.speech_token_size:
|
| 219 |
+
break
|
| 220 |
+
# in stream mode, yield token one by one
|
| 221 |
+
yield top_ids
|
| 222 |
+
out_tokens.append(top_ids)
|
| 223 |
+
offset += lm_input.size(1)
|
| 224 |
+
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class Qwen2Encoder(torch.nn.Module):
|
| 228 |
+
def __init__(self, pretrain_path):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
|
| 231 |
+
|
| 232 |
+
def forward_one_step(self, xs, masks, cache=None):
|
| 233 |
+
input_masks = masks[:, -1, :]
|
| 234 |
+
outs = self.model(
|
| 235 |
+
inputs_embeds=xs,
|
| 236 |
+
attention_mask=input_masks,
|
| 237 |
+
output_hidden_states=True,
|
| 238 |
+
return_dict=True,
|
| 239 |
+
use_cache=True,
|
| 240 |
+
past_key_values=cache,
|
| 241 |
+
)
|
| 242 |
+
xs = outs.hidden_states[-1]
|
| 243 |
+
new_cache = outs.past_key_values
|
| 244 |
+
return xs, new_cache
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class Qwen2LM(TransformerLM):
|
| 248 |
+
def __init__(
|
| 249 |
+
self,
|
| 250 |
+
llm_input_size: int,
|
| 251 |
+
llm_output_size: int,
|
| 252 |
+
speech_token_size: int,
|
| 253 |
+
llm: torch.nn.Module,
|
| 254 |
+
sampling: Callable,
|
| 255 |
+
length_normalized_loss: bool = True,
|
| 256 |
+
lsm_weight: float = 0.0,
|
| 257 |
+
mix_ratio: List[int] = [5, 15],
|
| 258 |
+
):
|
| 259 |
+
torch.nn.Module.__init__(self)
|
| 260 |
+
self.llm_input_size = llm_input_size
|
| 261 |
+
self.llm_output_size = llm_output_size
|
| 262 |
+
self.speech_token_size = speech_token_size
|
| 263 |
+
|
| 264 |
+
# 2. build speech token language model related modules
|
| 265 |
+
self.sos_eos = 0
|
| 266 |
+
self.task_id = 1
|
| 267 |
+
self.fill_token = 2
|
| 268 |
+
|
| 269 |
+
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
|
| 270 |
+
self.llm = llm
|
| 271 |
+
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
|
| 272 |
+
self.criterion_ce = LabelSmoothingLoss(
|
| 273 |
+
size=speech_token_size + 3,
|
| 274 |
+
padding_idx=IGNORE_ID,
|
| 275 |
+
smoothing=lsm_weight,
|
| 276 |
+
normalize_length=length_normalized_loss,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# 3. [Optional] build speech token related modules
|
| 280 |
+
self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
|
| 281 |
+
|
| 282 |
+
# 4. sampling method
|
| 283 |
+
self.sampling = sampling
|
| 284 |
+
self.mix_ratio = mix_ratio
|
| 285 |
+
|
| 286 |
+
@torch.inference_mode()
|
| 287 |
+
def inference(
|
| 288 |
+
self,
|
| 289 |
+
text: torch.Tensor,
|
| 290 |
+
text_len: torch.Tensor,
|
| 291 |
+
prompt_text: torch.Tensor,
|
| 292 |
+
prompt_text_len: torch.Tensor,
|
| 293 |
+
prompt_speech_token: torch.Tensor,
|
| 294 |
+
prompt_speech_token_len: torch.Tensor,
|
| 295 |
+
embedding: torch.Tensor,
|
| 296 |
+
sampling: int = 25,
|
| 297 |
+
max_token_text_ratio: float = 20,
|
| 298 |
+
min_token_text_ratio: float = 2,
|
| 299 |
+
) -> Generator[torch.Tensor, None, None]:
|
| 300 |
+
device = text.device
|
| 301 |
+
text = torch.concat([prompt_text, text], dim=1)
|
| 302 |
+
text_len += prompt_text_len
|
| 303 |
+
text = self.llm.model.model.embed_tokens(text)
|
| 304 |
+
|
| 305 |
+
# 3. concat llm_input
|
| 306 |
+
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
| 307 |
+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
| 308 |
+
if prompt_speech_token_len != 0:
|
| 309 |
+
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
| 310 |
+
else:
|
| 311 |
+
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
| 312 |
+
lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
| 313 |
+
|
| 314 |
+
# 4. cal min/max_length
|
| 315 |
+
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
| 316 |
+
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
| 317 |
+
|
| 318 |
+
# 5. step by step decode
|
| 319 |
+
out_tokens = []
|
| 320 |
+
cache = None
|
| 321 |
+
for i in range(max_len):
|
| 322 |
+
y_pred, cache = self.llm.forward_one_step(lm_input,
|
| 323 |
+
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
|
| 324 |
+
cache=cache)
|
| 325 |
+
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
| 326 |
+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
| 327 |
+
if top_ids == self.speech_token_size:
|
| 328 |
+
break
|
| 329 |
+
if top_ids > self.speech_token_size:
|
| 330 |
+
continue
|
| 331 |
+
# in stream mode, yield token one by one
|
| 332 |
+
yield top_ids
|
| 333 |
+
out_tokens.append(top_ids)
|
| 334 |
+
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
| 335 |
+
|
| 336 |
+
@torch.inference_mode()
|
| 337 |
+
def inference_bistream(
|
| 338 |
+
self,
|
| 339 |
+
text: Generator,
|
| 340 |
+
prompt_text: torch.Tensor,
|
| 341 |
+
prompt_text_len: torch.Tensor,
|
| 342 |
+
prompt_speech_token: torch.Tensor,
|
| 343 |
+
prompt_speech_token_len: torch.Tensor,
|
| 344 |
+
embedding: torch.Tensor,
|
| 345 |
+
sampling: int = 25,
|
| 346 |
+
max_token_text_ratio: float = 20,
|
| 347 |
+
min_token_text_ratio: float = 2,
|
| 348 |
+
) -> Generator[torch.Tensor, None, None]:
|
| 349 |
+
|
| 350 |
+
device = prompt_text.device
|
| 351 |
+
# 1. prepare input
|
| 352 |
+
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
| 353 |
+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
| 354 |
+
if prompt_speech_token_len != 0:
|
| 355 |
+
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
| 356 |
+
else:
|
| 357 |
+
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
|
| 358 |
+
lm_input = torch.concat([sos_eos_emb], dim=1)
|
| 359 |
+
|
| 360 |
+
# 2. iterate text
|
| 361 |
+
out_tokens = []
|
| 362 |
+
cache = None
|
| 363 |
+
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
|
| 364 |
+
text_cache = self.llm.model.model.embed_tokens(prompt_text)
|
| 365 |
+
next_fill_index = -1
|
| 366 |
+
for this_text in text:
|
| 367 |
+
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
|
| 368 |
+
# prompt_speech_token_emb not empty, try append to lm_input
|
| 369 |
+
while prompt_speech_token_emb.size(1) != 0:
|
| 370 |
+
if text_cache.size(1) >= self.mix_ratio[0]:
|
| 371 |
+
lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
|
| 372 |
+
logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
|
| 373 |
+
lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
|
| 374 |
+
text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
|
| 375 |
+
else:
|
| 376 |
+
logging.info('not enough text token to decode, wait for more')
|
| 377 |
+
break
|
| 378 |
+
# no prompt_speech_token_emb remain, can decode some speech token
|
| 379 |
+
if prompt_speech_token_emb.size(1) == 0:
|
| 380 |
+
if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
|
| 381 |
+
logging.info('get fill token, need to append more text token')
|
| 382 |
+
if text_cache.size(1) >= self.mix_ratio[0]:
|
| 383 |
+
lm_input_text = text_cache[:, :self.mix_ratio[0]]
|
| 384 |
+
logging.info('append {} text token'.format(lm_input_text.size(1)))
|
| 385 |
+
if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
|
| 386 |
+
lm_input = lm_input_text
|
| 387 |
+
else:
|
| 388 |
+
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
|
| 389 |
+
text_cache = text_cache[:, self.mix_ratio[0]:]
|
| 390 |
+
else:
|
| 391 |
+
logging.info('not enough text token to decode, wait for more')
|
| 392 |
+
continue
|
| 393 |
+
while True:
|
| 394 |
+
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
|
| 395 |
+
y_pred, cache = self.llm.forward_one_step(lm_input,
|
| 396 |
+
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
|
| 397 |
+
cache=cache)
|
| 398 |
+
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
| 399 |
+
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
|
| 400 |
+
top_ids = self.speech_token_size + 2
|
| 401 |
+
next_fill_index += (self.mix_ratio[1] + 1)
|
| 402 |
+
else:
|
| 403 |
+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
|
| 404 |
+
if top_ids == self.speech_token_size + 2:
|
| 405 |
+
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
|
| 406 |
+
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
|
| 407 |
+
out_tokens.append(top_ids)
|
| 408 |
+
if top_ids >= self.speech_token_size:
|
| 409 |
+
if top_ids == self.speech_token_size + 2:
|
| 410 |
+
break
|
| 411 |
+
else:
|
| 412 |
+
raise ValueError('should not get token {}'.format(top_ids))
|
| 413 |
+
yield top_ids
|
| 414 |
+
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
| 415 |
+
|
| 416 |
+
# 3. final decode
|
| 417 |
+
lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
|
| 418 |
+
logging.info('no more text token, decode until met eos')
|
| 419 |
+
while True:
|
| 420 |
+
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
|
| 421 |
+
y_pred, cache = self.llm.forward_one_step(lm_input,
|
| 422 |
+
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
|
| 423 |
+
cache=cache)
|
| 424 |
+
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
| 425 |
+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
|
| 426 |
+
out_tokens.append(top_ids)
|
| 427 |
+
if top_ids >= self.speech_token_size:
|
| 428 |
+
if top_ids == self.speech_token_size:
|
| 429 |
+
break
|
| 430 |
+
else:
|
| 431 |
+
raise ValueError('should not get token {}'.format(top_ids))
|
| 432 |
+
# in stream mode, yield token one by one
|
| 433 |
+
yield top_ids
|
| 434 |
+
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
cosyvoice/llm/llm_vllm.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import time
|
| 15 |
+
import queue
|
| 16 |
+
import asyncio
|
| 17 |
+
import threading
|
| 18 |
+
from typing import List, Generator, AsyncGenerator
|
| 19 |
+
import torch
|
| 20 |
+
from cosyvoice.utils.file_utils import logging
|
| 21 |
+
from cosyvoice.llm.llm import Qwen2LM
|
| 22 |
+
|
| 23 |
+
# 启用vllm V1版本
|
| 24 |
+
import os
|
| 25 |
+
os.environ["VLLM_USE_V1"] = '1'
|
| 26 |
+
from vllm import ModelRegistry
|
| 27 |
+
from vllm import LLMEngine, AsyncLLMEngine, CompletionOutput
|
| 28 |
+
from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs
|
| 29 |
+
from vllm.sampling_params import SamplingParams
|
| 30 |
+
|
| 31 |
+
from cosyvoice.llm.vllm_use_cosyvoice2_model import CosyVoice2Model as CosyVoice2LLM
|
| 32 |
+
ModelRegistry.register_model("CosyVoice2Model", CosyVoice2LLM)
|
| 33 |
+
|
| 34 |
+
# EngineArgs
|
| 35 |
+
ENGINE_ARGS = {
|
| 36 |
+
"block_size": 16,
|
| 37 |
+
"swap_space": 0,
|
| 38 |
+
# "enforce_eager": True,
|
| 39 |
+
"gpu_memory_utilization": 0.4,
|
| 40 |
+
"max_num_batched_tokens": 1024,
|
| 41 |
+
"max_model_len": 1024,
|
| 42 |
+
"max_num_seqs": 256,
|
| 43 |
+
"disable_log_requests": True,
|
| 44 |
+
"disable_log_stats": True,
|
| 45 |
+
"dtype": "float16"
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
from vllm.sampling_params import RequestOutputKind
|
| 49 |
+
# SamplingParams
|
| 50 |
+
SAMPLING_PARAMS = {
|
| 51 |
+
"temperature": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token
|
| 52 |
+
"top_p": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token
|
| 53 |
+
"top_k": 25,
|
| 54 |
+
# "min_tokens": 80, # 不支持设置最小的tokens数量设置,开启后vllm直接崩溃,无法启动
|
| 55 |
+
# "presence_penalty": 1.0, # 不支持设置
|
| 56 |
+
# "frequency_penalty": 0.0, # 不支持设置
|
| 57 |
+
"max_tokens": 1024,
|
| 58 |
+
"detokenize": False, # 目前 vllm 0.7.3 v1版本中设置无效,待后续版本更新后减少计算
|
| 59 |
+
"ignore_eos": False,
|
| 60 |
+
"output_kind": RequestOutputKind.DELTA # 设置为DELTA,如调整该参数,请同时调整llm_inference的处理代码
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
def tensor_to_list(tensor: torch.tensor):
|
| 64 |
+
return tensor.view(-1).cpu().numpy().tolist()
|
| 65 |
+
|
| 66 |
+
class VllmQwen2LM(Qwen2LM):
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
model_dir,
|
| 70 |
+
mix_ratio: List[int] = [5, 15],
|
| 71 |
+
):
|
| 72 |
+
self.fp16 = False
|
| 73 |
+
self.half = lambda: None
|
| 74 |
+
self.mix_ratio = mix_ratio
|
| 75 |
+
# ---------------------------------------------
|
| 76 |
+
# vllm engine 的参数配置
|
| 77 |
+
engine_args = AsyncEngineArgs(
|
| 78 |
+
model=model_dir,
|
| 79 |
+
**ENGINE_ARGS,
|
| 80 |
+
)
|
| 81 |
+
self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args)
|
| 82 |
+
|
| 83 |
+
self.speech_token_size = 6564 # 6561 + 3
|
| 84 |
+
self.llm_token_size = 151936 # llm vocab_size
|
| 85 |
+
self.sos_eos_token_id = self.speech_token_size + self.llm_token_size + 1
|
| 86 |
+
self.task_token_id = self.sos_eos_token_id + 1
|
| 87 |
+
self.zero_token_id = self.task_token_id + 1
|
| 88 |
+
|
| 89 |
+
# vllm 的推理任务需要在一个固定的事件循环中,因此启动一个后台线程运行转用于推理任务
|
| 90 |
+
self.loop = asyncio.new_event_loop()
|
| 91 |
+
self.loop_thread = threading.Thread(target=self._run_event_loop, daemon=True)
|
| 92 |
+
self.loop_thread.start()
|
| 93 |
+
|
| 94 |
+
def _run_event_loop(self):
|
| 95 |
+
asyncio.set_event_loop(self.loop)
|
| 96 |
+
self.loop.run_forever()
|
| 97 |
+
|
| 98 |
+
async def async_llm_inference(self, out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens):
|
| 99 |
+
sampling_params = SamplingParams(**SAMPLING_PARAMS)
|
| 100 |
+
sampling_params.stop_token_ids = stop_token_ids or [6561]
|
| 101 |
+
if max_tokens:
|
| 102 |
+
sampling_params.max_tokens = max_tokens
|
| 103 |
+
async for output in self.llm_engine.generate(
|
| 104 |
+
{
|
| 105 |
+
"prompt_token_ids": prompt_token_ids,
|
| 106 |
+
},
|
| 107 |
+
sampling_params=sampling_params,
|
| 108 |
+
request_id=request_id or f"{time.time()}",
|
| 109 |
+
):
|
| 110 |
+
out_queue.put((output.outputs[0], output.finished))
|
| 111 |
+
|
| 112 |
+
def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None):
|
| 113 |
+
out_queue = queue.Queue()
|
| 114 |
+
asyncio.run_coroutine_threadsafe(
|
| 115 |
+
self.async_llm_inference(out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens), self.loop
|
| 116 |
+
)
|
| 117 |
+
# 接收 out_queue 返回的结果
|
| 118 |
+
finished = False
|
| 119 |
+
while not finished:
|
| 120 |
+
(output, finished) = out_queue.get_nowait() if not out_queue.empty() else out_queue.get()
|
| 121 |
+
yield output
|
| 122 |
+
|
| 123 |
+
def inference(
|
| 124 |
+
self,
|
| 125 |
+
text: torch.Tensor,
|
| 126 |
+
text_len: torch.Tensor,
|
| 127 |
+
prompt_text: torch.Tensor,
|
| 128 |
+
prompt_text_len: torch.Tensor,
|
| 129 |
+
prompt_speech_token: torch.Tensor,
|
| 130 |
+
prompt_speech_token_len: torch.Tensor,
|
| 131 |
+
embedding: torch.Tensor,
|
| 132 |
+
sampling: int = 25,
|
| 133 |
+
max_token_text_ratio: float = 20,
|
| 134 |
+
min_token_text_ratio: float = 2,
|
| 135 |
+
) -> Generator[torch.Tensor|int, None, None]:
|
| 136 |
+
prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
|
| 137 |
+
prompt_speech_token = tensor_to_list(prompt_speech_token)
|
| 138 |
+
|
| 139 |
+
text = tensor_to_list(text + torch.tensor(6564))
|
| 140 |
+
prompt_token_ids = [self.sos_eos_token_id] + prompt_text + text + \
|
| 141 |
+
[self.task_token_id] + prompt_speech_token
|
| 142 |
+
max_tokens = len(text) * 20
|
| 143 |
+
for output in self.llm_inference(
|
| 144 |
+
prompt_token_ids,
|
| 145 |
+
stop_token_ids=[6561],
|
| 146 |
+
max_tokens=max_tokens,
|
| 147 |
+
):
|
| 148 |
+
if output.token_ids[-1] == 6561:
|
| 149 |
+
need_add_tokens = output.token_ids[:-1]
|
| 150 |
+
else:
|
| 151 |
+
need_add_tokens = output.token_ids
|
| 152 |
+
for token in need_add_tokens:
|
| 153 |
+
yield token
|
| 154 |
+
|
| 155 |
+
def inference_bistream(
|
| 156 |
+
self,
|
| 157 |
+
text: Generator,
|
| 158 |
+
prompt_text: torch.Tensor,
|
| 159 |
+
prompt_text_len: torch.Tensor,
|
| 160 |
+
prompt_speech_token: torch.Tensor,
|
| 161 |
+
prompt_speech_token_len: torch.Tensor,
|
| 162 |
+
embedding: torch.Tensor,
|
| 163 |
+
sampling: int = 25,
|
| 164 |
+
max_token_text_ratio: float = 20,
|
| 165 |
+
min_token_text_ratio: float = 2,
|
| 166 |
+
) -> Generator[torch.Tensor, None, None]:
|
| 167 |
+
prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
|
| 168 |
+
prompt_speech_token = tensor_to_list(prompt_speech_token)
|
| 169 |
+
|
| 170 |
+
last_tokens = []
|
| 171 |
+
prompt_token_ids = [self.sos_eos_token_id]
|
| 172 |
+
text_tokens_cache = prompt_text
|
| 173 |
+
for this_text in text:
|
| 174 |
+
this_text = tensor_to_list(this_text + torch.tensor(6564))
|
| 175 |
+
# text need tokens
|
| 176 |
+
assert isinstance(this_text, list), "text need token ids List[int]."
|
| 177 |
+
text_tokens_cache += this_text
|
| 178 |
+
while len(prompt_speech_token) != 0:
|
| 179 |
+
if len(text_tokens_cache) >= self.mix_ratio[0]:
|
| 180 |
+
text_input_token = text_tokens_cache[:self.mix_ratio[0]]
|
| 181 |
+
speech_input_token = prompt_speech_token[:self.mix_ratio[1]]
|
| 182 |
+
prompt_token_ids += text_input_token + speech_input_token
|
| 183 |
+
# reset the last cache
|
| 184 |
+
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
|
| 185 |
+
prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:]
|
| 186 |
+
else:
|
| 187 |
+
break
|
| 188 |
+
if len(prompt_speech_token) == 0:
|
| 189 |
+
if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1:
|
| 190 |
+
if len(text_tokens_cache) >= self.mix_ratio[0]:
|
| 191 |
+
text_tokens_temp = text_tokens_cache[:self.mix_ratio[0]]
|
| 192 |
+
prompt_token_ids += text_tokens_temp
|
| 193 |
+
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
|
| 194 |
+
else:
|
| 195 |
+
continue
|
| 196 |
+
for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6563]):
|
| 197 |
+
last_tokens = output.token_ids
|
| 198 |
+
if last_tokens[-1] == 6563:
|
| 199 |
+
need_add_tokens = last_tokens[:-1]
|
| 200 |
+
else:
|
| 201 |
+
need_add_tokens = last_tokens
|
| 202 |
+
for token in need_add_tokens:
|
| 203 |
+
yield token
|
| 204 |
+
prompt_token_ids.extend(need_add_tokens)
|
| 205 |
+
prompt_token_ids += text_tokens_cache + [self.task_token_id]
|
| 206 |
+
for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6561]):
|
| 207 |
+
if output.token_ids[-1] == 6561:
|
| 208 |
+
need_add_tokens = output.token_ids[:-1]
|
| 209 |
+
else:
|
| 210 |
+
need_add_tokens = output.token_ids
|
| 211 |
+
for token in need_add_tokens:
|
| 212 |
+
yield token
|
cosyvoice/llm/vllm_use_cosyvoice2_model.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
|
| 5 |
+
# Copyright 2024 The Qwen team.
|
| 6 |
+
# Copyright 2023 The vLLM team.
|
| 7 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 10 |
+
# and OPT implementations in this library. It has been modified from its
|
| 11 |
+
# original forms to accommodate minor architectural differences compared
|
| 12 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 13 |
+
#
|
| 14 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 15 |
+
# you may not use this file except in compliance with the License.
|
| 16 |
+
# You may obtain a copy of the License at
|
| 17 |
+
#
|
| 18 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 19 |
+
#
|
| 20 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 21 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 22 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 23 |
+
# See the License for the specific language governing permissions and
|
| 24 |
+
# limitations under the License.
|
| 25 |
+
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
| 26 |
+
from typing import Iterable, List, Optional, Set, Tuple, Union, Iterator, overload, TypedDict, Mapping, Any
|
| 27 |
+
from typing_extensions import TypeVar
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
from torch import nn
|
| 31 |
+
|
| 32 |
+
from vllm.attention import AttentionMetadata
|
| 33 |
+
from vllm.config import VllmConfig
|
| 34 |
+
from vllm.logger import init_logger
|
| 35 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 36 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 37 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
| 38 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 39 |
+
from vllm.sequence import IntermediateTensors
|
| 40 |
+
|
| 41 |
+
from vllm.model_executor.models.interfaces import T
|
| 42 |
+
from vllm.model_executor.models.qwen2 import Qwen2Model
|
| 43 |
+
|
| 44 |
+
from vllm.model_executor.models.utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
|
| 45 |
+
|
| 46 |
+
logger = init_logger(__name__)
|
| 47 |
+
|
| 48 |
+
IGNORE_ID = -1
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class CosyVoice2Model(nn.Module):
|
| 52 |
+
|
| 53 |
+
packed_modules_mapping = {
|
| 54 |
+
"qkv_proj": [
|
| 55 |
+
"q_proj",
|
| 56 |
+
"k_proj",
|
| 57 |
+
"v_proj",
|
| 58 |
+
],
|
| 59 |
+
"gate_up_proj": [
|
| 60 |
+
"gate_proj",
|
| 61 |
+
"up_proj",
|
| 62 |
+
],
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 66 |
+
super().__init__()
|
| 67 |
+
config = vllm_config.model_config.hf_config
|
| 68 |
+
quant_config = vllm_config.quant_config
|
| 69 |
+
lora_config = vllm_config.lora_config
|
| 70 |
+
|
| 71 |
+
self.config = config
|
| 72 |
+
self.lora_config = lora_config
|
| 73 |
+
self.quant_config = quant_config
|
| 74 |
+
|
| 75 |
+
self.llm_input_size = 896
|
| 76 |
+
self.llm_output_size = 896
|
| 77 |
+
|
| 78 |
+
self.speech_token_size = 6561+3
|
| 79 |
+
self.llm_token_size = config.vocab_size
|
| 80 |
+
|
| 81 |
+
# 2. build speech token language model related modules
|
| 82 |
+
self.sos_eos = 0
|
| 83 |
+
self.task_id = 1
|
| 84 |
+
self.fill_token = 2
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
self.allow_patterns_overrides = ["llm.*"]
|
| 88 |
+
self.llm_embedding = torch.nn.Embedding(2, self.llm_input_size)
|
| 89 |
+
self.model = Qwen2Model(vllm_config=vllm_config,
|
| 90 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 91 |
+
|
| 92 |
+
# self.llm_decoder = nn.Linear(self.llm_output_size, self.speech_token_size)
|
| 93 |
+
self.llm_decoder = ParallelLMHead(self.speech_token_size,
|
| 94 |
+
self.llm_output_size,
|
| 95 |
+
bias=True,
|
| 96 |
+
quant_config=quant_config,
|
| 97 |
+
prefix=maybe_prefix(
|
| 98 |
+
prefix, "llm_decoder"))
|
| 99 |
+
self.logits_processor = LogitsProcessor(self.speech_token_size)
|
| 100 |
+
|
| 101 |
+
# length_normalized_loss: bool = True,
|
| 102 |
+
# lsm_weight: float = 0.0,
|
| 103 |
+
# self.criterion_ce = LabelSmoothingLoss(
|
| 104 |
+
# size=self.speech_token_size,
|
| 105 |
+
# padding_idx=IGNORE_ID,
|
| 106 |
+
# smoothing=lsm_weight,
|
| 107 |
+
# normalize_length=length_normalized_loss,
|
| 108 |
+
# )
|
| 109 |
+
|
| 110 |
+
# 3. [Optional] build speech token related modules
|
| 111 |
+
self.speech_embedding = torch.nn.Embedding(self.speech_token_size, self.llm_input_size)
|
| 112 |
+
|
| 113 |
+
# 4. sampling method
|
| 114 |
+
## use vllm sampling method
|
| 115 |
+
self.sampler = get_sampler()
|
| 116 |
+
self.make_empty_intermediate_tensors = (
|
| 117 |
+
self.model.make_empty_intermediate_tensors)
|
| 118 |
+
|
| 119 |
+
self.mix_ratio: List[int] = [5, 15]
|
| 120 |
+
|
| 121 |
+
# 定义特殊token常量
|
| 122 |
+
self.llm_token_id_delta = torch.tensor(self.speech_token_size, dtype=torch.int32)
|
| 123 |
+
self.sos_eos_token_id = torch.tensor((self.llm_token_id_delta + self.llm_token_size + 1), dtype=torch.int32) # 163840 + 6564 = 170404
|
| 124 |
+
self.task_token_id = self.sos_eos_token_id + torch.tensor(1, dtype=torch.int32) # 170405
|
| 125 |
+
self.zero_token_id = self.task_token_id + torch.tensor(1, dtype=torch.int32)
|
| 126 |
+
|
| 127 |
+
self.zero_embed_buffer = torch.zeros(
|
| 128 |
+
(vllm_config.scheduler_config.max_num_seqs, self.llm_input_size),
|
| 129 |
+
dtype=self.llm_embedding.weight.dtype,
|
| 130 |
+
device=self.llm_embedding.weight.device
|
| 131 |
+
)
|
| 132 |
+
self.inputs_embed_buffer = torch.zeros(
|
| 133 |
+
(vllm_config.scheduler_config.max_num_batched_tokens, self.llm_input_size),
|
| 134 |
+
dtype=self.llm_embedding.weight.dtype,
|
| 135 |
+
device=self.llm_embedding.weight.device,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def get_sos_eos_emb(self):
|
| 139 |
+
return self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
| 140 |
+
|
| 141 |
+
def get_task_id_emb(self):
|
| 142 |
+
return self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
| 143 |
+
|
| 144 |
+
def get_input_embeddings(
|
| 145 |
+
self,
|
| 146 |
+
input_ids: torch.Tensor,
|
| 147 |
+
multimodal_embeddings: Optional[T] = None,
|
| 148 |
+
attn_metadata: Optional["AttentionMetadata"] = None,
|
| 149 |
+
) -> torch.Tensor:
|
| 150 |
+
"""
|
| 151 |
+
Returns the input embeddings merged from the text embeddings from
|
| 152 |
+
input_ids and the multimodal embeddings generated from multimodal
|
| 153 |
+
kwargs.
|
| 154 |
+
"""
|
| 155 |
+
# 创建掩码,标记哪些 token_id 属于音频 Token
|
| 156 |
+
mask = input_ids < self.speech_token_size
|
| 157 |
+
|
| 158 |
+
# 获取 input_ids 的原始形状
|
| 159 |
+
input_shape = input_ids.shape
|
| 160 |
+
# 展平 input_ids 和掩码以便统一处理
|
| 161 |
+
flat_input_ids = input_ids.view(-1)
|
| 162 |
+
flat_mask = mask.view(-1)
|
| 163 |
+
|
| 164 |
+
inputs_embeds = self.inputs_embed_buffer[:flat_input_ids.shape[0]]
|
| 165 |
+
inputs_embeds.zero_()
|
| 166 |
+
|
| 167 |
+
# Process speech tokens
|
| 168 |
+
if flat_mask.any():
|
| 169 |
+
speech_token_ids = flat_input_ids[flat_mask]
|
| 170 |
+
inputs_embeds[flat_mask] = self.speech_embedding(speech_token_ids)
|
| 171 |
+
|
| 172 |
+
# 处理大于 delta 的 token_id
|
| 173 |
+
if (~flat_mask).any():
|
| 174 |
+
llm_token_ids = flat_input_ids[~flat_mask]
|
| 175 |
+
llm_embeds = torch.zeros_like(inputs_embeds[~flat_mask])
|
| 176 |
+
|
| 177 |
+
sos_eos_mask = llm_token_ids == self.sos_eos_token_id
|
| 178 |
+
task_mask = llm_token_ids == self.task_token_id
|
| 179 |
+
zero_mask = llm_token_ids == self.zero_token_id
|
| 180 |
+
normal_mask = ~(sos_eos_mask | task_mask | zero_mask)
|
| 181 |
+
|
| 182 |
+
# 分层处理逻辑
|
| 183 |
+
# 第一优先级:SOS/EOS标记
|
| 184 |
+
if sos_eos_mask.any():
|
| 185 |
+
llm_embeds[sos_eos_mask] = self.llm_embedding.weight[self.sos_eos].unsqueeze(0)
|
| 186 |
+
|
| 187 |
+
# 第二优先级:任务标记
|
| 188 |
+
if task_mask.any():
|
| 189 |
+
llm_embeds[task_mask] = self.llm_embedding.weight[self.task_id].unsqueeze(0)
|
| 190 |
+
|
| 191 |
+
# 第二优先级:空音频标记
|
| 192 |
+
if zero_mask.any():
|
| 193 |
+
llm_embeds[zero_mask] = self.zero_embed_buffer[:len(llm_embeds[zero_mask])]
|
| 194 |
+
|
| 195 |
+
# 常规LLM token
|
| 196 |
+
if normal_mask.any():
|
| 197 |
+
original_ids = llm_token_ids[normal_mask] - self.llm_token_id_delta
|
| 198 |
+
# print('original_ids: ',original_ids)
|
| 199 |
+
llm_embeds[normal_mask] = self.model.get_input_embeddings(original_ids)
|
| 200 |
+
|
| 201 |
+
inputs_embeds[~flat_mask] = llm_embeds
|
| 202 |
+
|
| 203 |
+
inputs_embeds = inputs_embeds.view(*input_shape, self.llm_input_size)
|
| 204 |
+
|
| 205 |
+
# 合并多模态嵌入(如果有)
|
| 206 |
+
if multimodal_embeddings is not None:
|
| 207 |
+
inputs_embeds = merge_multimodal_embeddings(
|
| 208 |
+
input_ids, inputs_embeds, multimodal_embeddings,
|
| 209 |
+
self.config.audio_token_index
|
| 210 |
+
)
|
| 211 |
+
return inputs_embeds
|
| 212 |
+
|
| 213 |
+
def forward(
|
| 214 |
+
self,
|
| 215 |
+
input_ids: torch.Tensor,
|
| 216 |
+
positions: torch.Tensor,
|
| 217 |
+
kv_caches: List[torch.Tensor],
|
| 218 |
+
attn_metadata: AttentionMetadata,
|
| 219 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 220 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 221 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 222 |
+
if inputs_embeds is None:
|
| 223 |
+
inputs_embeds = self.get_input_embeddings(
|
| 224 |
+
input_ids,
|
| 225 |
+
attn_metadata=attn_metadata,
|
| 226 |
+
)
|
| 227 |
+
return self.model(input_ids, positions, kv_caches,
|
| 228 |
+
attn_metadata, intermediate_tensors,
|
| 229 |
+
inputs_embeds)
|
| 230 |
+
|
| 231 |
+
def compute_logits(
|
| 232 |
+
self,
|
| 233 |
+
hidden_states: torch.Tensor,
|
| 234 |
+
sampling_metadata: SamplingMetadata,
|
| 235 |
+
) -> Optional[torch.Tensor]:
|
| 236 |
+
logits = self.logits_processor(self.llm_decoder, hidden_states,
|
| 237 |
+
sampling_metadata)
|
| 238 |
+
return logits
|
| 239 |
+
|
| 240 |
+
def sample(
|
| 241 |
+
self,
|
| 242 |
+
logits: torch.Tensor,
|
| 243 |
+
sampling_metadata: SamplingMetadata,
|
| 244 |
+
) -> Optional[SamplerOutput]:
|
| 245 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 246 |
+
return next_tokens
|
| 247 |
+
|
| 248 |
+
@staticmethod
|
| 249 |
+
def convert_weights(weights: Iterable[Tuple[str, torch.Tensor]]) -> Iterable[Tuple[str, torch.Tensor]]:
|
| 250 |
+
for name, param in weights:
|
| 251 |
+
# 处理Qwen2Model核心参数
|
| 252 |
+
if name.startswith("llm."):
|
| 253 |
+
if name.startswith("llm.model.model."):
|
| 254 |
+
name = name.replace("llm.model.model.", "model.")
|
| 255 |
+
else:
|
| 256 |
+
continue
|
| 257 |
+
# print('weights name: ', name)
|
| 258 |
+
yield name, param
|
| 259 |
+
|
| 260 |
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
| 261 |
+
weights = self.convert_weights(weights)
|
| 262 |
+
loader = AutoWeightsLoader(self)
|
| 263 |
+
loader.load_weights(weights)
|
cosyvoice/tokenizer/__pycache__/tokenizer.cpython-310.pyc
ADDED
|
Binary file (7.89 kB). View file
|
|
|
cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cosyvoice/tokenizer/tokenizer.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import os
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
from typing import Optional
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
from whisper.tokenizer import Tokenizer
|
| 8 |
+
|
| 9 |
+
import tiktoken
|
| 10 |
+
|
| 11 |
+
LANGUAGES = {
|
| 12 |
+
"en": "english",
|
| 13 |
+
"zh": "chinese",
|
| 14 |
+
"de": "german",
|
| 15 |
+
"es": "spanish",
|
| 16 |
+
"ru": "russian",
|
| 17 |
+
"ko": "korean",
|
| 18 |
+
"fr": "french",
|
| 19 |
+
"ja": "japanese",
|
| 20 |
+
"pt": "portuguese",
|
| 21 |
+
"tr": "turkish",
|
| 22 |
+
"pl": "polish",
|
| 23 |
+
"ca": "catalan",
|
| 24 |
+
"nl": "dutch",
|
| 25 |
+
"ar": "arabic",
|
| 26 |
+
"sv": "swedish",
|
| 27 |
+
"it": "italian",
|
| 28 |
+
"id": "indonesian",
|
| 29 |
+
"hi": "hindi",
|
| 30 |
+
"fi": "finnish",
|
| 31 |
+
"vi": "vietnamese",
|
| 32 |
+
"he": "hebrew",
|
| 33 |
+
"uk": "ukrainian",
|
| 34 |
+
"el": "greek",
|
| 35 |
+
"ms": "malay",
|
| 36 |
+
"cs": "czech",
|
| 37 |
+
"ro": "romanian",
|
| 38 |
+
"da": "danish",
|
| 39 |
+
"hu": "hungarian",
|
| 40 |
+
"ta": "tamil",
|
| 41 |
+
"no": "norwegian",
|
| 42 |
+
"th": "thai",
|
| 43 |
+
"ur": "urdu",
|
| 44 |
+
"hr": "croatian",
|
| 45 |
+
"bg": "bulgarian",
|
| 46 |
+
"lt": "lithuanian",
|
| 47 |
+
"la": "latin",
|
| 48 |
+
"mi": "maori",
|
| 49 |
+
"ml": "malayalam",
|
| 50 |
+
"cy": "welsh",
|
| 51 |
+
"sk": "slovak",
|
| 52 |
+
"te": "telugu",
|
| 53 |
+
"fa": "persian",
|
| 54 |
+
"lv": "latvian",
|
| 55 |
+
"bn": "bengali",
|
| 56 |
+
"sr": "serbian",
|
| 57 |
+
"az": "azerbaijani",
|
| 58 |
+
"sl": "slovenian",
|
| 59 |
+
"kn": "kannada",
|
| 60 |
+
"et": "estonian",
|
| 61 |
+
"mk": "macedonian",
|
| 62 |
+
"br": "breton",
|
| 63 |
+
"eu": "basque",
|
| 64 |
+
"is": "icelandic",
|
| 65 |
+
"hy": "armenian",
|
| 66 |
+
"ne": "nepali",
|
| 67 |
+
"mn": "mongolian",
|
| 68 |
+
"bs": "bosnian",
|
| 69 |
+
"kk": "kazakh",
|
| 70 |
+
"sq": "albanian",
|
| 71 |
+
"sw": "swahili",
|
| 72 |
+
"gl": "galician",
|
| 73 |
+
"mr": "marathi",
|
| 74 |
+
"pa": "punjabi",
|
| 75 |
+
"si": "sinhala",
|
| 76 |
+
"km": "khmer",
|
| 77 |
+
"sn": "shona",
|
| 78 |
+
"yo": "yoruba",
|
| 79 |
+
"so": "somali",
|
| 80 |
+
"af": "afrikaans",
|
| 81 |
+
"oc": "occitan",
|
| 82 |
+
"ka": "georgian",
|
| 83 |
+
"be": "belarusian",
|
| 84 |
+
"tg": "tajik",
|
| 85 |
+
"sd": "sindhi",
|
| 86 |
+
"gu": "gujarati",
|
| 87 |
+
"am": "amharic",
|
| 88 |
+
"yi": "yiddish",
|
| 89 |
+
"lo": "lao",
|
| 90 |
+
"uz": "uzbek",
|
| 91 |
+
"fo": "faroese",
|
| 92 |
+
"ht": "haitian creole",
|
| 93 |
+
"ps": "pashto",
|
| 94 |
+
"tk": "turkmen",
|
| 95 |
+
"nn": "nynorsk",
|
| 96 |
+
"mt": "maltese",
|
| 97 |
+
"sa": "sanskrit",
|
| 98 |
+
"lb": "luxembourgish",
|
| 99 |
+
"my": "myanmar",
|
| 100 |
+
"bo": "tibetan",
|
| 101 |
+
"tl": "tagalog",
|
| 102 |
+
"mg": "malagasy",
|
| 103 |
+
"as": "assamese",
|
| 104 |
+
"tt": "tatar",
|
| 105 |
+
"haw": "hawaiian",
|
| 106 |
+
"ln": "lingala",
|
| 107 |
+
"ha": "hausa",
|
| 108 |
+
"ba": "bashkir",
|
| 109 |
+
"jw": "javanese",
|
| 110 |
+
"su": "sundanese",
|
| 111 |
+
"yue": "cantonese",
|
| 112 |
+
"minnan": "minnan",
|
| 113 |
+
"wuyu": "wuyu",
|
| 114 |
+
"dialect": "dialect",
|
| 115 |
+
"zh/en": "zh/en",
|
| 116 |
+
"en/zh": "en/zh",
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
# language code lookup by name, with a few language aliases
|
| 120 |
+
TO_LANGUAGE_CODE = {
|
| 121 |
+
**{language: code for code, language in LANGUAGES.items()},
|
| 122 |
+
"burmese": "my",
|
| 123 |
+
"valencian": "ca",
|
| 124 |
+
"flemish": "nl",
|
| 125 |
+
"haitian": "ht",
|
| 126 |
+
"letzeburgesch": "lb",
|
| 127 |
+
"pushto": "ps",
|
| 128 |
+
"panjabi": "pa",
|
| 129 |
+
"moldavian": "ro",
|
| 130 |
+
"moldovan": "ro",
|
| 131 |
+
"sinhalese": "si",
|
| 132 |
+
"castilian": "es",
|
| 133 |
+
"mandarin": "zh",
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
AUDIO_EVENT = {
|
| 137 |
+
"ASR": "ASR",
|
| 138 |
+
"AED": "AED",
|
| 139 |
+
"SER": "SER",
|
| 140 |
+
"Speech": "Speech",
|
| 141 |
+
"/Speech": "/Speech",
|
| 142 |
+
"BGM": "BGM",
|
| 143 |
+
"/BGM": "/BGM",
|
| 144 |
+
"Laughter": "Laughter",
|
| 145 |
+
"/Laughter": "/Laughter",
|
| 146 |
+
"Applause": "Applause",
|
| 147 |
+
"/Applause": "/Applause",
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
EMOTION = {
|
| 151 |
+
"HAPPY": "HAPPY",
|
| 152 |
+
"SAD": "SAD",
|
| 153 |
+
"ANGRY": "ANGRY",
|
| 154 |
+
"NEUTRAL": "NEUTRAL",
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
TTS_Vocal_Token = {
|
| 158 |
+
"TTS/B": "TTS/B",
|
| 159 |
+
"TTS/O": "TTS/O",
|
| 160 |
+
"TTS/Q": "TTS/Q",
|
| 161 |
+
"TTS/A": "TTS/A",
|
| 162 |
+
"TTS/CO": "TTS/CO",
|
| 163 |
+
"TTS/CL": "TTS/CL",
|
| 164 |
+
"TTS/H": "TTS/H",
|
| 165 |
+
**{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@lru_cache(maxsize=None)
|
| 170 |
+
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
| 171 |
+
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
| 172 |
+
ranks = {
|
| 173 |
+
base64.b64decode(token): int(rank)
|
| 174 |
+
for token, rank in (line.split() for line in open(vocab_path) if line)
|
| 175 |
+
}
|
| 176 |
+
n_vocab = len(ranks)
|
| 177 |
+
special_tokens = {}
|
| 178 |
+
|
| 179 |
+
specials = [
|
| 180 |
+
"<|endoftext|>",
|
| 181 |
+
"<|startoftranscript|>",
|
| 182 |
+
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
| 183 |
+
*[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
|
| 184 |
+
*[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
|
| 185 |
+
"<|translate|>",
|
| 186 |
+
"<|transcribe|>",
|
| 187 |
+
"<|startoflm|>",
|
| 188 |
+
"<|startofprev|>",
|
| 189 |
+
"<|nospeech|>",
|
| 190 |
+
"<|notimestamps|>",
|
| 191 |
+
*[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
|
| 192 |
+
*[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
|
| 193 |
+
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
| 194 |
+
]
|
| 195 |
+
|
| 196 |
+
for token in specials:
|
| 197 |
+
special_tokens[token] = n_vocab
|
| 198 |
+
n_vocab += 1
|
| 199 |
+
|
| 200 |
+
return tiktoken.Encoding(
|
| 201 |
+
name=os.path.basename(vocab_path),
|
| 202 |
+
explicit_n_vocab=n_vocab,
|
| 203 |
+
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
| 204 |
+
mergeable_ranks=ranks,
|
| 205 |
+
special_tokens=special_tokens,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
@lru_cache(maxsize=None)
|
| 210 |
+
def get_tokenizer(
|
| 211 |
+
multilingual: bool,
|
| 212 |
+
*,
|
| 213 |
+
num_languages: int = 99,
|
| 214 |
+
language: Optional[str] = None,
|
| 215 |
+
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
| 216 |
+
) -> Tokenizer:
|
| 217 |
+
if language is not None:
|
| 218 |
+
language = language.lower()
|
| 219 |
+
if language not in LANGUAGES:
|
| 220 |
+
if language in TO_LANGUAGE_CODE:
|
| 221 |
+
language = TO_LANGUAGE_CODE[language]
|
| 222 |
+
else:
|
| 223 |
+
raise ValueError(f"Unsupported language: {language}")
|
| 224 |
+
|
| 225 |
+
if multilingual:
|
| 226 |
+
encoding_name = "multilingual_zh_ja_yue_char_del"
|
| 227 |
+
language = language or "en"
|
| 228 |
+
task = task or "transcribe"
|
| 229 |
+
else:
|
| 230 |
+
encoding_name = "gpt2"
|
| 231 |
+
language = None
|
| 232 |
+
task = None
|
| 233 |
+
|
| 234 |
+
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
| 235 |
+
|
| 236 |
+
return Tokenizer(
|
| 237 |
+
encoding=encoding, num_languages=num_languages, language=language, task=task
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class QwenTokenizer():
|
| 242 |
+
def __init__(self, token_path, skip_special_tokens=True):
|
| 243 |
+
super().__init__()
|
| 244 |
+
# NOTE: non-chat model, all these special tokens keep randomly initialized.
|
| 245 |
+
special_tokens = {
|
| 246 |
+
'eos_token': '<|endoftext|>',
|
| 247 |
+
'pad_token': '<|endoftext|>',
|
| 248 |
+
'additional_special_tokens': [
|
| 249 |
+
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
| 250 |
+
'[breath]', '<strong>', '</strong>', '[noise]',
|
| 251 |
+
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
| 252 |
+
'[quick_breath]',
|
| 253 |
+
"<laughter>", "</laughter>",
|
| 254 |
+
"[hissing]", "[sigh]", "[vocalized-noise]",
|
| 255 |
+
"[lipsmack]", "[mn]"
|
| 256 |
+
]
|
| 257 |
+
}
|
| 258 |
+
self.special_tokens = special_tokens
|
| 259 |
+
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
|
| 260 |
+
self.tokenizer.add_special_tokens(special_tokens)
|
| 261 |
+
self.skip_special_tokens = skip_special_tokens
|
| 262 |
+
|
| 263 |
+
def encode(self, text, **kwargs):
|
| 264 |
+
tokens = self.tokenizer([text], return_tensors="pt")
|
| 265 |
+
tokens = tokens["input_ids"][0].cpu().tolist()
|
| 266 |
+
return tokens
|
| 267 |
+
|
| 268 |
+
def decode(self, tokens):
|
| 269 |
+
tokens = torch.tensor(tokens, dtype=torch.int64)
|
| 270 |
+
text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
|
| 271 |
+
return text
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
@lru_cache(maxsize=None)
|
| 275 |
+
def get_qwen_tokenizer(
|
| 276 |
+
token_path: str,
|
| 277 |
+
skip_special_tokens: bool
|
| 278 |
+
) -> QwenTokenizer:
|
| 279 |
+
return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
|
cosyvoice/transformer/__init__.py
ADDED
|
File without changes
|
cosyvoice/transformer/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (184 Bytes). View file
|
|
|
cosyvoice/transformer/__pycache__/activation.cpython-310.pyc
ADDED
|
Binary file (2.51 kB). View file
|
|
|
cosyvoice/transformer/__pycache__/attention.cpython-310.pyc
ADDED
|
Binary file (9.33 kB). View file
|
|
|
cosyvoice/transformer/__pycache__/convolution.cpython-310.pyc
ADDED
|
Binary file (3.09 kB). View file
|
|
|
cosyvoice/transformer/__pycache__/embedding.cpython-310.pyc
ADDED
|
Binary file (9.55 kB). View file
|
|
|
cosyvoice/transformer/__pycache__/encoder_layer.cpython-310.pyc
ADDED
|
Binary file (7.36 kB). View file
|
|
|
cosyvoice/transformer/__pycache__/label_smoothing_loss.cpython-310.pyc
ADDED
|
Binary file (2.91 kB). View file
|
|
|
cosyvoice/transformer/__pycache__/positionwise_feed_forward.cpython-310.pyc
ADDED
|
Binary file (3.82 kB). View file
|
|
|
cosyvoice/transformer/__pycache__/subsampling.cpython-310.pyc
ADDED
|
Binary file (9.86 kB). View file
|
|
|
cosyvoice/transformer/__pycache__/upsample_encoder.cpython-310.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
cosyvoice/transformer/activation.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
|
| 2 |
+
# 2020 Northwestern Polytechnical University (Pengcheng Guo)
|
| 3 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
| 4 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
"""Swish() activation function for Conformer."""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from torch import nn, sin, pow
|
| 21 |
+
from torch.nn import Parameter
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Swish(torch.nn.Module):
|
| 25 |
+
"""Construct an Swish object."""
|
| 26 |
+
|
| 27 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 28 |
+
"""Return Swish activation function."""
|
| 29 |
+
return x * torch.sigmoid(x)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
| 33 |
+
# LICENSE is in incl_licenses directory.
|
| 34 |
+
class Snake(nn.Module):
|
| 35 |
+
'''
|
| 36 |
+
Implementation of a sine-based periodic activation function
|
| 37 |
+
Shape:
|
| 38 |
+
- Input: (B, C, T)
|
| 39 |
+
- Output: (B, C, T), same shape as the input
|
| 40 |
+
Parameters:
|
| 41 |
+
- alpha - trainable parameter
|
| 42 |
+
References:
|
| 43 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 44 |
+
https://arxiv.org/abs/2006.08195
|
| 45 |
+
Examples:
|
| 46 |
+
>>> a1 = snake(256)
|
| 47 |
+
>>> x = torch.randn(256)
|
| 48 |
+
>>> x = a1(x)
|
| 49 |
+
'''
|
| 50 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
| 51 |
+
'''
|
| 52 |
+
Initialization.
|
| 53 |
+
INPUT:
|
| 54 |
+
- in_features: shape of the input
|
| 55 |
+
- alpha: trainable parameter
|
| 56 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
| 57 |
+
alpha will be trained along with the rest of your model.
|
| 58 |
+
'''
|
| 59 |
+
super(Snake, self).__init__()
|
| 60 |
+
self.in_features = in_features
|
| 61 |
+
|
| 62 |
+
# initialize alpha
|
| 63 |
+
self.alpha_logscale = alpha_logscale
|
| 64 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 65 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 66 |
+
else: # linear scale alphas initialized to ones
|
| 67 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
| 68 |
+
|
| 69 |
+
self.alpha.requires_grad = alpha_trainable
|
| 70 |
+
|
| 71 |
+
self.no_div_by_zero = 0.000000001
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
'''
|
| 75 |
+
Forward pass of the function.
|
| 76 |
+
Applies the function to the input elementwise.
|
| 77 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
| 78 |
+
'''
|
| 79 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 80 |
+
if self.alpha_logscale:
|
| 81 |
+
alpha = torch.exp(alpha)
|
| 82 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
| 83 |
+
|
| 84 |
+
return x
|
cosyvoice/transformer/attention.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019 Shigeki Karita
|
| 2 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
| 3 |
+
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
|
| 4 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
"""Multi-Head Attention layer definition."""
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
from typing import Tuple
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from torch import nn
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class MultiHeadedAttention(nn.Module):
|
| 27 |
+
"""Multi-Head Attention layer.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
n_head (int): The number of heads.
|
| 31 |
+
n_feat (int): The number of features.
|
| 32 |
+
dropout_rate (float): Dropout rate.
|
| 33 |
+
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self,
|
| 37 |
+
n_head: int,
|
| 38 |
+
n_feat: int,
|
| 39 |
+
dropout_rate: float,
|
| 40 |
+
key_bias: bool = True):
|
| 41 |
+
"""Construct an MultiHeadedAttention object."""
|
| 42 |
+
super().__init__()
|
| 43 |
+
assert n_feat % n_head == 0
|
| 44 |
+
# We assume d_v always equals d_k
|
| 45 |
+
self.d_k = n_feat // n_head
|
| 46 |
+
self.h = n_head
|
| 47 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
| 48 |
+
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
|
| 49 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
| 50 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
| 51 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
| 52 |
+
|
| 53 |
+
def forward_qkv(
|
| 54 |
+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
| 55 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 56 |
+
"""Transform query, key and value.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 60 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 61 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
torch.Tensor: Transformed query tensor, size
|
| 65 |
+
(#batch, n_head, time1, d_k).
|
| 66 |
+
torch.Tensor: Transformed key tensor, size
|
| 67 |
+
(#batch, n_head, time2, d_k).
|
| 68 |
+
torch.Tensor: Transformed value tensor, size
|
| 69 |
+
(#batch, n_head, time2, d_k).
|
| 70 |
+
|
| 71 |
+
"""
|
| 72 |
+
n_batch = query.size(0)
|
| 73 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
| 74 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
| 75 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
| 76 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
| 77 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
| 78 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
| 79 |
+
|
| 80 |
+
return q, k, v
|
| 81 |
+
|
| 82 |
+
def forward_attention(
|
| 83 |
+
self,
|
| 84 |
+
value: torch.Tensor,
|
| 85 |
+
scores: torch.Tensor,
|
| 86 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
| 87 |
+
) -> torch.Tensor:
|
| 88 |
+
"""Compute attention context vector.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
value (torch.Tensor): Transformed value, size
|
| 92 |
+
(#batch, n_head, time2, d_k).
|
| 93 |
+
scores (torch.Tensor): Attention score, size
|
| 94 |
+
(#batch, n_head, time1, time2).
|
| 95 |
+
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
|
| 96 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
| 100 |
+
weighted by the attention score (#batch, time1, time2).
|
| 101 |
+
|
| 102 |
+
"""
|
| 103 |
+
n_batch = value.size(0)
|
| 104 |
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
|
| 105 |
+
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
|
| 106 |
+
# 1st chunk to ease the onnx export.]
|
| 107 |
+
# 2. pytorch training
|
| 108 |
+
if mask.size(2) > 0: # time2 > 0
|
| 109 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
| 110 |
+
# For last chunk, time2 might be larger than scores.size(-1)
|
| 111 |
+
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
| 112 |
+
scores = scores.masked_fill(mask, -float('inf'))
|
| 113 |
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
| 114 |
+
mask, 0.0) # (batch, head, time1, time2)
|
| 115 |
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
|
| 116 |
+
# 1. onnx(16/-1, -1/-1, 16/0)
|
| 117 |
+
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
| 118 |
+
else:
|
| 119 |
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
| 120 |
+
|
| 121 |
+
p_attn = self.dropout(attn)
|
| 122 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
| 123 |
+
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
|
| 124 |
+
self.h * self.d_k)
|
| 125 |
+
) # (batch, time1, d_model)
|
| 126 |
+
|
| 127 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
| 128 |
+
|
| 129 |
+
def forward(
|
| 130 |
+
self,
|
| 131 |
+
query: torch.Tensor,
|
| 132 |
+
key: torch.Tensor,
|
| 133 |
+
value: torch.Tensor,
|
| 134 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
| 135 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
| 136 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
| 137 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 138 |
+
"""Compute scaled dot product attention.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 142 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 143 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 144 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
| 145 |
+
(#batch, time1, time2).
|
| 146 |
+
1.When applying cross attention between decoder and encoder,
|
| 147 |
+
the batch padding mask for input is in (#batch, 1, T) shape.
|
| 148 |
+
2.When applying self attention of encoder,
|
| 149 |
+
the mask is in (#batch, T, T) shape.
|
| 150 |
+
3.When applying self attention of decoder,
|
| 151 |
+
the mask is in (#batch, L, L) shape.
|
| 152 |
+
4.If the different position in decoder see different block
|
| 153 |
+
of the encoder, such as Mocha, the passed in mask could be
|
| 154 |
+
in (#batch, L, T) shape. But there is no such case in current
|
| 155 |
+
CosyVoice.
|
| 156 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
| 157 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
| 158 |
+
and `head * d_k == size`
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
| 163 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
| 164 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
| 165 |
+
and `head * d_k == size`
|
| 166 |
+
|
| 167 |
+
"""
|
| 168 |
+
q, k, v = self.forward_qkv(query, key, value)
|
| 169 |
+
|
| 170 |
+
# NOTE(xcsong):
|
| 171 |
+
# when export onnx model, for 1st chunk, we feed
|
| 172 |
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
| 173 |
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
| 174 |
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
| 175 |
+
# and we will always do splitting and
|
| 176 |
+
# concatnation(this will simplify onnx export). Note that
|
| 177 |
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
| 178 |
+
# when export jit model, for 1st chunk, we always feed
|
| 179 |
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
| 180 |
+
# >>> a = torch.ones((1, 2, 0, 4))
|
| 181 |
+
# >>> b = torch.ones((1, 2, 3, 4))
|
| 182 |
+
# >>> c = torch.cat((a, b), dim=2)
|
| 183 |
+
# >>> torch.equal(b, c) # True
|
| 184 |
+
# >>> d = torch.split(a, 2, dim=-1)
|
| 185 |
+
# >>> torch.equal(d[0], d[1]) # True
|
| 186 |
+
if cache.size(0) > 0:
|
| 187 |
+
key_cache, value_cache = torch.split(cache,
|
| 188 |
+
cache.size(-1) // 2,
|
| 189 |
+
dim=-1)
|
| 190 |
+
k = torch.cat([key_cache, k], dim=2)
|
| 191 |
+
v = torch.cat([value_cache, v], dim=2)
|
| 192 |
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
| 193 |
+
# non-trivial to calculate `next_cache_start` here.
|
| 194 |
+
new_cache = torch.cat((k, v), dim=-1)
|
| 195 |
+
|
| 196 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
| 197 |
+
return self.forward_attention(v, scores, mask), new_cache
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
| 201 |
+
"""Multi-Head Attention layer with relative position encoding.
|
| 202 |
+
Paper: https://arxiv.org/abs/1901.02860
|
| 203 |
+
Args:
|
| 204 |
+
n_head (int): The number of heads.
|
| 205 |
+
n_feat (int): The number of features.
|
| 206 |
+
dropout_rate (float): Dropout rate.
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
def __init__(self,
|
| 210 |
+
n_head: int,
|
| 211 |
+
n_feat: int,
|
| 212 |
+
dropout_rate: float,
|
| 213 |
+
key_bias: bool = True):
|
| 214 |
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
| 215 |
+
super().__init__(n_head, n_feat, dropout_rate, key_bias)
|
| 216 |
+
# linear transformation for positional encoding
|
| 217 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
| 218 |
+
# these two learnable bias are used in matrix c and matrix d
|
| 219 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 220 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
| 221 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
| 222 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
| 223 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
| 224 |
+
|
| 225 |
+
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
| 226 |
+
"""Compute relative positional encoding.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
| 230 |
+
time1 means the length of query vector.
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
torch.Tensor: Output tensor.
|
| 234 |
+
|
| 235 |
+
"""
|
| 236 |
+
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
| 237 |
+
device=x.device,
|
| 238 |
+
dtype=x.dtype)
|
| 239 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
| 240 |
+
|
| 241 |
+
x_padded = x_padded.view(x.size()[0],
|
| 242 |
+
x.size()[1],
|
| 243 |
+
x.size(3) + 1, x.size(2))
|
| 244 |
+
x = x_padded[:, :, 1:].view_as(x)[
|
| 245 |
+
:, :, :, : x.size(-1) // 2 + 1
|
| 246 |
+
] # only keep the positions from 0 to time2
|
| 247 |
+
return x
|
| 248 |
+
|
| 249 |
+
def forward(
|
| 250 |
+
self,
|
| 251 |
+
query: torch.Tensor,
|
| 252 |
+
key: torch.Tensor,
|
| 253 |
+
value: torch.Tensor,
|
| 254 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
| 255 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
| 256 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
| 257 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 258 |
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
| 259 |
+
Args:
|
| 260 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 261 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 262 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 263 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
| 264 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
| 265 |
+
pos_emb (torch.Tensor): Positional embedding tensor
|
| 266 |
+
(#batch, time2, size).
|
| 267 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
| 268 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
| 269 |
+
and `head * d_k == size`
|
| 270 |
+
Returns:
|
| 271 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
| 272 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
| 273 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
| 274 |
+
and `head * d_k == size`
|
| 275 |
+
"""
|
| 276 |
+
q, k, v = self.forward_qkv(query, key, value)
|
| 277 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
| 278 |
+
|
| 279 |
+
# NOTE(xcsong):
|
| 280 |
+
# when export onnx model, for 1st chunk, we feed
|
| 281 |
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
| 282 |
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
| 283 |
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
| 284 |
+
# and we will always do splitting and
|
| 285 |
+
# concatnation(this will simplify onnx export). Note that
|
| 286 |
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
| 287 |
+
# when export jit model, for 1st chunk, we always feed
|
| 288 |
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
| 289 |
+
# >>> a = torch.ones((1, 2, 0, 4))
|
| 290 |
+
# >>> b = torch.ones((1, 2, 3, 4))
|
| 291 |
+
# >>> c = torch.cat((a, b), dim=2)
|
| 292 |
+
# >>> torch.equal(b, c) # True
|
| 293 |
+
# >>> d = torch.split(a, 2, dim=-1)
|
| 294 |
+
# >>> torch.equal(d[0], d[1]) # True
|
| 295 |
+
if cache.size(0) > 0:
|
| 296 |
+
key_cache, value_cache = torch.split(cache,
|
| 297 |
+
cache.size(-1) // 2,
|
| 298 |
+
dim=-1)
|
| 299 |
+
k = torch.cat([key_cache, k], dim=2)
|
| 300 |
+
v = torch.cat([value_cache, v], dim=2)
|
| 301 |
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
| 302 |
+
# non-trivial to calculate `next_cache_start` here.
|
| 303 |
+
new_cache = torch.cat((k, v), dim=-1)
|
| 304 |
+
|
| 305 |
+
n_batch_pos = pos_emb.size(0)
|
| 306 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
| 307 |
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
| 308 |
+
|
| 309 |
+
# (batch, head, time1, d_k)
|
| 310 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
| 311 |
+
# (batch, head, time1, d_k)
|
| 312 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
| 313 |
+
|
| 314 |
+
# compute attention score
|
| 315 |
+
# first compute matrix a and matrix c
|
| 316 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 317 |
+
# (batch, head, time1, time2)
|
| 318 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
| 319 |
+
|
| 320 |
+
# compute matrix b and matrix d
|
| 321 |
+
# (batch, head, time1, time2)
|
| 322 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
| 323 |
+
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
|
| 324 |
+
if matrix_ac.shape != matrix_bd.shape:
|
| 325 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
| 326 |
+
|
| 327 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
| 328 |
+
self.d_k) # (batch, head, time1, time2)
|
| 329 |
+
|
| 330 |
+
return self.forward_attention(v, scores, mask), new_cache
|
cosyvoice/transformer/convolution.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
| 2 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
| 16 |
+
"""ConvolutionModule definition."""
|
| 17 |
+
|
| 18 |
+
from typing import Tuple
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from torch import nn
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ConvolutionModule(nn.Module):
|
| 25 |
+
"""ConvolutionModule in Conformer model."""
|
| 26 |
+
|
| 27 |
+
def __init__(self,
|
| 28 |
+
channels: int,
|
| 29 |
+
kernel_size: int = 15,
|
| 30 |
+
activation: nn.Module = nn.ReLU(),
|
| 31 |
+
norm: str = "batch_norm",
|
| 32 |
+
causal: bool = False,
|
| 33 |
+
bias: bool = True):
|
| 34 |
+
"""Construct an ConvolutionModule object.
|
| 35 |
+
Args:
|
| 36 |
+
channels (int): The number of channels of conv layers.
|
| 37 |
+
kernel_size (int): Kernel size of conv layers.
|
| 38 |
+
causal (int): Whether use causal convolution or not
|
| 39 |
+
"""
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
self.pointwise_conv1 = nn.Conv1d(
|
| 43 |
+
channels,
|
| 44 |
+
2 * channels,
|
| 45 |
+
kernel_size=1,
|
| 46 |
+
stride=1,
|
| 47 |
+
padding=0,
|
| 48 |
+
bias=bias,
|
| 49 |
+
)
|
| 50 |
+
# self.lorder is used to distinguish if it's a causal convolution,
|
| 51 |
+
# if self.lorder > 0: it's a causal convolution, the input will be
|
| 52 |
+
# padded with self.lorder frames on the left in forward.
|
| 53 |
+
# else: it's a symmetrical convolution
|
| 54 |
+
if causal:
|
| 55 |
+
padding = 0
|
| 56 |
+
self.lorder = kernel_size - 1
|
| 57 |
+
else:
|
| 58 |
+
# kernel_size should be an odd number for none causal convolution
|
| 59 |
+
assert (kernel_size - 1) % 2 == 0
|
| 60 |
+
padding = (kernel_size - 1) // 2
|
| 61 |
+
self.lorder = 0
|
| 62 |
+
self.depthwise_conv = nn.Conv1d(
|
| 63 |
+
channels,
|
| 64 |
+
channels,
|
| 65 |
+
kernel_size,
|
| 66 |
+
stride=1,
|
| 67 |
+
padding=padding,
|
| 68 |
+
groups=channels,
|
| 69 |
+
bias=bias,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
assert norm in ['batch_norm', 'layer_norm']
|
| 73 |
+
if norm == "batch_norm":
|
| 74 |
+
self.use_layer_norm = False
|
| 75 |
+
self.norm = nn.BatchNorm1d(channels)
|
| 76 |
+
else:
|
| 77 |
+
self.use_layer_norm = True
|
| 78 |
+
self.norm = nn.LayerNorm(channels)
|
| 79 |
+
|
| 80 |
+
self.pointwise_conv2 = nn.Conv1d(
|
| 81 |
+
channels,
|
| 82 |
+
channels,
|
| 83 |
+
kernel_size=1,
|
| 84 |
+
stride=1,
|
| 85 |
+
padding=0,
|
| 86 |
+
bias=bias,
|
| 87 |
+
)
|
| 88 |
+
self.activation = activation
|
| 89 |
+
|
| 90 |
+
def forward(
|
| 91 |
+
self,
|
| 92 |
+
x: torch.Tensor,
|
| 93 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
| 94 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0)),
|
| 95 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 96 |
+
"""Compute convolution module.
|
| 97 |
+
Args:
|
| 98 |
+
x (torch.Tensor): Input tensor (#batch, time, channels).
|
| 99 |
+
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
|
| 100 |
+
(0, 0, 0) means fake mask.
|
| 101 |
+
cache (torch.Tensor): left context cache, it is only
|
| 102 |
+
used in causal convolution (#batch, channels, cache_t),
|
| 103 |
+
(0, 0, 0) meas fake cache.
|
| 104 |
+
Returns:
|
| 105 |
+
torch.Tensor: Output tensor (#batch, time, channels).
|
| 106 |
+
"""
|
| 107 |
+
# exchange the temporal dimension and the feature dimension
|
| 108 |
+
x = x.transpose(1, 2) # (#batch, channels, time)
|
| 109 |
+
|
| 110 |
+
# mask batch padding
|
| 111 |
+
if mask_pad.size(2) > 0: # time > 0
|
| 112 |
+
x.masked_fill_(~mask_pad, 0.0)
|
| 113 |
+
|
| 114 |
+
if self.lorder > 0:
|
| 115 |
+
if cache.size(2) == 0: # cache_t == 0
|
| 116 |
+
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
|
| 117 |
+
else:
|
| 118 |
+
assert cache.size(0) == x.size(0) # equal batch
|
| 119 |
+
assert cache.size(1) == x.size(1) # equal channel
|
| 120 |
+
x = torch.cat((cache, x), dim=2)
|
| 121 |
+
assert (x.size(2) > self.lorder)
|
| 122 |
+
new_cache = x[:, :, -self.lorder:]
|
| 123 |
+
else:
|
| 124 |
+
# It's better we just return None if no cache is required,
|
| 125 |
+
# However, for JIT export, here we just fake one tensor instead of
|
| 126 |
+
# None.
|
| 127 |
+
new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
| 128 |
+
|
| 129 |
+
# GLU mechanism
|
| 130 |
+
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
| 131 |
+
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
| 132 |
+
|
| 133 |
+
# 1D Depthwise Conv
|
| 134 |
+
x = self.depthwise_conv(x)
|
| 135 |
+
if self.use_layer_norm:
|
| 136 |
+
x = x.transpose(1, 2)
|
| 137 |
+
x = self.activation(self.norm(x))
|
| 138 |
+
if self.use_layer_norm:
|
| 139 |
+
x = x.transpose(1, 2)
|
| 140 |
+
x = self.pointwise_conv2(x)
|
| 141 |
+
# mask batch padding
|
| 142 |
+
if mask_pad.size(2) > 0: # time > 0
|
| 143 |
+
x.masked_fill_(~mask_pad, 0.0)
|
| 144 |
+
|
| 145 |
+
return x.transpose(1, 2), new_cache
|
cosyvoice/transformer/decoder.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
| 2 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
| 16 |
+
"""Decoder definition."""
|
| 17 |
+
from typing import Tuple, List, Optional
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.utils.checkpoint as ckpt
|
| 21 |
+
import logging
|
| 22 |
+
|
| 23 |
+
from cosyvoice.transformer.decoder_layer import DecoderLayer
|
| 24 |
+
from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
| 25 |
+
from cosyvoice.utils.class_utils import (
|
| 26 |
+
COSYVOICE_EMB_CLASSES,
|
| 27 |
+
COSYVOICE_ATTENTION_CLASSES,
|
| 28 |
+
COSYVOICE_ACTIVATION_CLASSES,
|
| 29 |
+
)
|
| 30 |
+
from cosyvoice.utils.mask import (subsequent_mask, make_pad_mask)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class TransformerDecoder(torch.nn.Module):
|
| 34 |
+
"""Base class of Transfomer decoder module.
|
| 35 |
+
Args:
|
| 36 |
+
vocab_size: output dim
|
| 37 |
+
encoder_output_size: dimension of attention
|
| 38 |
+
attention_heads: the number of heads of multi head attention
|
| 39 |
+
linear_units: the hidden units number of position-wise feedforward
|
| 40 |
+
num_blocks: the number of decoder blocks
|
| 41 |
+
dropout_rate: dropout rate
|
| 42 |
+
self_attention_dropout_rate: dropout rate for attention
|
| 43 |
+
input_layer: input layer type
|
| 44 |
+
use_output_layer: whether to use output layer
|
| 45 |
+
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
| 46 |
+
normalize_before:
|
| 47 |
+
True: use layer_norm before each sub-block of a layer.
|
| 48 |
+
False: use layer_norm after each sub-block of a layer.
|
| 49 |
+
src_attention: if false, encoder-decoder cross attention is not
|
| 50 |
+
applied, such as CIF model
|
| 51 |
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
| 52 |
+
gradient_checkpointing: rerunning a forward-pass segment for each
|
| 53 |
+
checkpointed segment during backward.
|
| 54 |
+
tie_word_embedding: Tie or clone module weights depending of whether we are
|
| 55 |
+
using TorchScript or not
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
vocab_size: int,
|
| 61 |
+
encoder_output_size: int,
|
| 62 |
+
attention_heads: int = 4,
|
| 63 |
+
linear_units: int = 2048,
|
| 64 |
+
num_blocks: int = 6,
|
| 65 |
+
dropout_rate: float = 0.1,
|
| 66 |
+
positional_dropout_rate: float = 0.1,
|
| 67 |
+
self_attention_dropout_rate: float = 0.0,
|
| 68 |
+
src_attention_dropout_rate: float = 0.0,
|
| 69 |
+
input_layer: str = "embed",
|
| 70 |
+
use_output_layer: bool = True,
|
| 71 |
+
normalize_before: bool = True,
|
| 72 |
+
src_attention: bool = True,
|
| 73 |
+
key_bias: bool = True,
|
| 74 |
+
activation_type: str = "relu",
|
| 75 |
+
gradient_checkpointing: bool = False,
|
| 76 |
+
tie_word_embedding: bool = False,
|
| 77 |
+
):
|
| 78 |
+
super().__init__()
|
| 79 |
+
attention_dim = encoder_output_size
|
| 80 |
+
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
| 81 |
+
|
| 82 |
+
self.embed = torch.nn.Sequential(
|
| 83 |
+
torch.nn.Identity() if input_layer == "no_pos" else
|
| 84 |
+
torch.nn.Embedding(vocab_size, attention_dim),
|
| 85 |
+
COSYVOICE_EMB_CLASSES[input_layer](attention_dim,
|
| 86 |
+
positional_dropout_rate),
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
self.normalize_before = normalize_before
|
| 90 |
+
self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
|
| 91 |
+
self.use_output_layer = use_output_layer
|
| 92 |
+
if use_output_layer:
|
| 93 |
+
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
|
| 94 |
+
else:
|
| 95 |
+
self.output_layer = torch.nn.Identity()
|
| 96 |
+
self.num_blocks = num_blocks
|
| 97 |
+
self.decoders = torch.nn.ModuleList([
|
| 98 |
+
DecoderLayer(
|
| 99 |
+
attention_dim,
|
| 100 |
+
COSYVOICE_ATTENTION_CLASSES["selfattn"](
|
| 101 |
+
attention_heads, attention_dim,
|
| 102 |
+
self_attention_dropout_rate, key_bias),
|
| 103 |
+
COSYVOICE_ATTENTION_CLASSES["selfattn"](
|
| 104 |
+
attention_heads, attention_dim, src_attention_dropout_rate,
|
| 105 |
+
key_bias) if src_attention else None,
|
| 106 |
+
PositionwiseFeedForward(attention_dim, linear_units,
|
| 107 |
+
dropout_rate, activation),
|
| 108 |
+
dropout_rate,
|
| 109 |
+
normalize_before,
|
| 110 |
+
) for _ in range(self.num_blocks)
|
| 111 |
+
])
|
| 112 |
+
|
| 113 |
+
self.gradient_checkpointing = gradient_checkpointing
|
| 114 |
+
self.tie_word_embedding = tie_word_embedding
|
| 115 |
+
|
| 116 |
+
def forward(
|
| 117 |
+
self,
|
| 118 |
+
memory: torch.Tensor,
|
| 119 |
+
memory_mask: torch.Tensor,
|
| 120 |
+
ys_in_pad: torch.Tensor,
|
| 121 |
+
ys_in_lens: torch.Tensor,
|
| 122 |
+
r_ys_in_pad: torch.Tensor = torch.empty(0),
|
| 123 |
+
reverse_weight: float = 0.0,
|
| 124 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 125 |
+
"""Forward decoder.
|
| 126 |
+
Args:
|
| 127 |
+
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
| 128 |
+
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
|
| 129 |
+
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
|
| 130 |
+
ys_in_lens: input lengths of this batch (batch)
|
| 131 |
+
r_ys_in_pad: not used in transformer decoder, in order to unify api
|
| 132 |
+
with bidirectional decoder
|
| 133 |
+
reverse_weight: not used in transformer decoder, in order to unify
|
| 134 |
+
api with bidirectional decode
|
| 135 |
+
Returns:
|
| 136 |
+
(tuple): tuple containing:
|
| 137 |
+
x: decoded token score before softmax (batch, maxlen_out,
|
| 138 |
+
vocab_size) if use_output_layer is True,
|
| 139 |
+
torch.tensor(0.0), in order to unify api with bidirectional decoder
|
| 140 |
+
olens: (batch, )
|
| 141 |
+
NOTE(xcsong):
|
| 142 |
+
We pass the `__call__` method of the modules instead of `forward` to the
|
| 143 |
+
checkpointing API because `__call__` attaches all the hooks of the module.
|
| 144 |
+
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
| 145 |
+
"""
|
| 146 |
+
tgt = ys_in_pad
|
| 147 |
+
maxlen = tgt.size(1)
|
| 148 |
+
# tgt_mask: (B, 1, L)
|
| 149 |
+
tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1)
|
| 150 |
+
tgt_mask = tgt_mask.to(tgt.device)
|
| 151 |
+
# m: (1, L, L)
|
| 152 |
+
m = subsequent_mask(tgt_mask.size(-1),
|
| 153 |
+
device=tgt_mask.device).unsqueeze(0)
|
| 154 |
+
# tgt_mask: (B, L, L)
|
| 155 |
+
tgt_mask = tgt_mask & m
|
| 156 |
+
x, _ = self.embed(tgt)
|
| 157 |
+
if self.gradient_checkpointing and self.training:
|
| 158 |
+
x = self.forward_layers_checkpointed(x, tgt_mask, memory,
|
| 159 |
+
memory_mask)
|
| 160 |
+
else:
|
| 161 |
+
x = self.forward_layers(x, tgt_mask, memory, memory_mask)
|
| 162 |
+
if self.normalize_before:
|
| 163 |
+
x = self.after_norm(x)
|
| 164 |
+
if self.use_output_layer:
|
| 165 |
+
x = self.output_layer(x)
|
| 166 |
+
olens = tgt_mask.sum(1)
|
| 167 |
+
return x, torch.tensor(0.0), olens
|
| 168 |
+
|
| 169 |
+
def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor,
|
| 170 |
+
memory: torch.Tensor,
|
| 171 |
+
memory_mask: torch.Tensor) -> torch.Tensor:
|
| 172 |
+
for layer in self.decoders:
|
| 173 |
+
x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
|
| 174 |
+
memory_mask)
|
| 175 |
+
return x
|
| 176 |
+
|
| 177 |
+
@torch.jit.unused
|
| 178 |
+
def forward_layers_checkpointed(self, x: torch.Tensor,
|
| 179 |
+
tgt_mask: torch.Tensor,
|
| 180 |
+
memory: torch.Tensor,
|
| 181 |
+
memory_mask: torch.Tensor) -> torch.Tensor:
|
| 182 |
+
for layer in self.decoders:
|
| 183 |
+
x, tgt_mask, memory, memory_mask = ckpt.checkpoint(
|
| 184 |
+
layer.__call__, x, tgt_mask, memory, memory_mask)
|
| 185 |
+
return x
|
| 186 |
+
|
| 187 |
+
def forward_one_step(
|
| 188 |
+
self,
|
| 189 |
+
memory: torch.Tensor,
|
| 190 |
+
memory_mask: torch.Tensor,
|
| 191 |
+
tgt: torch.Tensor,
|
| 192 |
+
tgt_mask: torch.Tensor,
|
| 193 |
+
cache: Optional[List[torch.Tensor]] = None,
|
| 194 |
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 195 |
+
"""Forward one step.
|
| 196 |
+
This is only used for decoding.
|
| 197 |
+
Args:
|
| 198 |
+
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
| 199 |
+
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
|
| 200 |
+
tgt: input token ids, int64 (batch, maxlen_out)
|
| 201 |
+
tgt_mask: input token mask, (batch, maxlen_out)
|
| 202 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
| 203 |
+
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
| 204 |
+
cache: cached output list of (batch, max_time_out-1, size)
|
| 205 |
+
Returns:
|
| 206 |
+
y, cache: NN output value and cache per `self.decoders`.
|
| 207 |
+
y.shape` is (batch, maxlen_out, token)
|
| 208 |
+
"""
|
| 209 |
+
x, _ = self.embed(tgt)
|
| 210 |
+
new_cache = []
|
| 211 |
+
for i, decoder in enumerate(self.decoders):
|
| 212 |
+
if cache is None:
|
| 213 |
+
c = None
|
| 214 |
+
else:
|
| 215 |
+
c = cache[i]
|
| 216 |
+
x, tgt_mask, memory, memory_mask = decoder(x,
|
| 217 |
+
tgt_mask,
|
| 218 |
+
memory,
|
| 219 |
+
memory_mask,
|
| 220 |
+
cache=c)
|
| 221 |
+
new_cache.append(x)
|
| 222 |
+
if self.normalize_before:
|
| 223 |
+
y = self.after_norm(x[:, -1])
|
| 224 |
+
else:
|
| 225 |
+
y = x[:, -1]
|
| 226 |
+
if self.use_output_layer:
|
| 227 |
+
y = torch.log_softmax(self.output_layer(y), dim=-1)
|
| 228 |
+
return y, new_cache
|
| 229 |
+
|
| 230 |
+
def tie_or_clone_weights(self, jit_mode: bool = True):
|
| 231 |
+
"""Tie or clone module weights (between word_emb and output_layer)
|
| 232 |
+
depending of whether we are using TorchScript or not"""
|
| 233 |
+
if not self.use_output_layer:
|
| 234 |
+
return
|
| 235 |
+
if jit_mode:
|
| 236 |
+
logging.info("clone emb.weight to output.weight")
|
| 237 |
+
self.output_layer.weight = torch.nn.Parameter(
|
| 238 |
+
self.embed[0].weight.clone())
|
| 239 |
+
else:
|
| 240 |
+
logging.info("tie emb.weight with output.weight")
|
| 241 |
+
self.output_layer.weight = self.embed[0].weight
|
| 242 |
+
|
| 243 |
+
if getattr(self.output_layer, "bias", None) is not None:
|
| 244 |
+
self.output_layer.bias.data = torch.nn.functional.pad(
|
| 245 |
+
self.output_layer.bias.data,
|
| 246 |
+
(
|
| 247 |
+
0,
|
| 248 |
+
self.output_layer.weight.shape[0] -
|
| 249 |
+
self.output_layer.bias.shape[0],
|
| 250 |
+
),
|
| 251 |
+
"constant",
|
| 252 |
+
0,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class BiTransformerDecoder(torch.nn.Module):
|
| 257 |
+
"""Base class of Transfomer decoder module.
|
| 258 |
+
Args:
|
| 259 |
+
vocab_size: output dim
|
| 260 |
+
encoder_output_size: dimension of attention
|
| 261 |
+
attention_heads: the number of heads of multi head attention
|
| 262 |
+
linear_units: the hidden units number of position-wise feedforward
|
| 263 |
+
num_blocks: the number of decoder blocks
|
| 264 |
+
r_num_blocks: the number of right to left decoder blocks
|
| 265 |
+
dropout_rate: dropout rate
|
| 266 |
+
self_attention_dropout_rate: dropout rate for attention
|
| 267 |
+
input_layer: input layer type
|
| 268 |
+
use_output_layer: whether to use output layer
|
| 269 |
+
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
| 270 |
+
normalize_before:
|
| 271 |
+
True: use layer_norm before each sub-block of a layer.
|
| 272 |
+
False: use layer_norm after each sub-block of a layer.
|
| 273 |
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
| 274 |
+
"""
|
| 275 |
+
|
| 276 |
+
def __init__(
|
| 277 |
+
self,
|
| 278 |
+
vocab_size: int,
|
| 279 |
+
encoder_output_size: int,
|
| 280 |
+
attention_heads: int = 4,
|
| 281 |
+
linear_units: int = 2048,
|
| 282 |
+
num_blocks: int = 6,
|
| 283 |
+
r_num_blocks: int = 0,
|
| 284 |
+
dropout_rate: float = 0.1,
|
| 285 |
+
positional_dropout_rate: float = 0.1,
|
| 286 |
+
self_attention_dropout_rate: float = 0.0,
|
| 287 |
+
src_attention_dropout_rate: float = 0.0,
|
| 288 |
+
input_layer: str = "embed",
|
| 289 |
+
use_output_layer: bool = True,
|
| 290 |
+
normalize_before: bool = True,
|
| 291 |
+
key_bias: bool = True,
|
| 292 |
+
gradient_checkpointing: bool = False,
|
| 293 |
+
tie_word_embedding: bool = False,
|
| 294 |
+
):
|
| 295 |
+
|
| 296 |
+
super().__init__()
|
| 297 |
+
self.tie_word_embedding = tie_word_embedding
|
| 298 |
+
self.left_decoder = TransformerDecoder(
|
| 299 |
+
vocab_size,
|
| 300 |
+
encoder_output_size,
|
| 301 |
+
attention_heads,
|
| 302 |
+
linear_units,
|
| 303 |
+
num_blocks,
|
| 304 |
+
dropout_rate,
|
| 305 |
+
positional_dropout_rate,
|
| 306 |
+
self_attention_dropout_rate,
|
| 307 |
+
src_attention_dropout_rate,
|
| 308 |
+
input_layer,
|
| 309 |
+
use_output_layer,
|
| 310 |
+
normalize_before,
|
| 311 |
+
key_bias=key_bias,
|
| 312 |
+
gradient_checkpointing=gradient_checkpointing,
|
| 313 |
+
tie_word_embedding=tie_word_embedding)
|
| 314 |
+
|
| 315 |
+
self.right_decoder = TransformerDecoder(
|
| 316 |
+
vocab_size,
|
| 317 |
+
encoder_output_size,
|
| 318 |
+
attention_heads,
|
| 319 |
+
linear_units,
|
| 320 |
+
r_num_blocks,
|
| 321 |
+
dropout_rate,
|
| 322 |
+
positional_dropout_rate,
|
| 323 |
+
self_attention_dropout_rate,
|
| 324 |
+
src_attention_dropout_rate,
|
| 325 |
+
input_layer,
|
| 326 |
+
use_output_layer,
|
| 327 |
+
normalize_before,
|
| 328 |
+
key_bias=key_bias,
|
| 329 |
+
gradient_checkpointing=gradient_checkpointing,
|
| 330 |
+
tie_word_embedding=tie_word_embedding)
|
| 331 |
+
|
| 332 |
+
def forward(
|
| 333 |
+
self,
|
| 334 |
+
memory: torch.Tensor,
|
| 335 |
+
memory_mask: torch.Tensor,
|
| 336 |
+
ys_in_pad: torch.Tensor,
|
| 337 |
+
ys_in_lens: torch.Tensor,
|
| 338 |
+
r_ys_in_pad: torch.Tensor,
|
| 339 |
+
reverse_weight: float = 0.0,
|
| 340 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 341 |
+
"""Forward decoder.
|
| 342 |
+
Args:
|
| 343 |
+
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
| 344 |
+
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
|
| 345 |
+
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
|
| 346 |
+
ys_in_lens: input lengths of this batch (batch)
|
| 347 |
+
r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
|
| 348 |
+
used for right to left decoder
|
| 349 |
+
reverse_weight: used for right to left decoder
|
| 350 |
+
Returns:
|
| 351 |
+
(tuple): tuple containing:
|
| 352 |
+
x: decoded token score before softmax (batch, maxlen_out,
|
| 353 |
+
vocab_size) if use_output_layer is True,
|
| 354 |
+
r_x: x: decoded token score (right to left decoder)
|
| 355 |
+
before softmax (batch, maxlen_out, vocab_size)
|
| 356 |
+
if use_output_layer is True,
|
| 357 |
+
olens: (batch, )
|
| 358 |
+
"""
|
| 359 |
+
l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
|
| 360 |
+
ys_in_lens)
|
| 361 |
+
r_x = torch.tensor(0.0)
|
| 362 |
+
if reverse_weight > 0.0:
|
| 363 |
+
r_x, _, olens = self.right_decoder(memory, memory_mask,
|
| 364 |
+
r_ys_in_pad, ys_in_lens)
|
| 365 |
+
return l_x, r_x, olens
|
| 366 |
+
|
| 367 |
+
def forward_one_step(
|
| 368 |
+
self,
|
| 369 |
+
memory: torch.Tensor,
|
| 370 |
+
memory_mask: torch.Tensor,
|
| 371 |
+
tgt: torch.Tensor,
|
| 372 |
+
tgt_mask: torch.Tensor,
|
| 373 |
+
cache: Optional[List[torch.Tensor]] = None,
|
| 374 |
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 375 |
+
"""Forward one step.
|
| 376 |
+
This is only used for decoding.
|
| 377 |
+
Args:
|
| 378 |
+
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
| 379 |
+
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
|
| 380 |
+
tgt: input token ids, int64 (batch, maxlen_out)
|
| 381 |
+
tgt_mask: input token mask, (batch, maxlen_out)
|
| 382 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
| 383 |
+
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
| 384 |
+
cache: cached output list of (batch, max_time_out-1, size)
|
| 385 |
+
Returns:
|
| 386 |
+
y, cache: NN output value and cache per `self.decoders`.
|
| 387 |
+
y.shape` is (batch, maxlen_out, token)
|
| 388 |
+
"""
|
| 389 |
+
return self.left_decoder.forward_one_step(memory, memory_mask, tgt,
|
| 390 |
+
tgt_mask, cache)
|
| 391 |
+
|
| 392 |
+
def tie_or_clone_weights(self, jit_mode: bool = True):
|
| 393 |
+
"""Tie or clone module weights (between word_emb and output_layer)
|
| 394 |
+
depending of whether we are using TorchScript or not"""
|
| 395 |
+
self.left_decoder.tie_or_clone_weights(jit_mode)
|
| 396 |
+
self.right_decoder.tie_or_clone_weights(jit_mode)
|
cosyvoice/transformer/decoder_layer.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019 Shigeki Karita
|
| 2 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Decoder self-attention layer definition."""
|
| 16 |
+
from typing import Optional, Tuple
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from torch import nn
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DecoderLayer(nn.Module):
|
| 23 |
+
"""Single decoder layer module.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
size (int): Input dimension.
|
| 27 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
| 28 |
+
`MultiHeadedAttention` instance can be used as the argument.
|
| 29 |
+
src_attn (torch.nn.Module): Inter-attention module instance.
|
| 30 |
+
`MultiHeadedAttention` instance can be used as the argument.
|
| 31 |
+
If `None` is passed, Inter-attention is not used, such as
|
| 32 |
+
CIF, GPT, and other decoder only model.
|
| 33 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
| 34 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
| 35 |
+
dropout_rate (float): Dropout rate.
|
| 36 |
+
normalize_before (bool):
|
| 37 |
+
True: use layer_norm before each sub-block.
|
| 38 |
+
False: to use layer_norm after each sub-block.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
size: int,
|
| 44 |
+
self_attn: nn.Module,
|
| 45 |
+
src_attn: Optional[nn.Module],
|
| 46 |
+
feed_forward: nn.Module,
|
| 47 |
+
dropout_rate: float,
|
| 48 |
+
normalize_before: bool = True,
|
| 49 |
+
):
|
| 50 |
+
"""Construct an DecoderLayer object."""
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.size = size
|
| 53 |
+
self.self_attn = self_attn
|
| 54 |
+
self.src_attn = src_attn
|
| 55 |
+
self.feed_forward = feed_forward
|
| 56 |
+
self.norm1 = nn.LayerNorm(size, eps=1e-5)
|
| 57 |
+
self.norm2 = nn.LayerNorm(size, eps=1e-5)
|
| 58 |
+
self.norm3 = nn.LayerNorm(size, eps=1e-5)
|
| 59 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 60 |
+
self.normalize_before = normalize_before
|
| 61 |
+
|
| 62 |
+
def forward(
|
| 63 |
+
self,
|
| 64 |
+
tgt: torch.Tensor,
|
| 65 |
+
tgt_mask: torch.Tensor,
|
| 66 |
+
memory: torch.Tensor,
|
| 67 |
+
memory_mask: torch.Tensor,
|
| 68 |
+
cache: Optional[torch.Tensor] = None
|
| 69 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 70 |
+
"""Compute decoded features.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
|
| 74 |
+
tgt_mask (torch.Tensor): Mask for input tensor
|
| 75 |
+
(#batch, maxlen_out).
|
| 76 |
+
memory (torch.Tensor): Encoded memory
|
| 77 |
+
(#batch, maxlen_in, size).
|
| 78 |
+
memory_mask (torch.Tensor): Encoded memory mask
|
| 79 |
+
(#batch, maxlen_in).
|
| 80 |
+
cache (torch.Tensor): cached tensors.
|
| 81 |
+
(#batch, maxlen_out - 1, size).
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
torch.Tensor: Output tensor (#batch, maxlen_out, size).
|
| 85 |
+
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
|
| 86 |
+
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
|
| 87 |
+
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
|
| 88 |
+
|
| 89 |
+
"""
|
| 90 |
+
residual = tgt
|
| 91 |
+
if self.normalize_before:
|
| 92 |
+
tgt = self.norm1(tgt)
|
| 93 |
+
|
| 94 |
+
if cache is None:
|
| 95 |
+
tgt_q = tgt
|
| 96 |
+
tgt_q_mask = tgt_mask
|
| 97 |
+
else:
|
| 98 |
+
# compute only the last frame query keeping dim: max_time_out -> 1
|
| 99 |
+
assert cache.shape == (
|
| 100 |
+
tgt.shape[0],
|
| 101 |
+
tgt.shape[1] - 1,
|
| 102 |
+
self.size,
|
| 103 |
+
), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
|
| 104 |
+
tgt_q = tgt[:, -1:, :]
|
| 105 |
+
residual = residual[:, -1:, :]
|
| 106 |
+
tgt_q_mask = tgt_mask[:, -1:, :]
|
| 107 |
+
|
| 108 |
+
x = residual + self.dropout(
|
| 109 |
+
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
|
| 110 |
+
if not self.normalize_before:
|
| 111 |
+
x = self.norm1(x)
|
| 112 |
+
|
| 113 |
+
if self.src_attn is not None:
|
| 114 |
+
residual = x
|
| 115 |
+
if self.normalize_before:
|
| 116 |
+
x = self.norm2(x)
|
| 117 |
+
x = residual + self.dropout(
|
| 118 |
+
self.src_attn(x, memory, memory, memory_mask)[0])
|
| 119 |
+
if not self.normalize_before:
|
| 120 |
+
x = self.norm2(x)
|
| 121 |
+
|
| 122 |
+
residual = x
|
| 123 |
+
if self.normalize_before:
|
| 124 |
+
x = self.norm3(x)
|
| 125 |
+
x = residual + self.dropout(self.feed_forward(x))
|
| 126 |
+
if not self.normalize_before:
|
| 127 |
+
x = self.norm3(x)
|
| 128 |
+
|
| 129 |
+
if cache is not None:
|
| 130 |
+
x = torch.cat([cache, x], dim=1)
|
| 131 |
+
|
| 132 |
+
return x, tgt_mask, memory, memory_mask
|
cosyvoice/transformer/embedding.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
| 2 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
| 16 |
+
"""Positonal Encoding Module."""
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
from typing import Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class PositionalEncoding(torch.nn.Module):
|
| 27 |
+
"""Positional encoding.
|
| 28 |
+
|
| 29 |
+
:param int d_model: embedding dim
|
| 30 |
+
:param float dropout_rate: dropout rate
|
| 31 |
+
:param int max_len: maximum input length
|
| 32 |
+
|
| 33 |
+
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
|
| 34 |
+
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self,
|
| 38 |
+
d_model: int,
|
| 39 |
+
dropout_rate: float,
|
| 40 |
+
max_len: int = 5000,
|
| 41 |
+
reverse: bool = False):
|
| 42 |
+
"""Construct an PositionalEncoding object."""
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.d_model = d_model
|
| 45 |
+
self.xscale = math.sqrt(self.d_model)
|
| 46 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
| 47 |
+
self.max_len = max_len
|
| 48 |
+
|
| 49 |
+
self.pe = torch.zeros(self.max_len, self.d_model)
|
| 50 |
+
position = torch.arange(0, self.max_len,
|
| 51 |
+
dtype=torch.float32).unsqueeze(1)
|
| 52 |
+
div_term = torch.exp(
|
| 53 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32) *
|
| 54 |
+
-(math.log(10000.0) / self.d_model))
|
| 55 |
+
self.pe[:, 0::2] = torch.sin(position * div_term)
|
| 56 |
+
self.pe[:, 1::2] = torch.cos(position * div_term)
|
| 57 |
+
self.pe = self.pe.unsqueeze(0)
|
| 58 |
+
|
| 59 |
+
def forward(self,
|
| 60 |
+
x: torch.Tensor,
|
| 61 |
+
offset: Union[int, torch.Tensor] = 0) \
|
| 62 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
| 63 |
+
"""Add positional encoding.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
| 67 |
+
offset (int, torch.tensor): position offset
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
| 71 |
+
torch.Tensor: for compatibility to RelPositionalEncoding
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
self.pe = self.pe.to(x.device)
|
| 75 |
+
pos_emb = self.position_encoding(offset, x.size(1), False)
|
| 76 |
+
x = x * self.xscale + pos_emb
|
| 77 |
+
return self.dropout(x), self.dropout(pos_emb)
|
| 78 |
+
|
| 79 |
+
def position_encoding(self,
|
| 80 |
+
offset: Union[int, torch.Tensor],
|
| 81 |
+
size: int,
|
| 82 |
+
apply_dropout: bool = True) -> torch.Tensor:
|
| 83 |
+
""" For getting encoding in a streaming fashion
|
| 84 |
+
|
| 85 |
+
Attention!!!!!
|
| 86 |
+
we apply dropout only once at the whole utterance level in a none
|
| 87 |
+
streaming way, but will call this function several times with
|
| 88 |
+
increasing input size in a streaming scenario, so the dropout will
|
| 89 |
+
be applied several times.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
offset (int or torch.tensor): start offset
|
| 93 |
+
size (int): required size of position encoding
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
torch.Tensor: Corresponding encoding
|
| 97 |
+
"""
|
| 98 |
+
# How to subscript a Union type:
|
| 99 |
+
# https://github.com/pytorch/pytorch/issues/69434
|
| 100 |
+
if isinstance(offset, int):
|
| 101 |
+
assert offset + size <= self.max_len
|
| 102 |
+
pos_emb = self.pe[:, offset:offset + size]
|
| 103 |
+
elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
|
| 104 |
+
assert offset + size <= self.max_len
|
| 105 |
+
pos_emb = self.pe[:, offset:offset + size]
|
| 106 |
+
else: # for batched streaming decoding on GPU
|
| 107 |
+
assert torch.max(offset) + size <= self.max_len
|
| 108 |
+
index = offset.unsqueeze(1) + \
|
| 109 |
+
torch.arange(0, size).to(offset.device) # B X T
|
| 110 |
+
flag = index > 0
|
| 111 |
+
# remove negative offset
|
| 112 |
+
index = index * flag
|
| 113 |
+
pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
|
| 114 |
+
|
| 115 |
+
if apply_dropout:
|
| 116 |
+
pos_emb = self.dropout(pos_emb)
|
| 117 |
+
return pos_emb
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class RelPositionalEncoding(PositionalEncoding):
|
| 121 |
+
"""Relative positional encoding module.
|
| 122 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
| 123 |
+
Args:
|
| 124 |
+
d_model (int): Embedding dimension.
|
| 125 |
+
dropout_rate (float): Dropout rate.
|
| 126 |
+
max_len (int): Maximum input length.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
| 130 |
+
"""Initialize class."""
|
| 131 |
+
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
| 132 |
+
|
| 133 |
+
def forward(self,
|
| 134 |
+
x: torch.Tensor,
|
| 135 |
+
offset: Union[int, torch.Tensor] = 0) \
|
| 136 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
| 137 |
+
"""Compute positional encoding.
|
| 138 |
+
Args:
|
| 139 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 140 |
+
Returns:
|
| 141 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 142 |
+
torch.Tensor: Positional embedding tensor (1, time, `*`).
|
| 143 |
+
"""
|
| 144 |
+
self.pe = self.pe.to(x.device)
|
| 145 |
+
x = x * self.xscale
|
| 146 |
+
pos_emb = self.position_encoding(offset, x.size(1), False)
|
| 147 |
+
return self.dropout(x), self.dropout(pos_emb)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class WhisperPositionalEncoding(PositionalEncoding):
|
| 151 |
+
""" Sinusoids position encoding used in openai-whisper.encoder
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
|
| 155 |
+
super().__init__(d_model, dropout_rate, max_len)
|
| 156 |
+
self.xscale = 1.0
|
| 157 |
+
log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
|
| 158 |
+
inv_timescales = torch.exp(-log_timescale_increment *
|
| 159 |
+
torch.arange(d_model // 2))
|
| 160 |
+
scaled_time = torch.arange(max_len)[:, np.newaxis] * \
|
| 161 |
+
inv_timescales[np.newaxis, :]
|
| 162 |
+
pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
| 163 |
+
delattr(self, "pe")
|
| 164 |
+
self.register_buffer("pe", pe.unsqueeze(0))
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class LearnablePositionalEncoding(PositionalEncoding):
|
| 168 |
+
""" Learnable position encoding used in openai-whisper.decoder
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
|
| 172 |
+
super().__init__(d_model, dropout_rate, max_len)
|
| 173 |
+
# NOTE(xcsong): overwrite self.pe & self.xscale
|
| 174 |
+
self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
|
| 175 |
+
self.xscale = 1.0
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class NoPositionalEncoding(torch.nn.Module):
|
| 179 |
+
""" No position encoding
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
def __init__(self, d_model: int, dropout_rate: float):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.d_model = d_model
|
| 185 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
| 186 |
+
|
| 187 |
+
def forward(self,
|
| 188 |
+
x: torch.Tensor,
|
| 189 |
+
offset: Union[int, torch.Tensor] = 0) \
|
| 190 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
| 191 |
+
""" Just return zero vector for interface compatibility
|
| 192 |
+
"""
|
| 193 |
+
pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
|
| 194 |
+
return self.dropout(x), pos_emb
|
| 195 |
+
|
| 196 |
+
def position_encoding(self, offset: Union[int, torch.Tensor],
|
| 197 |
+
size: int) -> torch.Tensor:
|
| 198 |
+
return torch.zeros(1, size, self.d_model)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class EspnetRelPositionalEncoding(torch.nn.Module):
|
| 202 |
+
"""Relative positional encoding module (new implementation).
|
| 203 |
+
|
| 204 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
| 205 |
+
|
| 206 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
d_model (int): Embedding dimension.
|
| 210 |
+
dropout_rate (float): Dropout rate.
|
| 211 |
+
max_len (int): Maximum input length.
|
| 212 |
+
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
| 216 |
+
"""Construct an PositionalEncoding object."""
|
| 217 |
+
super(EspnetRelPositionalEncoding, self).__init__()
|
| 218 |
+
self.d_model = d_model
|
| 219 |
+
self.xscale = math.sqrt(self.d_model)
|
| 220 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
| 221 |
+
self.pe = None
|
| 222 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
| 223 |
+
|
| 224 |
+
def extend_pe(self, x: torch.Tensor):
|
| 225 |
+
"""Reset the positional encodings."""
|
| 226 |
+
if self.pe is not None:
|
| 227 |
+
# self.pe contains both positive and negative parts
|
| 228 |
+
# the length of self.pe is 2 * input_len - 1
|
| 229 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
| 230 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
| 231 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
| 232 |
+
return
|
| 233 |
+
# Suppose `i` means to the position of query vecotr and `j` means the
|
| 234 |
+
# position of key vector. We use position relative positions when keys
|
| 235 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
| 236 |
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
| 237 |
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
| 238 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
| 239 |
+
div_term = torch.exp(
|
| 240 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
| 241 |
+
* -(math.log(10000.0) / self.d_model)
|
| 242 |
+
)
|
| 243 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
| 244 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
| 245 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
| 246 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
| 247 |
+
|
| 248 |
+
# Reserve the order of positive indices and concat both positive and
|
| 249 |
+
# negative indices. This is used to support the shifting trick
|
| 250 |
+
# as in https://arxiv.org/abs/1901.02860
|
| 251 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
| 252 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
| 253 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
| 254 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
| 255 |
+
|
| 256 |
+
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
|
| 257 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
| 258 |
+
"""Add positional encoding.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 265 |
+
|
| 266 |
+
"""
|
| 267 |
+
self.extend_pe(x)
|
| 268 |
+
x = x * self.xscale
|
| 269 |
+
pos_emb = self.position_encoding(size=x.size(1), offset=offset)
|
| 270 |
+
return self.dropout(x), self.dropout(pos_emb)
|
| 271 |
+
|
| 272 |
+
def position_encoding(self,
|
| 273 |
+
offset: Union[int, torch.Tensor],
|
| 274 |
+
size: int) -> torch.Tensor:
|
| 275 |
+
""" For getting encoding in a streaming fashion
|
| 276 |
+
|
| 277 |
+
Attention!!!!!
|
| 278 |
+
we apply dropout only once at the whole utterance level in a none
|
| 279 |
+
streaming way, but will call this function several times with
|
| 280 |
+
increasing input size in a streaming scenario, so the dropout will
|
| 281 |
+
be applied several times.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
offset (int or torch.tensor): start offset
|
| 285 |
+
size (int): required size of position encoding
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
torch.Tensor: Corresponding encoding
|
| 289 |
+
"""
|
| 290 |
+
pos_emb = self.pe[
|
| 291 |
+
:,
|
| 292 |
+
self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
|
| 293 |
+
]
|
| 294 |
+
return pos_emb
|
cosyvoice/transformer/encoder.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
| 2 |
+
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
|
| 3 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
| 17 |
+
"""Encoder definition."""
|
| 18 |
+
from typing import Tuple
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.utils.checkpoint as ckpt
|
| 22 |
+
|
| 23 |
+
from cosyvoice.transformer.convolution import ConvolutionModule
|
| 24 |
+
from cosyvoice.transformer.encoder_layer import TransformerEncoderLayer
|
| 25 |
+
from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
|
| 26 |
+
from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
| 27 |
+
from cosyvoice.utils.class_utils import (
|
| 28 |
+
COSYVOICE_EMB_CLASSES,
|
| 29 |
+
COSYVOICE_SUBSAMPLE_CLASSES,
|
| 30 |
+
COSYVOICE_ATTENTION_CLASSES,
|
| 31 |
+
COSYVOICE_ACTIVATION_CLASSES,
|
| 32 |
+
)
|
| 33 |
+
from cosyvoice.utils.mask import make_pad_mask
|
| 34 |
+
from cosyvoice.utils.mask import add_optional_chunk_mask
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class BaseEncoder(torch.nn.Module):
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
input_size: int,
|
| 42 |
+
output_size: int = 256,
|
| 43 |
+
attention_heads: int = 4,
|
| 44 |
+
linear_units: int = 2048,
|
| 45 |
+
num_blocks: int = 6,
|
| 46 |
+
dropout_rate: float = 0.1,
|
| 47 |
+
positional_dropout_rate: float = 0.1,
|
| 48 |
+
attention_dropout_rate: float = 0.0,
|
| 49 |
+
input_layer: str = "conv2d",
|
| 50 |
+
pos_enc_layer_type: str = "abs_pos",
|
| 51 |
+
normalize_before: bool = True,
|
| 52 |
+
static_chunk_size: int = 0,
|
| 53 |
+
use_dynamic_chunk: bool = False,
|
| 54 |
+
global_cmvn: torch.nn.Module = None,
|
| 55 |
+
use_dynamic_left_chunk: bool = False,
|
| 56 |
+
gradient_checkpointing: bool = False,
|
| 57 |
+
):
|
| 58 |
+
"""
|
| 59 |
+
Args:
|
| 60 |
+
input_size (int): input dim
|
| 61 |
+
output_size (int): dimension of attention
|
| 62 |
+
attention_heads (int): the number of heads of multi head attention
|
| 63 |
+
linear_units (int): the hidden units number of position-wise feed
|
| 64 |
+
forward
|
| 65 |
+
num_blocks (int): the number of decoder blocks
|
| 66 |
+
dropout_rate (float): dropout rate
|
| 67 |
+
attention_dropout_rate (float): dropout rate in attention
|
| 68 |
+
positional_dropout_rate (float): dropout rate after adding
|
| 69 |
+
positional encoding
|
| 70 |
+
input_layer (str): input layer type.
|
| 71 |
+
optional [linear, conv2d, conv2d6, conv2d8]
|
| 72 |
+
pos_enc_layer_type (str): Encoder positional encoding layer type.
|
| 73 |
+
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
|
| 74 |
+
normalize_before (bool):
|
| 75 |
+
True: use layer_norm before each sub-block of a layer.
|
| 76 |
+
False: use layer_norm after each sub-block of a layer.
|
| 77 |
+
static_chunk_size (int): chunk size for static chunk training and
|
| 78 |
+
decoding
|
| 79 |
+
use_dynamic_chunk (bool): whether use dynamic chunk size for
|
| 80 |
+
training or not, You can only use fixed chunk(chunk_size > 0)
|
| 81 |
+
or dyanmic chunk size(use_dynamic_chunk = True)
|
| 82 |
+
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
|
| 83 |
+
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
|
| 84 |
+
dynamic chunk training
|
| 85 |
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
| 86 |
+
gradient_checkpointing: rerunning a forward-pass segment for each
|
| 87 |
+
checkpointed segment during backward.
|
| 88 |
+
"""
|
| 89 |
+
super().__init__()
|
| 90 |
+
self._output_size = output_size
|
| 91 |
+
|
| 92 |
+
self.global_cmvn = global_cmvn
|
| 93 |
+
self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
| 94 |
+
input_size,
|
| 95 |
+
output_size,
|
| 96 |
+
dropout_rate,
|
| 97 |
+
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
|
| 98 |
+
positional_dropout_rate),
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
self.normalize_before = normalize_before
|
| 102 |
+
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
| 103 |
+
self.static_chunk_size = static_chunk_size
|
| 104 |
+
self.use_dynamic_chunk = use_dynamic_chunk
|
| 105 |
+
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
| 106 |
+
self.gradient_checkpointing = gradient_checkpointing
|
| 107 |
+
|
| 108 |
+
def output_size(self) -> int:
|
| 109 |
+
return self._output_size
|
| 110 |
+
|
| 111 |
+
def forward(
|
| 112 |
+
self,
|
| 113 |
+
xs: torch.Tensor,
|
| 114 |
+
xs_lens: torch.Tensor,
|
| 115 |
+
decoding_chunk_size: int = 0,
|
| 116 |
+
num_decoding_left_chunks: int = -1,
|
| 117 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 118 |
+
"""Embed positions in tensor.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
xs: padded input tensor (B, T, D)
|
| 122 |
+
xs_lens: input length (B)
|
| 123 |
+
decoding_chunk_size: decoding chunk size for dynamic chunk
|
| 124 |
+
0: default for training, use random dynamic chunk.
|
| 125 |
+
<0: for decoding, use full chunk.
|
| 126 |
+
>0: for decoding, use fixed chunk size as set.
|
| 127 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
| 128 |
+
the chunk size is decoding_chunk_size.
|
| 129 |
+
>=0: use num_decoding_left_chunks
|
| 130 |
+
<0: use all left chunks
|
| 131 |
+
Returns:
|
| 132 |
+
encoder output tensor xs, and subsampled masks
|
| 133 |
+
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
| 134 |
+
masks: torch.Tensor batch padding mask after subsample
|
| 135 |
+
(B, 1, T' ~= T/subsample_rate)
|
| 136 |
+
NOTE(xcsong):
|
| 137 |
+
We pass the `__call__` method of the modules instead of `forward` to the
|
| 138 |
+
checkpointing API because `__call__` attaches all the hooks of the module.
|
| 139 |
+
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
| 140 |
+
"""
|
| 141 |
+
T = xs.size(1)
|
| 142 |
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
| 143 |
+
if self.global_cmvn is not None:
|
| 144 |
+
xs = self.global_cmvn(xs)
|
| 145 |
+
xs, pos_emb, masks = self.embed(xs, masks)
|
| 146 |
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
| 147 |
+
chunk_masks = add_optional_chunk_mask(xs, masks,
|
| 148 |
+
self.use_dynamic_chunk,
|
| 149 |
+
self.use_dynamic_left_chunk,
|
| 150 |
+
decoding_chunk_size,
|
| 151 |
+
self.static_chunk_size,
|
| 152 |
+
num_decoding_left_chunks)
|
| 153 |
+
if self.gradient_checkpointing and self.training:
|
| 154 |
+
xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb,
|
| 155 |
+
mask_pad)
|
| 156 |
+
else:
|
| 157 |
+
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
| 158 |
+
if self.normalize_before:
|
| 159 |
+
xs = self.after_norm(xs)
|
| 160 |
+
# Here we assume the mask is not changed in encoder layers, so just
|
| 161 |
+
# return the masks before encoder layers, and the masks will be used
|
| 162 |
+
# for cross attention with decoder later
|
| 163 |
+
return xs, masks
|
| 164 |
+
|
| 165 |
+
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
| 166 |
+
pos_emb: torch.Tensor,
|
| 167 |
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
| 168 |
+
for layer in self.encoders:
|
| 169 |
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
| 170 |
+
return xs
|
| 171 |
+
|
| 172 |
+
@torch.jit.unused
|
| 173 |
+
def forward_layers_checkpointed(self, xs: torch.Tensor,
|
| 174 |
+
chunk_masks: torch.Tensor,
|
| 175 |
+
pos_emb: torch.Tensor,
|
| 176 |
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
| 177 |
+
for layer in self.encoders:
|
| 178 |
+
xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs,
|
| 179 |
+
chunk_masks, pos_emb,
|
| 180 |
+
mask_pad)
|
| 181 |
+
return xs
|
| 182 |
+
|
| 183 |
+
@torch.jit.export
|
| 184 |
+
def forward_chunk(
|
| 185 |
+
self,
|
| 186 |
+
xs: torch.Tensor,
|
| 187 |
+
offset: int,
|
| 188 |
+
required_cache_size: int,
|
| 189 |
+
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
| 190 |
+
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
| 191 |
+
att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
| 192 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 193 |
+
""" Forward just one chunk
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
|
| 197 |
+
where `time == (chunk_size - 1) * subsample_rate + \
|
| 198 |
+
subsample.right_context + 1`
|
| 199 |
+
offset (int): current offset in encoder output time stamp
|
| 200 |
+
required_cache_size (int): cache size required for next chunk
|
| 201 |
+
compuation
|
| 202 |
+
>=0: actual cache size
|
| 203 |
+
<0: means all history cache is required
|
| 204 |
+
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
|
| 205 |
+
transformer/conformer attention, with shape
|
| 206 |
+
(elayers, head, cache_t1, d_k * 2), where
|
| 207 |
+
`head * d_k == hidden-dim` and
|
| 208 |
+
`cache_t1 == chunk_size * num_decoding_left_chunks`.
|
| 209 |
+
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
|
| 210 |
+
(elayers, b=1, hidden-dim, cache_t2), where
|
| 211 |
+
`cache_t2 == cnn.lorder - 1`
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
torch.Tensor: output of current input xs,
|
| 215 |
+
with shape (b=1, chunk_size, hidden-dim).
|
| 216 |
+
torch.Tensor: new attention cache required for next chunk, with
|
| 217 |
+
dynamic shape (elayers, head, ?, d_k * 2)
|
| 218 |
+
depending on required_cache_size.
|
| 219 |
+
torch.Tensor: new conformer cnn cache required for next chunk, with
|
| 220 |
+
same shape as the original cnn_cache.
|
| 221 |
+
|
| 222 |
+
"""
|
| 223 |
+
assert xs.size(0) == 1
|
| 224 |
+
# tmp_masks is just for interface compatibility
|
| 225 |
+
tmp_masks = torch.ones(1,
|
| 226 |
+
xs.size(1),
|
| 227 |
+
device=xs.device,
|
| 228 |
+
dtype=torch.bool)
|
| 229 |
+
tmp_masks = tmp_masks.unsqueeze(1)
|
| 230 |
+
if self.global_cmvn is not None:
|
| 231 |
+
xs = self.global_cmvn(xs)
|
| 232 |
+
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
|
| 233 |
+
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
|
| 234 |
+
# NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
|
| 235 |
+
elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
|
| 236 |
+
chunk_size = xs.size(1)
|
| 237 |
+
attention_key_size = cache_t1 + chunk_size
|
| 238 |
+
pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
|
| 239 |
+
size=attention_key_size)
|
| 240 |
+
if required_cache_size < 0:
|
| 241 |
+
next_cache_start = 0
|
| 242 |
+
elif required_cache_size == 0:
|
| 243 |
+
next_cache_start = attention_key_size
|
| 244 |
+
else:
|
| 245 |
+
next_cache_start = max(attention_key_size - required_cache_size, 0)
|
| 246 |
+
r_att_cache = []
|
| 247 |
+
r_cnn_cache = []
|
| 248 |
+
for i, layer in enumerate(self.encoders):
|
| 249 |
+
# NOTE(xcsong): Before layer.forward
|
| 250 |
+
# shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
|
| 251 |
+
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
|
| 252 |
+
xs, _, new_att_cache, new_cnn_cache = layer(
|
| 253 |
+
xs,
|
| 254 |
+
att_mask,
|
| 255 |
+
pos_emb,
|
| 256 |
+
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
|
| 257 |
+
cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache)
|
| 258 |
+
# NOTE(xcsong): After layer.forward
|
| 259 |
+
# shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
|
| 260 |
+
# shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
|
| 261 |
+
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
|
| 262 |
+
r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
|
| 263 |
+
if self.normalize_before:
|
| 264 |
+
xs = self.after_norm(xs)
|
| 265 |
+
|
| 266 |
+
# NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
|
| 267 |
+
# ? may be larger than cache_t1, it depends on required_cache_size
|
| 268 |
+
r_att_cache = torch.cat(r_att_cache, dim=0)
|
| 269 |
+
# NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
|
| 270 |
+
r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
|
| 271 |
+
|
| 272 |
+
return (xs, r_att_cache, r_cnn_cache)
|
| 273 |
+
|
| 274 |
+
@torch.jit.unused
|
| 275 |
+
def forward_chunk_by_chunk(
|
| 276 |
+
self,
|
| 277 |
+
xs: torch.Tensor,
|
| 278 |
+
decoding_chunk_size: int,
|
| 279 |
+
num_decoding_left_chunks: int = -1,
|
| 280 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 281 |
+
""" Forward input chunk by chunk with chunk_size like a streaming
|
| 282 |
+
fashion
|
| 283 |
+
|
| 284 |
+
Here we should pay special attention to computation cache in the
|
| 285 |
+
streaming style forward chunk by chunk. Three things should be taken
|
| 286 |
+
into account for computation in the current network:
|
| 287 |
+
1. transformer/conformer encoder layers output cache
|
| 288 |
+
2. convolution in conformer
|
| 289 |
+
3. convolution in subsampling
|
| 290 |
+
|
| 291 |
+
However, we don't implement subsampling cache for:
|
| 292 |
+
1. We can control subsampling module to output the right result by
|
| 293 |
+
overlapping input instead of cache left context, even though it
|
| 294 |
+
wastes some computation, but subsampling only takes a very
|
| 295 |
+
small fraction of computation in the whole model.
|
| 296 |
+
2. Typically, there are several covolution layers with subsampling
|
| 297 |
+
in subsampling module, it is tricky and complicated to do cache
|
| 298 |
+
with different convolution layers with different subsampling
|
| 299 |
+
rate.
|
| 300 |
+
3. Currently, nn.Sequential is used to stack all the convolution
|
| 301 |
+
layers in subsampling, we need to rewrite it to make it work
|
| 302 |
+
with cache, which is not preferred.
|
| 303 |
+
Args:
|
| 304 |
+
xs (torch.Tensor): (1, max_len, dim)
|
| 305 |
+
chunk_size (int): decoding chunk size
|
| 306 |
+
"""
|
| 307 |
+
assert decoding_chunk_size > 0
|
| 308 |
+
# The model is trained by static or dynamic chunk
|
| 309 |
+
assert self.static_chunk_size > 0 or self.use_dynamic_chunk
|
| 310 |
+
subsampling = self.embed.subsampling_rate
|
| 311 |
+
context = self.embed.right_context + 1 # Add current frame
|
| 312 |
+
stride = subsampling * decoding_chunk_size
|
| 313 |
+
decoding_window = (decoding_chunk_size - 1) * subsampling + context
|
| 314 |
+
num_frames = xs.size(1)
|
| 315 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
|
| 316 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
|
| 317 |
+
outputs = []
|
| 318 |
+
offset = 0
|
| 319 |
+
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
|
| 320 |
+
|
| 321 |
+
# Feed forward overlap input step by step
|
| 322 |
+
for cur in range(0, num_frames - context + 1, stride):
|
| 323 |
+
end = min(cur + decoding_window, num_frames)
|
| 324 |
+
chunk_xs = xs[:, cur:end, :]
|
| 325 |
+
(y, att_cache,
|
| 326 |
+
cnn_cache) = self.forward_chunk(chunk_xs, offset,
|
| 327 |
+
required_cache_size, att_cache,
|
| 328 |
+
cnn_cache)
|
| 329 |
+
outputs.append(y)
|
| 330 |
+
offset += y.size(1)
|
| 331 |
+
ys = torch.cat(outputs, 1)
|
| 332 |
+
masks = torch.ones((1, 1, ys.size(1)),
|
| 333 |
+
device=ys.device,
|
| 334 |
+
dtype=torch.bool)
|
| 335 |
+
return ys, masks
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
class TransformerEncoder(BaseEncoder):
|
| 339 |
+
"""Transformer encoder module."""
|
| 340 |
+
|
| 341 |
+
def __init__(
|
| 342 |
+
self,
|
| 343 |
+
input_size: int,
|
| 344 |
+
output_size: int = 256,
|
| 345 |
+
attention_heads: int = 4,
|
| 346 |
+
linear_units: int = 2048,
|
| 347 |
+
num_blocks: int = 6,
|
| 348 |
+
dropout_rate: float = 0.1,
|
| 349 |
+
positional_dropout_rate: float = 0.1,
|
| 350 |
+
attention_dropout_rate: float = 0.0,
|
| 351 |
+
input_layer: str = "conv2d",
|
| 352 |
+
pos_enc_layer_type: str = "abs_pos",
|
| 353 |
+
normalize_before: bool = True,
|
| 354 |
+
static_chunk_size: int = 0,
|
| 355 |
+
use_dynamic_chunk: bool = False,
|
| 356 |
+
global_cmvn: torch.nn.Module = None,
|
| 357 |
+
use_dynamic_left_chunk: bool = False,
|
| 358 |
+
key_bias: bool = True,
|
| 359 |
+
selfattention_layer_type: str = "selfattn",
|
| 360 |
+
activation_type: str = "relu",
|
| 361 |
+
gradient_checkpointing: bool = False,
|
| 362 |
+
):
|
| 363 |
+
""" Construct TransformerEncoder
|
| 364 |
+
|
| 365 |
+
See Encoder for the meaning of each parameter.
|
| 366 |
+
"""
|
| 367 |
+
super().__init__(input_size, output_size, attention_heads,
|
| 368 |
+
linear_units, num_blocks, dropout_rate,
|
| 369 |
+
positional_dropout_rate, attention_dropout_rate,
|
| 370 |
+
input_layer, pos_enc_layer_type, normalize_before,
|
| 371 |
+
static_chunk_size, use_dynamic_chunk, global_cmvn,
|
| 372 |
+
use_dynamic_left_chunk, gradient_checkpointing)
|
| 373 |
+
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
| 374 |
+
self.encoders = torch.nn.ModuleList([
|
| 375 |
+
TransformerEncoderLayer(
|
| 376 |
+
output_size,
|
| 377 |
+
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](attention_heads,
|
| 378 |
+
output_size,
|
| 379 |
+
attention_dropout_rate,
|
| 380 |
+
key_bias),
|
| 381 |
+
PositionwiseFeedForward(output_size, linear_units,
|
| 382 |
+
dropout_rate, activation),
|
| 383 |
+
dropout_rate, normalize_before) for _ in range(num_blocks)
|
| 384 |
+
])
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class ConformerEncoder(BaseEncoder):
|
| 388 |
+
"""Conformer encoder module."""
|
| 389 |
+
|
| 390 |
+
def __init__(
|
| 391 |
+
self,
|
| 392 |
+
input_size: int,
|
| 393 |
+
output_size: int = 256,
|
| 394 |
+
attention_heads: int = 4,
|
| 395 |
+
linear_units: int = 2048,
|
| 396 |
+
num_blocks: int = 6,
|
| 397 |
+
dropout_rate: float = 0.1,
|
| 398 |
+
positional_dropout_rate: float = 0.1,
|
| 399 |
+
attention_dropout_rate: float = 0.0,
|
| 400 |
+
input_layer: str = "conv2d",
|
| 401 |
+
pos_enc_layer_type: str = "rel_pos",
|
| 402 |
+
normalize_before: bool = True,
|
| 403 |
+
static_chunk_size: int = 0,
|
| 404 |
+
use_dynamic_chunk: bool = False,
|
| 405 |
+
global_cmvn: torch.nn.Module = None,
|
| 406 |
+
use_dynamic_left_chunk: bool = False,
|
| 407 |
+
positionwise_conv_kernel_size: int = 1,
|
| 408 |
+
macaron_style: bool = True,
|
| 409 |
+
selfattention_layer_type: str = "rel_selfattn",
|
| 410 |
+
activation_type: str = "swish",
|
| 411 |
+
use_cnn_module: bool = True,
|
| 412 |
+
cnn_module_kernel: int = 15,
|
| 413 |
+
causal: bool = False,
|
| 414 |
+
cnn_module_norm: str = "batch_norm",
|
| 415 |
+
key_bias: bool = True,
|
| 416 |
+
gradient_checkpointing: bool = False,
|
| 417 |
+
):
|
| 418 |
+
"""Construct ConformerEncoder
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
input_size to use_dynamic_chunk, see in BaseEncoder
|
| 422 |
+
positionwise_conv_kernel_size (int): Kernel size of positionwise
|
| 423 |
+
conv1d layer.
|
| 424 |
+
macaron_style (bool): Whether to use macaron style for
|
| 425 |
+
positionwise layer.
|
| 426 |
+
selfattention_layer_type (str): Encoder attention layer type,
|
| 427 |
+
the parameter has no effect now, it's just for configure
|
| 428 |
+
compatibility.
|
| 429 |
+
activation_type (str): Encoder activation function type.
|
| 430 |
+
use_cnn_module (bool): Whether to use convolution module.
|
| 431 |
+
cnn_module_kernel (int): Kernel size of convolution module.
|
| 432 |
+
causal (bool): whether to use causal convolution or not.
|
| 433 |
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
| 434 |
+
"""
|
| 435 |
+
super().__init__(input_size, output_size, attention_heads,
|
| 436 |
+
linear_units, num_blocks, dropout_rate,
|
| 437 |
+
positional_dropout_rate, attention_dropout_rate,
|
| 438 |
+
input_layer, pos_enc_layer_type, normalize_before,
|
| 439 |
+
static_chunk_size, use_dynamic_chunk, global_cmvn,
|
| 440 |
+
use_dynamic_left_chunk, gradient_checkpointing)
|
| 441 |
+
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
| 442 |
+
|
| 443 |
+
# self-attention module definition
|
| 444 |
+
encoder_selfattn_layer_args = (
|
| 445 |
+
attention_heads,
|
| 446 |
+
output_size,
|
| 447 |
+
attention_dropout_rate,
|
| 448 |
+
key_bias,
|
| 449 |
+
)
|
| 450 |
+
# feed-forward module definition
|
| 451 |
+
positionwise_layer_args = (
|
| 452 |
+
output_size,
|
| 453 |
+
linear_units,
|
| 454 |
+
dropout_rate,
|
| 455 |
+
activation,
|
| 456 |
+
)
|
| 457 |
+
# convolution module definition
|
| 458 |
+
convolution_layer_args = (output_size, cnn_module_kernel, activation,
|
| 459 |
+
cnn_module_norm, causal)
|
| 460 |
+
|
| 461 |
+
self.encoders = torch.nn.ModuleList([
|
| 462 |
+
ConformerEncoderLayer(
|
| 463 |
+
output_size,
|
| 464 |
+
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
| 465 |
+
*encoder_selfattn_layer_args),
|
| 466 |
+
PositionwiseFeedForward(*positionwise_layer_args),
|
| 467 |
+
PositionwiseFeedForward(
|
| 468 |
+
*positionwise_layer_args) if macaron_style else None,
|
| 469 |
+
ConvolutionModule(
|
| 470 |
+
*convolution_layer_args) if use_cnn_module else None,
|
| 471 |
+
dropout_rate,
|
| 472 |
+
normalize_before,
|
| 473 |
+
) for _ in range(num_blocks)
|
| 474 |
+
])
|
cosyvoice/transformer/encoder_layer.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
| 2 |
+
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
| 16 |
+
"""Encoder self-attention layer definition."""
|
| 17 |
+
|
| 18 |
+
from typing import Optional, Tuple
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from torch import nn
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TransformerEncoderLayer(nn.Module):
|
| 25 |
+
"""Encoder layer module.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
size (int): Input dimension.
|
| 29 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
| 30 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
| 31 |
+
instance can be used as the argument.
|
| 32 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
| 33 |
+
`PositionwiseFeedForward`, instance can be used as the argument.
|
| 34 |
+
dropout_rate (float): Dropout rate.
|
| 35 |
+
normalize_before (bool):
|
| 36 |
+
True: use layer_norm before each sub-block.
|
| 37 |
+
False: to use layer_norm after each sub-block.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
size: int,
|
| 43 |
+
self_attn: torch.nn.Module,
|
| 44 |
+
feed_forward: torch.nn.Module,
|
| 45 |
+
dropout_rate: float,
|
| 46 |
+
normalize_before: bool = True,
|
| 47 |
+
):
|
| 48 |
+
"""Construct an EncoderLayer object."""
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.self_attn = self_attn
|
| 51 |
+
self.feed_forward = feed_forward
|
| 52 |
+
self.norm1 = nn.LayerNorm(size, eps=1e-12)
|
| 53 |
+
self.norm2 = nn.LayerNorm(size, eps=1e-12)
|
| 54 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 55 |
+
self.size = size
|
| 56 |
+
self.normalize_before = normalize_before
|
| 57 |
+
|
| 58 |
+
def forward(
|
| 59 |
+
self,
|
| 60 |
+
x: torch.Tensor,
|
| 61 |
+
mask: torch.Tensor,
|
| 62 |
+
pos_emb: torch.Tensor,
|
| 63 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
| 64 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
| 65 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
| 66 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 67 |
+
"""Compute encoded features.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
x (torch.Tensor): (#batch, time, size)
|
| 71 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
| 72 |
+
(0, 0, 0) means fake mask.
|
| 73 |
+
pos_emb (torch.Tensor): just for interface compatibility
|
| 74 |
+
to ConformerEncoderLayer
|
| 75 |
+
mask_pad (torch.Tensor): does not used in transformer layer,
|
| 76 |
+
just for unified api with conformer.
|
| 77 |
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
| 78 |
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
| 79 |
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
| 80 |
+
(#batch=1, size, cache_t2), not used here, it's for interface
|
| 81 |
+
compatibility to ConformerEncoderLayer.
|
| 82 |
+
Returns:
|
| 83 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
| 84 |
+
torch.Tensor: Mask tensor (#batch, time, time).
|
| 85 |
+
torch.Tensor: att_cache tensor,
|
| 86 |
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
| 87 |
+
torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
|
| 88 |
+
|
| 89 |
+
"""
|
| 90 |
+
residual = x
|
| 91 |
+
if self.normalize_before:
|
| 92 |
+
x = self.norm1(x)
|
| 93 |
+
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
|
| 94 |
+
x = residual + self.dropout(x_att)
|
| 95 |
+
if not self.normalize_before:
|
| 96 |
+
x = self.norm1(x)
|
| 97 |
+
|
| 98 |
+
residual = x
|
| 99 |
+
if self.normalize_before:
|
| 100 |
+
x = self.norm2(x)
|
| 101 |
+
x = residual + self.dropout(self.feed_forward(x))
|
| 102 |
+
if not self.normalize_before:
|
| 103 |
+
x = self.norm2(x)
|
| 104 |
+
|
| 105 |
+
fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
| 106 |
+
return x, mask, new_att_cache, fake_cnn_cache
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class ConformerEncoderLayer(nn.Module):
|
| 110 |
+
"""Encoder layer module.
|
| 111 |
+
Args:
|
| 112 |
+
size (int): Input dimension.
|
| 113 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
| 114 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
| 115 |
+
instance can be used as the argument.
|
| 116 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
| 117 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
| 118 |
+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
|
| 119 |
+
instance.
|
| 120 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
| 121 |
+
conv_module (torch.nn.Module): Convolution module instance.
|
| 122 |
+
`ConvlutionModule` instance can be used as the argument.
|
| 123 |
+
dropout_rate (float): Dropout rate.
|
| 124 |
+
normalize_before (bool):
|
| 125 |
+
True: use layer_norm before each sub-block.
|
| 126 |
+
False: use layer_norm after each sub-block.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
size: int,
|
| 132 |
+
self_attn: torch.nn.Module,
|
| 133 |
+
feed_forward: Optional[nn.Module] = None,
|
| 134 |
+
feed_forward_macaron: Optional[nn.Module] = None,
|
| 135 |
+
conv_module: Optional[nn.Module] = None,
|
| 136 |
+
dropout_rate: float = 0.1,
|
| 137 |
+
normalize_before: bool = True,
|
| 138 |
+
):
|
| 139 |
+
"""Construct an EncoderLayer object."""
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.self_attn = self_attn
|
| 142 |
+
self.feed_forward = feed_forward
|
| 143 |
+
self.feed_forward_macaron = feed_forward_macaron
|
| 144 |
+
self.conv_module = conv_module
|
| 145 |
+
self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
|
| 146 |
+
self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
|
| 147 |
+
if feed_forward_macaron is not None:
|
| 148 |
+
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
|
| 149 |
+
self.ff_scale = 0.5
|
| 150 |
+
else:
|
| 151 |
+
self.ff_scale = 1.0
|
| 152 |
+
if self.conv_module is not None:
|
| 153 |
+
self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
|
| 154 |
+
self.norm_final = nn.LayerNorm(
|
| 155 |
+
size, eps=1e-12) # for the final output of the block
|
| 156 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 157 |
+
self.size = size
|
| 158 |
+
self.normalize_before = normalize_before
|
| 159 |
+
|
| 160 |
+
def forward(
|
| 161 |
+
self,
|
| 162 |
+
x: torch.Tensor,
|
| 163 |
+
mask: torch.Tensor,
|
| 164 |
+
pos_emb: torch.Tensor,
|
| 165 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
| 166 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
| 167 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
| 168 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 169 |
+
"""Compute encoded features.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
x (torch.Tensor): (#batch, time, size)
|
| 173 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
| 174 |
+
(0, 0, 0) means fake mask.
|
| 175 |
+
pos_emb (torch.Tensor): positional encoding, must not be None
|
| 176 |
+
for ConformerEncoderLayer.
|
| 177 |
+
mask_pad (torch.Tensor): batch padding mask used for conv module.
|
| 178 |
+
(#batch, 1,time), (0, 0, 0) means fake mask.
|
| 179 |
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
| 180 |
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
| 181 |
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
| 182 |
+
(#batch=1, size, cache_t2)
|
| 183 |
+
Returns:
|
| 184 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
| 185 |
+
torch.Tensor: Mask tensor (#batch, time, time).
|
| 186 |
+
torch.Tensor: att_cache tensor,
|
| 187 |
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
| 188 |
+
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
# whether to use macaron style
|
| 192 |
+
if self.feed_forward_macaron is not None:
|
| 193 |
+
residual = x
|
| 194 |
+
if self.normalize_before:
|
| 195 |
+
x = self.norm_ff_macaron(x)
|
| 196 |
+
x = residual + self.ff_scale * self.dropout(
|
| 197 |
+
self.feed_forward_macaron(x))
|
| 198 |
+
if not self.normalize_before:
|
| 199 |
+
x = self.norm_ff_macaron(x)
|
| 200 |
+
|
| 201 |
+
# multi-headed self-attention module
|
| 202 |
+
residual = x
|
| 203 |
+
if self.normalize_before:
|
| 204 |
+
x = self.norm_mha(x)
|
| 205 |
+
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
|
| 206 |
+
att_cache)
|
| 207 |
+
x = residual + self.dropout(x_att)
|
| 208 |
+
if not self.normalize_before:
|
| 209 |
+
x = self.norm_mha(x)
|
| 210 |
+
|
| 211 |
+
# convolution module
|
| 212 |
+
# Fake new cnn cache here, and then change it in conv_module
|
| 213 |
+
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
| 214 |
+
if self.conv_module is not None:
|
| 215 |
+
residual = x
|
| 216 |
+
if self.normalize_before:
|
| 217 |
+
x = self.norm_conv(x)
|
| 218 |
+
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
|
| 219 |
+
x = residual + self.dropout(x)
|
| 220 |
+
|
| 221 |
+
if not self.normalize_before:
|
| 222 |
+
x = self.norm_conv(x)
|
| 223 |
+
|
| 224 |
+
# feed forward module
|
| 225 |
+
residual = x
|
| 226 |
+
if self.normalize_before:
|
| 227 |
+
x = self.norm_ff(x)
|
| 228 |
+
|
| 229 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
| 230 |
+
if not self.normalize_before:
|
| 231 |
+
x = self.norm_ff(x)
|
| 232 |
+
|
| 233 |
+
if self.conv_module is not None:
|
| 234 |
+
x = self.norm_final(x)
|
| 235 |
+
|
| 236 |
+
return x, mask, new_att_cache, new_cnn_cache
|
cosyvoice/transformer/label_smoothing_loss.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019 Shigeki Karita
|
| 2 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Label smoothing module."""
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch import nn
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class LabelSmoothingLoss(nn.Module):
|
| 22 |
+
"""Label-smoothing loss.
|
| 23 |
+
|
| 24 |
+
In a standard CE loss, the label's data distribution is:
|
| 25 |
+
[0,1,2] ->
|
| 26 |
+
[
|
| 27 |
+
[1.0, 0.0, 0.0],
|
| 28 |
+
[0.0, 1.0, 0.0],
|
| 29 |
+
[0.0, 0.0, 1.0],
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
In the smoothing version CE Loss,some probabilities
|
| 33 |
+
are taken from the true label prob (1.0) and are divided
|
| 34 |
+
among other labels.
|
| 35 |
+
|
| 36 |
+
e.g.
|
| 37 |
+
smoothing=0.1
|
| 38 |
+
[0,1,2] ->
|
| 39 |
+
[
|
| 40 |
+
[0.9, 0.05, 0.05],
|
| 41 |
+
[0.05, 0.9, 0.05],
|
| 42 |
+
[0.05, 0.05, 0.9],
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
size (int): the number of class
|
| 47 |
+
padding_idx (int): padding class id which will be ignored for loss
|
| 48 |
+
smoothing (float): smoothing rate (0.0 means the conventional CE)
|
| 49 |
+
normalize_length (bool):
|
| 50 |
+
normalize loss by sequence length if True
|
| 51 |
+
normalize loss by batch size if False
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self,
|
| 55 |
+
size: int,
|
| 56 |
+
padding_idx: int,
|
| 57 |
+
smoothing: float,
|
| 58 |
+
normalize_length: bool = False):
|
| 59 |
+
"""Construct an LabelSmoothingLoss object."""
|
| 60 |
+
super(LabelSmoothingLoss, self).__init__()
|
| 61 |
+
self.criterion = nn.KLDivLoss(reduction="none")
|
| 62 |
+
self.padding_idx = padding_idx
|
| 63 |
+
self.confidence = 1.0 - smoothing
|
| 64 |
+
self.smoothing = smoothing
|
| 65 |
+
self.size = size
|
| 66 |
+
self.normalize_length = normalize_length
|
| 67 |
+
|
| 68 |
+
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
"""Compute loss between x and target.
|
| 70 |
+
|
| 71 |
+
The model outputs and data labels tensors are flatten to
|
| 72 |
+
(batch*seqlen, class) shape and a mask is applied to the
|
| 73 |
+
padding part which should not be calculated for loss.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
x (torch.Tensor): prediction (batch, seqlen, class)
|
| 77 |
+
target (torch.Tensor):
|
| 78 |
+
target signal masked with self.padding_id (batch, seqlen)
|
| 79 |
+
Returns:
|
| 80 |
+
loss (torch.Tensor) : The KL loss, scalar float value
|
| 81 |
+
"""
|
| 82 |
+
assert x.size(2) == self.size
|
| 83 |
+
batch_size = x.size(0)
|
| 84 |
+
x = x.view(-1, self.size)
|
| 85 |
+
target = target.view(-1)
|
| 86 |
+
# use zeros_like instead of torch.no_grad() for true_dist,
|
| 87 |
+
# since no_grad() can not be exported by JIT
|
| 88 |
+
true_dist = torch.zeros_like(x)
|
| 89 |
+
true_dist.fill_(self.smoothing / (self.size - 1))
|
| 90 |
+
ignore = target == self.padding_idx # (B,)
|
| 91 |
+
total = len(target) - ignore.sum().item()
|
| 92 |
+
target = target.masked_fill(ignore, 0) # avoid -1 index
|
| 93 |
+
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
|
| 94 |
+
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
|
| 95 |
+
denom = total if self.normalize_length else batch_size
|
| 96 |
+
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
|
cosyvoice/transformer/positionwise_feed_forward.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019 Shigeki Karita
|
| 2 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Positionwise feed forward layer definition."""
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class PositionwiseFeedForward(torch.nn.Module):
|
| 21 |
+
"""Positionwise feed forward layer.
|
| 22 |
+
|
| 23 |
+
FeedForward are appied on each position of the sequence.
|
| 24 |
+
The output dim is same with the input dim.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
idim (int): Input dimenstion.
|
| 28 |
+
hidden_units (int): The number of hidden units.
|
| 29 |
+
dropout_rate (float): Dropout rate.
|
| 30 |
+
activation (torch.nn.Module): Activation function
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
idim: int,
|
| 36 |
+
hidden_units: int,
|
| 37 |
+
dropout_rate: float,
|
| 38 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
| 39 |
+
):
|
| 40 |
+
"""Construct a PositionwiseFeedForward object."""
|
| 41 |
+
super(PositionwiseFeedForward, self).__init__()
|
| 42 |
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
| 43 |
+
self.activation = activation
|
| 44 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
| 45 |
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
| 46 |
+
|
| 47 |
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
"""Forward function.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
xs: input tensor (B, L, D)
|
| 52 |
+
Returns:
|
| 53 |
+
output tensor, (B, L, D)
|
| 54 |
+
"""
|
| 55 |
+
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class MoEFFNLayer(torch.nn.Module):
|
| 59 |
+
"""
|
| 60 |
+
Mixture of expert with Positionwise feed forward layer
|
| 61 |
+
See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
|
| 62 |
+
The output dim is same with the input dim.
|
| 63 |
+
|
| 64 |
+
Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
|
| 65 |
+
https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
|
| 66 |
+
Args:
|
| 67 |
+
n_expert: number of expert.
|
| 68 |
+
n_expert_per_token: The actual number of experts used for each frame
|
| 69 |
+
idim (int): Input dimenstion.
|
| 70 |
+
hidden_units (int): The number of hidden units.
|
| 71 |
+
dropout_rate (float): Dropout rate.
|
| 72 |
+
activation (torch.nn.Module): Activation function
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
n_expert: int,
|
| 78 |
+
n_expert_per_token: int,
|
| 79 |
+
idim: int,
|
| 80 |
+
hidden_units: int,
|
| 81 |
+
dropout_rate: float,
|
| 82 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
| 83 |
+
):
|
| 84 |
+
super(MoEFFNLayer, self).__init__()
|
| 85 |
+
self.gate = torch.nn.Linear(idim, n_expert, bias=False)
|
| 86 |
+
self.experts = torch.nn.ModuleList(
|
| 87 |
+
PositionwiseFeedForward(idim, hidden_units, dropout_rate,
|
| 88 |
+
activation) for _ in range(n_expert))
|
| 89 |
+
self.n_expert_per_token = n_expert_per_token
|
| 90 |
+
|
| 91 |
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
| 92 |
+
"""Foward function.
|
| 93 |
+
Args:
|
| 94 |
+
xs: input tensor (B, L, D)
|
| 95 |
+
Returns:
|
| 96 |
+
output tensor, (B, L, D)
|
| 97 |
+
|
| 98 |
+
"""
|
| 99 |
+
B, L, D = xs.size(
|
| 100 |
+
) # batch size, sequence length, embedding dimension (idim)
|
| 101 |
+
xs = xs.view(-1, D) # (B*L, D)
|
| 102 |
+
router = self.gate(xs) # (B*L, n_expert)
|
| 103 |
+
logits, indices = torch.topk(
|
| 104 |
+
router, self.n_expert_per_token
|
| 105 |
+
) # probs:(B*L, n_expert), indices: (B*L, n_expert)
|
| 106 |
+
weights = torch.nn.functional.softmax(
|
| 107 |
+
logits, dim=1,
|
| 108 |
+
dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
|
| 109 |
+
output = torch.zeros_like(xs) # (B*L, D)
|
| 110 |
+
for i, expert in enumerate(self.experts):
|
| 111 |
+
mask = indices == i
|
| 112 |
+
batch_idx, ith_expert = torch.where(mask)
|
| 113 |
+
output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
|
| 114 |
+
xs[batch_idx])
|
| 115 |
+
return output.view(B, L, D)
|
cosyvoice/transformer/subsampling.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
| 2 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
| 16 |
+
"""Subsampling layer definition."""
|
| 17 |
+
|
| 18 |
+
from typing import Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class BaseSubsampling(torch.nn.Module):
|
| 24 |
+
|
| 25 |
+
def __init__(self):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.right_context = 0
|
| 28 |
+
self.subsampling_rate = 1
|
| 29 |
+
|
| 30 |
+
def position_encoding(self, offset: Union[int, torch.Tensor],
|
| 31 |
+
size: int) -> torch.Tensor:
|
| 32 |
+
return self.pos_enc.position_encoding(offset, size)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class EmbedinigNoSubsampling(BaseSubsampling):
|
| 36 |
+
"""Embedding input without subsampling
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
| 40 |
+
pos_enc_class: torch.nn.Module):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.embed = torch.nn.Embedding(idim, odim)
|
| 43 |
+
self.pos_enc = pos_enc_class
|
| 44 |
+
|
| 45 |
+
def forward(
|
| 46 |
+
self,
|
| 47 |
+
x: torch.Tensor,
|
| 48 |
+
x_mask: torch.Tensor,
|
| 49 |
+
offset: Union[int, torch.Tensor] = 0
|
| 50 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 51 |
+
"""Input x.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 55 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
| 59 |
+
where time' = time .
|
| 60 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
| 61 |
+
where time' = time .
|
| 62 |
+
|
| 63 |
+
"""
|
| 64 |
+
x = self.embed(x)
|
| 65 |
+
x, pos_emb = self.pos_enc(x, offset)
|
| 66 |
+
return x, pos_emb, x_mask
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class LinearNoSubsampling(BaseSubsampling):
|
| 70 |
+
"""Linear transform the input without subsampling
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
idim (int): Input dimension.
|
| 74 |
+
odim (int): Output dimension.
|
| 75 |
+
dropout_rate (float): Dropout rate.
|
| 76 |
+
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
| 80 |
+
pos_enc_class: torch.nn.Module):
|
| 81 |
+
"""Construct an linear object."""
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.out = torch.nn.Sequential(
|
| 84 |
+
torch.nn.Linear(idim, odim),
|
| 85 |
+
torch.nn.LayerNorm(odim, eps=1e-5),
|
| 86 |
+
torch.nn.Dropout(dropout_rate),
|
| 87 |
+
)
|
| 88 |
+
self.pos_enc = pos_enc_class
|
| 89 |
+
self.right_context = 0
|
| 90 |
+
self.subsampling_rate = 1
|
| 91 |
+
|
| 92 |
+
def forward(
|
| 93 |
+
self,
|
| 94 |
+
x: torch.Tensor,
|
| 95 |
+
x_mask: torch.Tensor,
|
| 96 |
+
offset: Union[int, torch.Tensor] = 0
|
| 97 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 98 |
+
"""Input x.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 102 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
| 106 |
+
where time' = time .
|
| 107 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
| 108 |
+
where time' = time .
|
| 109 |
+
|
| 110 |
+
"""
|
| 111 |
+
x = self.out(x)
|
| 112 |
+
x, pos_emb = self.pos_enc(x, offset)
|
| 113 |
+
return x, pos_emb, x_mask
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class Conv1dSubsampling2(BaseSubsampling):
|
| 117 |
+
"""Convolutional 1D subsampling (to 1/2 length).
|
| 118 |
+
It is designed for Whisper, ref:
|
| 119 |
+
https://github.com/openai/whisper/blob/main/whisper/model.py
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
idim (int): Input dimension.
|
| 123 |
+
odim (int): Output dimension.
|
| 124 |
+
dropout_rate (float): Dropout rate.
|
| 125 |
+
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
| 129 |
+
pos_enc_class: torch.nn.Module):
|
| 130 |
+
"""Construct an Conv1dSubsampling2 object."""
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.conv = torch.nn.Sequential(
|
| 133 |
+
torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
|
| 134 |
+
torch.nn.GELU(),
|
| 135 |
+
torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
|
| 136 |
+
torch.nn.GELU(),
|
| 137 |
+
)
|
| 138 |
+
self.pos_enc = pos_enc_class
|
| 139 |
+
# The right context for every conv layer is computed by:
|
| 140 |
+
# (kernel_size - 1) * frame_rate_of_this_layer
|
| 141 |
+
self.subsampling_rate = 2
|
| 142 |
+
# 4 = (3 - 1) * 1 + (3 - 1) * 1
|
| 143 |
+
self.right_context = 4
|
| 144 |
+
|
| 145 |
+
def forward(
|
| 146 |
+
self,
|
| 147 |
+
x: torch.Tensor,
|
| 148 |
+
x_mask: torch.Tensor,
|
| 149 |
+
offset: Union[int, torch.Tensor] = 0
|
| 150 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 151 |
+
"""Subsample x.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 155 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
| 159 |
+
where time' = time // 2.
|
| 160 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
| 161 |
+
where time' = time // 2.
|
| 162 |
+
torch.Tensor: positional encoding
|
| 163 |
+
|
| 164 |
+
"""
|
| 165 |
+
time = x.size(1)
|
| 166 |
+
x = x.transpose(1, 2) # (b, f, t)
|
| 167 |
+
x = self.conv(x)
|
| 168 |
+
x = x.transpose(1, 2) # (b, t, f)
|
| 169 |
+
x, pos_emb = self.pos_enc(x, offset)
|
| 170 |
+
return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class Conv2dSubsampling4(BaseSubsampling):
|
| 174 |
+
"""Convolutional 2D subsampling (to 1/4 length).
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
idim (int): Input dimension.
|
| 178 |
+
odim (int): Output dimension.
|
| 179 |
+
dropout_rate (float): Dropout rate.
|
| 180 |
+
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
| 184 |
+
pos_enc_class: torch.nn.Module):
|
| 185 |
+
"""Construct an Conv2dSubsampling4 object."""
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.conv = torch.nn.Sequential(
|
| 188 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
| 189 |
+
torch.nn.ReLU(),
|
| 190 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
| 191 |
+
torch.nn.ReLU(),
|
| 192 |
+
)
|
| 193 |
+
self.out = torch.nn.Sequential(
|
| 194 |
+
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
|
| 195 |
+
self.pos_enc = pos_enc_class
|
| 196 |
+
# The right context for every conv layer is computed by:
|
| 197 |
+
# (kernel_size - 1) * frame_rate_of_this_layer
|
| 198 |
+
self.subsampling_rate = 4
|
| 199 |
+
# 6 = (3 - 1) * 1 + (3 - 1) * 2
|
| 200 |
+
self.right_context = 6
|
| 201 |
+
|
| 202 |
+
def forward(
|
| 203 |
+
self,
|
| 204 |
+
x: torch.Tensor,
|
| 205 |
+
x_mask: torch.Tensor,
|
| 206 |
+
offset: Union[int, torch.Tensor] = 0
|
| 207 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 208 |
+
"""Subsample x.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 212 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
| 216 |
+
where time' = time // 4.
|
| 217 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
| 218 |
+
where time' = time // 4.
|
| 219 |
+
torch.Tensor: positional encoding
|
| 220 |
+
|
| 221 |
+
"""
|
| 222 |
+
x = x.unsqueeze(1) # (b, c=1, t, f)
|
| 223 |
+
x = self.conv(x)
|
| 224 |
+
b, c, t, f = x.size()
|
| 225 |
+
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
| 226 |
+
x, pos_emb = self.pos_enc(x, offset)
|
| 227 |
+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class Conv2dSubsampling6(BaseSubsampling):
|
| 231 |
+
"""Convolutional 2D subsampling (to 1/6 length).
|
| 232 |
+
Args:
|
| 233 |
+
idim (int): Input dimension.
|
| 234 |
+
odim (int): Output dimension.
|
| 235 |
+
dropout_rate (float): Dropout rate.
|
| 236 |
+
pos_enc (torch.nn.Module): Custom position encoding layer.
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
| 240 |
+
pos_enc_class: torch.nn.Module):
|
| 241 |
+
"""Construct an Conv2dSubsampling6 object."""
|
| 242 |
+
super().__init__()
|
| 243 |
+
self.conv = torch.nn.Sequential(
|
| 244 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
| 245 |
+
torch.nn.ReLU(),
|
| 246 |
+
torch.nn.Conv2d(odim, odim, 5, 3),
|
| 247 |
+
torch.nn.ReLU(),
|
| 248 |
+
)
|
| 249 |
+
self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
|
| 250 |
+
odim)
|
| 251 |
+
self.pos_enc = pos_enc_class
|
| 252 |
+
# 10 = (3 - 1) * 1 + (5 - 1) * 2
|
| 253 |
+
self.subsampling_rate = 6
|
| 254 |
+
self.right_context = 10
|
| 255 |
+
|
| 256 |
+
def forward(
|
| 257 |
+
self,
|
| 258 |
+
x: torch.Tensor,
|
| 259 |
+
x_mask: torch.Tensor,
|
| 260 |
+
offset: Union[int, torch.Tensor] = 0
|
| 261 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 262 |
+
"""Subsample x.
|
| 263 |
+
Args:
|
| 264 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 265 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
| 269 |
+
where time' = time // 6.
|
| 270 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
| 271 |
+
where time' = time // 6.
|
| 272 |
+
torch.Tensor: positional encoding
|
| 273 |
+
"""
|
| 274 |
+
x = x.unsqueeze(1) # (b, c, t, f)
|
| 275 |
+
x = self.conv(x)
|
| 276 |
+
b, c, t, f = x.size()
|
| 277 |
+
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
| 278 |
+
x, pos_emb = self.pos_enc(x, offset)
|
| 279 |
+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class Conv2dSubsampling8(BaseSubsampling):
|
| 283 |
+
"""Convolutional 2D subsampling (to 1/8 length).
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
idim (int): Input dimension.
|
| 287 |
+
odim (int): Output dimension.
|
| 288 |
+
dropout_rate (float): Dropout rate.
|
| 289 |
+
|
| 290 |
+
"""
|
| 291 |
+
|
| 292 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
| 293 |
+
pos_enc_class: torch.nn.Module):
|
| 294 |
+
"""Construct an Conv2dSubsampling8 object."""
|
| 295 |
+
super().__init__()
|
| 296 |
+
self.conv = torch.nn.Sequential(
|
| 297 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
| 298 |
+
torch.nn.ReLU(),
|
| 299 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
| 300 |
+
torch.nn.ReLU(),
|
| 301 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
| 302 |
+
torch.nn.ReLU(),
|
| 303 |
+
)
|
| 304 |
+
self.linear = torch.nn.Linear(
|
| 305 |
+
odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
|
| 306 |
+
self.pos_enc = pos_enc_class
|
| 307 |
+
self.subsampling_rate = 8
|
| 308 |
+
# 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
|
| 309 |
+
self.right_context = 14
|
| 310 |
+
|
| 311 |
+
def forward(
|
| 312 |
+
self,
|
| 313 |
+
x: torch.Tensor,
|
| 314 |
+
x_mask: torch.Tensor,
|
| 315 |
+
offset: Union[int, torch.Tensor] = 0
|
| 316 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 317 |
+
"""Subsample x.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 321 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
| 325 |
+
where time' = time // 8.
|
| 326 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
| 327 |
+
where time' = time // 8.
|
| 328 |
+
torch.Tensor: positional encoding
|
| 329 |
+
"""
|
| 330 |
+
x = x.unsqueeze(1) # (b, c, t, f)
|
| 331 |
+
x = self.conv(x)
|
| 332 |
+
b, c, t, f = x.size()
|
| 333 |
+
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
| 334 |
+
x, pos_emb = self.pos_enc(x, offset)
|
| 335 |
+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
class LegacyLinearNoSubsampling(BaseSubsampling):
|
| 339 |
+
"""Linear transform the input without subsampling
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
idim (int): Input dimension.
|
| 343 |
+
odim (int): Output dimension.
|
| 344 |
+
dropout_rate (float): Dropout rate.
|
| 345 |
+
|
| 346 |
+
"""
|
| 347 |
+
|
| 348 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
| 349 |
+
pos_enc_class: torch.nn.Module):
|
| 350 |
+
"""Construct an linear object."""
|
| 351 |
+
super().__init__()
|
| 352 |
+
self.out = torch.nn.Sequential(
|
| 353 |
+
torch.nn.Linear(idim, odim),
|
| 354 |
+
torch.nn.LayerNorm(odim, eps=1e-5),
|
| 355 |
+
torch.nn.Dropout(dropout_rate),
|
| 356 |
+
torch.nn.ReLU(),
|
| 357 |
+
)
|
| 358 |
+
self.pos_enc = pos_enc_class
|
| 359 |
+
self.right_context = 0
|
| 360 |
+
self.subsampling_rate = 1
|
| 361 |
+
|
| 362 |
+
def forward(
|
| 363 |
+
self,
|
| 364 |
+
x: torch.Tensor,
|
| 365 |
+
x_mask: torch.Tensor,
|
| 366 |
+
offset: Union[int, torch.Tensor] = 0
|
| 367 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 368 |
+
"""Input x.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 372 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 373 |
+
|
| 374 |
+
Returns:
|
| 375 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
| 376 |
+
where time' = time .
|
| 377 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
| 378 |
+
where time' = time .
|
| 379 |
+
|
| 380 |
+
"""
|
| 381 |
+
x = self.out(x)
|
| 382 |
+
x, pos_emb = self.pos_enc(x, offset)
|
| 383 |
+
return x, pos_emb, x_mask
|