Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import spaces | |
| from diffusers import DiffusionPipeline | |
| from PIL import Image | |
| from typing import List, Optional, Any | |
| # --- Model Configuration --- | |
| MODEL_V1 = "CompVis/stable-diffusion-v1-4" | |
| MODEL_V2 = "Manojb/stable-diffusion-2-1-base" | |
| DEVICE = "cuda" | |
| # Use bfloat16 for optimized performance on modern GPUs (H200/A100/H100) | |
| DTYPE = torch.bfloat16 | |
| # Default prompts for generation when user input is empty | |
| DEFAULT_PROMPT_V1 = "A stunning photorealistic image of a golden retriever wearing a crown, in a grand hall, cinematic lighting, masterpiece, 4k" | |
| DEFAULT_PROMPT_V2 = "A detailed matte painting of an ancient ruined city overgrown with vines, dramatic sunset, fantasy art, 8k, cinematic" | |
| print("Loading Models...") | |
| pipe_v1 = DiffusionPipeline.from_pretrained( | |
| MODEL_V1, | |
| torch_dtype=DTYPE, | |
| safety_checker=None, | |
| requires_safety_checker=False, | |
| # Use from_single_file=True if loading .ckpt or .safetensors files directly | |
| ).to(DEVICE) | |
| pipe_v2 = DiffusionPipeline.from_pretrained( | |
| MODEL_V2, | |
| torch_dtype=DTYPE, | |
| safety_checker=None, | |
| requires_safety_checker=False, | |
| ).to(DEVICE) | |
| print("Models Loaded.") | |
| def compile_optimized_models(): | |
| """ | |
| Performs Ahead-of-Time (AoT) compilation for improved ZeroGPU performance. | |
| """ | |
| # --- Compilation for SD 1.4 (pipe_v1) --- | |
| print(f"Compiling UNet for {MODEL_V1} (SD 1.4)...") | |
| try: | |
| with spaces.aoti_capture(pipe_v1.unet) as call: | |
| # Run a quick example call (512x512, low steps) to capture inputs | |
| pipe_v1( | |
| prompt="compilation test", | |
| num_inference_steps=2, | |
| height=512, width=512 | |
| ) | |
| exported_v1 = torch.export.export(pipe_v1.unet, args=call.args, kwargs=call.kwargs) | |
| compiled_v1 = spaces.aoti_compile(exported_v1) | |
| spaces.aoti_apply(compiled_v1, pipe_v1.unet) | |
| print(f"Compilation for {MODEL_V1} complete.") | |
| except Exception as e: | |
| print(f"Warning: AoT compilation failed for SD 1.4. Running unoptimized. Error: {e}") | |
| # --- Compilation for SD 2.1 Base (pipe_v2) --- | |
| print(f"Compiling UNet for {MODEL_V2} (SD 2.1 Base)...") | |
| try: | |
| with spaces.aoti_capture(pipe_v2.unet) as call: | |
| # Run a quick example call (512x512, low steps) to capture inputs | |
| pipe_v2( | |
| prompt="compilation test", | |
| num_inference_steps=2, | |
| height=512, width=512 | |
| ) | |
| exported_v2 = torch.export.export(pipe_v2.unet, args=call.args, kwargs=call.kwargs) | |
| compiled_v2 = spaces.aoti_compile(exported_v2) | |
| spaces.aoti_apply(compiled_v2, pipe_v2.unet) | |
| print(f"Compilation for {MODEL_V2} complete.") | |
| except Exception as e: | |
| print(f"Warning: AoT compilation failed for SD 2.1 Base. Running unoptimized. Error: {e}") | |
| # Run compilation once at startup | |
| compile_optimized_models() | |
| def generate( | |
| model_choice: str, | |
| prompt: str, | |
| guidance_scale: float, | |
| num_inference_steps: int | |
| ) -> List[Image.Image]: | |
| """Generates images using the selected Stable Diffusion model.""" | |
| if model_choice == MODEL_V1: | |
| pipe = pipe_v1 | |
| if not prompt: | |
| prompt = DEFAULT_PROMPT_V1 | |
| elif model_choice == MODEL_V2: | |
| pipe = pipe_v2 | |
| if not prompt: | |
| prompt = DEFAULT_PROMPT_V2 | |
| else: | |
| raise gr.Error("Invalid model selection.") | |
| # We must use the resolution used during AoT compilation (512x512) | |
| # for best performance. | |
| result = pipe( | |
| prompt=prompt, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| num_images_per_prompt=4, # Generate 4 images as implied by gallery output | |
| height=512, | |
| width=512 | |
| ).images | |
| return result | |
| def display_uploads(files: Optional[List[Any]]) -> List[str]: | |
| """Converts uploaded FileData objects to displayable paths.""" | |
| if files: | |
| # FileData objects have a .path attribute pointing to the temporary file location | |
| return [f.path for f in files] | |
| return [] | |
| # --- Gradio Interface --- | |
| with gr.Blocks(title="Stable Diffusion Models Demo") as demo: | |
| gr.HTML( | |
| """ | |
| <div style='text-align: center; max-width: 800px; margin: 0 auto;'> | |
| <h1>Stable Diffusion v1.4 vs 2.1 Base</h1> | |
| <p>Select a model and enter a prompt to generate up to 4 images. Empty prompts use a powerful default prompt.</p> | |
| <p><a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">Built with anycoder</a></p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_choice = gr.Radio( | |
| choices=[MODEL_V1, MODEL_V2], | |
| value=MODEL_V2, | |
| label="Model Selection", | |
| info="Select the base Stable Diffusion version to use." | |
| ) | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter your prompt here (or leave empty for default demo prompt)" | |
| ) | |
| with gr.Accordion("Generation Parameters", open=True): | |
| guidance_scale = gr.Slider( | |
| minimum=1.0, maximum=15.0, value=7.5, step=0.5, label="Guidance Scale", | |
| info="Higher values push the generation closer to the prompt." | |
| ) | |
| num_inference_steps = gr.Slider( | |
| minimum=10, maximum=100, value=50, step=5, label="Inference Steps", | |
| info="More steps lead to higher quality, but slower generation." | |
| ) | |
| run_btn = gr.Button("Generate 4 Images", variant="primary") | |
| # Handling image uploads (for auxiliary display/reference) | |
| uploaded_files = gr.File( | |
| label="Upload Reference Images (Max 4)", | |
| file_count="multiple", | |
| file_types=['image'], | |
| max_files=4, | |
| interactive=True | |
| ) | |
| upload_display = gr.Gallery( | |
| label="Uploaded Images for Reference", | |
| columns=4, | |
| object_fit="contain", | |
| height=150, | |
| allow_preview=False | |
| ) | |
| uploaded_files.change(display_uploads, uploaded_files, upload_display) | |
| with gr.Column(scale=3): | |
| output_gallery = gr.Gallery( | |
| label="Generated Images (512x512)", | |
| columns=2, | |
| object_fit="contain", | |
| height=512, | |
| preview=True | |
| ) | |
| run_btn.click( | |
| fn=generate, | |
| inputs=[ | |
| model_choice, | |
| prompt, | |
| guidance_scale, | |
| num_inference_steps | |
| ], | |
| outputs=output_gallery | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |