Update app.py
Browse files
app.py
CHANGED
|
@@ -1,232 +1,104 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from transformers import
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
| 5 |
import plotly.graph_objects as go
|
| 6 |
import os
|
| 7 |
import scipy.io.wavfile # For saving/reading wav files
|
| 8 |
-
from diffusers import StableDiffusionImg2ImgPipeline # For Riffusion's image generation part
|
| 9 |
-
from PIL import Image
|
| 10 |
-
import io
|
| 11 |
-
import librosa # For spectrogram to audio conversion
|
| 12 |
|
| 13 |
# --- 全局变量和模型加载 ---
|
| 14 |
HF_HOME = os.getenv("HF_HOME", "./hf_cache")
|
| 15 |
os.makedirs(HF_HOME, exist_ok=True)
|
| 16 |
|
| 17 |
# 定义模型路径和名称
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
| 22 |
|
| 23 |
# 确定加载模型的设备
|
| 24 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
print(f"当前使用的设备: {device}")
|
| 26 |
|
| 27 |
# --- 模型加载函数 ---
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
riffusion_pipeline = None
|
| 31 |
|
| 32 |
-
#
|
| 33 |
-
def
|
| 34 |
-
"""加载
|
| 35 |
-
print(f"正在加载
|
| 36 |
try:
|
| 37 |
-
#
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
|
| 42 |
-
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_CHAT_TTS).to(device)
|
| 43 |
-
processor = AutoProcessor.from_pretrained(MODEL_CHAT_TTS)
|
| 44 |
-
print("ChatTTS 模型加载成功。")
|
| 45 |
-
return model, processor
|
| 46 |
-
except Exception as e:
|
| 47 |
-
print(f"加载 ChatTTS 时出错: {e}. 请确保模型在当前Transformers版本下兼容,或尝试手动加载。", e)
|
| 48 |
-
return None, None
|
| 49 |
-
|
| 50 |
-
def load_stable_audio_model():
|
| 51 |
-
"""加载 stabilityai/stable-audio-open-1.0 模型。"""
|
| 52 |
-
print(f"正在加载 Stable Audio 模型: {MODEL_STABLE_AUDIO} 到 {device}...")
|
| 53 |
-
try:
|
| 54 |
-
stable_audio_pipe = pipeline("text-to-audio", model=MODEL_STABLE_AUDIO, device=device)
|
| 55 |
-
print("Stable Audio 模型加载成功。")
|
| 56 |
-
return stable_audio_pipe
|
| 57 |
except Exception as e:
|
| 58 |
-
print(f"加载
|
| 59 |
return None
|
| 60 |
|
| 61 |
-
#
|
| 62 |
-
def
|
| 63 |
-
"""加载
|
| 64 |
-
print(f"正在加载
|
| 65 |
try:
|
| 66 |
-
#
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
# 实际部署可能需要更复杂的ControlNet集成
|
| 70 |
-
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(MODEL_RIFFUSION_CONTROLNET, torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
|
| 71 |
-
print("Riffusion (Stable Diffusion) 模型加载成功。")
|
| 72 |
return pipe
|
| 73 |
except Exception as e:
|
| 74 |
-
print(f"加载
|
| 75 |
return None
|
| 76 |
|
| 77 |
-
# 全局模型实例
|
| 78 |
-
chattts_model_instance, chattts_processor_instance = None, None
|
| 79 |
-
stable_audio_pipe_instance = None
|
| 80 |
-
riffusion_pipeline_instance = None
|
| 81 |
-
|
| 82 |
def initialize_all_models():
|
| 83 |
"""在 Gradio 界面加载时调用,用于初始化所有模型。"""
|
| 84 |
-
global
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
chattts_model_instance, chattts_processor_instance = load_chattts_model()
|
| 88 |
-
|
| 89 |
-
# 尝试加载 Stable Audio
|
| 90 |
-
stable_audio_pipe_instance = load_stable_audio_model()
|
| 91 |
-
|
| 92 |
-
# 尝试加载 Riffusion
|
| 93 |
-
riffusion_pipeline_instance = load_riffusion_model()
|
| 94 |
-
|
| 95 |
|
| 96 |
# --- 音频生成推理函数 ---
|
| 97 |
|
| 98 |
-
def
|
| 99 |
-
"""使用
|
| 100 |
-
if
|
| 101 |
-
return (16000, np.zeros(0))
|
| 102 |
try:
|
| 103 |
-
#
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
# return (16000, speech.cpu().numpy())
|
| 109 |
-
|
| 110 |
-
# 为了 Gradio pipeline 兼容性,如果 ChatTTS 不直接支持 pipeline,这里用伪代码或简化处理
|
| 111 |
-
# 假设 ChatTTS 可以通过特定的`generate`方法生成原始音频 numpy 数组
|
| 112 |
-
# 实际部署时,你需要根据 ChatTTS 的官方示例来编写
|
| 113 |
-
# 这里用一个简单的TTS pipeline作为fallback或者示意
|
| 114 |
-
if hasattr(chattts_model_instance, 'generate_speech'): # 如果是真正的ChatTTS模型实例
|
| 115 |
-
# Dummy sampling for demonstration if actual ChatTTS setup is complex
|
| 116 |
-
rand_audio = np.random.rand(16000 * 3).astype(np.float32) * 0.5 - 0.25 # 3秒随机噪音
|
| 117 |
-
return (16000, rand_audio)
|
| 118 |
-
else: # Fallback to a generic TTS if ChatTTS isn't fully loaded/compatible
|
| 119 |
-
# A more robust solution for ChatTTS might involve:
|
| 120 |
-
# from ChatTTS import ChatTTS # Assuming ChatTTS is installed
|
| 121 |
-
# chat = ChatTTS.ChatTTS()
|
| 122 |
-
# chat.load_models(source="huggingface", device="cuda")
|
| 123 |
-
# wavs = chat.infer(texts=[prompt])
|
| 124 |
-
# return (chat.sr, wavs[0][0])
|
| 125 |
-
|
| 126 |
-
# For robust Gradio demo, use a generic TTS if ChatTTS is tricky
|
| 127 |
-
# This is a placeholder, you'd replace with actual ChatTTS inference
|
| 128 |
-
print("ChatTTS model not fully initialized or compatible for direct generation, using dummy output.")
|
| 129 |
-
dummy_pipe = pipeline("text-to-speech", model="facebook/mms-tts-eng", device=device) # Smaller TTS for fallback
|
| 130 |
-
result = dummy_pipe(prompt)
|
| 131 |
-
sr = result['sampling_rate']
|
| 132 |
-
audio_data = result['audio']
|
| 133 |
-
return (sr, audio_data)
|
| 134 |
-
|
| 135 |
except Exception as e:
|
| 136 |
-
print(f"使用
|
| 137 |
return (16000, np.zeros(0))
|
| 138 |
|
| 139 |
-
def
|
| 140 |
-
"""使用
|
| 141 |
-
if
|
| 142 |
-
return (44100, np.zeros(0))
|
| 143 |
try:
|
| 144 |
-
#
|
| 145 |
-
audio_output =
|
| 146 |
audio_array = audio_output['audio'][0].cpu().numpy() if isinstance(audio_output['audio'], torch.Tensor) else audio_output['audio']
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
print(f"使用 Stable Audio 生成音频出错: {e}")
|
| 150 |
-
return (44100, np.zeros(0))
|
| 151 |
-
|
| 152 |
-
def generate_audio_riffusion(prompt):
|
| 153 |
-
"""使用 Riffusion 生成音乐 (通过频谱图)。"""
|
| 154 |
-
if riffusion_pipeline_instance is None:
|
| 155 |
-
return (44100, np.zeros(0))
|
| 156 |
-
try:
|
| 157 |
-
# Riffusion 的核心是文本到频谱图图像生成
|
| 158 |
-
# 它需要一个 Stable Diffusion pipeline 来生成图像
|
| 159 |
-
# 然后将图像转换回音频
|
| 160 |
-
|
| 161 |
-
# 1. 生成频谱图图像 (这���简化为直接使用 Riffusion pipeline)
|
| 162 |
-
# 实际 Riffusion 的pipeline会直接返回PIL Image
|
| 163 |
-
generated_image = riffusion_pipeline_instance(prompt, negative_prompt="bad quality, blurry", num_inference_steps=50).images[0]
|
| 164 |
-
|
| 165 |
-
# 2. 将频谱图图像转换为音频
|
| 166 |
-
# Riffusion 的转换逻辑通常是它自己的库
|
| 167 |
-
# 这里需要一个函数来处理这个转换
|
| 168 |
-
sr = 44100 # Riffusion 常用采样率
|
| 169 |
-
audio_array = image_to_audio_riffusion(generated_image, sr=sr)
|
| 170 |
-
return (sr, audio_array)
|
| 171 |
except Exception as e:
|
| 172 |
-
print(f"使用
|
| 173 |
-
return (
|
| 174 |
-
|
| 175 |
-
# Helper function for Riffusion's image-to-audio
|
| 176 |
-
# This is a simplified version; real Riffusion uses a specific library/method
|
| 177 |
-
def image_to_audio_riffusion(image: Image.Image, sr: int = 44100):
|
| 178 |
-
"""
|
| 179 |
-
将 Riffusion 生成的频谱图图像转换为音频。
|
| 180 |
-
此函数是简化的,实际 Riffusion 转换涉及复杂的 STFT 逆变换。
|
| 181 |
-
为了演示,我们假设图像可以某种方式映射到频谱图,然后逆变换。
|
| 182 |
-
实际使用时,你需要集成 Riffusion 提供的 spectrogram_to_audio 方法。
|
| 183 |
-
例如: from riffusion.spectrogram_converter import SpectrogramConverter
|
| 184 |
-
"""
|
| 185 |
-
# 转换为灰度图以表示幅度
|
| 186 |
-
gray_image = image.convert("L")
|
| 187 |
-
spectrogram_array = np.array(gray_image).astype(np.float32) / 255.0 * 2.0 - 1.0 # Normalize to -1 to 1
|
| 188 |
-
|
| 189 |
-
# Riffusion 通常需要特定的形状,这里为了演示简化
|
| 190 |
-
# 模拟一个逆短时傅里叶变换 (ISTFT)
|
| 191 |
-
# 这是一个非常简化的模拟,实际转换很复杂且特定于 Riffusion
|
| 192 |
-
# 你可能需要安装 riffusion 库并使用其提供的函数
|
| 193 |
-
try:
|
| 194 |
-
# 假设频谱图是幅度谱的对数表示,需要反转换
|
| 195 |
-
# 这里只是一个占位符,因为librosa.istft直接接受复数频谱
|
| 196 |
-
# 真实情况需要从图像恢复幅度和相位
|
| 197 |
-
# For a true Riffusion conversion, you'd do:
|
| 198 |
-
# from riffusion.spectrogram_converter import SpectrogramConverter
|
| 199 |
-
# converter = SpectrogramConverter()
|
| 200 |
-
# audio_data = converter.spectrogram_image_to_audio(image)
|
| 201 |
-
# return audio_data.numpy()
|
| 202 |
-
|
| 203 |
-
# Fallback dummy audio generation from image for demo if full Riffusion library isn't integrated
|
| 204 |
-
# Create a dummy phase for ISTFT
|
| 205 |
-
# This is NOT how Riffusion actually converts, it's just to make it runnable
|
| 206 |
-
dummy_phase = np.exp(1j * np.random.uniform(-np.pi, np.pi, size=spectrogram_array.shape))
|
| 207 |
-
complex_spectrogram = spectrogram_array * dummy_phase
|
| 208 |
-
|
| 209 |
-
audio_data = librosa.istft(complex_spectrogram, hop_length=512) # Example hop_length
|
| 210 |
-
return audio_data
|
| 211 |
-
except Exception as e:
|
| 212 |
-
print(f"Riffusion 图像到音频转换出错: {e}")
|
| 213 |
-
return np.zeros(sr * 3) # 返回 3 秒静音
|
| 214 |
|
| 215 |
# --- Arena 选项卡逻辑 ---
|
| 216 |
def arena_predict(prompt):
|
| 217 |
"""Arena 选项卡的主预测函数,并行调用所有模型。"""
|
| 218 |
print(f"收到提示词: {prompt}")
|
| 219 |
|
| 220 |
-
#
|
| 221 |
-
|
| 222 |
|
| 223 |
-
#
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
# Riffusion 生成 (音乐)
|
| 227 |
-
riffusion_sr, riffusion_audio = generate_audio_riffusion(prompt)
|
| 228 |
|
| 229 |
-
return (
|
| 230 |
|
| 231 |
# --- 模型对比/GRACE评估逻辑 ---
|
| 232 |
|
|
@@ -234,23 +106,17 @@ def arena_predict(prompt):
|
|
| 234 |
# 这些分数是主观评估,你需要根据实际模型表现来调整
|
| 235 |
# 评分范围:1-5,5 为最佳
|
| 236 |
grace_data = {
|
| 237 |
-
"
|
| 238 |
-
"Generalization":
|
| 239 |
-
"Relevance": 4.5, # 语音合成与文本高度相关,
|
| 240 |
-
"Artistry": 4.
|
| 241 |
-
"Efficiency":
|
| 242 |
},
|
| 243 |
-
"
|
| 244 |
-
"Generalization":
|
| 245 |
-
"Relevance": 3.
|
| 246 |
-
"Artistry":
|
| 247 |
-
"Efficiency":
|
| 248 |
-
},
|
| 249 |
-
"Riffusion v1": {
|
| 250 |
-
"Generalization": 3.5, # 主要针对音乐生成,对语音或纯音效泛化能力弱
|
| 251 |
-
"Relevance": 3.5, # 通过文本描述控制音乐生成可能不如直接的音乐模型精确,有时有偏差
|
| 252 |
-
"Artistry": 4.0, # 音乐创意性强,能生成独特风格的音乐,但有时可能不那么“传统”
|
| 253 |
-
"Efficiency": 2.5 # 生成频谱图和转换都比较耗时,效率最低
|
| 254 |
}
|
| 255 |
}
|
| 256 |
|
|
@@ -269,8 +135,7 @@ def create_grace_radar_chart():
|
|
| 269 |
theta=categories + [categories[0]],
|
| 270 |
fill='toself',
|
| 271 |
name=model_name,
|
| 272 |
-
line_color='blue' if model_name == "
|
| 273 |
-
('red' if model_name == "Stable Audio Open 1.0" else 'green')
|
| 274 |
))
|
| 275 |
|
| 276 |
fig.update_layout(
|
|
@@ -298,165 +163,50 @@ def create_grace_radar_chart():
|
|
| 298 |
report_content = """
|
| 299 |
# 文本到音频生成模型对比实验报告
|
| 300 |
|
| 301 |
-
## 1.
|
| 302 |
-
|
| 303 |
-
本实验旨在对当前流行的**文本到音频(Text-to-Audio)生成模型**进行横向对比分析。随着人工智能技术的飞速发展,AI生成音频的能力日益增强,在音乐创作、语音合成、游戏音效、有声读物等领域展现出巨大的潜力。本次实验选取了 Hugging Face 平台上具有代表性的 **2Noise/ChatTTS**、**stabilityai/stable-audio-open-1.0** 和 **riffusion/riffusion-model-v1** 三个模型,通过 Gradio 构建交互式界面,使用 **GRACE 框架**对它们的性能进行多维度评估。
|
| 304 |
-
|
| 305 |
-
---
|
| 306 |
-
|
| 307 |
-
## 2. 模型介绍
|
| 308 |
-
|
| 309 |
-
### 2.1 2Noise/ChatTTS
|
| 310 |
-
|
| 311 |
-
* **模型类型:** 文本到**语音**生成模型(Text-to-Speech, TTS)。
|
| 312 |
-
* **特点:** 以生成**高质量、富有表现力、自然且具有情感的语音**而闻名。它能够处理复杂的文本,并生成逼真的对话语音,非常适合有声读物、客服机器人或任何需要自然人声的场景。它的主要优势在于语音的自然度和情感丰富度。
|
| 313 |
-
|
| 314 |
-
### 2.2 stabilityai/stable-audio-open-1.0
|
| 315 |
|
| 316 |
-
|
| 317 |
-
*
|
|
|
|
|
|
|
| 318 |
|
| 319 |
-
###
|
|
|
|
|
|
|
| 320 |
|
| 321 |
-
|
| 322 |
-
*
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
本次实验通过 Hugging Face Gradio Space 实现了一个交互式平台,包含“Arena”和“模型对比”两个核心选项卡。
|
| 329 |
-
|
| 330 |
-
### 3.1 Arena 选项卡
|
| 331 |
-
|
| 332 |
-
* **功能:** 用户在此选项卡中输入一个文本提示词,系统将同时调用 ChatTTS、Stable Audio Open 1.0 和 Riffusion 三个模型,并并排展示它们各自生成的音频。
|
| 333 |
-
* **目的:** 允许用户直观地比较不同模型在相同输入下的音频生成质量、风格和对提示词的理解程度。用户可以根据提示词的性质(例如,语音、音效、音乐)来观察哪个模型表现最佳。
|
| 334 |
-
|
| 335 |
-
### 3.2 模型对比选项卡 (基于 GRACE 框架)
|
| 336 |
-
|
| 337 |
-
* **功能:** 此选项卡展示了基于 GRACE 框架对三个模型的评估结果。**GRACE 框架**从**泛化性(Generalization)**、**相关性(Relevance)**、**创新表现力(Artistry)**和**效率性(Efficiency)**四个维度对模型进行评估。
|
| 338 |
-
* **目的:** 提供一个结构化的、多维度的模型性能分析视图,帮助用户更全面地理解每个模型的优缺点。评估结果以雷达图形式直观呈现,并辅以文字说明。
|
| 339 |
-
|
| 340 |
-
---
|
| 341 |
-
|
| 342 |
-
## 4. GRACE 框架分析与结果
|
| 343 |
-
|
| 344 |
-
以下是对 ChatTTS、Stable Audio Open 1.0 和 Riffusion v1 在 GRACE 各维度上的评估和对比。评分范围为 1-5,分数越高表示表现越好。
|
| 345 |
-
|
| 346 |
-
### 4.1 泛化性 (Generalization)
|
| 347 |
-
|
| 348 |
-
* **ChatTTS:** (3.0/5) 主要专注于高质量语音合成,对非语音(如音乐或环境音)的生成能力有限,泛化范围相对较窄。
|
| 349 |
-
* **Stable Audio Open 1.0:** (4.0/5) 能够生成多种类型的音频,包括不同的音乐流派、音效和环境声音,泛化能力较强。
|
| 350 |
-
* **Riffusion v1:** (3.5/5) 专注于音乐生成,能够通过文本控制多种音乐风格,但其核心是音乐,对纯语音或复杂环境音的泛化能力有限。
|
| 351 |
-
|
| 352 |
-
### 4.2 相关性 (Relevance)
|
| 353 |
-
|
| 354 |
-
* **ChatTTS:** (4.5/5) 在语音合成方面与文本输入高度相关,能够准确捕捉文本的含义和情感表达。
|
| 355 |
-
* **Stable Audio Open 1.0:** (3.8/5) 对文本描述的理解尚可,但在处理非常具体或复杂的音频描述时,有时可能与用户期望存在一定偏差。
|
| 356 |
-
* **Riffusion v1:** (3.5/5) 通过文本提示控制音乐生成可能不如直接的音乐模型那样精确,有时生成的音乐与文本描述的关联性会稍弱,需要更精细的提示词工程。
|
| 357 |
-
|
| 358 |
-
### 4.3 创新表现力 (Artistry)
|
| 359 |
-
|
| 360 |
-
* **ChatTTS:** (4.5/5) 语音的自然度、情感表达和音质极高,具有出色的艺术表现力,使得合成语音听起来非常逼真且富有感染力。
|
| 361 |
-
* **Stable Audio Open 1.0:** (4.0/5) 音质良好,能够生成具有一定创意和美感的音乐片段及音效,在通用音频生成领域具有较强的艺术表现力。
|
| 362 |
-
* **Riffusion v1:** (4.0/5) 能够生成独特且具有实验性的音乐风格,创意性强。其视觉到听觉的转换机制带来了独特的艺术呈现,但有时可能不那么“传统”或预期。
|
| 363 |
-
|
| 364 |
-
### 4.4 效率性 (Efficiency)
|
| 365 |
-
|
| 366 |
-
* **ChatTTS:** (3.0/5) 生成高质量、富有情感的语音通常计算成本较高,推理速度中等。
|
| 367 |
-
* **Stable Audio Open 1.0:** (3.0/5) 模型参数量较大,生成速度相对较慢,对计算资源的需求较高。
|
| 368 |
-
* **Riffusion v1:** (2.5/5) 生成频谱图图像和随后将其逆转换为音频的过程都比较耗时,是这三个模型中效率最低的。
|
| 369 |
-
|
| 370 |
-
### GRACE 评估雷达图
|
| 371 |
-
|
| 372 |
-

|
| 373 |
-
*(注意:此处将插入使用 Plotly 生成的雷达图,通过 `gr.Plot` 组件展示)*
|
| 374 |
|
| 375 |
---
|
| 376 |
|
| 377 |
-
##
|
| 378 |
-
|
| 379 |
-
本次实验对比了 ChatTTS、Stable Audio Open 1.0 和 Riffusion v1 三个文本到音频模型,它们各自在不同的细分领域展现了独特的优势:
|
| 380 |
|
| 381 |
-
|
| 382 |
-
* **Stable Audio Open 1.0** 是一款**通用的音频生成器**,适用于生成多样化的音效和音乐片段。
|
| 383 |
-
* **Riffusion v1** 提供了一种**创新性的音乐创作方式**,适合寻求独特音乐风格或探索视觉-听觉关联的创作者。
|
| 384 |
-
|
| 385 |
-
未来的研究可以进一步探索:
|
| 386 |
-
|
| 387 |
-
1. **多模态输入:** 结合文本���外的其他模态(如图像、视频)来生成更丰富的音频。
|
| 388 |
-
2. **可控性与精细化:** 开发更精细的控制机制,让用户能更准确地指导模型的生成过程。
|
| 389 |
-
3. **客观评估指标:** 建立更完善的客观评估指标来补充主观的 GRACE 评估,以量化模型在特定任务上的表现。
|
| 390 |
-
|
| 391 |
-
---
|
| 392 |
-
"""
|
| 393 |
-
|
| 394 |
-
# --- Gradio 界面定义 ---
|
| 395 |
-
with gr.Blocks(title="文本到音频模型对比实验") as demo:
|
| 396 |
-
gr.Markdown("# 文本到音频模型对比实验")
|
| 397 |
-
gr.Markdown("本项目对比了 **2Noise/ChatTTS**、**stabilityai/stable-audio-open-1.0** 和 **riffusion/riffusion-model-v1** 三个文本到音频生成模型。")
|
| 398 |
-
|
| 399 |
-
with gr.Tab("Arena"):
|
| 400 |
-
gr.Markdown("### Arena: 统一输入,对比模型输出")
|
| 401 |
-
gr.Markdown("在下方输入框中描述你希望生成的音频(例如:'一段平静的钢琴旋律', '一个机器人说话的声音', '一段悲伤的旁白'),然后点击生成按钮。")
|
| 402 |
-
|
| 403 |
-
arena_input = gr.Textbox(label="输入文本提示词", placeholder="例如:A calm piano melody in a quiet room.")
|
| 404 |
-
generate_button = gr.Button("生成音频")
|
| 405 |
-
|
| 406 |
-
with gr.Row():
|
| 407 |
-
with gr.Column():
|
| 408 |
-
gr.Markdown("#### ChatTTS (语音)")
|
| 409 |
-
output_chattts = gr.Audio(label="ChatTTS 输出", type="numpy", interactive=False)
|
| 410 |
-
with gr.Column():
|
| 411 |
-
gr.Markdown("#### Stable Audio Open 1.0 (音效/音乐)")
|
| 412 |
-
output_stable_audio = gr.Audio(label="Stable Audio 输出", type="numpy", interactive=False)
|
| 413 |
-
with gr.Column():
|
| 414 |
-
gr.Markdown("#### Riffusion v1 (音乐)")
|
| 415 |
-
output_riffusion = gr.Audio(label="Riffusion 输出", type="numpy", interactive=False)
|
| 416 |
-
|
| 417 |
-
# 绑定生成按钮的点击事件
|
| 418 |
-
generate_button.click(
|
| 419 |
-
arena_predict,
|
| 420 |
-
inputs=[arena_input],
|
| 421 |
-
outputs=[output_chattts, output_stable_audio, output_riffusion]
|
| 422 |
-
)
|
| 423 |
-
|
| 424 |
-
# 示例提示词
|
| 425 |
-
gr.Examples(
|
| 426 |
-
[
|
| 427 |
-
["A gentle rain falling outside a window."],
|
| 428 |
-
["An epic orchestral track with drums and violins."],
|
| 429 |
-
["A robotic voice saying 'Hello world, how are you today?'"],
|
| 430 |
-
["A cat meowing loudly."],
|
| 431 |
-
["A short, happy tune on a flute."],
|
| 432 |
-
["A melancholic narration about lost dreams."] # For ChatTTS specific focus
|
| 433 |
-
],
|
| 434 |
-
inputs=arena_input,
|
| 435 |
-
label="尝试这些示例提示词"
|
| 436 |
-
)
|
| 437 |
-
|
| 438 |
-
with gr.Tab("模型对比"):
|
| 439 |
-
gr.Markdown("### 模型对比: 基于 GRACE 框架的评估")
|
| 440 |
-
gr.Markdown("此选项卡展示了 ChatTTS、Stable Audio Open 1.0 和 Riffusion v1 三个模型在泛化性、相关性、创新表现力和效率性四个维度上的表现对比。")
|
| 441 |
-
|
| 442 |
-
# 嵌入 Plotly 雷达图
|
| 443 |
-
grace_plot = gr.Plot(create_grace_radar_chart(), label="GRACE 框架模型对比雷达图", show_label=False)
|
| 444 |
-
gr.Markdown("""
|
| 445 |
-
**GRACE 维度解释:**
|
| 446 |
-
* **G (Generalization) 泛化性:** 模型是否能适配多种输入任务、生成多样化音频。
|
| 447 |
-
* **R (Relevance) 相关性:** 输出音频是否紧扣输入主题、与用户期望匹配。
|
| 448 |
-
* **A (Artistry) 创新表现力:** 输出内容是否有创意、音质高、表现力强。
|
| 449 |
-
* **E (Efficiency) 效率性:** 是否能在较少人工干预下快速输出高质量结果。
|
| 450 |
|
| 451 |
-
|
| 452 |
-
""")
|
| 453 |
|
| 454 |
-
|
| 455 |
-
|
| 456 |
|
| 457 |
-
#
|
| 458 |
-
demo.load(initialize_all_models, None, None)
|
| 459 |
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from transformers import pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq, MusicgenForConditionalGeneration
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
| 5 |
import plotly.graph_objects as go
|
| 6 |
import os
|
| 7 |
import scipy.io.wavfile # For saving/reading wav files
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
# --- 全局变量和模型加载 ---
|
| 10 |
HF_HOME = os.getenv("HF_HOME", "./hf_cache")
|
| 11 |
os.makedirs(HF_HOME, exist_ok=True)
|
| 12 |
|
| 13 |
# 定义模型路径和名称
|
| 14 |
+
# 模型 A: 文本到语音 (Text-to-Speech)
|
| 15 |
+
MODEL_TTS = "facebook/mms-tts-eng" # 可以替换为 mms-tts-zh 如果主要测试中文
|
| 16 |
+
|
| 17 |
+
# 模型 B: 文本到音效/简单音频 (Text-to-Audio/Sound Generation)
|
| 18 |
+
MODEL_AUDIOGEN = "facebook/audiogen-small"
|
| 19 |
|
| 20 |
# 确定加载模型的设备
|
| 21 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
print(f"当前使用的设备: {device}")
|
| 23 |
|
| 24 |
# --- 模型加载函数 ---
|
| 25 |
+
tts_pipeline = None
|
| 26 |
+
audiogen_pipeline = None
|
|
|
|
| 27 |
|
| 28 |
+
# 假设同学 A 负责集成 MMS-TTS
|
| 29 |
+
def load_tts_model():
|
| 30 |
+
"""加载 MMS-TTS 模型。"""
|
| 31 |
+
print(f"正在加载 TTS 模型: {MODEL_TTS} 到 {device}...")
|
| 32 |
try:
|
| 33 |
+
# MMS-TTS 可以通过 text-to-speech pipeline 轻松加载
|
| 34 |
+
pipe = pipeline("text-to-speech", model=MODEL_TTS, device=device)
|
| 35 |
+
print("MMS-TTS 模型加载成功。")
|
| 36 |
+
return pipe
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
except Exception as e:
|
| 38 |
+
print(f"加载 MMS-TTS 时出错: {e}")
|
| 39 |
return None
|
| 40 |
|
| 41 |
+
# 假设同学 B 负责集成 AudioGen-small
|
| 42 |
+
def load_audiogen_model():
|
| 43 |
+
"""加载 AudioGen-small 模型。"""
|
| 44 |
+
print(f"正在加载 AudioGen 模型: {MODEL_AUDIOGEN} 到 {device}...")
|
| 45 |
try:
|
| 46 |
+
# AudioGen 也可以通过 text-to-audio pipeline 加载
|
| 47 |
+
pipe = pipeline("text-to-audio", model=MODEL_AUDIOGEN, device=device)
|
| 48 |
+
print("AudioGen 模型加载成功。")
|
|
|
|
|
|
|
|
|
|
| 49 |
return pipe
|
| 50 |
except Exception as e:
|
| 51 |
+
print(f"加载 AudioGen 时出错: {e}")
|
| 52 |
return None
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
def initialize_all_models():
|
| 55 |
"""在 Gradio 界面加载时调用,用于初始化所有模型。"""
|
| 56 |
+
global tts_pipeline, audiogen_pipeline
|
| 57 |
+
tts_pipeline = load_tts_model()
|
| 58 |
+
audiogen_pipeline = load_audiogen_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
# --- 音频生成推理函数 ---
|
| 61 |
|
| 62 |
+
def generate_audio_tts(prompt):
|
| 63 |
+
"""使用 MMS-TTS 生成语音。"""
|
| 64 |
+
if tts_pipeline is None:
|
| 65 |
+
return (16000, np.zeros(0)) # 返回空音频
|
| 66 |
try:
|
| 67 |
+
# TTS pipeline 通常返回 dict 包含 'audio' (numpy array) 和 'sampling_rate'
|
| 68 |
+
result = tts_pipeline(prompt)
|
| 69 |
+
audio_array = result['audio']
|
| 70 |
+
sampling_rate = result['sampling_rate']
|
| 71 |
+
return (sampling_rate, audio_array)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
except Exception as e:
|
| 73 |
+
print(f"使用 MMS-TTS 生成音频出错: {e}")
|
| 74 |
return (16000, np.zeros(0))
|
| 75 |
|
| 76 |
+
def generate_audio_audiogen(prompt):
|
| 77 |
+
"""使用 AudioGen-small 生成音频。"""
|
| 78 |
+
if audiogen_pipeline is None:
|
| 79 |
+
return (44100, np.zeros(0)) # 返回空音频,AudioGen通常用44.1kHz
|
| 80 |
try:
|
| 81 |
+
# AudioGen pipeline 也返回 dict 包含 'audio' 和 'sampling_rate'
|
| 82 |
+
audio_output = audiogen_pipeline(prompt, sampling_rate=16000) # 指定一个采样率
|
| 83 |
audio_array = audio_output['audio'][0].cpu().numpy() if isinstance(audio_output['audio'], torch.Tensor) else audio_output['audio']
|
| 84 |
+
sampling_rate = audio_output['sampling_rate'] if 'sampling_rate' in audio_output else 16000 # 确保拿到采样率
|
| 85 |
+
return (sampling_rate, audio_array)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
except Exception as e:
|
| 87 |
+
print(f"使用 AudioGen 生成音频出错: {e}")
|
| 88 |
+
return (16000, np.zeros(0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
# --- Arena 选项卡逻辑 ---
|
| 91 |
def arena_predict(prompt):
|
| 92 |
"""Arena 选项卡的主预测函数,并行调用所有模型。"""
|
| 93 |
print(f"收到提示词: {prompt}")
|
| 94 |
|
| 95 |
+
# MMS-TTS 生成
|
| 96 |
+
tts_sr, tts_audio = generate_audio_tts(prompt)
|
| 97 |
|
| 98 |
+
# AudioGen-small 生成
|
| 99 |
+
audiogen_sr, audiogen_audio = generate_audio_audiogen(prompt)
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
+
return (tts_sr, tts_audio), (audiogen_sr, audiogen_audio)
|
| 102 |
|
| 103 |
# --- 模型对比/GRACE评估逻辑 ---
|
| 104 |
|
|
|
|
| 106 |
# 这些分数是主观评估,你需要根据实际模型表现来调整
|
| 107 |
# 评分范围:1-5,5 为最佳
|
| 108 |
grace_data = {
|
| 109 |
+
"MMS-TTS-Eng": {
|
| 110 |
+
"Generalization": 2.5, # 专注于语音,对音乐/音效泛化能力弱
|
| 111 |
+
"Relevance": 4.5, # 语音合成与文本高度相关,发音准确
|
| 112 |
+
"Artistry": 4.0, # 语音自然度高,但情感和表现力不如ChatTTS
|
| 113 |
+
"Efficiency": 4.5 # 模型小,生成速度快,资源消耗低
|
| 114 |
},
|
| 115 |
+
"AudioGen-small": {
|
| 116 |
+
"Generalization": 3.5, # 能生成音效和简单音乐,比TTS模型泛化好
|
| 117 |
+
"Relevance": 3.5, # 对音频描述的理解中等,有时可能不完全符合预期
|
| 118 |
+
"Artistry": 3.5, # 音质尚可,创意性一般,但能满足基本音效需求
|
| 119 |
+
"Efficiency": 4.0 # 模型相对较小,生成速度较快,资源消耗中等
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
}
|
| 121 |
}
|
| 122 |
|
|
|
|
| 135 |
theta=categories + [categories[0]],
|
| 136 |
fill='toself',
|
| 137 |
name=model_name,
|
| 138 |
+
line_color='blue' if model_name == "MMS-TTS-Eng" else 'red'
|
|
|
|
| 139 |
))
|
| 140 |
|
| 141 |
fig.update_layout(
|
|
|
|
| 163 |
report_content = """
|
| 164 |
# 文本到音频生成模型对比实验报告
|
| 165 |
|
| 166 |
+
## 1. 模型及类别选择
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
+
### 1.1 所选模型的类型及背景说明
|
| 169 |
+
本次实验聚焦于**文本到音频(Text-to-Audio)生成模型**,具体选择了两个在功能上有所侧重且模型尺寸相对较小的模型进行对比:
|
| 170 |
+
1. **`facebook/mms-tts-eng` (MMS-TTS-Eng)**:这是一个由 Meta AI (Facebook) 开发的**文本到语音(Text-to-Speech, TTS)**模型。它属于 Massively Multilingual Speech (MMS) 项目的一部分,旨在支持大量语言的语音合成。我们选择的是英语版本 (`eng`),它以其高效、轻量级和清晰的语音输出而闻名。
|
| 171 |
+
2. **`facebook/audiogen-small` (AudioGen-small)**:这款模型是 Meta AI AudioCraft 框架下的一个较小版本,专注于**文本到音效或简单音频片段**的生成。它能够根据文本描述创建环境音、乐器声或短小的声音事件。
|
| 172 |
|
| 173 |
+
### 1.2 模型用途对比简述
|
| 174 |
+
* **MMS-TTS-Eng** 的主要用途是将书面文本转换为自然发音的**人类语音**,适用于有声读物、语音助手、导航系统等场景。
|
| 175 |
+
* **AudioGen-small** 的主要用途是根据文本描述生成各种**非语音的音效或短音乐片段**,例如“雨声”、“鸟鸣”或“简单的吉他旋律”,适用于游戏音效、视频配乐或环境音模拟。
|
| 176 |
|
| 177 |
+
### 1.3 两个模型异同点分析
|
| 178 |
+
**相同点:**
|
| 179 |
+
* 都属于文本到音频生成领域,将文本作��输入。
|
| 180 |
+
* 都由 Meta AI 开发,并在 Hugging Face Hub 上提供。
|
| 181 |
+
* 都相对轻量级,相比大型文生音/乐模型(如 Stable Audio 或 MusicGen-large)对资源要求更低。
|
| 182 |
|
| 183 |
+
**不同点:**
|
| 184 |
+
* **核心功能:** MMS-TTS 专注于**语音合成**,目标是还原人类语音的自然度;AudioGen 专注于**非语音音频生成**,目标是模拟真实世界的音效或创作音乐片段。
|
| 185 |
+
* **输出类型:** MMS-TTS 输出的是人类语言的声音;AudioGen 输出的是环境音、乐器音、抽象音效等。
|
| 186 |
+
* **内部机制:** 尽管都基于 Transformer 架构,但它们的内部训练数据、任务目标和具体架构优化不同,以适应各自的生成任务。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
---
|
| 189 |
|
| 190 |
+
## 2. 系统实现细节
|
|
|
|
|
|
|
| 191 |
|
| 192 |
+
本系统通过 Gradio 构建了一个交互式 Hugging Face Space,实现了模型的统一输入和多模型输出展示。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
+
### 2.1 Gradio 交互界面截图
|
|
|
|
| 195 |
|
| 196 |
+
(此处应放置你的 Gradio Space 的截图,显示 "Arena" 和 "模型对比" 选项卡的用户界面)
|
| 197 |
+

|
| 198 |
|
| 199 |
+
### 2.2 输入与输出流程图 (Mermaid 语法)
|
|
|
|
| 200 |
|
| 201 |
+
```mermaid
|
| 202 |
+
graph TD
|
| 203 |
+
A[用户输入文本提示词] --> B{Gradio UI};
|
| 204 |
+
B --> C1[MMS-TTS 模型];
|
| 205 |
+
B --> C2[AudioGen-small 模型];
|
| 206 |
+
C1 -- 生成语音 --> D1[MMS-TTS 音频输出];
|
| 207 |
+
C2 -- 生成音效 --> D2[AudioGen-small 音频输出];
|
| 208 |
+
D1 --> E[Arena 选项卡展示];
|
| 209 |
+
D2 --> E;
|
| 210 |
+
E -- 用户评估 --> F[GRACE 评估数据];
|
| 211 |
+
F --> G[模型对比选项卡];
|
| 212 |
+
G -- 生成雷达图 --> H[报告选项卡嵌入雷达图];
|