audio-ldm2 / app.py
Adam-512's picture
fix
bdaac6f
# app.py
import torch
import spaces
import gradio as gr
from diffusers import AudioLDM2Pipeline
import scipy.io.wavfile as wavfile
import tempfile
import os
_pipe = None
# 设置缓存(只需这三行,永久生效)
os.environ["HF_HOME"] = "/tmp/huggingface"
os.makedirs("/tmp/huggingface", exist_ok=True)
def get_pipeline():
"""只在 CPU 上懒加载和缓存模型,避免在主进程中过早初始化 CUDA。"""
global _pipe
if _pipe is None:
_pipe = AudioLDM2Pipeline.from_pretrained(
"cvssp/audioldm2",
torch_dtype=torch.float16,
revision="0f5395520e81196e2edb657c0ea85aac026b0599"
)
_pipe.enable_attention_slicing()
_pipe.vae.enable_slicing()
print("模型加载完毕,后续调用秒响应!")
return _pipe
# 生成函数(所有 CUDA 相关操作必须放在 @spaces.GPU 内)
@spaces.GPU(duration=120)
def text_to_audio(
prompt: str,
negative_prompt: str = "",
duration: float = 5.0,
guidance_scale: float = 3.5,
num_inference_steps: int = 100,
num_waveforms: int = 1,
seed: int = -1,
):
# 在 GPU 进程中获取模型并移动到 CUDA
pipe = get_pipeline().to("cuda")
generator = None if seed == -1 else torch.Generator(device="cuda").manual_seed(seed)
with torch.autocast(device_type="cuda"):
audio = pipe(
prompt,
negative_prompt=negative_prompt or None,
num_inference_steps=num_inference_steps,
audio_length_in_s=duration,
num_waveforms_per_prompt=num_waveforms,
guidance_scale=guidance_scale,
generator=generator,
).audios[0]
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav", dir="/tmp")
wavfile.write(tmp_file.name, rate=16000, data=audio)
return tmp_file.name
# ==================== Gradio 界面 ====================
css = """
.gradio-container {max-width: 900px !important; margin: auto !important;}
footer {display: none !important;}
"""
with gr.Blocks(title="AudioLDM2-Large Text-to-Audio") as demo:
gr.HTML(f"<style>{css}</style>")
gr.Markdown("""
# AudioLDM2-Large
文本生成音频模型
""")
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Textbox(
label="描述你想要的音频(越详细越好)",
placeholder="例如:A dog barking angrily on a busy city street with car horns",
lines=3
)
negative = gr.Textbox(
label="负面提示(可选)",
placeholder="low quality, noise, distortion, echo",
value="low quality,music,noise",
lines=3
)
with gr.Row():
duration = gr.Slider(2.0, 120.0, value=30, step=0.5, label="时长(秒)")
steps = gr.Slider(50, 200, value=100, step=25, label="采样步数(越高越精细但越慢)")
with gr.Row():
guidance = gr.Slider(1.0, 10.0, value=3.5, step=0.5, label="引导尺度(Guidance Scale)")
num = gr.Slider(1, 4, value=1, step=1, label="生成数量(同时生成多个候选)")
seed = gr.Number(value=-1, label="随机种子(相同种子+相同提示 = 可复现,填 -1 随机)")
btn = gr.Button("Generate Audio 🎵", variant="primary", size="lg")
with gr.Column(scale=1):
output_audio = gr.Audio(label="生成的音频", type="filepath", interactive=False)
btn.click(
fn=text_to_audio,
inputs=[prompt, negative, duration, guidance, steps, num, seed],
outputs=output_audio,
show_progress=True
)
gr.Examples(
examples=[
["A beautiful piano melody with soft strings in the background", "", 8.0],
["Thunderstorm with heavy rain and strong wind blowing through trees", "", 7.0],
["A cat meowing and then purring while being petted", "", 5.0],
["80s synthwave music with retro drums and electric guitar solo", "", 10.0],
["Fire crackling in a cozy fireplace on a winter night", "", 6.0],
],
inputs=[prompt, negative, duration],
label="点击示例一键生成"
)
gr.Markdown("""
### Tips
- 生成一次大约需要 20~60 秒(取决于步数和时长)
- 推荐 200 步 + Guidance 3.5~4.5 获得最佳质量
""")
if __name__ == "__main__":
demo.queue(max_size=20).launch()