parkjihye commited on
Commit
495bed3
·
1 Parent(s): 7e006f9

Add file CosyVoice

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. model/.ipynb_checkpoints/README-checkpoint.md +24 -0
  2. model/.ipynb_checkpoints/cosyvoice_notebook-checkpoint.ipynb +113 -0
  3. model/.ipynb_checkpoints/gitignore-checkpoint.txt +7 -0
  4. model/.ipynb_checkpoints/requirements-checkpoint.txt +6 -0
  5. model/README.md +24 -0
  6. model/cosyvoice/cli/.ipynb_checkpoints/cosyvoice-checkpoint.py +179 -0
  7. model/cosyvoice/cli/.ipynb_checkpoints/frontend-checkpoint.py +211 -0
  8. model/cosyvoice/cli/__init__.py +0 -0
  9. model/cosyvoice/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  10. model/cosyvoice/cli/__pycache__/__init__.cpython-311.pyc +0 -0
  11. model/cosyvoice/cli/__pycache__/__init__.cpython-312.pyc +0 -0
  12. model/cosyvoice/cli/__pycache__/cosyvoice.cpython-310.pyc +0 -0
  13. model/cosyvoice/cli/__pycache__/cosyvoice.cpython-311.pyc +0 -0
  14. model/cosyvoice/cli/__pycache__/cosyvoice.cpython-312.pyc +0 -0
  15. model/cosyvoice/cli/__pycache__/frontend.cpython-310.pyc +0 -0
  16. model/cosyvoice/cli/__pycache__/frontend.cpython-311.pyc +0 -0
  17. model/cosyvoice/cli/__pycache__/frontend.cpython-312.pyc +0 -0
  18. model/cosyvoice/cli/__pycache__/model.cpython-310.pyc +0 -0
  19. model/cosyvoice/cli/__pycache__/model.cpython-311.pyc +0 -0
  20. model/cosyvoice/cli/cosyvoice.py +179 -0
  21. model/cosyvoice/cli/frontend.py +211 -0
  22. model/cosyvoice/cli/model.py +461 -0
  23. model/cosyvoice/dataset/__init__.py +0 -0
  24. model/cosyvoice/dataset/__pycache__/__init__.cpython-310.pyc +0 -0
  25. model/cosyvoice/dataset/__pycache__/__init__.cpython-311.pyc +0 -0
  26. model/cosyvoice/dataset/__pycache__/processor.cpython-310.pyc +0 -0
  27. model/cosyvoice/dataset/__pycache__/processor.cpython-311.pyc +0 -0
  28. model/cosyvoice/dataset/dataset.py +164 -0
  29. model/cosyvoice/dataset/processor.py +435 -0
  30. model/cosyvoice/flow/__pycache__/decoder.cpython-310.pyc +0 -0
  31. model/cosyvoice/flow/__pycache__/decoder.cpython-311.pyc +0 -0
  32. model/cosyvoice/flow/__pycache__/flow.cpython-310.pyc +0 -0
  33. model/cosyvoice/flow/__pycache__/flow.cpython-311.pyc +0 -0
  34. model/cosyvoice/flow/__pycache__/flow_matching.cpython-310.pyc +0 -0
  35. model/cosyvoice/flow/__pycache__/flow_matching.cpython-311.pyc +0 -0
  36. model/cosyvoice/flow/decoder.py +902 -0
  37. model/cosyvoice/flow/flow.py +289 -0
  38. model/cosyvoice/flow/flow_matching.py +344 -0
  39. model/cosyvoice/flow/length_regulator.py +70 -0
  40. model/cosyvoice/hifigan/__pycache__/discriminator.cpython-310.pyc +0 -0
  41. model/cosyvoice/hifigan/__pycache__/discriminator.cpython-311.pyc +0 -0
  42. model/cosyvoice/hifigan/__pycache__/f0_predictor.cpython-310.pyc +0 -0
  43. model/cosyvoice/hifigan/__pycache__/f0_predictor.cpython-311.pyc +0 -0
  44. model/cosyvoice/hifigan/__pycache__/generator.cpython-310.pyc +0 -0
  45. model/cosyvoice/hifigan/__pycache__/generator.cpython-311.pyc +0 -0
  46. model/cosyvoice/hifigan/__pycache__/hifigan.cpython-310.pyc +0 -0
  47. model/cosyvoice/hifigan/__pycache__/hifigan.cpython-311.pyc +0 -0
  48. model/cosyvoice/hifigan/discriminator.py +230 -0
  49. model/cosyvoice/hifigan/f0_predictor.py +58 -0
  50. model/cosyvoice/hifigan/generator.py +414 -0
model/.ipynb_checkpoints/README-checkpoint.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```bash
2
+ # 1. 리포지토리 클론 및 이동
3
+ git clone https://github.com/yourusername/cosyvoice.git
4
+ cd cosyvoice
5
+
6
+ # 2. 의존성 설치
7
+ pip install -r requirements.txt
8
+
9
+ # 3. 사전학습 모델 다운로드 (Python 코드 실행)
10
+ from modelscope import snapshot_download
11
+
12
+ # CosyVoice2 TTS 모델
13
+ snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
14
+
15
+ # CosyVoice 전처리기 (frontend)
16
+ snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
17
+
18
+ # 4. frontend 리소스 압축 해제
19
+ cd pretrained_models/CosyVoice-ttsfrd/
20
+ unzip resource.zip -d .
21
+
22
+ # 5. ttsfrd 의존성 설치
23
+ pip install ttsfrd_dependency-0.1-py3-none-any.whl
24
+ pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
model/.ipynb_checkpoints/cosyvoice_notebook-checkpoint.ipynb ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "8ae68495-7e0f-4308-b63a-f1d09f2850c7",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/home/pasong/miniconda3/envs/cosyvoice/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14
+ " from .autonotebook import tqdm as notebook_tqdm\n"
15
+ ]
16
+ }
17
+ ],
18
+ "source": [
19
+ "import sys\n",
20
+ "sys.path.append('third_party/Matcha-TTS')\n",
21
+ "from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2\n",
22
+ "from cosyvoice.utils.file_utils import load_wav\n",
23
+ "import torchaudio"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 2,
29
+ "id": "6a38d09e-c3a2-43cf-a684-0fd3b47e4b09",
30
+ "metadata": {},
31
+ "outputs": [
32
+ {
33
+ "name": "stderr",
34
+ "output_type": "stream",
35
+ "text": [
36
+ "/home/pasong/miniconda3/envs/cosyvoice/lib/python3.10/site-packages/diffusers/models/lora.py:393: FutureWarning: `LoRACompatibleLinear` is deprecated and will be removed in version 1.0.0. Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`.\n",
37
+ " deprecate(\"LoRACompatibleLinear\", \"1.0.0\", deprecation_message)\n",
38
+ "2025-04-19 19:46:33,344 INFO input frame rate=25\n",
39
+ "/home/pasong/miniconda3/envs/cosyvoice/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\n",
40
+ " warnings.warn(\"torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\")\n",
41
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
42
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
43
+ "\u001b[0;93m2025-04-19 19:46:35.693712543 [W:onnxruntime:, transformer_memcpy.cc:74 ApplyImpl] 8 Memcpy nodes are added to the graph main_graph for CUDAExecutionProvider. It might have negative impact on performance (including unable to run CUDA graph). Set session_options.log_severity_level=1 to see the detail logs before this message.\u001b[m\n",
44
+ "\u001b[0;93m2025-04-19 19:46:35.697676173 [W:onnxruntime:, session_state.cc:1166 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.\u001b[m\n",
45
+ "\u001b[0;93m2025-04-19 19:46:35.697686892 [W:onnxruntime:, session_state.cc:1168 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.\u001b[m\n",
46
+ " 0%| | 0/1 [00:00<?, ?it/s]2025-04-19 19:46:41,474 INFO synthesis text 공룡이 밤양갱을 몰래 먹고 도망쳤어요。\n"
47
+ ]
48
+ },
49
+ {
50
+ "name": "stdout",
51
+ "output_type": "stream",
52
+ "text": [
53
+ "text.cc: festival_Text_init\n",
54
+ "open voice lang map failed\n"
55
+ ]
56
+ },
57
+ {
58
+ "name": "stderr",
59
+ "output_type": "stream",
60
+ "text": [
61
+ "2025-04-19 19:46:46,661 INFO yield speech len 5.92, rtf 0.8762226314158054\n",
62
+ "100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:06<00:00, 6.36s/it]\n"
63
+ ]
64
+ }
65
+ ],
66
+ "source": [
67
+ "cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False, use_flow_cache=False)\n",
68
+ "\n",
69
+ "# 프롬프트로 사용할 음성 파일 (화자의 목소리 담긴 wav, 16kHz로)\n",
70
+ "prompt_speech_16k = load_wav('./asset/tts_test.wav', 16000)\n",
71
+ "\n",
72
+ "for i, j in enumerate(\n",
73
+ " cosyvoice.inference_zero_shot(\n",
74
+ " '공룡이 밤양갱을 몰래 먹고 도망쳤어요.', #tts할 문장\n",
75
+ " prompt_text='오느른 커피 안 마실 꺼야', #tts_test의 발음문장\n",
76
+ " prompt_speech_16k=prompt_speech_16k,\n",
77
+ " text_frontend=True, \n",
78
+ " )\n",
79
+ "):\n",
80
+ " torchaudio.save(f'korean_tts_{i}.wav', j['tts_speech'], cosyvoice.sample_rate)"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": null,
86
+ "id": "899101b0-3556-4ab5-8477-d182da48357d",
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": []
90
+ }
91
+ ],
92
+ "metadata": {
93
+ "kernelspec": {
94
+ "display_name": "Python (cosyvoice)",
95
+ "language": "python",
96
+ "name": "cosyvoice"
97
+ },
98
+ "language_info": {
99
+ "codemirror_mode": {
100
+ "name": "ipython",
101
+ "version": 3
102
+ },
103
+ "file_extension": ".py",
104
+ "mimetype": "text/x-python",
105
+ "name": "python",
106
+ "nbconvert_exporter": "python",
107
+ "pygments_lexer": "ipython3",
108
+ "version": "3.10.16"
109
+ }
110
+ },
111
+ "nbformat": 4,
112
+ "nbformat_minor": 5
113
+ }
model/.ipynb_checkpoints/gitignore-checkpoint.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ .ipynb_checkpoints/
3
+ *.pt
4
+ *.pth
5
+ *.onnx
6
+ pretrained_models/CosyVoice2-0.5B/
7
+ pretrained_models/CosyVoice-ttsfrd/resource/
model/.ipynb_checkpoints/requirements-checkpoint.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ HyperPyYAML==1.2.2
2
+ modelscope==1.20.0
3
+ tensorrt_cu12==10.0.1
4
+ torch==2.3.1+cu121
5
+ torchaudio==2.3.1+cu121
6
+ tqdm==4.67.1
model/README.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```bash
2
+ # 1. 리포지토리 클론 및 이동
3
+ git clone https://github.com/yourusername/cosyvoice.git
4
+ cd cosyvoice
5
+
6
+ # 2. 의존성 설치
7
+ pip install -r requirements.txt
8
+
9
+ # 3. 사전학습 모델 다운로드 (Python 코드 실행)
10
+ from modelscope import snapshot_download
11
+
12
+ # CosyVoice2 TTS 모델
13
+ snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
14
+
15
+ # CosyVoice 전처리기 (frontend)
16
+ snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
17
+
18
+ # 4. frontend 리소스 압축 해제
19
+ cd pretrained_models/CosyVoice-ttsfrd/
20
+ unzip resource.zip -d .
21
+
22
+ # 5. ttsfrd 의존성 설치
23
+ pip install ttsfrd_dependency-0.1-py3-none-any.whl
24
+ pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
model/cosyvoice/cli/.ipynb_checkpoints/cosyvoice-checkpoint.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import time
16
+ from typing import Generator
17
+ from tqdm import tqdm
18
+ from hyperpyyaml import load_hyperpyyaml
19
+ from modelscope import snapshot_download
20
+ import torch
21
+ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
22
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
23
+ from cosyvoice.utils.file_utils import logging
24
+ from cosyvoice.utils.class_utils import get_model_type
25
+
26
+
27
+ class CosyVoice:
28
+
29
+ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
30
+ self.instruct = True if '-Instruct' in model_dir else False
31
+ self.model_dir = model_dir
32
+ self.fp16 = fp16
33
+ if not os.path.exists(model_dir):
34
+ model_dir = snapshot_download(model_dir)
35
+ hyper_yaml_path = '{}/cosyvoice.yaml'.format(model_dir)
36
+ if not os.path.exists(hyper_yaml_path):
37
+ raise ValueError('{} not found!'.format(hyper_yaml_path))
38
+ with open(hyper_yaml_path, 'r') as f:
39
+ configs = load_hyperpyyaml(f)
40
+ assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
41
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
42
+ configs['feat_extractor'],
43
+ '{}/campplus.onnx'.format(model_dir),
44
+ '{}/speech_tokenizer_v1.onnx'.format(model_dir),
45
+ '{}/spk2info.pt'.format(model_dir),
46
+ configs['allowed_special'])
47
+ self.sample_rate = configs['sample_rate']
48
+ if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
49
+ load_jit, load_trt, fp16 = False, False, False
50
+ logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
51
+ self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
52
+ self.model.load('{}/llm.pt'.format(model_dir),
53
+ '{}/flow.pt'.format(model_dir),
54
+ '{}/hift.pt'.format(model_dir))
55
+ if load_jit:
56
+ self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
57
+ '{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
58
+ '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
59
+ if load_trt:
60
+ self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
61
+ '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
62
+ self.fp16)
63
+ del configs
64
+
65
+ def list_available_spks(self):
66
+ spks = list(self.frontend.spk2info.keys())
67
+ return spks
68
+
69
+ def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
70
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
71
+ model_input = self.frontend.frontend_sft(i, spk_id)
72
+ start_time = time.time()
73
+ logging.info('synthesis text {}'.format(i))
74
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
75
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
76
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
77
+ yield model_output
78
+ start_time = time.time()
79
+
80
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
81
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
82
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
83
+ if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
84
+ logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
85
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
86
+ start_time = time.time()
87
+ logging.info('synthesis text {}'.format(i))
88
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
89
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
90
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
91
+ yield model_output
92
+ start_time = time.time()
93
+
94
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
95
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
96
+ model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
97
+ start_time = time.time()
98
+ logging.info('synthesis text {}'.format(i))
99
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
100
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
101
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
102
+ yield model_output
103
+ start_time = time.time()
104
+
105
+ def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
106
+ assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
107
+ if self.instruct is False:
108
+ raise ValueError('{} do not support instruct inference'.format(self.model_dir))
109
+ instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
110
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
111
+ model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
112
+ start_time = time.time()
113
+ logging.info('synthesis text {}'.format(i))
114
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
115
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
116
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
117
+ yield model_output
118
+ start_time = time.time()
119
+
120
+ def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
121
+ model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
122
+ start_time = time.time()
123
+ for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
124
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
125
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
126
+ yield model_output
127
+ start_time = time.time()
128
+
129
+
130
+ class CosyVoice2(CosyVoice):
131
+
132
+ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False):
133
+ self.instruct = True if '-Instruct' in model_dir else False
134
+ self.model_dir = model_dir
135
+ self.fp16 = fp16
136
+ if not os.path.exists(model_dir):
137
+ model_dir = snapshot_download(model_dir)
138
+ hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
139
+ if not os.path.exists(hyper_yaml_path):
140
+ raise ValueError('{} not found!'.format(hyper_yaml_path))
141
+ with open(hyper_yaml_path, 'r') as f:
142
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
143
+ assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
144
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
145
+ configs['feat_extractor'],
146
+ '{}/campplus.onnx'.format(model_dir),
147
+ '{}/speech_tokenizer_v2.onnx'.format(model_dir),
148
+ '{}/spk2info.pt'.format(model_dir),
149
+ configs['allowed_special'])
150
+ self.sample_rate = configs['sample_rate']
151
+ if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
152
+ load_jit, load_trt, fp16 = False, False, False
153
+ logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
154
+ self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache)
155
+ self.model.load('{}/llm.pt'.format(model_dir),
156
+ '{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir),
157
+ '{}/hift.pt'.format(model_dir))
158
+ if load_jit:
159
+ self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
160
+ if load_trt:
161
+ self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
162
+ '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
163
+ self.fp16)
164
+ del configs
165
+
166
+ def inference_instruct(self, *args, **kwargs):
167
+ raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
168
+
169
+ def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
170
+ assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
171
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
172
+ model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
173
+ start_time = time.time()
174
+ logging.info('synthesis text {}'.format(i))
175
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
176
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
177
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
178
+ yield model_output
179
+ start_time = time.time()
model/cosyvoice/cli/.ipynb_checkpoints/frontend-checkpoint.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import partial
15
+ from typing import Generator
16
+ import json
17
+ import onnxruntime
18
+ import torch
19
+ import numpy as np
20
+ import whisper
21
+ from typing import Callable
22
+ import torchaudio.compliance.kaldi as kaldi
23
+ import torchaudio
24
+ import os
25
+ import re
26
+ import inflect
27
+ try:
28
+ import ttsfrd
29
+ use_ttsfrd = True
30
+ except ImportError:
31
+ print("failed to import ttsfrd, use WeTextProcessing instead")
32
+ from tn.chinese.normalizer import Normalizer as ZhNormalizer
33
+ from tn.english.normalizer import Normalizer as EnNormalizer
34
+ use_ttsfrd = False
35
+ from cosyvoice.utils.file_utils import logging
36
+ from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
37
+
38
+
39
+ class CosyVoiceFrontEnd:
40
+
41
+ def __init__(self,
42
+ get_tokenizer: Callable,
43
+ feat_extractor: Callable,
44
+ campplus_model: str,
45
+ speech_tokenizer_model: str,
46
+ spk2info: str = '',
47
+ allowed_special: str = 'all'):
48
+ self.tokenizer = get_tokenizer()
49
+ self.feat_extractor = feat_extractor
50
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
51
+ option = onnxruntime.SessionOptions()
52
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
53
+ option.intra_op_num_threads = 1
54
+ self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
55
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
56
+ providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
57
+ "CPUExecutionProvider"])
58
+ if os.path.exists(spk2info):
59
+ self.spk2info = torch.load(spk2info, map_location=self.device)
60
+ else:
61
+ self.spk2info = {}
62
+ self.allowed_special = allowed_special
63
+ self.use_ttsfrd = use_ttsfrd
64
+ if self.use_ttsfrd:
65
+ self.frd = ttsfrd.TtsFrontendEngine()
66
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
67
+ assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
68
+ 'failed to initialize ttsfrd resource'
69
+ self.frd.set_lang_type('pinyinvg')
70
+ else:
71
+ self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
72
+ self.en_tn_model = EnNormalizer()
73
+ self.inflect_parser = inflect.engine()
74
+
75
+ def _extract_text_token(self, text):
76
+ if isinstance(text, Generator):
77
+ logging.info('get tts_text generator, will return _extract_text_token_generator!')
78
+ # NOTE add a dummy text_token_len for compatibility
79
+ return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
80
+ else:
81
+ text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
82
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
83
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
84
+ return text_token, text_token_len
85
+
86
+ def _extract_text_token_generator(self, text_generator):
87
+ for text in text_generator:
88
+ text_token, _ = self._extract_text_token(text)
89
+ for i in range(text_token.shape[1]):
90
+ yield text_token[:, i: i + 1]
91
+
92
+ def _extract_speech_token(self, speech):
93
+ assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
94
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
95
+ speech_token = self.speech_tokenizer_session.run(None,
96
+ {self.speech_tokenizer_session.get_inputs()[0].name:
97
+ feat.detach().cpu().numpy(),
98
+ self.speech_tokenizer_session.get_inputs()[1].name:
99
+ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
100
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
101
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
102
+ return speech_token, speech_token_len
103
+
104
+ def _extract_spk_embedding(self, speech):
105
+ feat = kaldi.fbank(speech,
106
+ num_mel_bins=80,
107
+ dither=0,
108
+ sample_frequency=16000)
109
+ feat = feat - feat.mean(dim=0, keepdim=True)
110
+ embedding = self.campplus_session.run(None,
111
+ {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
112
+ embedding = torch.tensor([embedding]).to(self.device)
113
+ return embedding
114
+
115
+ def _extract_speech_feat(self, speech):
116
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
117
+ speech_feat = speech_feat.unsqueeze(dim=0)
118
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
119
+ return speech_feat, speech_feat_len
120
+
121
+ def text_normalize(self, text, split=True, text_frontend=True):
122
+ if isinstance(text, Generator):
123
+ logging.info('get tts_text generator, will skip text_normalize!')
124
+ return [text]
125
+ if text_frontend is False:
126
+ return [text] if split is True else text
127
+ text = text.strip()
128
+ if self.use_ttsfrd:
129
+ texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
130
+ text = ''.join(texts)
131
+ else:
132
+ if contains_chinese(text):
133
+ text = self.zh_tn_model.normalize(text)
134
+ text = text.replace("\n", "")
135
+ text = replace_blank(text)
136
+ text = replace_corner_mark(text)
137
+ text = text.replace(".", "。")
138
+ text = text.replace(" - ", ",")
139
+ text = remove_bracket(text)
140
+ text = re.sub(r'[,,、]+$', '。', text)
141
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
142
+ token_min_n=60, merge_len=20, comma_split=False))
143
+ else:
144
+ text = self.en_tn_model.normalize(text)
145
+ text = spell_out_number(text, self.inflect_parser)
146
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
147
+ token_min_n=60, merge_len=20, comma_split=False))
148
+ texts = [i for i in texts if not is_only_punctuation(i)]
149
+ return texts if split is True else text
150
+
151
+ def frontend_sft(self, tts_text, spk_id):
152
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
153
+ embedding = self.spk2info[spk_id]['embedding']
154
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
155
+ return model_input
156
+
157
+ def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
158
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
159
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
160
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
161
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
162
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
163
+ if resample_rate == 24000:
164
+ # cosyvoice2, force speech_feat % speech_token = 2
165
+ token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
166
+ speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
167
+ speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
168
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
169
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
170
+ 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
171
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
172
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
173
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
174
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
175
+ return model_input
176
+
177
+ def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
178
+ model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
179
+ # in cross lingual mode, we remove prompt in llm
180
+ del model_input['prompt_text']
181
+ del model_input['prompt_text_len']
182
+ del model_input['llm_prompt_speech_token']
183
+ del model_input['llm_prompt_speech_token_len']
184
+ return model_input
185
+
186
+ def frontend_instruct(self, tts_text, spk_id, instruct_text):
187
+ model_input = self.frontend_sft(tts_text, spk_id)
188
+ # in instruct mode, we remove spk_embedding in llm due to information leakage
189
+ del model_input['llm_embedding']
190
+ instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
191
+ model_input['prompt_text'] = instruct_text_token
192
+ model_input['prompt_text_len'] = instruct_text_token_len
193
+ return model_input
194
+
195
+ def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
196
+ model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate)
197
+ del model_input['llm_prompt_speech_token']
198
+ del model_input['llm_prompt_speech_token_len']
199
+ return model_input
200
+
201
+ def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
202
+ prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
203
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
204
+ prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
205
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
206
+ source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
207
+ model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
208
+ 'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
209
+ 'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
210
+ 'flow_embedding': embedding}
211
+ return model_input
model/cosyvoice/cli/__init__.py ADDED
File without changes
model/cosyvoice/cli/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (145 Bytes). View file
 
model/cosyvoice/cli/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (161 Bytes). View file
 
model/cosyvoice/cli/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (149 Bytes). View file
 
model/cosyvoice/cli/__pycache__/cosyvoice.cpython-310.pyc ADDED
Binary file (7.38 kB). View file
 
model/cosyvoice/cli/__pycache__/cosyvoice.cpython-311.pyc ADDED
Binary file (16.2 kB). View file
 
model/cosyvoice/cli/__pycache__/cosyvoice.cpython-312.pyc ADDED
Binary file (14.9 kB). View file
 
model/cosyvoice/cli/__pycache__/frontend.cpython-310.pyc ADDED
Binary file (8.58 kB). View file
 
model/cosyvoice/cli/__pycache__/frontend.cpython-311.pyc ADDED
Binary file (16.4 kB). View file
 
model/cosyvoice/cli/__pycache__/frontend.cpython-312.pyc ADDED
Binary file (15 kB). View file
 
model/cosyvoice/cli/__pycache__/model.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
model/cosyvoice/cli/__pycache__/model.cpython-311.pyc ADDED
Binary file (37 kB). View file
 
model/cosyvoice/cli/cosyvoice.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import time
16
+ from typing import Generator
17
+ from tqdm import tqdm
18
+ from hyperpyyaml import load_hyperpyyaml
19
+ from modelscope import snapshot_download
20
+ import torch
21
+ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
22
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
23
+ from cosyvoice.utils.file_utils import logging
24
+ from cosyvoice.utils.class_utils import get_model_type
25
+
26
+
27
+ class CosyVoice:
28
+
29
+ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
30
+ self.instruct = True if '-Instruct' in model_dir else False
31
+ self.model_dir = model_dir
32
+ self.fp16 = fp16
33
+ if not os.path.exists(model_dir):
34
+ model_dir = snapshot_download(model_dir)
35
+ hyper_yaml_path = '{}/cosyvoice.yaml'.format(model_dir)
36
+ if not os.path.exists(hyper_yaml_path):
37
+ raise ValueError('{} not found!'.format(hyper_yaml_path))
38
+ with open(hyper_yaml_path, 'r') as f:
39
+ configs = load_hyperpyyaml(f)
40
+ assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
41
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
42
+ configs['feat_extractor'],
43
+ '{}/campplus.onnx'.format(model_dir),
44
+ '{}/speech_tokenizer_v1.onnx'.format(model_dir),
45
+ '{}/spk2info.pt'.format(model_dir),
46
+ configs['allowed_special'])
47
+ self.sample_rate = configs['sample_rate']
48
+ if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
49
+ load_jit, load_trt, fp16 = False, False, False
50
+ logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
51
+ self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
52
+ self.model.load('{}/llm.pt'.format(model_dir),
53
+ '{}/flow.pt'.format(model_dir),
54
+ '{}/hift.pt'.format(model_dir))
55
+ if load_jit:
56
+ self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
57
+ '{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
58
+ '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
59
+ if load_trt:
60
+ self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
61
+ '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
62
+ self.fp16)
63
+ del configs
64
+
65
+ def list_available_spks(self):
66
+ spks = list(self.frontend.spk2info.keys())
67
+ return spks
68
+
69
+ def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
70
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
71
+ model_input = self.frontend.frontend_sft(i, spk_id)
72
+ start_time = time.time()
73
+ logging.info('synthesis text {}'.format(i))
74
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
75
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
76
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
77
+ yield model_output
78
+ start_time = time.time()
79
+
80
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
81
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
82
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
83
+ if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
84
+ logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
85
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
86
+ start_time = time.time()
87
+ logging.info('synthesis text {}'.format(i))
88
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
89
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
90
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
91
+ yield model_output
92
+ start_time = time.time()
93
+
94
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
95
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
96
+ model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
97
+ start_time = time.time()
98
+ logging.info('synthesis text {}'.format(i))
99
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
100
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
101
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
102
+ yield model_output
103
+ start_time = time.time()
104
+
105
+ def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
106
+ assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
107
+ if self.instruct is False:
108
+ raise ValueError('{} do not support instruct inference'.format(self.model_dir))
109
+ instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
110
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
111
+ model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
112
+ start_time = time.time()
113
+ logging.info('synthesis text {}'.format(i))
114
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
115
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
116
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
117
+ yield model_output
118
+ start_time = time.time()
119
+
120
+ def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
121
+ model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
122
+ start_time = time.time()
123
+ for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
124
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
125
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
126
+ yield model_output
127
+ start_time = time.time()
128
+
129
+
130
+ class CosyVoice2(CosyVoice):
131
+
132
+ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False):
133
+ self.instruct = True if '-Instruct' in model_dir else False
134
+ self.model_dir = model_dir
135
+ self.fp16 = fp16
136
+ if not os.path.exists(model_dir):
137
+ model_dir = snapshot_download(model_dir)
138
+ hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
139
+ if not os.path.exists(hyper_yaml_path):
140
+ raise ValueError('{} not found!'.format(hyper_yaml_path))
141
+ with open(hyper_yaml_path, 'r') as f:
142
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
143
+ assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
144
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
145
+ configs['feat_extractor'],
146
+ '{}/campplus.onnx'.format(model_dir),
147
+ '{}/speech_tokenizer_v2.onnx'.format(model_dir),
148
+ '{}/spk2info.pt'.format(model_dir),
149
+ configs['allowed_special'])
150
+ self.sample_rate = configs['sample_rate']
151
+ if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
152
+ load_jit, load_trt, fp16 = False, False, False
153
+ logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
154
+ self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache)
155
+ self.model.load('{}/llm.pt'.format(model_dir),
156
+ '{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir),
157
+ '{}/hift.pt'.format(model_dir))
158
+ if load_jit:
159
+ self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
160
+ if load_trt:
161
+ self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
162
+ '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
163
+ self.fp16)
164
+ del configs
165
+
166
+ def inference_instruct(self, *args, **kwargs):
167
+ raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
168
+
169
+ def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
170
+ assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
171
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
172
+ model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
173
+ start_time = time.time()
174
+ logging.info('synthesis text {}'.format(i))
175
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
176
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
177
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
178
+ yield model_output
179
+ start_time = time.time()
model/cosyvoice/cli/frontend.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import partial
15
+ from typing import Generator
16
+ import json
17
+ import onnxruntime
18
+ import torch
19
+ import numpy as np
20
+ import whisper
21
+ from typing import Callable
22
+ import torchaudio.compliance.kaldi as kaldi
23
+ import torchaudio
24
+ import os
25
+ import re
26
+ import inflect
27
+ try:
28
+ import ttsfrd
29
+ use_ttsfrd = True
30
+ except ImportError:
31
+ print("failed to import ttsfrd, use WeTextProcessing instead")
32
+ from tn.chinese.normalizer import Normalizer as ZhNormalizer
33
+ from tn.english.normalizer import Normalizer as EnNormalizer
34
+ use_ttsfrd = False
35
+ from cosyvoice.utils.file_utils import logging
36
+ from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
37
+
38
+
39
+ class CosyVoiceFrontEnd:
40
+
41
+ def __init__(self,
42
+ get_tokenizer: Callable,
43
+ feat_extractor: Callable,
44
+ campplus_model: str,
45
+ speech_tokenizer_model: str,
46
+ spk2info: str = '',
47
+ allowed_special: str = 'all'):
48
+ self.tokenizer = get_tokenizer()
49
+ self.feat_extractor = feat_extractor
50
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
51
+ option = onnxruntime.SessionOptions()
52
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
53
+ option.intra_op_num_threads = 1
54
+ self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
55
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
56
+ providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
57
+ "CPUExecutionProvider"])
58
+ if os.path.exists(spk2info):
59
+ self.spk2info = torch.load(spk2info, map_location=self.device)
60
+ else:
61
+ self.spk2info = {}
62
+ self.allowed_special = allowed_special
63
+ self.use_ttsfrd = use_ttsfrd
64
+ if self.use_ttsfrd:
65
+ self.frd = ttsfrd.TtsFrontendEngine()
66
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
67
+ assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
68
+ 'failed to initialize ttsfrd resource'
69
+ self.frd.set_lang_type('pinyinvg')
70
+ else:
71
+ self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
72
+ self.en_tn_model = EnNormalizer()
73
+ self.inflect_parser = inflect.engine()
74
+
75
+ def _extract_text_token(self, text):
76
+ if isinstance(text, Generator):
77
+ logging.info('get tts_text generator, will return _extract_text_token_generator!')
78
+ # NOTE add a dummy text_token_len for compatibility
79
+ return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
80
+ else:
81
+ text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
82
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
83
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
84
+ return text_token, text_token_len
85
+
86
+ def _extract_text_token_generator(self, text_generator):
87
+ for text in text_generator:
88
+ text_token, _ = self._extract_text_token(text)
89
+ for i in range(text_token.shape[1]):
90
+ yield text_token[:, i: i + 1]
91
+
92
+ def _extract_speech_token(self, speech):
93
+ assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
94
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
95
+ speech_token = self.speech_tokenizer_session.run(None,
96
+ {self.speech_tokenizer_session.get_inputs()[0].name:
97
+ feat.detach().cpu().numpy(),
98
+ self.speech_tokenizer_session.get_inputs()[1].name:
99
+ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
100
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
101
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
102
+ return speech_token, speech_token_len
103
+
104
+ def _extract_spk_embedding(self, speech):
105
+ feat = kaldi.fbank(speech,
106
+ num_mel_bins=80,
107
+ dither=0,
108
+ sample_frequency=16000)
109
+ feat = feat - feat.mean(dim=0, keepdim=True)
110
+ embedding = self.campplus_session.run(None,
111
+ {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
112
+ embedding = torch.tensor([embedding]).to(self.device)
113
+ return embedding
114
+
115
+ def _extract_speech_feat(self, speech):
116
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
117
+ speech_feat = speech_feat.unsqueeze(dim=0)
118
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
119
+ return speech_feat, speech_feat_len
120
+
121
+ def text_normalize(self, text, split=True, text_frontend=True):
122
+ if isinstance(text, Generator):
123
+ logging.info('get tts_text generator, will skip text_normalize!')
124
+ return [text]
125
+ if text_frontend is False:
126
+ return [text] if split is True else text
127
+ text = text.strip()
128
+ if self.use_ttsfrd:
129
+ texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
130
+ text = ''.join(texts)
131
+ else:
132
+ if contains_chinese(text):
133
+ text = self.zh_tn_model.normalize(text)
134
+ text = text.replace("\n", "")
135
+ text = replace_blank(text)
136
+ text = replace_corner_mark(text)
137
+ text = text.replace(".", "。")
138
+ text = text.replace(" - ", ",")
139
+ text = remove_bracket(text)
140
+ text = re.sub(r'[,,、]+$', '。', text)
141
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
142
+ token_min_n=60, merge_len=20, comma_split=False))
143
+ else:
144
+ text = self.en_tn_model.normalize(text)
145
+ text = spell_out_number(text, self.inflect_parser)
146
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
147
+ token_min_n=60, merge_len=20, comma_split=False))
148
+ texts = [i for i in texts if not is_only_punctuation(i)]
149
+ return texts if split is True else text
150
+
151
+ def frontend_sft(self, tts_text, spk_id):
152
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
153
+ embedding = self.spk2info[spk_id]['embedding']
154
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
155
+ return model_input
156
+
157
+ def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
158
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
159
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
160
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
161
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
162
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
163
+ if resample_rate == 24000:
164
+ # cosyvoice2, force speech_feat % speech_token = 2
165
+ token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
166
+ speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
167
+ speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
168
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
169
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
170
+ 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
171
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
172
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
173
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
174
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
175
+ return model_input
176
+
177
+ def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
178
+ model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
179
+ # in cross lingual mode, we remove prompt in llm
180
+ del model_input['prompt_text']
181
+ del model_input['prompt_text_len']
182
+ del model_input['llm_prompt_speech_token']
183
+ del model_input['llm_prompt_speech_token_len']
184
+ return model_input
185
+
186
+ def frontend_instruct(self, tts_text, spk_id, instruct_text):
187
+ model_input = self.frontend_sft(tts_text, spk_id)
188
+ # in instruct mode, we remove spk_embedding in llm due to information leakage
189
+ del model_input['llm_embedding']
190
+ instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
191
+ model_input['prompt_text'] = instruct_text_token
192
+ model_input['prompt_text_len'] = instruct_text_token_len
193
+ return model_input
194
+
195
+ def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
196
+ model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate)
197
+ del model_input['llm_prompt_speech_token']
198
+ del model_input['llm_prompt_speech_token_len']
199
+ return model_input
200
+
201
+ def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
202
+ prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
203
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
204
+ prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
205
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
206
+ source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
207
+ model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
208
+ 'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
209
+ 'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
210
+ 'flow_embedding': embedding}
211
+ return model_input
model/cosyvoice/cli/model.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from typing import Generator
16
+ import torch
17
+ import numpy as np
18
+ import threading
19
+ import time
20
+ from torch.nn import functional as F
21
+ from contextlib import nullcontext
22
+ import uuid
23
+ from cosyvoice.utils.common import fade_in_out
24
+ from cosyvoice.utils.file_utils import convert_onnx_to_trt
25
+
26
+
27
+ class CosyVoiceModel:
28
+
29
+ def __init__(self,
30
+ llm: torch.nn.Module,
31
+ flow: torch.nn.Module,
32
+ hift: torch.nn.Module,
33
+ fp16: bool):
34
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35
+ self.llm = llm
36
+ self.flow = flow
37
+ self.hift = hift
38
+ self.fp16 = fp16
39
+ if self.fp16 is True:
40
+ self.llm.half()
41
+ self.flow.half()
42
+ self.token_min_hop_len = 2 * self.flow.input_frame_rate
43
+ self.token_max_hop_len = 4 * self.flow.input_frame_rate
44
+ self.token_overlap_len = 20
45
+ # mel fade in out
46
+ self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
47
+ self.mel_window = np.hamming(2 * self.mel_overlap_len)
48
+ # hift cache
49
+ self.mel_cache_len = 20
50
+ self.source_cache_len = int(self.mel_cache_len * 256)
51
+ # speech fade in out
52
+ self.speech_window = np.hamming(2 * self.source_cache_len)
53
+ # rtf and decoding related
54
+ self.stream_scale_factor = 1
55
+ assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
56
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
57
+ self.lock = threading.Lock()
58
+ # dict used to store session related variable
59
+ self.tts_speech_token_dict = {}
60
+ self.llm_end_dict = {}
61
+ self.mel_overlap_dict = {}
62
+ self.flow_cache_dict = {}
63
+ self.hift_cache_dict = {}
64
+
65
+ def load(self, llm_model, flow_model, hift_model):
66
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
67
+ self.llm.to(self.device).eval()
68
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
69
+ self.flow.to(self.device).eval()
70
+ # in case hift_model is a hifigan model
71
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
72
+ self.hift.load_state_dict(hift_state_dict, strict=True)
73
+ self.hift.to(self.device).eval()
74
+
75
+ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
76
+ llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
77
+ self.llm.text_encoder = llm_text_encoder
78
+ llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
79
+ self.llm.llm = llm_llm
80
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
81
+ self.flow.encoder = flow_encoder
82
+
83
+ def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
84
+ assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
85
+ if not os.path.exists(flow_decoder_estimator_model):
86
+ convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
87
+ if os.path.getsize(flow_decoder_estimator_model) == 0:
88
+ raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
89
+ del self.flow.decoder.estimator
90
+ import tensorrt as trt
91
+ with open(flow_decoder_estimator_model, 'rb') as f:
92
+ self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
93
+ assert self.flow.decoder.estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
94
+ self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
95
+
96
+ def get_trt_kwargs(self):
97
+ min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
98
+ opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200)]
99
+ max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
100
+ input_names = ["x", "mask", "mu", "cond"]
101
+ return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
102
+
103
+ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
104
+ with self.llm_context, torch.cuda.amp.autocast(self.fp16):
105
+ if isinstance(text, Generator):
106
+ assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
107
+ for i in self.llm.inference_bistream(text=text,
108
+ prompt_text=prompt_text.to(self.device),
109
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
110
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
111
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
112
+ embedding=llm_embedding.to(self.device)):
113
+ self.tts_speech_token_dict[uuid].append(i)
114
+ else:
115
+ for i in self.llm.inference(text=text.to(self.device),
116
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
117
+ prompt_text=prompt_text.to(self.device),
118
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
119
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
120
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
121
+ embedding=llm_embedding.to(self.device)):
122
+ self.tts_speech_token_dict[uuid].append(i)
123
+ self.llm_end_dict[uuid] = True
124
+
125
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
126
+ with torch.cuda.amp.autocast(self.fp16):
127
+ tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
128
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
129
+ prompt_token=prompt_token.to(self.device),
130
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
131
+ prompt_feat=prompt_feat.to(self.device),
132
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
133
+ embedding=embedding.to(self.device),
134
+ flow_cache=self.flow_cache_dict[uuid])
135
+
136
+ # mel overlap fade in out
137
+ if self.mel_overlap_dict[uuid].shape[2] != 0:
138
+ tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
139
+ # append hift cache
140
+ if self.hift_cache_dict[uuid] is not None:
141
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
142
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
143
+ else:
144
+ hift_cache_source = torch.zeros(1, 1, 0)
145
+ # keep overlap mel and hift cache
146
+ if finalize is False:
147
+ self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
148
+ tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
149
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
150
+ if self.hift_cache_dict[uuid] is not None:
151
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
152
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
153
+ 'source': tts_source[:, :, -self.source_cache_len:],
154
+ 'speech': tts_speech[:, -self.source_cache_len:]}
155
+ tts_speech = tts_speech[:, :-self.source_cache_len]
156
+ else:
157
+ if speed != 1.0:
158
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
159
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
160
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
161
+ if self.hift_cache_dict[uuid] is not None:
162
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
163
+ return tts_speech
164
+
165
+ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
166
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
167
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
168
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
169
+ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
170
+ # this_uuid is used to track variables related to this inference thread
171
+ this_uuid = str(uuid.uuid1())
172
+ with self.lock:
173
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
174
+ self.hift_cache_dict[this_uuid] = None
175
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
176
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
177
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
178
+ p.start()
179
+ if stream is True:
180
+ token_hop_len = self.token_min_hop_len
181
+ while True:
182
+ time.sleep(0.1)
183
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
184
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
185
+ .unsqueeze(dim=0)
186
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
187
+ prompt_token=flow_prompt_speech_token,
188
+ prompt_feat=prompt_speech_feat,
189
+ embedding=flow_embedding,
190
+ uuid=this_uuid,
191
+ finalize=False)
192
+ yield {'tts_speech': this_tts_speech.cpu()}
193
+ with self.lock:
194
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
195
+ # increase token_hop_len for better speech quality
196
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
197
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
198
+ break
199
+ p.join()
200
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
201
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
202
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
203
+ prompt_token=flow_prompt_speech_token,
204
+ prompt_feat=prompt_speech_feat,
205
+ embedding=flow_embedding,
206
+ uuid=this_uuid,
207
+ finalize=True)
208
+ yield {'tts_speech': this_tts_speech.cpu()}
209
+ else:
210
+ # deal with all tokens
211
+ p.join()
212
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
213
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
214
+ prompt_token=flow_prompt_speech_token,
215
+ prompt_feat=prompt_speech_feat,
216
+ embedding=flow_embedding,
217
+ uuid=this_uuid,
218
+ finalize=True,
219
+ speed=speed)
220
+ yield {'tts_speech': this_tts_speech.cpu()}
221
+ with self.lock:
222
+ self.tts_speech_token_dict.pop(this_uuid)
223
+ self.llm_end_dict.pop(this_uuid)
224
+ self.mel_overlap_dict.pop(this_uuid)
225
+ self.hift_cache_dict.pop(this_uuid)
226
+ self.flow_cache_dict.pop(this_uuid)
227
+ torch.cuda.empty_cache()
228
+
229
+ def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
230
+ # this_uuid is used to track variables related to this inference thread
231
+ this_uuid = str(uuid.uuid1())
232
+ with self.lock:
233
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
234
+ self.hift_cache_dict[this_uuid] = None
235
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
236
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
237
+ if stream is True:
238
+ token_hop_len = self.token_min_hop_len
239
+ while True:
240
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
241
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
242
+ .unsqueeze(dim=0)
243
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
244
+ prompt_token=flow_prompt_speech_token,
245
+ prompt_feat=prompt_speech_feat,
246
+ embedding=flow_embedding,
247
+ uuid=this_uuid,
248
+ finalize=False)
249
+ yield {'tts_speech': this_tts_speech.cpu()}
250
+ with self.lock:
251
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
252
+ # increase token_hop_len for better speech quality
253
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
254
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
255
+ break
256
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
257
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
258
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
259
+ prompt_token=flow_prompt_speech_token,
260
+ prompt_feat=prompt_speech_feat,
261
+ embedding=flow_embedding,
262
+ uuid=this_uuid,
263
+ finalize=True)
264
+ yield {'tts_speech': this_tts_speech.cpu()}
265
+ else:
266
+ # deal with all tokens
267
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
268
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
269
+ prompt_token=flow_prompt_speech_token,
270
+ prompt_feat=prompt_speech_feat,
271
+ embedding=flow_embedding,
272
+ uuid=this_uuid,
273
+ finalize=True,
274
+ speed=speed)
275
+ yield {'tts_speech': this_tts_speech.cpu()}
276
+ with self.lock:
277
+ self.tts_speech_token_dict.pop(this_uuid)
278
+ self.llm_end_dict.pop(this_uuid)
279
+ self.mel_overlap_dict.pop(this_uuid)
280
+ self.hift_cache_dict.pop(this_uuid)
281
+ self.flow_cache_dict.pop(this_uuid)
282
+ torch.cuda.empty_cache()
283
+
284
+
285
+ class CosyVoice2Model(CosyVoiceModel):
286
+
287
+ def __init__(self,
288
+ llm: torch.nn.Module,
289
+ flow: torch.nn.Module,
290
+ hift: torch.nn.Module,
291
+ fp16: bool,
292
+ use_flow_cache: bool):
293
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
294
+ self.llm = llm
295
+ self.flow = flow
296
+ self.hift = hift
297
+ self.fp16 = fp16
298
+ self.use_flow_cache = use_flow_cache
299
+ if self.fp16 is True:
300
+ self.llm.half()
301
+ self.flow.half()
302
+ # stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml
303
+ self.token_hop_len = 25
304
+ self.flow_decoder_required_cache_size = -1 if use_flow_cache is False else 1 * self.token_hop_len
305
+ # hift cache
306
+ self.mel_cache_len = 8
307
+ self.source_cache_len = int(self.mel_cache_len * 480)
308
+ # speech fade in out
309
+ self.speech_window = np.hamming(2 * self.source_cache_len)
310
+ # rtf and decoding related
311
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
312
+ self.lock = threading.Lock()
313
+ # dict used to store session related variable
314
+ self.tts_speech_token_dict = {}
315
+ self.llm_end_dict = {}
316
+ self.flow_cache_dict = {}
317
+ self.hift_cache_dict = {}
318
+
319
+ def init_flow_cache(self):
320
+ encoder_cache = {'offset': 0,
321
+ 'pre_lookahead_layer_conv2_cache': torch.zeros(1, 512, 2).to(self.device),
322
+ 'encoders_kv_cache': torch.zeros(6, 1, 8, 0, 64 * 2).to(self.device),
323
+ 'upsample_offset': 0,
324
+ 'upsample_conv_cache': torch.zeros(1, 512, 4).to(self.device),
325
+ 'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)}
326
+ decoder_cache = {'offset': 0,
327
+ 'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device),
328
+ 'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, 0, 512, 2).to(self.device),
329
+ 'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device),
330
+ 'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, 0, 512, 2).to(self.device),
331
+ 'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device),
332
+ 'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, 0, 512, 2).to(self.device),
333
+ 'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)}
334
+ if self.fp16 is True:
335
+ for cache in [encoder_cache, decoder_cache]:
336
+ for k, v in cache.items():
337
+ if isinstance(v, torch.Tensor):
338
+ cache[k] = v.half()
339
+ cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache}
340
+ return cache
341
+
342
+ def trim_flow_cache(self, cache):
343
+ if self.flow_decoder_required_cache_size > 0:
344
+ cache['decoder_cache']['down_blocks_kv_cache'] = cache['decoder_cache']['down_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
345
+ cache['decoder_cache']['mid_blocks_kv_cache'] = cache['decoder_cache']['mid_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
346
+ cache['decoder_cache']['up_blocks_kv_cache'] = cache['decoder_cache']['up_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
347
+ return cache
348
+
349
+ def load_jit(self, flow_encoder_model):
350
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
351
+ self.flow.encoder = flow_encoder
352
+
353
+ def get_trt_kwargs(self):
354
+ min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (1, 4, 2, 0, 512, 2), (12, 4, 2, 0, 512, 2), (1, 4, 2, 0, 512, 2)]
355
+ opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200), (1, 4, 2, 100, 512, 2), (12, 4, 2, 100, 512, 2), (1, 4, 2, 100, 512, 2)]
356
+ max_shape = [(2, 80, 1500), (2, 1, 1500), (2, 80, 1500), (2, 80, 1500), (1, 4, 2, 200, 512, 2), (12, 4, 2, 200, 512, 2), (1, 4, 2, 200, 512, 2)]
357
+ input_names = ["x", "mask", "mu", "cond", 'down_blocks_kv_cache', 'mid_blocks_kv_cache', 'up_blocks_kv_cache']
358
+ assert self.use_flow_cache is True, "get_trt_kwargs is set for flow cache mode. If you want to use trt with use_flow_cache=False, please set higher max_shape"
359
+ return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
360
+
361
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
362
+ with torch.cuda.amp.autocast(self.fp16):
363
+ tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
364
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
365
+ prompt_token=prompt_token.to(self.device),
366
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
367
+ prompt_feat=prompt_feat.to(self.device),
368
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
369
+ embedding=embedding.to(self.device),
370
+ cache=self.flow_cache_dict[uuid],
371
+ finalize=finalize)
372
+ self.flow_cache_dict[uuid] = self.trim_flow_cache(self.flow_cache_dict[uuid])
373
+ # append hift cache
374
+ if self.hift_cache_dict[uuid] is not None:
375
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
376
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
377
+ else:
378
+ hift_cache_source = torch.zeros(1, 1, 0)
379
+ # keep overlap mel and hift cache
380
+ if finalize is False:
381
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
382
+ if self.hift_cache_dict[uuid] is not None:
383
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
384
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
385
+ 'source': tts_source[:, :, -self.source_cache_len:],
386
+ 'speech': tts_speech[:, -self.source_cache_len:]}
387
+ tts_speech = tts_speech[:, :-self.source_cache_len]
388
+ else:
389
+ if speed != 1.0:
390
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
391
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
392
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
393
+ if self.hift_cache_dict[uuid] is not None:
394
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
395
+ return tts_speech
396
+
397
+ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
398
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
399
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
400
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
401
+ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
402
+ # this_uuid is used to track variables related to this inference thread
403
+ this_uuid = str(uuid.uuid1())
404
+ with self.lock:
405
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
406
+ self.hift_cache_dict[this_uuid] = None
407
+ self.flow_cache_dict[this_uuid] = self.init_flow_cache()
408
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
409
+ p.start()
410
+ if stream is True:
411
+ assert self.use_flow_cache is True, "set use_flow_cache=True if you want to use stream inference to avoid OOM"
412
+ # NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size
413
+ flow_prompt_speech_token = flow_prompt_speech_token[:, -self.flow_decoder_required_cache_size:]
414
+ prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size * 2:]
415
+ while True:
416
+ time.sleep(0.1)
417
+ if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len:
418
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
419
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
420
+ prompt_token=flow_prompt_speech_token,
421
+ prompt_feat=prompt_speech_feat,
422
+ embedding=flow_embedding,
423
+ uuid=this_uuid,
424
+ finalize=False)
425
+ # NOTE in cache inference mode, we only use flow_prompt_speech_token/prompt_speech_feat in first chunk
426
+ flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device)
427
+ prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device)
428
+ yield {'tts_speech': this_tts_speech.cpu()}
429
+ with self.lock:
430
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][self.token_hop_len:]
431
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < self.token_hop_len + self.flow.pre_lookahead_len:
432
+ break
433
+ p.join()
434
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
435
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
436
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
437
+ prompt_token=flow_prompt_speech_token,
438
+ prompt_feat=prompt_speech_feat,
439
+ embedding=flow_embedding,
440
+ uuid=this_uuid,
441
+ finalize=True)
442
+ yield {'tts_speech': this_tts_speech.cpu()}
443
+ else:
444
+ # deal with all tokens
445
+ assert self.use_flow_cache is False, "set use_flow_cache=False for nonstream inference"
446
+ p.join()
447
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
448
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
449
+ prompt_token=flow_prompt_speech_token,
450
+ prompt_feat=prompt_speech_feat,
451
+ embedding=flow_embedding,
452
+ uuid=this_uuid,
453
+ finalize=True,
454
+ speed=speed)
455
+ yield {'tts_speech': this_tts_speech.cpu()}
456
+ with self.lock:
457
+ self.tts_speech_token_dict.pop(this_uuid)
458
+ self.llm_end_dict.pop(this_uuid)
459
+ self.hift_cache_dict.pop(this_uuid)
460
+ self.flow_cache_dict.pop(this_uuid)
461
+ torch.cuda.empty_cache()
model/cosyvoice/dataset/__init__.py ADDED
File without changes
model/cosyvoice/dataset/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (149 Bytes). View file
 
model/cosyvoice/dataset/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (165 Bytes). View file
 
model/cosyvoice/dataset/__pycache__/processor.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
model/cosyvoice/dataset/__pycache__/processor.cpython-311.pyc ADDED
Binary file (22.8 kB). View file
 
model/cosyvoice/dataset/dataset.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import random
17
+ import json
18
+ import math
19
+ from functools import partial
20
+
21
+ import torch
22
+ import torch.distributed as dist
23
+ from torch.utils.data import IterableDataset
24
+ from cosyvoice.utils.file_utils import read_lists, read_json_lists
25
+
26
+
27
+ class Processor(IterableDataset):
28
+
29
+ def __init__(self, source, f, *args, **kw):
30
+ assert callable(f)
31
+ self.source = source
32
+ self.f = f
33
+ self.args = args
34
+ self.kw = kw
35
+
36
+ def set_epoch(self, epoch):
37
+ self.source.set_epoch(epoch)
38
+
39
+ def __iter__(self):
40
+ """ Return an iterator over the source dataset processed by the
41
+ given processor.
42
+ """
43
+ assert self.source is not None
44
+ assert callable(self.f)
45
+ return self.f(iter(self.source), *self.args, **self.kw)
46
+
47
+ def apply(self, f):
48
+ assert callable(f)
49
+ return Processor(self, f, *self.args, **self.kw)
50
+
51
+
52
+ class DistributedSampler:
53
+
54
+ def __init__(self, shuffle=True, partition=True):
55
+ self.epoch = -1
56
+ self.update()
57
+ self.shuffle = shuffle
58
+ self.partition = partition
59
+
60
+ def update(self):
61
+ assert dist.is_available()
62
+ if dist.is_initialized():
63
+ self.rank = dist.get_rank()
64
+ self.world_size = dist.get_world_size()
65
+ else:
66
+ self.rank = 0
67
+ self.world_size = 1
68
+ worker_info = torch.utils.data.get_worker_info()
69
+ if worker_info is None:
70
+ self.worker_id = 0
71
+ self.num_workers = 1
72
+ else:
73
+ self.worker_id = worker_info.id
74
+ self.num_workers = worker_info.num_workers
75
+ return dict(rank=self.rank,
76
+ world_size=self.world_size,
77
+ worker_id=self.worker_id,
78
+ num_workers=self.num_workers)
79
+
80
+ def set_epoch(self, epoch):
81
+ self.epoch = epoch
82
+
83
+ def sample(self, data):
84
+ """ Sample data according to rank/world_size/num_workers
85
+
86
+ Args:
87
+ data(List): input data list
88
+
89
+ Returns:
90
+ List: data list after sample
91
+ """
92
+ data = list(range(len(data)))
93
+ # force datalist even
94
+ if self.partition:
95
+ if self.shuffle:
96
+ random.Random(self.epoch).shuffle(data)
97
+ if len(data) < self.world_size:
98
+ data = data * math.ceil(self.world_size / len(data))
99
+ data = data[:self.world_size]
100
+ data = data[self.rank::self.world_size]
101
+ if len(data) < self.num_workers:
102
+ data = data * math.ceil(self.num_workers / len(data))
103
+ data = data[:self.num_workers]
104
+ data = data[self.worker_id::self.num_workers]
105
+ return data
106
+
107
+
108
+ class DataList(IterableDataset):
109
+
110
+ def __init__(self, lists, shuffle=True, partition=True):
111
+ self.lists = lists
112
+ self.sampler = DistributedSampler(shuffle, partition)
113
+
114
+ def set_epoch(self, epoch):
115
+ self.sampler.set_epoch(epoch)
116
+
117
+ def __iter__(self):
118
+ sampler_info = self.sampler.update()
119
+ indexes = self.sampler.sample(self.lists)
120
+ for index in indexes:
121
+ data = dict(src=self.lists[index])
122
+ data.update(sampler_info)
123
+ yield data
124
+
125
+
126
+ def Dataset(data_list_file,
127
+ data_pipeline,
128
+ mode='train',
129
+ gan=False,
130
+ shuffle=True,
131
+ partition=True,
132
+ tts_file='',
133
+ prompt_utt2data=''):
134
+ """ Construct dataset from arguments
135
+
136
+ We have two shuffle stage in the Dataset. The first is global
137
+ shuffle at shards tar/raw file level. The second is global shuffle
138
+ at training samples level.
139
+
140
+ Args:
141
+ data_type(str): raw/shard
142
+ tokenizer (BaseTokenizer): tokenizer to tokenize
143
+ partition(bool): whether to do data partition in terms of rank
144
+ """
145
+ assert mode in ['train', 'inference']
146
+ lists = read_lists(data_list_file)
147
+ if mode == 'inference':
148
+ with open(tts_file) as f:
149
+ tts_data = json.load(f)
150
+ utt2lists = read_json_lists(prompt_utt2data)
151
+ # filter unnecessary file in inference mode
152
+ lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
153
+ dataset = DataList(lists,
154
+ shuffle=shuffle,
155
+ partition=partition)
156
+ if mode == 'inference':
157
+ # map partial arg to parquet_opener func in inference mode
158
+ data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
159
+ if gan is True:
160
+ # map partial arg to padding func in gan mode
161
+ data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
162
+ for func in data_pipeline:
163
+ dataset = Processor(dataset, func, mode=mode)
164
+ return dataset
model/cosyvoice/dataset/processor.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+
17
+ import pyarrow.parquet as pq
18
+ from io import BytesIO
19
+ import torch
20
+ import torchaudio
21
+ from torch.nn.utils.rnn import pad_sequence
22
+ import torch.nn.functional as F
23
+ import pyworld as pw
24
+
25
+
26
+ AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
27
+
28
+
29
+ def parquet_opener(data, mode='train', tts_data={}):
30
+ """ Give url or local file, return file descriptor
31
+ Inplace operation.
32
+
33
+ Args:
34
+ data(Iterable[str]): url or local file list
35
+
36
+ Returns:
37
+ Iterable[{src, stream}]
38
+ """
39
+ for sample in data:
40
+ assert 'src' in sample
41
+ url = sample['src']
42
+ try:
43
+ for df in pq.ParquetFile(url).iter_batches(batch_size=64):
44
+ df = df.to_pandas()
45
+ for i in range(len(df)):
46
+ if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
47
+ continue
48
+ sample.update(dict(df.loc[i]))
49
+ if mode == 'train':
50
+ # NOTE do not return sample directly, must initialize a new dict
51
+ yield {**sample}
52
+ else:
53
+ for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
54
+ yield {**sample, 'tts_index': index, 'tts_text': text}
55
+ except Exception as ex:
56
+ logging.warning('Failed to open {}, ex info {}'.format(url, ex))
57
+
58
+
59
+ def filter(data,
60
+ max_length=10240,
61
+ min_length=10,
62
+ token_max_length=200,
63
+ token_min_length=1,
64
+ min_output_input_ratio=0.0005,
65
+ max_output_input_ratio=1,
66
+ mode='train'):
67
+ """ Filter sample according to feature and label length
68
+ Inplace operation.
69
+
70
+ Args::
71
+ data: Iterable[{key, wav, label, sample_rate}]
72
+ max_length: drop utterance which is greater than max_length(10ms)
73
+ min_length: drop utterance which is less than min_length(10ms)
74
+ token_max_length: drop utterance which is greater than
75
+ token_max_length, especially when use char unit for
76
+ english modeling
77
+ token_min_length: drop utterance which is
78
+ less than token_max_length
79
+ min_output_input_ratio: minimal ration of
80
+ token_length / feats_length(10ms)
81
+ max_output_input_ratio: maximum ration of
82
+ token_length / feats_length(10ms)
83
+
84
+ Returns:
85
+ Iterable[{key, wav, label, sample_rate}]
86
+ """
87
+ for sample in data:
88
+ sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
89
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
90
+ del sample['audio_data']
91
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
92
+ num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
93
+ if num_frames < min_length:
94
+ continue
95
+ if num_frames > max_length:
96
+ continue
97
+ if len(sample['text_token']) < token_min_length:
98
+ continue
99
+ if len(sample['text_token']) > token_max_length:
100
+ continue
101
+ if len(sample['speech_token']) == 0:
102
+ continue
103
+ if num_frames != 0:
104
+ if len(sample['text_token']) / num_frames < min_output_input_ratio:
105
+ continue
106
+ if len(sample['text_token']) / num_frames > max_output_input_ratio:
107
+ continue
108
+ yield sample
109
+
110
+
111
+ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
112
+ """ Resample data.
113
+ Inplace operation.
114
+
115
+ Args:
116
+ data: Iterable[{key, wav, label, sample_rate}]
117
+ resample_rate: target resample rate
118
+
119
+ Returns:
120
+ Iterable[{key, wav, label, sample_rate}]
121
+ """
122
+ for sample in data:
123
+ assert 'sample_rate' in sample
124
+ assert 'speech' in sample
125
+ sample_rate = sample['sample_rate']
126
+ waveform = sample['speech']
127
+ if sample_rate != resample_rate:
128
+ if sample_rate < min_sample_rate:
129
+ continue
130
+ sample['sample_rate'] = resample_rate
131
+ sample['speech'] = torchaudio.transforms.Resample(
132
+ orig_freq=sample_rate, new_freq=resample_rate)(waveform)
133
+ max_val = sample['speech'].abs().max()
134
+ if max_val > 1:
135
+ sample['speech'] /= max_val
136
+ yield sample
137
+
138
+
139
+ def truncate(data, truncate_length=24576, mode='train'):
140
+ """ Truncate data.
141
+
142
+ Args:
143
+ data: Iterable[{key, wav, label, sample_rate}]
144
+ truncate_length: truncate length
145
+
146
+ Returns:
147
+ Iterable[{key, wav, label, sample_rate}]
148
+ """
149
+ for sample in data:
150
+ waveform = sample['speech']
151
+ if waveform.shape[1] > truncate_length:
152
+ start = random.randint(0, waveform.shape[1] - truncate_length)
153
+ waveform = waveform[:, start: start + truncate_length]
154
+ else:
155
+ waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
156
+ sample['speech'] = waveform
157
+ yield sample
158
+
159
+
160
+ def compute_fbank(data,
161
+ feat_extractor,
162
+ mode='train'):
163
+ """ Extract fbank
164
+
165
+ Args:
166
+ data: Iterable[{key, wav, label, sample_rate}]
167
+
168
+ Returns:
169
+ Iterable[{key, feat, label}]
170
+ """
171
+ for sample in data:
172
+ assert 'sample_rate' in sample
173
+ assert 'speech' in sample
174
+ assert 'utt' in sample
175
+ assert 'text_token' in sample
176
+ waveform = sample['speech']
177
+ mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
178
+ sample['speech_feat'] = mat
179
+ yield sample
180
+
181
+
182
+ def compute_f0(data, sample_rate, hop_size, mode='train'):
183
+ """ Extract f0
184
+
185
+ Args:
186
+ data: Iterable[{key, wav, label, sample_rate}]
187
+
188
+ Returns:
189
+ Iterable[{key, feat, label}]
190
+ """
191
+ frame_period = hop_size * 1000 / sample_rate
192
+ for sample in data:
193
+ assert 'sample_rate' in sample
194
+ assert 'speech' in sample
195
+ assert 'utt' in sample
196
+ assert 'text_token' in sample
197
+ waveform = sample['speech']
198
+ _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
199
+ if sum(_f0 != 0) < 5: # this happens when the algorithm fails
200
+ _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
201
+ f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
202
+ f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
203
+ sample['pitch_feat'] = f0
204
+ yield sample
205
+
206
+
207
+ def parse_embedding(data, normalize, mode='train'):
208
+ """ Parse utt_embedding/spk_embedding
209
+
210
+ Args:
211
+ data: Iterable[{key, wav, label, sample_rate}]
212
+
213
+ Returns:
214
+ Iterable[{key, feat, label}]
215
+ """
216
+ for sample in data:
217
+ sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
218
+ sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
219
+ if normalize:
220
+ sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
221
+ sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
222
+ yield sample
223
+
224
+
225
+ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
226
+ """ Decode text to chars or BPE
227
+ Inplace operation
228
+
229
+ Args:
230
+ data: Iterable[{key, wav, txt, sample_rate}]
231
+
232
+ Returns:
233
+ Iterable[{key, wav, txt, tokens, label, sample_rate}]
234
+ """
235
+ tokenizer = get_tokenizer()
236
+ for sample in data:
237
+ assert 'text' in sample
238
+ sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
239
+ if mode == 'inference':
240
+ sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
241
+ yield sample
242
+
243
+
244
+ def shuffle(data, shuffle_size=10000, mode='train'):
245
+ """ Local shuffle the data
246
+
247
+ Args:
248
+ data: Iterable[{key, feat, label}]
249
+ shuffle_size: buffer size for shuffle
250
+
251
+ Returns:
252
+ Iterable[{key, feat, label}]
253
+ """
254
+ buf = []
255
+ for sample in data:
256
+ buf.append(sample)
257
+ if len(buf) >= shuffle_size:
258
+ random.shuffle(buf)
259
+ for x in buf:
260
+ yield x
261
+ buf = []
262
+ # The sample left over
263
+ random.shuffle(buf)
264
+ for x in buf:
265
+ yield x
266
+
267
+
268
+ def sort(data, sort_size=500, mode='train'):
269
+ """ Sort the data by feature length.
270
+ Sort is used after shuffle and before batch, so we can group
271
+ utts with similar lengths into a batch, and `sort_size` should
272
+ be less than `shuffle_size`
273
+
274
+ Args:
275
+ data: Iterable[{key, feat, label}]
276
+ sort_size: buffer size for sort
277
+
278
+ Returns:
279
+ Iterable[{key, feat, label}]
280
+ """
281
+
282
+ buf = []
283
+ for sample in data:
284
+ buf.append(sample)
285
+ if len(buf) >= sort_size:
286
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
287
+ for x in buf:
288
+ yield x
289
+ buf = []
290
+ # The sample left over
291
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
292
+ for x in buf:
293
+ yield x
294
+
295
+
296
+ def static_batch(data, batch_size=16):
297
+ """ Static batch the data by `batch_size`
298
+
299
+ Args:
300
+ data: Iterable[{key, feat, label}]
301
+ batch_size: batch size
302
+
303
+ Returns:
304
+ Iterable[List[{key, feat, label}]]
305
+ """
306
+ buf = []
307
+ for sample in data:
308
+ buf.append(sample)
309
+ if len(buf) >= batch_size:
310
+ yield buf
311
+ buf = []
312
+ if len(buf) > 0:
313
+ yield buf
314
+
315
+
316
+ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
317
+ """ Dynamic batch the data until the total frames in batch
318
+ reach `max_frames_in_batch`
319
+
320
+ Args:
321
+ data: Iterable[{key, feat, label}]
322
+ max_frames_in_batch: max_frames in one batch
323
+
324
+ Returns:
325
+ Iterable[List[{key, feat, label}]]
326
+ """
327
+ buf = []
328
+ longest_frames = 0
329
+ for sample in data:
330
+ assert 'speech_feat' in sample
331
+ assert isinstance(sample['speech_feat'], torch.Tensor)
332
+ new_sample_frames = sample['speech_feat'].size(0)
333
+ longest_frames = max(longest_frames, new_sample_frames)
334
+ frames_after_padding = longest_frames * (len(buf) + 1)
335
+ if frames_after_padding > max_frames_in_batch:
336
+ yield buf
337
+ buf = [sample]
338
+ longest_frames = new_sample_frames
339
+ else:
340
+ buf.append(sample)
341
+ if len(buf) > 0:
342
+ yield buf
343
+
344
+
345
+ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
346
+ """ Wrapper for static/dynamic batch
347
+ """
348
+ if mode == 'inference':
349
+ return static_batch(data, 1)
350
+ else:
351
+ if batch_type == 'static':
352
+ return static_batch(data, batch_size)
353
+ elif batch_type == 'dynamic':
354
+ return dynamic_batch(data, max_frames_in_batch)
355
+ else:
356
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
357
+
358
+
359
+ def padding(data, use_spk_embedding, mode='train', gan=False):
360
+ """ Padding the data into training data
361
+
362
+ Args:
363
+ data: Iterable[List[{key, feat, label}]]
364
+
365
+ Returns:
366
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
367
+ """
368
+ for sample in data:
369
+ assert isinstance(sample, list)
370
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
371
+ dtype=torch.int32)
372
+ order = torch.argsort(speech_feat_len, descending=True)
373
+
374
+ utts = [sample[i]['utt'] for i in order]
375
+ speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
376
+ speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
377
+ speech = pad_sequence(speech, batch_first=True, padding_value=0)
378
+ speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
379
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
380
+ speech_token = pad_sequence(speech_token,
381
+ batch_first=True,
382
+ padding_value=0)
383
+ speech_feat = [sample[i]['speech_feat'] for i in order]
384
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
385
+ speech_feat = pad_sequence(speech_feat,
386
+ batch_first=True,
387
+ padding_value=0)
388
+ text = [sample[i]['text'] for i in order]
389
+ text_token = [torch.tensor(sample[i]['text_token']) for i in order]
390
+ text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
391
+ text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
392
+ utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
393
+ spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
394
+ batch = {
395
+ "utts": utts,
396
+ "speech": speech,
397
+ "speech_len": speech_len,
398
+ "speech_token": speech_token,
399
+ "speech_token_len": speech_token_len,
400
+ "speech_feat": speech_feat,
401
+ "speech_feat_len": speech_feat_len,
402
+ "text": text,
403
+ "text_token": text_token,
404
+ "text_token_len": text_token_len,
405
+ "utt_embedding": utt_embedding,
406
+ "spk_embedding": spk_embedding,
407
+ }
408
+ if gan is True:
409
+ # in gan train, we need pitch_feat
410
+ pitch_feat = [sample[i]['pitch_feat'] for i in order]
411
+ pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
412
+ pitch_feat = pad_sequence(pitch_feat,
413
+ batch_first=True,
414
+ padding_value=0)
415
+ batch["pitch_feat"] = pitch_feat
416
+ batch["pitch_feat_len"] = pitch_feat_len
417
+ else:
418
+ # only gan train needs speech, delete it to save memory
419
+ del batch["speech"]
420
+ del batch["speech_len"]
421
+ if mode == 'inference':
422
+ tts_text = [sample[i]['tts_text'] for i in order]
423
+ tts_index = [sample[i]['tts_index'] for i in order]
424
+ tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
425
+ tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
426
+ tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
427
+ batch.update({'tts_text': tts_text,
428
+ 'tts_index': tts_index,
429
+ 'tts_text_token': tts_text_token,
430
+ 'tts_text_token_len': tts_text_token_len})
431
+ if use_spk_embedding is True:
432
+ batch["embedding"] = batch["spk_embedding"]
433
+ else:
434
+ batch["embedding"] = batch["utt_embedding"]
435
+ yield batch
model/cosyvoice/flow/__pycache__/decoder.cpython-310.pyc ADDED
Binary file (22.6 kB). View file
 
model/cosyvoice/flow/__pycache__/decoder.cpython-311.pyc ADDED
Binary file (48.9 kB). View file
 
model/cosyvoice/flow/__pycache__/flow.cpython-310.pyc ADDED
Binary file (7.68 kB). View file
 
model/cosyvoice/flow/__pycache__/flow.cpython-311.pyc ADDED
Binary file (17.1 kB). View file
 
model/cosyvoice/flow/__pycache__/flow_matching.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
model/cosyvoice/flow/__pycache__/flow_matching.cpython-311.pyc ADDED
Binary file (22.8 kB). View file
 
model/cosyvoice/flow/decoder.py ADDED
@@ -0,0 +1,902 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Tuple, Optional, Dict, Any
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from einops import pack, rearrange, repeat
19
+ from diffusers.models.attention_processor import Attention, AttnProcessor2_0, inspect, logger, deprecate
20
+ from cosyvoice.utils.common import mask_to_bias
21
+ from cosyvoice.utils.mask import add_optional_chunk_mask
22
+ from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
23
+ from matcha.models.components.transformer import BasicTransformerBlock, maybe_allow_in_graph
24
+
25
+
26
+ class Transpose(torch.nn.Module):
27
+ def __init__(self, dim0: int, dim1: int):
28
+ super().__init__()
29
+ self.dim0 = dim0
30
+ self.dim1 = dim1
31
+
32
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
33
+ x = torch.transpose(x, self.dim0, self.dim1)
34
+ return x
35
+
36
+
37
+ class CausalConv1d(torch.nn.Conv1d):
38
+ def __init__(
39
+ self,
40
+ in_channels: int,
41
+ out_channels: int,
42
+ kernel_size: int,
43
+ stride: int = 1,
44
+ dilation: int = 1,
45
+ groups: int = 1,
46
+ bias: bool = True,
47
+ padding_mode: str = 'zeros',
48
+ device=None,
49
+ dtype=None
50
+ ) -> None:
51
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
52
+ kernel_size, stride,
53
+ padding=0, dilation=dilation,
54
+ groups=groups, bias=bias,
55
+ padding_mode=padding_mode,
56
+ device=device, dtype=dtype)
57
+ assert stride == 1
58
+ self.causal_padding = kernel_size - 1
59
+
60
+ def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
61
+ if cache.size(2) == 0:
62
+ x = F.pad(x, (self.causal_padding, 0), value=0.0)
63
+ else:
64
+ assert cache.size(2) == self.causal_padding
65
+ x = torch.concat([cache, x], dim=2)
66
+ cache = x[:, :, -self.causal_padding:]
67
+ x = super(CausalConv1d, self).forward(x)
68
+ return x, cache
69
+
70
+
71
+ class CausalBlock1D(Block1D):
72
+ def __init__(self, dim: int, dim_out: int):
73
+ super(CausalBlock1D, self).__init__(dim, dim_out)
74
+ self.block = torch.nn.Sequential(
75
+ CausalConv1d(dim, dim_out, 3),
76
+ Transpose(1, 2),
77
+ nn.LayerNorm(dim_out),
78
+ Transpose(1, 2),
79
+ nn.Mish(),
80
+ )
81
+
82
+ def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
83
+ output, cache = self.block[0](x * mask, cache)
84
+ for i in range(1, len(self.block)):
85
+ output = self.block[i](output)
86
+ return output * mask, cache
87
+
88
+
89
+ class CausalResnetBlock1D(ResnetBlock1D):
90
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
91
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
92
+ self.block1 = CausalBlock1D(dim, dim_out)
93
+ self.block2 = CausalBlock1D(dim_out, dim_out)
94
+
95
+ def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor,
96
+ block1_cache: torch.Tensor = torch.zeros(0, 0, 0), block2_cache: torch.Tensor = torch.zeros(0, 0, 0)
97
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
98
+ h, block1_cache = self.block1(x, mask, block1_cache)
99
+ h += self.mlp(time_emb).unsqueeze(-1)
100
+ h, block2_cache = self.block2(h, mask, block2_cache)
101
+ output = h + self.res_conv(x * mask)
102
+ return output, block1_cache, block2_cache
103
+
104
+
105
+ class CausalAttnProcessor2_0(AttnProcessor2_0):
106
+ r"""
107
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
108
+ """
109
+
110
+ def __init__(self):
111
+ super(CausalAttnProcessor2_0, self).__init__()
112
+
113
+ def __call__(
114
+ self,
115
+ attn: Attention,
116
+ hidden_states: torch.FloatTensor,
117
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
118
+ attention_mask: Optional[torch.FloatTensor] = None,
119
+ temb: Optional[torch.FloatTensor] = None,
120
+ cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
121
+ *args,
122
+ **kwargs,
123
+ ) -> Tuple[torch.FloatTensor, torch.Tensor]:
124
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
125
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. \
126
+ `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
127
+ deprecate("scale", "1.0.0", deprecation_message)
128
+
129
+ residual = hidden_states
130
+ if attn.spatial_norm is not None:
131
+ hidden_states = attn.spatial_norm(hidden_states, temb)
132
+
133
+ input_ndim = hidden_states.ndim
134
+
135
+ if input_ndim == 4:
136
+ batch_size, channel, height, width = hidden_states.shape
137
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
138
+
139
+ batch_size, sequence_length, _ = (
140
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
141
+ )
142
+
143
+ if attention_mask is not None:
144
+ # NOTE do not use attn.prepare_attention_mask as we have already provided the correct attention_mask
145
+ # scaled_dot_product_attention expects attention_mask shape to be
146
+ # (batch, heads, source_length, target_length)
147
+ attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, attn.heads, 1, 1)
148
+
149
+ if attn.group_norm is not None:
150
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
151
+
152
+ query = attn.to_q(hidden_states)
153
+
154
+ if encoder_hidden_states is None:
155
+ encoder_hidden_states = hidden_states
156
+ elif attn.norm_cross:
157
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
158
+
159
+ key_cache = attn.to_k(encoder_hidden_states)
160
+ value_cache = attn.to_v(encoder_hidden_states)
161
+ # NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2)
162
+ if cache.size(0) != 0:
163
+ key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
164
+ value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
165
+ else:
166
+ key, value = key_cache, value_cache
167
+ cache = torch.stack([key_cache, value_cache], dim=3)
168
+
169
+ inner_dim = key.shape[-1]
170
+ head_dim = inner_dim // attn.heads
171
+
172
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
173
+
174
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
175
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
176
+
177
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
178
+ # TODO: add support for attn.scale when we move to Torch 2.1
179
+ hidden_states = F.scaled_dot_product_attention(
180
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
181
+ )
182
+
183
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
184
+ hidden_states = hidden_states.to(query.dtype)
185
+
186
+ # linear proj
187
+ hidden_states = attn.to_out[0](hidden_states)
188
+ # dropout
189
+ hidden_states = attn.to_out[1](hidden_states)
190
+
191
+ if input_ndim == 4:
192
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
193
+
194
+ if attn.residual_connection:
195
+ hidden_states = hidden_states + residual
196
+
197
+ hidden_states = hidden_states / attn.rescale_output_factor
198
+
199
+ return hidden_states, cache
200
+
201
+
202
+ @maybe_allow_in_graph
203
+ class CausalAttention(Attention):
204
+ def __init__(
205
+ self,
206
+ query_dim: int,
207
+ cross_attention_dim: Optional[int] = None,
208
+ heads: int = 8,
209
+ dim_head: int = 64,
210
+ dropout: float = 0.0,
211
+ bias: bool = False,
212
+ upcast_attention: bool = False,
213
+ upcast_softmax: bool = False,
214
+ cross_attention_norm: Optional[str] = None,
215
+ cross_attention_norm_num_groups: int = 32,
216
+ qk_norm: Optional[str] = None,
217
+ added_kv_proj_dim: Optional[int] = None,
218
+ norm_num_groups: Optional[int] = None,
219
+ spatial_norm_dim: Optional[int] = None,
220
+ out_bias: bool = True,
221
+ scale_qk: bool = True,
222
+ only_cross_attention: bool = False,
223
+ eps: float = 1e-5,
224
+ rescale_output_factor: float = 1.0,
225
+ residual_connection: bool = False,
226
+ _from_deprecated_attn_block: bool = False,
227
+ processor: Optional["AttnProcessor2_0"] = None,
228
+ out_dim: int = None,
229
+ ):
230
+ super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax,
231
+ cross_attention_norm, cross_attention_norm_num_groups, qk_norm, added_kv_proj_dim, norm_num_groups,
232
+ spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection,
233
+ _from_deprecated_attn_block, processor, out_dim)
234
+ processor = CausalAttnProcessor2_0()
235
+ self.set_processor(processor)
236
+
237
+ def forward(
238
+ self,
239
+ hidden_states: torch.FloatTensor,
240
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
241
+ attention_mask: Optional[torch.FloatTensor] = None,
242
+ cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
243
+ **cross_attention_kwargs,
244
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
245
+ r"""
246
+ The forward method of the `Attention` class.
247
+
248
+ Args:
249
+ hidden_states (`torch.Tensor`):
250
+ The hidden states of the query.
251
+ encoder_hidden_states (`torch.Tensor`, *optional*):
252
+ The hidden states of the encoder.
253
+ attention_mask (`torch.Tensor`, *optional*):
254
+ The attention mask to use. If `None`, no mask is applied.
255
+ **cross_attention_kwargs:
256
+ Additional keyword arguments to pass along to the cross attention.
257
+
258
+ Returns:
259
+ `torch.Tensor`: The output of the attention layer.
260
+ """
261
+ # The `Attention` class can call different attention processors / attention functions
262
+ # here we simply pass along all tensors to the selected processor class
263
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
264
+
265
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
266
+ unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
267
+ if len(unused_kwargs) > 0:
268
+ logger.warning(
269
+ f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
270
+ )
271
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
272
+
273
+ return self.processor(
274
+ self,
275
+ hidden_states,
276
+ encoder_hidden_states=encoder_hidden_states,
277
+ attention_mask=attention_mask,
278
+ cache=cache,
279
+ **cross_attention_kwargs,
280
+ )
281
+
282
+
283
+ @maybe_allow_in_graph
284
+ class CausalBasicTransformerBlock(BasicTransformerBlock):
285
+ def __init__(
286
+ self,
287
+ dim: int,
288
+ num_attention_heads: int,
289
+ attention_head_dim: int,
290
+ dropout=0.0,
291
+ cross_attention_dim: Optional[int] = None,
292
+ activation_fn: str = "geglu",
293
+ num_embeds_ada_norm: Optional[int] = None,
294
+ attention_bias: bool = False,
295
+ only_cross_attention: bool = False,
296
+ double_self_attention: bool = False,
297
+ upcast_attention: bool = False,
298
+ norm_elementwise_affine: bool = True,
299
+ norm_type: str = "layer_norm",
300
+ final_dropout: bool = False,
301
+ ):
302
+ super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout,
303
+ cross_attention_dim, activation_fn, num_embeds_ada_norm,
304
+ attention_bias, only_cross_attention, double_self_attention,
305
+ upcast_attention, norm_elementwise_affine, norm_type, final_dropout)
306
+ self.attn1 = CausalAttention(
307
+ query_dim=dim,
308
+ heads=num_attention_heads,
309
+ dim_head=attention_head_dim,
310
+ dropout=dropout,
311
+ bias=attention_bias,
312
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
313
+ upcast_attention=upcast_attention,
314
+ )
315
+
316
+ def forward(
317
+ self,
318
+ hidden_states: torch.FloatTensor,
319
+ attention_mask: Optional[torch.FloatTensor] = None,
320
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
321
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
322
+ timestep: Optional[torch.LongTensor] = None,
323
+ cross_attention_kwargs: Dict[str, Any] = None,
324
+ class_labels: Optional[torch.LongTensor] = None,
325
+ cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
326
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
327
+ # Notice that normalization is always applied before the real computation in the following blocks.
328
+ # 1. Self-Attention
329
+ if self.use_ada_layer_norm:
330
+ norm_hidden_states = self.norm1(hidden_states, timestep)
331
+ elif self.use_ada_layer_norm_zero:
332
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
333
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
334
+ )
335
+ else:
336
+ norm_hidden_states = self.norm1(hidden_states)
337
+
338
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
339
+
340
+ attn_output, cache = self.attn1(
341
+ norm_hidden_states,
342
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
343
+ attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
344
+ cache=cache,
345
+ **cross_attention_kwargs,
346
+ )
347
+ if self.use_ada_layer_norm_zero:
348
+ attn_output = gate_msa.unsqueeze(1) * attn_output
349
+ hidden_states = attn_output + hidden_states
350
+
351
+ # 2. Cross-Attention
352
+ if self.attn2 is not None:
353
+ norm_hidden_states = (
354
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
355
+ )
356
+
357
+ attn_output = self.attn2(
358
+ norm_hidden_states,
359
+ encoder_hidden_states=encoder_hidden_states,
360
+ attention_mask=encoder_attention_mask,
361
+ **cross_attention_kwargs,
362
+ )
363
+ hidden_states = attn_output + hidden_states
364
+
365
+ # 3. Feed-forward
366
+ norm_hidden_states = self.norm3(hidden_states)
367
+
368
+ if self.use_ada_layer_norm_zero:
369
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
370
+
371
+ if self._chunk_size is not None:
372
+ # "feed_forward_chunk_size" can be used to save memory
373
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
374
+ raise ValueError(f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: \
375
+ {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.")
376
+
377
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
378
+ ff_output = torch.cat(
379
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
380
+ dim=self._chunk_dim,
381
+ )
382
+ else:
383
+ ff_output = self.ff(norm_hidden_states)
384
+
385
+ if self.use_ada_layer_norm_zero:
386
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
387
+
388
+ hidden_states = ff_output + hidden_states
389
+
390
+ return hidden_states, cache
391
+
392
+
393
+ class ConditionalDecoder(nn.Module):
394
+ def __init__(
395
+ self,
396
+ in_channels,
397
+ out_channels,
398
+ channels=(256, 256),
399
+ dropout=0.05,
400
+ attention_head_dim=64,
401
+ n_blocks=1,
402
+ num_mid_blocks=2,
403
+ num_heads=4,
404
+ act_fn="snake",
405
+ ):
406
+ """
407
+ This decoder requires an input with the same shape of the target. So, if your text content
408
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
409
+ """
410
+ super().__init__()
411
+ channels = tuple(channels)
412
+ self.in_channels = in_channels
413
+ self.out_channels = out_channels
414
+
415
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
416
+ time_embed_dim = channels[0] * 4
417
+ self.time_mlp = TimestepEmbedding(
418
+ in_channels=in_channels,
419
+ time_embed_dim=time_embed_dim,
420
+ act_fn="silu",
421
+ )
422
+ self.down_blocks = nn.ModuleList([])
423
+ self.mid_blocks = nn.ModuleList([])
424
+ self.up_blocks = nn.ModuleList([])
425
+
426
+ output_channel = in_channels
427
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
428
+ input_channel = output_channel
429
+ output_channel = channels[i]
430
+ is_last = i == len(channels) - 1
431
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
432
+ transformer_blocks = nn.ModuleList(
433
+ [
434
+ BasicTransformerBlock(
435
+ dim=output_channel,
436
+ num_attention_heads=num_heads,
437
+ attention_head_dim=attention_head_dim,
438
+ dropout=dropout,
439
+ activation_fn=act_fn,
440
+ )
441
+ for _ in range(n_blocks)
442
+ ]
443
+ )
444
+ downsample = (
445
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
446
+ )
447
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
448
+
449
+ for _ in range(num_mid_blocks):
450
+ input_channel = channels[-1]
451
+ out_channels = channels[-1]
452
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
453
+
454
+ transformer_blocks = nn.ModuleList(
455
+ [
456
+ BasicTransformerBlock(
457
+ dim=output_channel,
458
+ num_attention_heads=num_heads,
459
+ attention_head_dim=attention_head_dim,
460
+ dropout=dropout,
461
+ activation_fn=act_fn,
462
+ )
463
+ for _ in range(n_blocks)
464
+ ]
465
+ )
466
+
467
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
468
+
469
+ channels = channels[::-1] + (channels[0],)
470
+ for i in range(len(channels) - 1):
471
+ input_channel = channels[i] * 2
472
+ output_channel = channels[i + 1]
473
+ is_last = i == len(channels) - 2
474
+ resnet = ResnetBlock1D(
475
+ dim=input_channel,
476
+ dim_out=output_channel,
477
+ time_emb_dim=time_embed_dim,
478
+ )
479
+ transformer_blocks = nn.ModuleList(
480
+ [
481
+ BasicTransformerBlock(
482
+ dim=output_channel,
483
+ num_attention_heads=num_heads,
484
+ attention_head_dim=attention_head_dim,
485
+ dropout=dropout,
486
+ activation_fn=act_fn,
487
+ )
488
+ for _ in range(n_blocks)
489
+ ]
490
+ )
491
+ upsample = (
492
+ Upsample1D(output_channel, use_conv_transpose=True)
493
+ if not is_last
494
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
495
+ )
496
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
497
+ self.final_block = Block1D(channels[-1], channels[-1])
498
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
499
+ self.initialize_weights()
500
+
501
+ def initialize_weights(self):
502
+ for m in self.modules():
503
+ if isinstance(m, nn.Conv1d):
504
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
505
+ if m.bias is not None:
506
+ nn.init.constant_(m.bias, 0)
507
+ elif isinstance(m, nn.GroupNorm):
508
+ nn.init.constant_(m.weight, 1)
509
+ nn.init.constant_(m.bias, 0)
510
+ elif isinstance(m, nn.Linear):
511
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
512
+ if m.bias is not None:
513
+ nn.init.constant_(m.bias, 0)
514
+
515
+ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
516
+ """Forward pass of the UNet1DConditional model.
517
+
518
+ Args:
519
+ x (torch.Tensor): shape (batch_size, in_channels, time)
520
+ mask (_type_): shape (batch_size, 1, time)
521
+ t (_type_): shape (batch_size)
522
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
523
+ cond (_type_, optional): placeholder for future use. Defaults to None.
524
+
525
+ Raises:
526
+ ValueError: _description_
527
+ ValueError: _description_
528
+
529
+ Returns:
530
+ _type_: _description_
531
+ """
532
+
533
+ t = self.time_embeddings(t).to(t.dtype)
534
+ t = self.time_mlp(t)
535
+
536
+ x = pack([x, mu], "b * t")[0]
537
+
538
+ if spks is not None:
539
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
540
+ x = pack([x, spks], "b * t")[0]
541
+ if cond is not None:
542
+ x = pack([x, cond], "b * t")[0]
543
+
544
+ hiddens = []
545
+ masks = [mask]
546
+ for resnet, transformer_blocks, downsample in self.down_blocks:
547
+ mask_down = masks[-1]
548
+ x = resnet(x, mask_down, t)
549
+ x = rearrange(x, "b c t -> b t c").contiguous()
550
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
551
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
552
+ for transformer_block in transformer_blocks:
553
+ x = transformer_block(
554
+ hidden_states=x,
555
+ attention_mask=attn_mask,
556
+ timestep=t,
557
+ )
558
+ x = rearrange(x, "b t c -> b c t").contiguous()
559
+ hiddens.append(x) # Save hidden states for skip connections
560
+ x = downsample(x * mask_down)
561
+ masks.append(mask_down[:, :, ::2])
562
+ masks = masks[:-1]
563
+ mask_mid = masks[-1]
564
+
565
+ for resnet, transformer_blocks in self.mid_blocks:
566
+ x = resnet(x, mask_mid, t)
567
+ x = rearrange(x, "b c t -> b t c").contiguous()
568
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
569
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
570
+ for transformer_block in transformer_blocks:
571
+ x = transformer_block(
572
+ hidden_states=x,
573
+ attention_mask=attn_mask,
574
+ timestep=t,
575
+ )
576
+ x = rearrange(x, "b t c -> b c t").contiguous()
577
+
578
+ for resnet, transformer_blocks, upsample in self.up_blocks:
579
+ mask_up = masks.pop()
580
+ skip = hiddens.pop()
581
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
582
+ x = resnet(x, mask_up, t)
583
+ x = rearrange(x, "b c t -> b t c").contiguous()
584
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
585
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
586
+ for transformer_block in transformer_blocks:
587
+ x = transformer_block(
588
+ hidden_states=x,
589
+ attention_mask=attn_mask,
590
+ timestep=t,
591
+ )
592
+ x = rearrange(x, "b t c -> b c t").contiguous()
593
+ x = upsample(x * mask_up)
594
+ x = self.final_block(x, mask_up)
595
+ output = self.final_proj(x * mask_up)
596
+ return output * mask
597
+
598
+
599
+ class CausalConditionalDecoder(ConditionalDecoder):
600
+ def __init__(
601
+ self,
602
+ in_channels,
603
+ out_channels,
604
+ channels=(256, 256),
605
+ dropout=0.05,
606
+ attention_head_dim=64,
607
+ n_blocks=1,
608
+ num_mid_blocks=2,
609
+ num_heads=4,
610
+ act_fn="snake",
611
+ static_chunk_size=50,
612
+ num_decoding_left_chunks=2,
613
+ ):
614
+ """
615
+ This decoder requires an input with the same shape of the target. So, if your text content
616
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
617
+ """
618
+ torch.nn.Module.__init__(self)
619
+ channels = tuple(channels)
620
+ self.in_channels = in_channels
621
+ self.out_channels = out_channels
622
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
623
+ time_embed_dim = channels[0] * 4
624
+ self.time_mlp = TimestepEmbedding(
625
+ in_channels=in_channels,
626
+ time_embed_dim=time_embed_dim,
627
+ act_fn="silu",
628
+ )
629
+ self.static_chunk_size = static_chunk_size
630
+ self.num_decoding_left_chunks = num_decoding_left_chunks
631
+ self.down_blocks = nn.ModuleList([])
632
+ self.mid_blocks = nn.ModuleList([])
633
+ self.up_blocks = nn.ModuleList([])
634
+
635
+ output_channel = in_channels
636
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
637
+ input_channel = output_channel
638
+ output_channel = channels[i]
639
+ is_last = i == len(channels) - 1
640
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
641
+ transformer_blocks = nn.ModuleList(
642
+ [
643
+ CausalBasicTransformerBlock(
644
+ dim=output_channel,
645
+ num_attention_heads=num_heads,
646
+ attention_head_dim=attention_head_dim,
647
+ dropout=dropout,
648
+ activation_fn=act_fn,
649
+ )
650
+ for _ in range(n_blocks)
651
+ ]
652
+ )
653
+ downsample = (
654
+ Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
655
+ )
656
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
657
+
658
+ for _ in range(num_mid_blocks):
659
+ input_channel = channels[-1]
660
+ out_channels = channels[-1]
661
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
662
+
663
+ transformer_blocks = nn.ModuleList(
664
+ [
665
+ CausalBasicTransformerBlock(
666
+ dim=output_channel,
667
+ num_attention_heads=num_heads,
668
+ attention_head_dim=attention_head_dim,
669
+ dropout=dropout,
670
+ activation_fn=act_fn,
671
+ )
672
+ for _ in range(n_blocks)
673
+ ]
674
+ )
675
+
676
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
677
+
678
+ channels = channels[::-1] + (channels[0],)
679
+ for i in range(len(channels) - 1):
680
+ input_channel = channels[i] * 2
681
+ output_channel = channels[i + 1]
682
+ is_last = i == len(channels) - 2
683
+ resnet = CausalResnetBlock1D(
684
+ dim=input_channel,
685
+ dim_out=output_channel,
686
+ time_emb_dim=time_embed_dim,
687
+ )
688
+ transformer_blocks = nn.ModuleList(
689
+ [
690
+ CausalBasicTransformerBlock(
691
+ dim=output_channel,
692
+ num_attention_heads=num_heads,
693
+ attention_head_dim=attention_head_dim,
694
+ dropout=dropout,
695
+ activation_fn=act_fn,
696
+ )
697
+ for _ in range(n_blocks)
698
+ ]
699
+ )
700
+ upsample = (
701
+ Upsample1D(output_channel, use_conv_transpose=True)
702
+ if not is_last
703
+ else CausalConv1d(output_channel, output_channel, 3)
704
+ )
705
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
706
+ self.final_block = CausalBlock1D(channels[-1], channels[-1])
707
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
708
+ self.initialize_weights()
709
+
710
+ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
711
+ """Forward pass of the UNet1DConditional model.
712
+
713
+ Args:
714
+ x (torch.Tensor): shape (batch_size, in_channels, time)
715
+ mask (_type_): shape (batch_size, 1, time)
716
+ t (_type_): shape (batch_size)
717
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
718
+ cond (_type_, optional): placeholder for future use. Defaults to None.
719
+
720
+ Raises:
721
+ ValueError: _description_
722
+ ValueError: _description_
723
+
724
+ Returns:
725
+ _type_: _description_
726
+ """
727
+
728
+ t = self.time_embeddings(t).to(t.dtype)
729
+ t = self.time_mlp(t)
730
+
731
+ x = pack([x, mu], "b * t")[0]
732
+
733
+ if spks is not None:
734
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
735
+ x = pack([x, spks], "b * t")[0]
736
+ if cond is not None:
737
+ x = pack([x, cond], "b * t")[0]
738
+
739
+ hiddens = []
740
+ masks = [mask]
741
+ for resnet, transformer_blocks, downsample in self.down_blocks:
742
+ mask_down = masks[-1]
743
+ x, _, _ = resnet(x, mask_down, t)
744
+ x = rearrange(x, "b c t -> b t c").contiguous()
745
+ if streaming is True:
746
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
747
+ else:
748
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
749
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
750
+ for transformer_block in transformer_blocks:
751
+ x, _ = transformer_block(
752
+ hidden_states=x,
753
+ attention_mask=attn_mask,
754
+ timestep=t,
755
+ )
756
+ x = rearrange(x, "b t c -> b c t").contiguous()
757
+ hiddens.append(x) # Save hidden states for skip connections
758
+ x, _ = downsample(x * mask_down)
759
+ masks.append(mask_down[:, :, ::2])
760
+ masks = masks[:-1]
761
+ mask_mid = masks[-1]
762
+
763
+ for resnet, transformer_blocks in self.mid_blocks:
764
+ x, _, _ = resnet(x, mask_mid, t)
765
+ x = rearrange(x, "b c t -> b t c").contiguous()
766
+ if streaming is True:
767
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
768
+ else:
769
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
770
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
771
+ for transformer_block in transformer_blocks:
772
+ x, _ = transformer_block(
773
+ hidden_states=x,
774
+ attention_mask=attn_mask,
775
+ timestep=t,
776
+ )
777
+ x = rearrange(x, "b t c -> b c t").contiguous()
778
+
779
+ for resnet, transformer_blocks, upsample in self.up_blocks:
780
+ mask_up = masks.pop()
781
+ skip = hiddens.pop()
782
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
783
+ x, _, _ = resnet(x, mask_up, t)
784
+ x = rearrange(x, "b c t -> b t c").contiguous()
785
+ if streaming is True:
786
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
787
+ else:
788
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
789
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
790
+ for transformer_block in transformer_blocks:
791
+ x, _ = transformer_block(
792
+ hidden_states=x,
793
+ attention_mask=attn_mask,
794
+ timestep=t,
795
+ )
796
+ x = rearrange(x, "b t c -> b c t").contiguous()
797
+ x, _ = upsample(x * mask_up)
798
+ x, _ = self.final_block(x, mask_up)
799
+ output = self.final_proj(x * mask_up)
800
+ return output * mask
801
+
802
+ def forward_chunk(self, x, mask, mu, t, spks=None, cond=None,
803
+ down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
804
+ down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
805
+ mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
806
+ mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
807
+ up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
808
+ up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
809
+ final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0)
810
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
811
+ """Forward pass of the UNet1DConditional model.
812
+
813
+ Args:
814
+ x (torch.Tensor): shape (batch_size, in_channels, time)
815
+ mask (_type_): shape (batch_size, 1, time)
816
+ t (_type_): shape (batch_size)
817
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
818
+ cond (_type_, optional): placeholder for future use. Defaults to None.
819
+
820
+ Raises:
821
+ ValueError: _description_
822
+ ValueError: _description_
823
+
824
+ Returns:
825
+ _type_: _description_
826
+ """
827
+
828
+ t = self.time_embeddings(t).to(t.dtype)
829
+ t = self.time_mlp(t)
830
+
831
+ x = pack([x, mu], "b * t")[0]
832
+
833
+ if spks is not None:
834
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
835
+ x = pack([x, spks], "b * t")[0]
836
+ if cond is not None:
837
+ x = pack([x, cond], "b * t")[0]
838
+
839
+ hiddens = []
840
+ masks = [mask]
841
+
842
+ down_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
843
+ mid_blocks_kv_cache_new = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x.device)
844
+ up_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
845
+ for index, (resnet, transformer_blocks, downsample) in enumerate(self.down_blocks):
846
+ mask_down = masks[-1]
847
+ x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = \
848
+ resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576])
849
+ x = rearrange(x, "b c t -> b t c").contiguous()
850
+ attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + down_blocks_kv_cache.size(3), device=x.device).bool()
851
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
852
+ for i, transformer_block in enumerate(transformer_blocks):
853
+ x, down_blocks_kv_cache_new[index, i] = transformer_block(
854
+ hidden_states=x,
855
+ attention_mask=attn_mask,
856
+ timestep=t,
857
+ cache=down_blocks_kv_cache[index, i],
858
+ )
859
+ x = rearrange(x, "b t c -> b c t").contiguous()
860
+ hiddens.append(x) # Save hidden states for skip connections
861
+ x, down_blocks_conv_cache[index][:, 576:] = downsample(x * mask_down, down_blocks_conv_cache[index][:, 576:])
862
+ masks.append(mask_down[:, :, ::2])
863
+ masks = masks[:-1]
864
+ mask_mid = masks[-1]
865
+
866
+ for index, (resnet, transformer_blocks) in enumerate(self.mid_blocks):
867
+ x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = \
868
+ resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:])
869
+ x = rearrange(x, "b c t -> b t c").contiguous()
870
+ attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + mid_blocks_kv_cache.size(3), device=x.device).bool()
871
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
872
+ for i, transformer_block in enumerate(transformer_blocks):
873
+ x, mid_blocks_kv_cache_new[index, i] = transformer_block(
874
+ hidden_states=x,
875
+ attention_mask=attn_mask,
876
+ timestep=t,
877
+ cache=mid_blocks_kv_cache[index, i]
878
+ )
879
+ x = rearrange(x, "b t c -> b c t").contiguous()
880
+
881
+ for index, (resnet, transformer_blocks, upsample) in enumerate(self.up_blocks):
882
+ mask_up = masks.pop()
883
+ skip = hiddens.pop()
884
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
885
+ x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = \
886
+ resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768])
887
+ x = rearrange(x, "b c t -> b t c").contiguous()
888
+ attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + up_blocks_kv_cache.size(3), device=x.device).bool()
889
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
890
+ for i, transformer_block in enumerate(transformer_blocks):
891
+ x, up_blocks_kv_cache_new[index, i] = transformer_block(
892
+ hidden_states=x,
893
+ attention_mask=attn_mask,
894
+ timestep=t,
895
+ cache=up_blocks_kv_cache[index, i]
896
+ )
897
+ x = rearrange(x, "b t c -> b c t").contiguous()
898
+ x, up_blocks_conv_cache[index][:, 768:] = upsample(x * mask_up, up_blocks_conv_cache[index][:, 768:])
899
+ x, final_blocks_conv_cache = self.final_block(x, mask_up, final_blocks_conv_cache)
900
+ output = self.final_proj(x * mask_up)
901
+ return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, \
902
+ up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache
model/cosyvoice/flow/flow.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+ from typing import Dict, Optional
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+ from omegaconf import DictConfig
21
+ from cosyvoice.utils.mask import make_pad_mask
22
+
23
+
24
+ class MaskedDiffWithXvec(torch.nn.Module):
25
+ def __init__(self,
26
+ input_size: int = 512,
27
+ output_size: int = 80,
28
+ spk_embed_dim: int = 192,
29
+ output_type: str = "mel",
30
+ vocab_size: int = 4096,
31
+ input_frame_rate: int = 50,
32
+ only_mask_loss: bool = True,
33
+ encoder: torch.nn.Module = None,
34
+ length_regulator: torch.nn.Module = None,
35
+ decoder: torch.nn.Module = None,
36
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
37
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
38
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
39
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
40
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
41
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
42
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
43
+ super().__init__()
44
+ self.input_size = input_size
45
+ self.output_size = output_size
46
+ self.decoder_conf = decoder_conf
47
+ self.mel_feat_conf = mel_feat_conf
48
+ self.vocab_size = vocab_size
49
+ self.output_type = output_type
50
+ self.input_frame_rate = input_frame_rate
51
+ logging.info(f"input frame rate={self.input_frame_rate}")
52
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
53
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
54
+ self.encoder = encoder
55
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
56
+ self.decoder = decoder
57
+ self.length_regulator = length_regulator
58
+ self.only_mask_loss = only_mask_loss
59
+
60
+ def forward(
61
+ self,
62
+ batch: dict,
63
+ device: torch.device,
64
+ ) -> Dict[str, Optional[torch.Tensor]]:
65
+ token = batch['speech_token'].to(device)
66
+ token_len = batch['speech_token_len'].to(device)
67
+ feat = batch['speech_feat'].to(device)
68
+ feat_len = batch['speech_feat_len'].to(device)
69
+ embedding = batch['embedding'].to(device)
70
+
71
+ # xvec projection
72
+ embedding = F.normalize(embedding, dim=1)
73
+ embedding = self.spk_embed_affine_layer(embedding)
74
+
75
+ # concat text and prompt_text
76
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
77
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
78
+
79
+ # text encode
80
+ h, h_lengths = self.encoder(token, token_len)
81
+ h = self.encoder_proj(h)
82
+ h, h_lengths = self.length_regulator(h, feat_len)
83
+
84
+ # get conditions
85
+ conds = torch.zeros(feat.shape, device=token.device)
86
+ for i, j in enumerate(feat_len):
87
+ if random.random() < 0.5:
88
+ continue
89
+ index = random.randint(0, int(0.3 * j))
90
+ conds[i, :index] = feat[i, :index]
91
+ conds = conds.transpose(1, 2)
92
+
93
+ mask = (~make_pad_mask(feat_len)).to(h)
94
+ # NOTE this is unnecessary, feat/h already same shape
95
+ feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
96
+ loss, _ = self.decoder.compute_loss(
97
+ feat.transpose(1, 2).contiguous(),
98
+ mask.unsqueeze(1),
99
+ h.transpose(1, 2).contiguous(),
100
+ embedding,
101
+ cond=conds
102
+ )
103
+ return {'loss': loss}
104
+
105
+ @torch.inference_mode()
106
+ def inference(self,
107
+ token,
108
+ token_len,
109
+ prompt_token,
110
+ prompt_token_len,
111
+ prompt_feat,
112
+ prompt_feat_len,
113
+ embedding,
114
+ flow_cache):
115
+ assert token.shape[0] == 1
116
+ # xvec projection
117
+ embedding = F.normalize(embedding, dim=1)
118
+ embedding = self.spk_embed_affine_layer(embedding)
119
+
120
+ # concat speech token and prompt speech token
121
+ token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
122
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
123
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
124
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
125
+
126
+ # text encode
127
+ h, h_lengths = self.encoder(token, token_len)
128
+ h = self.encoder_proj(h)
129
+ mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
130
+ h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
131
+
132
+ # get conditions
133
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
134
+ conds[:, :mel_len1] = prompt_feat
135
+ conds = conds.transpose(1, 2)
136
+
137
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
138
+ feat, flow_cache = self.decoder(
139
+ mu=h.transpose(1, 2).contiguous(),
140
+ mask=mask.unsqueeze(1),
141
+ spks=embedding,
142
+ cond=conds,
143
+ n_timesteps=10,
144
+ prompt_len=mel_len1,
145
+ cache=flow_cache
146
+ )
147
+ feat = feat[:, :, mel_len1:]
148
+ assert feat.shape[2] == mel_len2
149
+ return feat.float(), flow_cache
150
+
151
+
152
+ class CausalMaskedDiffWithXvec(torch.nn.Module):
153
+ def __init__(self,
154
+ input_size: int = 512,
155
+ output_size: int = 80,
156
+ spk_embed_dim: int = 192,
157
+ output_type: str = "mel",
158
+ vocab_size: int = 4096,
159
+ input_frame_rate: int = 50,
160
+ only_mask_loss: bool = True,
161
+ token_mel_ratio: int = 2,
162
+ pre_lookahead_len: int = 3,
163
+ encoder: torch.nn.Module = None,
164
+ decoder: torch.nn.Module = None,
165
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
166
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
167
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
168
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
169
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
170
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
171
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
172
+ super().__init__()
173
+ self.input_size = input_size
174
+ self.output_size = output_size
175
+ self.decoder_conf = decoder_conf
176
+ self.mel_feat_conf = mel_feat_conf
177
+ self.vocab_size = vocab_size
178
+ self.output_type = output_type
179
+ self.input_frame_rate = input_frame_rate
180
+ logging.info(f"input frame rate={self.input_frame_rate}")
181
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
182
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
183
+ self.encoder = encoder
184
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
185
+ self.decoder = decoder
186
+ self.only_mask_loss = only_mask_loss
187
+ self.token_mel_ratio = token_mel_ratio
188
+ self.pre_lookahead_len = pre_lookahead_len
189
+
190
+ def forward(
191
+ self,
192
+ batch: dict,
193
+ device: torch.device,
194
+ ) -> Dict[str, Optional[torch.Tensor]]:
195
+ token = batch['speech_token'].to(device)
196
+ token_len = batch['speech_token_len'].to(device)
197
+ feat = batch['speech_feat'].to(device)
198
+ feat_len = batch['speech_feat_len'].to(device)
199
+ embedding = batch['embedding'].to(device)
200
+
201
+ # NOTE unified training, static_chunk_size > 0 or = 0
202
+ streaming = True if random.random() < 0.5 else False
203
+
204
+ # xvec projection
205
+ embedding = F.normalize(embedding, dim=1)
206
+ embedding = self.spk_embed_affine_layer(embedding)
207
+
208
+ # concat text and prompt_text
209
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
210
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
211
+
212
+ # text encode
213
+ h, h_lengths = self.encoder(token, token_len, streaming=streaming)
214
+ h = self.encoder_proj(h)
215
+
216
+ # get conditions
217
+ feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
218
+ conds = torch.zeros(feat.shape, device=token.device)
219
+ for i, j in enumerate(feat_len):
220
+ if random.random() < 0.5:
221
+ continue
222
+ index = random.randint(0, int(0.3 * j))
223
+ conds[i, :index] = feat[i, :index]
224
+ conds = conds.transpose(1, 2)
225
+
226
+ mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
227
+ loss, _ = self.decoder.compute_loss(
228
+ feat.transpose(1, 2).contiguous(),
229
+ mask.unsqueeze(1),
230
+ h.transpose(1, 2).contiguous(),
231
+ embedding,
232
+ cond=conds,
233
+ streaming=streaming,
234
+ )
235
+ return {'loss': loss}
236
+
237
+ @torch.inference_mode()
238
+ def inference(self,
239
+ token,
240
+ token_len,
241
+ prompt_token,
242
+ prompt_token_len,
243
+ prompt_feat,
244
+ prompt_feat_len,
245
+ embedding,
246
+ cache,
247
+ finalize):
248
+ assert token.shape[0] == 1
249
+ # xvec projection
250
+ embedding = F.normalize(embedding, dim=1)
251
+ embedding = self.spk_embed_affine_layer(embedding)
252
+
253
+ # concat text and prompt_text
254
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
255
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
256
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
257
+
258
+ # text encode
259
+ if finalize is True:
260
+ h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, **cache['encoder_cache'])
261
+ else:
262
+ token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
263
+ h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, context=context, **cache['encoder_cache'])
264
+ cache['encoder_cache']['offset'] = encoder_cache[0]
265
+ cache['encoder_cache']['pre_lookahead_layer_conv2_cache'] = encoder_cache[1]
266
+ cache['encoder_cache']['encoders_kv_cache'] = encoder_cache[2]
267
+ cache['encoder_cache']['upsample_offset'] = encoder_cache[3]
268
+ cache['encoder_cache']['upsample_conv_cache'] = encoder_cache[4]
269
+ cache['encoder_cache']['upsample_kv_cache'] = encoder_cache[5]
270
+ mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
271
+ h = self.encoder_proj(h)
272
+
273
+ # get conditions
274
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
275
+ conds[:, :mel_len1] = prompt_feat
276
+ conds = conds.transpose(1, 2)
277
+
278
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
279
+ feat, cache['decoder_cache'] = self.decoder(
280
+ mu=h.transpose(1, 2).contiguous(),
281
+ mask=mask.unsqueeze(1),
282
+ spks=embedding,
283
+ cond=conds,
284
+ n_timesteps=10,
285
+ cache=cache['decoder_cache']
286
+ )
287
+ feat = feat[:, :, mel_len1:]
288
+ assert feat.shape[2] == mel_len2
289
+ return feat.float(), cache
model/cosyvoice/flow/flow_matching.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import threading
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from matcha.models.components.flow_matching import BASECFM
18
+
19
+
20
+ class ConditionalCFM(BASECFM):
21
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
22
+ super().__init__(
23
+ n_feats=in_channels,
24
+ cfm_params=cfm_params,
25
+ n_spks=n_spks,
26
+ spk_emb_dim=spk_emb_dim,
27
+ )
28
+ self.t_scheduler = cfm_params.t_scheduler
29
+ self.training_cfg_rate = cfm_params.training_cfg_rate
30
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
31
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
32
+ # Just change the architecture of the estimator here
33
+ self.estimator = estimator
34
+ self.lock = threading.Lock()
35
+
36
+ @torch.inference_mode()
37
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
38
+ """Forward diffusion
39
+
40
+ Args:
41
+ mu (torch.Tensor): output of encoder
42
+ shape: (batch_size, n_feats, mel_timesteps)
43
+ mask (torch.Tensor): output_mask
44
+ shape: (batch_size, 1, mel_timesteps)
45
+ n_timesteps (int): number of diffusion steps
46
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
47
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
48
+ shape: (batch_size, spk_emb_dim)
49
+ cond: Not used but kept for future purposes
50
+
51
+ Returns:
52
+ sample: generated mel-spectrogram
53
+ shape: (batch_size, n_feats, mel_timesteps)
54
+ """
55
+
56
+ z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
57
+ cache_size = cache.shape[2]
58
+ # fix prompt and overlap part mu and z
59
+ if cache_size != 0:
60
+ z[:, :, :cache_size] = cache[:, :, :, 0]
61
+ mu[:, :, :cache_size] = cache[:, :, :, 1]
62
+ z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
63
+ mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
64
+ cache = torch.stack([z_cache, mu_cache], dim=-1)
65
+
66
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
67
+ if self.t_scheduler == 'cosine':
68
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
69
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
70
+
71
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
72
+ """
73
+ Fixed euler solver for ODEs.
74
+ Args:
75
+ x (torch.Tensor): random noise
76
+ t_span (torch.Tensor): n_timesteps interpolated
77
+ shape: (n_timesteps + 1,)
78
+ mu (torch.Tensor): output of encoder
79
+ shape: (batch_size, n_feats, mel_timesteps)
80
+ mask (torch.Tensor): output_mask
81
+ shape: (batch_size, 1, mel_timesteps)
82
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
83
+ shape: (batch_size, spk_emb_dim)
84
+ cond: Not used but kept for future purposes
85
+ """
86
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
87
+ t = t.unsqueeze(dim=0)
88
+
89
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
90
+ # Or in future might add like a return_all_steps flag
91
+ sol = []
92
+
93
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
94
+ x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
95
+ mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
96
+ mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
97
+ t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
98
+ spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
99
+ cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
100
+ for step in range(1, len(t_span)):
101
+ # Classifier-Free Guidance inference introduced in VoiceBox
102
+ x_in[:] = x
103
+ mask_in[:] = mask
104
+ mu_in[0] = mu
105
+ t_in[:] = t.unsqueeze(0)
106
+ spks_in[0] = spks
107
+ cond_in[0] = cond
108
+ dphi_dt = self.forward_estimator(
109
+ x_in, mask_in,
110
+ mu_in, t_in,
111
+ spks_in,
112
+ cond_in
113
+ )
114
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
115
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
116
+ x = x + dt * dphi_dt
117
+ t = t + dt
118
+ sol.append(x)
119
+ if step < len(t_span) - 1:
120
+ dt = t_span[step + 1] - t
121
+
122
+ return sol[-1].float()
123
+
124
+ def forward_estimator(self, x, mask, mu, t, spks, cond):
125
+ if isinstance(self.estimator, torch.nn.Module):
126
+ return self.estimator(x, mask, mu, t, spks, cond)
127
+ else:
128
+ with self.lock:
129
+ self.estimator.set_input_shape('x', (2, 80, x.size(2)))
130
+ self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
131
+ self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
132
+ self.estimator.set_input_shape('t', (2,))
133
+ self.estimator.set_input_shape('spks', (2, 80))
134
+ self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
135
+ # run trt engine
136
+ assert self.estimator.execute_v2([x.contiguous().data_ptr(),
137
+ mask.contiguous().data_ptr(),
138
+ mu.contiguous().data_ptr(),
139
+ t.contiguous().data_ptr(),
140
+ spks.contiguous().data_ptr(),
141
+ cond.contiguous().data_ptr(),
142
+ x.data_ptr()]) is True
143
+ return x
144
+
145
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
146
+ """Computes diffusion loss
147
+
148
+ Args:
149
+ x1 (torch.Tensor): Target
150
+ shape: (batch_size, n_feats, mel_timesteps)
151
+ mask (torch.Tensor): target mask
152
+ shape: (batch_size, 1, mel_timesteps)
153
+ mu (torch.Tensor): output of encoder
154
+ shape: (batch_size, n_feats, mel_timesteps)
155
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
156
+ shape: (batch_size, spk_emb_dim)
157
+
158
+ Returns:
159
+ loss: conditional flow matching loss
160
+ y: conditional flow
161
+ shape: (batch_size, n_feats, mel_timesteps)
162
+ """
163
+ b, _, t = mu.shape
164
+
165
+ # random timestep
166
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
167
+ if self.t_scheduler == 'cosine':
168
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
169
+ # sample noise p(x_0)
170
+ z = torch.randn_like(x1)
171
+
172
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
173
+ u = x1 - (1 - self.sigma_min) * z
174
+
175
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
176
+ if self.training_cfg_rate > 0:
177
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
178
+ mu = mu * cfg_mask.view(-1, 1, 1)
179
+ spks = spks * cfg_mask.view(-1, 1)
180
+ cond = cond * cfg_mask.view(-1, 1, 1)
181
+
182
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
183
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
184
+ return loss, y
185
+
186
+
187
+ class CausalConditionalCFM(ConditionalCFM):
188
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
189
+ super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
190
+ self.rand_noise = torch.randn([1, 80, 50 * 300])
191
+
192
+ @torch.inference_mode()
193
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, cache={}):
194
+ """Forward diffusion
195
+
196
+ Args:
197
+ mu (torch.Tensor): output of encoder
198
+ shape: (batch_size, n_feats, mel_timesteps)
199
+ mask (torch.Tensor): output_mask
200
+ shape: (batch_size, 1, mel_timesteps)
201
+ n_timesteps (int): number of diffusion steps
202
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
203
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
204
+ shape: (batch_size, spk_emb_dim)
205
+ cond: Not used but kept for future purposes
206
+
207
+ Returns:
208
+ sample: generated mel-spectrogram
209
+ shape: (batch_size, n_feats, mel_timesteps)
210
+ """
211
+
212
+ offset = cache.pop('offset')
213
+ z = self.rand_noise[:, :, :mu.size(2) + offset].to(mu.device).to(mu.dtype) * temperature
214
+ z = z[:, :, offset:]
215
+ offset += mu.size(2)
216
+ # fix prompt and overlap part mu and z
217
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
218
+ if self.t_scheduler == 'cosine':
219
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
220
+ mel, cache = self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, cache=cache)
221
+ cache['offset'] = offset
222
+ return mel, cache
223
+
224
+ def solve_euler(self, x, t_span, mu, mask, spks, cond, cache):
225
+ """
226
+ Fixed euler solver for ODEs.
227
+ Args:
228
+ x (torch.Tensor): random noise
229
+ t_span (torch.Tensor): n_timesteps interpolated
230
+ shape: (n_timesteps + 1,)
231
+ mu (torch.Tensor): output of encoder
232
+ shape: (batch_size, n_feats, mel_timesteps)
233
+ mask (torch.Tensor): output_mask
234
+ shape: (batch_size, 1, mel_timesteps)
235
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
236
+ shape: (batch_size, spk_emb_dim)
237
+ cond: Not used but kept for future purposes
238
+ """
239
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
240
+ t = t.unsqueeze(dim=0)
241
+
242
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
243
+ # Or in future might add like a return_all_steps flag
244
+ sol = []
245
+
246
+ # estimator cache for each step
247
+ down_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x)
248
+ mid_blocks_kv_cache_new = torch.zeros(10, 12, 4, 2, x.size(2), 512, 2).to(x)
249
+ up_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x)
250
+
251
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
252
+ x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
253
+ mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
254
+ mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
255
+ t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
256
+ spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
257
+ cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
258
+ for step in range(1, len(t_span)):
259
+ # Classifier-Free Guidance inference introduced in VoiceBox
260
+ x_in[:] = x
261
+ mask_in[:] = mask
262
+ mu_in[0] = mu
263
+ t_in[:] = t.unsqueeze(0)
264
+ spks_in[0] = spks
265
+ cond_in[0] = cond
266
+ cache_step = {k: v[step - 1] for k, v in cache.items()}
267
+ dphi_dt, cache_step = self.forward_estimator(
268
+ x_in, mask_in,
269
+ mu_in, t_in,
270
+ spks_in,
271
+ cond_in,
272
+ cache_step
273
+ )
274
+ cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
275
+ down_blocks_kv_cache_new[step - 1] = cache_step[1]
276
+ cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
277
+ mid_blocks_kv_cache_new[step - 1] = cache_step[3]
278
+ cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
279
+ up_blocks_kv_cache_new[step - 1] = cache_step[5]
280
+ cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
281
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
282
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
283
+ x = x + dt * dphi_dt
284
+ t = t + dt
285
+ sol.append(x)
286
+ if step < len(t_span) - 1:
287
+ dt = t_span[step + 1] - t
288
+ cache['down_blocks_kv_cache'] = torch.concat([cache['down_blocks_kv_cache'], down_blocks_kv_cache_new], dim=4)
289
+ cache['mid_blocks_kv_cache'] = torch.concat([cache['mid_blocks_kv_cache'], mid_blocks_kv_cache_new], dim=4)
290
+ cache['up_blocks_kv_cache'] = torch.concat([cache['up_blocks_kv_cache'], up_blocks_kv_cache_new], dim=4)
291
+ return sol[-1].float(), cache
292
+
293
+ def forward_estimator(self, x, mask, mu, t, spks, cond, cache):
294
+ if isinstance(self.estimator, torch.nn.Module):
295
+ x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache)
296
+ cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7)
297
+ else:
298
+ with self.lock:
299
+ self.estimator.set_input_shape('x', (2, 80, x.size(2)))
300
+ self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
301
+ self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
302
+ self.estimator.set_input_shape('t', (2,))
303
+ self.estimator.set_input_shape('spks', (2, 80))
304
+ self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
305
+ self.estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape)
306
+ self.estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape)
307
+ self.estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape)
308
+ self.estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape)
309
+ self.estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape)
310
+ self.estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape)
311
+ self.estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape)
312
+ # run trt engine
313
+ down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
314
+ mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x)
315
+ up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
316
+ assert self.estimator.execute_v2([x.contiguous().data_ptr(),
317
+ mask.contiguous().data_ptr(),
318
+ mu.contiguous().data_ptr(),
319
+ t.contiguous().data_ptr(),
320
+ spks.contiguous().data_ptr(),
321
+ cond.contiguous().data_ptr(),
322
+ cache['down_blocks_conv_cache'].contiguous().data_ptr(),
323
+ cache['down_blocks_kv_cache'].contiguous().data_ptr(),
324
+ cache['mid_blocks_conv_cache'].contiguous().data_ptr(),
325
+ cache['mid_blocks_kv_cache'].contiguous().data_ptr(),
326
+ cache['up_blocks_conv_cache'].contiguous().data_ptr(),
327
+ cache['up_blocks_kv_cache'].contiguous().data_ptr(),
328
+ cache['final_blocks_conv_cache'].contiguous().data_ptr(),
329
+ x.data_ptr(),
330
+ cache['down_blocks_conv_cache'].data_ptr(),
331
+ down_blocks_kv_cache_out.data_ptr(),
332
+ cache['mid_blocks_conv_cache'].data_ptr(),
333
+ mid_blocks_kv_cache_out.data_ptr(),
334
+ cache['up_blocks_conv_cache'].data_ptr(),
335
+ up_blocks_kv_cache_out.data_ptr(),
336
+ cache['final_blocks_conv_cache'].data_ptr()]) is True
337
+ cache = (cache['down_blocks_conv_cache'],
338
+ down_blocks_kv_cache_out,
339
+ cache['mid_blocks_conv_cache'],
340
+ mid_blocks_kv_cache_out,
341
+ cache['up_blocks_conv_cache'],
342
+ up_blocks_kv_cache_out,
343
+ cache['final_blocks_conv_cache'])
344
+ return x, cache
model/cosyvoice/flow/length_regulator.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Tuple
15
+ import torch.nn as nn
16
+ import torch
17
+ from torch.nn import functional as F
18
+ from cosyvoice.utils.mask import make_pad_mask
19
+
20
+
21
+ class InterpolateRegulator(nn.Module):
22
+ def __init__(
23
+ self,
24
+ channels: int,
25
+ sampling_ratios: Tuple,
26
+ out_channels: int = None,
27
+ groups: int = 1,
28
+ ):
29
+ super().__init__()
30
+ self.sampling_ratios = sampling_ratios
31
+ out_channels = out_channels or channels
32
+ model = nn.ModuleList([])
33
+ if len(sampling_ratios) > 0:
34
+ for _ in sampling_ratios:
35
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
36
+ norm = nn.GroupNorm(groups, channels)
37
+ act = nn.Mish()
38
+ model.extend([module, norm, act])
39
+ model.append(
40
+ nn.Conv1d(channels, out_channels, 1, 1)
41
+ )
42
+ self.model = nn.Sequential(*model)
43
+
44
+ def forward(self, x, ylens=None):
45
+ # x in (B, T, D)
46
+ mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
47
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
48
+ out = self.model(x).transpose(1, 2).contiguous()
49
+ olens = ylens
50
+ return out * mask, olens
51
+
52
+ def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
53
+ # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
54
+ # NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py
55
+ # x in (B, T, D)
56
+ if x2.shape[1] > 40:
57
+ x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
58
+ x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
59
+ mode='linear')
60
+ x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
61
+ x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
62
+ else:
63
+ x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
64
+ if x1.shape[1] != 0:
65
+ x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
66
+ x = torch.concat([x1, x2], dim=2)
67
+ else:
68
+ x = x2
69
+ out = self.model(x).transpose(1, 2).contiguous()
70
+ return out, mel_len1 + mel_len2
model/cosyvoice/hifigan/__pycache__/discriminator.cpython-310.pyc ADDED
Binary file (8.74 kB). View file
 
model/cosyvoice/hifigan/__pycache__/discriminator.cpython-311.pyc ADDED
Binary file (16.1 kB). View file
 
model/cosyvoice/hifigan/__pycache__/f0_predictor.cpython-310.pyc ADDED
Binary file (1.44 kB). View file
 
model/cosyvoice/hifigan/__pycache__/f0_predictor.cpython-311.pyc ADDED
Binary file (2.82 kB). View file
 
model/cosyvoice/hifigan/__pycache__/generator.cpython-310.pyc ADDED
Binary file (11.8 kB). View file
 
model/cosyvoice/hifigan/__pycache__/generator.cpython-311.pyc ADDED
Binary file (23.4 kB). View file
 
model/cosyvoice/hifigan/__pycache__/hifigan.cpython-310.pyc ADDED
Binary file (2.58 kB). View file
 
model/cosyvoice/hifigan/__pycache__/hifigan.cpython-311.pyc ADDED
Binary file (4.69 kB). View file
 
model/cosyvoice/hifigan/discriminator.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ try:
5
+ from torch.nn.utils.parametrizations import weight_norm, spectral_norm
6
+ except ImportError:
7
+ from torch.nn.utils import weight_norm, spectral_norm
8
+ from typing import List, Optional, Tuple
9
+ from einops import rearrange
10
+ from torchaudio.transforms import Spectrogram
11
+
12
+ LRELU_SLOPE = 0.1
13
+
14
+
15
+ class MultipleDiscriminator(nn.Module):
16
+ def __init__(
17
+ self, mpd: nn.Module, mrd: nn.Module
18
+ ):
19
+ super().__init__()
20
+ self.mpd = mpd
21
+ self.mrd = mrd
22
+
23
+ def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
24
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
25
+ this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
26
+ y_d_rs += this_y_d_rs
27
+ y_d_gs += this_y_d_gs
28
+ fmap_rs += this_fmap_rs
29
+ fmap_gs += this_fmap_gs
30
+ this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
31
+ y_d_rs += this_y_d_rs
32
+ y_d_gs += this_y_d_gs
33
+ fmap_rs += this_fmap_rs
34
+ fmap_gs += this_fmap_gs
35
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
36
+
37
+
38
+ class MultiResolutionDiscriminator(nn.Module):
39
+ def __init__(
40
+ self,
41
+ fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
42
+ num_embeddings: Optional[int] = None,
43
+ ):
44
+ """
45
+ Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
46
+ Additionally, it allows incorporating conditional information with a learned embeddings table.
47
+
48
+ Args:
49
+ fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
50
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
51
+ Defaults to None.
52
+ """
53
+
54
+ super().__init__()
55
+ self.discriminators = nn.ModuleList(
56
+ [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
57
+ )
58
+
59
+ def forward(
60
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
61
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
62
+ y_d_rs = []
63
+ y_d_gs = []
64
+ fmap_rs = []
65
+ fmap_gs = []
66
+
67
+ for d in self.discriminators:
68
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
69
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
70
+ y_d_rs.append(y_d_r)
71
+ fmap_rs.append(fmap_r)
72
+ y_d_gs.append(y_d_g)
73
+ fmap_gs.append(fmap_g)
74
+
75
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
76
+
77
+
78
+ class DiscriminatorR(nn.Module):
79
+ def __init__(
80
+ self,
81
+ window_length: int,
82
+ num_embeddings: Optional[int] = None,
83
+ channels: int = 32,
84
+ hop_factor: float = 0.25,
85
+ bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
86
+ ):
87
+ super().__init__()
88
+ self.window_length = window_length
89
+ self.hop_factor = hop_factor
90
+ self.spec_fn = Spectrogram(
91
+ n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
92
+ )
93
+ n_fft = window_length // 2 + 1
94
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
95
+ self.bands = bands
96
+ convs = lambda: nn.ModuleList(
97
+ [
98
+ weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
99
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
100
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
101
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
102
+ weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
103
+ ]
104
+ )
105
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
106
+
107
+ if num_embeddings is not None:
108
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
109
+ torch.nn.init.zeros_(self.emb.weight)
110
+
111
+ self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
112
+
113
+ def spectrogram(self, x):
114
+ # Remove DC offset
115
+ x = x - x.mean(dim=-1, keepdims=True)
116
+ # Peak normalize the volume of input audio
117
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
118
+ x = self.spec_fn(x)
119
+ x = torch.view_as_real(x)
120
+ x = rearrange(x, "b f t c -> b c t f")
121
+ # Split into bands
122
+ x_bands = [x[..., b[0]: b[1]] for b in self.bands]
123
+ return x_bands
124
+
125
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
126
+ x_bands = self.spectrogram(x)
127
+ fmap = []
128
+ x = []
129
+ for band, stack in zip(x_bands, self.band_convs):
130
+ for i, layer in enumerate(stack):
131
+ band = layer(band)
132
+ band = torch.nn.functional.leaky_relu(band, 0.1)
133
+ if i > 0:
134
+ fmap.append(band)
135
+ x.append(band)
136
+ x = torch.cat(x, dim=-1)
137
+ if cond_embedding_id is not None:
138
+ emb = self.emb(cond_embedding_id)
139
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
140
+ else:
141
+ h = 0
142
+ x = self.conv_post(x)
143
+ fmap.append(x)
144
+ x += h
145
+
146
+ return x, fmap
147
+
148
+
149
+ class MultiResSpecDiscriminator(torch.nn.Module):
150
+
151
+ def __init__(self,
152
+ fft_sizes=[1024, 2048, 512],
153
+ hop_sizes=[120, 240, 50],
154
+ win_lengths=[600, 1200, 240],
155
+ window="hann_window"):
156
+
157
+ super(MultiResSpecDiscriminator, self).__init__()
158
+ self.discriminators = nn.ModuleList([
159
+ SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
160
+ SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
161
+ SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)])
162
+
163
+ def forward(self, y, y_hat):
164
+ y_d_rs = []
165
+ y_d_gs = []
166
+ fmap_rs = []
167
+ fmap_gs = []
168
+ for _, d in enumerate(self.discriminators):
169
+ y_d_r, fmap_r = d(y)
170
+ y_d_g, fmap_g = d(y_hat)
171
+ y_d_rs.append(y_d_r)
172
+ fmap_rs.append(fmap_r)
173
+ y_d_gs.append(y_d_g)
174
+ fmap_gs.append(fmap_g)
175
+
176
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
177
+
178
+
179
+ def stft(x, fft_size, hop_size, win_length, window):
180
+ """Perform STFT and convert to magnitude spectrogram.
181
+ Args:
182
+ x (Tensor): Input signal tensor (B, T).
183
+ fft_size (int): FFT size.
184
+ hop_size (int): Hop size.
185
+ win_length (int): Window length.
186
+ window (str): Window function type.
187
+ Returns:
188
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
189
+ """
190
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
191
+
192
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
193
+ return torch.abs(x_stft).transpose(2, 1)
194
+
195
+
196
+ class SpecDiscriminator(nn.Module):
197
+ """docstring for Discriminator."""
198
+
199
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
200
+ super(SpecDiscriminator, self).__init__()
201
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
202
+ self.fft_size = fft_size
203
+ self.shift_size = shift_size
204
+ self.win_length = win_length
205
+ self.window = getattr(torch, window)(win_length)
206
+ self.discriminators = nn.ModuleList([
207
+ norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
208
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
209
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
210
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
211
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
212
+ ])
213
+
214
+ self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
215
+
216
+ def forward(self, y):
217
+
218
+ fmap = []
219
+ y = y.squeeze(1)
220
+ y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.device))
221
+ y = y.unsqueeze(1)
222
+ for _, d in enumerate(self.discriminators):
223
+ y = d(y)
224
+ y = F.leaky_relu(y, LRELU_SLOPE)
225
+ fmap.append(y)
226
+
227
+ y = self.out(y)
228
+ fmap.append(y)
229
+
230
+ return torch.flatten(y, 1, -1), fmap
model/cosyvoice/hifigan/f0_predictor.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ try:
17
+ from torch.nn.utils.parametrizations import weight_norm
18
+ except ImportError:
19
+ from torch.nn.utils import weight_norm
20
+
21
+
22
+ class ConvRNNF0Predictor(nn.Module):
23
+ def __init__(self,
24
+ num_class: int = 1,
25
+ in_channels: int = 80,
26
+ cond_channels: int = 512
27
+ ):
28
+ super().__init__()
29
+
30
+ self.num_class = num_class
31
+ self.condnet = nn.Sequential(
32
+ weight_norm(
33
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
34
+ ),
35
+ nn.ELU(),
36
+ weight_norm(
37
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
38
+ ),
39
+ nn.ELU(),
40
+ weight_norm(
41
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
42
+ ),
43
+ nn.ELU(),
44
+ weight_norm(
45
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
46
+ ),
47
+ nn.ELU(),
48
+ weight_norm(
49
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
50
+ ),
51
+ nn.ELU(),
52
+ )
53
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
54
+
55
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ x = self.condnet(x)
57
+ x = x.transpose(1, 2)
58
+ return torch.abs(self.classifier(x).squeeze(-1))
model/cosyvoice/hifigan/generator.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """HIFI-GAN"""
16
+
17
+ from typing import Dict, Optional, List
18
+ import numpy as np
19
+ from scipy.signal import get_window
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.nn import Conv1d
24
+ from torch.nn import ConvTranspose1d
25
+ from torch.nn.utils import remove_weight_norm
26
+ try:
27
+ from torch.nn.utils.parametrizations import weight_norm
28
+ except ImportError:
29
+ from torch.nn.utils import weight_norm
30
+ from torch.distributions.uniform import Uniform
31
+
32
+ from cosyvoice.transformer.activation import Snake
33
+ from cosyvoice.utils.common import get_padding
34
+ from cosyvoice.utils.common import init_weights
35
+
36
+
37
+ """hifigan based generator implementation.
38
+
39
+ This code is modified from https://github.com/jik876/hifi-gan
40
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
41
+ https://github.com/NVIDIA/BigVGAN
42
+
43
+ """
44
+
45
+
46
+ class ResBlock(torch.nn.Module):
47
+ """Residual block module in HiFiGAN/BigVGAN."""
48
+ def __init__(
49
+ self,
50
+ channels: int = 512,
51
+ kernel_size: int = 3,
52
+ dilations: List[int] = [1, 3, 5],
53
+ ):
54
+ super(ResBlock, self).__init__()
55
+ self.convs1 = nn.ModuleList()
56
+ self.convs2 = nn.ModuleList()
57
+
58
+ for dilation in dilations:
59
+ self.convs1.append(
60
+ weight_norm(
61
+ Conv1d(
62
+ channels,
63
+ channels,
64
+ kernel_size,
65
+ 1,
66
+ dilation=dilation,
67
+ padding=get_padding(kernel_size, dilation)
68
+ )
69
+ )
70
+ )
71
+ self.convs2.append(
72
+ weight_norm(
73
+ Conv1d(
74
+ channels,
75
+ channels,
76
+ kernel_size,
77
+ 1,
78
+ dilation=1,
79
+ padding=get_padding(kernel_size, 1)
80
+ )
81
+ )
82
+ )
83
+ self.convs1.apply(init_weights)
84
+ self.convs2.apply(init_weights)
85
+ self.activations1 = nn.ModuleList([
86
+ Snake(channels, alpha_logscale=False)
87
+ for _ in range(len(self.convs1))
88
+ ])
89
+ self.activations2 = nn.ModuleList([
90
+ Snake(channels, alpha_logscale=False)
91
+ for _ in range(len(self.convs2))
92
+ ])
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ for idx in range(len(self.convs1)):
96
+ xt = self.activations1[idx](x)
97
+ xt = self.convs1[idx](xt)
98
+ xt = self.activations2[idx](xt)
99
+ xt = self.convs2[idx](xt)
100
+ x = xt + x
101
+ return x
102
+
103
+ def remove_weight_norm(self):
104
+ for idx in range(len(self.convs1)):
105
+ remove_weight_norm(self.convs1[idx])
106
+ remove_weight_norm(self.convs2[idx])
107
+
108
+
109
+ class SineGen(torch.nn.Module):
110
+ """ Definition of sine generator
111
+ SineGen(samp_rate, harmonic_num = 0,
112
+ sine_amp = 0.1, noise_std = 0.003,
113
+ voiced_threshold = 0,
114
+ flag_for_pulse=False)
115
+ samp_rate: sampling rate in Hz
116
+ harmonic_num: number of harmonic overtones (default 0)
117
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
118
+ noise_std: std of Gaussian noise (default 0.003)
119
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
120
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
121
+ Note: when flag_for_pulse is True, the first time step of a voiced
122
+ segment is always sin(np.pi) or cos(0)
123
+ """
124
+
125
+ def __init__(self, samp_rate, harmonic_num=0,
126
+ sine_amp=0.1, noise_std=0.003,
127
+ voiced_threshold=0):
128
+ super(SineGen, self).__init__()
129
+ self.sine_amp = sine_amp
130
+ self.noise_std = noise_std
131
+ self.harmonic_num = harmonic_num
132
+ self.sampling_rate = samp_rate
133
+ self.voiced_threshold = voiced_threshold
134
+
135
+ def _f02uv(self, f0):
136
+ # generate uv signal
137
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
138
+ return uv
139
+
140
+ @torch.no_grad()
141
+ def forward(self, f0):
142
+ """
143
+ :param f0: [B, 1, sample_len], Hz
144
+ :return: [B, 1, sample_len]
145
+ """
146
+
147
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
148
+ for i in range(self.harmonic_num + 1):
149
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
150
+
151
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
152
+ u_dist = Uniform(low=-np.pi, high=np.pi)
153
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
154
+ phase_vec[:, 0, :] = 0
155
+
156
+ # generate sine waveforms
157
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
158
+
159
+ # generate uv signal
160
+ uv = self._f02uv(f0)
161
+
162
+ # noise: for unvoiced should be similar to sine_amp
163
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
164
+ # . for voiced regions is self.noise_std
165
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
166
+ noise = noise_amp * torch.randn_like(sine_waves)
167
+
168
+ # first: set the unvoiced part to 0 by uv
169
+ # then: additive noise
170
+ sine_waves = sine_waves * uv + noise
171
+ return sine_waves, uv, noise
172
+
173
+
174
+ class SourceModuleHnNSF(torch.nn.Module):
175
+ """ SourceModule for hn-nsf
176
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
177
+ add_noise_std=0.003, voiced_threshod=0)
178
+ sampling_rate: sampling_rate in Hz
179
+ harmonic_num: number of harmonic above F0 (default: 0)
180
+ sine_amp: amplitude of sine source signal (default: 0.1)
181
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
182
+ note that amplitude of noise in unvoiced is decided
183
+ by sine_amp
184
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
185
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
186
+ F0_sampled (batchsize, length, 1)
187
+ Sine_source (batchsize, length, 1)
188
+ noise_source (batchsize, length 1)
189
+ uv (batchsize, length, 1)
190
+ """
191
+
192
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
193
+ add_noise_std=0.003, voiced_threshod=0):
194
+ super(SourceModuleHnNSF, self).__init__()
195
+
196
+ self.sine_amp = sine_amp
197
+ self.noise_std = add_noise_std
198
+
199
+ # to produce sine waveforms
200
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
201
+ sine_amp, add_noise_std, voiced_threshod)
202
+
203
+ # to merge source harmonics into a single excitation
204
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
205
+ self.l_tanh = torch.nn.Tanh()
206
+
207
+ def forward(self, x):
208
+ """
209
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
210
+ F0_sampled (batchsize, length, 1)
211
+ Sine_source (batchsize, length, 1)
212
+ noise_source (batchsize, length 1)
213
+ """
214
+ # source for harmonic branch
215
+ with torch.no_grad():
216
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
217
+ sine_wavs = sine_wavs.transpose(1, 2)
218
+ uv = uv.transpose(1, 2)
219
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
220
+
221
+ # source for noise branch, in the same shape as uv
222
+ noise = torch.randn_like(uv) * self.sine_amp / 3
223
+ return sine_merge, noise, uv
224
+
225
+
226
+ class HiFTGenerator(nn.Module):
227
+ """
228
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
229
+ https://arxiv.org/abs/2309.09493
230
+ """
231
+ def __init__(
232
+ self,
233
+ in_channels: int = 80,
234
+ base_channels: int = 512,
235
+ nb_harmonics: int = 8,
236
+ sampling_rate: int = 22050,
237
+ nsf_alpha: float = 0.1,
238
+ nsf_sigma: float = 0.003,
239
+ nsf_voiced_threshold: float = 10,
240
+ upsample_rates: List[int] = [8, 8],
241
+ upsample_kernel_sizes: List[int] = [16, 16],
242
+ istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
243
+ resblock_kernel_sizes: List[int] = [3, 7, 11],
244
+ resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
245
+ source_resblock_kernel_sizes: List[int] = [7, 11],
246
+ source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
247
+ lrelu_slope: float = 0.1,
248
+ audio_limit: float = 0.99,
249
+ f0_predictor: torch.nn.Module = None,
250
+ ):
251
+ super(HiFTGenerator, self).__init__()
252
+
253
+ self.out_channels = 1
254
+ self.nb_harmonics = nb_harmonics
255
+ self.sampling_rate = sampling_rate
256
+ self.istft_params = istft_params
257
+ self.lrelu_slope = lrelu_slope
258
+ self.audio_limit = audio_limit
259
+
260
+ self.num_kernels = len(resblock_kernel_sizes)
261
+ self.num_upsamples = len(upsample_rates)
262
+ self.m_source = SourceModuleHnNSF(
263
+ sampling_rate=sampling_rate,
264
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
265
+ harmonic_num=nb_harmonics,
266
+ sine_amp=nsf_alpha,
267
+ add_noise_std=nsf_sigma,
268
+ voiced_threshod=nsf_voiced_threshold)
269
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
270
+
271
+ self.conv_pre = weight_norm(
272
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
273
+ )
274
+
275
+ # Up
276
+ self.ups = nn.ModuleList()
277
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
278
+ self.ups.append(
279
+ weight_norm(
280
+ ConvTranspose1d(
281
+ base_channels // (2**i),
282
+ base_channels // (2**(i + 1)),
283
+ k,
284
+ u,
285
+ padding=(k - u) // 2,
286
+ )
287
+ )
288
+ )
289
+
290
+ # Down
291
+ self.source_downs = nn.ModuleList()
292
+ self.source_resblocks = nn.ModuleList()
293
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
294
+ downsample_cum_rates = np.cumprod(downsample_rates)
295
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
296
+ if u == 1:
297
+ self.source_downs.append(
298
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
299
+ )
300
+ else:
301
+ self.source_downs.append(
302
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
303
+ )
304
+
305
+ self.source_resblocks.append(
306
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
307
+ )
308
+
309
+ self.resblocks = nn.ModuleList()
310
+ for i in range(len(self.ups)):
311
+ ch = base_channels // (2**(i + 1))
312
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
313
+ self.resblocks.append(ResBlock(ch, k, d))
314
+
315
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
316
+ self.ups.apply(init_weights)
317
+ self.conv_post.apply(init_weights)
318
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
319
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
320
+ self.f0_predictor = f0_predictor
321
+
322
+ def remove_weight_norm(self):
323
+ print('Removing weight norm...')
324
+ for l in self.ups:
325
+ remove_weight_norm(l)
326
+ for l in self.resblocks:
327
+ l.remove_weight_norm()
328
+ remove_weight_norm(self.conv_pre)
329
+ remove_weight_norm(self.conv_post)
330
+ self.m_source.remove_weight_norm()
331
+ for l in self.source_downs:
332
+ remove_weight_norm(l)
333
+ for l in self.source_resblocks:
334
+ l.remove_weight_norm()
335
+
336
+ def _stft(self, x):
337
+ spec = torch.stft(
338
+ x,
339
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
340
+ return_complex=True)
341
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
342
+ return spec[..., 0], spec[..., 1]
343
+
344
+ def _istft(self, magnitude, phase):
345
+ magnitude = torch.clip(magnitude, max=1e2)
346
+ real = magnitude * torch.cos(phase)
347
+ img = magnitude * torch.sin(phase)
348
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
349
+ self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
350
+ return inverse_transform
351
+
352
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
353
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
354
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
355
+
356
+ x = self.conv_pre(x)
357
+ for i in range(self.num_upsamples):
358
+ x = F.leaky_relu(x, self.lrelu_slope)
359
+ x = self.ups[i](x)
360
+
361
+ if i == self.num_upsamples - 1:
362
+ x = self.reflection_pad(x)
363
+
364
+ # fusion
365
+ si = self.source_downs[i](s_stft)
366
+ si = self.source_resblocks[i](si)
367
+ x = x + si
368
+
369
+ xs = None
370
+ for j in range(self.num_kernels):
371
+ if xs is None:
372
+ xs = self.resblocks[i * self.num_kernels + j](x)
373
+ else:
374
+ xs += self.resblocks[i * self.num_kernels + j](x)
375
+ x = xs / self.num_kernels
376
+
377
+ x = F.leaky_relu(x)
378
+ x = self.conv_post(x)
379
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
380
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
381
+
382
+ x = self._istft(magnitude, phase)
383
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
384
+ return x
385
+
386
+ def forward(
387
+ self,
388
+ batch: dict,
389
+ device: torch.device,
390
+ ) -> Dict[str, Optional[torch.Tensor]]:
391
+ speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
392
+ # mel->f0
393
+ f0 = self.f0_predictor(speech_feat)
394
+ # f0->source
395
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
396
+ s, _, _ = self.m_source(s)
397
+ s = s.transpose(1, 2)
398
+ # mel+source->speech
399
+ generated_speech = self.decode(x=speech_feat, s=s)
400
+ return generated_speech, f0
401
+
402
+ @torch.inference_mode()
403
+ def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
404
+ # mel->f0
405
+ f0 = self.f0_predictor(speech_feat)
406
+ # f0->source
407
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
408
+ s, _, _ = self.m_source(s)
409
+ s = s.transpose(1, 2)
410
+ # use cache_source to avoid glitch
411
+ if cache_source.shape[2] != 0:
412
+ s[:, :, :cache_source.shape[2]] = cache_source
413
+ generated_speech = self.decode(x=speech_feat, s=s)
414
+ return generated_speech, s