Spaces:
Running
Running
File size: 4,604 Bytes
ef5d3f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
import torch
import spaces
import os
from diffusers import DiffusionPipeline
# --- Model Configuration and Loading ---
MODEL_ID = "Manojb/stable-diffusion-2-1-base"
DTYPE = torch.bfloat16
try:
# Load pipeline
pipe = DiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
use_safetensors=True
)
pipe.to('cuda')
# --- Mandatory ZeroGPU AoT Compilation for Optimization ---
@spaces.GPU(duration=1500) # Extended duration for startup compilation
def compile_unet():
print("Starting AoT compilation for UNet...")
# Dummy inputs for 512x512 generation (B=1, latents=64x64 for UNet)
B, C, H, W = 1, 4, 64, 64
sample = torch.randn(B, C, H, W, dtype=DTYPE, device='cuda')
timestep = torch.tensor([999], dtype=torch.long, device='cuda')
# Encoder Hidden States (text embeddings): (B, 77, 1024) for SD2.1
EHS_DIM = 77
EHS_HIDDEN = 1024
encoder_hidden_states = torch.randn(B, EHS_DIM, EHS_HIDDEN, dtype=DTYPE, device='cuda')
inputs = (sample, timestep, encoder_hidden_states)
with spaces.aoti_capture(pipe.unet) as call:
call(*inputs)
exported = torch.export.export(pipe.unet, args=call.args, kwargs=call.kwargs)
compiled_model = spaces.aoti_compile(exported)
print("AoT compilation successful.")
return compiled_model
# Execute compilation during startup
compiled_unet = compile_unet()
spaces.aoti_apply(compiled_unet, pipe.unet)
except Exception as e:
print(f"⚠️ Warning: Model initialization or AoT compilation failed ({e}). Running without optimization or skipping initialization if severe.")
# Fallback to loading the model without AoT if compilation fails
if 'pipe' not in locals():
pipe = DiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=DTYPE, use_safetensors=True)
pipe.to('cuda')
print("Model loaded successfully without AoT.")
@spaces.GPU(duration=60) # Standard GPU allocation for inference
def generate(prompt: str, num_images: int):
"""Generates images using the Stable Diffusion pipeline."""
if not prompt:
raise gr.Error("Prompt cannot be empty.")
# Prepare batch input
prompt_list = [prompt] * num_images
# Generate images
output = pipe(
prompt_list,
num_inference_steps=25,
guidance_scale=9.0,
)
return output.images
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), title="SD 2.1 Base Generator") as demo:
gr.HTML(
"""
<div style="text-align: center; margin-bottom: 20px;">
<h1>Stable Diffusion 2.1 Base (512x512)</h1>
<p>Model: Manojb/stable-diffusion-2-1-base | Optimized with ZeroGPU AoT</p>
<p>Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a></p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Prompt",
placeholder="A detailed digital painting of a majestic dragon flying over a medieval castle, fantasy art",
lines=3
)
num_images = gr.Slider(
minimum=1,
maximum=4,
step=1,
value=2,
label="Number of Images to Generate (Max 4)",
info="Generates multiple images in a single batch call."
)
generate_btn = gr.Button("Generate Images", variant="primary")
with gr.Column(scale=2):
output_gallery = gr.Gallery(
label="Generated Images (512x512)",
height=512,
columns=2,
rows=2,
object_fit="contain"
)
generate_btn.click(
fn=generate,
inputs=[prompt, num_images],
outputs=output_gallery
)
gr.Examples(
examples=[
["A photorealistic portrait of a golden retriever wearing sunglasses on a beach, cinematic lighting", 2],
["Steampunk owl on a bookshelf, detailed brass gears, oil painting", 4],
["High contrast black and white photograph of an old lighthouse during a storm", 1]
],
inputs=[prompt, num_images],
outputs=output_gallery,
fn=generate,
cache_examples=True,
cache_mode="eager"
)
demo.queue()
if __name__ == "__main__":
demo.launch() |