import gradio as gr import torch from diffusers import StableDiffusionPipeline from PIL import Image import traceback from typing import Optional # Stable Diffusion模型相关设置 model_id: str = "runwayml/stable-diffusion-v1-5" device: str = "cpu" # force CPU usage for compatibility image_generator_pipe: Optional[StableDiffusionPipeline] = None try: pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32) image_generator_pipe = pipe.to(device) except Exception as e: print(f"Failed to load Stable Diffusion model: {e}") # 提示词优化函数(简单版) def optimize_prompt_simple(short_description: str) -> str: optimized_prompt = f"Generate a high-quality, detailed image based on the following description: {short_description}" return optimized_prompt # 图像生成函数 def generate_image_sd(short_description: str, negative_prompt: str, guidance_scale: float, num_inference_steps: int) -> Image.Image: optimized_prompt = optimize_prompt_simple(short_description) try: with torch.no_grad(): if image_generator_pipe is None: raise RuntimeError("Stable Diffusion pipeline is not available.") output = image_generator_pipe( prompt=optimized_prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps ) image = output.images[0] if output.images else None if not image: raise RuntimeError("No image was returned from the generation pipeline.") return image except Exception as e: traceback.print_exc() raise gr.Error(f"Image generation failed: {str(e)}") # Gradio界面 with gr.Blocks(theme=gr.themes.Soft()) as demo: with gr.Row(): with gr.Column(scale=1): short_description = gr.Textbox(label="Short Description", placeholder="A magical treehouse in the sky") optimized_prompt_display = gr.Textbox(label="Optimized Prompt", interactive=False) neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="blurry, distorted, watermark") guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="Guidance Scale") steps = gr.Slider(10, 50, value=25, step=1, label="Inference Steps") generate_btn = gr.Button("Generate Image") with gr.Column(scale=1): output_image = gr.Image(label="Generated Image", type="pil") # 当用户输入简短描述时,自动优化提示词并显示 short_description.input( fn=lambda x: optimize_prompt_simple(x), inputs=short_description, outputs=optimized_prompt_display ) generate_btn.click( fn=generate_image_sd, inputs=[short_description, neg_prompt, guidance, steps], outputs=output_image ) if __name__ == "__main__": if not image_generator_pipe: print("WARNING: Stable Diffusion pipeline is not available. UI will launch, but generation will fail.") demo.launch(server_name="0.0.0.0", server_port=7860)