prithivMLmods commited on
Commit
8b92413
·
verified ·
1 Parent(s): 47b061d

update app

Browse files
Files changed (1) hide show
  1. app.py +21 -22
app.py CHANGED
@@ -15,7 +15,6 @@ from gradio.themes.utils import colors, fonts, sizes
15
  import rerun as rr
16
  from gradio_rerun import Rerun
17
 
18
- # --- Theme Configuration ---
19
  colors.orange_red = colors.Color(
20
  name="orange_red",
21
  c50="#FFF0E5",
@@ -84,7 +83,6 @@ class OrangeRedTheme(Soft):
84
 
85
  orange_red_theme = OrangeRedTheme()
86
 
87
- # --- Model & Device Setup ---
88
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
 
90
  print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
@@ -180,9 +178,7 @@ def infer(
180
  pil_images = []
181
  if images is not None:
182
  for item in images:
183
- # Gradio Gallery returns a list of tuples (filepath, label) or (image, label) depending on version/type
184
  try:
185
- # Check for tuple (standard Gradio Gallery output)
186
  if isinstance(item, tuple) or isinstance(item, list):
187
  path_or_img = item[0]
188
  else:
@@ -193,7 +189,6 @@ def infer(
193
  elif isinstance(path_or_img, Image.Image):
194
  pil_images.append(path_or_img.convert("RGB"))
195
  else:
196
- # Fallback for complex Gradio objects
197
  pil_images.append(Image.open(path_or_img.name).convert("RGB"))
198
  except Exception as e:
199
  print(f"Skipping invalid image item: {e}")
@@ -232,13 +227,11 @@ def infer(
232
  generator = torch.Generator(device=device).manual_seed(seed)
233
  negative_prompt = "worst quality, low quality, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry"
234
 
235
- # Use dimensions from the first image for the output
236
  width, height = update_dimensions_on_upload(pil_images[0])
237
 
238
  try:
239
  progress(0.4, desc="Generating Image...")
240
 
241
- # Pass the list of PIL images to the pipeline
242
  result_image = pipe(
243
  image=pil_images,
244
  prompt=prompt,
@@ -250,11 +243,14 @@ def infer(
250
  true_cfg_scale=guidance_scale,
251
  ).images[0]
252
 
 
 
 
 
 
253
  # --- Rerun Visualization Logic ---
254
  progress(0.9, desc="Preparing Rerun Visualization...")
255
 
256
- run_id = str(uuid.uuid4())
257
-
258
  # Handle different Rerun SDK versions
259
  rec = None
260
  if hasattr(rr, "new_recording"):
@@ -265,7 +261,7 @@ def infer(
265
  rr.init("Qwen-Image-Edit", recording_id=run_id, spawn=False)
266
  rec = rr
267
 
268
- # Log all input images
269
  for i, img in enumerate(pil_images):
270
  rec.log(f"images/input_{i}", rr.Image(np.array(img)))
271
 
@@ -276,7 +272,7 @@ def infer(
276
  rrd_path = os.path.join(TMP_DIR, f"{run_id}.rrd")
277
  rec.save(rrd_path)
278
 
279
- return rrd_path, seed
280
 
281
  except Exception as e:
282
  raise e
@@ -286,16 +282,13 @@ def infer(
286
 
287
  @spaces.GPU
288
  def infer_example(images, prompt, lora_adapter):
289
- # Wrapper for examples (images coming from gr.Examples are usually list of filepaths)
290
  if not images:
291
- return None, 0
292
 
293
- # Ensure input is treated as a list even if example passes single path string
294
  if isinstance(images, str):
295
  images = [images]
296
 
297
- # infer expects the gallery format or list of paths
298
- result_rrd, seed = infer(
299
  images=images,
300
  prompt=prompt,
301
  lora_adapter=lora_adapter,
@@ -304,7 +297,7 @@ def infer_example(images, prompt, lora_adapter):
304
  guidance_scale=1.0,
305
  steps=4
306
  )
307
- return result_rrd, seed
308
 
309
  css="""
310
  #col-container {
@@ -321,7 +314,6 @@ with gr.Blocks() as demo:
321
 
322
  with gr.Row(equal_height=True):
323
  with gr.Column():
324
- # Changed to Gallery to support multiple images
325
  images = gr.Gallery(
326
  label="Upload Images",
327
  type="filepath",
@@ -342,7 +334,7 @@ with gr.Blocks() as demo:
342
  with gr.Column():
343
  rerun_output = Rerun(
344
  label="Rerun Visualization",
345
- height=355
346
  )
347
 
348
  with gr.Row():
@@ -351,13 +343,20 @@ with gr.Blocks() as demo:
351
  choices=list(ADAPTER_SPECS.keys()),
352
  value="Photo-to-Anime"
353
  )
 
354
  with gr.Accordion("Advanced Settings", open=False, visible=False):
355
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
356
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
357
  guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0)
358
  steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=4)
 
 
 
 
 
 
 
359
 
360
- # Updated examples to use list of paths for Gallery input
361
  gr.Examples(
362
  examples=[
363
  [["examples/B.jpg"], "Transform into anime.", "Photo-to-Anime"],
@@ -365,7 +364,7 @@ with gr.Blocks() as demo:
365
  [["examples/A.jpeg"], "Rotate the camera 45 degrees to the right.", "Multiple-Angles"],
366
  ],
367
  inputs=[images, prompt, lora_adapter],
368
- outputs=[rerun_output, seed],
369
  fn=infer_example,
370
  cache_examples=False,
371
  label="Examples"
@@ -376,7 +375,7 @@ with gr.Blocks() as demo:
376
  run_button.click(
377
  fn=infer,
378
  inputs=[images, prompt, lora_adapter, seed, randomize_seed, guidance_scale, steps],
379
- outputs=[rerun_output, seed]
380
  )
381
 
382
  if __name__ == "__main__":
 
15
  import rerun as rr
16
  from gradio_rerun import Rerun
17
 
 
18
  colors.orange_red = colors.Color(
19
  name="orange_red",
20
  c50="#FFF0E5",
 
83
 
84
  orange_red_theme = OrangeRedTheme()
85
 
 
86
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
87
 
88
  print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
 
178
  pil_images = []
179
  if images is not None:
180
  for item in images:
 
181
  try:
 
182
  if isinstance(item, tuple) or isinstance(item, list):
183
  path_or_img = item[0]
184
  else:
 
189
  elif isinstance(path_or_img, Image.Image):
190
  pil_images.append(path_or_img.convert("RGB"))
191
  else:
 
192
  pil_images.append(Image.open(path_or_img.name).convert("RGB"))
193
  except Exception as e:
194
  print(f"Skipping invalid image item: {e}")
 
227
  generator = torch.Generator(device=device).manual_seed(seed)
228
  negative_prompt = "worst quality, low quality, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry"
229
 
 
230
  width, height = update_dimensions_on_upload(pil_images[0])
231
 
232
  try:
233
  progress(0.4, desc="Generating Image...")
234
 
 
235
  result_image = pipe(
236
  image=pil_images,
237
  prompt=prompt,
 
243
  true_cfg_scale=guidance_scale,
244
  ).images[0]
245
 
246
+ # --- Save Image for Download ---
247
+ run_id = str(uuid.uuid4())
248
+ output_image_path = os.path.join(TMP_DIR, f"{run_id}_output.png")
249
+ result_image.save(output_image_path)
250
+
251
  # --- Rerun Visualization Logic ---
252
  progress(0.9, desc="Preparing Rerun Visualization...")
253
 
 
 
254
  # Handle different Rerun SDK versions
255
  rec = None
256
  if hasattr(rr, "new_recording"):
 
261
  rr.init("Qwen-Image-Edit", recording_id=run_id, spawn=False)
262
  rec = rr
263
 
264
+ # Log inputs
265
  for i, img in enumerate(pil_images):
266
  rec.log(f"images/input_{i}", rr.Image(np.array(img)))
267
 
 
272
  rrd_path = os.path.join(TMP_DIR, f"{run_id}.rrd")
273
  rec.save(rrd_path)
274
 
275
+ return rrd_path, seed, gr.update(value=output_image_path, visible=True)
276
 
277
  except Exception as e:
278
  raise e
 
282
 
283
  @spaces.GPU
284
  def infer_example(images, prompt, lora_adapter):
 
285
  if not images:
286
+ return None, 0, gr.update(visible=False)
287
 
 
288
  if isinstance(images, str):
289
  images = [images]
290
 
291
+ result_rrd, seed, img_path = infer(
 
292
  images=images,
293
  prompt=prompt,
294
  lora_adapter=lora_adapter,
 
297
  guidance_scale=1.0,
298
  steps=4
299
  )
300
+ return result_rrd, seed, img_path
301
 
302
  css="""
303
  #col-container {
 
314
 
315
  with gr.Row(equal_height=True):
316
  with gr.Column():
 
317
  images = gr.Gallery(
318
  label="Upload Images",
319
  type="filepath",
 
334
  with gr.Column():
335
  rerun_output = Rerun(
336
  label="Rerun Visualization",
337
+ height=353
338
  )
339
 
340
  with gr.Row():
 
343
  choices=list(ADAPTER_SPECS.keys()),
344
  value="Photo-to-Anime"
345
  )
346
+
347
  with gr.Accordion("Advanced Settings", open=False, visible=False):
348
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
349
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
350
  guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0)
351
  steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=4)
352
+
353
+
354
+ with gr.Accordion("Download The Edited Image File", open=False, visible=True):
355
+ download_button = gr.DownloadButton(
356
+ label="Download Image",
357
+ visible=True
358
+ )
359
 
 
360
  gr.Examples(
361
  examples=[
362
  [["examples/B.jpg"], "Transform into anime.", "Photo-to-Anime"],
 
364
  [["examples/A.jpeg"], "Rotate the camera 45 degrees to the right.", "Multiple-Angles"],
365
  ],
366
  inputs=[images, prompt, lora_adapter],
367
+ outputs=[rerun_output, seed, download_button],
368
  fn=infer_example,
369
  cache_examples=False,
370
  label="Examples"
 
375
  run_button.click(
376
  fn=infer,
377
  inputs=[images, prompt, lora_adapter, seed, randomize_seed, guidance_scale, steps],
378
+ outputs=[rerun_output, seed, download_button]
379
  )
380
 
381
  if __name__ == "__main__":