Spaces:
Sleeping
Sleeping
Georg commited on
Commit ·
053c7f6
1
Parent(s): f7e2564
Optimized Docker build to fix OOM errors
Browse files- app.py +90 -4
- estimator.py +38 -24
- requirements.txt +2 -0
app.py
CHANGED
|
@@ -16,6 +16,59 @@ import gradio as gr
|
|
| 16 |
import numpy as np
|
| 17 |
import torch
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
logging.basicConfig(
|
| 20 |
level=logging.INFO,
|
| 21 |
format="[%(asctime)s] %(levelname)s: %(message)s"
|
|
@@ -262,7 +315,16 @@ def gradio_initialize_model_free(object_id: str, reference_files: List, fx: floa
|
|
| 262 |
return f"Error: {str(e)}"
|
| 263 |
|
| 264 |
|
| 265 |
-
def gradio_estimate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
"""Gradio wrapper for pose estimation."""
|
| 267 |
try:
|
| 268 |
if query_image is None:
|
|
@@ -304,12 +366,28 @@ def gradio_estimate(object_id: str, query_image: np.ndarray, depth_image: np.nda
|
|
| 304 |
"cy": cy
|
| 305 |
}
|
| 306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
# Estimate pose
|
| 308 |
result = pose_estimator.estimate_pose(
|
| 309 |
object_id=object_id,
|
| 310 |
query_image=query_image,
|
| 311 |
depth_image=depth,
|
| 312 |
-
camera_intrinsics=camera_intrinsics
|
|
|
|
| 313 |
)
|
| 314 |
|
| 315 |
if not result.get("success"):
|
|
@@ -318,7 +396,8 @@ def gradio_estimate(object_id: str, query_image: np.ndarray, depth_image: np.nda
|
|
| 318 |
|
| 319 |
poses = result.get("poses", [])
|
| 320 |
note = result.get("note", "")
|
| 321 |
-
|
|
|
|
| 322 |
|
| 323 |
# Create mask visualization
|
| 324 |
mask_vis = None
|
|
@@ -524,6 +603,12 @@ with gr.Blocks(title="FoundationPose Inference", theme=gr.themes.Soft()) as demo
|
|
| 524 |
type="numpy"
|
| 525 |
)
|
| 526 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
gr.Markdown("### Camera Intrinsics")
|
| 528 |
with gr.Row():
|
| 529 |
est_fx = gr.Number(label="fx (focal length x)", value=500.0)
|
|
@@ -545,7 +630,7 @@ with gr.Blocks(title="FoundationPose Inference", theme=gr.themes.Soft()) as demo
|
|
| 545 |
|
| 546 |
est_button.click(
|
| 547 |
fn=gradio_estimate,
|
| 548 |
-
inputs=[est_object_id, est_query_image, est_depth_image, est_fx, est_fy, est_cx, est_cy],
|
| 549 |
outputs=[est_output, est_viz, est_mask]
|
| 550 |
)
|
| 551 |
|
|
@@ -573,6 +658,7 @@ with gr.Blocks(title="FoundationPose Inference", theme=gr.themes.Soft()) as demo
|
|
| 573 |
object_id="target_cube",
|
| 574 |
query_image=image,
|
| 575 |
fx=500.0, fy=500.0, cx=320.0, cy=240.0,
|
|
|
|
| 576 |
api_name="/gradio_estimate"
|
| 577 |
)
|
| 578 |
```
|
|
|
|
| 16 |
import numpy as np
|
| 17 |
import torch
|
| 18 |
|
| 19 |
+
from estimator import generate_naive_mask
|
| 20 |
+
|
| 21 |
+
_slimsam_model = None
|
| 22 |
+
_slimsam_processor = None
|
| 23 |
+
_slimsam_device = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _get_slimsam():
|
| 27 |
+
"""Lazy-load SlimSAM to avoid heavy startup cost."""
|
| 28 |
+
global _slimsam_model, _slimsam_processor, _slimsam_device
|
| 29 |
+
if _slimsam_model is None or _slimsam_processor is None:
|
| 30 |
+
from transformers import SamModel, SamProcessor
|
| 31 |
+
|
| 32 |
+
_slimsam_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 33 |
+
_slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to(_slimsam_device)
|
| 34 |
+
_slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform")
|
| 35 |
+
logger.info("SlimSAM loaded on %s", _slimsam_device)
|
| 36 |
+
|
| 37 |
+
return _slimsam_model, _slimsam_processor, _slimsam_device
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _box_from_mask(mask_bool: np.ndarray) -> List[int]:
|
| 41 |
+
ys, xs = np.where(mask_bool)
|
| 42 |
+
if len(xs) == 0:
|
| 43 |
+
return [0, 0, mask_bool.shape[1] - 1, mask_bool.shape[0] - 1]
|
| 44 |
+
x0, x1 = int(xs.min()), int(xs.max())
|
| 45 |
+
y0, y1 = int(ys.min()), int(ys.max())
|
| 46 |
+
return [x0, y0, x1, y1]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def generate_slimsam_mask(rgb_image: np.ndarray, box_prompt: List[int]) -> tuple[np.ndarray, np.ndarray, float]:
|
| 50 |
+
"""Generate a SlimSAM mask using a box prompt."""
|
| 51 |
+
from PIL import Image
|
| 52 |
+
|
| 53 |
+
model, processor, device = _get_slimsam()
|
| 54 |
+
raw_image = Image.fromarray(rgb_image).convert("RGB")
|
| 55 |
+
inputs = processor(raw_image, input_boxes=[[box_prompt]], return_tensors="pt").to(device)
|
| 56 |
+
outputs = model(**inputs)
|
| 57 |
+
|
| 58 |
+
masks = processor.image_processor.post_process_masks(
|
| 59 |
+
outputs.pred_masks.cpu(),
|
| 60 |
+
inputs["original_sizes"].cpu(),
|
| 61 |
+
inputs["reshaped_input_sizes"].cpu(),
|
| 62 |
+
)[0]
|
| 63 |
+
scores = outputs.iou_scores.squeeze().cpu()
|
| 64 |
+
best_idx = int(scores.argmax().item())
|
| 65 |
+
best_mask = masks[0, best_idx].numpy()
|
| 66 |
+
best_score = float(scores[best_idx].item())
|
| 67 |
+
|
| 68 |
+
mask_bool = best_mask.astype(bool)
|
| 69 |
+
debug_mask = (mask_bool.astype(np.uint8) * 255)
|
| 70 |
+
return mask_bool, debug_mask, best_score
|
| 71 |
+
|
| 72 |
logging.basicConfig(
|
| 73 |
level=logging.INFO,
|
| 74 |
format="[%(asctime)s] %(levelname)s: %(message)s"
|
|
|
|
| 315 |
return f"Error: {str(e)}"
|
| 316 |
|
| 317 |
|
| 318 |
+
def gradio_estimate(
|
| 319 |
+
object_id: str,
|
| 320 |
+
query_image: np.ndarray,
|
| 321 |
+
depth_image: np.ndarray,
|
| 322 |
+
fx: float,
|
| 323 |
+
fy: float,
|
| 324 |
+
cx: float,
|
| 325 |
+
cy: float,
|
| 326 |
+
mask_method: str
|
| 327 |
+
):
|
| 328 |
"""Gradio wrapper for pose estimation."""
|
| 329 |
try:
|
| 330 |
if query_image is None:
|
|
|
|
| 366 |
"cy": cy
|
| 367 |
}
|
| 368 |
|
| 369 |
+
# Choose mask method
|
| 370 |
+
mask = None
|
| 371 |
+
debug_mask = None
|
| 372 |
+
if mask_method == "SlimSAM":
|
| 373 |
+
# Use Otsu mask as a box prompt to guide SlimSAM
|
| 374 |
+
naive_mask, _, _, _ = generate_naive_mask(query_image)
|
| 375 |
+
box_prompt = _box_from_mask(naive_mask)
|
| 376 |
+
mask, debug_mask, score = generate_slimsam_mask(query_image, box_prompt)
|
| 377 |
+
logger.info("SlimSAM mask generated (score=%.3f, box=%s)", score, box_prompt)
|
| 378 |
+
elif mask_method == "Otsu":
|
| 379 |
+
mask, debug_mask, mask_percentage, fallback_full_image = generate_naive_mask(query_image)
|
| 380 |
+
logger.info("Otsu mask coverage %.1f%%", mask_percentage)
|
| 381 |
+
if fallback_full_image:
|
| 382 |
+
logger.warning("Otsu mask fallback to full image due to unrealistic coverage")
|
| 383 |
+
|
| 384 |
# Estimate pose
|
| 385 |
result = pose_estimator.estimate_pose(
|
| 386 |
object_id=object_id,
|
| 387 |
query_image=query_image,
|
| 388 |
depth_image=depth,
|
| 389 |
+
camera_intrinsics=camera_intrinsics,
|
| 390 |
+
mask=mask
|
| 391 |
)
|
| 392 |
|
| 393 |
if not result.get("success"):
|
|
|
|
| 396 |
|
| 397 |
poses = result.get("poses", [])
|
| 398 |
note = result.get("note", "")
|
| 399 |
+
if debug_mask is None:
|
| 400 |
+
debug_mask = result.get("debug_mask", None)
|
| 401 |
|
| 402 |
# Create mask visualization
|
| 403 |
mask_vis = None
|
|
|
|
| 603 |
type="numpy"
|
| 604 |
)
|
| 605 |
|
| 606 |
+
est_mask_method = gr.Radio(
|
| 607 |
+
choices=["SlimSAM", "Otsu"],
|
| 608 |
+
value="SlimSAM",
|
| 609 |
+
label="Mask Method"
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
gr.Markdown("### Camera Intrinsics")
|
| 613 |
with gr.Row():
|
| 614 |
est_fx = gr.Number(label="fx (focal length x)", value=500.0)
|
|
|
|
| 630 |
|
| 631 |
est_button.click(
|
| 632 |
fn=gradio_estimate,
|
| 633 |
+
inputs=[est_object_id, est_query_image, est_depth_image, est_fx, est_fy, est_cx, est_cy, est_mask_method],
|
| 634 |
outputs=[est_output, est_viz, est_mask]
|
| 635 |
)
|
| 636 |
|
|
|
|
| 658 |
object_id="target_cube",
|
| 659 |
query_image=image,
|
| 660 |
fx=500.0, fy=500.0, cx=320.0, cy=240.0,
|
| 661 |
+
mask_method="SlimSAM",
|
| 662 |
api_name="/gradio_estimate"
|
| 663 |
)
|
| 664 |
```
|
estimator.py
CHANGED
|
@@ -33,6 +33,39 @@ except ImportError as e:
|
|
| 33 |
FOUNDATIONPOSE_AVAILABLE = False
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
class FoundationPoseEstimator:
|
| 37 |
"""Wrapper for FoundationPose model."""
|
| 38 |
|
|
@@ -206,31 +239,12 @@ class FoundationPoseEstimator:
|
|
| 206 |
# Use automatic foreground segmentation based on brightness
|
| 207 |
# This works well for light objects on dark backgrounds
|
| 208 |
logger.info("Generating automatic object mask from image")
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
# Use Otsu's thresholding for automatic threshold selection
|
| 212 |
-
_, mask = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| 213 |
-
|
| 214 |
-
# Clean up mask with morphological operations
|
| 215 |
-
kernel = np.ones((5, 5), np.uint8)
|
| 216 |
-
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) # Fill holes
|
| 217 |
-
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) # Remove noise
|
| 218 |
-
|
| 219 |
-
# Store visualization version (uint8) before converting to boolean
|
| 220 |
-
debug_mask = mask.copy()
|
| 221 |
-
|
| 222 |
-
# Convert to boolean
|
| 223 |
-
mask = mask.astype(bool)
|
| 224 |
-
|
| 225 |
-
# Log mask statistics
|
| 226 |
-
mask_percentage = (mask.sum() / mask.size) * 100
|
| 227 |
logger.info(f"Auto-generated mask covers {mask_percentage:.1f}% of image")
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
mask = np.ones((rgb_image.shape[0], rgb_image.shape[1]), dtype=bool)
|
| 233 |
-
debug_mask = np.ones((rgb_image.shape[0], rgb_image.shape[1]), dtype=np.uint8) * 255
|
| 234 |
|
| 235 |
mask_was_generated = True
|
| 236 |
|
|
|
|
| 33 |
FOUNDATIONPOSE_AVAILABLE = False
|
| 34 |
|
| 35 |
|
| 36 |
+
def generate_naive_mask(
|
| 37 |
+
rgb_image: np.ndarray,
|
| 38 |
+
min_percentage: float = 1.0,
|
| 39 |
+
max_percentage: float = 90.0
|
| 40 |
+
) -> tuple[np.ndarray, np.ndarray, float, bool]:
|
| 41 |
+
"""Generate a naive foreground mask using brightness + Otsu thresholding.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
mask_bool: Boolean mask (H, W)
|
| 45 |
+
debug_mask: uint8 mask for visualization (H, W)
|
| 46 |
+
mask_percentage: % of pixels active in mask_bool
|
| 47 |
+
fallback_full_image: True if the mask was replaced by full-image mask
|
| 48 |
+
"""
|
| 49 |
+
gray = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2GRAY)
|
| 50 |
+
_, mask = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| 51 |
+
|
| 52 |
+
kernel = np.ones((5, 5), np.uint8)
|
| 53 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) # Fill holes
|
| 54 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) # Remove noise
|
| 55 |
+
|
| 56 |
+
debug_mask = mask.copy()
|
| 57 |
+
mask_bool = mask.astype(bool)
|
| 58 |
+
mask_percentage = (mask_bool.sum() / mask_bool.size) * 100
|
| 59 |
+
|
| 60 |
+
fallback_full_image = False
|
| 61 |
+
if mask_percentage < min_percentage or mask_percentage > max_percentage:
|
| 62 |
+
fallback_full_image = True
|
| 63 |
+
mask_bool = np.ones((rgb_image.shape[0], rgb_image.shape[1]), dtype=bool)
|
| 64 |
+
debug_mask = np.ones((rgb_image.shape[0], rgb_image.shape[1]), dtype=np.uint8) * 255
|
| 65 |
+
|
| 66 |
+
return mask_bool, debug_mask, mask_percentage, fallback_full_image
|
| 67 |
+
|
| 68 |
+
|
| 69 |
class FoundationPoseEstimator:
|
| 70 |
"""Wrapper for FoundationPose model."""
|
| 71 |
|
|
|
|
| 239 |
# Use automatic foreground segmentation based on brightness
|
| 240 |
# This works well for light objects on dark backgrounds
|
| 241 |
logger.info("Generating automatic object mask from image")
|
| 242 |
+
mask, debug_mask, mask_percentage, fallback_full_image = generate_naive_mask(rgb_image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
logger.info(f"Auto-generated mask covers {mask_percentage:.1f}% of image")
|
| 244 |
+
if fallback_full_image:
|
| 245 |
+
logger.warning(
|
| 246 |
+
f"Mask coverage ({mask_percentage:.1f}%) seems unrealistic, using full image"
|
| 247 |
+
)
|
|
|
|
|
|
|
| 248 |
|
| 249 |
mask_was_generated = True
|
| 250 |
|
requirements.txt
CHANGED
|
@@ -4,6 +4,8 @@ numpy>=1.24.0
|
|
| 4 |
opencv-python-headless>=4.8.0 # Headless version saves ~400MB
|
| 5 |
Pillow>=10.0.0
|
| 6 |
huggingface-hub>=0.20.0
|
|
|
|
|
|
|
| 7 |
|
| 8 |
# Note: torch and torchvision are installed separately with CUDA support
|
| 9 |
# Note: FoundationPose C++ extensions built at runtime
|
|
|
|
| 4 |
opencv-python-headless>=4.8.0 # Headless version saves ~400MB
|
| 5 |
Pillow>=10.0.0
|
| 6 |
huggingface-hub>=0.20.0
|
| 7 |
+
matplotlib>=3.8.0
|
| 8 |
+
transformers>=4.38.0
|
| 9 |
|
| 10 |
# Note: torch and torchvision are installed separately with CUDA support
|
| 11 |
# Note: FoundationPose C++ extensions built at runtime
|