|
|
|
|
|
|
|
|
import gradio as gr |
|
|
import os |
|
|
import torch |
|
|
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler |
|
|
from transformers import pipeline |
|
|
from openai import OpenAI |
|
|
import time |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
|
if not OPENAI_API_KEY: |
|
|
print("警告:未在环境变量/Secrets中找到OpenAI API密钥。") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openai_client = None |
|
|
if OPENAI_API_KEY: |
|
|
try: |
|
|
openai_client = OpenAI(api_key=OPENAI_API_KEY) |
|
|
except Exception as e: |
|
|
print(f"初始化OpenAI客户端时出错: {e}") |
|
|
openai_client = None |
|
|
else: |
|
|
print("因缺少API密钥,OpenAI客户端未初始化。") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("正在加载 Stable Diffusion 模型...") |
|
|
start_time = time.time() |
|
|
try: |
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
torch_dtype=torch.float32 |
|
|
) |
|
|
|
|
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) |
|
|
|
|
|
|
|
|
print(f"Stable Diffusion 模型加载完成,耗时 {time.time() - start_time:.2f} 秒。") |
|
|
except Exception as e: |
|
|
print(f"加载 Stable Diffusion 模型时出错: {e}") |
|
|
pipe = None |
|
|
|
|
|
|
|
|
|
|
|
print("正在加载语音转文本模型...") |
|
|
start_time = time.time() |
|
|
try: |
|
|
|
|
|
|
|
|
asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=-1) |
|
|
print(f"语音转文本模型加载完成,耗时 {time.time() - start_time:.2f} 秒。") |
|
|
except Exception as e: |
|
|
print(f"加载语音转文本模型时出错: {e}") |
|
|
asr_pipeline = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_detailed_prompt(short_prompt, style_preference="电影感", neg_prompt=""): |
|
|
"""使用 LLM 生成详细提示词。""" |
|
|
if not openai_client: |
|
|
print("OpenAI 客户端不可用。") |
|
|
gr.Warning("OpenAI API Key 未配置。将使用基础提示词。") |
|
|
|
|
|
return f"{short_prompt},风格:{style_preference}", neg_prompt or "丑陋, 模糊, 低质量" |
|
|
|
|
|
|
|
|
system_message = "You are an expert prompt generator for Stable Diffusion." |
|
|
user_message = ( |
|
|
f"Based on the user's simple idea: '{short_prompt}', " |
|
|
f"generate a detailed and structured prompt for Stable Diffusion v1.5. " |
|
|
f"Incorporate the desired style: '{style_preference}'. " |
|
|
f"Include details about the subject, setting, lighting, composition, and quality keywords (like 'photorealistic', 'highly detailed', '4k', 'masterpiece'). " |
|
|
f"Also, suggest a relevant negative prompt focusing on common image issues (like 'ugly, deformed, blurry, low quality, extra limbs, text, words'). " |
|
|
f"If the user provided a negative prompt ('{neg_prompt}'), incorporate its essence or add to it. " |
|
|
f"Format the output clearly, separating the main prompt and the negative prompt, perhaps using '### Prompt:' and '### Negative Prompt:' labels." |
|
|
) |
|
|
|
|
|
try: |
|
|
print(f"向OpenAI发送请求以生成提示词: {short_prompt}") |
|
|
response = openai_client.chat.completions.create( |
|
|
model="gpt-3.5-turbo", |
|
|
messages=[ |
|
|
{"role": "system", "content": system_message}, |
|
|
{"role": "user", "content": user_message}, |
|
|
], |
|
|
max_tokens=150, |
|
|
temperature=0.7, |
|
|
) |
|
|
generated_text = response.choices[0].message.content.strip() |
|
|
print(f"收到来自OpenAI的响应: {generated_text}") |
|
|
|
|
|
|
|
|
prompt_marker = "### Prompt:" |
|
|
neg_prompt_marker = "### Negative Prompt:" |
|
|
|
|
|
final_prompt = generated_text |
|
|
final_neg_prompt = neg_prompt |
|
|
|
|
|
if prompt_marker in generated_text: |
|
|
start_idx = generated_text.find(prompt_marker) + len(prompt_marker) |
|
|
end_idx = generated_text.find(neg_prompt_marker) |
|
|
if end_idx != -1: |
|
|
final_prompt = generated_text[start_idx:end_idx].strip() |
|
|
else: |
|
|
final_prompt = generated_text[start_idx:].strip() |
|
|
|
|
|
if neg_prompt_marker in generated_text: |
|
|
start_idx = generated_text.find(neg_prompt_marker) + len(neg_prompt_marker) |
|
|
final_neg_prompt = generated_text[start_idx:].strip() |
|
|
|
|
|
if neg_prompt and neg_prompt not in final_neg_prompt: |
|
|
final_neg_prompt = f"{neg_prompt}, {final_neg_prompt}" |
|
|
|
|
|
|
|
|
|
|
|
final_prompt = final_prompt.replace("### Prompt:", "").strip() |
|
|
final_neg_prompt = final_neg_prompt.replace("### Negative Prompt:", "").strip() |
|
|
|
|
|
|
|
|
if not final_prompt: |
|
|
final_prompt = f"{short_prompt},风格:{style_preference},高度详细" |
|
|
if not final_neg_prompt: |
|
|
final_neg_prompt = "丑陋, 模糊, 低质量" |
|
|
|
|
|
return final_prompt, final_neg_prompt |
|
|
|
|
|
except Exception as e: |
|
|
print(f"调用OpenAI API时出错: {e}") |
|
|
gr.Warning(f"OpenAI API 错误: {e}。将使用基础提示词。") |
|
|
|
|
|
return f"{short_prompt},风格:{style_preference},高度详细", neg_prompt or "丑陋, 模糊, 低质量" |
|
|
|
|
|
|
|
|
def generate_image(prompt, neg_prompt, guidance, steps): |
|
|
"""使用 Stable Diffusion 生成图像。""" |
|
|
if not pipe: |
|
|
gr.Error("Stable Diffusion 模型加载失败。无法生成图像。") |
|
|
|
|
|
return Image.new('RGB', (512, 512), color = 'grey'), "错误:SD 模型未加载。" |
|
|
|
|
|
print(f"正在使用提示词生成图像: {prompt}") |
|
|
print(f"反向提示词: {neg_prompt}, 引导系数: {guidance}, 步数: {steps}") |
|
|
start_time = time.time() |
|
|
try: |
|
|
|
|
|
with torch.no_grad(): |
|
|
image = pipe( |
|
|
prompt, |
|
|
negative_prompt=neg_prompt, |
|
|
guidance_scale=float(guidance), |
|
|
num_inference_steps=int(steps), |
|
|
num_images_per_prompt=1, |
|
|
|
|
|
|
|
|
).images[0] |
|
|
print(f"图像生成完成,耗时 {time.time() - start_time:.2f} 秒。") |
|
|
return image, prompt |
|
|
except Exception as e: |
|
|
print(f"图像生成过程中出错: {e}") |
|
|
gr.Error(f"图像生成失败: {e}") |
|
|
return Image.new('RGB', (512, 512), color = 'red'), prompt |
|
|
|
|
|
|
|
|
|
|
|
def transcribe_audio(audio_filepath): |
|
|
"""使用 Whisper 模型转录音频。""" |
|
|
if not asr_pipeline: |
|
|
gr.Warning("语音转文本模型未加载。") |
|
|
return "错误:ASR模型未加载。" |
|
|
if audio_filepath is None: |
|
|
gr.Warning("未提供音频输入。") |
|
|
return "" |
|
|
|
|
|
print(f"正在转录音频文件: {audio_filepath}") |
|
|
start_time = time.time() |
|
|
try: |
|
|
transcript = asr_pipeline(audio_filepath) |
|
|
print(f"音频转录完成,耗时 {time.time() - start_time:.2f} 秒。") |
|
|
print(f"转录结果: {transcript['text']}") |
|
|
return transcript["text"] |
|
|
except Exception as e: |
|
|
print(f"音频转录过程中出错: {e}") |
|
|
gr.Error(f"音频转录失败: {e}") |
|
|
return f"音频转录过程中出错: {e}" |
|
|
|
|
|
|
|
|
def process_input(short_prompt_text, audio_input, style, neg_prompt, guidance, steps): |
|
|
"""按钮点击触发的主函数。""" |
|
|
final_short_prompt = short_prompt_text |
|
|
status = "正在开始..." |
|
|
|
|
|
|
|
|
if audio_input is not None: |
|
|
status = "正在转录音频..." |
|
|
print(f"检测到音频输入: {audio_input}") |
|
|
|
|
|
|
|
|
transcription = transcribe_audio(audio_input) |
|
|
if "错误" not in transcription and transcription.strip(): |
|
|
final_short_prompt = transcription |
|
|
status = f"使用转录文本: '{transcription[:50]}...'" |
|
|
print(status) |
|
|
elif "错误" in transcription: |
|
|
status = "音频转录失败。如文本框有内容将使用文本输入。" |
|
|
print(status) |
|
|
|
|
|
if not final_short_prompt: |
|
|
gr.Error("音频转录失败且文本框为空。") |
|
|
return "错误:无有效输入。", None, None |
|
|
else: |
|
|
status = "音频转录结果为空。如文本框有内容将使用文本输入。" |
|
|
print(status) |
|
|
|
|
|
if not final_short_prompt: |
|
|
gr.Error("音频转录结果为空且文本框为空。") |
|
|
return "错误:无有效输入。", None, None |
|
|
|
|
|
|
|
|
if not final_short_prompt: |
|
|
gr.Error("请输入简短描述或使用语音输入。") |
|
|
return "错误:输入为空。", None, None |
|
|
|
|
|
|
|
|
status = "正在生成详细提示词..." |
|
|
print(status) |
|
|
|
|
|
|
|
|
detailed_prompt, final_neg_prompt = generate_detailed_prompt(final_short_prompt, style, neg_prompt) |
|
|
status = "正在生成图像..." |
|
|
print(f"使用详细提示词: {detailed_prompt}") |
|
|
print(f"使用反向提示词: {final_neg_prompt}") |
|
|
|
|
|
|
|
|
|
|
|
image, used_prompt = generate_image(detailed_prompt, final_neg_prompt, guidance, steps) |
|
|
status = "图像生成完成!" |
|
|
print(status) |
|
|
return status, used_prompt, image |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# Stable Diffusion 提示词优化与图像生成器") |
|
|
gr.Markdown( |
|
|
"输入简短描述(或使用语音输入!),选择风格,获取详细提示词和生成的图像。\n" |
|
|
"*注意:图像生成在 CPU 上运行,将会**很慢**(可能需要几分钟)。*\n" |
|
|
"*提示词优化功能需要在 Hugging Face Space Secrets 中设置 `OPENAI_API_KEY`。*" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
inp_short_prompt = gr.Textbox(label="简短描述", placeholder="例如:天空中的魔法树屋") |
|
|
inp_audio = gr.Audio(sources=["microphone"], type="filepath", label="或者录制语音输入(可选)") |
|
|
inp_style = gr.Dropdown( |
|
|
label="图像风格", |
|
|
choices=["照片写实", "电影感", "奇幻艺术", "动漫", "水彩", "像素艺术", "赛博朋克"], |
|
|
value="电影感" |
|
|
) |
|
|
inp_neg_prompt = gr.Textbox(label="反向提示词(可选)", placeholder="例如:文字, 词语, 模糊, 变形") |
|
|
inp_guidance = gr.Slider(minimum=1, maximum=15, step=0.5, value=7.5, label="引导系数 (CFG)") |
|
|
inp_steps = gr.Slider(minimum=10, maximum=50, step=1, value=20, label="推理步数(越低越快,细节可能减少)") |
|
|
|
|
|
generate_button = gr.Button("生成图像", variant="primary") |
|
|
|
|
|
|
|
|
out_status = gr.Textbox(label="状态", value="准备就绪", interactive=False) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
|
|
|
out_detailed_prompt = gr.Textbox(label="生成的详细提示词", interactive=False, lines=5) |
|
|
out_image = gr.Image(label="生成的图像", type="pil") |
|
|
|
|
|
|
|
|
|
|
|
generate_button.click( |
|
|
fn=process_input, |
|
|
inputs=[ |
|
|
inp_short_prompt, |
|
|
inp_audio, |
|
|
inp_style, |
|
|
inp_neg_prompt, |
|
|
inp_guidance, |
|
|
inp_steps |
|
|
], |
|
|
outputs=[ |
|
|
out_status, |
|
|
out_detailed_prompt, |
|
|
out_image |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
if not pipe: |
|
|
print("\n严重错误:Stable Diffusion 模型加载失败。Gradio 应用可能无法正常工作。") |
|
|
if not asr_pipeline: |
|
|
print("\n警告:语音转文本模型加载失败。语音输入将无法工作。") |
|
|
if not openai_client: |
|
|
print("\n警告:OpenAI客户端未初始化(检查API密钥)。提示词优化将是基础模式。") |
|
|
|
|
|
demo.queue() |
|
|
demo.launch(debug=False) |