# Copyright (c) Meta Platforms, Inc. and affiliates. """ SAM 3D Body (3DB) Mesh Alignment Utilities Handles alignment of 3DB meshes to SAM 3D Object, same as MoGe point cloud scale. """ import os import math import json import numpy as np import torch import trimesh from PIL import Image import torch.nn.functional as F from pytorch3d.structures import Meshes from pytorch3d.renderer import PerspectiveCameras, RasterizationSettings, MeshRasterizer, TexturesVertex from moge.model.v1 import MoGeModel def load_3db_mesh(mesh_path, device='cuda'): """Load 3DB mesh and convert from OpenGL to PyTorch3D coordinates.""" mesh = trimesh.load(mesh_path) vertices = np.array(mesh.vertices) faces = np.array(mesh.faces) # Convert from OpenGL to PyTorch3D coordinates vertices[:, 0] *= -1 # Flip X vertices[:, 2] *= -1 # Flip Z vertices = torch.from_numpy(vertices).float().to(device) faces = torch.from_numpy(faces).long().to(device) return vertices, faces def get_moge_pointcloud(image_tensor, device='cuda'): """Generate MoGe point cloud from image tensor.""" moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device) moge_model.eval() with torch.no_grad(): moge_output = moge_model.infer(image_tensor) return moge_output def denormalize_intrinsics(norm_K, height, width): """Convert normalized intrinsics to absolute pixel coordinates.""" cx_norm, cy_norm = norm_K[0, 2], norm_K[1, 2] fx_norm, fy_norm = norm_K[0, 0], norm_K[1, 1] fx_abs = fx_norm * width fy_abs = fy_norm * height cx_abs = cx_norm * width cy_abs = cy_norm * height fx_abs = fy_abs return np.array([ [fx_abs, 0.0, cx_abs], [0.0, fy_abs, cy_abs], [0.0, 0.0, 1.0] ]) def crop_mesh_with_mask(vertices, faces, focal_length, mask, device='cuda'): """Crop mesh vertices to only those visible in the mask.""" textures = TexturesVertex(verts_features=torch.ones_like(vertices)[None]) mesh = Meshes(verts=[vertices], faces=[faces], textures=textures) H, W = mask.shape[-2:] fx = fy = focal_length cx, cy = W / 2.0, H / 2.0 camera = PerspectiveCameras( focal_length=((fx, fy),), principal_point=((cx, cy),), image_size=((H, W),), in_ndc=False, device=device ) raster_settings = RasterizationSettings( image_size=(H, W), blur_radius=0.0, faces_per_pixel=1, cull_backfaces=False, bin_size=0, ) rasterizer = MeshRasterizer(cameras=camera, raster_settings=raster_settings) fragments = rasterizer(mesh) face_indices = fragments.pix_to_face[0, ..., 0] # (H, W) visible_mask = (mask > 0) & (face_indices >= 0) visible_face_ids = face_indices[visible_mask] visible_faces = faces[visible_face_ids] visible_vert_ids = torch.unique(visible_faces) verts_cropped = vertices[visible_vert_ids] return verts_cropped, visible_mask def extract_target_points(pointmap, visible_mask): """Extract target points from MoGe pointmap using visible mask.""" target_points = pointmap[visible_mask.bool()] # Convert from MoGe coordinates to PyTorch3D coordinates target_points[:, 0] *= -1 target_points[:, 1] *= -1 # Remove flying points using adaptive quantile filtering z_range = torch.max(target_points[:, 2]) - torch.min(target_points[:, 2]) if z_range > 6.0: thresh = 0.90 elif z_range > 2.0: thresh = 0.93 else: thresh = 0.95 depth_quantile = torch.quantile(target_points[:, 2], thresh) target_points = target_points[target_points[:, 2] <= depth_quantile] # Remove infinite values finite_mask = torch.isfinite(target_points).all(dim=1) target_points = target_points[finite_mask] return target_points def align_mesh_to_pointcloud(vertices, target_points): """Align mesh vertices to target point cloud using scale and translation.""" if target_points.shape[0] == 0: print("[WARNING] No target points for alignment!") return vertices, torch.tensor(1.0), torch.zeros(3) # Scale alignment based on height height_src = torch.max(vertices[:, 1]) - torch.min(vertices[:, 1]) height_tgt = torch.max(target_points[:, 1]) - torch.min(target_points[:, 1]) scale_factor = height_tgt / height_src vertices_scaled = vertices * scale_factor # Translation alignment based on centers center_src = torch.mean(vertices_scaled, dim=0) center_tgt = torch.mean(target_points, dim=0) translation = center_tgt - center_src vertices_aligned = vertices_scaled + translation return vertices_aligned, scale_factor, translation def load_mask_for_alignment(mask_path): """Load mask image as numpy array.""" mask = Image.open(mask_path).convert('L') mask_array = np.array(mask) / 255.0 return mask_array def load_focal_length_from_json(json_path): """Load focal length from JSON file.""" try: with open(json_path, 'r') as f: data = json.load(f) focal_length = data.get('focal_length') if focal_length is None: raise ValueError("'focal_length' key not found in JSON file") print(f"[INFO] Loaded focal length from {json_path}: {focal_length}") return focal_length except Exception as e: print(f"[ERROR] Failed to load focal length from {json_path}: {e}") raise def process_3db_alignment(mesh_path, mask_path, image_path, device='cuda', focal_length_json_path=None): """Complete pipeline for aligning 3DB mesh to MoGe scale.""" print(f"[INFO] Processing alignment...") # Load input data vertices, faces = load_3db_mesh(mesh_path, device) # Load and preprocess image image = Image.open(image_path).convert('RGB') image_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1) / 255.0 image_tensor = image_tensor.to(device) # Load mask and resize to match image H, W = image_tensor.shape[1:] mask = load_mask_for_alignment(mask_path) if mask.shape != (H, W): mask = Image.fromarray((mask * 255).astype(np.uint8)) mask = mask.resize((W, H), Image.NEAREST) mask = np.array(mask) / 255.0 mask = torch.from_numpy(mask).float().to(device) # Generate MoGe point cloud print("[INFO] Generating MoGe point cloud...") moge_output = get_moge_pointcloud(image_tensor, device) # Load focal length from JSON if provided, otherwise compute from MoGe intrinsics if focal_length_json_path is not None: focal_length = load_focal_length_from_json(focal_length_json_path) else: # Compute camera parameters from MoGe intrinsics (fallback) intrinsics = denormalize_intrinsics(moge_output['intrinsics'].cpu().numpy(), H, W) focal_length = intrinsics[1, 1] # Use fy print(f"[INFO] Using computed focal length from MoGe: {focal_length}") # Crop mesh using mask print("[INFO] Cropping mesh with mask...") verts_cropped, visible_mask = crop_mesh_with_mask(vertices, faces, focal_length, mask, device) # Extract target points from MoGe print("[INFO] Extracting target points...") target_points = extract_target_points(moge_output['points'], visible_mask) if target_points.shape[0] == 0: print("[ERROR] No valid target points found!") return None # Perform alignment print("[INFO] Aligning mesh to point cloud...") aligned_vertices, scale_factor, translation = align_mesh_to_pointcloud(verts_cropped, target_points) # Apply alignment to full mesh full_aligned_vertices = (vertices * scale_factor) + translation # Convert back to OpenGL coordinates for final output final_vertices_opengl = full_aligned_vertices.cpu().numpy() final_vertices_opengl[:, 0] *= -1 final_vertices_opengl[:, 2] *= -1 results = { 'aligned_vertices_opengl': final_vertices_opengl, 'faces': faces.cpu().numpy(), 'scale_factor': scale_factor.item(), 'translation': translation.cpu().numpy(), 'focal_length': focal_length, 'target_points_count': target_points.shape[0], 'cropped_vertices_count': verts_cropped.shape[0] } print(f"[INFO] Alignment completed - Scale: {scale_factor.item():.4f}, Target points: {target_points.shape[0]}") return results def process_and_save_alignment(mesh_path, mask_path, image_path, output_dir, device='cuda', focal_length_json_path=None): """ Complete pipeline for processing 3DB alignment and saving the result. Args: mesh_path: Path to input 3DB mesh (.ply) mask_path: Path to mask image (.png) image_path: Path to input image (.jpg) output_dir: Directory to save aligned mesh device: Device to use ('cuda' or 'cpu') focal_length_json_path: Optional path to focal length JSON file Returns: tuple: (success: bool, output_mesh_path: str or None, result_info: dict or None) """ try: print("[INFO] Starting 3DB mesh alignment pipeline...") # Ensure output directory exists os.makedirs(output_dir, exist_ok=True) # Process alignment result = process_3db_alignment( mesh_path=mesh_path, mask_path=mask_path, image_path=image_path, device=device, focal_length_json_path=focal_length_json_path ) if result is not None: # Save aligned mesh output_mesh_path = os.path.join(output_dir, 'human_aligned.ply') aligned_mesh = trimesh.Trimesh( vertices=result['aligned_vertices_opengl'], faces=result['faces'] ) aligned_mesh.export(output_mesh_path) print(f" SUCCESS! Saved aligned mesh to: {output_mesh_path}") return True, output_mesh_path, result else: print(" ERROR: Failed to process mesh alignment") return False, None, None except Exception as e: print(f" ERROR: Exception during processing: {e}") import traceback traceback.print_exc() return False, None, None finally: print(" Processing complete!") def visualize_meshes_interactive(aligned_mesh_path, dfy_mesh_path, output_dir=None, share=True, height=600): """ Interactive Gradio-based 3D visualization of aligned human and object meshes. Args: aligned_mesh_path: Path to aligned mesh PLY file dfy_mesh_path: Path to 3Dfy GLB file output_dir: Directory to save combined GLB file (defaults to same dir as aligned_mesh_path) share: Whether to create a public shareable link (default: True) height: Height of the 3D viewer in pixels (default: 600) Returns: tuple: (demo, combined_glb_path) - Gradio demo object and path to combined GLB file """ import gradio as gr print("Loading meshes for interactive visualization...") try: # Load aligned mesh (PLY) aligned_mesh = trimesh.load(aligned_mesh_path) print(f"Loaded aligned mesh: {len(aligned_mesh.vertices)} vertices") # Load 3Dfy mesh (GLB - handle scene structure) dfy_scene = trimesh.load(dfy_mesh_path) if hasattr(dfy_scene, 'dump'): # It's a scene dfy_meshes = [geom for geom in dfy_scene.geometry.values() if hasattr(geom, 'vertices')] if len(dfy_meshes) == 1: dfy_mesh = dfy_meshes[0] elif len(dfy_meshes) > 1: dfy_mesh = trimesh.util.concatenate(dfy_meshes) else: raise ValueError("No valid meshes in GLB file") else: dfy_mesh = dfy_scene print(f"Loaded 3Dfy mesh: {len(dfy_mesh.vertices)} vertices") # Create combined scene scene = trimesh.Scene() # Add both meshes with different colors aligned_copy = aligned_mesh.copy() aligned_copy.visual.vertex_colors = [255, 0, 0, 200] # Red for aligned human scene.add_geometry(aligned_copy, node_name="sam3d_aligned_human") dfy_copy = dfy_mesh.copy() dfy_copy.visual.vertex_colors = [0, 0, 255, 200] # Blue for 3Dfy object scene.add_geometry(dfy_copy, node_name="dfy_object") # Determine output path if output_dir is None: output_dir = os.path.dirname(aligned_mesh_path) os.makedirs(output_dir, exist_ok=True) combined_glb_path = os.path.join(output_dir, 'combined_scene.glb') scene.export(combined_glb_path) print(f"Exported combined scene to: {combined_glb_path}") # Create interactive Gradio viewer with gr.Blocks() as demo: gr.Markdown("# 3D Mesh Alignment Visualization") gr.Markdown("**Red**: SAM 3D Body Aligned Human | **Blue**: 3Dfy Object") gr.Model3D( value=combined_glb_path, label="Combined 3D Scene (Interactive)", height=height ) # Launch the viewer print("Launching interactive 3D viewer...") demo.launch(share=share) return demo, combined_glb_path except Exception as e: print(f"ERROR in visualization: {e}") import traceback traceback.print_exc() return None, None def visualize_meshes_comparison(aligned_mesh_path, dfy_mesh_path, use_interactive=False): """ Simple visualization of both meshes in a single 3D plot. DEPRECATED: Use visualize_meshes_interactive() for better interactive visualization. Args: aligned_mesh_path: Path to aligned mesh PLY file dfy_mesh_path: Path to 3Dfy GLB file use_interactive: Whether to attempt trimesh scene viewer (default: False) Returns: tuple: (aligned_mesh, dfy_mesh) trimesh objects or (None, None) if failed """ import matplotlib.pyplot as plt print("Loading meshes for visualization...") try: # Load aligned mesh (PLY) aligned_mesh = trimesh.load(aligned_mesh_path) print(f"Loaded aligned mesh: {len(aligned_mesh.vertices)} vertices") # Load 3Dfy mesh (GLB - handle scene structure) dfy_scene = trimesh.load(dfy_mesh_path) if hasattr(dfy_scene, 'dump'): # It's a scene dfy_meshes = [geom for geom in dfy_scene.geometry.values() if hasattr(geom, 'vertices')] if len(dfy_meshes) == 1: dfy_mesh = dfy_meshes[0] elif len(dfy_meshes) > 1: dfy_mesh = trimesh.util.concatenate(dfy_meshes) else: raise ValueError("No valid meshes in GLB file") else: dfy_mesh = dfy_scene print(f"Loaded 3Dfy mesh: {len(dfy_mesh.vertices)} vertices") # Create single 3D plot with both meshes fig = plt.figure(figsize=(12, 10)) ax = fig.add_subplot(111, projection='3d') # Plot both meshes in the same space ax.scatter(dfy_mesh.vertices[:, 0], dfy_mesh.vertices[:, 1], dfy_mesh.vertices[:, 2], c='blue', s=0.1, alpha=0.6, label='3Dfy Original') ax.scatter(aligned_mesh.vertices[:, 0], aligned_mesh.vertices[:, 1], aligned_mesh.vertices[:, 2], c='red', s=0.1, alpha=0.6, label='SAM 3D Body Aligned') ax.set_title('Mesh Comparison: 3Dfy vs SAM 3D Body Aligned', fontsize=16, fontweight='bold') ax.set_xlabel('X') ax.set_ylabel('Y') ax.set_zlabel('Z') ax.legend() plt.tight_layout() plt.show() # Optional trimesh scene viewer if use_interactive: try: print("Creating trimesh scene...") scene = trimesh.Scene() # Add both meshes with different colors aligned_copy = aligned_mesh.copy() aligned_copy.visual.vertex_colors = [255, 0, 0, 200] # Red scene.add_geometry(aligned_copy, node_name="sam3d_aligned") dfy_copy = dfy_mesh.copy() dfy_copy.visual.vertex_colors = [0, 0, 255, 200] # Blue scene.add_geometry(dfy_copy, node_name="dfy_original") print("Opening interactive trimesh viewer...") scene.show() except Exception as e: print(f"Trimesh viewer not available: {e}") print("Visualization complete") return aligned_mesh, dfy_mesh except Exception as e: print(f"ERROR in visualization: {e}") import traceback traceback.print_exc() return None, None