Georg Claude Sonnet 4.5 commited on
Commit
16d53ca
·
1 Parent(s): 42ce71e

Add depth image support to FoundationPose API

Browse files

Changes:
- Add depth image upload field to Gradio UI (below query image)
- Update gradio_estimate() to accept and process depth images
- Support 16-bit PNG depth (converts mm to meters)
- Handle depth/RGB size mismatches with automatic resizing
- Update test suite to load RGB+depth test images
- Replace old reference images with single RGB/depth pair

Test updates:
- Load specific rgb_001.jpg and depth_001.png files
- Auto-resize depth to match RGB dimensions
- Print depth statistics (shape, dtype, range)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

app.py CHANGED
@@ -262,12 +262,40 @@ def gradio_initialize_model_free(object_id: str, reference_files: List, fx: floa
262
  return f"Error: {str(e)}"
263
 
264
 
265
- def gradio_estimate(object_id: str, query_image: np.ndarray, fx: float, fy: float, cx: float, cy: float):
266
  """Gradio wrapper for pose estimation."""
267
  try:
268
  if query_image is None:
269
  return "Error: No query image provided", None, None
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  # Prepare camera intrinsics
272
  camera_intrinsics = {
273
  "fx": fx,
@@ -280,6 +308,7 @@ def gradio_estimate(object_id: str, query_image: np.ndarray, fx: float, fy: floa
280
  result = pose_estimator.estimate_pose(
281
  object_id=object_id,
282
  query_image=query_image,
 
283
  camera_intrinsics=camera_intrinsics
284
  )
285
 
@@ -486,7 +515,12 @@ with gr.Blocks(title="FoundationPose Inference", theme=gr.themes.Soft()) as demo
486
  )
487
 
488
  est_query_image = gr.Image(
489
- label="Query Image",
 
 
 
 
 
490
  type="numpy"
491
  )
492
 
@@ -511,7 +545,7 @@ with gr.Blocks(title="FoundationPose Inference", theme=gr.themes.Soft()) as demo
511
 
512
  est_button.click(
513
  fn=gradio_estimate,
514
- inputs=[est_object_id, est_query_image, est_fx, est_fy, est_cx, est_cy],
515
  outputs=[est_output, est_viz, est_mask]
516
  )
517
 
 
262
  return f"Error: {str(e)}"
263
 
264
 
265
+ def gradio_estimate(object_id: str, query_image: np.ndarray, depth_image: np.ndarray, fx: float, fy: float, cx: float, cy: float):
266
  """Gradio wrapper for pose estimation."""
267
  try:
268
  if query_image is None:
269
  return "Error: No query image provided", None, None
270
 
271
+ # Process depth image if provided
272
+ depth = None
273
+ if depth_image is not None:
274
+ # Check if depth needs resizing to match RGB
275
+ if depth_image.shape[:2] != query_image.shape[:2]:
276
+ logger.warning(f"Depth {depth_image.shape[:2]} and RGB {query_image.shape[:2]} sizes don't match, resizing depth")
277
+ depth_image = cv2.resize(depth_image, (query_image.shape[1], query_image.shape[0]), interpolation=cv2.INTER_NEAREST)
278
+
279
+ # Convert to float32 if needed
280
+ if depth_image.dtype == np.uint16:
281
+ # Assume 16-bit depth in millimeters
282
+ depth = depth_image.astype(np.float32) / 1000.0
283
+ logger.info(f"Converted 16-bit depth to float32, range: [{depth.min():.3f}, {depth.max():.3f}]m")
284
+ elif depth_image.dtype == np.uint8:
285
+ # 8-bit depth (encoded), need to decode based on format
286
+ # For now, assume linear scaling to reasonable depth range
287
+ depth = depth_image.astype(np.float32) / 255.0 * 5.0 # Map to 0-5m
288
+ logger.info(f"Converted 8-bit depth to float32, range: [{depth.min():.3f}, {depth.max():.3f}]m")
289
+ else:
290
+ # Already float, use as-is
291
+ depth = depth_image.astype(np.float32)
292
+ logger.info(f"Using provided depth (dtype={depth_image.dtype}), range: [{depth.min():.3f}, {depth.max():.3f}]m")
293
+
294
+ # Handle color depth images (H, W, 3) - take first channel
295
+ if len(depth.shape) == 3:
296
+ logger.warning("Depth image has 3 channels, using first channel")
297
+ depth = depth[:, :, 0]
298
+
299
  # Prepare camera intrinsics
300
  camera_intrinsics = {
301
  "fx": fx,
 
308
  result = pose_estimator.estimate_pose(
309
  object_id=object_id,
310
  query_image=query_image,
311
+ depth_image=depth,
312
  camera_intrinsics=camera_intrinsics
313
  )
314
 
 
515
  )
516
 
517
  est_query_image = gr.Image(
518
+ label="Query Image (RGB)",
519
+ type="numpy"
520
+ )
521
+
522
+ est_depth_image = gr.Image(
523
+ label="Depth Image (Optional, 16-bit PNG)",
524
  type="numpy"
525
  )
526
 
 
545
 
546
  est_button.click(
547
  fn=gradio_estimate,
548
+ inputs=[est_object_id, est_query_image, est_depth_image, est_fx, est_fy, est_cx, est_cy],
549
  outputs=[est_output, est_viz, est_mask]
550
  )
551
 
tests/reference/t_shape/README.md DELETED
@@ -1,63 +0,0 @@
1
- # T-Shaped Object Mesh
2
-
3
- This directory contains a 3D mesh of the T-shaped pushing object from the MuJoCo scene `nova-sim/robots/ur5/model/scene_t_push.xml`.
4
-
5
- ## Files
6
-
7
- - `t_shape.obj` - 3D mesh in Wavefront OBJ format
8
-
9
- ## Dimensions
10
-
11
- The T-shape consists of two rectangular boxes:
12
-
13
- ### Stem (vertical part)
14
- - Dimensions: 40mm × 140mm × 60mm (width × height × depth)
15
- - Position: centered at (0, -50mm, 0)
16
-
17
- ### Cap (horizontal part)
18
- - Dimensions: 160mm × 40mm × 60mm
19
- - Position: centered at (0, 30mm, 0)
20
-
21
- ### Overall Bounds
22
- - X: [-80mm, 80mm] (160mm total width)
23
- - Y: [-120mm, 50mm] (170mm total height)
24
- - Z: [-30mm, 30mm] (60mm total depth)
25
-
26
- ## Usage
27
-
28
- This mesh can be used with FoundationPose's CAD-based initialization mode for 6D pose estimation of the T-shaped object in the nova-sim push manipulation task.
29
-
30
- ### Example Usage
31
-
32
- ```python
33
- from gradio_client import Client, handle_file
34
-
35
- client = Client("https://gpue-foundationpose.hf.space")
36
-
37
- # Initialize with T-shape mesh
38
- result = client.predict(
39
- object_id="t_shape",
40
- mesh_file=handle_file("t_shape.obj"),
41
- reference_files=[], # Optional reference images
42
- fx=500.0, fy=500.0, cx=320.0, cy=240.0,
43
- api_name="/gradio_initialize_cad"
44
- )
45
-
46
- # Estimate pose in query image
47
- result = client.predict(
48
- object_id="t_shape",
49
- query_image=handle_file("camera_frame.jpg"),
50
- fx=500.0, fy=500.0, cx=320.0, cy=240.0,
51
- api_name="/gradio_estimate"
52
- )
53
- ```
54
-
55
- ## Material Properties (from MuJoCo)
56
-
57
- - Mass: 5.0 kg total (stem: 3.0 kg, cap: 2.0 kg)
58
- - Friction: 0.3 (sliding), 0.005 (torsional), 0.005 (rolling)
59
- - Color: Light blue (rgba: 0.55, 0.65, 0.98, 1.0)
60
-
61
- ## Generation
62
-
63
- This mesh was automatically generated from the MuJoCo scene definition using a Python script that extracts the box geometries and creates a combined mesh.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/reference/t_shape/depth_001.png ADDED
tests/reference/t_shape/image_001.jpg DELETED
Binary file (4.79 kB)
 
tests/reference/t_shape/image_002.jpg DELETED
Binary file (4.96 kB)
 
tests/reference/t_shape/image_004.jpg DELETED
Binary file (4.43 kB)
 
tests/reference/t_shape/image_005.jpg DELETED
Binary file (4.91 kB)
 
tests/reference/t_shape/image_006.jpg DELETED
Binary file (4.69 kB)
 
tests/reference/t_shape/image_007.jpg DELETED
Binary file (4.67 kB)
 
tests/reference/t_shape/image_008.jpg DELETED
Binary file (4.86 kB)
 
tests/reference/t_shape/image_009.jpg DELETED
Binary file (4.49 kB)
 
tests/reference/t_shape/image_010.jpg DELETED
Binary file (4.9 kB)
 
tests/reference/t_shape/image_011.jpg DELETED
Binary file (4.3 kB)
 
tests/reference/t_shape/image_012.jpg DELETED
Binary file (4.56 kB)
 
tests/reference/t_shape/image_013.jpg DELETED
Binary file (4.97 kB)
 
tests/reference/t_shape/image_014.jpg DELETED
Binary file (4.64 kB)
 
tests/reference/t_shape/image_015.jpg DELETED
Binary file (4.74 kB)
 
tests/reference/t_shape/{image_003.jpg → rgb_001.jpg} RENAMED
File without changes
tests/test_estimator.py CHANGED
@@ -8,7 +8,7 @@ This test verifies that the API can:
8
 
9
  import sys
10
  from pathlib import Path
11
- import random
12
  import cv2
13
  from gradio_client import Client, handle_file
14
 
@@ -18,24 +18,38 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
18
  from client import FoundationPoseClient
19
 
20
 
21
- def load_reference_images(reference_dir: Path):
22
- """Load all reference images from directory."""
23
- # Get all jpg and png files, excluding mesh files
24
- image_files = sorted([
25
- f for f in reference_dir.glob("*")
26
- if f.suffix.lower() in ['.jpg', '.png']
27
- ])
28
- images = []
29
 
30
- for img_path in image_files:
31
- # Use cv2 to load images (same as client.py)
32
- img = cv2.imread(str(img_path))
33
- if img is None:
34
- continue
35
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
36
- images.append(img)
37
 
38
- return images, image_files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  def test_client_initialization():
@@ -178,11 +192,16 @@ def main():
178
  return
179
 
180
  print(f"\nUsing T-shape mesh: {mesh_path}")
181
- print(f"Using query images from: {reference_dir}")
182
 
183
- # Load reference images (will be used as query images)
184
- reference_images, image_files = load_reference_images(reference_dir)
185
- print(f"✓ Loaded {len(reference_images)} query images")
 
 
 
 
 
186
 
187
  # Test 1: Initialize API client
188
  client = test_client_initialization()
@@ -200,12 +219,8 @@ def main():
200
  print("=" * 60)
201
  return
202
 
203
- # Test 3: Estimate pose on a random query image
204
- random_idx = random.randint(0, len(reference_images) - 1)
205
- query_image = reference_images[random_idx]
206
- query_name = image_files[random_idx].name
207
-
208
- success = test_pose_estimation(client, query_image, query_name)
209
 
210
  # Print final results
211
  print("\n" + "=" * 60)
 
8
 
9
  import sys
10
  from pathlib import Path
11
+ import numpy as np
12
  import cv2
13
  from gradio_client import Client, handle_file
14
 
 
18
  from client import FoundationPoseClient
19
 
20
 
21
+ def load_test_data(reference_dir: Path):
22
+ """Load RGB and depth test images from t_shape directory."""
23
+ rgb_path = reference_dir / "rgb_001.jpg"
24
+ depth_path = reference_dir / "depth_001.png"
 
 
 
 
25
 
26
+ # Load RGB image
27
+ print(f"Loading RGB: {rgb_path}")
28
+ rgb = cv2.imread(str(rgb_path))
29
+ if rgb is None:
30
+ raise FileNotFoundError(f"Could not load RGB image: {rgb_path}")
31
+ rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
 
32
 
33
+ # Load depth image (16-bit PNG)
34
+ print(f"Loading depth: {depth_path}")
35
+ depth = cv2.imread(str(depth_path), cv2.IMREAD_ANYDEPTH)
36
+ if depth is None:
37
+ raise FileNotFoundError(f"Could not load depth image: {depth_path}")
38
+
39
+ # Check if depth needs resizing to match RGB
40
+ if depth.shape[:2] != rgb.shape[:2]:
41
+ print(f"⚠ Warning: Depth ({depth.shape[:2]}) and RGB ({rgb.shape[:2]}) sizes don't match")
42
+ print(f" Resizing depth to match RGB...")
43
+ depth = cv2.resize(depth, (rgb.shape[1], rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
44
+
45
+ # Convert depth to meters (assuming it's in mm or similar)
46
+ # Depth should be float32 in meters for FoundationPose
47
+ depth = depth.astype(np.float32) / 1000.0 # Convert mm to meters
48
+
49
+ print(f"✓ RGB loaded: shape={rgb.shape}, dtype={rgb.dtype}")
50
+ print(f"✓ Depth loaded: shape={depth.shape}, dtype={depth.dtype}, range=[{depth.min():.3f}, {depth.max():.3f}]m")
51
+
52
+ return rgb, depth
53
 
54
 
55
  def test_client_initialization():
 
192
  return
193
 
194
  print(f"\nUsing T-shape mesh: {mesh_path}")
195
+ print(f"Using test data from: {reference_dir}")
196
 
197
+ # Load test RGB and depth images
198
+ try:
199
+ rgb_image, depth_image = load_test_data(reference_dir)
200
+ except FileNotFoundError as e:
201
+ print(f"✗ {e}")
202
+ return
203
+
204
+ print(f"\n⚠ Note: API currently only supports RGB (depth support coming soon)")
205
 
206
  # Test 1: Initialize API client
207
  client = test_client_initialization()
 
219
  print("=" * 60)
220
  return
221
 
222
+ # Test 3: Estimate pose using RGB image
223
+ success = test_pose_estimation(client, rgb_image, "rgb_001.jpg")
 
 
 
 
224
 
225
  # Print final results
226
  print("\n" + "=" * 60)