txh17 commited on
Commit
12d7ae3
·
verified ·
1 Parent(s): fecdc58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -349
app.py CHANGED
@@ -1,232 +1,104 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, AutoModelForTextToSpectrogram, pipeline, StableDiffusionPipeline
3
  import torch
4
  import numpy as np
5
  import plotly.graph_objects as go
6
  import os
7
  import scipy.io.wavfile # For saving/reading wav files
8
- from diffusers import StableDiffusionImg2ImgPipeline # For Riffusion's image generation part
9
- from PIL import Image
10
- import io
11
- import librosa # For spectrogram to audio conversion
12
 
13
  # --- 全局变量和模型加载 ---
14
  HF_HOME = os.getenv("HF_HOME", "./hf_cache")
15
  os.makedirs(HF_HOME, exist_ok=True)
16
 
17
  # 定义模型路径和名称
18
- MODEL_CHAT_TTS = "2Noise/ChatTTS"
19
- MODEL_STABLE_AUDIO = "stabilityai/stable-audio-open-1.0"
20
- MODEL_RIFFUSION_SD = "runwayml/stable-diffusion-v1-5" # Riffusion uses Stable Diffusion internally for spectrogram
21
- MODEL_RIFFUSION_CONTROLNET = "riffusion/riffusion-model-v1" # This is the "model" for Riffusion's special handling
 
22
 
23
  # 确定加载模型的设备
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
  print(f"当前使用的设备: {device}")
26
 
27
  # --- 模型加载函数 ---
28
- chattts_model, chattts_tokenizer = None, None
29
- stable_audio_pipe = None
30
- riffusion_pipeline = None
31
 
32
- # ChatTTS 加载 (可能需要额外步骤,此处简化为pipeline)
33
- def load_chattts_model():
34
- """加载 2Noise/ChatTTS 模型。"""
35
- print(f"正在加载 ChatTTS 模型: {MODEL_CHAT_TTS} 到 {device}...")
36
  try:
37
- # ChatTTS 模型的实际加载能需要更复杂的代码,此处使用 pipeline 模拟
38
- # 实际使用 ChatTTS 可能需要手动下载模型权重和推理代码
39
- # 参考:https://huggingface.co/2Noise/ChatTTS
40
- # 这里为了简化Gradio部署,我们尝试用AutoModelForSpeechSeq2Seq
41
- from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
42
- model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_CHAT_TTS).to(device)
43
- processor = AutoProcessor.from_pretrained(MODEL_CHAT_TTS)
44
- print("ChatTTS 模型加载成功。")
45
- return model, processor
46
- except Exception as e:
47
- print(f"加载 ChatTTS 时出错: {e}. 请确保模型在当前Transformers版本下兼容,或尝试手动加载。", e)
48
- return None, None
49
-
50
- def load_stable_audio_model():
51
- """加载 stabilityai/stable-audio-open-1.0 模型。"""
52
- print(f"正在加载 Stable Audio 模型: {MODEL_STABLE_AUDIO} 到 {device}...")
53
- try:
54
- stable_audio_pipe = pipeline("text-to-audio", model=MODEL_STABLE_AUDIO, device=device)
55
- print("Stable Audio 模型加载成功。")
56
- return stable_audio_pipe
57
  except Exception as e:
58
- print(f"加载 Stable Audio 时出错: {e}")
59
  return None
60
 
61
- # Riffusion 模型加载 (特殊处理)
62
- def load_riffusion_model():
63
- """加载 Riffusion 模型。它实际上是基于 Stable Diffusion 的图像到图像 pipeline。"""
64
- print(f"正在加载 Riffusion (Stable Diffusion部分): {MODEL_RIFFUSION_SD} 到 {device}...")
65
  try:
66
- # Riffusion 实际上是 Stable Diffusion 的一个变种,用于生成频谱图图像
67
- # 它通常需要一个StableDiffusionPipeline,并加载特定的ControlNet或LoRA
68
- # 这里简化为直接加载riffusion模型作为一个img2img pipeline
69
- # 实际部署可能需要更复杂的ControlNet集成
70
- pipe = StableDiffusionImg2ImgPipeline.from_pretrained(MODEL_RIFFUSION_CONTROLNET, torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
71
- print("Riffusion (Stable Diffusion) 模型加载成功。")
72
  return pipe
73
  except Exception as e:
74
- print(f"加载 Riffusion 时出错: {e}")
75
  return None
76
 
77
- # 全局模型实例
78
- chattts_model_instance, chattts_processor_instance = None, None
79
- stable_audio_pipe_instance = None
80
- riffusion_pipeline_instance = None
81
-
82
  def initialize_all_models():
83
  """在 Gradio 界面加载时调用,用于初始化所有模型。"""
84
- global chattts_model_instance, chattts_processor_instance, stable_audio_pipe_instance, riffusion_pipeline_instance
85
-
86
- # 尝试加载 ChatTTS
87
- chattts_model_instance, chattts_processor_instance = load_chattts_model()
88
-
89
- # 尝试加载 Stable Audio
90
- stable_audio_pipe_instance = load_stable_audio_model()
91
-
92
- # 尝试加载 Riffusion
93
- riffusion_pipeline_instance = load_riffusion_model()
94
-
95
 
96
  # --- 音频生成推理函数 ---
97
 
98
- def generate_audio_chattts(prompt):
99
- """使用 ChatTTS 生成语音。"""
100
- if chattts_model_instance is None or chattts_processor_instance is None:
101
- return (16000, np.zeros(0))
102
  try:
103
- # ChatTTS 的实际推理过程可能更复杂,此处为示意
104
- # 通常涉及文本到token,token到语音
105
- # 示例:
106
- # inputs = chattts_processor_instance(text=prompt, return_tensors="pt").to(device)
107
- # speech = chattts_model_instance.generate_speech(inputs["input_ids"], sampling_rate=16000)
108
- # return (16000, speech.cpu().numpy())
109
-
110
- # 为了 Gradio pipeline 兼容性,如果 ChatTTS 不直接支持 pipeline,这里用伪代码或简化处理
111
- # 假设 ChatTTS 可以通过特定的`generate`方法生成原始音频 numpy 数组
112
- # 实际部署时,你需要根据 ChatTTS 的官方示例来编写
113
- # 这里用一个简单的TTS pipeline作为fallback或者示意
114
- if hasattr(chattts_model_instance, 'generate_speech'): # 如果是真正的ChatTTS模型实例
115
- # Dummy sampling for demonstration if actual ChatTTS setup is complex
116
- rand_audio = np.random.rand(16000 * 3).astype(np.float32) * 0.5 - 0.25 # 3秒随机噪音
117
- return (16000, rand_audio)
118
- else: # Fallback to a generic TTS if ChatTTS isn't fully loaded/compatible
119
- # A more robust solution for ChatTTS might involve:
120
- # from ChatTTS import ChatTTS # Assuming ChatTTS is installed
121
- # chat = ChatTTS.ChatTTS()
122
- # chat.load_models(source="huggingface", device="cuda")
123
- # wavs = chat.infer(texts=[prompt])
124
- # return (chat.sr, wavs[0][0])
125
-
126
- # For robust Gradio demo, use a generic TTS if ChatTTS is tricky
127
- # This is a placeholder, you'd replace with actual ChatTTS inference
128
- print("ChatTTS model not fully initialized or compatible for direct generation, using dummy output.")
129
- dummy_pipe = pipeline("text-to-speech", model="facebook/mms-tts-eng", device=device) # Smaller TTS for fallback
130
- result = dummy_pipe(prompt)
131
- sr = result['sampling_rate']
132
- audio_data = result['audio']
133
- return (sr, audio_data)
134
-
135
  except Exception as e:
136
- print(f"使用 ChatTTS 生成音频出错: {e}")
137
  return (16000, np.zeros(0))
138
 
139
- def generate_audio_stable_audio(prompt):
140
- """使用 Stable Audio 生成音频。"""
141
- if stable_audio_pipe_instance is None:
142
- return (44100, np.zeros(0))
143
  try:
144
- # Stable Audio pipeline 通常返回 Dict with 'audio' key
145
- audio_output = stable_audio_pipe_instance(prompt, sampling_rate=44100) # Stable Audio uses 44.1kHz
146
  audio_array = audio_output['audio'][0].cpu().numpy() if isinstance(audio_output['audio'], torch.Tensor) else audio_output['audio']
147
- return (44100, audio_array)
148
- except Exception as e:
149
- print(f"使用 Stable Audio 生成音频出错: {e}")
150
- return (44100, np.zeros(0))
151
-
152
- def generate_audio_riffusion(prompt):
153
- """使用 Riffusion 生成音乐 (通过频谱图)。"""
154
- if riffusion_pipeline_instance is None:
155
- return (44100, np.zeros(0))
156
- try:
157
- # Riffusion 的核心是文本到频谱图图像生成
158
- # 它需要一个 Stable Diffusion pipeline 来生成图像
159
- # 然后将图像转换回音频
160
-
161
- # 1. 生成频谱图图像 (这���简化为直接使用 Riffusion pipeline)
162
- # 实际 Riffusion 的pipeline会直接返回PIL Image
163
- generated_image = riffusion_pipeline_instance(prompt, negative_prompt="bad quality, blurry", num_inference_steps=50).images[0]
164
-
165
- # 2. 将频谱图图像转换为音频
166
- # Riffusion 的转换逻辑通常是它自己的库
167
- # 这里需要一个函数来处理这个转换
168
- sr = 44100 # Riffusion 常用采样率
169
- audio_array = image_to_audio_riffusion(generated_image, sr=sr)
170
- return (sr, audio_array)
171
  except Exception as e:
172
- print(f"使用 Riffusion 生成音频出错: {e}")
173
- return (44100, np.zeros(0))
174
-
175
- # Helper function for Riffusion's image-to-audio
176
- # This is a simplified version; real Riffusion uses a specific library/method
177
- def image_to_audio_riffusion(image: Image.Image, sr: int = 44100):
178
- """
179
- 将 Riffusion 生成的频谱图图像转换为音频。
180
- 此函数是简化的,实际 Riffusion 转换涉及复杂的 STFT 逆变换。
181
- 为了演示,我们假设图像可以某种方式映射到频谱图,然后逆变换。
182
- 实际使用时,你需要集成 Riffusion 提供的 spectrogram_to_audio 方法。
183
- 例如: from riffusion.spectrogram_converter import SpectrogramConverter
184
- """
185
- # 转换为灰度图以表示幅度
186
- gray_image = image.convert("L")
187
- spectrogram_array = np.array(gray_image).astype(np.float32) / 255.0 * 2.0 - 1.0 # Normalize to -1 to 1
188
-
189
- # Riffusion 通常需要特定的形状,这里为了演示简化
190
- # 模拟一个逆短时傅里叶变换 (ISTFT)
191
- # 这是一个非常简化的模拟,实际转换很复杂且特定于 Riffusion
192
- # 你可能需要安装 riffusion 库并使用其提供的函数
193
- try:
194
- # 假设频谱图是幅度谱的对数表示,需要反转换
195
- # 这里只是一个占位符,因为librosa.istft直接接受复数频谱
196
- # 真实情况需要从图像恢复幅度和相位
197
- # For a true Riffusion conversion, you'd do:
198
- # from riffusion.spectrogram_converter import SpectrogramConverter
199
- # converter = SpectrogramConverter()
200
- # audio_data = converter.spectrogram_image_to_audio(image)
201
- # return audio_data.numpy()
202
-
203
- # Fallback dummy audio generation from image for demo if full Riffusion library isn't integrated
204
- # Create a dummy phase for ISTFT
205
- # This is NOT how Riffusion actually converts, it's just to make it runnable
206
- dummy_phase = np.exp(1j * np.random.uniform(-np.pi, np.pi, size=spectrogram_array.shape))
207
- complex_spectrogram = spectrogram_array * dummy_phase
208
-
209
- audio_data = librosa.istft(complex_spectrogram, hop_length=512) # Example hop_length
210
- return audio_data
211
- except Exception as e:
212
- print(f"Riffusion 图像到音频转换出错: {e}")
213
- return np.zeros(sr * 3) # 返回 3 秒静音
214
 
215
  # --- Arena 选项卡逻辑 ---
216
  def arena_predict(prompt):
217
  """Arena 选项卡的主预测函数,并行调用所有模型。"""
218
  print(f"收到提示词: {prompt}")
219
 
220
- # ChatTTS 生成 (高质量语音)
221
- chattts_sr, chattts_audio = generate_audio_chattts(prompt)
222
 
223
- # Stable Audio 生成 (通用声音/音乐)
224
- stable_audio_sr, stable_audio_audio = generate_audio_stable_audio(prompt)
225
-
226
- # Riffusion 生成 (音乐)
227
- riffusion_sr, riffusion_audio = generate_audio_riffusion(prompt)
228
 
229
- return (chattts_sr, chattts_audio), (stable_audio_sr, stable_audio_audio), (riffusion_sr, riffusion_audio)
230
 
231
  # --- 模型对比/GRACE评估逻辑 ---
232
 
@@ -234,23 +106,17 @@ def arena_predict(prompt):
234
  # 这些分数是主观评估,你需要根据实际模型表现来调整
235
  # 评分范围:1-5,5 为最佳
236
  grace_data = {
237
- "ChatTTS": {
238
- "Generalization": 3.0, # 主要用于语音,对非语音泛化能力弱
239
- "Relevance": 4.5, # 语音合成与文本高度相关,情感还原度高
240
- "Artistry": 4.5, # 语音自然度表现力极高,逼真
241
- "Efficiency": 3.0 # 生成高质量语音通常计算成本较高,速度中等
242
  },
243
- "Stable Audio Open 1.0": {
244
- "Generalization": 4.0, # 能生成各种音效和音乐片段,泛化能力较强
245
- "Relevance": 3.8, # 对复杂音频描述的理解有时有偏差,但通常抓住核心
246
- "Artistry": 4.0, # 音质良好音乐性和创意性较强
247
- "Efficiency": 3.0 # 模型较,生成速度相对慢,资源消耗
248
- },
249
- "Riffusion v1": {
250
- "Generalization": 3.5, # 主要针对音乐生成,对语音或纯音效泛化能力弱
251
- "Relevance": 3.5, # 通过文本描述控制音乐生成可能不如直接的音乐模型精确,有时有偏差
252
- "Artistry": 4.0, # 音乐创意性强,能生成独特风格的音乐,但有时可能不那么“传统”
253
- "Efficiency": 2.5 # 生成频谱图和转换都比较耗时,效率最低
254
  }
255
  }
256
 
@@ -269,8 +135,7 @@ def create_grace_radar_chart():
269
  theta=categories + [categories[0]],
270
  fill='toself',
271
  name=model_name,
272
- line_color='blue' if model_name == "ChatTTS" else \
273
- ('red' if model_name == "Stable Audio Open 1.0" else 'green')
274
  ))
275
 
276
  fig.update_layout(
@@ -298,165 +163,50 @@ def create_grace_radar_chart():
298
  report_content = """
299
  # 文本到音频生成模型对比实验报告
300
 
301
- ## 1. 引言
302
-
303
- 本实验旨在对当前流行的**文本到音频(Text-to-Audio)生成模型**进行横向对比分析。随着人工智能技术的飞速发展,AI生成音频的能力日益增强,在音乐创作、语音合成、游戏音效、有声读物等领域展现出巨大的潜力。本次实验选取了 Hugging Face 平台上具有代表性的 **2Noise/ChatTTS**、**stabilityai/stable-audio-open-1.0** 和 **riffusion/riffusion-model-v1** 三个模型,通过 Gradio 构建交互式界面,使用 **GRACE 框架**对它们的性能进行多维度评估。
304
-
305
- ---
306
-
307
- ## 2. 模型介绍
308
-
309
- ### 2.1 2Noise/ChatTTS
310
-
311
- * **模型类型:** 文本到**语音**生成模型(Text-to-Speech, TTS)。
312
- * **特点:** 以生成**高质量、富有表现力、自然且具有情感的语音**而闻名。它能够处理复杂的文本,并生成逼真的对话语音,非常适合有声读物、客服机器人或任何需要自然人声的场景。它的主要优势在于语音的自然度和情感丰富度。
313
-
314
- ### 2.2 stabilityai/stable-audio-open-1.0
315
 
316
- * **模型类型:** 通用文本到**音频/音乐**生成模型。
317
- * **特点:** 由 Stability AI 发布,能够根据文本描述生成多种类的音频,包括**环境音效、短音乐片段、乐器音色**等。它是一个较为通用的模型,旨在为创作者提供丰富的音频素材,支持生成长达90秒的音频。
 
 
318
 
319
- ### 2.3 riffusion/riffusion-model-v1
 
 
320
 
321
- * **模型类型:** 文本到**音乐**生成模型(基于扩散模型和频谱图)。
322
- * **特:** 这款模型非常独特,它首先将文本提示转化为**音乐的频谱图图像**(类似 Stable Diffusion 生成图像),然后将这个频谱图逆转换为实际的音频。这意味着它结合了图像生成和音频处理技术,能够生成具有**视觉-听觉对应**的创意音乐,特别适合探索不同风格和氛围的音乐生成。
 
 
 
323
 
324
- ---
325
-
326
- ## 3. 实验设计
327
-
328
- 本次实验通过 Hugging Face Gradio Space 实现了一个交互式平台,包含“Arena”和“模型对比”两个核心选项卡。
329
-
330
- ### 3.1 Arena 选项卡
331
-
332
- * **功能:** 用户在此选项卡中输入一个文本提示词,系统将同时调用 ChatTTS、Stable Audio Open 1.0 和 Riffusion 三个模型,并并排展示它们各自生成的音频。
333
- * **目的:** 允许用户直观地比较不同模型在相同输入下的音频生成质量、风格和对提示词的理解程度。用户可以根据提示词的性质(例如,语音、音效、音乐)来观察哪个模型表现最佳。
334
-
335
- ### 3.2 模型对比选项卡 (基于 GRACE 框架)
336
-
337
- * **功能:** 此选项卡展示了基于 GRACE 框架对三个模型的评估结果。**GRACE 框架**从**泛化性(Generalization)**、**相关性(Relevance)**、**创新表现力(Artistry)**和**效率性(Efficiency)**四个维度对模型进行评估。
338
- * **目的:** 提供一个结构化的、多维度的模型性能分析视图,帮助用户更全面地理解每个模型的优缺点。评估结果以雷达图形式直观呈现,并辅以文字说明。
339
-
340
- ---
341
-
342
- ## 4. GRACE 框架分析与结果
343
-
344
- 以下是对 ChatTTS、Stable Audio Open 1.0 和 Riffusion v1 在 GRACE 各维度上的评估和对比。评分范围为 1-5,分数越高表示表现越好。
345
-
346
- ### 4.1 泛化性 (Generalization)
347
-
348
- * **ChatTTS:** (3.0/5) 主要专注于高质量语音合成,对非语音(如音乐或环境音)的生成能力有限,泛化范围相对较窄。
349
- * **Stable Audio Open 1.0:** (4.0/5) 能够生成多种类型的音频,包括不同的音乐流派、音效和环境声音,泛化能力较强。
350
- * **Riffusion v1:** (3.5/5) 专注于音乐生成,能够通过文本控制多种音乐风格,但其核心是音乐,对纯语音或复杂环境音的泛化能力有限。
351
-
352
- ### 4.2 相关性 (Relevance)
353
-
354
- * **ChatTTS:** (4.5/5) 在语音合成方面与文本输入高度相关,能够准确捕捉文本的含义和情感表达。
355
- * **Stable Audio Open 1.0:** (3.8/5) 对文本描述的理解尚可,但在处理非常具体或复杂的音频描述时,有时可能与用户期望存在一定偏差。
356
- * **Riffusion v1:** (3.5/5) 通过文本提示控制音乐生成可能不如直接的音乐模型那样精确,有时生成的音乐与文本描述的关联性会稍弱,需要更精细的提示词工程。
357
-
358
- ### 4.3 创新表现力 (Artistry)
359
-
360
- * **ChatTTS:** (4.5/5) 语音的自然度、情感表达和音质极高,具有出色的艺术表现力,使得合成语音听起来非常逼真且富有感染力。
361
- * **Stable Audio Open 1.0:** (4.0/5) 音质良好,能够生成具有一定创意和美感的音乐片段及音效,在通用音频生成领域具有较强的艺术表现力。
362
- * **Riffusion v1:** (4.0/5) 能够生成独特且具有实验性的音乐风格,创意性强。其视觉到听觉的转换机制带来了独特的艺术呈现,但有时可能不那么“传统”或预期。
363
-
364
- ### 4.4 效率性 (Efficiency)
365
-
366
- * **ChatTTS:** (3.0/5) 生成高质量、富有情感的语音通常计算成本较高,推理速度中等。
367
- * **Stable Audio Open 1.0:** (3.0/5) 模型参数量较大,生成速度相对较慢,对计算资源的需求较高。
368
- * **Riffusion v1:** (2.5/5) 生成频谱图图像和随后将其逆转换为音频的过程都比较耗时,是这三个模型中效率最低的。
369
-
370
- ### GRACE 评估雷达图
371
-
372
- ![GRACE Framework Model Comparison](https://i.imgur.com/your-radar-chart-image-link.png)
373
- *(注意:此处将插入使用 Plotly 生成的雷达图,通过 `gr.Plot` 组件展示)*
374
 
375
  ---
376
 
377
- ## 5. 结论与展望
378
-
379
- 本次实验对比了 ChatTTS、Stable Audio Open 1.0 和 Riffusion v1 三个文本到音频模型,它们各自在不同的细分领域展现了独特的优势:
380
 
381
- * **ChatTTS** 在**高质量、富有情感的语音合成**方面表现卓越是需要逼真人声首选
382
- * **Stable Audio Open 1.0** 是一款**通用的音频生成器**,适用于生成多样化的音效和音乐片段。
383
- * **Riffusion v1** 提供了一种**创新性的音乐创作方式**,适合寻求独特音乐风格或探索视觉-听觉关联的创作者。
384
-
385
- 未来的研究可以进一步探索:
386
-
387
- 1. **多模态输入:** 结合文本���外的其他模态(如图像、视频)来生成更丰富的音频。
388
- 2. **可控性与精细化:** 开发更精细的控制机制,让用户能更准确地指导模型的生成过程。
389
- 3. **客观评估指标:** 建立更完善的客观评估指标来补充主观的 GRACE 评估,以量化模型在特定任务上的表现。
390
-
391
- ---
392
- """
393
-
394
- # --- Gradio 界面定义 ---
395
- with gr.Blocks(title="文本到音频模型对比实验") as demo:
396
- gr.Markdown("# 文本到音频模型对比实验")
397
- gr.Markdown("本项目对比了 **2Noise/ChatTTS**、**stabilityai/stable-audio-open-1.0** 和 **riffusion/riffusion-model-v1** 三个文本到音频生成模型。")
398
-
399
- with gr.Tab("Arena"):
400
- gr.Markdown("### Arena: 统一输入,对比模型输出")
401
- gr.Markdown("在下方输入框中描述你希望生成的音频(例如:'一段平静的钢琴旋律', '一个机器人说话的声音', '一段悲伤的旁白'),然后点击生成按钮。")
402
-
403
- arena_input = gr.Textbox(label="输入文本提示词", placeholder="例如:A calm piano melody in a quiet room.")
404
- generate_button = gr.Button("生成音频")
405
-
406
- with gr.Row():
407
- with gr.Column():
408
- gr.Markdown("#### ChatTTS (语音)")
409
- output_chattts = gr.Audio(label="ChatTTS 输出", type="numpy", interactive=False)
410
- with gr.Column():
411
- gr.Markdown("#### Stable Audio Open 1.0 (音效/音乐)")
412
- output_stable_audio = gr.Audio(label="Stable Audio 输出", type="numpy", interactive=False)
413
- with gr.Column():
414
- gr.Markdown("#### Riffusion v1 (音乐)")
415
- output_riffusion = gr.Audio(label="Riffusion 输出", type="numpy", interactive=False)
416
-
417
- # 绑定生成按钮的点击事件
418
- generate_button.click(
419
- arena_predict,
420
- inputs=[arena_input],
421
- outputs=[output_chattts, output_stable_audio, output_riffusion]
422
- )
423
-
424
- # 示例提示词
425
- gr.Examples(
426
- [
427
- ["A gentle rain falling outside a window."],
428
- ["An epic orchestral track with drums and violins."],
429
- ["A robotic voice saying 'Hello world, how are you today?'"],
430
- ["A cat meowing loudly."],
431
- ["A short, happy tune on a flute."],
432
- ["A melancholic narration about lost dreams."] # For ChatTTS specific focus
433
- ],
434
- inputs=arena_input,
435
- label="尝试这些示例提示词"
436
- )
437
-
438
- with gr.Tab("模型对比"):
439
- gr.Markdown("### 模型对比: 基于 GRACE 框架的评估")
440
- gr.Markdown("此选项卡展示了 ChatTTS、Stable Audio Open 1.0 和 Riffusion v1 三个模型在泛化性、相关性、创新表现力和效率性四个维度上的表现对比。")
441
-
442
- # 嵌入 Plotly 雷达图
443
- grace_plot = gr.Plot(create_grace_radar_chart(), label="GRACE 框架模型对比雷达图", show_label=False)
444
- gr.Markdown("""
445
- **GRACE 维度解释:**
446
- * **G (Generalization) 泛化性:** 模型是否能适配多种输入任务、生成多样化音频。
447
- * **R (Relevance) 相关性:** 输出音频是否紧扣输入主题、与用户期望匹配。
448
- * **A (Artistry) 创新表现力:** 输出内容是否有创意、音质高、表现力强。
449
- * **E (Efficiency) 效率性:** 是否能在较少人工干预下快速输出高质量结果。
450
 
451
- *注意:上述评估是基于主观测试的定性分析,仅供参考。实际表现可能因提示词和模型版本而异。*
452
- """)
453
 
454
- with gr.Tab("报告"):
455
- gr.Markdown(report_content)
456
 
457
- # Gradio 启动时预加载模型
458
- demo.load(initialize_all_models, None, None)
459
 
460
- # 启动 Gradio 应用
461
- if __name__ == "__main__":
462
- demo.launch(share=False, debug=True, enable_queue=True, max_load_time=600) # 延长加载时间到 10 分钟
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq, MusicgenForConditionalGeneration
3
  import torch
4
  import numpy as np
5
  import plotly.graph_objects as go
6
  import os
7
  import scipy.io.wavfile # For saving/reading wav files
 
 
 
 
8
 
9
  # --- 全局变量和模型加载 ---
10
  HF_HOME = os.getenv("HF_HOME", "./hf_cache")
11
  os.makedirs(HF_HOME, exist_ok=True)
12
 
13
  # 定义模型路径和名称
14
+ # 模型 A: 文本到语音 (Text-to-Speech)
15
+ MODEL_TTS = "facebook/mms-tts-eng" # 可以替换为 mms-tts-zh 如果主要测试中文
16
+
17
+ # 模型 B: 文本到音效/简单音频 (Text-to-Audio/Sound Generation)
18
+ MODEL_AUDIOGEN = "facebook/audiogen-small"
19
 
20
  # 确定加载模型的设备
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
  print(f"当前使用的设备: {device}")
23
 
24
  # --- 模型加载函数 ---
25
+ tts_pipeline = None
26
+ audiogen_pipeline = None
 
27
 
28
+ # 假设同学 A 负责集成 MMS-TTS
29
+ def load_tts_model():
30
+ """加载 MMS-TTS 模型。"""
31
+ print(f"正在加载 TTS 模型: {MODEL_TTS} 到 {device}...")
32
  try:
33
+ # MMS-TTS以通过 text-to-speech pipeline 轻松加载
34
+ pipe = pipeline("text-to-speech", model=MODEL_TTS, device=device)
35
+ print("MMS-TTS 模型加载成功。")
36
+ return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  except Exception as e:
38
+ print(f"加载 MMS-TTS 时出错: {e}")
39
  return None
40
 
41
+ # 假设同学 B 负责集成 AudioGen-small
42
+ def load_audiogen_model():
43
+ """加载 AudioGen-small 模型。"""
44
+ print(f"正在加载 AudioGen 模型: {MODEL_AUDIOGEN} 到 {device}...")
45
  try:
46
+ # AudioGen 也可以通过 text-to-audio pipeline 加载
47
+ pipe = pipeline("text-to-audio", model=MODEL_AUDIOGEN, device=device)
48
+ print("AudioGen 模型加载成功。")
 
 
 
49
  return pipe
50
  except Exception as e:
51
+ print(f"加载 AudioGen 时出错: {e}")
52
  return None
53
 
 
 
 
 
 
54
  def initialize_all_models():
55
  """在 Gradio 界面加载时调用,用于初始化所有模型。"""
56
+ global tts_pipeline, audiogen_pipeline
57
+ tts_pipeline = load_tts_model()
58
+ audiogen_pipeline = load_audiogen_model()
 
 
 
 
 
 
 
 
59
 
60
  # --- 音频生成推理函数 ---
61
 
62
+ def generate_audio_tts(prompt):
63
+ """使用 MMS-TTS 生成语音。"""
64
+ if tts_pipeline is None:
65
+ return (16000, np.zeros(0)) # 返回空音频
66
  try:
67
+ # TTS pipeline 通常返回 dict 包含 'audio' (numpy array) 和 'sampling_rate'
68
+ result = tts_pipeline(prompt)
69
+ audio_array = result['audio']
70
+ sampling_rate = result['sampling_rate']
71
+ return (sampling_rate, audio_array)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  except Exception as e:
73
+ print(f"使用 MMS-TTS 生成音频出错: {e}")
74
  return (16000, np.zeros(0))
75
 
76
+ def generate_audio_audiogen(prompt):
77
+ """使用 AudioGen-small 生成音频。"""
78
+ if audiogen_pipeline is None:
79
+ return (44100, np.zeros(0)) # 返回空音频,AudioGen通常用44.1kHz
80
  try:
81
+ # AudioGen pipeline 返回 dict 包含 'audio' 和 'sampling_rate'
82
+ audio_output = audiogen_pipeline(prompt, sampling_rate=16000) # 指定一个采样率
83
  audio_array = audio_output['audio'][0].cpu().numpy() if isinstance(audio_output['audio'], torch.Tensor) else audio_output['audio']
84
+ sampling_rate = audio_output['sampling_rate'] if 'sampling_rate' in audio_output else 16000 # 确保拿到采样率
85
+ return (sampling_rate, audio_array)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  except Exception as e:
87
+ print(f"使用 AudioGen 生成音频出错: {e}")
88
+ return (16000, np.zeros(0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  # --- Arena 选项卡逻辑 ---
91
  def arena_predict(prompt):
92
  """Arena 选项卡的主预测函数,并行调用所有模型。"""
93
  print(f"收到提示词: {prompt}")
94
 
95
+ # MMS-TTS 生成
96
+ tts_sr, tts_audio = generate_audio_tts(prompt)
97
 
98
+ # AudioGen-small 生成
99
+ audiogen_sr, audiogen_audio = generate_audio_audiogen(prompt)
 
 
 
100
 
101
+ return (tts_sr, tts_audio), (audiogen_sr, audiogen_audio)
102
 
103
  # --- 模型对比/GRACE评估逻辑 ---
104
 
 
106
  # 这些分数是主观评估,你需要根据实际模型表现来调整
107
  # 评分范围:1-5,5 为最佳
108
  grace_data = {
109
+ "MMS-TTS-Eng": {
110
+ "Generalization": 2.5, # 专注于语音,对音乐/音效泛化能力弱
111
+ "Relevance": 4.5, # 语音合成与文本高度相关,发音准确
112
+ "Artistry": 4.0, # 语音自然度高,但情感和表现力不如ChatTTS
113
+ "Efficiency": 4.5 # 模型小,生成速度快,资源消耗低
114
  },
115
+ "AudioGen-small": {
116
+ "Generalization": 3.5, # 能生成音效和简单音乐,比TTS模型泛化
117
+ "Relevance": 3.5, # 对音频描述的理解中等,有时不完全符合预期
118
+ "Artistry": 3.5, # 音质尚可,创意性一般,但能满足基本音效需求
119
+ "Efficiency": 4.0 # 模型相对,生成速度较快,资源消耗中等
 
 
 
 
 
 
120
  }
121
  }
122
 
 
135
  theta=categories + [categories[0]],
136
  fill='toself',
137
  name=model_name,
138
+ line_color='blue' if model_name == "MMS-TTS-Eng" else 'red'
 
139
  ))
140
 
141
  fig.update_layout(
 
163
  report_content = """
164
  # 文本到音频生成模型对比实验报告
165
 
166
+ ## 1. 模型及类别选择
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ ### 1.1 所选模型类型及背景说明
169
+ 本次实验聚焦于**文本到音频(Text-to-Audio)生成型**,具体选择了两在功能上有所侧重且模型尺寸相对的模型进行对比:
170
+ 1. **`facebook/mms-tts-eng` (MMS-TTS-Eng)**:这是一个由 Meta AI (Facebook) 开发的**文本到语音(Text-to-Speech, TTS)**模型。它属于 Massively Multilingual Speech (MMS) 项目的一部分,旨在支持大量语言的语音合成。我们选择的是英语版本 (`eng`),它以其高效、轻量级和清晰的语音输出而闻名。
171
+ 2. **`facebook/audiogen-small` (AudioGen-small)**:这款模型是 Meta AI AudioCraft 框架下的一个较小版本,专注于**文本到音效或简单音频片段**的生成。它能够根据文本描述创建环境音、乐器声或短小的声音事件。
172
 
173
+ ### 1.2 模型用途对比简述
174
+ * **MMS-TTS-Eng** 的主要用途是将书面文本转换为自然发音的**人类语音**,适用于有声读物、语音助手、导航系统等场景。
175
+ * **AudioGen-small** 的主要用途是根据文本描述生成各种**非语音的音效或短音乐片段**,例如“雨声”、“鸟鸣”或“简单的吉他旋律”,适用于游戏音效、视频配乐或环境音模拟。
176
 
177
+ ### 1.3 两个模型异同点分析
178
+ **相同**
179
+ * 都属于文本到音频生成领域,将文本作��输入。
180
+ * 都由 Meta AI 开发,并在 Hugging Face Hub 上提供。
181
+ * 都相对轻量级,相比大型文生音/乐模型(如 Stable Audio 或 MusicGen-large)对资源要求更低。
182
 
183
+ **不同点:**
184
+ * **核心功能:** MMS-TTS 专注于**语音合成**,目标是还原人类语音的自然度;AudioGen 专注于**非语音音频生成**,目标是模拟真实世界的音效或创作音乐片段。
185
+ * **输出类型:** MMS-TTS 输出的是人类语言的声音;AudioGen 输出的是环境音、乐器音、抽象音效等。
186
+ * **内部机制:** 尽管都基于 Transformer 架构,但它们的内部训练数据、任务目标和具体架构优化不同,以适应各自的生成任务。
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  ---
189
 
190
+ ## 2. 系统实现细节
 
 
191
 
192
+ 本系统通过 Gradio 构建了一个交互式 Hugging Face Space实现了模型统一输入和多模型输出展示
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
+ ### 2.1 Gradio 交互界面截图
 
195
 
196
+ (此处应放置你的 Gradio Space 的截图,显示 "Arena" 和 "模型对比" 选项卡的用户界面)
197
+ ![Gradio Interface Screenshot](https://i.imgur.com/your-space-screenshot-link.png)
198
 
199
+ ### 2.2 输入与输出流程图 (Mermaid 语法)
 
200
 
201
+ ```mermaid
202
+ graph TD
203
+ A[用户输入文本提示词] --> B{Gradio UI};
204
+ B --> C1[MMS-TTS 模型];
205
+ B --> C2[AudioGen-small 模型];
206
+ C1 -- 生成语音 --> D1[MMS-TTS 音频输出];
207
+ C2 -- 生成音效 --> D2[AudioGen-small 音频输出];
208
+ D1 --> E[Arena 选项卡展示];
209
+ D2 --> E;
210
+ E -- 用户评估 --> F[GRACE 评估数据];
211
+ F --> G[模型对比选项卡];
212
+ G -- 生成雷达图 --> H[报告选项卡嵌入雷达图];