111 / app.py
shenyugan's picture
Update app.py
75670ae verified
import gradio as gr
import torch
import os
import time
from openai import OpenAI, OpenAIError
from transformers import pipeline as hf_pipeline # Renamed to avoid conflict
from diffusers import StableDiffusionPipeline
# --- 配置 ---
# 1. OpenAI API Key (IMPORTANT: Set this as a Secret in your Hugging Face Space)
# Go to your Space > Settings > Secrets > Add secret
# Name: OPENAI_API_KEY
# Value: sk-your-openai-api-key
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
print("警告:未找到 OPENAI_API_KEY 环境变量。请在 Hugging Face Space Secrets 中设置它。")
# You might want to add a fallback or raise an error here for local testing
# For local testing only, uncomment and set your key here:
# openai_api_key = "sk-..."
client = OpenAI(api_key=openai_api_key) if openai_api_key else None
# 2. 选择设备 (GPU优先, 否则CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用的设备: {device}")
# 3. 模型加载 (放在全局以避免重复加载)
print("正在加载模型...")
# 语音识别模型 (Whisper - base 版本在速度和准确性之间取得良好平衡)
try:
asr_pipeline = hf_pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
print("Whisper 模型加载成功。")
except Exception as e:
print(f"加载 Whisper 模型失败: {e}")
asr_pipeline = None
# 图像生成模型 (Stable Diffusion v1.5 - 速度较快,资源占用相对较少)
model_id = "runwayml/stable-diffusion-v1-5"
try:
# 如果使用GPU,使用float16以节省显存和加速
dtype = torch.float16 if device == "cuda" else torch.float32
sd_pipeline = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype)
sd_pipeline = sd_pipeline.to(device)
print("Stable Diffusion v1.5 模型加载成功。")
except Exception as e:
print(f"加载 Stable Diffusion 模型失败: {e}")
sd_pipeline = None
print("模型加载完成。")
# --- 功能函数 ---
# Step 1: 语音转文字 (如果提供了音频)
def transcribe_audio(audio_path):
"""将音频文件转录为文字"""
if not asr_pipeline:
return "错误:语音识别模型未加载。", ""
if audio_path is None:
return "错误:未提供音频输入。", ""
try:
print(f"开始转录音频: {audio_path}")
# 确保 audio_path 是有效的文件路径字符串
if not isinstance(audio_path, str) or not os.path.exists(audio_path):
return f"错误:无效的音频文件路径 '{audio_path}'。", ""
# 使用 chunk_length_s 和 stride_length_s 可能有助于处理长音频
transcript = asr_pipeline(audio_path, chunk_length_s=30, stride_length_s=5)
text = transcript["text"]
print(f"转录完成: {text}")
return text, f"语音识别结果: {text}" # 返回文本和状态信息
except Exception as e:
print(f"语音转录失败: {e}")
return f"错误:语音转录失败 - {e}", f"错误:语音转录失败 - {e}"
# Step 2: 生成详细的 Prompt
def generate_detailed_prompt(short_prompt):
"""使用 LLM (OpenAI GPT) 将简短描述扩展为详细的 SD Prompt"""
if not client:
return "错误:OpenAI API Key 未配置或无效。", "错误:OpenAI API Key 未配置或无效。"
if not short_prompt or not short_prompt.strip():
return "错误:请输入有效的简短描述。", "错误:请输入有效的简短描述。"
print(f"开始为 '{short_prompt}' 生成详细提示词...")
system_message = """你是一位专业的 Stable Diffusion 提示词工程师。
你的任务是将用户提供的简短描述扩展为一个结构良好、细节丰富、视觉效果强烈的英文提示词。
提示词应该包含以下元素(如果适用):
1. **主体:** 清晰描述图像的核心内容。
2. **细节:** 添加具体的物体、特征、纹理、颜色等。
3. **场景/环境:** 描述背景、地点、氛围。
4. **风格:** 如 'photorealistic', 'anime', 'oil painting', 'cyberpunk', 'fantasy art' 等。
5. **构图/视角:** 如 'wide angle', 'close-up', 'portrait', 'overhead view'。
6. **光照:** 如 'cinematic lighting', 'soft light', 'studio lighting', 'volumetric lighting'。
7. **艺术家风格 (可选):** 如 'by greg rutkowski', 'by artgerm', 'by studio ghibli'。
8. **质量词:** 如 'masterpiece', 'high resolution', '4k', 'detailed', 'intricate details'。
请用逗号分隔不同的描述性词语或短语。确保输出的只是提示词本身,不要包含任何解释性文字或前缀。
例如,输入 "空中的魔法树屋",输出可能像:
magical treehouse floating in the sky, intricate details, fantasy art style, surrounded by glowing clouds, cinematic lighting, wide angle view, masterpiece, high resolution, by studio ghibli and greg rutkowski
"""
user_message = f"请为以下描述生成详细的 Stable Diffusion 提示词: \"{short_prompt}\""
try:
start_time = time.time()
response = client.chat.completions.create(
model="gpt-3.5-turbo", # 速度较快
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
],
temperature=0.7, # 控制创造性
max_tokens=150 # 限制输出长度
)
detailed_prompt = response.choices[0].message.content.strip()
end_time = time.time()
print(f"详细提示词生成成功 (耗时: {end_time - start_time:.2f} 秒): {detailed_prompt}")
return detailed_prompt, f"详细提示词生成成功 (耗时: {end_time - start_time:.2f} 秒)"
except OpenAIError as e:
print(f"OpenAI API 调用失败: {e}")
return f"错误:OpenAI API 调用失败 - {e}", f"错误:OpenAI API 调用失败 - {e}"
except Exception as e:
print(f"生成详细提示词时发生未知错误: {e}")
return f"错误:生成详细提示词时发生未知错误 - {e}", f"错误:生成详细提示词时发生未知错误 - {e}"
# Step 3: 生成图像
def generate_image(prompt, negative_prompt, steps, scale):
"""使用 Stable Diffusion 生成图像"""
if not sd_pipeline:
return None, "错误:Stable Diffusion 模型未加载。"
if not prompt or not prompt.strip():
return None, "错误:无法使用空的详细提示词生成图像。"
print(f"开始生成图像,提示词: '{prompt}', Negative: '{negative_prompt}', Steps: {steps}, Scale: {scale}")
try:
start_time = time.time()
# 在 GPU 上运行时设置生成器以获得确定性结果 (可选)
generator = torch.Generator(device=device).manual_seed(int(time.time())) if device == "cuda" else None
with torch.inference_mode(): # 节省显存
image = sd_pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=int(steps),
guidance_scale=float(scale),
generator=generator
).images[0]
end_time = time.time()
print(f"图像生成成功 (耗时: {end_time - start_time:.2f} 秒)")
return image, f"图像生成成功 (耗时: {end_time - start_time:.2f} 秒)"
except Exception as e:
print(f"图像生成失败: {e}")
# 尝试释放一些显存 (如果是在GPU上)
if device == "cuda":
torch.cuda.empty_cache()
return None, f"错误:图像生成失败 - {e}"
# --- Gradio 应用主逻辑 ---
# --- Gradio 应用主逻辑 ---
def process_input_and_generate(short_prompt_text, short_prompt_audio, negative_prompt, steps, scale):
"""
处理文本或语音输入,生成详细提示词,然后生成图像。
这是一个生成器函数,可以逐步更新 Gradio UI 状态。
"""
yield "开始处理...", None, None, "状态:初始化..." # 清空输出, 更新状态
final_short_prompt = ""
status_updates = []
image_result = None # Initialize image result
detailed_prompt_result = None # Initialize detailed prompt result
# 1. 确定输入源 (优先使用语音)
if short_prompt_audio is not None:
yield "正在处理语音输入...", None, None, "状态:语音识别中..."
transcribed_text, status = transcribe_audio(short_prompt_audio)
status_updates.append(status)
if "错误" in transcribed_text:
# FIX: Separate f-string and newline concatenation
status_message = f"状态:处理失败\n{status}"
yield None, None, None, status_message
return
final_short_prompt = transcribed_text
yield f"语音识别结果: {final_short_prompt}", None, None, "状态:语音识别完成"
elif short_prompt_text and short_prompt_text.strip():
final_short_prompt = short_prompt_text.strip()
status_updates.append("使用文本输入。")
yield f"使用文本输入: {final_short_prompt}", None, None, "状态:获取文本输入"
else:
yield "错误:请输入文本描述或提供语音输入。", None, None, "状态:错误 - 无有效输入"
return
if not final_short_prompt:
yield "错误:无法获取有效的简短描述。", None, None, "状态:错误 - 无有效输入"
return
# --- Update the display of the short prompt used ---
# We yield the final_short_prompt along with other updates from now on
detailed_prompt_result = "正在生成详细提示词..."
yield final_short_prompt, detailed_prompt_result, image_result, f"状态:正在调用 LLM 生成详细提示词...\n{status_updates[-1] if status_updates else ''}" # Show last status
# 2. 生成详细 Prompt
detailed_prompt, status = generate_detailed_prompt(final_short_prompt)
status_updates.append(status)
detailed_prompt_result = detailed_prompt # Update the result holder
if "错误" in detailed_prompt:
# FIX: Separate f-string and newline concatenation
joined_updates = "\n".join(status_updates)
status_message = f"状态:处理失败\n{joined_updates}"
yield final_short_prompt, detailed_prompt_result, image_result, status_message
return
# FIX: Separate f-string and newline concatenation
joined_updates = "\n".join(status_updates)
status_message = f"状态:详细提示词生成成功,准备生成图像...\n{joined_updates}"
yield final_short_prompt, detailed_prompt_result, image_result, status_message # Keep image empty for now
# 3. 生成图像
image_result = "正在生成图像..." # Placeholder text while generating
# FIX: Separate f-string and newline concatenation
joined_updates = "\n".join(status_updates)
status_message = f"状态:正在调用 Stable Diffusion...\n{joined_updates}"
yield final_short_prompt, detailed_prompt_result, image_result, status_message
image, status = generate_image(detailed_prompt_result, negative_prompt, steps, scale)
status_updates.append(status)
image_result = image # Update with actual image or None
if image_result is None:
# FIX: Separate f-string and newline concatenation
joined_updates = "\n".join(status_updates)
status_message = f"状态:处理失败\n{joined_updates}"
# Yield the error message in the image spot temporarily? Or keep it None. Let's keep it None.
yield final_short_prompt, detailed_prompt_result, None, status_message
return
# 4. 显示最终结果
# FIX: Separate f-string and newline concatenation
joined_updates = "\n".join(status_updates)
status_message = f"状态:完成!\n{joined_updates}"
yield final_short_prompt, detailed_prompt_result, image_result, status_message
# --- Gradio 界面构建 ---
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown(
"""
# 🎨 Stable Diffusion 图像生成器 (文字/语音 -> Prompt -> 图像)
输入一个简短的描述(或使用麦克风说出来),系统会自动生成详细的英文提示词,并使用 Stable Diffusion v1.5 生成图像。
**注意:**
* 语音输入优先于文本输入。如果同时提供了两者,将使用语音输入。
* 模型加载和图像生成可能需要一些时间,尤其是在 CPU 或免费 Hugging Face Spaces 上。
* 请确保已在 Space Secrets 中配置 OpenAI API Key (`OPENAI_API_KEY`)。
* **提示**: 语音输入优先于文本;生成步数影响细节和时间;引导系数控制与提示词的符合度;负面提示词排除不想要的内容。
"""
) # Added essential info to the main markdown
with gr.Row():
with gr.Column(scale=1):
# 输入控件
# REMOVED info argument
inp_short_prompt_text = gr.Textbox(label="输入简短描述 (例如:空中的魔法树屋)")
# REMOVED info argument
inp_short_prompt_audio = gr.Audio(sources=["microphone"], type="filepath", label="或者,使用麦克风说出描述")
# Gradio 控件使用要求 (至少三种)
# 1. Textbox (上面已有)
# 2. Audio (上面已有)
# 3. Slider
# REMOVED info argument
inp_steps = gr.Slider(minimum=10, maximum=100, value=30, step=1, label="生成步数 (Steps)")
# 4. Slider (另一个)
# REMOVED info argument
inp_scale = gr.Slider(minimum=1.0, maximum=20.0, value=7.5, step=0.5, label="引导系数 (Guidance Scale)")
# 5. Textbox (负面提示词)
# REMOVED info argument
inp_negative_prompt = gr.Textbox(label="负面提示词 (Negative Prompt)", value="ugly, blurry, low quality, deformed, text, words, signature")
generate_button = gr.Button("✨ 生成图像", variant="primary")
# 状态显示区域
status_output = gr.Textbox(label="处理状态", lines=4, interactive=False)
with gr.Column(scale=1):
# 输出控件
out_final_short_prompt = gr.Textbox(label="使用的简短描述", interactive=False) # 显示最终使用的输入
out_detailed_prompt = gr.Textbox(label="生成的详细提示词 (用于 Stable Diffusion)", lines=5, interactive=False)
out_image = gr.Image(label="生成的图像", type="pil", interactive=False) # 使用 PIL 格式
# 将按钮点击连接到处理函数
generate_button.click(
fn=process_input_and_generate,
inputs=[inp_short_prompt_text, inp_short_prompt_audio, inp_negative_prompt, inp_steps, inp_scale],
outputs=[out_final_short_prompt, out_detailed_prompt, out_image, status_output]
)
# 添加一些示例 (确保 inputs/outputs 匹配修改后的组件)
gr.Examples(
examples=[
["a magical treehouse in the sky", None, "blurry, low quality", 30, 7.5],
["photo of a cute cat wearing sunglasses", None, "cartoon, drawing, sketch", 35, 8.0],
["cyberpunk city street at night, raining", None, "daytime, bright sun", 40, 7.0],
],
# Ensure the order of inputs here matches the inputs list in the .click() call
inputs=[inp_short_prompt_text, inp_short_prompt_audio, inp_negative_prompt, inp_steps, inp_scale],
# Ensure the order of outputs here matches the outputs list in the .click() call
outputs=[out_final_short_prompt, out_detailed_prompt, out_image, status_output],
fn=process_input_and_generate,
cache_examples=False,
)
# --- 启动 Gradio 应用 ---
# --- 启动 Gradio 应用 ---
if __name__ == "__main__":
# 简化启动块:移除启动前的检查,让应用尝试启动。
# 具体的模型/API Key是否可用的错误处理,放在 Gradio 事件触发的函数内部进行。
# (可选) 可以在这里打印一次警告信息,但不进行任何退出操作
if not client:
print("警告:OpenAI API Key 未配置或无效。需要 API Key 才能使用提示词生成功能。")
if not asr_pipeline:
print("警告:语音识别模型未能加载。语音输入功能将不可用。")
if not sd_pipeline:
print("警告:Stable Diffusion 模型未能加载。图像生成功能将不可用。")
print("--- 模型和配置加载检查完成 ---")
print("--- 准备启动 Gradio 应用 ---") # 添加日志点
# 在 Hugging Face Spaces 上运行时,不需要 share=True
# debug=True 可以在日志中提供更详细的 Gradio 输出,有助于排查问题
# server_name="0.0.0.0" 确保服务监听所有网络接口,这在容器化环境中通常是必要的
try:
app.launch(server_name="0.0.0.0") # 明确指定 server_name
# 如果 launch 成功启动并阻塞,下面的 print 不会执行
# 如果 launch 因为某种原因快速失败并返回,可能会看到下面的 print
print("--- Gradio launch() 函数已返回 (可能未成功启动) ---")
except Exception as e:
print(f"--- Gradio app.launch() 发生异常: {e} ---") # 捕获可能的启动异常
# 可以在这里添加更详细的错误处理或日志记录
raise # 重新抛出异常,以便 Spaces 能看到错误
# 注意:正常情况下,app.launch() 会阻塞进程,使 Web 服务持续运行。
# 如果脚本在这里之后还能继续执行并退出,说明 launch() 没有成功启动或保持运行。