phxdev Claude commited on
Commit
990ef3a
·
1 Parent(s): 17afea8

Add 4x upscaling with stabilityai/stable-diffusion-x4-upscaler

Browse files

- Load StableDiffusionUpscalePipeline for 4x upscaling
- Add enable_upscale checkbox in Advanced Settings
- Apply upscaling as final step after image generation
- Maintain live preview during generation, then upscale final image
- Add error handling for upscaling failures

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +44 -3
app.py CHANGED
@@ -3,7 +3,7 @@ import numpy as np
3
  import random
4
  import spaces
5
  import torch
6
- from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
7
  from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
8
  from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
9
  from huggingface_hub import hf_hub_download
@@ -17,6 +17,9 @@ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).
17
  good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
18
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1).to(device)
19
 
 
 
 
20
  # Available LoRAs
21
  LORAS = {
22
  "None": None,
@@ -69,7 +72,7 @@ MAX_IMAGE_SIZE = 2048
69
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
70
 
71
  @spaces.GPU(duration=75)
72
- def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, lora_selection="None", progress=gr.Progress(track_tqdm=True)):
73
  if randomize_seed:
74
  seed = random.randint(0, MAX_SEED)
75
  generator = torch.Generator().manual_seed(seed)
@@ -86,6 +89,7 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidan
86
  print(f"Failed to load LoRA {lora_selection}: {e}")
87
 
88
  try:
 
89
  for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
90
  prompt=prompt,
91
  guidance_scale=guidance_scale,
@@ -96,7 +100,24 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidan
96
  output_type="pil",
97
  good_vae=good_vae,
98
  ):
 
99
  yield img, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  except Exception as e:
101
  print(f"Error during generation: {e}")
102
  # Fallback to basic generation
@@ -108,6 +129,20 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidan
108
  height=height,
109
  generator=generator,
110
  ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  yield img, seed
112
 
113
  examples = [
@@ -154,6 +189,12 @@ with gr.Blocks(css=css) as demo:
154
  info="Select a LoRA to enhance image generation"
155
  )
156
 
 
 
 
 
 
 
157
  seed = gr.Slider(
158
  label="Seed",
159
  minimum=0,
@@ -211,7 +252,7 @@ with gr.Blocks(css=css) as demo:
211
  gr.on(
212
  triggers=[run_button.click, prompt.submit],
213
  fn = infer,
214
- inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, lora_selection],
215
  outputs = [result, seed]
216
  )
217
 
 
3
  import random
4
  import spaces
5
  import torch
6
+ from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, StableDiffusionUpscalePipeline
7
  from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
8
  from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
9
  from huggingface_hub import hf_hub_download
 
17
  good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
18
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1).to(device)
19
 
20
+ # Load upscaler pipeline
21
+ upscaler = StableDiffusionUpscalePipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", torch_dtype=dtype).to(device)
22
+
23
  # Available LoRAs
24
  LORAS = {
25
  "None": None,
 
72
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
73
 
74
  @spaces.GPU(duration=75)
75
+ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, lora_selection="None", enable_upscale=False, progress=gr.Progress(track_tqdm=True)):
76
  if randomize_seed:
77
  seed = random.randint(0, MAX_SEED)
78
  generator = torch.Generator().manual_seed(seed)
 
89
  print(f"Failed to load LoRA {lora_selection}: {e}")
90
 
91
  try:
92
+ final_img = None
93
  for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
94
  prompt=prompt,
95
  guidance_scale=guidance_scale,
 
100
  output_type="pil",
101
  good_vae=good_vae,
102
  ):
103
+ final_img = img
104
  yield img, seed
105
+
106
+ # Apply upscaling if enabled
107
+ if enable_upscale and final_img is not None:
108
+ try:
109
+ upscaled_img = upscaler(
110
+ prompt=prompt,
111
+ image=final_img,
112
+ num_inference_steps=20,
113
+ guidance_scale=7.5,
114
+ generator=generator,
115
+ ).images[0]
116
+ yield upscaled_img, seed
117
+ except Exception as e:
118
+ print(f"Error during upscaling: {e}")
119
+ yield final_img, seed
120
+
121
  except Exception as e:
122
  print(f"Error during generation: {e}")
123
  # Fallback to basic generation
 
129
  height=height,
130
  generator=generator,
131
  ).images[0]
132
+
133
+ # Apply upscaling if enabled
134
+ if enable_upscale:
135
+ try:
136
+ img = upscaler(
137
+ prompt=prompt,
138
+ image=img,
139
+ num_inference_steps=20,
140
+ guidance_scale=7.5,
141
+ generator=generator,
142
+ ).images[0]
143
+ except Exception as e:
144
+ print(f"Error during upscaling: {e}")
145
+
146
  yield img, seed
147
 
148
  examples = [
 
189
  info="Select a LoRA to enhance image generation"
190
  )
191
 
192
+ enable_upscale = gr.Checkbox(
193
+ label="Enable 4x Upscaling",
194
+ value=False,
195
+ info="Upscale final image using Stable Diffusion 4x upscaler"
196
+ )
197
+
198
  seed = gr.Slider(
199
  label="Seed",
200
  minimum=0,
 
252
  gr.on(
253
  triggers=[run_button.click, prompt.submit],
254
  fn = infer,
255
+ inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, lora_selection, enable_upscale],
256
  outputs = [result, seed]
257
  )
258