Georg commited on
Commit
053c7f6
·
1 Parent(s): f7e2564

Optimized Docker build to fix OOM errors

Browse files
Files changed (3) hide show
  1. app.py +90 -4
  2. estimator.py +38 -24
  3. requirements.txt +2 -0
app.py CHANGED
@@ -16,6 +16,59 @@ import gradio as gr
16
  import numpy as np
17
  import torch
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  logging.basicConfig(
20
  level=logging.INFO,
21
  format="[%(asctime)s] %(levelname)s: %(message)s"
@@ -262,7 +315,16 @@ def gradio_initialize_model_free(object_id: str, reference_files: List, fx: floa
262
  return f"Error: {str(e)}"
263
 
264
 
265
- def gradio_estimate(object_id: str, query_image: np.ndarray, depth_image: np.ndarray, fx: float, fy: float, cx: float, cy: float):
 
 
 
 
 
 
 
 
 
266
  """Gradio wrapper for pose estimation."""
267
  try:
268
  if query_image is None:
@@ -304,12 +366,28 @@ def gradio_estimate(object_id: str, query_image: np.ndarray, depth_image: np.nda
304
  "cy": cy
305
  }
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  # Estimate pose
308
  result = pose_estimator.estimate_pose(
309
  object_id=object_id,
310
  query_image=query_image,
311
  depth_image=depth,
312
- camera_intrinsics=camera_intrinsics
 
313
  )
314
 
315
  if not result.get("success"):
@@ -318,7 +396,8 @@ def gradio_estimate(object_id: str, query_image: np.ndarray, depth_image: np.nda
318
 
319
  poses = result.get("poses", [])
320
  note = result.get("note", "")
321
- debug_mask = result.get("debug_mask", None)
 
322
 
323
  # Create mask visualization
324
  mask_vis = None
@@ -524,6 +603,12 @@ with gr.Blocks(title="FoundationPose Inference", theme=gr.themes.Soft()) as demo
524
  type="numpy"
525
  )
526
 
 
 
 
 
 
 
527
  gr.Markdown("### Camera Intrinsics")
528
  with gr.Row():
529
  est_fx = gr.Number(label="fx (focal length x)", value=500.0)
@@ -545,7 +630,7 @@ with gr.Blocks(title="FoundationPose Inference", theme=gr.themes.Soft()) as demo
545
 
546
  est_button.click(
547
  fn=gradio_estimate,
548
- inputs=[est_object_id, est_query_image, est_depth_image, est_fx, est_fy, est_cx, est_cy],
549
  outputs=[est_output, est_viz, est_mask]
550
  )
551
 
@@ -573,6 +658,7 @@ with gr.Blocks(title="FoundationPose Inference", theme=gr.themes.Soft()) as demo
573
  object_id="target_cube",
574
  query_image=image,
575
  fx=500.0, fy=500.0, cx=320.0, cy=240.0,
 
576
  api_name="/gradio_estimate"
577
  )
578
  ```
 
16
  import numpy as np
17
  import torch
18
 
19
+ from estimator import generate_naive_mask
20
+
21
+ _slimsam_model = None
22
+ _slimsam_processor = None
23
+ _slimsam_device = None
24
+
25
+
26
+ def _get_slimsam():
27
+ """Lazy-load SlimSAM to avoid heavy startup cost."""
28
+ global _slimsam_model, _slimsam_processor, _slimsam_device
29
+ if _slimsam_model is None or _slimsam_processor is None:
30
+ from transformers import SamModel, SamProcessor
31
+
32
+ _slimsam_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ _slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to(_slimsam_device)
34
+ _slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform")
35
+ logger.info("SlimSAM loaded on %s", _slimsam_device)
36
+
37
+ return _slimsam_model, _slimsam_processor, _slimsam_device
38
+
39
+
40
+ def _box_from_mask(mask_bool: np.ndarray) -> List[int]:
41
+ ys, xs = np.where(mask_bool)
42
+ if len(xs) == 0:
43
+ return [0, 0, mask_bool.shape[1] - 1, mask_bool.shape[0] - 1]
44
+ x0, x1 = int(xs.min()), int(xs.max())
45
+ y0, y1 = int(ys.min()), int(ys.max())
46
+ return [x0, y0, x1, y1]
47
+
48
+
49
+ def generate_slimsam_mask(rgb_image: np.ndarray, box_prompt: List[int]) -> tuple[np.ndarray, np.ndarray, float]:
50
+ """Generate a SlimSAM mask using a box prompt."""
51
+ from PIL import Image
52
+
53
+ model, processor, device = _get_slimsam()
54
+ raw_image = Image.fromarray(rgb_image).convert("RGB")
55
+ inputs = processor(raw_image, input_boxes=[[box_prompt]], return_tensors="pt").to(device)
56
+ outputs = model(**inputs)
57
+
58
+ masks = processor.image_processor.post_process_masks(
59
+ outputs.pred_masks.cpu(),
60
+ inputs["original_sizes"].cpu(),
61
+ inputs["reshaped_input_sizes"].cpu(),
62
+ )[0]
63
+ scores = outputs.iou_scores.squeeze().cpu()
64
+ best_idx = int(scores.argmax().item())
65
+ best_mask = masks[0, best_idx].numpy()
66
+ best_score = float(scores[best_idx].item())
67
+
68
+ mask_bool = best_mask.astype(bool)
69
+ debug_mask = (mask_bool.astype(np.uint8) * 255)
70
+ return mask_bool, debug_mask, best_score
71
+
72
  logging.basicConfig(
73
  level=logging.INFO,
74
  format="[%(asctime)s] %(levelname)s: %(message)s"
 
315
  return f"Error: {str(e)}"
316
 
317
 
318
+ def gradio_estimate(
319
+ object_id: str,
320
+ query_image: np.ndarray,
321
+ depth_image: np.ndarray,
322
+ fx: float,
323
+ fy: float,
324
+ cx: float,
325
+ cy: float,
326
+ mask_method: str
327
+ ):
328
  """Gradio wrapper for pose estimation."""
329
  try:
330
  if query_image is None:
 
366
  "cy": cy
367
  }
368
 
369
+ # Choose mask method
370
+ mask = None
371
+ debug_mask = None
372
+ if mask_method == "SlimSAM":
373
+ # Use Otsu mask as a box prompt to guide SlimSAM
374
+ naive_mask, _, _, _ = generate_naive_mask(query_image)
375
+ box_prompt = _box_from_mask(naive_mask)
376
+ mask, debug_mask, score = generate_slimsam_mask(query_image, box_prompt)
377
+ logger.info("SlimSAM mask generated (score=%.3f, box=%s)", score, box_prompt)
378
+ elif mask_method == "Otsu":
379
+ mask, debug_mask, mask_percentage, fallback_full_image = generate_naive_mask(query_image)
380
+ logger.info("Otsu mask coverage %.1f%%", mask_percentage)
381
+ if fallback_full_image:
382
+ logger.warning("Otsu mask fallback to full image due to unrealistic coverage")
383
+
384
  # Estimate pose
385
  result = pose_estimator.estimate_pose(
386
  object_id=object_id,
387
  query_image=query_image,
388
  depth_image=depth,
389
+ camera_intrinsics=camera_intrinsics,
390
+ mask=mask
391
  )
392
 
393
  if not result.get("success"):
 
396
 
397
  poses = result.get("poses", [])
398
  note = result.get("note", "")
399
+ if debug_mask is None:
400
+ debug_mask = result.get("debug_mask", None)
401
 
402
  # Create mask visualization
403
  mask_vis = None
 
603
  type="numpy"
604
  )
605
 
606
+ est_mask_method = gr.Radio(
607
+ choices=["SlimSAM", "Otsu"],
608
+ value="SlimSAM",
609
+ label="Mask Method"
610
+ )
611
+
612
  gr.Markdown("### Camera Intrinsics")
613
  with gr.Row():
614
  est_fx = gr.Number(label="fx (focal length x)", value=500.0)
 
630
 
631
  est_button.click(
632
  fn=gradio_estimate,
633
+ inputs=[est_object_id, est_query_image, est_depth_image, est_fx, est_fy, est_cx, est_cy, est_mask_method],
634
  outputs=[est_output, est_viz, est_mask]
635
  )
636
 
 
658
  object_id="target_cube",
659
  query_image=image,
660
  fx=500.0, fy=500.0, cx=320.0, cy=240.0,
661
+ mask_method="SlimSAM",
662
  api_name="/gradio_estimate"
663
  )
664
  ```
estimator.py CHANGED
@@ -33,6 +33,39 @@ except ImportError as e:
33
  FOUNDATIONPOSE_AVAILABLE = False
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  class FoundationPoseEstimator:
37
  """Wrapper for FoundationPose model."""
38
 
@@ -206,31 +239,12 @@ class FoundationPoseEstimator:
206
  # Use automatic foreground segmentation based on brightness
207
  # This works well for light objects on dark backgrounds
208
  logger.info("Generating automatic object mask from image")
209
- gray = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2GRAY)
210
-
211
- # Use Otsu's thresholding for automatic threshold selection
212
- _, mask = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
213
-
214
- # Clean up mask with morphological operations
215
- kernel = np.ones((5, 5), np.uint8)
216
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) # Fill holes
217
- mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) # Remove noise
218
-
219
- # Store visualization version (uint8) before converting to boolean
220
- debug_mask = mask.copy()
221
-
222
- # Convert to boolean
223
- mask = mask.astype(bool)
224
-
225
- # Log mask statistics
226
- mask_percentage = (mask.sum() / mask.size) * 100
227
  logger.info(f"Auto-generated mask covers {mask_percentage:.1f}% of image")
228
-
229
- # If mask is too large or too small, fall back to full image
230
- if mask_percentage < 1 or mask_percentage > 90:
231
- logger.warning(f"Mask coverage ({mask_percentage:.1f}%) seems unrealistic, using full image")
232
- mask = np.ones((rgb_image.shape[0], rgb_image.shape[1]), dtype=bool)
233
- debug_mask = np.ones((rgb_image.shape[0], rgb_image.shape[1]), dtype=np.uint8) * 255
234
 
235
  mask_was_generated = True
236
 
 
33
  FOUNDATIONPOSE_AVAILABLE = False
34
 
35
 
36
+ def generate_naive_mask(
37
+ rgb_image: np.ndarray,
38
+ min_percentage: float = 1.0,
39
+ max_percentage: float = 90.0
40
+ ) -> tuple[np.ndarray, np.ndarray, float, bool]:
41
+ """Generate a naive foreground mask using brightness + Otsu thresholding.
42
+
43
+ Returns:
44
+ mask_bool: Boolean mask (H, W)
45
+ debug_mask: uint8 mask for visualization (H, W)
46
+ mask_percentage: % of pixels active in mask_bool
47
+ fallback_full_image: True if the mask was replaced by full-image mask
48
+ """
49
+ gray = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2GRAY)
50
+ _, mask = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
51
+
52
+ kernel = np.ones((5, 5), np.uint8)
53
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) # Fill holes
54
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) # Remove noise
55
+
56
+ debug_mask = mask.copy()
57
+ mask_bool = mask.astype(bool)
58
+ mask_percentage = (mask_bool.sum() / mask_bool.size) * 100
59
+
60
+ fallback_full_image = False
61
+ if mask_percentage < min_percentage or mask_percentage > max_percentage:
62
+ fallback_full_image = True
63
+ mask_bool = np.ones((rgb_image.shape[0], rgb_image.shape[1]), dtype=bool)
64
+ debug_mask = np.ones((rgb_image.shape[0], rgb_image.shape[1]), dtype=np.uint8) * 255
65
+
66
+ return mask_bool, debug_mask, mask_percentage, fallback_full_image
67
+
68
+
69
  class FoundationPoseEstimator:
70
  """Wrapper for FoundationPose model."""
71
 
 
239
  # Use automatic foreground segmentation based on brightness
240
  # This works well for light objects on dark backgrounds
241
  logger.info("Generating automatic object mask from image")
242
+ mask, debug_mask, mask_percentage, fallback_full_image = generate_naive_mask(rgb_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  logger.info(f"Auto-generated mask covers {mask_percentage:.1f}% of image")
244
+ if fallback_full_image:
245
+ logger.warning(
246
+ f"Mask coverage ({mask_percentage:.1f}%) seems unrealistic, using full image"
247
+ )
 
 
248
 
249
  mask_was_generated = True
250
 
requirements.txt CHANGED
@@ -4,6 +4,8 @@ numpy>=1.24.0
4
  opencv-python-headless>=4.8.0 # Headless version saves ~400MB
5
  Pillow>=10.0.0
6
  huggingface-hub>=0.20.0
 
 
7
 
8
  # Note: torch and torchvision are installed separately with CUDA support
9
  # Note: FoundationPose C++ extensions built at runtime
 
4
  opencv-python-headless>=4.8.0 # Headless version saves ~400MB
5
  Pillow>=10.0.0
6
  huggingface-hub>=0.20.0
7
+ matplotlib>=3.8.0
8
+ transformers>=4.38.0
9
 
10
  # Note: torch and torchvision are installed separately with CUDA support
11
  # Note: FoundationPose C++ extensions built at runtime