+
+## Acknowledge
+
+1. We borrowed a lot of code from [FunASR](https://github.com/modelscope/FunASR).
+2. We borrowed a lot of code from [FunCodec](https://github.com/modelscope/FunCodec).
+3. We borrowed a lot of code from [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS).
+4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
+5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
+
+## Disclaimer
+The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..822757d1d44d27561795b33018cd434e7697e4c9
--- /dev/null
+++ b/app.py
@@ -0,0 +1,211 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Liu Yue)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+import argparse
+import gradio as gr
+import numpy as np
+import torch
+import torchaudio
+import random
+import librosa
+from funasr import AutoModel
+from funasr.utils.postprocess_utils import rich_transcription_postprocess
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
+
+from modelscope import snapshot_download
+snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
+snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
+os.system('cd pretrained_models/CosyVoice-ttsfrd/ && pip install ttsfrd_dependency-0.1-py3-none-any.whl && pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl && apt install -y unzip && unzip resource.zip -d .')
+
+from cosyvoice.cli.cosyvoice import CosyVoice2
+from cosyvoice.utils.file_utils import load_wav, logging
+from cosyvoice.utils.common import set_all_random_seed
+
+inference_mode_list = ['3s极速复刻', '自然语言控制']
+instruct_dict = {'3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
+ '自然语言控制': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入instruct文本\n3. 点击生成音频按钮'}
+stream_mode_list = [('否', False), ('是', True)]
+max_val = 0.8
+
+
+def generate_seed():
+ seed = random.randint(1, 100000000)
+ return {
+ "__type__": "update",
+ "value": seed
+ }
+
+
+def postprocess(speech, top_db=60, hop_length=220, win_length=440):
+ speech, _ = librosa.effects.trim(
+ speech, top_db=top_db,
+ frame_length=win_length,
+ hop_length=hop_length
+ )
+ if speech.abs().max() > max_val:
+ speech = speech / speech.abs().max() * max_val
+ speech = torch.concat([speech, torch.zeros(1, int(target_sr * 0.2))], dim=1)
+ return speech
+
+
+def change_instruction(mode_checkbox_group):
+ return instruct_dict[mode_checkbox_group]
+
+def prompt_wav_recognition(prompt_wav):
+ res = asr_model.generate(input=prompt_wav,
+ language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
+ use_itn=True,
+ )
+ text = res[0]["text"].split('|>')[-1]
+ return text
+
+def generate_audio(tts_text, mode_checkbox_group, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
+ seed, stream):
+ sft_dropdown, speed = '', 1.0
+ if prompt_wav_upload is not None:
+ prompt_wav = prompt_wav_upload
+ elif prompt_wav_record is not None:
+ prompt_wav = prompt_wav_record
+ else:
+ prompt_wav = None
+ # if instruct mode, please make sure that model is iic/CosyVoice-300M-Instruct and not cross_lingual mode
+ if mode_checkbox_group in ['自然语言控制']:
+ if instruct_text == '':
+ gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本')
+ yield (target_sr, default_data)
+ if prompt_wav is None:
+ gr.Info('您正在使用自然语言控制模式, 请输入prompt音频')
+ # if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language
+ if mode_checkbox_group in ['跨语种复刻']:
+ if cosyvoice.frontend.instruct is True:
+ gr.Warning('您正在使用跨语种复刻模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M模型'.format(args.model_dir))
+ yield (target_sr, default_data)
+ if instruct_text != '':
+ gr.Info('您正在使用跨语种复刻模式, instruct文本会被忽略')
+ if prompt_wav is None:
+ gr.Warning('您正在使用跨语种复刻模式, 请提供prompt音频')
+ yield (target_sr, default_data)
+ gr.Info('您正在使用跨语种复刻模式, 请确保合成文本和prompt文本为不同语言')
+ # if in zero_shot cross_lingual, please make sure that prompt_text and prompt_wav meets requirements
+ if mode_checkbox_group in ['3s极速复刻', '跨语种复刻']:
+ if prompt_wav is None:
+ gr.Warning('prompt音频为空,您是否忘记输入prompt音频?')
+ yield (target_sr, default_data)
+ if torchaudio.info(prompt_wav).sample_rate < prompt_sr:
+ gr.Warning('prompt音频采样率{}低于{}'.format(torchaudio.info(prompt_wav).sample_rate, prompt_sr))
+ yield (target_sr, default_data)
+ # sft mode only use sft_dropdown
+ if mode_checkbox_group in ['预训练音色']:
+ if instruct_text != '' or prompt_wav is not None or prompt_text != '':
+ gr.Info('您正在使用预训练音色模式,prompt文本/prompt音频/instruct文本会被忽略!')
+ # zero_shot mode only use prompt_wav prompt text
+ if mode_checkbox_group in ['3s极速复刻']:
+ if prompt_text == '':
+ gr.Warning('prompt文本为空,您是否忘记输入prompt文本?')
+ yield (target_sr, default_data)
+ if instruct_text != '':
+ gr.Info('您正在使用3s极速复刻模式,预训练音色/instruct文本会被忽略!')
+ info = torchaudio.info(prompt_wav)
+ if info.num_frames / info.sample_rate > 10:
+ gr.Warning('请限制输入音频在10s内,避免推理效果过低')
+ yield (target_sr, default_data)
+
+ if mode_checkbox_group == '预训练音色':
+ logging.info('get sft inference request')
+ set_all_random_seed(seed)
+ for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream, speed=speed):
+ yield (target_sr, i['tts_speech'].numpy().flatten())
+ elif mode_checkbox_group == '3s极速复刻':
+ logging.info('get zero_shot inference request')
+ prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
+ set_all_random_seed(seed)
+ for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed):
+ yield (target_sr, i['tts_speech'].numpy().flatten())
+ elif mode_checkbox_group == '跨语种复刻':
+ logging.info('get cross_lingual inference request')
+ prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
+ set_all_random_seed(seed)
+ for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed):
+ yield (target_sr, i['tts_speech'].numpy().flatten())
+ else:
+ logging.info('get instruct inference request')
+ logging.info('get instruct inference request')
+ prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
+ set_all_random_seed(seed)
+ for i in cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k, stream=stream, speed=speed):
+ yield (target_sr, i['tts_speech'].numpy().flatten())
+
+
+def main():
+ with gr.Blocks() as demo:
+ gr.Markdown("### 代码库 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) \
+ 预训练模型 [CosyVoice2-0.5B](https://www.modelscope.cn/models/iic/CosyVoice2-0.5B) \
+ [CosyVoice-300M](https://www.modelscope.cn/models/iic/CosyVoice-300M) \
+ [CosyVoice-300M-Instruct](https://www.modelscope.cn/models/iic/CosyVoice-300M-Instruct) \
+ [CosyVoice-300M-SFT](https://www.modelscope.cn/models/iic/CosyVoice-300M-SFT)")
+ gr.Markdown("#### 请输入需要合成的文本,选择推理模式,并按照提示步骤进行操作")
+
+ tts_text = gr.Textbox(label="输入合成文本", lines=1, value="CosyVoice迎来全面升级,提供更准、更稳、更快、 更好的语音生成能力。CosyVoice is undergoing a comprehensive upgrade, providing more accurate, stable, faster, and better voice generation capabilities.")
+ with gr.Row():
+ mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0])
+ instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5)
+ stream = gr.Radio(choices=stream_mode_list, label='是否流式推理', value=stream_mode_list[0][1])
+ with gr.Column(scale=0.25):
+ seed_button = gr.Button(value="\U0001F3B2")
+ seed = gr.Number(value=0, label="随机推理种子")
+
+ with gr.Row():
+ prompt_wav_upload = gr.Audio(sources='upload', type='filepath', label='选择prompt音频文件,注意采样率不低于16khz')
+ prompt_wav_record = gr.Audio(sources='microphone', type='filepath', label='录制prompt音频文件')
+ prompt_text = gr.Textbox(label="prompt文本", lines=1, placeholder="请输入prompt文本,支持自动识别,您可以自行修正识别结果...", value='')
+ instruct_text = gr.Textbox(label="输入instruct文本", lines=1, placeholder="请输入instruct文本.例如:用四川话说这句话。", value='')
+
+ generate_button = gr.Button("生成音频")
+
+ audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=True)
+
+ seed_button.click(generate_seed, inputs=[], outputs=seed)
+ generate_button.click(generate_audio,
+ inputs=[tts_text, mode_checkbox_group, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
+ seed, stream],
+ outputs=[audio_output])
+ mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
+ prompt_wav_upload.change(fn=prompt_wav_recognition, inputs=[prompt_wav_upload], outputs=[prompt_text])
+ prompt_wav_record.change(fn=prompt_wav_recognition, inputs=[prompt_wav_record], outputs=[prompt_text])
+ demo.queue(max_size=4, default_concurrency_limit=2).launch(server_port=50000)
+
+
+if __name__ == '__main__':
+ load_jit = True if os.environ.get('jit') == '1' else False
+ load_onnx = True if os.environ.get('onnx') == '1' else False
+ load_trt = True if os.environ.get('trt') == '1' else False
+ logging.info('cosyvoice args load_jit {} load_onnx {} load_trt {}'.format(load_jit, load_onnx, load_trt))
+ cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=load_jit, load_onnx=load_onnx, load_trt=load_trt)
+ sft_spk = cosyvoice.list_avaliable_spks()
+ prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
+ for stream in [True, False]:
+ for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=stream)):
+ continue
+ prompt_sr, target_sr = 16000, 24000
+ default_data = np.zeros(target_sr)
+
+ model_dir = "iic/SenseVoiceSmall"
+ asr_model = AutoModel(
+ model=model_dir,
+ disable_update=True,
+ log_level='DEBUG',
+ device="cuda:0")
+ main()
diff --git a/cert.pem b/cert.pem
new file mode 100644
index 0000000000000000000000000000000000000000..4ffd9a21561f0823ba6454adff61e1e34af41048
--- /dev/null
+++ b/cert.pem
@@ -0,0 +1,32 @@
+-----BEGIN CERTIFICATE-----
+MIIFkTCCA3mgAwIBAgIUEO2zq0OQeuRFIFH4lfHLgcR5hTUwDQYJKoZIhvcNAQEL
+BQAwWDELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAlpKMQswCQYDVQQHDAJIWjEQMA4G
+A1UECgwHY29tcGFueTEMMAoGA1UECwwDYWxpMQ8wDQYDVQQDDAZncmFkaW8wHhcN
+MjQwNDI5MTIxNDQxWhcNMjUwNDI5MTIxNDQxWjBYMQswCQYDVQQGEwJDTjELMAkG
+A1UECAwCWkoxCzAJBgNVBAcMAkhaMRAwDgYDVQQKDAdjb21wYW55MQwwCgYDVQQL
+DANhbGkxDzANBgNVBAMMBmdyYWRpbzCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCC
+AgoCggIBAKohzP3V7VdDyMgfRO4+xzh/mWFPapQWJIIrhnHj8GRJ9tgFVXf71vcU
+PMo+/t+y0rjupw3WwWIj6kJP15t46xxmLzoJHZKHV7d1Y7XJTyN1hvRCzeGz6w/E
+VX0y6U+0y9m1HG0kvsfLwCKZPxEN21RfPukGN3qOIpjaRvE6fxg8DCUQN8qEpjQ9
+DQehq/g0B/wZFwIB2089+BeqesjaOinY2+z4YiMreIj2dy8XM6G59quS21oe0u5n
+6SW80ayf/yA6CHqblCHNfdi3vrzxMalNjT5EHKxQsLEDd2nWSndoPeXClXdSoIpE
+1+H86dWHZpzPLd6rOfa+FCZ3TQsZbL+p3ree2AIMIB7zWw59oKGE8UuZbtyCVWK6
+hufMOs703ZT97WeBEoOA72itUwCBqAakYNoULvYSOuXZT0LvJN1Z4YLNTkJXDA0u
+vMABPbRFXfFK67F/fLm/vges4dhhpQNeSxSuXEC7rMA5hCQRk3BccdEgxoBfNZcM
+HKo8CaB3wxbK7inXZb3JD4sFK64H5VjfJE8ibFzoIhiPICuC+0bzSKfc0+dcUNMb
+KsE5M3etmS1TcPKuebk9OTu8YUJiNMYgEInw7vCq004v4IOqQr0aX/LGRm21RB/i
+M3qFKCSHSw5/Z+o9sZ/kw3AeNnx5r5dq4OAswx3RhScPJtd6qesZAgMBAAGjUzBR
+MB0GA1UdDgQWBBSNZx2v1BNAGL4gGM4TUXIvn1OyFTAfBgNVHSMEGDAWgBSNZx2v
+1BNAGL4gGM4TUXIvn1OyFTAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUA
+A4ICAQA/khg91VtI/tDLCLyQ6ZMulfOzJHuGmIs4cvG5fIOvzYjQpvAGSgNeivKp
++5RIkpUabcwdUCq6VXeXheo+SaGgVdwpxQy/p/E+i+AengRB5Qm/hJ5lLU6CdNBq
+WCN/0Aa1GL/pM4HAzVQY81HeB46UaHWtW6J9hnBbVg2MF2GanAqfeODpZqIHEggt
+Vw2ivElV47JTFZsNU+JYG5ECsfTjNQYpoA6Hyb/d5ZW8YsfOjr8oIBM4QyZWq1Ke
+eAlytVwl9lj4AkAQIAgkrJHkLjj5yjZ7Hir5NjBuBx06FDAIFb2XWgNnq4ua/pSq
+9fL4cxx4cEJku1X/FYtUBbWsXe8uFGwTEGHuEZR3pj5VSFbuNlARLIsq8/gh8MRQ
+NjKQIlTVINkuOFuVmSrLC5nIwTPhlpEFwIQPGzFD2DbVNor9EXQ2b89WtHqZAZik
+qFDb76JM9jctf9n8l96oSKrwEaCoFmRojnyyYl9UByJxPRCeTJ//i2vxeTvLC3FT
+Rw2jFi/pwoqSVmJtuAFLT96/x2qKpgk+M1zG3oFiDV1lxY8sw1RA3Mm4s3Cm8H5A
+3E+6R34XZLifqhxLVcyDsRWPcqte3Pt6v/xXWN+EuOigK4tr69p8aU7WR5mskmzO
+tZFeEb0OxL1WjF/rmwCkd/SvSuWSiszMoX5hcOA7/GGw3pl3YQ==
+-----END CERTIFICATE-----
diff --git a/cosyvoice/__init__.py b/cosyvoice/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cosyvoice/bin/average_model.py b/cosyvoice/bin/average_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d095dcd99f915f0ffdbc3a0c14fcb6f8db900be0
--- /dev/null
+++ b/cosyvoice/bin/average_model.py
@@ -0,0 +1,92 @@
+# Copyright (c) 2020 Mobvoi Inc (Di Wu)
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import argparse
+import glob
+
+import yaml
+import torch
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='average model')
+ parser.add_argument('--dst_model', required=True, help='averaged model')
+ parser.add_argument('--src_path',
+ required=True,
+ help='src model path for average')
+ parser.add_argument('--val_best',
+ action="store_true",
+ help='averaged model')
+ parser.add_argument('--num',
+ default=5,
+ type=int,
+ help='nums for averaged model')
+
+ args = parser.parse_args()
+ print(args)
+ return args
+
+
+def main():
+ args = get_args()
+ val_scores = []
+ if args.val_best:
+ yamls = glob.glob('{}/*.yaml'.format(args.src_path))
+ yamls = [
+ f for f in yamls
+ if not (os.path.basename(f).startswith('train')
+ or os.path.basename(f).startswith('init'))
+ ]
+ for y in yamls:
+ with open(y, 'r') as f:
+ dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
+ loss = float(dic_yaml['loss_dict']['loss'])
+ epoch = int(dic_yaml['epoch'])
+ step = int(dic_yaml['step'])
+ tag = dic_yaml['tag']
+ val_scores += [[epoch, step, loss, tag]]
+ sorted_val_scores = sorted(val_scores,
+ key=lambda x: x[2],
+ reverse=False)
+ print("best val (epoch, step, loss, tag) = " +
+ str(sorted_val_scores[:args.num]))
+ path_list = [
+ args.src_path + '/epoch_{}_whole.pt'.format(score[0])
+ for score in sorted_val_scores[:args.num]
+ ]
+ print(path_list)
+ avg = {}
+ num = args.num
+ assert num == len(path_list)
+ for path in path_list:
+ print('Processing {}'.format(path))
+ states = torch.load(path, map_location=torch.device('cpu'))
+ for k in states.keys():
+ if k not in avg.keys():
+ avg[k] = states[k].clone()
+ else:
+ avg[k] += states[k]
+ # average
+ for k in avg.keys():
+ if avg[k] is not None:
+ # pytorch 1.6 use true_divide instead of /=
+ avg[k] = torch.true_divide(avg[k], num)
+ print('Saving to {}'.format(args.dst_model))
+ torch.save(avg, args.dst_model)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cosyvoice/bin/convert.py b/cosyvoice/bin/convert.py
new file mode 100644
index 0000000000000000000000000000000000000000..789efe6e643c0d8723f6f216cf27ae2be1863986
--- /dev/null
+++ b/cosyvoice/bin/convert.py
@@ -0,0 +1,168 @@
+import sys
+import torch
+
+def convert_llm(state_dict):
+ # 调整了lm的结构,把codec_lm.encoder作为llm,codec_lm.decoder作为decoder
+ keys = list(state_dict.keys())
+ for k in keys:
+ if k.startswith('codec_lm.encoder.'):
+ v = state_dict.pop(k)
+ k = k.replace('codec_lm.encoder.', 'llm.')
+ state_dict[k] = v
+ if k.startswith('codec_lm.decoder.'):
+ v = state_dict.pop(k)
+ k = k.replace('codec_lm.decoder.', 'llm_decoder.')
+ state_dict[k] = v
+ # espnet和wenet具体实现上的差异
+ keys = list(state_dict.keys())
+ for k in keys:
+ if k.startswith('text_encoder.embed.'):
+ v = state_dict.pop(k)
+ k = k.replace('text_encoder.embed.', 'text_encoder.embed.out.')
+ state_dict[k] = v
+ if k.startswith('llm.embed.'):
+ v = state_dict.pop(k)
+ k = k.replace('llm.embed.', 'llm.embed.out.')
+ state_dict[k] = v
+ keys = list(state_dict.keys())
+ for k in keys:
+ if k.startswith('text_enc_out_layer.'):
+ v = state_dict.pop(k)
+ k = k.replace('text_enc_out_layer.', 'text_encoder_affine_layer.')
+ state_dict[k] = v
+ if k.startswith('token_embedding.'):
+ v = state_dict.pop(k)
+ k = k.replace('token_embedding.', 'text_embedding.')
+ state_dict[k] = v
+ if k.startswith('xvec_proj.'):
+ v = state_dict.pop(k)
+ k = k.replace('xvec_proj.', 'spk_embed_affine_layer.')
+ state_dict[k] = v
+ if k.startswith('lm_embedding.'):
+ v = state_dict.pop(k)
+ k = k.replace('lm_embedding.', 'llm_embedding.')
+ state_dict[k] = v
+ if k.startswith('codec_embedder.'):
+ v = state_dict.pop(k)
+ k = k.replace('codec_embedder.', 'speech_embedding.')
+ state_dict[k] = v
+ # instruct少了spk embedding参数,加个全0上去
+ keys = list(state_dict.keys())
+ if 'spk_embed_affine_layer.weight' not in keys:
+ print('no spk_embed_affine_layer.weight, should be instruct model')
+ state_dict['spk_embed_affine_layer.weight'] = torch.zeros(1024, 192)
+ if 'spk_embed_affine_layer.bias' not in keys:
+ print('no spk_embed_affine_layer.bias, should be instruct model')
+ state_dict['spk_embed_affine_layer.bias'] = torch.zeros(1024)
+ return state_dict
+
+def convert_hift(state_dict):
+ # 调整了cosyvoice中hifigan的结构,把f0_predictor放到generator里
+ keys = list(state_dict.keys())
+ for k in keys:
+ if k.startswith('decoder.'):
+ v = state_dict.pop(k)
+ k = k.replace('decoder.', '')
+ state_dict[k] = v
+ if k.startswith('generator.'):
+ v = state_dict.pop(k)
+ k = k.replace('generator.', '')
+ state_dict[k] = v
+ return state_dict
+
+def convert_flow(state_dict):
+ keys = list(state_dict.keys())
+ for k in keys:
+ if k.startswith('encoder.embed.'):
+ v = state_dict.pop(k)
+ k = k.replace('encoder.embed.', 'encoder.embed.out.')
+ state_dict[k] = v
+ for k in keys:
+ if k.startswith('xvec_proj.'):
+ v = state_dict.pop(k)
+ k = k.replace('xvec_proj.', 'spk_embed_affine_layer.')
+ state_dict[k] = v
+ return state_dict
+
+def convert_llm2(state_dict):
+ # 调整了lm的结构,把codec_lm.encoder作为llm,codec_lm.decoder作为decoder
+ keys = list(state_dict.keys())
+ for k in keys:
+ if k.startswith('codec_lm.encoder.'):
+ v = state_dict.pop(k)
+ k = k.replace('codec_lm.encoder.', 'llm.')
+ state_dict[k] = v
+ if k.startswith('codec_lm.decoder.'):
+ v = state_dict.pop(k)
+ k = k.replace('codec_lm.decoder.', 'llm_decoder.')
+ state_dict[k] = v
+ if k.startswith('lm_embedding.'):
+ v = state_dict.pop(k)
+ k = k.replace('lm_embedding.', 'llm_embedding.')
+ state_dict[k] = v
+ if k.startswith('codec_embedder.'):
+ v = state_dict.pop(k)
+ k = k.replace('codec_embedder.', 'speech_embedding.')
+ state_dict[k] = v
+ if k.startswith('text_enc_out_layer.'):
+ state_dict.pop(k)
+ if k.startswith('token_embedding.weight'):
+ state_dict.pop(k)
+ return state_dict
+
+def convert_flow2(state_dict):
+ keys = list(state_dict.keys())
+ for k in keys:
+ if k.startswith('encoder.embed.'):
+ v = state_dict.pop(k)
+ k = k.replace('encoder.embed.', 'encoder.embed.out.')
+ state_dict[k] = v
+ for k in keys:
+ if k.startswith('xvec_proj.'):
+ v = state_dict.pop(k)
+ k = k.replace('xvec_proj.', 'spk_embed_affine_layer.')
+ state_dict[k] = v
+ for k in keys:
+ if k.startswith('mel_extractor.'):
+ state_dict.pop(k)
+ for k in keys:
+ if k.startswith('encoder.upsample_blocks.0.0.'):
+ v = state_dict.pop(k)
+ k = k.replace('encoder.upsample_blocks.0.0.', 'encoder.up_layer.')
+ state_dict[k] = v
+ if k.startswith('encoder.upsample_blocks.0.1.'):
+ v = state_dict.pop(k)
+ k = k.replace('encoder.upsample_blocks.0.1.', 'encoder.up_embed.out.')
+ state_dict[k] = v
+ if k.startswith('encoder.upsample_blocks.0.2.'):
+ v = state_dict.pop(k)
+ k = k.replace('encoder.upsample_blocks.0.2.', 'encoder.up_encoders.')
+ state_dict[k] = v
+ # CausalBlock1D中sequantial 1->2
+ if k.startswith('decoder.estimator.') and k.endswith('block.1.weight'):
+ v = state_dict.pop(k)
+ k = k.replace('block.1.weight', 'block.2.weight')
+ state_dict[k] = v
+ if k.startswith('decoder.estimator.') and k.endswith('block.1.bias'):
+ v = state_dict.pop(k)
+ k = k.replace('block.1.bias', 'block.2.bias')
+ state_dict[k] = v
+ return state_dict
+
+if __name__ == '__main__':
+ # 使用方法 python3 convert.py 原格式llm.pt llm normalize 新格式llm.pt
+ # 或者 python3 convert.py 新格式llm.pt llm inverse_normalize 原格式llm.pt
+ state_dict = torch.load(sys.argv[1], map_location='cpu')
+ if sys.argv[2] == 'llm':
+ state_dict = convert_llm(state_dict)
+ elif sys.argv[2] == 'flow':
+ state_dict = convert_flow(state_dict)
+ elif sys.argv[2] == 'hift':
+ state_dict = convert_hift(state_dict)
+ elif sys.argv[2] == 'llm2':
+ state_dict = convert_llm2(state_dict)
+ elif sys.argv[2] == 'flow2':
+ state_dict = convert_flow2(state_dict)
+ else:
+ raise ValueError
+ torch.save(state_dict, sys.argv[4])
diff --git a/cosyvoice/bin/export_jit.py b/cosyvoice/bin/export_jit.py
new file mode 100644
index 0000000000000000000000000000000000000000..7587bd81d90558b24803e6724cad52ee11ea3b18
--- /dev/null
+++ b/cosyvoice/bin/export_jit.py
@@ -0,0 +1,74 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import argparse
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+import os
+import sys
+import torch
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.append('{}/../..'.format(ROOT_DIR))
+sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
+from cosyvoice.cli.cosyvoice import CosyVoice
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='export your model for deployment')
+ parser.add_argument('--model_dir',
+ type=str,
+ default='pretrained_models/CosyVoice-300M',
+ help='local path')
+ args = parser.parse_args()
+ print(args)
+ return args
+
+
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+
+ torch._C._jit_set_fusion_strategy([('STATIC', 1)])
+ torch._C._jit_set_profiling_mode(False)
+ torch._C._jit_set_profiling_executor(False)
+
+ cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
+
+ # 1. export llm text_encoder
+ llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
+ script = torch.jit.script(llm_text_encoder)
+ script = torch.jit.freeze(script)
+ script = torch.jit.optimize_for_inference(script)
+ script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
+
+ # 2. export llm llm
+ llm_llm = cosyvoice.model.llm.llm.half()
+ script = torch.jit.script(llm_llm)
+ script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
+ script = torch.jit.optimize_for_inference(script)
+ script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
+
+ # 3. export flow encoder
+ flow_encoder = cosyvoice.model.flow.encoder
+ script = torch.jit.script(flow_encoder)
+ script = torch.jit.freeze(script)
+ script = torch.jit.optimize_for_inference(script)
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cosyvoice/bin/export_jit_cosyvoice2.py b/cosyvoice/bin/export_jit_cosyvoice2.py
new file mode 100644
index 0000000000000000000000000000000000000000..009b4d16abc4459454d8f247b6c2a012e5b3213a
--- /dev/null
+++ b/cosyvoice/bin/export_jit_cosyvoice2.py
@@ -0,0 +1,60 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import argparse
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+import os
+import sys
+import torch
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.append('{}/../..'.format(ROOT_DIR))
+sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
+from cosyvoice.cli.cosyvoice import CosyVoice2
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='export your model for deployment')
+ parser.add_argument('--model_dir',
+ type=str,
+ default='pretrained_models/CosyVoice-300M',
+ help='local path')
+ args = parser.parse_args()
+ print(args)
+ return args
+
+
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+
+ torch._C._jit_set_fusion_strategy([('STATIC', 1)])
+ torch._C._jit_set_profiling_mode(False)
+ torch._C._jit_set_profiling_executor(False)
+
+ cosyvoice = CosyVoice2(args.model_dir, load_jit=False, load_onnx=False)
+
+ # 3. export flow encoder
+ flow_encoder = cosyvoice.model.flow.encoder
+ script = torch.jit.script(flow_encoder)
+ script = torch.jit.freeze(script)
+ script = torch.jit.optimize_for_inference(script)
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cosyvoice/bin/export_onnx.py b/cosyvoice/bin/export_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4051f64e7e49cb1054bc346b2b5f722c86a00c0
--- /dev/null
+++ b/cosyvoice/bin/export_onnx.py
@@ -0,0 +1,112 @@
+# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import argparse
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+import os
+import sys
+import onnxruntime
+import random
+import torch
+from tqdm import tqdm
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.append('{}/../..'.format(ROOT_DIR))
+sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
+from cosyvoice.cli.cosyvoice import CosyVoice
+
+
+def get_dummy_input(batch_size, seq_len, out_channels, device):
+ x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
+ mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
+ mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
+ t = torch.rand((batch_size), dtype=torch.float32, device=device)
+ spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
+ cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
+ return x, mask, mu, t, spks, cond
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='export your model for deployment')
+ parser.add_argument('--model_dir',
+ type=str,
+ default='pretrained_models/CosyVoice-300M',
+ help='local path')
+ args = parser.parse_args()
+ print(args)
+ return args
+
+
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+
+ cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
+
+ # 1. export flow decoder estimator
+ estimator = cosyvoice.model.flow.decoder.estimator
+
+ device = cosyvoice.model.device
+ batch_size, seq_len = 1, 256
+ out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
+ torch.onnx.export(
+ estimator,
+ (x, mask, mu, t, spks, cond),
+ '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
+ export_params=True,
+ opset_version=18,
+ do_constant_folding=True,
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
+ output_names=['estimator_out'],
+ dynamic_axes={
+ 'x': {0: 'batch_size', 2: 'seq_len'},
+ 'mask': {0: 'batch_size', 2: 'seq_len'},
+ 'mu': {0: 'batch_size', 2: 'seq_len'},
+ 'cond': {0: 'batch_size', 2: 'seq_len'},
+ 't': {0: 'batch_size'},
+ 'spks': {0: 'batch_size'},
+ 'estimator_out': {0: 'batch_size', 2: 'seq_len'},
+ }
+ )
+
+ # 2. test computation consistency
+ option = onnxruntime.SessionOptions()
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+ option.intra_op_num_threads = 1
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
+ estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
+ sess_options=option, providers=providers)
+
+ for _ in tqdm(range(10)):
+ x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
+ output_pytorch = estimator(x, mask, mu, t, spks, cond)
+ ort_inputs = {
+ 'x': x.cpu().numpy(),
+ 'mask': mask.cpu().numpy(),
+ 'mu': mu.cpu().numpy(),
+ 't': t.cpu().numpy(),
+ 'spks': spks.cpu().numpy(),
+ 'cond': cond.cpu().numpy()
+ }
+ output_onnx = estimator_onnx.run(None, ort_inputs)[0]
+ torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/cosyvoice/bin/export_onnx_cosyvoice2.py b/cosyvoice/bin/export_onnx_cosyvoice2.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b092a8b38f8dd25833d48e44b6642f0df680aaa
--- /dev/null
+++ b/cosyvoice/bin/export_onnx_cosyvoice2.py
@@ -0,0 +1,110 @@
+# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import argparse
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+import os
+import sys
+import onnxruntime
+import random
+import torch
+from tqdm import tqdm
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.append('{}/../..'.format(ROOT_DIR))
+sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
+from cosyvoice.cli.cosyvoice import CosyVoice2
+
+
+def get_dummy_input(batch_size, seq_len, out_channels, device):
+ x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
+ mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
+ mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
+ t = torch.rand((batch_size), dtype=torch.float32, device=device)
+ spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
+ cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
+ return x, mask, mu, t, spks, cond
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='export your model for deployment')
+ parser.add_argument('--model_dir',
+ type=str,
+ default='pretrained_models/CosyVoice-300M',
+ help='local path')
+ args = parser.parse_args()
+ print(args)
+ return args
+
+
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+
+ cosyvoice = CosyVoice2(args.model_dir, load_jit=False, load_onnx=False)
+
+ # 1. export flow decoder estimator
+ estimator = cosyvoice.model.flow.decoder.estimator
+
+ device = cosyvoice.model.device
+ batch_size, seq_len = 2, 320
+ out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
+ torch.onnx.export(
+ estimator,
+ (x, mask, mu, t, spks, cond),
+ '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
+ export_params=True,
+ opset_version=18,
+ do_constant_folding=True,
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
+ output_names=['estimator_out'],
+ dynamic_axes={
+ 'x': {2: 'seq_len'},
+ 'mask': {2: 'seq_len'},
+ 'mu': {2: 'seq_len'},
+ 'cond': {2: 'seq_len'},
+ 'estimator_out': {2: 'seq_len'},
+ }
+ )
+
+ # 2. test computation consistency
+ option = onnxruntime.SessionOptions()
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+ option.intra_op_num_threads = 1
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
+ estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
+ sess_options=option, providers=providers)
+
+ for _ in tqdm(range(10)):
+ x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
+ output_pytorch = estimator(x, mask, mu, t, spks, cond)
+ ort_inputs = {
+ 'x': x.cpu().numpy(),
+ 'mask': mask.cpu().numpy(),
+ 'mu': mu.cpu().numpy(),
+ 't': t.cpu().numpy(),
+ 'spks': spks.cpu().numpy(),
+ 'cond': cond.cpu().numpy()
+ }
+ output_onnx = estimator_onnx.run(None, ort_inputs)[0]
+ torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/cosyvoice/bin/export_trt_cosyvoce2.sh b/cosyvoice/bin/export_trt_cosyvoce2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a4eb227787ee18d83489953803a7bac743d51d5e
--- /dev/null
+++ b/cosyvoice/bin/export_trt_cosyvoce2.sh
@@ -0,0 +1,3 @@
+#!/bin/bash
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/mnt/lyuxiang.lx/data/TensorRT-10.0.1.6-cu124/TensorRT-10.0.1.6/lib:/usr/local/cuda-12.4/lib64
+/mnt/lyuxiang.lx/data/TensorRT-10.0.1.6-cu124/TensorRT-10.0.1.6/bin/trtexec --onnx=/mnt/lyuxiang.lx/CosyVoice_github/pretrained_models/CosyVoice2-0.5B/flow.decoder.estimator.fp32.onnx --saveEngine=/mnt/lyuxiang.lx/CosyVoice_github/pretrained_models/CosyVoice2-0.5B/flow.decoder.estimator.fp16.Volta.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw
diff --git a/cosyvoice/bin/inference.py b/cosyvoice/bin/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cb831a5f3e93ad2c8f5447920acecb8238818b9
--- /dev/null
+++ b/cosyvoice/bin/inference.py
@@ -0,0 +1,115 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import argparse
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+import os
+import torch
+from torch.utils.data import DataLoader
+import torchaudio
+from hyperpyyaml import load_hyperpyyaml
+from tqdm import tqdm
+from cosyvoice.cli.model import CosyVoiceModel
+from cosyvoice.dataset.dataset import Dataset
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='inference with your model')
+ parser.add_argument('--config', required=True, help='config file')
+ parser.add_argument('--prompt_data', required=True, help='prompt data file')
+ parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
+ parser.add_argument('--tts_text', required=True, help='tts input file')
+ parser.add_argument('--llm_model', required=True, help='llm model file')
+ parser.add_argument('--flow_model', required=True, help='flow model file')
+ parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
+ parser.add_argument('--gpu',
+ type=int,
+ default=-1,
+ help='gpu id for this rank, -1 for cpu')
+ parser.add_argument('--mode',
+ default='sft',
+ choices=['sft', 'zero_shot'],
+ help='inference mode')
+ parser.add_argument('--result_dir', required=True, help='asr result file')
+ args = parser.parse_args()
+ print(args)
+ return args
+
+
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
+
+ # Init cosyvoice models from configs
+ use_cuda = args.gpu >= 0 and torch.cuda.is_available()
+ device = torch.device('cuda' if use_cuda else 'cpu')
+ with open(args.config, 'r') as f:
+ configs = load_hyperpyyaml(f)
+
+ model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
+ model.load(args.llm_model, args.flow_model, args.hifigan_model)
+
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
+ tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
+ test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
+
+ del configs
+ os.makedirs(args.result_dir, exist_ok=True)
+ fn = os.path.join(args.result_dir, 'wav.scp')
+ f = open(fn, 'w')
+ with torch.no_grad():
+ for _, batch in tqdm(enumerate(test_data_loader)):
+ utts = batch["utts"]
+ assert len(utts) == 1, "inference mode only support batchsize 1"
+ text_token = batch["text_token"].to(device)
+ text_token_len = batch["text_token_len"].to(device)
+ tts_index = batch["tts_index"]
+ tts_text_token = batch["tts_text_token"].to(device)
+ tts_text_token_len = batch["tts_text_token_len"].to(device)
+ speech_token = batch["speech_token"].to(device)
+ speech_token_len = batch["speech_token_len"].to(device)
+ speech_feat = batch["speech_feat"].to(device)
+ speech_feat_len = batch["speech_feat_len"].to(device)
+ utt_embedding = batch["utt_embedding"].to(device)
+ spk_embedding = batch["spk_embedding"].to(device)
+ if args.mode == 'sft':
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
+ 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
+ else:
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
+ 'prompt_text': text_token, 'prompt_text_len': text_token_len,
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
+ 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
+ tts_speeches = []
+ for model_output in model.tts(**model_input):
+ tts_speeches.append(model_output['tts_speech'])
+ tts_speeches = torch.concat(tts_speeches, dim=1)
+ tts_key = '{}_{}'.format(utts[0], tts_index[0])
+ tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
+ torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
+ f.write('{} {}\n'.format(tts_key, tts_fn))
+ f.flush()
+ f.close()
+ logging.info('Result wav.scp saved in {}'.format(fn))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cosyvoice/bin/train.py b/cosyvoice/bin/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b4710e4df9a13f3b6a41f8da6ae38b55af1d8ac
--- /dev/null
+++ b/cosyvoice/bin/train.py
@@ -0,0 +1,170 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+import argparse
+import datetime
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+from copy import deepcopy
+import os
+import torch
+import torch.distributed as dist
+import deepspeed
+
+from hyperpyyaml import load_hyperpyyaml
+
+from torch.distributed.elastic.multiprocessing.errors import record
+
+from cosyvoice.utils.executor import Executor
+from cosyvoice.utils.train_utils import (
+ init_distributed,
+ init_dataset_and_dataloader,
+ init_optimizer_and_scheduler,
+ init_summarywriter, save_model,
+ wrap_cuda_model, check_modify_and_save_config)
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='training your network')
+ parser.add_argument('--train_engine',
+ default='torch_ddp',
+ choices=['torch_ddp', 'deepspeed'],
+ help='Engine for paralleled training')
+ parser.add_argument('--model', required=True, help='model which will be trained')
+ parser.add_argument('--config', required=True, help='config file')
+ parser.add_argument('--train_data', required=True, help='train data file')
+ parser.add_argument('--cv_data', required=True, help='cv data file')
+ parser.add_argument('--checkpoint', help='checkpoint model')
+ parser.add_argument('--model_dir', required=True, help='save model dir')
+ parser.add_argument('--tensorboard_dir',
+ default='tensorboard',
+ help='tensorboard log dir')
+ parser.add_argument('--ddp.dist_backend',
+ dest='dist_backend',
+ default='nccl',
+ choices=['nccl', 'gloo'],
+ help='distributed backend')
+ parser.add_argument('--num_workers',
+ default=0,
+ type=int,
+ help='num of subprocess workers for reading')
+ parser.add_argument('--prefetch',
+ default=100,
+ type=int,
+ help='prefetch number')
+ parser.add_argument('--pin_memory',
+ action='store_true',
+ default=False,
+ help='Use pinned memory buffers used for reading')
+ parser.add_argument('--use_amp',
+ action='store_true',
+ default=False,
+ help='Use automatic mixed precision training')
+ parser.add_argument('--deepspeed.save_states',
+ dest='save_states',
+ default='model_only',
+ choices=['model_only', 'model+optimizer'],
+ help='save model/optimizer states')
+ parser.add_argument('--timeout',
+ default=60,
+ type=int,
+ help='timeout (in seconds) of cosyvoice_join.')
+ parser = deepspeed.add_config_arguments(parser)
+ args = parser.parse_args()
+ return args
+
+
+@record
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+ # gan train has some special initialization logic
+ gan = True if args.model == 'hifigan' else False
+
+ override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
+ if gan is True:
+ override_dict.pop('hift')
+ with open(args.config, 'r') as f:
+ configs = load_hyperpyyaml(f, overrides=override_dict)
+ if gan is True:
+ configs['train_conf'] = configs['train_conf_gan']
+ configs['train_conf'].update(vars(args))
+
+ # Init env for ddp
+ init_distributed(args)
+
+ # Get dataset & dataloader
+ train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
+ init_dataset_and_dataloader(args, configs, gan)
+
+ # Do some sanity checks and save config to arsg.model_dir
+ configs = check_modify_and_save_config(args, configs)
+
+ # Tensorboard summary
+ writer = init_summarywriter(args)
+
+ # load checkpoint
+ model = configs[args.model]
+ start_step, start_epoch = 0, -1
+ if args.checkpoint is not None:
+ if os.path.exists(args.checkpoint):
+ state_dict = torch.load(args.checkpoint, map_location='cpu')
+ model.load_state_dict(state_dict, strict=False)
+ if 'step' in state_dict:
+ start_step = state_dict['step']
+ if 'epoch' in state_dict:
+ start_epoch = state_dict['epoch']
+ else:
+ logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
+
+ # Dispatch model from cpu to gpu
+ model = wrap_cuda_model(args, model)
+
+ # Get optimizer & scheduler
+ model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
+ scheduler.set_step(start_step)
+ if scheduler_d is not None:
+ scheduler_d.set_step(start_step)
+
+ # Save init checkpoints
+ info_dict = deepcopy(configs['train_conf'])
+ info_dict['step'] = start_step
+ info_dict['epoch'] = start_epoch
+ save_model(model, 'init', info_dict)
+
+ # Get executor
+ executor = Executor(gan=gan)
+ executor.step = start_step
+
+ # Init scaler, used for pytorch amp mixed precision training
+ scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
+ print('start step {} start epoch {}'.format(start_step, start_epoch))
+ # Start training loop
+ for epoch in range(start_epoch + 1, info_dict['max_epoch']):
+ executor.epoch = epoch
+ train_dataset.set_epoch(epoch)
+ dist.barrier()
+ group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
+ if gan is True:
+ executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
+ writer, info_dict, scaler, group_join)
+ else:
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
+ dist.destroy_process_group(group_join)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cosyvoice/cli/__init__.py b/cosyvoice/cli/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f642348ee0f2a0ce4d9b80a27fa309f03455134
--- /dev/null
+++ b/cosyvoice/cli/cosyvoice.py
@@ -0,0 +1,167 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import time
+from tqdm import tqdm
+from hyperpyyaml import load_hyperpyyaml
+from modelscope import snapshot_download
+import torch
+from cosyvoice.cli.frontend import CosyVoiceFrontEnd
+from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
+from cosyvoice.utils.file_utils import logging
+
+
+class CosyVoice:
+
+ def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
+ instruct = True if '-Instruct' in model_dir else False
+ self.model_dir = model_dir
+ if not os.path.exists(model_dir):
+ model_dir = snapshot_download(model_dir)
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
+ configs = load_hyperpyyaml(f)
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
+ configs['feat_extractor'],
+ '{}/campplus.onnx'.format(model_dir),
+ '{}/speech_tokenizer_v1.onnx'.format(model_dir),
+ '{}/spk2info.pt'.format(model_dir),
+ instruct,
+ configs['allowed_special'])
+ self.sample_rate = configs['sample_rate']
+ if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
+ load_jit = False
+ fp16 = False
+ logging.warning('cpu do not support fp16 and jit, force set to False')
+ self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
+ self.model.load('{}/llm.pt'.format(model_dir),
+ '{}/flow.pt'.format(model_dir),
+ '{}/hift.pt'.format(model_dir))
+ if load_jit:
+ self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
+ '{}/llm.llm.fp16.zip'.format(model_dir),
+ '{}/flow.encoder.fp32.zip'.format(model_dir))
+ if load_onnx:
+ self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
+ del configs
+
+ def list_avaliable_spks(self):
+ spks = list(self.frontend.spk2info.keys())
+ return spks
+
+ def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0):
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
+ model_input = self.frontend.frontend_sft(i, spk_id)
+ start_time = time.time()
+ logging.info('synthesis text {}'.format(i))
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+ yield model_output
+ start_time = time.time()
+
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0):
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False)
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
+ if len(i) < 0.5 * len(prompt_text):
+ logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
+ start_time = time.time()
+ logging.info('synthesis text {}'.format(i))
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
+ logging.info('yield speech len {}, rtf {}, abs mean {}, std {}'.format(speech_len, (time.time() - start_time) / speech_len, model_output['tts_speech'].abs().mean(), model_output['tts_speech'].std()))
+ yield model_output
+ start_time = time.time()
+
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0):
+ if self.frontend.instruct is True:
+ raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
+ model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
+ start_time = time.time()
+ logging.info('synthesis text {}'.format(i))
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+ yield model_output
+ start_time = time.time()
+
+ def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0):
+ if self.frontend.instruct is False:
+ raise ValueError('{} do not support instruct inference'.format(self.model_dir))
+ instruct_text = self.frontend.text_normalize(instruct_text, split=False)
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
+ model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
+ start_time = time.time()
+ logging.info('synthesis text {}'.format(i))
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+ yield model_output
+ start_time = time.time()
+
+ def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0):
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
+ model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
+ start_time = time.time()
+ logging.info('synthesis text {}'.format(i))
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
+ logging.info('yield speech len {}, rtf {}, abs mean {}, std {}'.format(speech_len, (time.time() - start_time) / speech_len, model_output['tts_speech'].abs().mean(), model_output['tts_speech'].std()))
+ yield model_output
+ start_time = time.time()
+
+ def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
+ model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
+ start_time = time.time()
+ for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+ yield model_output
+ start_time = time.time()
+
+class CosyVoice2(CosyVoice):
+
+ def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
+ instruct = True if '-Instruct' in model_dir else False
+ self.model_dir = model_dir
+ if not os.path.exists(model_dir):
+ model_dir = snapshot_download(model_dir)
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
+ configs['feat_extractor'],
+ '{}/campplus.onnx'.format(model_dir),
+ '{}/speech_tokenizer_v2.onnx'.format(model_dir),
+ '{}/spk2info.pt'.format(model_dir),
+ instruct,
+ configs['allowed_special'])
+ self.sample_rate = configs['sample_rate']
+ if torch.cuda.is_available() is False and load_jit is True:
+ load_jit = False
+ logging.warning('cpu do not support jit, force set to False')
+ self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
+ self.model.load('{}/llm.pt'.format(model_dir),
+ '{}/flow.pt'.format(model_dir),
+ '{}/hift.pt'.format(model_dir))
+ if load_jit:
+ self.model.load_jit('{}/flow.encoder.fp32.zip'.format(model_dir))
+ if load_trt is True and load_onnx is True:
+ load_onnx = False
+ logging.warning('can not set both load_trt and load_onnx to True, force set load_onnx to False')
+ if load_onnx:
+ self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
+ if load_trt:
+ self.model.load_trt('{}/flow.decoder.estimator.fp16.A10.plan'.format(model_dir))
+ del configs
\ No newline at end of file
diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py
new file mode 100644
index 0000000000000000000000000000000000000000..d312cccb8cfbb92a61e8fdd9b066b6ee1ae15da9
--- /dev/null
+++ b/cosyvoice/cli/frontend.py
@@ -0,0 +1,213 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from functools import partial
+import json
+import onnxruntime
+import torch
+import numpy as np
+import whisper
+from typing import Callable
+import torchaudio.compliance.kaldi as kaldi
+import torchaudio
+import os
+import re
+import inflect
+try:
+ import ttsfrd
+ use_ttsfrd = True
+except ImportError:
+ print("failed to import ttsfrd, use WeTextProcessing instead")
+ from tn.chinese.normalizer import Normalizer as ZhNormalizer
+ from tn.english.normalizer import Normalizer as EnNormalizer
+ use_ttsfrd = False
+from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
+
+
+class CosyVoiceFrontEnd:
+
+ def __init__(self,
+ get_tokenizer: Callable,
+ feat_extractor: Callable,
+ campplus_model: str,
+ speech_tokenizer_model: str,
+ spk2info: str = '',
+ instruct: bool = False,
+ allowed_special: str = 'all'):
+ self.tokenizer = get_tokenizer()
+ self.feat_extractor = feat_extractor
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ option = onnxruntime.SessionOptions()
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+ option.intra_op_num_threads = 1
+ self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
+ providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
+ "CPUExecutionProvider"])
+ if os.path.exists(spk2info):
+ self.spk2info = torch.load(spk2info, map_location=self.device)
+ else:
+ self.spk2info = {}
+ self.instruct = instruct
+ self.allowed_special = allowed_special
+ self.inflect_parser = inflect.engine()
+ self.use_ttsfrd = use_ttsfrd
+ if self.use_ttsfrd:
+ self.frd = ttsfrd.TtsFrontendEngine()
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+ assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
+ 'failed to initialize ttsfrd resource'
+ self.frd.set_lang_type('pinyinvg')
+ else:
+ self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
+ self.en_tn_model = EnNormalizer()
+
+ def _extract_text_token(self, text):
+ text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
+ return text_token, text_token_len
+
+ def _extract_speech_token(self, speech):
+ assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
+ speech_token = self.speech_tokenizer_session.run(None,
+ {self.speech_tokenizer_session.get_inputs()[0].name:
+ feat.detach().cpu().numpy(),
+ self.speech_tokenizer_session.get_inputs()[1].name:
+ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
+ return speech_token, speech_token_len
+
+ def _extract_spk_embedding(self, speech):
+ feat = kaldi.fbank(speech,
+ num_mel_bins=80,
+ dither=0,
+ sample_frequency=16000)
+ feat = feat - feat.mean(dim=0, keepdim=True)
+ embedding = self.campplus_session.run(None,
+ {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
+ embedding = torch.tensor([embedding]).to(self.device)
+ return embedding
+
+ def _extract_speech_feat(self, speech):
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
+ speech_feat = speech_feat.unsqueeze(dim=0)
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
+ return speech_feat, speech_feat_len
+
+ def text_normalize(self, text, split=True):
+ text = text.strip()
+ if contains_chinese(text):
+ if self.use_ttsfrd:
+ texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
+ text = ''.join(texts)
+ else:
+ text = self.zh_tn_model.normalize(text)
+ text = text.replace("\n", "")
+ text = replace_blank(text)
+ text = replace_corner_mark(text)
+ text = text.replace(".", "。")
+ text = text.replace(" - ", ",")
+ text = remove_bracket(text)
+ text = re.sub(r'[,,、]+$', '。', text)
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
+ token_min_n=60, merge_len=20, comma_split=False))
+ else:
+ if self.use_ttsfrd:
+ texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
+ text = ''.join(texts)
+ else:
+ text = self.en_tn_model.normalize(text)
+ text = spell_out_number(text, self.inflect_parser)
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
+ token_min_n=60, merge_len=20, comma_split=False))
+ if split is False:
+ return text
+ return texts
+
+ def frontend_sft(self, tts_text, spk_id):
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
+ embedding = self.spk2info[spk_id]['embedding']
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
+ return model_input
+
+ def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
+ if resample_rate == 24000:
+ # cosyvoice2, force speech_feat % speech_token = 2
+ token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
+ speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2* token_len
+ speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
+ 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
+ return model_input
+
+ def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(instruct_text + '<|endofprompt|>')
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
+ if resample_rate == 24000:
+ # cosyvoice2, force speech_feat % speech_token = 2
+ token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
+ speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2* token_len
+ speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
+ 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
+ return model_input
+
+ def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
+ model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
+ # in cross lingual mode, we remove prompt in llm
+ del model_input['prompt_text']
+ del model_input['prompt_text_len']
+ del model_input['llm_prompt_speech_token']
+ del model_input['llm_prompt_speech_token_len']
+ return model_input
+
+ def frontend_instruct(self, tts_text, spk_id, instruct_text):
+ model_input = self.frontend_sft(tts_text, spk_id)
+ # in instruct mode, we remove spk_embedding in llm due to information leakage
+ del model_input['llm_embedding']
+ instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '
+
+