Spaces:
Build error
Build error
| import os | |
| import sys | |
| import argparse | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| import random # 即使没有随机种子UI,set_all_random_seed可能还用 | |
| import librosa | |
| ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.append(os.path.join(ROOT_DIR, 'third_party', 'Matcha-TTS')) # 使用os.path.join更安全 | |
| from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 | |
| from cosyvoice.utils.file_utils import load_wav, logging | |
| from cosyvoice.utils.common import set_all_random_seed | |
| max_val = 0.8 # 保持音频归一化参数 | |
| cosyvoice = None | |
| prompt_sr = 16000 # prompt 音频采样率 | |
| default_data = None # 默认静音音频数据,在 cosyvoice 初始化后定义 | |
| def postprocess(speech, top_db=60, hop_length=220, win_length=440): | |
| """ | |
| 后处理函数,处理音频数据(包括归一化、去除静音、添加尾部静音)。 | |
| 输入: speech (torch.Tensor), 可能是 (N,) 或 (C, N) | |
| 输出: out (torch.Tensor), 始终为 (1, N') | |
| """ | |
| # 核心修复点:将 torch.Tensor 转换为 numpy.ndarray 以便 librosa 处理 | |
| # 并确保是单声道 | |
| speech_np = speech.cpu().numpy() | |
| if speech_np.ndim > 1: # 如果是多声道 (C, N) | |
| speech_np = speech_np[0] # 取第一个通道,变为 (N,) | |
| # 去除开头结尾静音 (librosa 操作 numpy 数组) | |
| speech_trimmed_np, _ = librosa.effects.trim( | |
| speech_np, top_db=top_db, | |
| frame_length=win_length, | |
| hop_length=hop_length | |
| ) | |
| speech_trimmed_tensor = torch.from_numpy(speech_trimmed_np).to(speech.device).float() | |
| if speech_trimmed_tensor.ndim == 1: | |
| speech_trimmed_tensor = speech_trimmed_tensor.unsqueeze(0) # 从 (N,) 变为 (1, N) | |
| if speech_trimmed_tensor.abs().max() > max_val: | |
| speech_trimmed_tensor = speech_trimmed_tensor / speech_trimmed_tensor.abs().max() * max_val | |
| pad_tensor = torch.zeros(1, int(cosyvoice.sample_rate * 0.2), device=speech_trimmed_tensor.device, dtype=speech_trimmed_tensor.dtype) | |
| out = torch.cat([speech_trimmed_tensor, pad_tensor], dim=1) | |
| return out | |
| def generate_audio( | |
| tts_text: str, | |
| prompt_wav_upload: str, | |
| prompt_wav_record: str, | |
| prompt_text: str | |
| ): | |
| """ | |
| 根据输入文本和prompt音频生成语音(仅支持3s极速复刻模式)。 | |
| """ | |
| global cosyvoice, default_data # 确保能访问全局变量 | |
| if cosyvoice is None: | |
| gr.Info("模型未初始化,请检查启动配置。") | |
| # yield (cosyvoice.sample_rate, default_data) # yield 仅用于生成器函数 | |
| return None # 对于非生成器函数,返回 None 清除输出 | |
| if prompt_wav_upload is not None: | |
| prompt_wav = prompt_wav_upload | |
| elif prompt_wav_record is not None: | |
| prompt_wav = prompt_wav_record | |
| else: | |
| prompt_wav = None | |
| # 针对 3s极速复刻 模式的检查 | |
| if prompt_wav is None: | |
| gr.Info('prompt音频为空,您是否忘记输入prompt音频?') # 使用 gr.Info 弹窗 | |
| return None | |
| # 检查采样率 | |
| try: | |
| # 核心修复点:torchaudio.info 返回 AudioMetaData,从中获取采样率 | |
| info = torchaudio.info(prompt_wav) | |
| if info.sample_rate < prompt_sr: | |
| gr.Info(f"prompt 音频采样率过低:{info.sample_rate} < {prompt_sr}") # 使用 gr.Info 弹窗 | |
| return None | |
| except Exception as e: | |
| gr.Info(f"无法读取 prompt 音频信息,请检查文件格式或损坏:{e}") # 使用 gr.Info 弹窗 | |
| return None | |
| if not prompt_text: | |
| gr.Info('prompt文本为空,您是否忘记输入prompt文本?') # 使用 gr.Info 弹窗 | |
| return None | |
| # 处理 prompt 音频 | |
| try: | |
| # 核心修复点:load_wav(filepath, sr) 返回一个 torch.Tensor,不是 (wav, sr) 元组 | |
| wav_tensor = load_wav(prompt_wav, prompt_sr) | |
| prompt_speech_16k = postprocess(wav_tensor) # postprocess 现在可以处理 torch.Tensor | |
| except Exception as e: | |
| gr.Info(f"处理 prompt 音频时出错:{e}") | |
| return None | |
| set_all_random_seed(0) # 对应 generate_seed 函数的移除 | |
| logging.info("执行 3s 极速复刻 推理") | |
| try: | |
| result = next(cosyvoice.inference_zero_shot( | |
| tts_text, | |
| prompt_text, | |
| prompt_speech_16k, | |
| stream=False, | |
| speed=1.0 | |
| )) | |
| audio = result["tts_speech"].numpy().flatten() | |
| return cosyvoice.sample_rate, audio | |
| except Exception as e: | |
| gr.Info(f"推理过程中发生错误:{e}") | |
| # 发生错误时返回静音数据,而不是 None,这样 Gradio Audio 组件不会报错 | |
| return cosyvoice.sample_rate, default_data | |
| def main(): | |
| with gr.Blocks() as demo: | |
| # 简化 Gradio Markdown 提示 | |
| gr.Markdown("### SMIIP-NV finetune CosyVoice2") | |
| gr.Markdown("#### 上传一段 ≤30s 的 prompt 音频,填写对应文本,合成目标语音。") | |
| tts_text = gr.Textbox(label="输入合成文本", lines=1, value="在这个孤独的夜晚<crying>,窗外的雨声让我想起了你,<crying>我真的好想你。") | |
| with gr.Row(): | |
| # Gradio 4.x 更改:sources 参数使用列表 | |
| prompt_wav_upload = gr.Audio(sources=['upload'], type='filepath', label='选择prompt音频文件,注意采样率不低于16khz') | |
| prompt_wav_record = gr.Audio(sources=['microphone'], type='filepath', label='录制prompt音频文件') | |
| prompt_text = gr.Textbox(label="输入prompt文本", lines=1, placeholder="请输入prompt文本,需与prompt音频内容一致,暂时不支持自动识别...", value='') | |
| generate_button = gr.Button("生成音频") | |
| audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=True) | |
| generate_button.click(generate_audio, | |
| inputs=[tts_text, prompt_wav_upload, prompt_wav_record, prompt_text], | |
| outputs=[audio_output]) | |
| demo.queue(max_size=4, default_concurrency_limit=2) | |
| demo.launch(server_name='0.0.0.0', server_port=args.port) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--port', | |
| type=int, | |
| default=8000, | |
| help="服务启动端口") | |
| parser.add_argument('--model_dir', | |
| type=str, | |
| default='pretrained_models/CosyVoice2-0.5B', | |
| help='local path or modelscope repo id') | |
| args = parser.parse_args() | |
| # 初始化模型,并捕获错误 | |
| try: | |
| cosyvoice = CosyVoice(args.model_dir) | |
| print("CosyVoice 模型加载成功!") | |
| except Exception as e: | |
| print(f"加载 CosyVoice 模型失败:{e},尝试加载 CosyVoice2...") | |
| try: | |
| cosyvoice = CosyVoice2(args.model_dir) | |
| print("CosyVoice2 模型加载成功!") | |
| except Exception as e2: | |
| print(f"加载 CosyVoice2 模型也失败了:{e2}") | |
| raise TypeError('no valid model_type found for model_dir: ' + args.model_dir + f'\nError: {e2}') | |
| default_data = np.zeros(cosyvoice.sample_rate, dtype=np.float32) | |
| main() | |