xunyi's picture
Update web_nv.py
5fa7f1e verified
import os
import sys
import argparse
import gradio as gr
import numpy as np
import torch
import torchaudio
import random # 即使没有随机种子UI,set_all_random_seed可能还用
import librosa
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(ROOT_DIR, 'third_party', 'Matcha-TTS')) # 使用os.path.join更安全
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav, logging
from cosyvoice.utils.common import set_all_random_seed
max_val = 0.8 # 保持音频归一化参数
cosyvoice = None
prompt_sr = 16000 # prompt 音频采样率
default_data = None # 默认静音音频数据,在 cosyvoice 初始化后定义
def postprocess(speech, top_db=60, hop_length=220, win_length=440):
"""
后处理函数,处理音频数据(包括归一化、去除静音、添加尾部静音)。
输入: speech (torch.Tensor), 可能是 (N,) 或 (C, N)
输出: out (torch.Tensor), 始终为 (1, N')
"""
# 核心修复点:将 torch.Tensor 转换为 numpy.ndarray 以便 librosa 处理
# 并确保是单声道
speech_np = speech.cpu().numpy()
if speech_np.ndim > 1: # 如果是多声道 (C, N)
speech_np = speech_np[0] # 取第一个通道,变为 (N,)
# 去除开头结尾静音 (librosa 操作 numpy 数组)
speech_trimmed_np, _ = librosa.effects.trim(
speech_np, top_db=top_db,
frame_length=win_length,
hop_length=hop_length
)
speech_trimmed_tensor = torch.from_numpy(speech_trimmed_np).to(speech.device).float()
if speech_trimmed_tensor.ndim == 1:
speech_trimmed_tensor = speech_trimmed_tensor.unsqueeze(0) # 从 (N,) 变为 (1, N)
if speech_trimmed_tensor.abs().max() > max_val:
speech_trimmed_tensor = speech_trimmed_tensor / speech_trimmed_tensor.abs().max() * max_val
pad_tensor = torch.zeros(1, int(cosyvoice.sample_rate * 0.2), device=speech_trimmed_tensor.device, dtype=speech_trimmed_tensor.dtype)
out = torch.cat([speech_trimmed_tensor, pad_tensor], dim=1)
return out
def generate_audio(
tts_text: str,
prompt_wav_upload: str,
prompt_wav_record: str,
prompt_text: str
):
"""
根据输入文本和prompt音频生成语音(仅支持3s极速复刻模式)。
"""
global cosyvoice, default_data # 确保能访问全局变量
if cosyvoice is None:
gr.Info("模型未初始化,请检查启动配置。")
# yield (cosyvoice.sample_rate, default_data) # yield 仅用于生成器函数
return None # 对于非生成器函数,返回 None 清除输出
if prompt_wav_upload is not None:
prompt_wav = prompt_wav_upload
elif prompt_wav_record is not None:
prompt_wav = prompt_wav_record
else:
prompt_wav = None
# 针对 3s极速复刻 模式的检查
if prompt_wav is None:
gr.Info('prompt音频为空,您是否忘记输入prompt音频?') # 使用 gr.Info 弹窗
return None
# 检查采样率
try:
# 核心修复点:torchaudio.info 返回 AudioMetaData,从中获取采样率
info = torchaudio.info(prompt_wav)
if info.sample_rate < prompt_sr:
gr.Info(f"prompt 音频采样率过低:{info.sample_rate} < {prompt_sr}") # 使用 gr.Info 弹窗
return None
except Exception as e:
gr.Info(f"无法读取 prompt 音频信息,请检查文件格式或损坏:{e}") # 使用 gr.Info 弹窗
return None
if not prompt_text:
gr.Info('prompt文本为空,您是否忘记输入prompt文本?') # 使用 gr.Info 弹窗
return None
# 处理 prompt 音频
try:
# 核心修复点:load_wav(filepath, sr) 返回一个 torch.Tensor,不是 (wav, sr) 元组
wav_tensor = load_wav(prompt_wav, prompt_sr)
prompt_speech_16k = postprocess(wav_tensor) # postprocess 现在可以处理 torch.Tensor
except Exception as e:
gr.Info(f"处理 prompt 音频时出错:{e}")
return None
set_all_random_seed(0) # 对应 generate_seed 函数的移除
logging.info("执行 3s 极速复刻 推理")
try:
result = next(cosyvoice.inference_zero_shot(
tts_text,
prompt_text,
prompt_speech_16k,
stream=False,
speed=1.0
))
audio = result["tts_speech"].numpy().flatten()
return cosyvoice.sample_rate, audio
except Exception as e:
gr.Info(f"推理过程中发生错误:{e}")
# 发生错误时返回静音数据,而不是 None,这样 Gradio Audio 组件不会报错
return cosyvoice.sample_rate, default_data
def main():
with gr.Blocks() as demo:
# 简化 Gradio Markdown 提示
gr.Markdown("### SMIIP-NV finetune CosyVoice2")
gr.Markdown("#### 上传一段 ≤30s 的 prompt 音频,填写对应文本,合成目标语音。")
tts_text = gr.Textbox(label="输入合成文本", lines=1, value="在这个孤独的夜晚<crying>,窗外的雨声让我想起了你,<crying>我真的好想你。")
with gr.Row():
# Gradio 4.x 更改:sources 参数使用列表
prompt_wav_upload = gr.Audio(sources=['upload'], type='filepath', label='选择prompt音频文件,注意采样率不低于16khz')
prompt_wav_record = gr.Audio(sources=['microphone'], type='filepath', label='录制prompt音频文件')
prompt_text = gr.Textbox(label="输入prompt文本", lines=1, placeholder="请输入prompt文本,需与prompt音频内容一致,暂时不支持自动识别...", value='')
generate_button = gr.Button("生成音频")
audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=True)
generate_button.click(generate_audio,
inputs=[tts_text, prompt_wav_upload, prompt_wav_record, prompt_text],
outputs=[audio_output])
demo.queue(max_size=4, default_concurrency_limit=2)
demo.launch(server_name='0.0.0.0', server_port=args.port)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--port',
type=int,
default=8000,
help="服务启动端口")
parser.add_argument('--model_dir',
type=str,
default='pretrained_models/CosyVoice2-0.5B',
help='local path or modelscope repo id')
args = parser.parse_args()
# 初始化模型,并捕获错误
try:
cosyvoice = CosyVoice(args.model_dir)
print("CosyVoice 模型加载成功!")
except Exception as e:
print(f"加载 CosyVoice 模型失败:{e},尝试加载 CosyVoice2...")
try:
cosyvoice = CosyVoice2(args.model_dir)
print("CosyVoice2 模型加载成功!")
except Exception as e2:
print(f"加载 CosyVoice2 模型也失败了:{e2}")
raise TypeError('no valid model_type found for model_dir: ' + args.model_dir + f'\nError: {e2}')
default_data = np.zeros(cosyvoice.sample_rate, dtype=np.float32)
main()