yeq6x commited on
Commit
1573e37
·
1 Parent(s): 426b3b0

Update inference output in app.py to yield stage2-only images, refining the generator's response structure for better integration with the UI.

Browse files
Files changed (1) hide show
  1. app.py +61 -27
app.py CHANGED
@@ -78,9 +78,9 @@ pipe.load_lora_weights(STAGE2_LORA_REPO, weight_name=STAGE2_LORA_WEIGHT, adapter
78
  # --- UI Constants ---
79
  MAX_SEED = np.iinfo(np.int32).max
80
 
81
- # --- Main Inference Function (Combined LoRA) ---
82
  @spaces.GPU()
83
- def infer(
84
  image,
85
  seed=42,
86
  randomize_seed=False,
@@ -88,29 +88,14 @@ def infer(
88
  num_inference_steps=4,
89
  height=None,
90
  width=None,
91
- stage1_weight=1.0,
92
- stage2_weight=1.0,
93
  progress=gr.Progress(track_tqdm=True),
94
  ):
95
  """
96
- Run stage2-only inference, then combined LoRAs: Lightning + Stage1 + Stage2.
97
-
98
- Parameters:
99
- image: Input image (PIL Image or path string).
100
- seed (int): Random seed for reproducibility.
101
- randomize_seed (bool): If True, overrides seed with a random value.
102
- true_guidance_scale (float): CFG scale used by Qwen-Image.
103
- num_inference_steps (int): Number of diffusion steps.
104
- height (int | None): Optional output height override.
105
- width (int | None): Optional output width override.
106
- stage1_weight (float): Weight for Stage1 LoRA.
107
- stage2_weight (float): Weight for Stage2 LoRA.
108
- progress: Gradio progress callback.
109
 
110
  Returns:
111
- generator: yields (stage2_only_image, result_image, seed_used)
112
  """
113
-
114
  # Hardcode the negative prompt
115
  negative_prompt = " "
116
 
@@ -150,7 +135,43 @@ def infer(
150
  num_images_per_prompt=1,
151
  ).images
152
  stage2_only_image = stage2_images[0] if stage2_images else None
153
- yield stage2_only_image, None, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  # --- Combined generation ---
156
  print(f"Generating with combined LoRAs...")
@@ -180,11 +201,10 @@ def infer(
180
  if pil_image.size != generated_image.size:
181
  pil_image = pil_image.resize(generated_image.size, Image.Resampling.LANCZOS)
182
  blended_image = Image.blend(pil_image, generated_image, alpha=0.75)
183
- yield gr.update(), blended_image, seed
184
- return
185
 
186
- # Return first result image and seed
187
- yield gr.update(), result_images[0] if result_images else None, seed
188
 
189
  # --- Examples and UI Layout ---
190
  examples = []
@@ -339,8 +359,9 @@ with gr.Blocks(css=css) as demo:
339
  value=None,
340
  )
341
 
342
- run_button.click(
343
- fn=infer,
 
344
  inputs=[
345
  input_image,
346
  seed,
@@ -349,10 +370,23 @@ with gr.Blocks(css=css) as demo:
349
  num_inference_steps,
350
  height,
351
  width,
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  stage1_weight,
353
  stage2_weight,
354
  ],
355
- outputs=[stage2_result, result, seed],
356
  )
357
 
358
  if __name__ == "__main__":
 
78
  # --- UI Constants ---
79
  MAX_SEED = np.iinfo(np.int32).max
80
 
81
+ # --- Main Inference Function (Split into two stages) ---
82
  @spaces.GPU()
83
+ def infer_stage2(
84
  image,
85
  seed=42,
86
  randomize_seed=False,
 
88
  num_inference_steps=4,
89
  height=None,
90
  width=None,
 
 
91
  progress=gr.Progress(track_tqdm=True),
92
  ):
93
  """
94
+ Run stage2-only inference.
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  Returns:
97
+ (stage2_only_image, image, seed, true_guidance_scale, num_inference_steps, height, width)
98
  """
 
99
  # Hardcode the negative prompt
100
  negative_prompt = " "
101
 
 
135
  num_images_per_prompt=1,
136
  ).images
137
  stage2_only_image = stage2_images[0] if stage2_images else None
138
+
139
+ return stage2_only_image, image, seed, true_guidance_scale, num_inference_steps, height, width
140
+
141
+ @spaces.GPU()
142
+ def infer_combined(
143
+ image,
144
+ seed,
145
+ true_guidance_scale,
146
+ num_inference_steps,
147
+ height,
148
+ width,
149
+ stage1_weight,
150
+ stage2_weight,
151
+ progress=gr.Progress(track_tqdm=True),
152
+ ):
153
+ """
154
+ Run combined LoRAs inference.
155
+
156
+ Returns:
157
+ result_image
158
+ """
159
+ # Hardcode the negative prompt
160
+ negative_prompt = " "
161
+
162
+ # Set up the generator for reproducibility
163
+ generator = torch.Generator(device=device).manual_seed(seed)
164
+
165
+ # Load input image into PIL Image
166
+ pil_image = None
167
+ if image is not None:
168
+ if isinstance(image, Image.Image):
169
+ pil_image = image.convert("RGB")
170
+ elif isinstance(image, str):
171
+ pil_image = Image.open(image).convert("RGB")
172
+
173
+ if height==256 and width==256:
174
+ height, width = None, None
175
 
176
  # --- Combined generation ---
177
  print(f"Generating with combined LoRAs...")
 
201
  if pil_image.size != generated_image.size:
202
  pil_image = pil_image.resize(generated_image.size, Image.Resampling.LANCZOS)
203
  blended_image = Image.blend(pil_image, generated_image, alpha=0.75)
204
+ return blended_image
 
205
 
206
+ # Return first result image
207
+ return result_images[0] if result_images else None
208
 
209
  # --- Examples and UI Layout ---
210
  examples = []
 
359
  value=None,
360
  )
361
 
362
+ # Chain two inference stages using .then()
363
+ stage2_event = run_button.click(
364
+ fn=infer_stage2,
365
  inputs=[
366
  input_image,
367
  seed,
 
370
  num_inference_steps,
371
  height,
372
  width,
373
+ ],
374
+ outputs=[stage2_result, gr.State(), gr.State(), gr.State(), gr.State(), gr.State(), gr.State()],
375
+ )
376
+
377
+ stage2_event.then(
378
+ fn=infer_combined,
379
+ inputs=[
380
+ stage2_event.outputs[1], # image
381
+ stage2_event.outputs[2], # seed
382
+ stage2_event.outputs[3], # true_guidance_scale
383
+ stage2_event.outputs[4], # num_inference_steps
384
+ stage2_event.outputs[5], # height
385
+ stage2_event.outputs[6], # width
386
  stage1_weight,
387
  stage2_weight,
388
  ],
389
+ outputs=[result],
390
  )
391
 
392
  if __name__ == "__main__":