123 / app.py
shenyugan's picture
Update app.py
0f0f3dc verified
import gradio as gr
import openai
import os
import torch
from diffusers import StableDiffusionPipeline
from transformers import pipeline
import time
import warnings
from PIL import Image
# --- 配置和模型加载 ---
# 忽略一些不必要的警告
warnings.filterwarnings("ignore")
# 检查 OpenAI API Key
api_key = os.getenv("OPENAI_API_KEY")
openai_client = None
openai_available = False
if api_key:
try:
openai_client = openai.OpenAI(api_key=api_key)
# 尝试进行一次小的 API 调用以验证密钥
openai_client.models.list()
openai_available = True
print("OpenAI API Key 已加载并验证成功。")
except Exception as e:
print(f"加载 OpenAI API Key 时出错或验证失败: {e}")
print("提示词增强功能将不可用。请在 Hugging Face Secrets 中设置 OPENAI_API_KEY。")
else:
print("警告: 环境变量 OPENAI_API_KEY 未设置。")
print("提示词增强功能将不可用。请在 Hugging Face Secrets 中设置 OPENAI_API_KEY。")
# 全局加载模型 (这对于Hugging Face Spaces性能至关重要,避免每次调用都加载)
# 使用 try-except 块,以便在模型加载失败时应用仍能部分运行
sd_pipe = None
asr_pipe = None
sd_model_id = "runwayml/stable-diffusion-v1-5"
asr_model_id = "openai/whisper-base.en" # 使用英文基础模型,更快
try:
print(f"开始加载 Stable Diffusion 模型: {sd_model_id} (这可能需要一些时间)...")
# 明确指定使用 CPU 和 float32
sd_pipe = StableDiffusionPipeline.from_pretrained(
sd_model_id,
torch_dtype=torch.float32 # CPU 推理使用 float32
)
# 不需要 .to("cuda"),默认在 CPU
print("Stable Diffusion 模型加载完成。")
except Exception as e:
print(f"错误: 加载 Stable Diffusion 模型失败: {e}")
sd_pipe = None # 标记为不可用
try:
print(f"开始加载 ASR (语音识别) 模型: {asr_model_id} ...")
# 为 ASR pipeline 指定设备为 CPU
asr_pipe = pipeline("automatic-speech-recognition", model=asr_model_id, device=-1) # device=-1 表示 CPU
print("ASR 模型加载完成。")
except Exception as e:
print(f"错误: 加载 ASR 模型失败: {e}")
asr_pipe = None # 标记为不可用
# --- 核心功能函数 ---
# Step 1: Prompt-to-Prompt
def enhance_prompt_with_llm(short_prompt):
"""使用 OpenAI LLM 增强简短描述为详细的 Stable Diffusion 提示词"""
if not openai_available or not openai_client:
print("OpenAI 不可用,跳过提示词增强。")
# 如果 OpenAI 不可用,直接返回原始输入或一个默认结构
return f"{short_prompt}, photorealistic, high quality"
system_message = """你是一个创意助手,擅长将简单的想法扩展成生动、详细、结构化的图像生成提示词(Prompt),特别适用于 Stable Diffusion。
请根据用户输入的简短描述,生成一个更丰富、包含细节、风格和构图建议的英文提示词。
提示词应该包含:
1. 主要主体和动作。
2. 重要的细节描述(材质、光线、环境)。
3. 艺术风格(例如:photorealistic, cartoon, illustration, oil painting, watercolor, cyberpunk, fantasy art)。
4. 构图和视角(例如:wide angle shot, close-up shot, aerial view)。
5. 画面质量词语(例如:masterpiece, high resolution, ultra detailed, 8k)。
6. (可选)一个简单的负面提示词(Negative Prompt),放在 "Negative Prompt: " 之后,用来排除不想要的元素(例如:low quality, blurry, text, watermark, signature)。
请直接输出增强后的英文提示词和负面提示词。
格式示例:
[增强后的英文提示词主体], [风格], [构图], [质量词语] Negative Prompt: [负面提示词]
"""
user_message = f"请将以下简短描述扩展成一个详细的 Stable Diffusion 提示词: '{short_prompt}'"
try:
response = openai_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
],
temperature=0.7, # 增加一点创造性
max_tokens=150
)
enhanced_prompt = response.choices[0].message.content.strip()
# 简单的后处理,确保有 Negative Prompt 部分
if "Negative Prompt:" not in enhanced_prompt:
enhanced_prompt += " Negative Prompt: low quality, blurry, worst quality, text, watermark"
# 分离主提示词和负面提示词
if "Negative Prompt:" in enhanced_prompt:
parts = enhanced_prompt.split("Negative Prompt:", 1)
main_prompt = parts[0].strip()
negative_prompt = parts[1].strip()
else:
main_prompt = enhanced_prompt
negative_prompt = "low quality, blurry, worst quality, text, watermark" # 默认负面提示词
return main_prompt, negative_prompt
except Exception as e:
print(f"调用 OpenAI API 时出错: {e}")
# 出错时提供一个基础版本
return f"{short_prompt}, photorealistic, high quality", "low quality, blurry, worst quality"
# Step 2: Prompt-to-Image
def generate_image_from_prompt(prompt, negative_prompt, guidance, steps):
"""使用 Stable Diffusion v1.5 生成图像"""
if not sd_pipe:
print("Stable Diffusion 模型不可用,无法生成图像。")
# 返回一个占位符或错误信息图像
# 创建一个简单的黑色图像表示错误
img = Image.new('RGB', (512, 512), color = 'black')
# 在图像上添加错误信息可能比较复杂,这里只返回黑图
return img, "错误: Stable Diffusion 模型未加载"
print(f"开始生成图像,提示词: {prompt}, 负面提示词: {negative_prompt}")
print(f"参数: guidance_scale={guidance}, num_inference_steps={steps}")
start_time = time.time()
try:
# 在 CPU 上生成图像
# 注意:CPU 生成会非常慢!
with torch.no_grad(): # 确保在推理模式下不计算梯度
# 检查 prompt 和 negative_prompt 是否为空
prompt = prompt if prompt else "default prompt"
negative_prompt = negative_prompt if negative_prompt else ""
image = sd_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=float(guidance),
num_inference_steps=int(steps),
generator=torch.Generator("cpu").manual_seed(42) # 固定种子以获得可复现结果 (可选)
).images[0]
end_time = time.time()
print(f"图像生成完成,耗时: {end_time - start_time:.2f} 秒")
return image, f"生成成功 (耗时: {end_time - start_time:.2f} 秒)"
except Exception as e:
print(f"生成图像时出错: {e}")
img = Image.new('RGB', (512, 512), color = 'red') # 红色图像表示错误
return img, f"错误: {e}"
# 语音转文字函数
def transcribe_audio(audio_path):
"""将音频文件转录为文字"""
if not asr_pipe:
print("ASR 模型不可用,无法处理语音输入。")
return "错误: 语音识别模型未加载", "" # 返回错误信息和空状态
if audio_path is None:
return "没有检测到音频输入。", "" # 返回提示信息
print(f"开始处理音频文件: {audio_path}")
try:
start_time = time.time()
# 使用 ASR pipeline 处理本地文件路径
transcription = asr_pipe(audio_path)
text = transcription["text"]
end_time = time.time()
print(f"语音识别完成,耗时: {end_time - start_time:.2f} 秒")
print(f"识别结果: {text}")
# 将识别的文本更新到文本输入框
return text, f"语音识别成功 (耗时: {end_time - start_time:.2f} 秒)"
except Exception as e:
print(f"语音识别过程中出错: {e}")
return f"语音识别错误: {e}", "错误"
# --- Gradio 界面 ---
# 定义主处理流程函数
def process_input(input_text, guidance, steps, audio_input):
"""处理文本或语音输入,生成提示词和图像 (修正版)"""
# 初始化变量以存储最终结果
final_text_input = input_text
enhanced_prompt = ""
negative_prompt = ""
generated_image = None # Placeholder for the final image
status = "等待输入"
transcription_status = ""
# 1. 初始状态
yield None, None, None, final_text_input, "开始处理..." # Image, Prompts=None, Text=Current, Status
# 2. 处理语音输入 (如果提供)
if audio_input is not None:
status = "语音识别中..."
yield None, None, None, final_text_input, status
transcribed_text, transcription_status = transcribe_audio(audio_input)
if "错误" not in transcription_status:
final_text_input = transcribed_text
status = f"语音识别完成: {transcription_status}"
yield None, None, None, final_text_input, status # Update input text and status
else:
status = f"语音识别失败: {transcribed_text}. 使用文本框内容。"
# 即使失败,也要 yield 5 个值
yield None, None, None, final_text_input, status
if not final_text_input:
status = "请输入描述或提供语音输入。"
yield None, None, None, final_text_input, status
return # 结束执行
# 3. 增强提示词
status = "增强提示词中..."
yield None, None, None, final_text_input, status # Image, Prompts=None
if openai_available:
try:
enhanced_prompt, negative_prompt = enhance_prompt_with_llm(final_text_input)
status = "提示词增强完成。"
except Exception as e:
status = f"提示词增强错误: {e}. 使用基础提示词。"
# 使用基础提示词
enhanced_prompt, negative_prompt = f"{final_text_input}, photorealistic, high quality", "low quality, blurry, worst quality"
else:
# OpenAI 不可用,使用基础提示词
enhanced_prompt, negative_prompt = f"{final_text_input}, photorealistic, high quality", "low quality, blurry, worst quality"
status = "跳过提示词增强 (OpenAI不可用)。"
# Yield a state update *before* starting image generation
yield None, enhanced_prompt, negative_prompt, final_text_input, status
# 4. 生成图像
status = "生成图像中 (CPU可能较慢)..."
yield None, enhanced_prompt, negative_prompt, final_text_input, status # Update prompts and status, image still None
generated_image, image_status = generate_image_from_prompt(enhanced_prompt, negative_prompt, guidance, steps)
status = f"图像生成完成. {image_status}" # Combine statuses
# 5. 最终 Yield: 返回所有 5 个最终值
# Make sure the order matches the `outputs` list in the .click() handler:
# [output_image, enhanced_prompt_display, negative_prompt_display, input_text, status_textbox]
yield generated_image, enhanced_prompt, negative_prompt, final_text_input, status
# 构建 Gradio Blocks 界面
with gr.Blocks(css="footer {visibility: hidden}") as demo:
gr.Markdown("# AI 图像生成器 (文字/语音 -> 提示词 -> 图像)")
gr.Markdown("输入简短描述或使用麦克风录音,应用将自动 (可选) 增强提示词并使用 Stable Diffusion v1.5 生成图像。")
# 显示模型加载状态和警告
if not openai_available:
gr.Markdown("<p style='color:orange;'>⚠️ **警告:** OpenAI API Key 未正确配置,提示词增强功能将使用默认模式。</p>")
if not sd_pipe:
gr.Markdown("<p style='color:red;'>❌ **错误:** Stable Diffusion 模型加载失败,图像生成功能不可用。</p>")
if not asr_pipe:
gr.Markdown("<p style='color:orange;'>⚠️ **警告:** 语音识别模型加载失败,语音输入功能不可用。</p>")
status_textbox = gr.Textbox(label="处理状态", value="等待输入", interactive=False)
with gr.Row():
with gr.Column(scale=1):
# 输入控件
input_text = gr.Textbox(label="输入简短描述", placeholder="例如:空中的魔法树屋")
# 语音输入控件 (加分项)
audio_input = gr.Audio(sources=["microphone"], type="filepath", label="或者,使用语音输入", visible=asr_pipe is not None) # 仅当ASR加载成功时可见
# Stable Diffusion 参数控件 (类型1: 滑块)
guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, value=7.5, step=0.5, label="引导系数 (Guidance Scale)")
# (类型2: 滑块) - 注意: 步数越多越慢
num_inference_steps = gr.Slider(minimum=10, maximum=50, value=25, step=1, label="推理步数 (Steps)")
# (类型3: 按钮)
generate_button = gr.Button("✨ 生成图像 ✨", variant="primary")
with gr.Column(scale=1):
# 输出控件
enhanced_prompt_display = gr.Textbox(label="增强后的提示词 (Prompt)", interactive=False)
negative_prompt_display = gr.Textbox(label="负面提示词 (Negative Prompt)", interactive=False)
output_image = gr.Image(label="生成的图像", type="pil") # 使用 PIL 格式
# 绑定事件
# 将文本输入、语音输入和参数控件连接到处理函数
generate_button.click(
fn=process_input,
inputs=[input_text, guidance_scale, num_inference_steps, audio_input],
outputs=[output_image, enhanced_prompt_display, negative_prompt_display, input_text, status_textbox] # 输出图像、增强提示词、负面提示词、更新后的文本输入框(用于显示语音识别结果)、状态文本框
)
# 当音频录制完成时,可以先调用语音识别更新文本框(可选,但更流畅)
# 但为了简化流程,我们将所有处理放在按钮点击后
# audio_input.change( # 或者 .stop_recording
# fn=lambda audio: transcribe_audio(audio)[0], # 只取文本部分
# inputs=[audio_input],
# outputs=[input_text] # 直接更新文本输入框
# )
# --- 启动应用 ---
# 使用 queue() 使其能够处理并发请求和长时间运行的任务
# 设置 share=True 可以获得一个公开链接(如果在本地运行)
# 在 Hugging Face Spaces 上运行时不需要 share=True
demo.queue().launch(debug=True) # 开启 debug 模式可以在控制台看到更详细的日志