|
|
import gradio as gr |
|
|
import torch |
|
|
import os |
|
|
import time |
|
|
from openai import OpenAI, OpenAIError |
|
|
from transformers import pipeline as hf_pipeline |
|
|
from diffusers import StableDiffusionPipeline |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openai_api_key = os.getenv("OPENAI_API_KEY") |
|
|
if not openai_api_key: |
|
|
print("警告:未找到 OPENAI_API_KEY 环境变量。请在 Hugging Face Space Secrets 中设置它。") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
client = OpenAI(api_key=openai_api_key) if openai_api_key else None |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"使用的设备: {device}") |
|
|
|
|
|
|
|
|
print("正在加载模型...") |
|
|
|
|
|
try: |
|
|
asr_pipeline = hf_pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device) |
|
|
print("Whisper 模型加载成功。") |
|
|
except Exception as e: |
|
|
print(f"加载 Whisper 模型失败: {e}") |
|
|
asr_pipeline = None |
|
|
|
|
|
|
|
|
model_id = "runwayml/stable-diffusion-v1-5" |
|
|
try: |
|
|
|
|
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
sd_pipeline = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype) |
|
|
sd_pipeline = sd_pipeline.to(device) |
|
|
print("Stable Diffusion v1.5 模型加载成功。") |
|
|
except Exception as e: |
|
|
print(f"加载 Stable Diffusion 模型失败: {e}") |
|
|
sd_pipeline = None |
|
|
|
|
|
print("模型加载完成。") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transcribe_audio(audio_path): |
|
|
"""将音频文件转录为文字""" |
|
|
if not asr_pipeline: |
|
|
return "错误:语音识别模型未加载。", "" |
|
|
if audio_path is None: |
|
|
return "错误:未提供音频输入。", "" |
|
|
try: |
|
|
print(f"开始转录音频: {audio_path}") |
|
|
|
|
|
if not isinstance(audio_path, str) or not os.path.exists(audio_path): |
|
|
return f"错误:无效的音频文件路径 '{audio_path}'。", "" |
|
|
|
|
|
|
|
|
transcript = asr_pipeline(audio_path, chunk_length_s=30, stride_length_s=5) |
|
|
text = transcript["text"] |
|
|
print(f"转录完成: {text}") |
|
|
return text, f"语音识别结果: {text}" |
|
|
except Exception as e: |
|
|
print(f"语音转录失败: {e}") |
|
|
return f"错误:语音转录失败 - {e}", f"错误:语音转录失败 - {e}" |
|
|
|
|
|
|
|
|
def generate_detailed_prompt(short_prompt): |
|
|
"""使用 LLM (OpenAI GPT) 将简短描述扩展为详细的 SD Prompt""" |
|
|
if not client: |
|
|
return "错误:OpenAI API Key 未配置或无效。", "错误:OpenAI API Key 未配置或无效。" |
|
|
if not short_prompt or not short_prompt.strip(): |
|
|
return "错误:请输入有效的简短描述。", "错误:请输入有效的简短描述。" |
|
|
|
|
|
print(f"开始为 '{short_prompt}' 生成详细提示词...") |
|
|
system_message = """你是一位专业的 Stable Diffusion 提示词工程师。 |
|
|
你的任务是将用户提供的简短描述扩展为一个结构良好、细节丰富、视觉效果强烈的英文提示词。 |
|
|
提示词应该包含以下元素(如果适用): |
|
|
1. **主体:** 清晰描述图像的核心内容。 |
|
|
2. **细节:** 添加具体的物体、特征、纹理、颜色等。 |
|
|
3. **场景/环境:** 描述背景、地点、氛围。 |
|
|
4. **风格:** 如 'photorealistic', 'anime', 'oil painting', 'cyberpunk', 'fantasy art' 等。 |
|
|
5. **构图/视角:** 如 'wide angle', 'close-up', 'portrait', 'overhead view'。 |
|
|
6. **光照:** 如 'cinematic lighting', 'soft light', 'studio lighting', 'volumetric lighting'。 |
|
|
7. **艺术家风格 (可选):** 如 'by greg rutkowski', 'by artgerm', 'by studio ghibli'。 |
|
|
8. **质量词:** 如 'masterpiece', 'high resolution', '4k', 'detailed', 'intricate details'。 |
|
|
请用逗号分隔不同的描述性词语或短语。确保输出的只是提示词本身,不要包含任何解释性文字或前缀。 |
|
|
例如,输入 "空中的魔法树屋",输出可能像: |
|
|
magical treehouse floating in the sky, intricate details, fantasy art style, surrounded by glowing clouds, cinematic lighting, wide angle view, masterpiece, high resolution, by studio ghibli and greg rutkowski |
|
|
""" |
|
|
user_message = f"请为以下描述生成详细的 Stable Diffusion 提示词: \"{short_prompt}\"" |
|
|
|
|
|
try: |
|
|
start_time = time.time() |
|
|
response = client.chat.completions.create( |
|
|
model="gpt-3.5-turbo", |
|
|
messages=[ |
|
|
{"role": "system", "content": system_message}, |
|
|
{"role": "user", "content": user_message} |
|
|
], |
|
|
temperature=0.7, |
|
|
max_tokens=150 |
|
|
) |
|
|
detailed_prompt = response.choices[0].message.content.strip() |
|
|
end_time = time.time() |
|
|
print(f"详细提示词生成成功 (耗时: {end_time - start_time:.2f} 秒): {detailed_prompt}") |
|
|
return detailed_prompt, f"详细提示词生成成功 (耗时: {end_time - start_time:.2f} 秒)" |
|
|
except OpenAIError as e: |
|
|
print(f"OpenAI API 调用失败: {e}") |
|
|
return f"错误:OpenAI API 调用失败 - {e}", f"错误:OpenAI API 调用失败 - {e}" |
|
|
except Exception as e: |
|
|
print(f"生成详细提示词时发生未知错误: {e}") |
|
|
return f"错误:生成详细提示词时发生未知错误 - {e}", f"错误:生成详细提示词时发生未知错误 - {e}" |
|
|
|
|
|
|
|
|
def generate_image(prompt, negative_prompt, steps, scale): |
|
|
"""使用 Stable Diffusion 生成图像""" |
|
|
if not sd_pipeline: |
|
|
return None, "错误:Stable Diffusion 模型未加载。" |
|
|
if not prompt or not prompt.strip(): |
|
|
return None, "错误:无法使用空的详细提示词生成图像。" |
|
|
|
|
|
print(f"开始生成图像,提示词: '{prompt}', Negative: '{negative_prompt}', Steps: {steps}, Scale: {scale}") |
|
|
try: |
|
|
start_time = time.time() |
|
|
|
|
|
generator = torch.Generator(device=device).manual_seed(int(time.time())) if device == "cuda" else None |
|
|
|
|
|
with torch.inference_mode(): |
|
|
image = sd_pipeline( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
num_inference_steps=int(steps), |
|
|
guidance_scale=float(scale), |
|
|
generator=generator |
|
|
).images[0] |
|
|
|
|
|
end_time = time.time() |
|
|
print(f"图像生成成功 (耗时: {end_time - start_time:.2f} 秒)") |
|
|
return image, f"图像生成成功 (耗时: {end_time - start_time:.2f} 秒)" |
|
|
except Exception as e: |
|
|
print(f"图像生成失败: {e}") |
|
|
|
|
|
if device == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
return None, f"错误:图像生成失败 - {e}" |
|
|
|
|
|
|
|
|
|
|
|
def process_input_and_generate(short_prompt_text, short_prompt_audio, negative_prompt, steps, scale): |
|
|
""" |
|
|
处理文本或语音输入,生成详细提示词,然后生成图像。 |
|
|
这是一个生成器函数,可以逐步更新 Gradio UI 状态。 |
|
|
""" |
|
|
yield "开始处理...", None, None, "状态:初始化..." |
|
|
|
|
|
final_short_prompt = "" |
|
|
status_updates = [] |
|
|
image_result = None |
|
|
detailed_prompt_result = None |
|
|
|
|
|
|
|
|
if short_prompt_audio is not None: |
|
|
yield "正在处理语音输入...", None, None, "状态:语音识别中..." |
|
|
transcribed_text, status = transcribe_audio(short_prompt_audio) |
|
|
status_updates.append(status) |
|
|
if "错误" in transcribed_text: |
|
|
|
|
|
status_message = f"状态:处理失败\n{status}" |
|
|
yield None, None, None, status_message |
|
|
return |
|
|
final_short_prompt = transcribed_text |
|
|
yield f"语音识别结果: {final_short_prompt}", None, None, "状态:语音识别完成" |
|
|
elif short_prompt_text and short_prompt_text.strip(): |
|
|
final_short_prompt = short_prompt_text.strip() |
|
|
status_updates.append("使用文本输入。") |
|
|
yield f"使用文本输入: {final_short_prompt}", None, None, "状态:获取文本输入" |
|
|
else: |
|
|
yield "错误:请输入文本描述或提供语音输入。", None, None, "状态:错误 - 无有效输入" |
|
|
return |
|
|
|
|
|
if not final_short_prompt: |
|
|
yield "错误:无法获取有效的简短描述。", None, None, "状态:错误 - 无有效输入" |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
detailed_prompt_result = "正在生成详细提示词..." |
|
|
yield final_short_prompt, detailed_prompt_result, image_result, f"状态:正在调用 LLM 生成详细提示词...\n{status_updates[-1] if status_updates else ''}" |
|
|
|
|
|
|
|
|
detailed_prompt, status = generate_detailed_prompt(final_short_prompt) |
|
|
status_updates.append(status) |
|
|
detailed_prompt_result = detailed_prompt |
|
|
if "错误" in detailed_prompt: |
|
|
|
|
|
joined_updates = "\n".join(status_updates) |
|
|
status_message = f"状态:处理失败\n{joined_updates}" |
|
|
yield final_short_prompt, detailed_prompt_result, image_result, status_message |
|
|
return |
|
|
|
|
|
joined_updates = "\n".join(status_updates) |
|
|
status_message = f"状态:详细提示词生成成功,准备生成图像...\n{joined_updates}" |
|
|
yield final_short_prompt, detailed_prompt_result, image_result, status_message |
|
|
|
|
|
|
|
|
image_result = "正在生成图像..." |
|
|
|
|
|
joined_updates = "\n".join(status_updates) |
|
|
status_message = f"状态:正在调用 Stable Diffusion...\n{joined_updates}" |
|
|
yield final_short_prompt, detailed_prompt_result, image_result, status_message |
|
|
|
|
|
image, status = generate_image(detailed_prompt_result, negative_prompt, steps, scale) |
|
|
status_updates.append(status) |
|
|
image_result = image |
|
|
if image_result is None: |
|
|
|
|
|
joined_updates = "\n".join(status_updates) |
|
|
status_message = f"状态:处理失败\n{joined_updates}" |
|
|
|
|
|
yield final_short_prompt, detailed_prompt_result, None, status_message |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
joined_updates = "\n".join(status_updates) |
|
|
status_message = f"状态:完成!\n{joined_updates}" |
|
|
yield final_short_prompt, detailed_prompt_result, image_result, status_message |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as app: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# 🎨 Stable Diffusion 图像生成器 (文字/语音 -> Prompt -> 图像) |
|
|
|
|
|
输入一个简短的描述(或使用麦克风说出来),系统会自动生成详细的英文提示词,并使用 Stable Diffusion v1.5 生成图像。 |
|
|
**注意:** |
|
|
* 语音输入优先于文本输入。如果同时提供了两者,将使用语音输入。 |
|
|
* 模型加载和图像生成可能需要一些时间,尤其是在 CPU 或免费 Hugging Face Spaces 上。 |
|
|
* 请确保已在 Space Secrets 中配置 OpenAI API Key (`OPENAI_API_KEY`)。 |
|
|
* **提示**: 语音输入优先于文本;生成步数影响细节和时间;引导系数控制与提示词的符合度;负面提示词排除不想要的内容。 |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
|
|
|
inp_short_prompt_text = gr.Textbox(label="输入简短描述 (例如:空中的魔法树屋)") |
|
|
|
|
|
inp_short_prompt_audio = gr.Audio(sources=["microphone"], type="filepath", label="或者,使用麦克风说出描述") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inp_steps = gr.Slider(minimum=10, maximum=100, value=30, step=1, label="生成步数 (Steps)") |
|
|
|
|
|
|
|
|
inp_scale = gr.Slider(minimum=1.0, maximum=20.0, value=7.5, step=0.5, label="引导系数 (Guidance Scale)") |
|
|
|
|
|
|
|
|
inp_negative_prompt = gr.Textbox(label="负面提示词 (Negative Prompt)", value="ugly, blurry, low quality, deformed, text, words, signature") |
|
|
|
|
|
generate_button = gr.Button("✨ 生成图像", variant="primary") |
|
|
|
|
|
|
|
|
status_output = gr.Textbox(label="处理状态", lines=4, interactive=False) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
|
|
|
out_final_short_prompt = gr.Textbox(label="使用的简短描述", interactive=False) |
|
|
out_detailed_prompt = gr.Textbox(label="生成的详细提示词 (用于 Stable Diffusion)", lines=5, interactive=False) |
|
|
out_image = gr.Image(label="生成的图像", type="pil", interactive=False) |
|
|
|
|
|
|
|
|
generate_button.click( |
|
|
fn=process_input_and_generate, |
|
|
inputs=[inp_short_prompt_text, inp_short_prompt_audio, inp_negative_prompt, inp_steps, inp_scale], |
|
|
outputs=[out_final_short_prompt, out_detailed_prompt, out_image, status_output] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["a magical treehouse in the sky", None, "blurry, low quality", 30, 7.5], |
|
|
["photo of a cute cat wearing sunglasses", None, "cartoon, drawing, sketch", 35, 8.0], |
|
|
["cyberpunk city street at night, raining", None, "daytime, bright sun", 40, 7.0], |
|
|
], |
|
|
|
|
|
inputs=[inp_short_prompt_text, inp_short_prompt_audio, inp_negative_prompt, inp_steps, inp_scale], |
|
|
|
|
|
outputs=[out_final_short_prompt, out_detailed_prompt, out_image, status_output], |
|
|
fn=process_input_and_generate, |
|
|
cache_examples=False, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not client: |
|
|
print("警告:OpenAI API Key 未配置或无效。需要 API Key 才能使用提示词生成功能。") |
|
|
if not asr_pipeline: |
|
|
print("警告:语音识别模型未能加载。语音输入功能将不可用。") |
|
|
if not sd_pipeline: |
|
|
print("警告:Stable Diffusion 模型未能加载。图像生成功能将不可用。") |
|
|
|
|
|
print("--- 模型和配置加载检查完成 ---") |
|
|
print("--- 准备启动 Gradio 应用 ---") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
app.launch(server_name="0.0.0.0") |
|
|
|
|
|
|
|
|
print("--- Gradio launch() 函数已返回 (可能未成功启动) ---") |
|
|
except Exception as e: |
|
|
print(f"--- Gradio app.launch() 发生异常: {e} ---") |
|
|
|
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
|