Gertie01 commited on
Commit
ef5d3f9
·
verified ·
1 Parent(s): 44e8877

Deploy Gradio app with multiple files

Browse files
Files changed (2) hide show
  1. app.py +137 -0
  2. requirements.txt +19 -0
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import spaces
4
+ import os
5
+ from diffusers import DiffusionPipeline
6
+
7
+ # --- Model Configuration and Loading ---
8
+ MODEL_ID = "Manojb/stable-diffusion-2-1-base"
9
+ DTYPE = torch.bfloat16
10
+
11
+ try:
12
+ # Load pipeline
13
+ pipe = DiffusionPipeline.from_pretrained(
14
+ MODEL_ID,
15
+ torch_dtype=DTYPE,
16
+ use_safetensors=True
17
+ )
18
+ pipe.to('cuda')
19
+
20
+ # --- Mandatory ZeroGPU AoT Compilation for Optimization ---
21
+
22
+ @spaces.GPU(duration=1500) # Extended duration for startup compilation
23
+ def compile_unet():
24
+ print("Starting AoT compilation for UNet...")
25
+
26
+ # Dummy inputs for 512x512 generation (B=1, latents=64x64 for UNet)
27
+ B, C, H, W = 1, 4, 64, 64
28
+ sample = torch.randn(B, C, H, W, dtype=DTYPE, device='cuda')
29
+ timestep = torch.tensor([999], dtype=torch.long, device='cuda')
30
+
31
+ # Encoder Hidden States (text embeddings): (B, 77, 1024) for SD2.1
32
+ EHS_DIM = 77
33
+ EHS_HIDDEN = 1024
34
+ encoder_hidden_states = torch.randn(B, EHS_DIM, EHS_HIDDEN, dtype=DTYPE, device='cuda')
35
+
36
+ inputs = (sample, timestep, encoder_hidden_states)
37
+
38
+ with spaces.aoti_capture(pipe.unet) as call:
39
+ call(*inputs)
40
+
41
+ exported = torch.export.export(pipe.unet, args=call.args, kwargs=call.kwargs)
42
+ compiled_model = spaces.aoti_compile(exported)
43
+ print("AoT compilation successful.")
44
+ return compiled_model
45
+
46
+ # Execute compilation during startup
47
+ compiled_unet = compile_unet()
48
+ spaces.aoti_apply(compiled_unet, pipe.unet)
49
+
50
+ except Exception as e:
51
+ print(f"⚠️ Warning: Model initialization or AoT compilation failed ({e}). Running without optimization or skipping initialization if severe.")
52
+ # Fallback to loading the model without AoT if compilation fails
53
+ if 'pipe' not in locals():
54
+ pipe = DiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=DTYPE, use_safetensors=True)
55
+ pipe.to('cuda')
56
+ print("Model loaded successfully without AoT.")
57
+
58
+ @spaces.GPU(duration=60) # Standard GPU allocation for inference
59
+ def generate(prompt: str, num_images: int):
60
+ """Generates images using the Stable Diffusion pipeline."""
61
+
62
+ if not prompt:
63
+ raise gr.Error("Prompt cannot be empty.")
64
+
65
+ # Prepare batch input
66
+ prompt_list = [prompt] * num_images
67
+
68
+ # Generate images
69
+ output = pipe(
70
+ prompt_list,
71
+ num_inference_steps=25,
72
+ guidance_scale=9.0,
73
+ )
74
+
75
+ return output.images
76
+
77
+ # --- Gradio Interface ---
78
+
79
+ with gr.Blocks(theme=gr.themes.Soft(), title="SD 2.1 Base Generator") as demo:
80
+ gr.HTML(
81
+ """
82
+ <div style="text-align: center; margin-bottom: 20px;">
83
+ <h1>Stable Diffusion 2.1 Base (512x512)</h1>
84
+ <p>Model: Manojb/stable-diffusion-2-1-base | Optimized with ZeroGPU AoT</p>
85
+ <p>Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a></p>
86
+ </div>
87
+ """
88
+ )
89
+
90
+ with gr.Row():
91
+ with gr.Column(scale=1):
92
+ prompt = gr.Textbox(
93
+ label="Prompt",
94
+ placeholder="A detailed digital painting of a majestic dragon flying over a medieval castle, fantasy art",
95
+ lines=3
96
+ )
97
+ num_images = gr.Slider(
98
+ minimum=1,
99
+ maximum=4,
100
+ step=1,
101
+ value=2,
102
+ label="Number of Images to Generate (Max 4)",
103
+ info="Generates multiple images in a single batch call."
104
+ )
105
+ generate_btn = gr.Button("Generate Images", variant="primary")
106
+
107
+ with gr.Column(scale=2):
108
+ output_gallery = gr.Gallery(
109
+ label="Generated Images (512x512)",
110
+ height=512,
111
+ columns=2,
112
+ rows=2,
113
+ object_fit="contain"
114
+ )
115
+
116
+ generate_btn.click(
117
+ fn=generate,
118
+ inputs=[prompt, num_images],
119
+ outputs=output_gallery
120
+ )
121
+
122
+ gr.Examples(
123
+ examples=[
124
+ ["A photorealistic portrait of a golden retriever wearing sunglasses on a beach, cinematic lighting", 2],
125
+ ["Steampunk owl on a bookshelf, detailed brass gears, oil painting", 4],
126
+ ["High contrast black and white photograph of an old lighthouse during a storm", 1]
127
+ ],
128
+ inputs=[prompt, num_images],
129
+ outputs=output_gallery,
130
+ fn=generate,
131
+ cache_examples=True,
132
+ cache_mode="eager"
133
+ )
134
+
135
+ demo.queue()
136
+ if __name__ == "__main__":
137
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ diffusers
4
+ accelerate
5
+ safetensors
6
+ git+https://github.com/huggingface/spaces
7
+ Pillow
8
+ xformers
9
+ scipy
10
+ opencv-python
11
+ ftfy
12
+ transformers
13
+ regex
14
+ httpx
15
+ pydantic
16
+ typing-extensions
17
+ dataclasses_json
18
+ aiohttp
19
+ numpy