HOMEWORK_4 / app.py
longdiyao's picture
Update app.py
9cc55e6 verified
import os
import torch
import gradio as gr
from PIL import Image
from diffusers import StableDiffusionPipeline
from openai import OpenAI
import speech_recognition as sr
# 配置模型与API密钥
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
openai_client = OpenAI(api_key=OPENAI_API_KEY)
MODEL_ID = "runwayml/stable-diffusion-v1-5"
# 初始化 Stable Diffusion
device = "cpu"
pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float32)
pipe = pipe.to(device)
# Prompt-to-Prompt 模块
def generate_prompt(description, provider, temperature):
try:
model = "gpt-3.5-turbo" if provider == "OpenAI" else "deepseek-chat"
client = openai_client
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": f"将以下描述转换为一个适用于 Stable Diffusion 的图像生成提示词:'{description}'"}],
temperature=temperature
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"❌ 提示生成失败:{str(e)}"
# Prompt-to-Image 模块
def generate_image(prompt, negative_prompt, guidance_scale, steps):
with torch.no_grad():
result = pipe(prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=steps)
return result.images[0]
# 语音转文字
def transcribe_audio(audio):
recognizer = sr.Recognizer()
with sr.AudioFile(audio) as source:
audio_data = recognizer.record(source)
try:
return recognizer.recognize_google(audio_data, language="zh-CN")
except sr.UnknownValueError:
return "❌ 无法识别语音"
except sr.RequestError as e:
return f"❌ 语音识别出错: {e}"
# 构建 Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎨 Prompt-to-Image 应用\n通过自然语言或语音描述生成图像")
with gr.Row():
with gr.Column():
description = gr.Textbox(label="请输入图像描述", placeholder="如:空中的魔法树屋")
audio_input = gr.Audio(type="filepath", label="录音或上传音频", sources=["microphone", "upload"])
transcribed_text = gr.Textbox(label="语音转文字结果", interactive=False)
transcribe_btn = gr.Button("识别语音")
model_provider = gr.Dropdown(["OpenAI", "DeepSeek"], label="选择模型提供者")
temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="温度(Temperature)")
generate_prompt_btn = gr.Button("生成提示词")
generated_prompt = gr.Textbox(label="生成的提示词")
negative_prompt = gr.Textbox(label="反向提示词", value="blurry, distorted, watermark")
guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="Guidance Scale")
steps = gr.Slider(10, 50, value=25, step=1, label="Inference Steps")
generate_image_btn = gr.Button("🎨 生成图像")
with gr.Column():
output_image = gr.Image(label="生成图像", type="pil")
# 点击绑定
transcribe_btn.click(fn=transcribe_audio, inputs=audio_input, outputs=transcribed_text)
generate_prompt_btn.click(fn=generate_prompt, inputs=[description, model_provider, temperature], outputs=generated_prompt)
generate_image_btn.click(fn=generate_image, inputs=[generated_prompt, negative_prompt, guidance, steps], outputs=output_image)
# 运行
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)