|
|
import gradio as gr |
|
|
import openai |
|
|
import os |
|
|
import torch |
|
|
from diffusers import StableDiffusionPipeline |
|
|
from transformers import pipeline |
|
|
import time |
|
|
import warnings |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
api_key = os.getenv("OPENAI_API_KEY") |
|
|
openai_client = None |
|
|
openai_available = False |
|
|
if api_key: |
|
|
try: |
|
|
openai_client = openai.OpenAI(api_key=api_key) |
|
|
|
|
|
openai_client.models.list() |
|
|
openai_available = True |
|
|
print("OpenAI API Key 已加载并验证成功。") |
|
|
except Exception as e: |
|
|
print(f"加载 OpenAI API Key 时出错或验证失败: {e}") |
|
|
print("提示词增强功能将不可用。请在 Hugging Face Secrets 中设置 OPENAI_API_KEY。") |
|
|
else: |
|
|
print("警告: 环境变量 OPENAI_API_KEY 未设置。") |
|
|
print("提示词增强功能将不可用。请在 Hugging Face Secrets 中设置 OPENAI_API_KEY。") |
|
|
|
|
|
|
|
|
|
|
|
sd_pipe = None |
|
|
asr_pipe = None |
|
|
sd_model_id = "runwayml/stable-diffusion-v1-5" |
|
|
asr_model_id = "openai/whisper-base.en" |
|
|
|
|
|
try: |
|
|
print(f"开始加载 Stable Diffusion 模型: {sd_model_id} (这可能需要一些时间)...") |
|
|
|
|
|
sd_pipe = StableDiffusionPipeline.from_pretrained( |
|
|
sd_model_id, |
|
|
torch_dtype=torch.float32 |
|
|
) |
|
|
|
|
|
print("Stable Diffusion 模型加载完成。") |
|
|
except Exception as e: |
|
|
print(f"错误: 加载 Stable Diffusion 模型失败: {e}") |
|
|
sd_pipe = None |
|
|
|
|
|
try: |
|
|
print(f"开始加载 ASR (语音识别) 模型: {asr_model_id} ...") |
|
|
|
|
|
asr_pipe = pipeline("automatic-speech-recognition", model=asr_model_id, device=-1) |
|
|
print("ASR 模型加载完成。") |
|
|
except Exception as e: |
|
|
print(f"错误: 加载 ASR 模型失败: {e}") |
|
|
asr_pipe = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def enhance_prompt_with_llm(short_prompt): |
|
|
"""使用 OpenAI LLM 增强简短描述为详细的 Stable Diffusion 提示词""" |
|
|
if not openai_available or not openai_client: |
|
|
print("OpenAI 不可用,跳过提示词增强。") |
|
|
|
|
|
return f"{short_prompt}, photorealistic, high quality" |
|
|
|
|
|
system_message = """你是一个创意助手,擅长将简单的想法扩展成生动、详细、结构化的图像生成提示词(Prompt),特别适用于 Stable Diffusion。 |
|
|
请根据用户输入的简短描述,生成一个更丰富、包含细节、风格和构图建议的英文提示词。 |
|
|
提示词应该包含: |
|
|
1. 主要主体和动作。 |
|
|
2. 重要的细节描述(材质、光线、环境)。 |
|
|
3. 艺术风格(例如:photorealistic, cartoon, illustration, oil painting, watercolor, cyberpunk, fantasy art)。 |
|
|
4. 构图和视角(例如:wide angle shot, close-up shot, aerial view)。 |
|
|
5. 画面质量词语(例如:masterpiece, high resolution, ultra detailed, 8k)。 |
|
|
6. (可选)一个简单的负面提示词(Negative Prompt),放在 "Negative Prompt: " 之后,用来排除不想要的元素(例如:low quality, blurry, text, watermark, signature)。 |
|
|
请直接输出增强后的英文提示词和负面提示词。 |
|
|
格式示例: |
|
|
[增强后的英文提示词主体], [风格], [构图], [质量词语] Negative Prompt: [负面提示词] |
|
|
""" |
|
|
user_message = f"请将以下简短描述扩展成一个详细的 Stable Diffusion 提示词: '{short_prompt}'" |
|
|
|
|
|
try: |
|
|
response = openai_client.chat.completions.create( |
|
|
model="gpt-3.5-turbo", |
|
|
messages=[ |
|
|
{"role": "system", "content": system_message}, |
|
|
{"role": "user", "content": user_message}, |
|
|
], |
|
|
temperature=0.7, |
|
|
max_tokens=150 |
|
|
) |
|
|
enhanced_prompt = response.choices[0].message.content.strip() |
|
|
|
|
|
|
|
|
if "Negative Prompt:" not in enhanced_prompt: |
|
|
enhanced_prompt += " Negative Prompt: low quality, blurry, worst quality, text, watermark" |
|
|
|
|
|
|
|
|
if "Negative Prompt:" in enhanced_prompt: |
|
|
parts = enhanced_prompt.split("Negative Prompt:", 1) |
|
|
main_prompt = parts[0].strip() |
|
|
negative_prompt = parts[1].strip() |
|
|
else: |
|
|
main_prompt = enhanced_prompt |
|
|
negative_prompt = "low quality, blurry, worst quality, text, watermark" |
|
|
|
|
|
return main_prompt, negative_prompt |
|
|
|
|
|
except Exception as e: |
|
|
print(f"调用 OpenAI API 时出错: {e}") |
|
|
|
|
|
return f"{short_prompt}, photorealistic, high quality", "low quality, blurry, worst quality" |
|
|
|
|
|
|
|
|
def generate_image_from_prompt(prompt, negative_prompt, guidance, steps): |
|
|
"""使用 Stable Diffusion v1.5 生成图像""" |
|
|
if not sd_pipe: |
|
|
print("Stable Diffusion 模型不可用,无法生成图像。") |
|
|
|
|
|
|
|
|
img = Image.new('RGB', (512, 512), color = 'black') |
|
|
|
|
|
return img, "错误: Stable Diffusion 模型未加载" |
|
|
|
|
|
print(f"开始生成图像,提示词: {prompt}, 负面提示词: {negative_prompt}") |
|
|
print(f"参数: guidance_scale={guidance}, num_inference_steps={steps}") |
|
|
start_time = time.time() |
|
|
try: |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
prompt = prompt if prompt else "default prompt" |
|
|
negative_prompt = negative_prompt if negative_prompt else "" |
|
|
|
|
|
image = sd_pipe( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
guidance_scale=float(guidance), |
|
|
num_inference_steps=int(steps), |
|
|
generator=torch.Generator("cpu").manual_seed(42) |
|
|
).images[0] |
|
|
end_time = time.time() |
|
|
print(f"图像生成完成,耗时: {end_time - start_time:.2f} 秒") |
|
|
return image, f"生成成功 (耗时: {end_time - start_time:.2f} 秒)" |
|
|
except Exception as e: |
|
|
print(f"生成图像时出错: {e}") |
|
|
img = Image.new('RGB', (512, 512), color = 'red') |
|
|
return img, f"错误: {e}" |
|
|
|
|
|
|
|
|
def transcribe_audio(audio_path): |
|
|
"""将音频文件转录为文字""" |
|
|
if not asr_pipe: |
|
|
print("ASR 模型不可用,无法处理语音输入。") |
|
|
return "错误: 语音识别模型未加载", "" |
|
|
|
|
|
if audio_path is None: |
|
|
return "没有检测到音频输入。", "" |
|
|
|
|
|
print(f"开始处理音频文件: {audio_path}") |
|
|
try: |
|
|
start_time = time.time() |
|
|
|
|
|
transcription = asr_pipe(audio_path) |
|
|
text = transcription["text"] |
|
|
end_time = time.time() |
|
|
print(f"语音识别完成,耗时: {end_time - start_time:.2f} 秒") |
|
|
print(f"识别结果: {text}") |
|
|
|
|
|
return text, f"语音识别成功 (耗时: {end_time - start_time:.2f} 秒)" |
|
|
except Exception as e: |
|
|
print(f"语音识别过程中出错: {e}") |
|
|
return f"语音识别错误: {e}", "错误" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_input(input_text, guidance, steps, audio_input): |
|
|
"""处理文本或语音输入,生成提示词和图像 (修正版)""" |
|
|
|
|
|
final_text_input = input_text |
|
|
enhanced_prompt = "" |
|
|
negative_prompt = "" |
|
|
generated_image = None |
|
|
status = "等待输入" |
|
|
transcription_status = "" |
|
|
|
|
|
|
|
|
yield None, None, None, final_text_input, "开始处理..." |
|
|
|
|
|
|
|
|
if audio_input is not None: |
|
|
status = "语音识别中..." |
|
|
yield None, None, None, final_text_input, status |
|
|
transcribed_text, transcription_status = transcribe_audio(audio_input) |
|
|
if "错误" not in transcription_status: |
|
|
final_text_input = transcribed_text |
|
|
status = f"语音识别完成: {transcription_status}" |
|
|
yield None, None, None, final_text_input, status |
|
|
else: |
|
|
status = f"语音识别失败: {transcribed_text}. 使用文本框内容。" |
|
|
|
|
|
yield None, None, None, final_text_input, status |
|
|
|
|
|
if not final_text_input: |
|
|
status = "请输入描述或提供语音输入。" |
|
|
yield None, None, None, final_text_input, status |
|
|
return |
|
|
|
|
|
|
|
|
status = "增强提示词中..." |
|
|
yield None, None, None, final_text_input, status |
|
|
if openai_available: |
|
|
try: |
|
|
enhanced_prompt, negative_prompt = enhance_prompt_with_llm(final_text_input) |
|
|
status = "提示词增强完成。" |
|
|
except Exception as e: |
|
|
status = f"提示词增强错误: {e}. 使用基础提示词。" |
|
|
|
|
|
enhanced_prompt, negative_prompt = f"{final_text_input}, photorealistic, high quality", "low quality, blurry, worst quality" |
|
|
else: |
|
|
|
|
|
enhanced_prompt, negative_prompt = f"{final_text_input}, photorealistic, high quality", "low quality, blurry, worst quality" |
|
|
status = "跳过提示词增强 (OpenAI不可用)。" |
|
|
|
|
|
|
|
|
yield None, enhanced_prompt, negative_prompt, final_text_input, status |
|
|
|
|
|
|
|
|
status = "生成图像中 (CPU可能较慢)..." |
|
|
yield None, enhanced_prompt, negative_prompt, final_text_input, status |
|
|
|
|
|
generated_image, image_status = generate_image_from_prompt(enhanced_prompt, negative_prompt, guidance, steps) |
|
|
status = f"图像生成完成. {image_status}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
yield generated_image, enhanced_prompt, negative_prompt, final_text_input, status |
|
|
|
|
|
with gr.Blocks(css="footer {visibility: hidden}") as demo: |
|
|
gr.Markdown("# AI 图像生成器 (文字/语音 -> 提示词 -> 图像)") |
|
|
gr.Markdown("输入简短描述或使用麦克风录音,应用将自动 (可选) 增强提示词并使用 Stable Diffusion v1.5 生成图像。") |
|
|
|
|
|
|
|
|
if not openai_available: |
|
|
gr.Markdown("<p style='color:orange;'>⚠️ **警告:** OpenAI API Key 未正确配置,提示词增强功能将使用默认模式。</p>") |
|
|
if not sd_pipe: |
|
|
gr.Markdown("<p style='color:red;'>❌ **错误:** Stable Diffusion 模型加载失败,图像生成功能不可用。</p>") |
|
|
if not asr_pipe: |
|
|
gr.Markdown("<p style='color:orange;'>⚠️ **警告:** 语音识别模型加载失败,语音输入功能不可用。</p>") |
|
|
|
|
|
status_textbox = gr.Textbox(label="处理状态", value="等待输入", interactive=False) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
input_text = gr.Textbox(label="输入简短描述", placeholder="例如:空中的魔法树屋") |
|
|
|
|
|
audio_input = gr.Audio(sources=["microphone"], type="filepath", label="或者,使用语音输入", visible=asr_pipe is not None) |
|
|
|
|
|
|
|
|
guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, value=7.5, step=0.5, label="引导系数 (Guidance Scale)") |
|
|
|
|
|
num_inference_steps = gr.Slider(minimum=10, maximum=50, value=25, step=1, label="推理步数 (Steps)") |
|
|
|
|
|
|
|
|
generate_button = gr.Button("✨ 生成图像 ✨", variant="primary") |
|
|
|
|
|
|
|
|
with gr.Column(scale=1): |
|
|
|
|
|
enhanced_prompt_display = gr.Textbox(label="增强后的提示词 (Prompt)", interactive=False) |
|
|
negative_prompt_display = gr.Textbox(label="负面提示词 (Negative Prompt)", interactive=False) |
|
|
output_image = gr.Image(label="生成的图像", type="pil") |
|
|
|
|
|
|
|
|
|
|
|
generate_button.click( |
|
|
fn=process_input, |
|
|
inputs=[input_text, guidance_scale, num_inference_steps, audio_input], |
|
|
outputs=[output_image, enhanced_prompt_display, negative_prompt_display, input_text, status_textbox] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.queue().launch(debug=True) |