Georg commited on
Commit
a2e4f10
·
1 Parent(s): a7b86b6

Prepare job build context

Browse files
Files changed (1) hide show
  1. app.py +56 -3
app.py CHANGED
@@ -352,7 +352,8 @@ def gradio_estimate(
352
  fy: float,
353
  cx: float,
354
  cy: float,
355
- mask_method: str
 
356
  ):
357
  """Gradio wrapper for pose estimation."""
358
  try:
@@ -409,6 +410,33 @@ def gradio_estimate(
409
  logger.info("Otsu mask coverage %.1f%%", mask_percentage)
410
  if fallback_full_image:
411
  logger.warning("Otsu mask fallback to full image due to unrealistic coverage")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
  # Estimate pose
414
  result = pose_estimator.estimate_pose(
@@ -593,11 +621,17 @@ with gr.Blocks(title="FoundationPose Inference", theme=gr.themes.Soft()) as demo
593
  )
594
 
595
  est_mask_method = gr.Radio(
596
- choices=["SlimSAM", "Otsu"],
597
  value="SlimSAM",
598
  label="Mask Method"
599
  )
600
 
 
 
 
 
 
 
601
  est_fx = gr.Number(label="fx (focal length x)", value=193.13708498984758, visible=False)
602
  est_fy = gr.Number(label="fy (focal length y)", value=193.13708498984758, visible=False)
603
  est_cx = gr.Number(label="cx (principal point x)", value=120.0, visible=False)
@@ -614,9 +648,28 @@ with gr.Blocks(title="FoundationPose Inference", theme=gr.themes.Soft()) as demo
614
  )
615
  est_viz = gr.Image(label="Query Image")
616
 
 
 
 
 
 
 
 
 
 
617
  est_button.click(
618
  fn=gradio_estimate,
619
- inputs=[est_object_id, est_query_image, est_depth_image, est_fx, est_fy, est_cx, est_cy, est_mask_method],
 
 
 
 
 
 
 
 
 
 
620
  outputs=[est_output, est_viz, est_mask]
621
  )
622
 
 
352
  fy: float,
353
  cx: float,
354
  cy: float,
355
+ mask_method: str,
356
+ mask_editor_data
357
  ):
358
  """Gradio wrapper for pose estimation."""
359
  try:
 
410
  logger.info("Otsu mask coverage %.1f%%", mask_percentage)
411
  if fallback_full_image:
412
  logger.warning("Otsu mask fallback to full image due to unrealistic coverage")
413
+ elif mask_method == "From editor":
414
+ editor_mask = None
415
+ if isinstance(mask_editor_data, dict):
416
+ layers = mask_editor_data.get("layers")
417
+ if isinstance(layers, list) and layers:
418
+ editor_mask = layers[-1]
419
+ else:
420
+ editor_mask = mask_editor_data.get("composite")
421
+ else:
422
+ editor_mask = mask_editor_data
423
+
424
+ if editor_mask is None:
425
+ return "Error: No editor mask provided", query_image, None
426
+
427
+ editor_mask = np.array(editor_mask)
428
+ if editor_mask.ndim == 3 and editor_mask.shape[2] >= 4:
429
+ alpha = editor_mask[:, :, 3]
430
+ mask = (alpha > 0).astype(np.uint8) * 255
431
+ elif editor_mask.ndim == 3:
432
+ gray = cv2.cvtColor(editor_mask, cv2.COLOR_RGB2GRAY)
433
+ mask = (gray > 0).astype(np.uint8) * 255
434
+ elif editor_mask.ndim == 2:
435
+ mask = (editor_mask > 0).astype(np.uint8) * 255
436
+ else:
437
+ return "Error: Unsupported editor mask format", query_image, None
438
+
439
+ debug_mask = mask
440
 
441
  # Estimate pose
442
  result = pose_estimator.estimate_pose(
 
621
  )
622
 
623
  est_mask_method = gr.Radio(
624
+ choices=["SlimSAM", "Otsu", "From editor"],
625
  value="SlimSAM",
626
  label="Mask Method"
627
  )
628
 
629
+ est_mask_editor = gr.ImageEditor(
630
+ label="Mask Editor (paint mask)",
631
+ type="numpy",
632
+ visible=False
633
+ )
634
+
635
  est_fx = gr.Number(label="fx (focal length x)", value=193.13708498984758, visible=False)
636
  est_fy = gr.Number(label="fy (focal length y)", value=193.13708498984758, visible=False)
637
  est_cx = gr.Number(label="cx (principal point x)", value=120.0, visible=False)
 
648
  )
649
  est_viz = gr.Image(label="Query Image")
650
 
651
+ def _toggle_editor(method: str):
652
+ return gr.update(visible=method == "From editor")
653
+
654
+ est_mask_method.change(
655
+ fn=_toggle_editor,
656
+ inputs=est_mask_method,
657
+ outputs=est_mask_editor
658
+ )
659
+
660
  est_button.click(
661
  fn=gradio_estimate,
662
+ inputs=[
663
+ est_object_id,
664
+ est_query_image,
665
+ est_depth_image,
666
+ est_fx,
667
+ est_fy,
668
+ est_cx,
669
+ est_cy,
670
+ est_mask_method,
671
+ est_mask_editor,
672
+ ],
673
  outputs=[est_output, est_viz, est_mask]
674
  )
675