yeq6x commited on
Commit
2d9cf16
·
1 Parent(s): 2812ae1

Refactor inference process in app.py to support stage2-only generation and update output structure to return both stage2 and combined results. Adjust UI layout for improved result display and enhance generator function for reproducibility.

Browse files
Files changed (1) hide show
  1. app.py +49 -23
app.py CHANGED
@@ -6,10 +6,10 @@ import spaces
6
 
7
  from PIL import Image
8
  from diffusers import FlowMatchEulerDiscreteScheduler, QwenImageEditPlusPipeline
9
- # from optimization import optimize_pipeline_
10
- # from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
11
- # from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
12
- # from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
13
 
14
  import math
15
  import os
@@ -68,12 +68,12 @@ pipe.load_lora_weights(STAGE1_LORA_REPO, weight_name=STAGE1_LORA_WEIGHT, adapter
68
  # Load Stage 2 LoRA
69
  pipe.load_lora_weights(STAGE2_LORA_REPO, weight_name=STAGE2_LORA_WEIGHT, adapter_name="stage2")
70
 
71
- # # Apply the same optimizations from the first version
72
- # pipe.transformer.__class__ = QwenImageTransformer2DModel
73
- # pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
74
 
75
- # # --- Ahead-of-time compilation ---
76
- # optimize_pipeline_(pipe, image=[Image.new("RGB", (1024, 1024)), Image.new("RGB", (1024, 1024))], prompt="prompt")
77
 
78
  # --- UI Constants ---
79
  MAX_SEED = np.iinfo(np.int32).max
@@ -93,7 +93,7 @@ def infer(
93
  progress=gr.Progress(track_tqdm=True),
94
  ):
95
  """
96
- Run single inference with combined LoRAs: Lightning + Stage1 + Stage2.
97
 
98
  Parameters:
99
  image: Input image (PIL Image or path string).
@@ -108,7 +108,7 @@ def infer(
108
  progress: Gradio progress callback.
109
 
110
  Returns:
111
- tuple: (result_image, seed_used)
112
  """
113
 
114
  # Hardcode the negative prompt
@@ -117,8 +117,8 @@ def infer(
117
  if randomize_seed:
118
  seed = random.randint(0, MAX_SEED)
119
 
120
- # Set up the generator for reproducibility
121
- generator = torch.Generator(device=device).manual_seed(seed)
122
 
123
  # Load input image into PIL Image
124
  pil_image = None
@@ -131,6 +131,26 @@ def infer(
131
  if height==256 and width==256:
132
  height, width = None, None
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  # Apply all LoRAs: Lightning + Stage1 + Stage2
135
  print(f"Generating with combined LoRAs...")
136
  print(f"Prompt: '{STAGE1_PROMPT}'")
@@ -147,7 +167,7 @@ def infer(
147
  width=width,
148
  negative_prompt=negative_prompt,
149
  num_inference_steps=num_inference_steps,
150
- generator=generator,
151
  true_cfg_scale=true_guidance_scale,
152
  num_images_per_prompt=1,
153
  ).images
@@ -159,10 +179,10 @@ def infer(
159
  if pil_image.size != generated_image.size:
160
  pil_image = pil_image.resize(generated_image.size, Image.Resampling.LANCZOS)
161
  blended_image = Image.blend(pil_image, generated_image, alpha=0.75)
162
- return blended_image, seed
163
 
164
  # Return first result image and seed
165
- return result_images[0] if result_images else None, seed
166
 
167
  # --- Examples and UI Layout ---
168
  examples = []
@@ -170,7 +190,7 @@ examples = []
170
  css = """
171
  #col-container {
172
  margin: 0 auto;
173
- max-width: 900px;
174
  }
175
  #logo-title {
176
  text-align: center;
@@ -192,7 +212,8 @@ with gr.Blocks(css=css) as demo:
192
  show_label=False,
193
  type="pil",
194
  interactive=True,
195
- elem_id="input-image")
 
196
 
197
  gr.HTML("""
198
  <script>
@@ -243,13 +264,18 @@ with gr.Blocks(css=css) as demo:
243
  </script>
244
  """)
245
 
246
- with gr.Column(scale=1):
247
- gr.Markdown("### 📤 Result")
248
- result = gr.Image(label="Result", show_label=False, type="pil", interactive=False)
 
 
 
 
 
249
 
250
  run_button = gr.Button("🚀 Generate", variant="primary", size="lg")
251
 
252
- with gr.Accordion("Advanced Settings", open=False):
253
  with gr.Row():
254
  seed = gr.Slider(
255
  label="Seed",
@@ -325,7 +351,7 @@ with gr.Blocks(css=css) as demo:
325
  stage1_weight,
326
  stage2_weight,
327
  ],
328
- outputs=[result, seed],
329
  )
330
 
331
  if __name__ == "__main__":
 
6
 
7
  from PIL import Image
8
  from diffusers import FlowMatchEulerDiscreteScheduler, QwenImageEditPlusPipeline
9
+ from optimization import optimize_pipeline_
10
+ from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
11
+ from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
12
+ from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
13
 
14
  import math
15
  import os
 
68
  # Load Stage 2 LoRA
69
  pipe.load_lora_weights(STAGE2_LORA_REPO, weight_name=STAGE2_LORA_WEIGHT, adapter_name="stage2")
70
 
71
+ # Apply the same optimizations from the first version
72
+ pipe.transformer.__class__ = QwenImageTransformer2DModel
73
+ pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
74
 
75
+ # --- Ahead-of-time compilation ---
76
+ optimize_pipeline_(pipe, image=[Image.new("RGB", (1024, 1024)), Image.new("RGB", (1024, 1024))], prompt="prompt")
77
 
78
  # --- UI Constants ---
79
  MAX_SEED = np.iinfo(np.int32).max
 
93
  progress=gr.Progress(track_tqdm=True),
94
  ):
95
  """
96
+ Run stage2-only inference, then combined LoRAs: Lightning + Stage1 + Stage2.
97
 
98
  Parameters:
99
  image: Input image (PIL Image or path string).
 
108
  progress: Gradio progress callback.
109
 
110
  Returns:
111
+ tuple: (stage2_only_image, result_image, seed_used)
112
  """
113
 
114
  # Hardcode the negative prompt
 
117
  if randomize_seed:
118
  seed = random.randint(0, MAX_SEED)
119
 
120
+ def make_generator():
121
+ return torch.Generator(device=device).manual_seed(seed)
122
 
123
  # Load input image into PIL Image
124
  pil_image = None
 
131
  if height==256 and width==256:
132
  height, width = None, None
133
 
134
+ # Stage2-only generation
135
+ print("Generating with Stage2 LoRA only...")
136
+ print(f"Prompt: '{STAGE2_PROMPT}'")
137
+ print(f"Seed: {seed}, Steps: {num_inference_steps}, Guidance: {true_guidance_scale}, Size: {width}x{height}")
138
+ print("LoRA Weights - Stage2: 1.0")
139
+
140
+ pipe.set_adapters(["stage2"], adapter_weights=[1.0])
141
+ stage2_images = pipe(
142
+ image=[pil_image] if pil_image is not None else None,
143
+ prompt=STAGE2_PROMPT,
144
+ height=height,
145
+ width=width,
146
+ negative_prompt=negative_prompt,
147
+ num_inference_steps=num_inference_steps,
148
+ generator=make_generator(),
149
+ true_cfg_scale=true_guidance_scale,
150
+ num_images_per_prompt=1,
151
+ ).images
152
+ stage2_only_image = stage2_images[0] if stage2_images else None
153
+
154
  # Apply all LoRAs: Lightning + Stage1 + Stage2
155
  print(f"Generating with combined LoRAs...")
156
  print(f"Prompt: '{STAGE1_PROMPT}'")
 
167
  width=width,
168
  negative_prompt=negative_prompt,
169
  num_inference_steps=num_inference_steps,
170
+ generator=make_generator(),
171
  true_cfg_scale=true_guidance_scale,
172
  num_images_per_prompt=1,
173
  ).images
 
179
  if pil_image.size != generated_image.size:
180
  pil_image = pil_image.resize(generated_image.size, Image.Resampling.LANCZOS)
181
  blended_image = Image.blend(pil_image, generated_image, alpha=0.75)
182
+ return stage2_only_image, blended_image, seed
183
 
184
  # Return first result image and seed
185
+ return stage2_only_image, result_images[0] if result_images else None, seed
186
 
187
  # --- Examples and UI Layout ---
188
  examples = []
 
190
  css = """
191
  #col-container {
192
  margin: 0 auto;
193
+ max-width: 1000px;
194
  }
195
  #logo-title {
196
  text-align: center;
 
212
  show_label=False,
213
  type="pil",
214
  interactive=True,
215
+ elem_id="input-image",
216
+ height=350)
217
 
218
  gr.HTML("""
219
  <script>
 
264
  </script>
265
  """)
266
 
267
+ with gr.Row(scale=2):
268
+ with gr.Column(scale=1):
269
+ gr.Markdown("### 🧪 Result1")
270
+ stage2_result = gr.Image(label="Result1", show_label=False, type="pil", interactive=False, height=350)
271
+
272
+ with gr.Column(scale=1):
273
+ gr.Markdown("### 📤 Result2")
274
+ result = gr.Image(label="Result2", show_label=False, type="pil", interactive=False, height=350)
275
 
276
  run_button = gr.Button("🚀 Generate", variant="primary", size="lg")
277
 
278
+ with gr.Accordion("Advanced Settings", open=False, visible=False):
279
  with gr.Row():
280
  seed = gr.Slider(
281
  label="Seed",
 
351
  stage1_weight,
352
  stage2_weight,
353
  ],
354
+ outputs=[stage2_result, result, seed],
355
  )
356
 
357
  if __name__ == "__main__":