| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import sys |
| import argparse |
| import gradio as gr |
| import numpy as np |
| import torch |
| import torchaudio |
| import random |
| import librosa |
|
|
| ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.append(os.path.join(ROOT_DIR, 'third_party', 'Matcha-TTS')) |
|
|
| from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 |
| from cosyvoice.utils.file_utils import load_wav, logging |
| from cosyvoice.utils.common import set_all_random_seed |
|
|
| |
| |
| |
| |
|
|
| max_val = 0.8 |
|
|
| |
| cosyvoice = None |
| prompt_sr = 16000 |
| default_data = None |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def postprocess(speech, top_db=60, hop_length=220, win_length=440): |
| """ |
| 后处理函数,处理音频数据(包括归一化、去除静音、添加尾部静音)。 |
| 输入: speech (torch.Tensor), 可能是 (N,) 或 (C, N) |
| 输出: out (torch.Tensor), 始终为 (1, N') |
| """ |
| |
| |
| speech_np = speech.cpu().numpy() |
| if speech_np.ndim > 1: |
| speech_np = speech_np[0] |
|
|
| |
| speech_trimmed_np, _ = librosa.effects.trim( |
| speech_np, top_db=top_db, |
| frame_length=win_length, |
| hop_length=hop_length |
| ) |
|
|
| |
| speech_trimmed_tensor = torch.from_numpy(speech_trimmed_np).to(speech.device).float() |
| |
| |
| if speech_trimmed_tensor.ndim == 1: |
| speech_trimmed_tensor = speech_trimmed_tensor.unsqueeze(0) |
|
|
| |
| if speech_trimmed_tensor.abs().max() > max_val: |
| speech_trimmed_tensor = speech_trimmed_tensor / speech_trimmed_tensor.abs().max() * max_val |
| |
| |
| |
| pad_tensor = torch.zeros(1, int(cosyvoice.sample_rate * 0.2), device=speech_trimmed_tensor.device, dtype=speech_trimmed_tensor.dtype) |
| |
| |
| out = torch.cat([speech_trimmed_tensor, pad_tensor], dim=1) |
| return out |
|
|
|
|
| |
| |
| |
|
|
|
|
| def generate_audio( |
| tts_text: str, |
| prompt_wav_upload: str, |
| prompt_wav_record: str, |
| prompt_text: str |
| |
| |
| ): |
| """ |
| 根据输入文本和prompt音频生成语音(仅支持3s极速复刻模式)。 |
| """ |
| global cosyvoice, default_data |
|
|
| if cosyvoice is None: |
| gr.Info("模型未初始化,请检查启动配置。") |
| |
| return None |
|
|
| if prompt_wav_upload is not None: |
| prompt_wav = prompt_wav_upload |
| elif prompt_wav_record is not None: |
| prompt_wav = prompt_wav_record |
| else: |
| prompt_wav = None |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| if prompt_wav is None: |
| gr.Info('prompt音频为空,您是否忘记输入prompt音频?') |
| return None |
|
|
| |
| try: |
| |
| info = torchaudio.info(prompt_wav) |
| if info.sample_rate < prompt_sr: |
| gr.Info(f"prompt 音频采样率过低:{info.sample_rate} < {prompt_sr}") |
| return None |
| except Exception as e: |
| gr.Info(f"无法读取 prompt 音频信息,请检查文件格式或损坏:{e}") |
| return None |
|
|
| if not prompt_text: |
| gr.Info('prompt文本为空,您是否忘记输入prompt文本?') |
| return None |
| |
| |
|
|
| |
| try: |
| |
| wav_tensor = load_wav(prompt_wav, prompt_sr) |
| prompt_speech_16k = postprocess(wav_tensor) |
| except Exception as e: |
| gr.Info(f"处理 prompt 音频时出错:{e}") |
| return None |
|
|
| |
| set_all_random_seed(0) |
| logging.info("执行 3s 极速复刻 推理") |
| |
| try: |
| |
| |
| |
| result = next(cosyvoice.inference_zero_shot( |
| tts_text, |
| prompt_text, |
| prompt_speech_16k, |
| stream=False, |
| speed=1.0 |
| )) |
| audio = result["tts_speech"].numpy().flatten() |
| return cosyvoice.sample_rate, audio |
| except Exception as e: |
| gr.Info(f"推理过程中发生错误:{e}") |
| |
| return cosyvoice.sample_rate, default_data |
|
|
|
|
| def main(): |
| with gr.Blocks() as demo: |
| |
| gr.Markdown("### SMIIP-NV finetune CosyVoice2") |
| gr.Markdown("#### 上传一段 ≤30s 的 prompt 音频,填写对应文本,合成目标语音。") |
|
|
| tts_text = gr.Textbox(label="输入合成文本", lines=1, value="在这个孤独的夜晚<crying>,窗外的雨声让我想起了你,<crying>我真的好想你。") |
| |
| |
| |
|
|
| with gr.Row(): |
| |
| prompt_wav_upload = gr.Audio(sources=['upload'], type='filepath', label='选择prompt音频文件,注意采样率不低于16khz') |
| prompt_wav_record = gr.Audio(sources=['microphone'], type='filepath', label='录制prompt音频文件') |
| prompt_text = gr.Textbox(label="输入prompt文本", lines=1, placeholder="请输入prompt文本,需与prompt音频内容一致,暂时不支持自动识别...", value='') |
| |
|
|
| generate_button = gr.Button("生成音频") |
|
|
| |
| |
| audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=True) |
|
|
| |
| |
|
|
| |
| generate_button.click(generate_audio, |
| inputs=[tts_text, prompt_wav_upload, prompt_wav_record, prompt_text], |
| outputs=[audio_output]) |
| |
| |
| |
| |
| demo.queue(max_size=4, default_concurrency_limit=2) |
| demo.launch(server_name='0.0.0.0', server_port=args.port) |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--port', |
| type=int, |
| default=8000, |
| help="服务启动端口") |
| parser.add_argument('--model_dir', |
| type=str, |
| default='pretrained_models/CosyVoice2-0.5B', |
| help='local path or modelscope repo id') |
| args = parser.parse_args() |
|
|
| |
| try: |
| cosyvoice = CosyVoice(args.model_dir) |
| print("CosyVoice 模型加载成功!") |
| except Exception as e: |
| print(f"加载 CosyVoice 模型失败:{e},尝试加载 CosyVoice2...") |
| try: |
| cosyvoice = CosyVoice2(args.model_dir) |
| print("CosyVoice2 模型加载成功!") |
| except Exception as e2: |
| print(f"加载 CosyVoice2 模型也失败了:{e2}") |
| |
| raise TypeError('no valid model_type found for model_dir: ' + args.model_dir + f'\nError: {e2}') |
|
|
| |
| |
| |
| |
| |
| |
| default_data = np.zeros(cosyvoice.sample_rate, dtype=np.float32) |
|
|
| main() |
|
|