File size: 4,604 Bytes
ef5d3f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
import torch
import spaces
import os
from diffusers import DiffusionPipeline

# --- Model Configuration and Loading ---
MODEL_ID = "Manojb/stable-diffusion-2-1-base"
DTYPE = torch.bfloat16

try:
    # Load pipeline
    pipe = DiffusionPipeline.from_pretrained(
        MODEL_ID, 
        torch_dtype=DTYPE, 
        use_safetensors=True
    )
    pipe.to('cuda')
    
    # --- Mandatory ZeroGPU AoT Compilation for Optimization ---
    
    @spaces.GPU(duration=1500)  # Extended duration for startup compilation
    def compile_unet():
        print("Starting AoT compilation for UNet...")
        
        # Dummy inputs for 512x512 generation (B=1, latents=64x64 for UNet)
        B, C, H, W = 1, 4, 64, 64
        sample = torch.randn(B, C, H, W, dtype=DTYPE, device='cuda')
        timestep = torch.tensor([999], dtype=torch.long, device='cuda')
        
        # Encoder Hidden States (text embeddings): (B, 77, 1024) for SD2.1
        EHS_DIM = 77
        EHS_HIDDEN = 1024
        encoder_hidden_states = torch.randn(B, EHS_DIM, EHS_HIDDEN, dtype=DTYPE, device='cuda')

        inputs = (sample, timestep, encoder_hidden_states)
        
        with spaces.aoti_capture(pipe.unet) as call:
            call(*inputs)
        
        exported = torch.export.export(pipe.unet, args=call.args, kwargs=call.kwargs)
        compiled_model = spaces.aoti_compile(exported)
        print("AoT compilation successful.")
        return compiled_model

    # Execute compilation during startup
    compiled_unet = compile_unet()
    spaces.aoti_apply(compiled_unet, pipe.unet)

except Exception as e:
    print(f"⚠️ Warning: Model initialization or AoT compilation failed ({e}). Running without optimization or skipping initialization if severe.")
    # Fallback to loading the model without AoT if compilation fails
    if 'pipe' not in locals():
        pipe = DiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=DTYPE, use_safetensors=True)
        pipe.to('cuda')
        print("Model loaded successfully without AoT.")

@spaces.GPU(duration=60) # Standard GPU allocation for inference
def generate(prompt: str, num_images: int):
    """Generates images using the Stable Diffusion pipeline."""
    
    if not prompt:
        raise gr.Error("Prompt cannot be empty.")

    # Prepare batch input
    prompt_list = [prompt] * num_images
    
    # Generate images
    output = pipe(
        prompt_list,
        num_inference_steps=25,
        guidance_scale=9.0,
    )
    
    return output.images

# --- Gradio Interface ---

with gr.Blocks(theme=gr.themes.Soft(), title="SD 2.1 Base Generator") as demo:
    gr.HTML(
        """
        <div style="text-align: center; margin-bottom: 20px;">
            <h1>Stable Diffusion 2.1 Base (512x512)</h1>
            <p>Model: Manojb/stable-diffusion-2-1-base | Optimized with ZeroGPU AoT</p>
            <p>Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a></p>
        </div>
        """
    )

    with gr.Row():
        with gr.Column(scale=1):
            prompt = gr.Textbox(
                label="Prompt",
                placeholder="A detailed digital painting of a majestic dragon flying over a medieval castle, fantasy art",
                lines=3
            )
            num_images = gr.Slider(
                minimum=1,
                maximum=4,
                step=1,
                value=2,
                label="Number of Images to Generate (Max 4)",
                info="Generates multiple images in a single batch call."
            )
            generate_btn = gr.Button("Generate Images", variant="primary")
        
        with gr.Column(scale=2):
            output_gallery = gr.Gallery(
                label="Generated Images (512x512)",
                height=512,
                columns=2,
                rows=2,
                object_fit="contain"
            )

    generate_btn.click(
        fn=generate,
        inputs=[prompt, num_images],
        outputs=output_gallery
    )

    gr.Examples(
        examples=[
            ["A photorealistic portrait of a golden retriever wearing sunglasses on a beach, cinematic lighting", 2],
            ["Steampunk owl on a bookshelf, detailed brass gears, oil painting", 4],
            ["High contrast black and white photograph of an old lighthouse during a storm", 1]
        ],
        inputs=[prompt, num_images],
        outputs=output_gallery,
        fn=generate,
        cache_examples=True,
        cache_mode="eager"
    )

demo.queue()
if __name__ == "__main__":
    demo.launch()