Spaces:
Sleeping
Sleeping
File size: 27,387 Bytes
24857f8 10b80bb 24857f8 10b80bb 24857f8 837e5f7 703d3c2 24857f8 dea278b 24857f8 dea278b 053c7f6 58e94fa 3db9095 58e94fa 3db9095 58e94fa 053c7f6 19d8da0 053c7f6 19d8da0 053c7f6 24857f8 96b4daf dd44013 24857f8 dd44013 837e5f7 24857f8 9526ed3 24857f8 96b4daf 24857f8 703d3c2 24857f8 703d3c2 24857f8 703d3c2 24857f8 703d3c2 24857f8 4183cba 24857f8 4183cba 24857f8 4183cba 24857f8 2df2c23 e219ce4 08948c8 e219ce4 08948c8 e219ce4 08948c8 e219ce4 08948c8 e219ce4 08948c8 e219ce4 08948c8 053c7f6 a2e4f10 053c7f6 08948c8 4183cba 08948c8 16d53ca 08948c8 053c7f6 a2e4f10 053c7f6 e93c3d3 08948c8 16d53ca 053c7f6 08948c8 86ef72b 08948c8 053c7f6 4183cba 08948c8 4183cba 08948c8 4183cba 08948c8 4183cba 08948c8 10b80bb 24857f8 3db9095 24857f8 dd44013 24857f8 08948c8 e219ce4 08948c8 e219ce4 08948c8 e219ce4 58e94fa e219ce4 58e94fa e219ce4 19d8da0 e219ce4 19d8da0 e219ce4 08948c8 74ae2d1 08948c8 3db9095 08948c8 16d53ca 58e94fa 16d53ca 58e94fa 08948c8 053c7f6 a2e4f10 053c7f6 a2e4f10 a7b86b6 08948c8 f5d32da 08948c8 a2e4f10 08948c8 a2e4f10 4183cba 08948c8 24857f8 08948c8 10b80bb 08948c8 10b80bb 24857f8 10b80bb 24857f8 10b80bb 4fd63bf 10b80bb 053c7f6 10b80bb 24857f8 f592ee6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 |
"""
FoundationPose inference server for Hugging Face Spaces with ZeroGPU.
This version uses pure Gradio for ZeroGPU compatibility.
"""
import base64
import json
import logging
import os
from pathlib import Path
from typing import Dict, List
# Ensure OMP_NUM_THREADS is a valid integer to avoid libgomp warnings
if not os.environ.get("OMP_NUM_THREADS", "").isdigit():
os.environ["OMP_NUM_THREADS"] = "1"
import cv2
import gradio as gr
import numpy as np
import torch
from masks import generate_naive_mask
DEFAULT_DATA_DIR = Path("/app/tests/reference/t_shape")
DEFAULT_MESH = DEFAULT_DATA_DIR / "t_shape.obj"
DEFAULT_RGB = DEFAULT_DATA_DIR / "rgb_001.png"
DEFAULT_DEPTH = DEFAULT_DATA_DIR / "depth_001.png"
DEFAULT_REF_IMAGES = [
DEFAULT_DATA_DIR / "rgb_001.png"
]
_slimsam_model = None
_slimsam_processor = None
_slimsam_device = None
def _get_slimsam():
"""Lazy-load SlimSAM to avoid heavy startup cost."""
global _slimsam_model, _slimsam_processor, _slimsam_device
if _slimsam_model is None or _slimsam_processor is None:
from transformers import SamModel, SamProcessor
_slimsam_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to(_slimsam_device)
_slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform")
logger.info("SlimSAM loaded on %s", _slimsam_device)
return _slimsam_model, _slimsam_processor, _slimsam_device
def _box_from_mask(mask_bool: np.ndarray) -> List[int]:
ys, xs = np.where(mask_bool)
if len(xs) == 0:
return [0, 0, mask_bool.shape[1] - 1, mask_bool.shape[0] - 1]
x0, x1 = int(xs.min()), int(xs.max())
y0, y1 = int(ys.min()), int(ys.max())
return [x0, y0, x1, y1]
def generate_slimsam_mask(rgb_image: np.ndarray, box_prompt: List[int]) -> tuple[np.ndarray, np.ndarray, float]:
"""Generate a SlimSAM mask using a box prompt."""
from PIL import Image
model, processor, device = _get_slimsam()
raw_image = Image.fromarray(rgb_image).convert("RGB")
enc = processor(raw_image, input_boxes=[[box_prompt]], return_tensors="np")
# Keep size tensors on CPU for post-processing
original_sizes = torch.as_tensor(enc["original_sizes"])
reshaped_sizes = torch.as_tensor(enc["reshaped_input_sizes"])
# Move model inputs to device
inputs = {
k: torch.as_tensor(v).to(device)
for k, v in enc.items()
if k not in {"original_sizes", "reshaped_input_sizes"}
}
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
original_sizes,
reshaped_sizes,
)[0]
scores = outputs.iou_scores.squeeze().cpu()
best_idx = int(scores.argmax().item())
best_mask = masks[0, best_idx].numpy()
best_score = float(scores[best_idx].item())
mask_bool = best_mask.astype(bool)
debug_mask = (mask_bool.astype(np.uint8) * 255)
return mask_bool, debug_mask, best_score
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s"
)
logger = logging.getLogger(__name__)
# Ensure OMP_NUM_THREADS is a valid integer to avoid libgomp warnings
if not os.environ.get("OMP_NUM_THREADS", "").isdigit():
os.environ["OMP_NUM_THREADS"] = "1"
# Always use real FoundationPose model
USE_REAL_MODEL = True
logger.info("Starting in REAL mode with FoundationPose")
class FoundationPoseInference:
"""Wrapper for FoundationPose model inference."""
def __init__(self):
self.model = None
self.device = None
self.initialized = False
self.tracked_objects = {}
self.use_real_model = USE_REAL_MODEL
def initialize_model(self):
"""Initialize the FoundationPose model on GPU."""
if self.initialized:
logger.info("Model already initialized")
return
logger.info("Initializing FoundationPose model...")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
if self.use_real_model:
try:
logger.info("Loading real FoundationPose model...")
from estimator import FoundationPoseEstimator
self.model = FoundationPoseEstimator(
device=str(self.device)
)
if getattr(self.model, "available", True):
logger.info("✓ Real FoundationPose model initialized successfully")
else:
raise RuntimeError("FoundationPose dependencies missing")
except Exception as e:
logger.error(f"Failed to initialize real model: {e}", exc_info=True)
logger.warning("Falling back to placeholder mode")
self.use_real_model = False
self.model = None
else:
logger.info("Using placeholder mode (set USE_REAL_MODEL=true for real inference)")
self.model = None
self.initialized = True
logger.info("FoundationPose inference ready")
def register_object(
self,
object_id: str,
reference_images: List[np.ndarray],
camera_intrinsics: Dict = None,
mesh_path: str = None
) -> bool:
"""Register an object for tracking with reference images."""
if not self.initialized:
self.initialize_model()
logger.info(f"Registering object '{object_id}' with {len(reference_images)} reference images")
if self.use_real_model and self.model is not None:
try:
success = self.model.register_object(
object_id=object_id,
reference_images=reference_images,
camera_intrinsics=camera_intrinsics,
mesh_path=mesh_path
)
if success:
self.tracked_objects[object_id] = {
"num_references": len(reference_images),
"camera_intrinsics": camera_intrinsics,
"mesh_path": mesh_path
}
return success
except Exception as e:
logger.error(f"Registration failed: {e}", exc_info=True)
return False
else:
self.tracked_objects[object_id] = {
"num_references": len(reference_images),
"camera_intrinsics": camera_intrinsics,
"mesh_path": mesh_path
}
logger.info(f"✓ Object '{object_id}' registered (placeholder mode)")
return True
def estimate_pose(
self,
object_id: str,
query_image: np.ndarray,
camera_intrinsics: Dict = None,
depth_image: np.ndarray = None,
mask: np.ndarray = None
) -> Dict:
"""Estimate 6D pose of an object in a query image."""
if not self.initialized:
return {"success": False, "error": "Model not initialized"}
if object_id not in self.tracked_objects:
return {"success": False, "error": f"Object '{object_id}' not registered"}
logger.info(f"Estimating pose for object '{object_id}'")
if self.use_real_model and self.model is not None:
try:
pose_result = self.model.estimate_pose(
object_id=object_id,
rgb_image=query_image,
depth_image=depth_image,
mask=mask,
camera_intrinsics=camera_intrinsics
)
if pose_result is None:
return {
"success": False,
"error": "Pose estimation returned None",
"poses": [],
"debug_mask": None
}
# Extract debug mask if present
debug_mask = pose_result.pop("debug_mask", None)
return {
"success": True,
"poses": [pose_result],
"debug_mask": debug_mask
}
except Exception as e:
logger.error(f"Pose estimation error: {e}", exc_info=True)
return {"success": False, "error": str(e), "poses": []}
else:
logger.info("Placeholder mode: returning empty pose result")
return {
"success": True,
"poses": [],
"note": "Placeholder mode - set USE_REAL_MODEL=true for real inference"
}
# Global model instance
pose_estimator = FoundationPoseInference()
# Gradio wrapper functions
def gradio_initialize_cad(object_id: str, mesh_file, reference_files: List, fx: float, fy: float, cx: float, cy: float):
"""Gradio wrapper for CAD-based object initialization."""
try:
if not mesh_file:
return "Error: No mesh file provided"
# Load reference images (optional for CAD mode)
reference_images = []
if reference_files:
for file in reference_files:
img = cv2.imread(file.name)
if img is None:
continue
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
reference_images.append(img)
# Prepare camera intrinsics
camera_intrinsics = {
"fx": fx,
"fy": fy,
"cx": cx,
"cy": cy
}
# Register object with mesh
success = pose_estimator.register_object(
object_id=object_id,
reference_images=reference_images if reference_images else [],
camera_intrinsics=camera_intrinsics,
mesh_path=mesh_file.name
)
if success:
ref_info = f" and {len(reference_images)} reference images" if reference_images else ""
return f"✓ Object '{object_id}' initialized with CAD model{ref_info}"
else:
return f"✗ Failed to initialize object '{object_id}'"
except Exception as e:
logger.error(f"CAD initialization error: {e}", exc_info=True)
return f"Error: {str(e)}"
def gradio_initialize_model_free(object_id: str, reference_files: List, fx: float, fy: float, cx: float, cy: float):
"""Gradio wrapper for model-free object initialization."""
try:
if not reference_files:
return "Error: No reference images provided"
# Load reference images
reference_images = []
for file in reference_files:
img = cv2.imread(file.name)
if img is None:
continue
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
reference_images.append(img)
if not reference_images:
return "Error: Could not load any reference images"
if len(reference_images) < 8:
return f"Warning: Only {len(reference_images)} images provided. 16-24 recommended for best results."
# Prepare camera intrinsics
camera_intrinsics = {
"fx": fx,
"fy": fy,
"cx": cx,
"cy": cy
}
# Register object without mesh (model-free)
success = pose_estimator.register_object(
object_id=object_id,
reference_images=reference_images,
camera_intrinsics=camera_intrinsics,
mesh_path=None
)
if success:
return f"✓ Object '{object_id}' initialized with {len(reference_images)} reference images (model-free mode)"
else:
return f"✗ Failed to initialize object '{object_id}'"
except Exception as e:
logger.error(f"Model-free initialization error: {e}", exc_info=True)
return f"Error: {str(e)}"
def gradio_estimate(
object_id: str,
query_image: np.ndarray,
depth_image: np.ndarray,
fx: float,
fy: float,
cx: float,
cy: float,
mask_method: str,
mask_editor_data
):
"""Gradio wrapper for pose estimation."""
try:
if query_image is None:
return "Error: No query image provided", None, None
# Process depth image if provided
depth = None
if depth_image is not None:
# Check if depth needs resizing to match RGB
if depth_image.shape[:2] != query_image.shape[:2]:
logger.warning(f"Depth {depth_image.shape[:2]} and RGB {query_image.shape[:2]} sizes don't match, resizing depth")
depth_image = cv2.resize(depth_image, (query_image.shape[1], query_image.shape[0]), interpolation=cv2.INTER_NEAREST)
# Convert to float32 if needed
if depth_image.dtype == np.uint16:
# Assume 16-bit depth in millimeters
depth = depth_image.astype(np.float32) / 1000.0
logger.info(f"Converted 16-bit depth to float32, range: [{depth.min():.3f}, {depth.max():.3f}]m")
elif depth_image.dtype == np.uint8:
# 8-bit depth (encoded), need to decode based on format
# For now, assume linear scaling to reasonable depth range
depth = depth_image.astype(np.float32) / 255.0 * 5.0 # Map to 0-5m
logger.info(f"Converted 8-bit depth to float32, range: [{depth.min():.3f}, {depth.max():.3f}]m")
else:
# Already float, use as-is
depth = depth_image.astype(np.float32)
logger.info(f"Using provided depth (dtype={depth_image.dtype}), range: [{depth.min():.3f}, {depth.max():.3f}]m")
# Handle color depth images (H, W, 3) - take first channel
if len(depth.shape) == 3:
logger.warning("Depth image has 3 channels, using first channel")
depth = depth[:, :, 0]
# Prepare camera intrinsics
camera_intrinsics = {
"fx": fx,
"fy": fy,
"cx": cx,
"cy": cy
}
# Choose mask method
mask = None
debug_mask = None
if mask_method == "SlimSAM":
# Use Otsu mask as a box prompt to guide SlimSAM
naive_mask, _, _, _ = generate_naive_mask(query_image)
box_prompt = _box_from_mask(naive_mask)
mask, debug_mask, score = generate_slimsam_mask(query_image, box_prompt)
logger.info("SlimSAM mask generated (score=%.3f, box=%s)", score, box_prompt)
elif mask_method == "Otsu":
mask, debug_mask, mask_percentage, fallback_full_image = generate_naive_mask(query_image)
logger.info("Otsu mask coverage %.1f%%", mask_percentage)
if fallback_full_image:
logger.warning("Otsu mask fallback to full image due to unrealistic coverage")
elif mask_method == "From editor":
editor_mask = None
if isinstance(mask_editor_data, dict):
layers = mask_editor_data.get("layers")
if isinstance(layers, list) and layers:
editor_mask = layers[-1]
else:
editor_mask = mask_editor_data.get("composite")
else:
editor_mask = mask_editor_data
if editor_mask is None:
return "Error: No editor mask provided", query_image, None
editor_mask = np.array(editor_mask)
if editor_mask.ndim == 3 and editor_mask.shape[2] >= 4:
alpha = editor_mask[:, :, 3]
mask = (alpha > 0).astype(np.uint8) * 255
elif editor_mask.ndim == 3:
gray = cv2.cvtColor(editor_mask, cv2.COLOR_RGB2GRAY)
mask = (gray > 0).astype(np.uint8) * 255
elif editor_mask.ndim == 2:
mask = (editor_mask > 0).astype(np.uint8) * 255
else:
return "Error: Unsupported editor mask format", query_image, None
debug_mask = mask
# Estimate pose
result = pose_estimator.estimate_pose(
object_id=object_id,
query_image=query_image,
depth_image=depth,
camera_intrinsics=camera_intrinsics,
mask=mask
)
if not result.get("success"):
error = result.get("error", "Unknown error")
# Still show mask output even on failure
mask_vis = None
if debug_mask is not None:
mask_vis = query_image.copy()
mask_overlay = np.zeros_like(query_image)
mask_overlay[:, :, 1] = debug_mask
mask_vis = cv2.addWeighted(mask_vis, 0.7, mask_overlay, 0.3, 0)
return f"✗ Estimation failed: {error}", query_image, mask_vis
poses = result.get("poses", [])
note = result.get("note", "")
if debug_mask is None:
debug_mask = result.get("debug_mask", None)
# Create mask visualization
mask_vis = None
if debug_mask is not None:
# Create an RGB visualization of the mask overlaid on the original image
mask_vis = query_image.copy()
# Create green overlay where mask is active
mask_overlay = np.zeros_like(query_image)
mask_overlay[:, :, 1] = debug_mask # Green channel
# Blend with original image
mask_vis = cv2.addWeighted(mask_vis, 0.7, mask_overlay, 0.3, 0)
# Format output
if not poses:
output = "⚠ No poses detected\n"
if note:
output += f"\nNote: {note}"
if debug_mask is not None:
mask_percentage = (debug_mask > 0).sum() / debug_mask.size * 100
output += f"\n\nMask Coverage: {mask_percentage:.1f}% of image"
return output, query_image, mask_vis
output = f"✓ Detected {len(poses)} pose(s):\n\n"
for i, pose in enumerate(poses):
output += f"Pose {i + 1}:\n"
output += f" Object ID: {pose.get('object_id', 'unknown')}\n"
if 'position' in pose:
pos = pose['position']
output += f" Position:\n"
output += f" x: {pos.get('x', 0):.4f} m\n"
output += f" y: {pos.get('y', 0):.4f} m\n"
output += f" z: {pos.get('z', 0):.4f} m\n"
if 'orientation' in pose:
ori = pose['orientation']
output += f" Orientation (quaternion):\n"
output += f" w: {ori.get('w', 0):.4f}\n"
output += f" x: {ori.get('x', 0):.4f}\n"
output += f" y: {ori.get('y', 0):.4f}\n"
output += f" z: {ori.get('z', 0):.4f}\n"
if 'confidence' in pose:
output += f" Confidence: {pose['confidence']:.2%}\n"
output += "\n"
if debug_mask is not None:
mask_percentage = (debug_mask > 0).sum() / debug_mask.size * 100
output += f"\nMask Coverage: {mask_percentage:.1f}% of image"
return output, query_image, mask_vis
except Exception as e:
logger.error(f"Gradio estimation error: {e}", exc_info=True)
return f"Error: {str(e)}", None, None
# Gradio UI
with gr.Blocks(title="FoundationPose Inference", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎯 FoundationPose 6D Object Pose Estimation")
gr.Markdown("Project page: https://nvlabs.github.io/FoundationPose/")
mode_indicator = gr.Markdown(
"**Mode:** 🟢 Real FoundationPose with GPU",
elem_id="mode"
)
with gr.Tabs():
# Tab 1: Initialize Object
with gr.Tab("Initialize Object"):
gr.Markdown("""
Choose the initialization mode based on whether you have a 3D CAD model of your object.
""")
with gr.Tabs():
# Sub-tab 1.1: CAD-Based Init
with gr.Tab("CAD-Based (Model-Based)"):
gr.Markdown("""
**Model-Based Mode**: Use this if you have a 3D mesh/CAD model (.obj, .stl, .ply).
- Upload your 3D mesh file
- Optionally upload reference images for better initialization
- More accurate and robust
""")
with gr.Row():
with gr.Column():
cad_object_id = gr.Textbox(
label="Object ID",
placeholder="e.g., target_cube",
value="target_cube"
)
cad_mesh_file = gr.File(
label="3D Mesh File (.obj, .stl, .ply)",
file_count="single",
file_types=[".obj", ".stl", ".ply", ".mesh"],
value=str(DEFAULT_MESH) if DEFAULT_MESH.exists() else None
)
cad_ref_files = gr.File(
label="Reference Images (Optional)",
file_count="multiple",
file_types=["image"],
value=[str(p) for p in DEFAULT_REF_IMAGES if p.exists()]
)
gr.Markdown("### Camera Intrinsics")
with gr.Row():
cad_fx = gr.Number(label="fx", value=193.13708498984758)
cad_fy = gr.Number(label="fy", value=193.13708498984758)
with gr.Row():
cad_cx = gr.Number(label="cx", value=120.0)
cad_cy = gr.Number(label="cy", value=80.0)
cad_init_button = gr.Button("Initialize with CAD", variant="primary")
with gr.Column():
cad_init_output = gr.Textbox(
label="Initialization Result",
lines=5,
interactive=False
)
cad_init_button.click(
fn=gradio_initialize_cad,
inputs=[cad_object_id, cad_mesh_file, cad_ref_files, cad_fx, cad_fy, cad_cx, cad_cy],
outputs=cad_init_output
)
# Sub-tab 1.2: Model-Free Init (disabled)
# Tab 2: Estimate Pose
with gr.Tab("Estimate Pose"):
gr.Markdown("""
Upload a query image containing the initialized object.
The model will estimate the 6D pose (position + orientation).
""")
gr.Markdown("""
**Troubleshooting**
- **Camera intrinsics**: make sure `fx`, `fy`, `cx`, `cy` match the query RGB resolution. If you resize the RGB image, scale intrinsics by the same factor.
- **Depth values**: verify units. 16-bit PNG is usually millimeters (converted to meters here). If values look clipped or too small/large, confirm the sensor's depth scale.
- **Image scales**: RGB and depth must be the same size. If they differ, depth is resized to match RGB (nearest-neighbor). Prefer exporting aligned pairs from the same stream.
""")
with gr.Row():
with gr.Column():
est_object_id = gr.Textbox(
label="Object ID",
placeholder="e.g., target_cube",
value="target_cube"
)
est_query_image = gr.Image(
label="Query Image (RGB)",
type="numpy",
value=str(DEFAULT_RGB) if DEFAULT_RGB.exists() else None
)
est_depth_image = gr.Image(
label="Depth Image (Optional, 16-bit PNG)",
type="numpy",
value=str(DEFAULT_DEPTH) if DEFAULT_DEPTH.exists() else None
)
est_mask_method = gr.Radio(
choices=["SlimSAM", "Otsu", "From editor"],
value="SlimSAM",
label="Mask Method"
)
est_mask_editor = gr.ImageEditor(
label="Mask Editor (paint mask)",
type="numpy",
visible=False
)
est_fx = gr.Number(label="fx (focal length x)", value=193.13708498984758, visible=False)
est_fy = gr.Number(label="fy (focal length y)", value=193.13708498984758, visible=False)
est_cx = gr.Number(label="cx (principal point x)", value=120.0, visible=False)
est_cy = gr.Number(label="cy (principal point y)", value=80.0, visible=False)
est_button = gr.Button("Estimate Pose", variant="primary")
with gr.Column():
est_mask = gr.Image(label="Auto-Generated Mask (green overlay)")
est_output = gr.Textbox(
label="Pose Estimation Result",
lines=15,
interactive=False
)
est_viz = gr.Image(label="Query Image")
def _toggle_editor(method: str):
return gr.update(visible=method == "From editor")
est_mask_method.change(
fn=_toggle_editor,
inputs=est_mask_method,
outputs=est_mask_editor
)
est_button.click(
fn=gradio_estimate,
inputs=[
est_object_id,
est_query_image,
est_depth_image,
est_fx,
est_fy,
est_cx,
est_cy,
est_mask_method,
est_mask_editor,
],
outputs=[est_output, est_viz, est_mask]
)
gr.Markdown("""
---
## API Documentation
This Space uses Gradio's built-in API. For programmatic access, use the `gradio_client` library:
```python
from gradio_client import Client
client = Client("https://gpue-foundationpose.hf.space")
# Initialize object
result = client.predict(
object_id="target_cube",
reference_files=[file1, file2, ...],
fx=500.0, fy=500.0, cx=320.0, cy=240.0,
api_name="/gradio_initialize"
)
# Estimate pose
result = client.predict(
object_id="target_cube",
query_image=image,
fx=500.0, fy=500.0, cx=320.0, cy=240.0,
mask_method="SlimSAM",
api_name="/gradio_estimate"
)
```
See [client.py](https://huggingface.co/spaces/gpue/foundationpose/blob/main/client.py) for a complete example.
""")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|