Spaces:
Sleeping
Sleeping
Georg
Claude Sonnet 4.5
commited on
Commit
·
16d53ca
1
Parent(s):
42ce71e
Add depth image support to FoundationPose API
Browse filesChanges:
- Add depth image upload field to Gradio UI (below query image)
- Update gradio_estimate() to accept and process depth images
- Support 16-bit PNG depth (converts mm to meters)
- Handle depth/RGB size mismatches with automatic resizing
- Update test suite to load RGB+depth test images
- Replace old reference images with single RGB/depth pair
Test updates:
- Load specific rgb_001.jpg and depth_001.png files
- Auto-resize depth to match RGB dimensions
- Print depth statistics (shape, dtype, range)
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
- app.py +37 -3
- tests/reference/t_shape/README.md +0 -63
- tests/reference/t_shape/depth_001.png +0 -0
- tests/reference/t_shape/image_001.jpg +0 -0
- tests/reference/t_shape/image_002.jpg +0 -0
- tests/reference/t_shape/image_004.jpg +0 -0
- tests/reference/t_shape/image_005.jpg +0 -0
- tests/reference/t_shape/image_006.jpg +0 -0
- tests/reference/t_shape/image_007.jpg +0 -0
- tests/reference/t_shape/image_008.jpg +0 -0
- tests/reference/t_shape/image_009.jpg +0 -0
- tests/reference/t_shape/image_010.jpg +0 -0
- tests/reference/t_shape/image_011.jpg +0 -0
- tests/reference/t_shape/image_012.jpg +0 -0
- tests/reference/t_shape/image_013.jpg +0 -0
- tests/reference/t_shape/image_014.jpg +0 -0
- tests/reference/t_shape/image_015.jpg +0 -0
- tests/reference/t_shape/{image_003.jpg → rgb_001.jpg} +0 -0
- tests/test_estimator.py +42 -27
app.py
CHANGED
|
@@ -262,12 +262,40 @@ def gradio_initialize_model_free(object_id: str, reference_files: List, fx: floa
|
|
| 262 |
return f"Error: {str(e)}"
|
| 263 |
|
| 264 |
|
| 265 |
-
def gradio_estimate(object_id: str, query_image: np.ndarray, fx: float, fy: float, cx: float, cy: float):
|
| 266 |
"""Gradio wrapper for pose estimation."""
|
| 267 |
try:
|
| 268 |
if query_image is None:
|
| 269 |
return "Error: No query image provided", None, None
|
| 270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
# Prepare camera intrinsics
|
| 272 |
camera_intrinsics = {
|
| 273 |
"fx": fx,
|
|
@@ -280,6 +308,7 @@ def gradio_estimate(object_id: str, query_image: np.ndarray, fx: float, fy: floa
|
|
| 280 |
result = pose_estimator.estimate_pose(
|
| 281 |
object_id=object_id,
|
| 282 |
query_image=query_image,
|
|
|
|
| 283 |
camera_intrinsics=camera_intrinsics
|
| 284 |
)
|
| 285 |
|
|
@@ -486,7 +515,12 @@ with gr.Blocks(title="FoundationPose Inference", theme=gr.themes.Soft()) as demo
|
|
| 486 |
)
|
| 487 |
|
| 488 |
est_query_image = gr.Image(
|
| 489 |
-
label="Query Image",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
type="numpy"
|
| 491 |
)
|
| 492 |
|
|
@@ -511,7 +545,7 @@ with gr.Blocks(title="FoundationPose Inference", theme=gr.themes.Soft()) as demo
|
|
| 511 |
|
| 512 |
est_button.click(
|
| 513 |
fn=gradio_estimate,
|
| 514 |
-
inputs=[est_object_id, est_query_image, est_fx, est_fy, est_cx, est_cy],
|
| 515 |
outputs=[est_output, est_viz, est_mask]
|
| 516 |
)
|
| 517 |
|
|
|
|
| 262 |
return f"Error: {str(e)}"
|
| 263 |
|
| 264 |
|
| 265 |
+
def gradio_estimate(object_id: str, query_image: np.ndarray, depth_image: np.ndarray, fx: float, fy: float, cx: float, cy: float):
|
| 266 |
"""Gradio wrapper for pose estimation."""
|
| 267 |
try:
|
| 268 |
if query_image is None:
|
| 269 |
return "Error: No query image provided", None, None
|
| 270 |
|
| 271 |
+
# Process depth image if provided
|
| 272 |
+
depth = None
|
| 273 |
+
if depth_image is not None:
|
| 274 |
+
# Check if depth needs resizing to match RGB
|
| 275 |
+
if depth_image.shape[:2] != query_image.shape[:2]:
|
| 276 |
+
logger.warning(f"Depth {depth_image.shape[:2]} and RGB {query_image.shape[:2]} sizes don't match, resizing depth")
|
| 277 |
+
depth_image = cv2.resize(depth_image, (query_image.shape[1], query_image.shape[0]), interpolation=cv2.INTER_NEAREST)
|
| 278 |
+
|
| 279 |
+
# Convert to float32 if needed
|
| 280 |
+
if depth_image.dtype == np.uint16:
|
| 281 |
+
# Assume 16-bit depth in millimeters
|
| 282 |
+
depth = depth_image.astype(np.float32) / 1000.0
|
| 283 |
+
logger.info(f"Converted 16-bit depth to float32, range: [{depth.min():.3f}, {depth.max():.3f}]m")
|
| 284 |
+
elif depth_image.dtype == np.uint8:
|
| 285 |
+
# 8-bit depth (encoded), need to decode based on format
|
| 286 |
+
# For now, assume linear scaling to reasonable depth range
|
| 287 |
+
depth = depth_image.astype(np.float32) / 255.0 * 5.0 # Map to 0-5m
|
| 288 |
+
logger.info(f"Converted 8-bit depth to float32, range: [{depth.min():.3f}, {depth.max():.3f}]m")
|
| 289 |
+
else:
|
| 290 |
+
# Already float, use as-is
|
| 291 |
+
depth = depth_image.astype(np.float32)
|
| 292 |
+
logger.info(f"Using provided depth (dtype={depth_image.dtype}), range: [{depth.min():.3f}, {depth.max():.3f}]m")
|
| 293 |
+
|
| 294 |
+
# Handle color depth images (H, W, 3) - take first channel
|
| 295 |
+
if len(depth.shape) == 3:
|
| 296 |
+
logger.warning("Depth image has 3 channels, using first channel")
|
| 297 |
+
depth = depth[:, :, 0]
|
| 298 |
+
|
| 299 |
# Prepare camera intrinsics
|
| 300 |
camera_intrinsics = {
|
| 301 |
"fx": fx,
|
|
|
|
| 308 |
result = pose_estimator.estimate_pose(
|
| 309 |
object_id=object_id,
|
| 310 |
query_image=query_image,
|
| 311 |
+
depth_image=depth,
|
| 312 |
camera_intrinsics=camera_intrinsics
|
| 313 |
)
|
| 314 |
|
|
|
|
| 515 |
)
|
| 516 |
|
| 517 |
est_query_image = gr.Image(
|
| 518 |
+
label="Query Image (RGB)",
|
| 519 |
+
type="numpy"
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
est_depth_image = gr.Image(
|
| 523 |
+
label="Depth Image (Optional, 16-bit PNG)",
|
| 524 |
type="numpy"
|
| 525 |
)
|
| 526 |
|
|
|
|
| 545 |
|
| 546 |
est_button.click(
|
| 547 |
fn=gradio_estimate,
|
| 548 |
+
inputs=[est_object_id, est_query_image, est_depth_image, est_fx, est_fy, est_cx, est_cy],
|
| 549 |
outputs=[est_output, est_viz, est_mask]
|
| 550 |
)
|
| 551 |
|
tests/reference/t_shape/README.md
DELETED
|
@@ -1,63 +0,0 @@
|
|
| 1 |
-
# T-Shaped Object Mesh
|
| 2 |
-
|
| 3 |
-
This directory contains a 3D mesh of the T-shaped pushing object from the MuJoCo scene `nova-sim/robots/ur5/model/scene_t_push.xml`.
|
| 4 |
-
|
| 5 |
-
## Files
|
| 6 |
-
|
| 7 |
-
- `t_shape.obj` - 3D mesh in Wavefront OBJ format
|
| 8 |
-
|
| 9 |
-
## Dimensions
|
| 10 |
-
|
| 11 |
-
The T-shape consists of two rectangular boxes:
|
| 12 |
-
|
| 13 |
-
### Stem (vertical part)
|
| 14 |
-
- Dimensions: 40mm × 140mm × 60mm (width × height × depth)
|
| 15 |
-
- Position: centered at (0, -50mm, 0)
|
| 16 |
-
|
| 17 |
-
### Cap (horizontal part)
|
| 18 |
-
- Dimensions: 160mm × 40mm × 60mm
|
| 19 |
-
- Position: centered at (0, 30mm, 0)
|
| 20 |
-
|
| 21 |
-
### Overall Bounds
|
| 22 |
-
- X: [-80mm, 80mm] (160mm total width)
|
| 23 |
-
- Y: [-120mm, 50mm] (170mm total height)
|
| 24 |
-
- Z: [-30mm, 30mm] (60mm total depth)
|
| 25 |
-
|
| 26 |
-
## Usage
|
| 27 |
-
|
| 28 |
-
This mesh can be used with FoundationPose's CAD-based initialization mode for 6D pose estimation of the T-shaped object in the nova-sim push manipulation task.
|
| 29 |
-
|
| 30 |
-
### Example Usage
|
| 31 |
-
|
| 32 |
-
```python
|
| 33 |
-
from gradio_client import Client, handle_file
|
| 34 |
-
|
| 35 |
-
client = Client("https://gpue-foundationpose.hf.space")
|
| 36 |
-
|
| 37 |
-
# Initialize with T-shape mesh
|
| 38 |
-
result = client.predict(
|
| 39 |
-
object_id="t_shape",
|
| 40 |
-
mesh_file=handle_file("t_shape.obj"),
|
| 41 |
-
reference_files=[], # Optional reference images
|
| 42 |
-
fx=500.0, fy=500.0, cx=320.0, cy=240.0,
|
| 43 |
-
api_name="/gradio_initialize_cad"
|
| 44 |
-
)
|
| 45 |
-
|
| 46 |
-
# Estimate pose in query image
|
| 47 |
-
result = client.predict(
|
| 48 |
-
object_id="t_shape",
|
| 49 |
-
query_image=handle_file("camera_frame.jpg"),
|
| 50 |
-
fx=500.0, fy=500.0, cx=320.0, cy=240.0,
|
| 51 |
-
api_name="/gradio_estimate"
|
| 52 |
-
)
|
| 53 |
-
```
|
| 54 |
-
|
| 55 |
-
## Material Properties (from MuJoCo)
|
| 56 |
-
|
| 57 |
-
- Mass: 5.0 kg total (stem: 3.0 kg, cap: 2.0 kg)
|
| 58 |
-
- Friction: 0.3 (sliding), 0.005 (torsional), 0.005 (rolling)
|
| 59 |
-
- Color: Light blue (rgba: 0.55, 0.65, 0.98, 1.0)
|
| 60 |
-
|
| 61 |
-
## Generation
|
| 62 |
-
|
| 63 |
-
This mesh was automatically generated from the MuJoCo scene definition using a Python script that extracts the box geometries and creates a combined mesh.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/reference/t_shape/depth_001.png
ADDED
|
tests/reference/t_shape/image_001.jpg
DELETED
|
Binary file (4.79 kB)
|
|
|
tests/reference/t_shape/image_002.jpg
DELETED
|
Binary file (4.96 kB)
|
|
|
tests/reference/t_shape/image_004.jpg
DELETED
|
Binary file (4.43 kB)
|
|
|
tests/reference/t_shape/image_005.jpg
DELETED
|
Binary file (4.91 kB)
|
|
|
tests/reference/t_shape/image_006.jpg
DELETED
|
Binary file (4.69 kB)
|
|
|
tests/reference/t_shape/image_007.jpg
DELETED
|
Binary file (4.67 kB)
|
|
|
tests/reference/t_shape/image_008.jpg
DELETED
|
Binary file (4.86 kB)
|
|
|
tests/reference/t_shape/image_009.jpg
DELETED
|
Binary file (4.49 kB)
|
|
|
tests/reference/t_shape/image_010.jpg
DELETED
|
Binary file (4.9 kB)
|
|
|
tests/reference/t_shape/image_011.jpg
DELETED
|
Binary file (4.3 kB)
|
|
|
tests/reference/t_shape/image_012.jpg
DELETED
|
Binary file (4.56 kB)
|
|
|
tests/reference/t_shape/image_013.jpg
DELETED
|
Binary file (4.97 kB)
|
|
|
tests/reference/t_shape/image_014.jpg
DELETED
|
Binary file (4.64 kB)
|
|
|
tests/reference/t_shape/image_015.jpg
DELETED
|
Binary file (4.74 kB)
|
|
|
tests/reference/t_shape/{image_003.jpg → rgb_001.jpg}
RENAMED
|
File without changes
|
tests/test_estimator.py
CHANGED
|
@@ -8,7 +8,7 @@ This test verifies that the API can:
|
|
| 8 |
|
| 9 |
import sys
|
| 10 |
from pathlib import Path
|
| 11 |
-
import
|
| 12 |
import cv2
|
| 13 |
from gradio_client import Client, handle_file
|
| 14 |
|
|
@@ -18,24 +18,38 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
| 18 |
from client import FoundationPoseClient
|
| 19 |
|
| 20 |
|
| 21 |
-
def
|
| 22 |
-
"""Load
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
f for f in reference_dir.glob("*")
|
| 26 |
-
if f.suffix.lower() in ['.jpg', '.png']
|
| 27 |
-
])
|
| 28 |
-
images = []
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
images.append(img)
|
| 37 |
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
def test_client_initialization():
|
|
@@ -178,11 +192,16 @@ def main():
|
|
| 178 |
return
|
| 179 |
|
| 180 |
print(f"\nUsing T-shape mesh: {mesh_path}")
|
| 181 |
-
print(f"Using
|
| 182 |
|
| 183 |
-
# Load
|
| 184 |
-
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
# Test 1: Initialize API client
|
| 188 |
client = test_client_initialization()
|
|
@@ -200,12 +219,8 @@ def main():
|
|
| 200 |
print("=" * 60)
|
| 201 |
return
|
| 202 |
|
| 203 |
-
# Test 3: Estimate pose
|
| 204 |
-
|
| 205 |
-
query_image = reference_images[random_idx]
|
| 206 |
-
query_name = image_files[random_idx].name
|
| 207 |
-
|
| 208 |
-
success = test_pose_estimation(client, query_image, query_name)
|
| 209 |
|
| 210 |
# Print final results
|
| 211 |
print("\n" + "=" * 60)
|
|
|
|
| 8 |
|
| 9 |
import sys
|
| 10 |
from pathlib import Path
|
| 11 |
+
import numpy as np
|
| 12 |
import cv2
|
| 13 |
from gradio_client import Client, handle_file
|
| 14 |
|
|
|
|
| 18 |
from client import FoundationPoseClient
|
| 19 |
|
| 20 |
|
| 21 |
+
def load_test_data(reference_dir: Path):
|
| 22 |
+
"""Load RGB and depth test images from t_shape directory."""
|
| 23 |
+
rgb_path = reference_dir / "rgb_001.jpg"
|
| 24 |
+
depth_path = reference_dir / "depth_001.png"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
# Load RGB image
|
| 27 |
+
print(f"Loading RGB: {rgb_path}")
|
| 28 |
+
rgb = cv2.imread(str(rgb_path))
|
| 29 |
+
if rgb is None:
|
| 30 |
+
raise FileNotFoundError(f"Could not load RGB image: {rgb_path}")
|
| 31 |
+
rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
|
|
|
|
| 32 |
|
| 33 |
+
# Load depth image (16-bit PNG)
|
| 34 |
+
print(f"Loading depth: {depth_path}")
|
| 35 |
+
depth = cv2.imread(str(depth_path), cv2.IMREAD_ANYDEPTH)
|
| 36 |
+
if depth is None:
|
| 37 |
+
raise FileNotFoundError(f"Could not load depth image: {depth_path}")
|
| 38 |
+
|
| 39 |
+
# Check if depth needs resizing to match RGB
|
| 40 |
+
if depth.shape[:2] != rgb.shape[:2]:
|
| 41 |
+
print(f"⚠ Warning: Depth ({depth.shape[:2]}) and RGB ({rgb.shape[:2]}) sizes don't match")
|
| 42 |
+
print(f" Resizing depth to match RGB...")
|
| 43 |
+
depth = cv2.resize(depth, (rgb.shape[1], rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
|
| 44 |
+
|
| 45 |
+
# Convert depth to meters (assuming it's in mm or similar)
|
| 46 |
+
# Depth should be float32 in meters for FoundationPose
|
| 47 |
+
depth = depth.astype(np.float32) / 1000.0 # Convert mm to meters
|
| 48 |
+
|
| 49 |
+
print(f"✓ RGB loaded: shape={rgb.shape}, dtype={rgb.dtype}")
|
| 50 |
+
print(f"✓ Depth loaded: shape={depth.shape}, dtype={depth.dtype}, range=[{depth.min():.3f}, {depth.max():.3f}]m")
|
| 51 |
+
|
| 52 |
+
return rgb, depth
|
| 53 |
|
| 54 |
|
| 55 |
def test_client_initialization():
|
|
|
|
| 192 |
return
|
| 193 |
|
| 194 |
print(f"\nUsing T-shape mesh: {mesh_path}")
|
| 195 |
+
print(f"Using test data from: {reference_dir}")
|
| 196 |
|
| 197 |
+
# Load test RGB and depth images
|
| 198 |
+
try:
|
| 199 |
+
rgb_image, depth_image = load_test_data(reference_dir)
|
| 200 |
+
except FileNotFoundError as e:
|
| 201 |
+
print(f"✗ {e}")
|
| 202 |
+
return
|
| 203 |
+
|
| 204 |
+
print(f"\n⚠ Note: API currently only supports RGB (depth support coming soon)")
|
| 205 |
|
| 206 |
# Test 1: Initialize API client
|
| 207 |
client = test_client_initialization()
|
|
|
|
| 219 |
print("=" * 60)
|
| 220 |
return
|
| 221 |
|
| 222 |
+
# Test 3: Estimate pose using RGB image
|
| 223 |
+
success = test_pose_estimation(client, rgb_image, "rgb_001.jpg")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
# Print final results
|
| 226 |
print("\n" + "=" * 60)
|