work / app.py
txh17's picture
Create app.py
3320745 verified
import gradio as gr
from transformers import pipeline
import torch
from diffusers import StableDiffusionPipeline
import soundfile as sf
import speech_recognition as sr
import numpy as np
import os
# 初始化组件
# 使用较小的开源LLM进行提示增强
llm_pipe = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.1")
# 初始化Stable Diffusion
sd_pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16
).to("cuda" if torch.cuda.is_available() else "cpu")
# 语音识别初始化
recognizer = sr.Recognizer()
def enhance_prompt(basic_prompt, style, detail_level, artist_style):
"""使用LLM增强提示词"""
prompt_template = f"""
根据以下简短描述创建一个详细的Stable Diffusion提示:
原始描述: {basic_prompt}
风格: {style}
细节级别: {detail_level}
艺术家风格: {artist_style}
请生成一个包含以下元素的详细提示:
- 主体描述
- 环境/背景
- 光照条件
- 色彩风格
- 艺术媒介(如数字绘画、油画等)
- 质量描述(如4K、超详细等)
生成的提示:
"""
enhanced_prompt = llm_pipe(
prompt_template,
max_length=200,
num_return_sequences=1,
temperature=0.7
)[0]['generated_text']
# 清理生成的文本
enhanced_prompt = enhanced_prompt.replace(prompt_template, "").strip()
return enhanced_prompt
def generate_image(enhanced_prompt, steps, guidance_scale, seed):
"""使用Stable Diffusion生成图像"""
if seed == -1:
seed = torch.randint(0, 2**32, (1,)).item()
generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed)
image = sd_pipe(
enhanced_prompt,
num_inference_steps=steps,
guidance_scale=guidance_scale,
generator=generator
).images[0]
return image, seed
def process_audio(audio):
"""处理语音输入"""
sr, audio_data = audio
audio_array = np.array(audio_data, dtype=np.float32)
# 保存临时文件供语音识别使用
temp_file = "temp_audio.wav"
sf.write(temp_file, audio_array, sr)
with sr.AudioFile(temp_file) as source:
audio_data = recognizer.record(source)
try:
text = recognizer.recognize_google(audio_data, language='en-US')
os.remove(temp_file)
return text
except Exception as e:
os.remove(temp_file)
return f"语音识别错误: {str(e)}"
def full_process(basic_prompt, style, detail_level, artist_style, steps, guidance_scale, seed, use_audio, audio_input):
"""完整处理流程"""
# 处理语音输入
if use_audio and audio_input is not None:
basic_prompt = process_audio(audio_input)
# 生成增强提示
enhanced_prompt = enhance_prompt(basic_prompt, style, detail_level, artist_style)
# 生成图像
image, used_seed = generate_image(enhanced_prompt, steps, guidance_scale, seed)
return enhanced_prompt, image, used_seed
# Gradio界面
with gr.Blocks(title="魔法树屋图像生成器") as demo:
gr.Markdown("# 🎨 魔法树屋图像生成器")
gr.Markdown("输入简短描述或使用语音输入,生成精美图像!")
with gr.Row():
with gr.Column():
# 输入部分
use_audio = gr.Checkbox(label="使用语音输入")
audio_input = gr.Audio(label="录音", visible=False)
basic_prompt = gr.Textbox(
label="简短描述",
placeholder="例如: 天空中的魔法树屋",
visible=True
)
# 当复选框变化时切换输入方式
def toggle_input(use_audio):
return {
basic_prompt: gr.update(visible=not use_audio),
audio_input: gr.update(visible=use_audio)
}
use_audio.change(
toggle_input,
inputs=use_audio,
outputs=[basic_prompt, audio_input]
)
# 风格选项
style = gr.Dropdown(
label="风格",
choices=["现实主义", "幻想艺术", "赛博朋克", "水墨画", "卡通", "极简主义"],
value="幻想艺术"
)
detail_level = gr.Slider(
label="细节级别",
minimum=1,
maximum=5,
step=1,
value=3
)
artist_style = gr.Dropdown(
label="艺术家风格",
choices=["无", "梵高", "毕加索", "莫奈", "达利", "宫崎骏"],
value="无"
)
# 高级选项
with gr.Accordion("高级选项", open=False):
steps = gr.Slider(
label="生成步数",
minimum=20,
maximum=100,
step=5,
value=50
)
guidance_scale = gr.Slider(
label="引导尺度",
minimum=1.0,
maximum=20.0,
step=0.5,
value=7.5
)
seed = gr.Number(
label="随机种子 (-1 表示随机)",
value=-1
)
submit_btn = gr.Button("生成图像", variant="primary")
with gr.Column():
# 输出部分
enhanced_prompt = gr.Textbox(
label="生成的提示",
interactive=False
)
image_output = gr.Image(
label="生成的图像",
height=512
)
used_seed = gr.Number(
label="使用的种子",
interactive=False
)
# 连接按钮
submit_btn.click(
fn=full_process,
inputs=[
basic_prompt, style, detail_level, artist_style,
steps, guidance_scale, seed, use_audio, audio_input
],
outputs=[enhanced_prompt, image_output, used_seed]
)
# 对于Hugging Face Spaces,我们需要设置队列
demo.queue()
# 启动应用
if __name__ == "__main__":
demo.launch()