linoyts HF Staff commited on
Commit
e83f3dc
·
verified ·
1 Parent(s): e129c9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -18
app.py CHANGED
@@ -187,18 +187,40 @@ def remove_background_from_image(image: Image.Image) -> Image.Image:
187
 
188
  return result_image
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  # --- Inference ---
191
  @spaces.GPU
192
  def infer(
193
- gallery_images,
194
  image_background,
195
  prompt="",
196
  seed=42,
197
  randomize_seed=True,
198
  true_guidance_scale=1,
199
  num_inference_steps=4,
200
- height=None,
201
- width=None,
202
  progress=gr.Progress(track_tqdm=True)
203
  ):
204
  if randomize_seed:
@@ -206,26 +228,27 @@ def infer(
206
  generator = torch.Generator(device=device).manual_seed(seed)
207
 
208
  processed_subjects = []
209
- if gallery_images:
210
- for img in gallery_images:
211
- image = img[0] # Extract PIL image from gallery format
212
 
213
- image = remove_background_from_image(image)
214
-
215
- # Always remove alpha channels to ensure RGB format
216
- image = remove_alpha_channel(image)
217
- processed_subjects.append(image)
218
 
219
  all_inputs = processed_subjects
220
  if image_background is not None:
221
  all_inputs.append(image_background)
222
 
 
 
223
  if not all_inputs:
224
  raise gr.Error("Please upload at least one image or a background image.")
225
 
226
  result = pipe(
227
  image=all_inputs,
228
  prompt=prompt,
 
 
229
  num_inference_steps=num_inference_steps,
230
  generator=generator,
231
  true_cfg_scale=true_guidance_scale,
@@ -247,9 +270,8 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
247
  with gr.Row():
248
  with gr.Column():
249
  with gr.Row():
250
- gallery = gr.Gallery(
251
- label="Product image (background auto removed)",
252
- columns=3, rows=2, height="auto", type="pil"
253
  )
254
  image_background = gr.Image(label="Background Image", type="pil", visible=True)
255
  prompt = gr.Textbox(label="Prompt")
@@ -260,8 +282,6 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
260
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
261
  true_guidance_scale = gr.Slider(label="True Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0)
262
  num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=40, step=1, value=4)
263
- height = gr.Slider(label="Height", minimum=256, maximum=2048, step=8, value=1024)
264
- width = gr.Slider(label="Width", minimum=256, maximum=2048, step=8, value=1024)
265
 
266
  with gr.Column():
267
  result = gr.ImageSlider(label="Output Image", interactive=False)
@@ -274,14 +294,14 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
274
  [[], "fusion_shoes.png", ""],
275
  [["product_3.png"], "background_3.png", ""],
276
  ],
277
- inputs=[gallery, image_background, prompt],
278
  outputs=[result, seed],
279
  fn=infer,
280
  cache_examples="lazy",
281
  elem_id="examples"
282
  )
283
 
284
- inputs = [gallery, image_background, prompt, seed, randomize_seed, true_guidance_scale, num_inference_steps, height, width]
285
  outputs = [result, seed]
286
 
287
  run_button.click(fn=infer, inputs=inputs, outputs=outputs)
 
187
 
188
  return result_image
189
 
190
+
191
+ def calculate_dimensions(image):
192
+ """Calculate output dimensions based on background image, keeping largest side at 1024."""
193
+ if image is None:
194
+ return 1024, 1024
195
+
196
+ original_width, original_height = image.size
197
+
198
+ if original_width > original_height:
199
+ new_width = 1024
200
+ aspect_ratio = original_height / original_width
201
+ new_height = int(new_width * aspect_ratio)
202
+ else:
203
+ new_height = 1024
204
+ aspect_ratio = original_width / original_height
205
+ new_width = int(new_height * aspect_ratio)
206
+
207
+ # Ensure dimensions are multiples of 8
208
+ new_width = (new_width // 8) * 8
209
+ new_height = (new_height // 8) * 8
210
+
211
+ return new_width, new_height
212
+
213
+
214
  # --- Inference ---
215
  @spaces.GPU
216
  def infer(
217
+ product_image,
218
  image_background,
219
  prompt="",
220
  seed=42,
221
  randomize_seed=True,
222
  true_guidance_scale=1,
223
  num_inference_steps=4,
 
 
224
  progress=gr.Progress(track_tqdm=True)
225
  ):
226
  if randomize_seed:
 
228
  generator = torch.Generator(device=device).manual_seed(seed)
229
 
230
  processed_subjects = []
231
+ if product_image:
232
+ image = remove_background_from_image(image)
 
233
 
234
+ # Always remove alpha channels to ensure RGB format
235
+ image = remove_alpha_channel(image)
236
+ processed_subjects.append(image)
 
 
237
 
238
  all_inputs = processed_subjects
239
  if image_background is not None:
240
  all_inputs.append(image_background)
241
 
242
+ width, height = calculate_dimensions(image_background)
243
+
244
  if not all_inputs:
245
  raise gr.Error("Please upload at least one image or a background image.")
246
 
247
  result = pipe(
248
  image=all_inputs,
249
  prompt=prompt,
250
+ width=width,
251
+ height=height,
252
  num_inference_steps=num_inference_steps,
253
  generator=generator,
254
  true_cfg_scale=true_guidance_scale,
 
270
  with gr.Row():
271
  with gr.Column():
272
  with gr.Row():
273
+ product_image = gr.Image(
274
+ label="Product image (background auto removed)" type="pil"
 
275
  )
276
  image_background = gr.Image(label="Background Image", type="pil", visible=True)
277
  prompt = gr.Textbox(label="Prompt")
 
282
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
283
  true_guidance_scale = gr.Slider(label="True Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0)
284
  num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=40, step=1, value=4)
 
 
285
 
286
  with gr.Column():
287
  result = gr.ImageSlider(label="Output Image", interactive=False)
 
294
  [[], "fusion_shoes.png", ""],
295
  [["product_3.png"], "background_3.png", ""],
296
  ],
297
+ inputs=[product_image, image_background, prompt],
298
  outputs=[result, seed],
299
  fn=infer,
300
  cache_examples="lazy",
301
  elem_id="examples"
302
  )
303
 
304
+ inputs = [product_image, image_background, prompt, seed, randomize_seed, true_guidance_scale, num_inference_steps]
305
  outputs = [result, seed]
306
 
307
  run_button.click(fn=infer, inputs=inputs, outputs=outputs)