phxdev commited on
Commit
79dcd91
·
verified ·
1 Parent(s): 6c51bd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +287 -289
app.py CHANGED
@@ -1,290 +1,288 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
4
- import spaces
5
- import torch
6
- from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, StableDiffusionUpscalePipeline
7
- from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
8
- from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
9
- from huggingface_hub import hf_hub_download
10
- import os
11
- import requests
12
-
13
- dtype = torch.bfloat16
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
-
16
- taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
17
- good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
18
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1).to(device)
19
-
20
- # Performance optimizations
21
- if hasattr(pipe, "enable_model_cpu_offload"):
22
- pipe.enable_model_cpu_offload()
23
- if hasattr(pipe, "enable_attention_slicing"):
24
- pipe.enable_attention_slicing(1)
25
- if hasattr(pipe, "enable_vae_slicing"):
26
- pipe.enable_vae_slicing()
27
- if hasattr(pipe, "enable_vae_tiling"):
28
- pipe.enable_vae_tiling()
29
-
30
- # Compile transformer for faster inference (if supported)
31
- try:
32
- pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
33
- print("✓ Transformer compiled for faster inference")
34
- except Exception as e:
35
- print(f"Warning: Could not compile transformer: {e}")
36
-
37
- # Load upscaler pipeline with optimizations
38
- upscaler = StableDiffusionUpscalePipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", torch_dtype=dtype).to(device)
39
- if hasattr(upscaler, "enable_model_cpu_offload"):
40
- upscaler.enable_model_cpu_offload()
41
- if hasattr(upscaler, "enable_attention_slicing"):
42
- upscaler.enable_attention_slicing(1)
43
- if hasattr(upscaler, "enable_vae_slicing"):
44
- upscaler.enable_vae_slicing()
45
-
46
- # Available LoRAs
47
- LORAS = {
48
- "None": None,
49
- "AntiBlur": "Shakker-Labs/FLUX.1-dev-LoRA-AntiBlur",
50
- "Add Details": "Shakker-Labs/FLUX.1-dev-LoRA-add-details",
51
- "Ultra Realism": "https://huggingface.co/its-magick/merlin-test-loras/resolve/main/Canopus-LoRA-Flux-UltraRealism.safetensors",
52
- "Face Realism": "https://huggingface.co/its-magick/merlin-test-loras/resolve/main/Canopus-LoRA-Flux-FaceRealism.safetensors"
53
- }
54
-
55
- # Store loaded LoRA paths
56
- loaded_loras = {}
57
-
58
- def download_lora_from_url(url, filename):
59
- """Download LoRA file from direct URL"""
60
- if not os.path.exists(filename):
61
- print(f"Downloading {filename}...")
62
- response = requests.get(url)
63
- with open(filename, 'wb') as f:
64
- f.write(response.content)
65
- print(f"Downloaded {filename}")
66
- return filename
67
-
68
- def preload_and_apply_all_loras():
69
- """Download and apply all LoRAs simultaneously at startup"""
70
- global loaded_loras
71
-
72
- print("Downloading and applying all LoRAs...")
73
-
74
- for lora_name, lora_path in LORAS.items():
75
- if lora_name == "None" or lora_path is None:
76
- continue
77
-
78
- # Handle direct URL downloads
79
- if lora_path.startswith('http'):
80
- filename = f"{lora_name.lower().replace(' ', '_')}_lora.safetensors"
81
- lora_path = download_lora_from_url(lora_path, filename)
82
-
83
- loaded_loras[lora_name] = lora_path
84
- print(f"Downloaded {lora_name}")
85
-
86
- # Apply each LoRA with optimal scaling
87
- try:
88
- optimal_scale = get_optimal_lora_scale(lora_name)
89
- pipe.load_lora_weights(lora_path, adapter_name=lora_name.lower().replace(' ', '_'))
90
- print(f"Applied {lora_name} with scale {optimal_scale}")
91
- except Exception as e:
92
- print(f"Failed to apply {lora_name}: {e}")
93
-
94
- print(f"All {len(loaded_loras)} LoRAs downloaded and applied!")
95
-
96
- def get_optimal_lora_scale(lora_name):
97
- """Return optimal LoRA scale based on LoRA type for better quality/speed balance"""
98
- lora_scales = {
99
- "AntiBlur": 0.8, # Slightly lower for better balance
100
- "Add Details": 1.2, # Higher for more detail enhancement
101
- "Ultra Realism": 0.9, # Balanced for realism
102
- "Face Realism": 1.1, # Optimized for facial features
103
- }
104
- return lora_scales.get(lora_name, 1.0)
105
-
106
- # Download and apply all LoRAs at startup
107
- preload_and_apply_all_loras()
108
-
109
- torch.cuda.empty_cache()
110
-
111
- MAX_SEED = np.iinfo(np.int32).max
112
- MAX_IMAGE_SIZE = 2048
113
-
114
- pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
115
-
116
- @spaces.GPU(duration=75)
117
- def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, enable_upscale=False, progress=gr.Progress(track_tqdm=True)):
118
- if randomize_seed:
119
- seed = random.randint(0, MAX_SEED)
120
- generator = torch.Generator().manual_seed(seed)
121
-
122
- # All LoRAs are already loaded and active
123
-
124
- try:
125
- final_img = None
126
- for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
127
- prompt=prompt,
128
- guidance_scale=guidance_scale,
129
- num_inference_steps=num_inference_steps,
130
- width=width,
131
- height=height,
132
- generator=generator,
133
- output_type="pil",
134
- good_vae=good_vae,
135
- ):
136
- final_img = img
137
- yield img, seed
138
-
139
- # Apply upscaling if enabled with optimized settings
140
- if enable_upscale and final_img is not None:
141
- try:
142
- # Use fewer steps for faster upscaling with minimal quality loss
143
- upscaled_img = upscaler(
144
- prompt=prompt,
145
- image=final_img,
146
- num_inference_steps=15, # Reduced from 20 for speed
147
- guidance_scale=6.0, # Slightly lower for faster convergence
148
- generator=generator,
149
- ).images[0]
150
- yield upscaled_img, seed
151
- except Exception as e:
152
- print(f"Error during upscaling: {e}")
153
- yield final_img, seed
154
-
155
- except Exception as e:
156
- print(f"Error during generation: {e}")
157
- # Fallback to basic generation
158
- img = pipe(
159
- prompt=prompt,
160
- guidance_scale=guidance_scale,
161
- num_inference_steps=num_inference_steps,
162
- width=width,
163
- height=height,
164
- generator=generator,
165
- ).images[0]
166
-
167
- # Apply upscaling if enabled
168
- if enable_upscale:
169
- try:
170
- img = upscaler(
171
- prompt=prompt,
172
- image=img,
173
- num_inference_steps=20,
174
- guidance_scale=7.5,
175
- generator=generator,
176
- ).images[0]
177
- except Exception as e:
178
- print(f"Error during upscaling: {e}")
179
-
180
- yield img, seed
181
-
182
- examples = [
183
- "a tiny astronaut hatching from an egg on the moon",
184
- "a cat holding a sign that says hello world",
185
- "an anime illustration of a wiener schnitzel",
186
- ]
187
-
188
- css="""
189
- #col-container {
190
- margin: 0 auto;
191
- max-width: 520px;
192
- }
193
- """
194
-
195
- with gr.Blocks(css=css) as demo:
196
-
197
- with gr.Column(elem_id="col-container"):
198
- gr.Markdown(f"""# FLUX.1 [dev]
199
- 12B param rectified flow transformer guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
200
- [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
201
- """)
202
-
203
- with gr.Row():
204
-
205
- prompt = gr.Text(
206
- label="Prompt",
207
- show_label=False,
208
- max_lines=1,
209
- placeholder="Enter your prompt",
210
- container=False,
211
- )
212
-
213
- run_button = gr.Button("Run", scale=0)
214
-
215
- result = gr.Image(label="Result", show_label=False)
216
-
217
- with gr.Accordion("Advanced Settings", open=False):
218
-
219
- gr.Markdown("**LoRAs Active:** All LoRAs are loaded and active simultaneously")
220
-
221
- enable_upscale = gr.Checkbox(
222
- label="Enable 4x Upscaling",
223
- value=False,
224
- info="Upscale final image using Stable Diffusion 4x upscaler"
225
- )
226
-
227
- seed = gr.Slider(
228
- label="Seed",
229
- minimum=0,
230
- maximum=MAX_SEED,
231
- step=1,
232
- value=0,
233
- )
234
-
235
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
236
-
237
- with gr.Row():
238
-
239
- width = gr.Slider(
240
- label="Width",
241
- minimum=256,
242
- maximum=MAX_IMAGE_SIZE,
243
- step=32,
244
- value=1024,
245
- )
246
-
247
- height = gr.Slider(
248
- label="Height",
249
- minimum=256,
250
- maximum=MAX_IMAGE_SIZE,
251
- step=32,
252
- value=1024,
253
- )
254
-
255
- with gr.Row():
256
-
257
- guidance_scale = gr.Slider(
258
- label="Guidance Scale",
259
- minimum=1,
260
- maximum=15,
261
- step=0.1,
262
- value=3.5,
263
- info="Lower values = faster generation, higher values = more prompt adherence"
264
- )
265
-
266
- num_inference_steps = gr.Slider(
267
- label="Number of inference steps",
268
- minimum=4,
269
- maximum=50,
270
- step=1,
271
- value=20,
272
- info="Lower values = faster generation, higher values = better quality"
273
- )
274
-
275
- gr.Examples(
276
- examples = examples,
277
- fn = infer,
278
- inputs = [prompt],
279
- outputs = [result, seed],
280
- cache_examples="lazy"
281
- )
282
-
283
- gr.on(
284
- triggers=[run_button.click, prompt.submit],
285
- fn = infer,
286
- inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, enable_upscale],
287
- outputs = [result, seed]
288
- )
289
-
290
  demo.launch(share=True)
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ import spaces
5
+ import torch
6
+ from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, StableDiffusionUpscalePipeline
7
+ from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
8
+ from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
9
+ from huggingface_hub import hf_hub_download
10
+ import os
11
+ import requests
12
+
13
+ dtype = torch.bfloat16
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
17
+ good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
18
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1).to(device)
19
+
20
+ # Performance optimizations
21
+ if hasattr(pipe, "enable_model_cpu_offload"):
22
+ pipe.enable_model_cpu_offload()
23
+ if hasattr(pipe, "enable_attention_slicing"):
24
+ pipe.enable_attention_slicing(1)
25
+ if hasattr(pipe, "enable_vae_slicing"):
26
+ pipe.enable_vae_slicing()
27
+ if hasattr(pipe, "enable_vae_tiling"):
28
+ pipe.enable_vae_tiling()
29
+
30
+ # Compile transformer for faster inference (if supported)
31
+ try:
32
+ pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
33
+ print("✓ Transformer compiled for faster inference")
34
+ except Exception as e:
35
+ print(f"Warning: Could not compile transformer: {e}")
36
+
37
+ # Load upscaler pipeline with optimizations
38
+ upscaler = StableDiffusionUpscalePipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", torch_dtype=dtype).to(device)
39
+ if hasattr(upscaler, "enable_model_cpu_offload"):
40
+ upscaler.enable_model_cpu_offload()
41
+ if hasattr(upscaler, "enable_attention_slicing"):
42
+ upscaler.enable_attention_slicing(1)
43
+ if hasattr(upscaler, "enable_vae_slicing"):
44
+ upscaler.enable_vae_slicing()
45
+
46
+ # Available LoRAs
47
+ LORAS = {
48
+ "None": None,
49
+ "AntiBlur": "Shakker-Labs/FLUX.1-dev-LoRA-AntiBlur",
50
+ "Add Details": "Shakker-Labs/FLUX.1-dev-LoRA-add-details"
51
+ }
52
+
53
+ # Store loaded LoRA paths
54
+ loaded_loras = {}
55
+
56
+ def download_lora_from_url(url, filename):
57
+ """Download LoRA file from direct URL"""
58
+ if not os.path.exists(filename):
59
+ print(f"Downloading {filename}...")
60
+ response = requests.get(url)
61
+ with open(filename, 'wb') as f:
62
+ f.write(response.content)
63
+ print(f"Downloaded {filename}")
64
+ return filename
65
+
66
+ def preload_and_apply_all_loras():
67
+ """Download and apply all LoRAs simultaneously at startup"""
68
+ global loaded_loras
69
+
70
+ print("Downloading and applying all LoRAs...")
71
+
72
+ for lora_name, lora_path in LORAS.items():
73
+ if lora_name == "None" or lora_path is None:
74
+ continue
75
+
76
+ # Handle direct URL downloads
77
+ if lora_path.startswith('http'):
78
+ filename = f"{lora_name.lower().replace(' ', '_')}_lora.safetensors"
79
+ lora_path = download_lora_from_url(lora_path, filename)
80
+
81
+ loaded_loras[lora_name] = lora_path
82
+ print(f"Downloaded {lora_name}")
83
+
84
+ # Apply each LoRA with optimal scaling
85
+ try:
86
+ optimal_scale = get_optimal_lora_scale(lora_name)
87
+ pipe.load_lora_weights(lora_path, adapter_name=lora_name.lower().replace(' ', '_'))
88
+ print(f"Applied {lora_name} with scale {optimal_scale}")
89
+ except Exception as e:
90
+ print(f"Failed to apply {lora_name}: {e}")
91
+
92
+ print(f"All {len(loaded_loras)} LoRAs downloaded and applied!")
93
+
94
+ def get_optimal_lora_scale(lora_name):
95
+ """Return optimal LoRA scale based on LoRA type for better quality/speed balance"""
96
+ lora_scales = {
97
+ "AntiBlur": 0.8, # Slightly lower for better balance
98
+ "Add Details": 1.2, # Higher for more detail enhancement
99
+ "Ultra Realism": 0.9, # Balanced for realism
100
+ "Face Realism": 1.1, # Optimized for facial features
101
+ }
102
+ return lora_scales.get(lora_name, 1.0)
103
+
104
+ # Download and apply all LoRAs at startup
105
+ preload_and_apply_all_loras()
106
+
107
+ torch.cuda.empty_cache()
108
+
109
+ MAX_SEED = np.iinfo(np.int32).max
110
+ MAX_IMAGE_SIZE = 2048
111
+
112
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
113
+
114
+ @spaces.GPU(duration=75)
115
+ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, enable_upscale=False, progress=gr.Progress(track_tqdm=True)):
116
+ if randomize_seed:
117
+ seed = random.randint(0, MAX_SEED)
118
+ generator = torch.Generator().manual_seed(seed)
119
+
120
+ # All LoRAs are already loaded and active
121
+
122
+ try:
123
+ final_img = None
124
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
125
+ prompt=prompt,
126
+ guidance_scale=guidance_scale,
127
+ num_inference_steps=num_inference_steps,
128
+ width=width,
129
+ height=height,
130
+ generator=generator,
131
+ output_type="pil",
132
+ good_vae=good_vae,
133
+ ):
134
+ final_img = img
135
+ yield img, seed
136
+
137
+ # Apply upscaling if enabled with optimized settings
138
+ if enable_upscale and final_img is not None:
139
+ try:
140
+ # Use fewer steps for faster upscaling with minimal quality loss
141
+ upscaled_img = upscaler(
142
+ prompt=prompt,
143
+ image=final_img,
144
+ num_inference_steps=15, # Reduced from 20 for speed
145
+ guidance_scale=6.0, # Slightly lower for faster convergence
146
+ generator=generator,
147
+ ).images[0]
148
+ yield upscaled_img, seed
149
+ except Exception as e:
150
+ print(f"Error during upscaling: {e}")
151
+ yield final_img, seed
152
+
153
+ except Exception as e:
154
+ print(f"Error during generation: {e}")
155
+ # Fallback to basic generation
156
+ img = pipe(
157
+ prompt=prompt,
158
+ guidance_scale=guidance_scale,
159
+ num_inference_steps=num_inference_steps,
160
+ width=width,
161
+ height=height,
162
+ generator=generator,
163
+ ).images[0]
164
+
165
+ # Apply upscaling if enabled
166
+ if enable_upscale:
167
+ try:
168
+ img = upscaler(
169
+ prompt=prompt,
170
+ image=img,
171
+ num_inference_steps=20,
172
+ guidance_scale=7.5,
173
+ generator=generator,
174
+ ).images[0]
175
+ except Exception as e:
176
+ print(f"Error during upscaling: {e}")
177
+
178
+ yield img, seed
179
+
180
+ examples = [
181
+ "a tiny astronaut hatching from an egg on the moon",
182
+ "a cat holding a sign that says hello world",
183
+ "an anime illustration of a wiener schnitzel",
184
+ ]
185
+
186
+ css="""
187
+ #col-container {
188
+ margin: 0 auto;
189
+ max-width: 520px;
190
+ }
191
+ """
192
+
193
+ with gr.Blocks(css=css) as demo:
194
+
195
+ with gr.Column(elem_id="col-container"):
196
+ gr.Markdown(f"""# FLUX.1 [dev]
197
+ 12B param rectified flow transformer guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
198
+ [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
199
+ """)
200
+
201
+ with gr.Row():
202
+
203
+ prompt = gr.Text(
204
+ label="Prompt",
205
+ show_label=False,
206
+ max_lines=1,
207
+ placeholder="Enter your prompt",
208
+ container=False,
209
+ )
210
+
211
+ run_button = gr.Button("Run", scale=0)
212
+
213
+ result = gr.Image(label="Result", show_label=False)
214
+
215
+ with gr.Accordion("Advanced Settings", open=False):
216
+
217
+ gr.Markdown("**LoRAs Active:** All LoRAs are loaded and active simultaneously")
218
+
219
+ enable_upscale = gr.Checkbox(
220
+ label="Enable 4x Upscaling",
221
+ value=False,
222
+ info="Upscale final image using Stable Diffusion 4x upscaler"
223
+ )
224
+
225
+ seed = gr.Slider(
226
+ label="Seed",
227
+ minimum=0,
228
+ maximum=MAX_SEED,
229
+ step=1,
230
+ value=0,
231
+ )
232
+
233
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
234
+
235
+ with gr.Row():
236
+
237
+ width = gr.Slider(
238
+ label="Width",
239
+ minimum=256,
240
+ maximum=MAX_IMAGE_SIZE,
241
+ step=32,
242
+ value=1024,
243
+ )
244
+
245
+ height = gr.Slider(
246
+ label="Height",
247
+ minimum=256,
248
+ maximum=MAX_IMAGE_SIZE,
249
+ step=32,
250
+ value=1024,
251
+ )
252
+
253
+ with gr.Row():
254
+
255
+ guidance_scale = gr.Slider(
256
+ label="Guidance Scale",
257
+ minimum=1,
258
+ maximum=15,
259
+ step=0.1,
260
+ value=3.5,
261
+ info="Lower values = faster generation, higher values = more prompt adherence"
262
+ )
263
+
264
+ num_inference_steps = gr.Slider(
265
+ label="Number of inference steps",
266
+ minimum=4,
267
+ maximum=50,
268
+ step=1,
269
+ value=20,
270
+ info="Lower values = faster generation, higher values = better quality"
271
+ )
272
+
273
+ gr.Examples(
274
+ examples = examples,
275
+ fn = infer,
276
+ inputs = [prompt],
277
+ outputs = [result, seed],
278
+ cache_examples="lazy"
279
+ )
280
+
281
+ gr.on(
282
+ triggers=[run_button.click, prompt.submit],
283
+ fn = infer,
284
+ inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, enable_upscale],
285
+ outputs = [result, seed]
286
+ )
287
+
 
 
288
  demo.launch(share=True)