foundationpose / estimator.py
Georg
Prepare job build context
19d8da0
"""
FoundationPose model wrapper for inference.
This module wraps the FoundationPose library for 6D object pose estimation.
"""
import logging
import os
import sys
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
import torch
import cv2
from masks import generate_naive_mask
logger = logging.getLogger(__name__)
# Add FoundationPose to Python path
FOUNDATIONPOSE_ROOT = Path("/app/FoundationPose")
if FOUNDATIONPOSE_ROOT.exists():
sys.path.insert(0, str(FOUNDATIONPOSE_ROOT))
# Try to import FoundationPose modules
try:
from estimater import FoundationPose
from learning.training.predict_score import ScorePredictor
from learning.training.predict_pose_refine import PoseRefinePredictor
import nvdiffrast.torch as dr
import trimesh
FOUNDATIONPOSE_AVAILABLE = True
except ImportError as e:
logger.warning(f"FoundationPose modules not available: {e}")
FOUNDATIONPOSE_AVAILABLE = False
trimesh = None
class FoundationPoseEstimator:
"""Wrapper for FoundationPose model."""
def __init__(self, device: str = "cuda", weights_dir: str | None = None):
"""Initialize FoundationPose estimator.
Args:
device: Device to run inference on ('cuda' or 'cpu')
weights_dir: Directory containing model weights
"""
self.device = device
if weights_dir is None:
weights_dir = os.environ.get("FOUNDATIONPOSE_WEIGHTS_DIR", "/app/FoundationPose/weights")
self.weights_dir = Path(weights_dir)
self.registered_objects = {}
self.scorer = None
self.refiner = None
self.glctx = None
self.available = FOUNDATIONPOSE_AVAILABLE
# Check if FoundationPose is available
if not FOUNDATIONPOSE_ROOT.exists():
raise RuntimeError(
f"FoundationPose repository not found at {FOUNDATIONPOSE_ROOT}. "
"Clone it with: git clone https://github.com/NVlabs/FoundationPose.git"
)
if not FOUNDATIONPOSE_AVAILABLE:
logger.warning("FoundationPose modules not loaded - inference will not work")
return
# Check if weights exist
if not self.weights_dir.exists() or not any(self.weights_dir.glob("**/*.pth")):
logger.warning(f"No model weights found in {self.weights_dir}")
logger.warning("Model will not work without weights")
# Initialize predictors (lazy loading - only when needed)
logger.info(f"FoundationPose estimator initialized (device: {device})")
def register_object(
self,
object_id: str,
reference_images: List[np.ndarray],
camera_intrinsics: Optional[Dict] = None,
mesh_path: Optional[str] = None
) -> bool:
"""Register an object for tracking.
Args:
object_id: Unique identifier for the object
reference_images: List of RGB reference images (H, W, 3)
camera_intrinsics: Camera parameters {fx, fy, cx, cy}
mesh_path: Optional path to object mesh file
Returns:
True if registration successful
"""
try:
# Load mesh if provided
mesh = None
if mesh_path and Path(mesh_path).exists():
if trimesh is None:
logger.warning("trimesh not available, skipping mesh load")
else:
try:
mesh = trimesh.load(mesh_path)
logger.info(f"Loaded mesh for '{object_id}' from {mesh_path}")
except Exception as e:
logger.warning(f"Failed to load mesh: {e}")
# Store object registration
self.registered_objects[object_id] = {
"num_references": len(reference_images),
"camera_intrinsics": camera_intrinsics,
"mesh_path": mesh_path,
"mesh": mesh,
"reference_images": reference_images,
"estimator": None, # Will be created lazily
"pose_last": None # Track last pose for temporal tracking
}
logger.info(f"✓ Registered object '{object_id}' with {len(reference_images)} reference images")
return True
except Exception as e:
logger.error(f"Failed to register object '{object_id}': {e}", exc_info=True)
return False
def estimate_pose(
self,
object_id: str,
rgb_image: np.ndarray,
depth_image: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
camera_intrinsics: Optional[Dict] = None
) -> Optional[Dict]:
"""Estimate 6D pose of registered object in image.
Args:
object_id: ID of object to detect
rgb_image: RGB query image (H, W, 3)
depth_image: Optional depth image (H, W)
mask: Optional object mask (H, W)
camera_intrinsics: Camera parameters {fx, fy, cx, cy}
Returns:
Pose dictionary with position, orientation, confidence or None
"""
if object_id not in self.registered_objects:
logger.error(f"Object '{object_id}' not registered")
return None
if not FOUNDATIONPOSE_AVAILABLE:
logger.error("FoundationPose not available")
return None
try:
obj_data = self.registered_objects[object_id]
# Initialize predictors if not done yet
if self.scorer is None:
logger.info("Initializing score predictor...")
self.scorer = ScorePredictor()
logger.info("Initializing pose refiner...")
self.refiner = PoseRefinePredictor()
logger.info("Initializing CUDA rasterizer...")
self.glctx = dr.RasterizeCudaContext()
# Initialize object-specific estimator if not done yet
if obj_data["estimator"] is None:
logger.info(f"Creating FoundationPose estimator for '{object_id}'...")
mesh = obj_data["mesh"]
if mesh is not None:
# Model-based mode: use mesh
logger.info("Using model-based mode with mesh")
obj_data["estimator"] = FoundationPose(
model_pts=mesh.vertices,
model_normals=mesh.vertex_normals,
mesh=mesh,
scorer=self.scorer,
refiner=self.refiner,
glctx=self.glctx,
debug=0
)
else:
# Model-free mode: requires 3D reconstruction from reference images
# This would typically use structure-from-motion (SfM) to create a mesh
# For now, this is not implemented
logger.error("Model-free mode not yet implemented")
logger.error("To use FoundationPose, please provide a 3D mesh (.obj, .stl, .ply)")
logger.error("You can:")
logger.error(" 1. Use CAD-Based initialization with your object's 3D model")
logger.error(" 2. Create a mesh from photos using photogrammetry tools (e.g., Meshroom, COLMAP)")
logger.error(" 3. Scan the object with a 3D scanner")
return None
estimator = obj_data["estimator"]
# Prepare camera intrinsics matrix
K = self._get_camera_matrix(camera_intrinsics or obj_data["camera_intrinsics"])
if K is None:
logger.error("Camera intrinsics required")
return None
# Generate or use depth if not provided
if depth_image is None:
# Create dummy depth for model-based case
# Use a more realistic depth distribution centered at 0.5m with some variation
depth_image = np.ones((rgb_image.shape[0], rgb_image.shape[1]), dtype=np.float32) * 0.5
logger.warning("Using dummy depth image - for better results, provide actual depth data")
# Generate mask if not provided
mask_was_generated = False
debug_mask = None
if mask is None:
# Use automatic foreground segmentation based on brightness
# This works well for light objects on dark backgrounds
logger.info("Generating automatic object mask from image")
mask, debug_mask, mask_percentage, fallback_full_image = generate_naive_mask(rgb_image)
logger.info(f"Auto-generated mask covers {mask_percentage:.1f}% of image")
if fallback_full_image:
logger.warning(
f"Mask coverage ({mask_percentage:.1f}%) seems unrealistic, using full image"
)
mask_was_generated = True
# First frame or lost tracking: register
if obj_data["pose_last"] is None:
logger.info("Running registration (first frame)...")
pose = estimator.register(
K=K,
rgb=rgb_image,
depth=depth_image,
ob_mask=mask,
iteration=5 # Number of refinement iterations
)
else:
# Subsequent frames: track
pose = estimator.track_one(
rgb=rgb_image,
depth=depth_image,
K=K,
iteration=2 # Fewer iterations for tracking
)
# Store pose for next frame (move to CPU if it's a tensor)
if torch.is_tensor(pose):
pose = pose.detach().cpu().numpy()
obj_data["pose_last"] = pose
if pose is None:
logger.warning("Pose estimation returned None")
return None
# Convert pose to our format
# pose is a 4x4 transformation matrix
result = self._format_pose_output(pose)
# Add debug mask if it was auto-generated
if mask_was_generated and debug_mask is not None:
result["debug_mask"] = debug_mask
return result
except Exception as e:
logger.error(f"Pose estimation failed: {e}", exc_info=True)
import traceback
traceback.print_exc()
return None
def _get_camera_matrix(self, intrinsics: Optional[Dict]) -> Optional[np.ndarray]:
"""Convert intrinsics dict to camera matrix."""
if intrinsics is None:
return None
fx = intrinsics.get("fx")
fy = intrinsics.get("fy")
cx = intrinsics.get("cx")
cy = intrinsics.get("cy")
if None in [fx, fy, cx, cy]:
return None
K = np.array([
[fx, 0, cx],
[0, fy, cy],
[0, 0, 1]
], dtype=np.float64)
return K
def _format_pose_output(self, pose_matrix: np.ndarray) -> Dict:
"""Convert 4x4 pose matrix to output format.
Args:
pose_matrix: 4x4 transformation matrix
Returns:
Dictionary with position, orientation (quaternion), and confidence
"""
if torch.is_tensor(pose_matrix):
pose_matrix = pose_matrix.detach().cpu().numpy()
# Extract translation
translation = pose_matrix[:3, 3]
# Extract rotation matrix
rotation_matrix = pose_matrix[:3, :3]
# Convert rotation matrix to quaternion
# Using Shepperd's method for numerical stability
trace = np.trace(rotation_matrix)
if trace > 0:
s = np.sqrt(trace + 1.0) * 2
w = 0.25 * s
x = (rotation_matrix[2, 1] - rotation_matrix[1, 2]) / s
y = (rotation_matrix[0, 2] - rotation_matrix[2, 0]) / s
z = (rotation_matrix[1, 0] - rotation_matrix[0, 1]) / s
elif rotation_matrix[0, 0] > rotation_matrix[1, 1] and rotation_matrix[0, 0] > rotation_matrix[2, 2]:
s = np.sqrt(1.0 + rotation_matrix[0, 0] - rotation_matrix[1, 1] - rotation_matrix[2, 2]) * 2
w = (rotation_matrix[2, 1] - rotation_matrix[1, 2]) / s
x = 0.25 * s
y = (rotation_matrix[0, 1] + rotation_matrix[1, 0]) / s
z = (rotation_matrix[0, 2] + rotation_matrix[2, 0]) / s
elif rotation_matrix[1, 1] > rotation_matrix[2, 2]:
s = np.sqrt(1.0 + rotation_matrix[1, 1] - rotation_matrix[0, 0] - rotation_matrix[2, 2]) * 2
w = (rotation_matrix[0, 2] - rotation_matrix[2, 0]) / s
x = (rotation_matrix[0, 1] + rotation_matrix[1, 0]) / s
y = 0.25 * s
z = (rotation_matrix[1, 2] + rotation_matrix[2, 1]) / s
else:
s = np.sqrt(1.0 + rotation_matrix[2, 2] - rotation_matrix[0, 0] - rotation_matrix[1, 1]) * 2
w = (rotation_matrix[1, 0] - rotation_matrix[0, 1]) / s
x = (rotation_matrix[0, 2] + rotation_matrix[2, 0]) / s
y = (rotation_matrix[1, 2] + rotation_matrix[2, 1]) / s
z = 0.25 * s
return {
"position": {
"x": float(translation[0]),
"y": float(translation[1]),
"z": float(translation[2])
},
"orientation": {
"w": float(w),
"x": float(x),
"y": float(y),
"z": float(z)
},
"confidence": 1.0, # FoundationPose doesn't provide explicit confidence
"pose_matrix": pose_matrix.tolist()
}