Gertie01 commited on
Commit
185331b
Β·
verified Β·
1 Parent(s): f723243

Deploy Gradio app with multiple files

Browse files
Files changed (2) hide show
  1. app.py +198 -0
  2. requirements.txt +15 -0
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import spaces
4
+ from diffusers import DiffusionPipeline
5
+ from PIL import Image
6
+ from typing import List, Optional, Any
7
+
8
+ # --- Model Configuration ---
9
+ MODEL_V1 = "CompVis/stable-diffusion-v1-4"
10
+ MODEL_V2 = "Manojb/stable-diffusion-2-1-base"
11
+ DEVICE = "cuda"
12
+
13
+ # Use bfloat16 for optimized performance on modern GPUs (H200/A100/H100)
14
+ DTYPE = torch.bfloat16
15
+
16
+ # Default prompts for generation when user input is empty
17
+ DEFAULT_PROMPT_V1 = "A stunning photorealistic image of a golden retriever wearing a crown, in a grand hall, cinematic lighting, masterpiece, 4k"
18
+ DEFAULT_PROMPT_V2 = "A detailed matte painting of an ancient ruined city overgrown with vines, dramatic sunset, fantasy art, 8k, cinematic"
19
+
20
+ print("Loading Models...")
21
+ pipe_v1 = DiffusionPipeline.from_pretrained(
22
+ MODEL_V1,
23
+ torch_dtype=DTYPE,
24
+ safety_checker=None,
25
+ requires_safety_checker=False,
26
+ # Use from_single_file=True if loading .ckpt or .safetensors files directly
27
+ ).to(DEVICE)
28
+
29
+ pipe_v2 = DiffusionPipeline.from_pretrained(
30
+ MODEL_V2,
31
+ torch_dtype=DTYPE,
32
+ safety_checker=None,
33
+ requires_safety_checker=False,
34
+ ).to(DEVICE)
35
+ print("Models Loaded.")
36
+
37
+
38
+ @spaces.GPU(duration=1500)
39
+ def compile_optimized_models():
40
+ """
41
+ Performs Ahead-of-Time (AoT) compilation for improved ZeroGPU performance.
42
+ """
43
+ # --- Compilation for SD 1.4 (pipe_v1) ---
44
+ print(f"Compiling UNet for {MODEL_V1} (SD 1.4)...")
45
+ try:
46
+ with spaces.aoti_capture(pipe_v1.unet) as call:
47
+ # Run a quick example call (512x512, low steps) to capture inputs
48
+ pipe_v1(
49
+ prompt="compilation test",
50
+ num_inference_steps=2,
51
+ height=512, width=512
52
+ )
53
+ exported_v1 = torch.export.export(pipe_v1.unet, args=call.args, kwargs=call.kwargs)
54
+ compiled_v1 = spaces.aoti_compile(exported_v1)
55
+ spaces.aoti_apply(compiled_v1, pipe_v1.unet)
56
+ print(f"Compilation for {MODEL_V1} complete.")
57
+ except Exception as e:
58
+ print(f"Warning: AoT compilation failed for SD 1.4. Running unoptimized. Error: {e}")
59
+
60
+ # --- Compilation for SD 2.1 Base (pipe_v2) ---
61
+ print(f"Compiling UNet for {MODEL_V2} (SD 2.1 Base)...")
62
+ try:
63
+ with spaces.aoti_capture(pipe_v2.unet) as call:
64
+ # Run a quick example call (512x512, low steps) to capture inputs
65
+ pipe_v2(
66
+ prompt="compilation test",
67
+ num_inference_steps=2,
68
+ height=512, width=512
69
+ )
70
+ exported_v2 = torch.export.export(pipe_v2.unet, args=call.args, kwargs=call.kwargs)
71
+ compiled_v2 = spaces.aoti_compile(exported_v2)
72
+ spaces.aoti_apply(compiled_v2, pipe_v2.unet)
73
+ print(f"Compilation for {MODEL_V2} complete.")
74
+ except Exception as e:
75
+ print(f"Warning: AoT compilation failed for SD 2.1 Base. Running unoptimized. Error: {e}")
76
+
77
+ # Run compilation once at startup
78
+ compile_optimized_models()
79
+
80
+
81
+ @spaces.GPU
82
+ def generate(
83
+ model_choice: str,
84
+ prompt: str,
85
+ guidance_scale: float,
86
+ num_inference_steps: int
87
+ ) -> List[Image.Image]:
88
+ """Generates images using the selected Stable Diffusion model."""
89
+
90
+ if model_choice == MODEL_V1:
91
+ pipe = pipe_v1
92
+ if not prompt:
93
+ prompt = DEFAULT_PROMPT_V1
94
+ elif model_choice == MODEL_V2:
95
+ pipe = pipe_v2
96
+ if not prompt:
97
+ prompt = DEFAULT_PROMPT_V2
98
+ else:
99
+ raise gr.Error("Invalid model selection.")
100
+
101
+ # We must use the resolution used during AoT compilation (512x512)
102
+ # for best performance.
103
+ result = pipe(
104
+ prompt=prompt,
105
+ guidance_scale=guidance_scale,
106
+ num_inference_steps=num_inference_steps,
107
+ num_images_per_prompt=4, # Generate 4 images as implied by gallery output
108
+ height=512,
109
+ width=512
110
+ ).images
111
+
112
+ return result
113
+
114
+
115
+ def display_uploads(files: Optional[List[Any]]) -> List[str]:
116
+ """Converts uploaded FileData objects to displayable paths."""
117
+ if files:
118
+ # FileData objects have a .path attribute pointing to the temporary file location
119
+ return [f.path for f in files]
120
+ return []
121
+
122
+
123
+ # --- Gradio Interface ---
124
+ with gr.Blocks(title="Stable Diffusion Models Demo") as demo:
125
+ gr.HTML(
126
+ """
127
+ <div style='text-align: center; max-width: 800px; margin: 0 auto;'>
128
+ <h1>Stable Diffusion v1.4 vs 2.1 Base</h1>
129
+ <p>Select a model and enter a prompt to generate up to 4 images. Empty prompts use a powerful default prompt.</p>
130
+ <p><a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">Built with anycoder</a></p>
131
+ </div>
132
+ """
133
+ )
134
+
135
+ with gr.Row():
136
+ with gr.Column(scale=1):
137
+ model_choice = gr.Radio(
138
+ choices=[MODEL_V1, MODEL_V2],
139
+ value=MODEL_V2,
140
+ label="Model Selection",
141
+ info="Select the base Stable Diffusion version to use."
142
+ )
143
+ prompt = gr.Textbox(
144
+ label="Prompt",
145
+ placeholder="Enter your prompt here (or leave empty for default demo prompt)"
146
+ )
147
+
148
+ with gr.Accordion("Generation Parameters", open=True):
149
+ guidance_scale = gr.Slider(
150
+ minimum=1.0, maximum=15.0, value=7.5, step=0.5, label="Guidance Scale",
151
+ info="Higher values push the generation closer to the prompt."
152
+ )
153
+ num_inference_steps = gr.Slider(
154
+ minimum=10, maximum=100, value=50, step=5, label="Inference Steps",
155
+ info="More steps lead to higher quality, but slower generation."
156
+ )
157
+
158
+ run_btn = gr.Button("Generate 4 Images", variant="primary")
159
+
160
+ # Handling image uploads (for auxiliary display/reference)
161
+ uploaded_files = gr.File(
162
+ label="Upload Reference Images (Max 4)",
163
+ file_count="multiple",
164
+ file_types=['image'],
165
+ max_files=4,
166
+ interactive=True
167
+ )
168
+ upload_display = gr.Gallery(
169
+ label="Uploaded Images for Reference",
170
+ columns=4,
171
+ object_fit="contain",
172
+ height=150,
173
+ allow_preview=False
174
+ )
175
+ uploaded_files.change(display_uploads, uploaded_files, upload_display)
176
+
177
+ with gr.Column(scale=3):
178
+ output_gallery = gr.Gallery(
179
+ label="Generated Images (512x512)",
180
+ columns=2,
181
+ object_fit="contain",
182
+ height=512,
183
+ preview=True
184
+ )
185
+
186
+ run_btn.click(
187
+ fn=generate,
188
+ inputs=[
189
+ model_choice,
190
+ prompt,
191
+ guidance_scale,
192
+ num_inference_steps
193
+ ],
194
+ outputs=output_gallery
195
+ )
196
+
197
+ if __name__ == "__main__":
198
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ diffusers
4
+ numpy
5
+ accelerate
6
+ safetensors
7
+ pillow
8
+ git+https://github.com/huggingface/spaces@main
9
+ xformers
10
+ scipy
11
+ ftfy
12
+ opencv-python
13
+ tensorboard
14
+ clean-fid
15
+ huggingface-hub