# Copyright (c) Meta Platforms, Inc. and affiliates. import torch import numpy as np import open3d as o3d import trimesh from pytorch3d.structures import Meshes from pytorch3d.transforms import quaternion_to_matrix, Transform3d, matrix_to_quaternion from sam3d_objects.data.dataset.tdfy.transforms_3d import compose_transform, decompose_transform from sam3d_objects.data.dataset.tdfy.pose_target import PoseTargetConverter from loguru import logger from sam3d_objects.pipeline.layout_post_optimization_utils import ( run_ICP, compute_iou, set_seed, apply_transform, get_mesh, get_mask_renderer, run_alignment, run_render_compare, check_occlusion, ) SLAT_STD = torch.tensor( [ 2.377650737762451, 2.386378288269043, 2.124418020248413, 2.1748552322387695, 2.663944721221924, 2.371192216873169, 2.6217446327209473, 2.684523105621338, ] ) SLAT_MEAN = torch.tensor( [ -2.1687545776367188, -0.004347046371549368, -0.13352349400520325, -0.08418072760105133, -0.5271206498146057, 0.7238689064979553, -1.1414450407028198, 1.2039363384246826, ] ) ROTATION_6D_MEAN = torch.tensor( [ -0.06366084883674913, 0.008438224692279752, 0.00017084786438302483, 0.0007126610473540038, -0.0030916726538816417, 0.5166093753457688, ] ) ROTATION_6D_STD = torch.tensor( [ 0.6656971967514863, 0.6787012271867754, 0.30345010594844524, 0.4394504420678794, 0.39817973931717104, 0.6176286868761914, ] ) def layout_post_optimization( Mesh, Quaternion, Translation, Scale, Mask, Point_Map, Intrinsics, Enable_shape_ICP=True, Enable_rendering_optimization=True, min_size=512, device=None, ): set_seed(100) if device is None: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # init transform and process mesh Rotation = quaternion_to_matrix(Quaternion.squeeze(1)) center = Translation[0].clone() tfm_ori = compose_transform(scale=Scale, rotation=Rotation, translation=Translation) mesh, faces_idx, textures = get_mesh(Mesh, tfm_ori, device) # get mask and renderer mask, renderer = get_mask_renderer(Mask, min_size, Intrinsics, device) # check occlusion if check_occlusion(mask[0, 0].cpu().numpy(), Point_Map.cpu().numpy()): return ( Quaternion, Translation, Scale, -1.0, False, False, ) # Step 1: Manual Alignment source_points, target_points, center, tfm1, mesh, ori_iou, final_iou, flag_notgt = ( run_alignment( Point_Map, mask, mesh, center, faces_idx, textures, renderer, device ) ) # return original layout if no target points. if flag_notgt: return ( Quaternion, Translation, Scale, -1.0, False, False, ) # Step 2: Shape ICP if Enable_shape_ICP: Flag_ICP = True points_aligned_icp, transformation = run_ICP( mesh, source_points, target_points, threshold=0.05 ) mesh_ICP = Meshes( verts=[points_aligned_icp], faces=[faces_idx], textures=textures ) rendered = renderer(mesh_ICP) ori_iou_shapeICP = compute_iou( rendered[..., 3][0][None, None], mask, threshold=0.5 ) # determine whether accept ICP if ori_iou_shapeICP > ori_iou: mesh = mesh_ICP final_iou = ori_iou_shapeICP.cpu().item() T_o3d = torch.tensor(transformation, dtype=torch.float32, device=device) T_o3d = T_o3d.T A = T_o3d[:3, :3] t = T_o3d[3, :3] scale = A.norm(dim=1) R = A / scale[:, None] center = ((center[None] * scale) @ R + t)[0] # transform center tfm2 = ( Transform3d(device=device) .scale(scale[None]) .rotate(R[None]) .translate(t[None]) ) else: Flag_ICP = False scale_2, translation_2 = torch.tensor(1).to(device), torch.zeros([3]).to( device ) tfm2 = ( Transform3d(device=device) .scale(scale_2.expand(3)[None]) .translate(translation_2[None]) ) else: Flag_ICP = False scale_2, translation_2 = torch.tensor(1).to(device), torch.zeros([3]).to(device) tfm2 = ( Transform3d(device=device) .scale(scale_2.expand(3)[None]) .translate(translation_2[None]) ) # Step 3: Render-and-Compare if not Enable_rendering_optimization: Flag_optim = False tfm = tfm_ori.compose(tfm1).compose(tfm2) else: quat, translation, scale, R = run_render_compare( mesh, center, renderer, mask, device ) with torch.no_grad(): transformed = apply_transform(mesh, center, quat, translation, scale) rendered = renderer(transformed) optimized_iou = compute_iou( rendered[..., 3][0][None, None], mask, threshold=0.5 ) # Criterior to use layout optimization if optimized_iou < 0.5 or optimized_iou <= ori_iou: Flag_optim = False tfm = tfm_ori # reject manual alignment and ICP as well. # tfm = tfm_ori.compose(tfm1).compose(tfm2) # only reject render-compare but keep manual alignment and ICP. else: Flag_optim = True final_iou = optimized_iou.detach().cpu().item() tfm3 = ( Transform3d(device=device) .translate(-center[None]) # move to center .scale(scale.expand(3)[None]) .rotate(R.T[None]) .translate(center[None]) # move back .translate(translation[None]) ) tfm = tfm_ori.compose(tfm1).compose(tfm2).compose(tfm3) M = tfm.get_matrix()[0] T_final = M[3, :3][None] A = M[:3, :3] scale_final = A.norm(dim=1)[None] R_final = A / scale_final[:, None] quat_final = matrix_to_quaternion(R_final) return ( quat_final, T_final, scale_final, round(float(final_iou), 4), Flag_ICP, Flag_optim, ) def pose_decoder( pose_target_convention, ): def decode(model_output_dict, scene_scale=None, scene_shift=None): x = model_output_dict # BEGIN: copied from generative.py key_mapping = { "shape": "x_shape_latent", "quaternion": "x_instance_rotation", "6drotation": "x_instance_rotation_6d", "6drotation_normalized": "x_instance_rotation_6d_normalized", "translation": "x_instance_translation", "scale": "x_instance_scale", "translation_scale": "x_translation_scale", } # Decodes for metrics pose_target_dict = {} for k, v in x.items(): pose_target_dict[key_mapping.get(k, k)] = v # TODO: Hao & Bowen please do clean this up! # Convert 6D rotation to quaternion if needed if ( "x_instance_rotation_6d" in pose_target_dict or "x_instance_rotation_6d_normalized" in pose_target_dict ): # Extract the two 3D vectors if "x_instance_rotation_6d_normalized" in pose_target_dict: rot_6d = pose_target_dict[ "x_instance_rotation_6d_normalized" ] * ROTATION_6D_STD.to( pose_target_dict["x_instance_rotation_6d_normalized"].device ) + ROTATION_6D_MEAN.to( pose_target_dict["x_instance_rotation_6d_normalized"].device ) else: rot_6d = pose_target_dict["x_instance_rotation_6d"] a1 = rot_6d[..., 0:3] a2 = rot_6d[..., 3:6] # Normalize first vector b1 = torch.nn.functional.normalize(a1, dim=-1) # Make second vector orthogonal to first b2 = a2 - torch.sum(b1 * a2, dim=-1, keepdim=True) * b1 b2 = torch.nn.functional.normalize(b2, dim=-1) # Compute third vector as cross product b3 = torch.cross(b1, b2, dim=-1) # Stack to create rotation matrix rotation_matrix = torch.stack([b1, b2, b3], dim=-1) # Convert to quaternion quaternion = matrix_to_quaternion(rotation_matrix) pose_target_dict["x_instance_rotation"] = quaternion if "x_instance_scale" in pose_target_dict: pose_target_dict["x_instance_scale"] = torch.exp( pose_target_dict["x_instance_scale"] ) if "x_translation_scale" in pose_target_dict: pose_target_dict["x_translation_scale"] = torch.exp( pose_target_dict["x_translation_scale"] ) pose_target_dict["pose_target_convention"] = [pose_target_convention] * x[ "shape" ].shape[0] # END: copied from generative.py # Fake pointmap moments device = x["shape"].device _scene_scale = ( scene_scale if scene_scale is not None else torch.tensor(1.0, device=device) ) _scene_shift = ( scene_shift if scene_shift is not None else torch.tensor([[0, 0, 0]], device=device) ) pose_target_dict["x_scene_scale"] = _scene_scale pose_target_dict["x_scene_center"] = _scene_shift # Convert to instance pose pose_instance_dict = PoseTargetConverter.dicts_pose_target_to_instance_pose( pose_target_convention=pose_target_convention, x_instance_scale=pose_target_dict["x_instance_scale"], x_instance_translation=pose_target_dict["x_instance_translation"], x_instance_rotation=pose_target_dict["x_instance_rotation"], x_translation_scale=pose_target_dict["x_translation_scale"], x_scene_scale=pose_target_dict["x_scene_scale"], x_scene_center=pose_target_dict["x_scene_center"], ) return { "translation": pose_instance_dict["instance_position_l2c"].squeeze(0), "rotation": pose_instance_dict["instance_quaternion_l2c"].squeeze(0), "scale": pose_instance_dict["instance_scale_l2c"].squeeze(0).mean(-1, keepdim=True).expand(1,3), } return decode def zero_prediction_decoder(): def decode(model_output_dict, scene_scale=None, scene_shift=None): import copy from loguru import logger _pose_decoder = pose_decoder("ScaleShiftInvariant") model_output_dict = copy.deepcopy(model_output_dict) logger.warning("Overwriting predictions to zero prediction") model_output_dict["translation"] = torch.zeros_like(model_output_dict["translation"]) model_output_dict["translation_scale"] = torch.zeros_like(model_output_dict["translation_scale"]) model_output_dict["scale"] = torch.zeros_like(model_output_dict["scale"]) + 1.337 # Empirical average on R3 return _pose_decoder(model_output_dict, scene_scale, scene_shift) return decode def get_default_pose_decoder(): def decode(model_output_dict, **kwargs): return {} return decode POSE_DECODERS = { "default": get_default_pose_decoder(), "ApparentSize": pose_decoder("ApparentSize"), "DisparitySpace": pose_decoder("DisparitySpace"), "ScaleShiftInvariant": pose_decoder("ScaleShiftInvariant"), "ZeroPredictionScaleShiftInvariant": zero_prediction_decoder(), } def get_pose_decoder(name): if name not in POSE_DECODERS: raise NotImplementedError return POSE_DECODERS[name] def prune_sparse_structure( coord_batch, max_neighbor_axes_dist=1, ): coords, batch = coord_batch[:, 1:], coord_batch[:, 0].unsqueeze(-1) device = coords.device # 1) shift coords so minimum is zero min_xyz = coords.min(0)[0] coords0 = coords - min_xyz # 2) build occupancy grid max_xyz = coords0.max(0)[0] + 1 # size in each dim D, H, W = max_xyz.tolist() # shape (1,1,D,H,W) occ = torch.zeros((1, 1, D, H, W), dtype=torch.uint8, device=device) x, y, z = coords0.unbind(1) occ[0, 0, x, y, z] = 1 # 3) 3×3×3 convolution to count each voxel + neighbors kernel = torch.ones( ( 1, 1, 2 * max_neighbor_axes_dist + 1, 2 * max_neighbor_axes_dist + 1, 2 * max_neighbor_axes_dist + 1, ), dtype=torch.uint8, device=device, ) # pad so output is same size pad = max_neighbor_axes_dist counts = torch.nn.functional.conv3d(occ.float(), kernel.float(), padding=pad) # interior voxels have count == (2*max_neighbor_axes_dist+1)**3 full_count = (2 * max_neighbor_axes_dist + 1) ** 3 # 4) lookup counts at each original coord counts_at_pts = counts[0, 0, x, y, z] # (N,) is_surface = counts_at_pts < full_count # 5) return filtered batch+coords (shift back if you want original coords) kept = is_surface.nonzero(as_tuple=False).squeeze(1) out_batch = batch[kept] out_coords = coords[kept] coords = torch.cat([out_batch, out_coords], dim=1) return torch.cat([out_batch, out_coords], dim=1) def downsample_sparse_structure( coord_batch, max_coords=42000, downsample_factor=2, ): """ Downsample sparse structure coordinates when there are more than max_coords. Downsamples by rescaling coordinates, effectively shrinking the grid while preserving the structure. The downsampled grid is centered in the original space. Args: coord_batch: tensor of shape (N, 4) where [:, 0] is batch index and [:, 1:] are coords max_coords: maximum number of coordinates to keep 42000 should be safe number. Calculation: max(int32) / (64*768) ~= 43691 Only needed for mesh decoding. downsample_factor: factor by which to downsample (e.g., 2 means half resolution) Returns: Downsampled coord_batch with coordinates rescaled if downsampling is needed """ if coord_batch.shape[0] <= max_coords: return coord_batch, 1 # Extract coordinates and batch indices coords = coord_batch[:, 1:].float() # Shape: (N, 3), convert to float for scaling batch_indices = coord_batch[:, 0:1] # Shape: (N, 1) # Find the actual coordinate bounds coords_min = coords.min(dim=0)[0] # Shape: (3,) coords_max = coords.max(dim=0)[0] # Shape: (3,) original_size = coords_max - coords_min + 1 # Add 1 since coordinates are discrete # Calculate target size after downsampling target_size = original_size / downsample_factor # Calculate the offset to center the downsampled grid offset = (original_size - target_size) / 2 target_min = coords_min + offset target_max = coords_min + offset + target_size - 1 # Normalize coordinates to [0, 1] within their actual range coords_normalized = (coords - coords_min) / (coords_max - coords_min) # Scale to the target range coords_rescaled = coords_normalized * (target_size - 1) + target_min # Round to integers to get discrete grid coordinates coords_rescaled = torch.round(coords_rescaled).int() # Clamp to ensure we stay within bounds coords_rescaled = torch.clamp(coords_rescaled, target_min.int(), target_max.int()) # Remove duplicates that may have been created by the downsampling # Concatenate batch and coords for duplicate removal combined = torch.cat([batch_indices, coords_rescaled], dim=1) unique_combined = torch.unique(combined, dim=0) # If still too many after deduplication, randomly subsample if unique_combined.shape[0] > max_coords: indices = torch.randperm(unique_combined.shape[0], device=coord_batch.device)[ :max_coords ] unique_combined = unique_combined[indices] return unique_combined.int(), downsample_factor def normalize_mesh_verts(verts): vmin = verts.min(axis=0) vmax = verts.max(axis=0) center = (vmax + vmin) / 2.0 extent = vmax - vmin # largest side length max_extent = np.max(extent) if max_extent == 0: vertices = verts - center scale = 1 else: scale = 1.0 / max_extent vertices = (verts - center) * scale return vertices, scale, center def voxelize_mesh(mesh, resolution=64): verts = np.asarray(mesh.vertices) # rotate mesh (from z-up to y-up) verts = verts @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]).T # normalize vertices # skip vertices to avoid losing points, likely already normalized if np.abs(verts.min() + 0.5) < 1e-3 and np.abs(verts.max() - 0.5) < 1e-3: vertices, scale, center = verts, None, None else: vertices, scale, center = normalize_mesh_verts(verts) vertices = np.clip(vertices, -0.5 + 1e-6, 0.5 - 1e-6) mesh.vertices = o3d.utility.Vector3dVector(vertices) voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds( mesh, voxel_size=1 / 64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5), ) vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()]) vertices = (vertices + 0.5) / 64 - 0.5 coords = ((torch.tensor(vertices) + 0.5) * resolution).int().contiguous() ss = torch.zeros(1, resolution, resolution, resolution, dtype=torch.long) ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1 return ss, scale, center def preprocess_mesh(mesh: trimesh.Trimesh): verts = mesh.vertices if np.abs(verts.min() + 0.5) < 1e-3 and np.abs(verts.max() - 0.5) < 1e-3: return mesh vertices, _, _ = normalize_mesh_verts(verts) mesh.vertices = vertices return mesh def trimesh2o3d_mesh(trimesh_mesh): verts = np.asarray(trimesh_mesh.vertices) faces = np.asarray(trimesh_mesh.faces) return o3d.geometry.TriangleMesh( o3d.utility.Vector3dVector(verts), o3d.utility.Vector3iVector(faces) ) def update_layout(pred_t, pred_s, pred_quat, center, scale, to_halo=True): if center is None and not to_halo: return pred_t, pred_s, pred_quat pred_transform = compose_transform( pred_s, quaternion_to_matrix(pred_quat[0]), pred_t ) if center is None: comb_transform = pred_transform else: norm_transform = compose_transform( scale * torch.ones_like(pred_t), torch.eye(3, dtype=pred_t.dtype).to(pred_t.device)[None], scale * -torch.tensor(center, dtype=pred_t.dtype).to(pred_t.device)[None], ) comb_transform = norm_transform.compose(pred_transform) comb_transform = convert_to_halo(comb_transform, pred_t.device, pred_t.dtype) decomposed = decompose_transform(comb_transform) quat = matrix_to_quaternion(decomposed.rotation) return decomposed.translation, decomposed.scale, quat def convert_to_halo(pred_transform, device, dtype): on_mesh_transform = Transform3d(dtype=dtype, device=device).rotate( torch.tensor( [ [1, 0, 0], [0, 0, 1], [0, -1, 0], ], dtype=dtype, ) ) on_pm_transform = Transform3d(dtype=dtype, device=device).rotate( torch.tensor( [ [-1, 0, 0], [0, -1, 0], [0, 0, 1], ], dtype=dtype, ) ) return on_mesh_transform.compose(pred_transform).compose(on_pm_transform) def quat_wxyz_to_euler_XYZ(q: torch.Tensor) -> torch.Tensor: """ Convert PyTorch3D quaternions (w,x,y,z) to SciPy-style Euler angles with sequence 'XYZ' (extrinsic, radians). Works with batch dims. Args: q: (..., 4) tensor in w,x,y,z order. Doesn't need to be normalized. Returns: angles: (..., 3) tensor [alpha_X, beta_Y, gamma_Z] in radians. """ q = q / q.norm(dim=-1, keepdim=True) # normalize R = quaternion_to_matrix(q) # (..., 3, 3) R = R.transpose(-1, -2) r00 = R[..., 0, 0] r10 = R[..., 1, 0] r20 = R[..., 2, 0] r21 = R[..., 2, 1] r22 = R[..., 2, 2] # For extrinsic XYZ (R = Rz(gamma) @ Ry(beta) @ Rx(alpha)): # beta = atan2(-r20, sqrt(r00^2 + r10^2)) # alpha = atan2(r21, r22) # gamma = atan2(r10, r00) eps = torch.finfo(R.dtype).eps beta = torch.atan2(-r20, torch.clamp((r00 * r00 + r10 * r10).sqrt(), min=eps)) alpha = torch.atan2(r21, r22) gamma = torch.atan2(r10, r00) return -torch.stack((alpha, beta, gamma), dim=-1) def format_to_halo(layout_output): json_out = {} quaternion = layout_output["quaternion"][0, 0] translation = layout_output["translation"][0] scale = list(layout_output["scale"][0]) euler = quat_wxyz_to_euler_XYZ(quaternion) json_out["roll"] = float(euler[0]) json_out["pitch"] = float(euler[1]) json_out["yaw"] = float(euler[2]) json_out["pred_scale"] = [float(s) for s in scale] rot_matrix = quaternion_to_matrix(quaternion) pred_transform = torch.eye(4, dtype=quaternion.dtype).to(quaternion.device) pred_transform[:3, :3] = rot_matrix pred_transform[:3, 3] = translation pred_transform_list = [ [float(t) for t in trans_row] for trans_row in pred_transform ] json_out["pred_transform"] = pred_transform_list return json_out def json_to_halo_payloads(target_data): pred_transform = target_data["pred_transform"] pred_scale = target_data["pred_scale"] roll = target_data.get("roll", 0) pitch = target_data.get("pitch", 0) yaw = target_data.get("yaw", 0) # Update positions, rotation, and scale in the payload item_attachments = {} item_attachments["positions"] = { "x": pred_transform[0][3], "y": pred_transform[1][3], "z": pred_transform[2][3] - 1, # Adjust for Halo design } item_attachments["rotation"] = {"x": roll, "y": pitch, "z": yaw} item_attachments["scale"] = { "x": pred_scale[0], "y": pred_scale[1], "z": pred_scale[2], } return item_attachments def o3d_plane_estimation(points): pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points) plane_model, inliers = pcd.segment_plane(0.02, 3, 1000) [a, b, c, d] = plane_model logger.info(f"Plane equation: {a:.2f}x + {b:.2f}y + {c:.2f}z + {d:.2f} = 0") # Get the inlier points from RANSAC inlier_points = np.asarray(pcd.points)[inliers] # Adaptive flying point removal based on Z-range z_range = np.max(inlier_points[:, 2]) - np.min(inlier_points[:, 2]) if z_range > 6.0: # Large range - likely flying points thresh = 0.90 # Remove 10% elif z_range > 2.0: # Moderate range thresh = 0.93 # Remove 7% else: # Small range - clean thresh = 0.95 # Remove 5% depth_quantile = np.quantile(inlier_points[:, 2], thresh) clean_points = inlier_points[inlier_points[:, 2] <= depth_quantile] logger.info(f"Flying point removal: {len(inlier_points)} -> {len(clean_points)} points (z_range: {z_range:.2f}m, thresh: {thresh})") logger.info(f"Clean points Z range: [{clean_points[:, 2].min():.3f}, {clean_points[:, 2].max():.3f}]") # Get the normal vector of the plane normal = np.array([a, b, c]) normal = normal / np.linalg.norm(normal) # Create two orthogonal vectors in the plane using camera-aware approach # Use Z-axis as primary tangent (depth direction in camera coords) # This helps align one plane axis with the camera's depth direction if abs(normal[2]) < 0.9: # Use Z-axis if normal isn't too close to Z tangent = np.array([0, 0, 1]) else: tangent = np.array([1, 0, 0]) # Use X-axis otherwise v1 = np.cross(normal, tangent) v1 = v1 / np.linalg.norm(v1) v2 = np.cross(normal, v1) v2 = v2 / np.linalg.norm(v2) # Explicit normalization for numerical stability # Ensure consistent right-handed coordinate system if np.dot(np.cross(v1, v2), normal) < 0: v2 = -v2 logger.info(f"Plane basis vectors - v1: [{v1[0]:.3f}, {v1[1]:.3f}, {v1[2]:.3f}], v2: [{v2[0]:.3f}, {v2[1]:.3f}, {v2[2]:.3f}]") # Calculate centroid using bounding box center (more robust to density bias) min_vals = np.min(clean_points, axis=0) max_vals = np.max(clean_points, axis=0) centroid = (min_vals + max_vals) / 2 logger.info(f"Bbox centroid: [{centroid[0]:.3f}, {centroid[1]:.3f}, {centroid[2]:.3f}]") # Project clean points onto the plane's coordinate system relative_points = clean_points - centroid u_coords = np.dot(relative_points, v1) # coordinates along v1 direction v_coords = np.dot(relative_points, v2) # coordinates along v2 direction # Since flying points are already removed, use minimal percentile filtering [0, 99] u_min, u_max = np.percentile(u_coords, [0, 100]) v_min, v_max = np.percentile(v_coords, [0, 100]) # Calculate extents u_extent = u_max - u_min v_extent = v_max - v_min # Ensure minimum size u_extent = max(u_extent, 0.1) # minimum 10cm v_extent = max(v_extent, 0.1) logger.info(f"Plane size: {u_extent:.3f}m x {v_extent:.3f}m") # Calculate direction away from camera center (at origin [0,0,0]) camera_pos = np.array([0, 0, 0]) # Camera at origin camera_to_centroid = centroid - camera_pos # Direction from camera to plane center camera_distance = np.linalg.norm(camera_to_centroid) away_direction = camera_to_centroid / camera_distance # Project away direction onto the plane (remove component normal to plane) away_in_plane = away_direction - np.dot(away_direction, normal) * normal away_in_plane_norm = np.linalg.norm(away_in_plane) # Create plane coordinate system based on camera direction if away_in_plane_norm > 1e-6: # Only if there's a meaningful in-plane component # Define plane axes directly based on camera direction away_axis = away_in_plane / away_in_plane_norm # Away from camera direction (in plane) perp_axis = np.cross(normal, away_axis) # Perpendicular to away direction (in plane) perp_axis = perp_axis / np.linalg.norm(perp_axis) logger.info(f"Camera-based plane axes:") logger.info(f" Away axis: [{away_axis[0]:.3f}, {away_axis[1]:.3f}, {away_axis[2]:.3f}]") logger.info(f" Perp axis: [{perp_axis[0]:.3f}, {perp_axis[1]:.3f}, {perp_axis[2]:.3f}]") # Project all points onto this camera-aligned coordinate system relative_points = clean_points - centroid away_coords = np.dot(relative_points, away_axis) # coordinates along away direction perp_coords = np.dot(relative_points, perp_axis) # coordinates perpendicular to away # Calculate extents in camera-aligned system away_min, away_max = np.percentile(away_coords, [0, 100]) perp_min, perp_max = np.percentile(perp_coords, [0, 100]) away_extent = max(away_max - away_min, 0.1) perp_extent = max(perp_max - perp_min, 0.1) # Asymmetric extension: 10% towards camera, 50% away from camera, 20% perpendicular both sides away_extent_extended = away_extent * 1.6 # 60% larger in away direction (10% + 50%) perp_extent_extended = perp_extent * 1.4 # 40% larger in perpendicular direction (20% each side) logger.info(f"Original extents: away={away_extent:.3f}m, perp={perp_extent:.3f}m") logger.info(f"Extended extents: away={away_extent_extended:.3f}m, perp={perp_extent_extended:.3f}m") # Extension amounts for each direction away_extension_near = away_extent * 0.1 # 10% extension towards camera (near side) away_extension_far = away_extent * 0.5 # 50% extension away from camera (far side) perp_extension = perp_extent * 0.2 # 20% extension on each perpendicular side logger.info(f"Extensions: near={away_extension_near:.3f}m, far={away_extension_far:.3f}m, perp={perp_extension:.3f}m per side") logger.info(f"Extending plane asymmetrically: 10% towards camera, 50% away from camera, 20% perpendicular both sides") corners = [] for da in [-1, 1]: for dp in [-1, 1]: # Asymmetric extension in away direction if da == 1: # Away from camera side - extend by 50% away_distance = away_extent/2 + away_extension_far else: # Near camera side - extend by 10% away_distance = da * (away_extent/2 + away_extension_near) # Extend perpendicular direction by 20% on both sides perp_distance = dp * (perp_extent/2 + perp_extension) corner = (centroid + away_distance * away_axis + perp_distance * perp_axis) corners.append(corner) else: # If plane is parallel to camera direction, use original v1/v2 system logger.info("Plane parallel to camera direction, using original coordinate system") corners = [] for dx in [-1, 1]: for dy in [-1, 1]: corner = centroid + dx * (u_extent/2) * v1 + dy * (v_extent/2) * v2 corners.append(corner) corners = np.array(corners) # Create a quad mesh using trimesh # Define vertices (4 corners) vertices = corners # Define a single quad face (indices of the 4 vertices) # Make sure the order is correct for proper orientation faces = np.array([[0, 1, 3, 2]]) # quad face # Create trimesh with quad faces # rotate mesh (from z-up to y-up) vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) mesh = trimesh.Trimesh( vertices=vertices, faces=faces, process=False # Important: prevents automatic triangulation ) # Optional: set face colors mesh.visual.face_colors = [128, 128, 128, 255] # gray color (RGBA) return mesh def estimate_plane_area(mask): """ Calculate the area covered by the mask's 2D bounding box as a fraction of total image area. """ if mask.numel() == 0: return 0.0 # Find coordinates where mask > 0.5 (valid mask pixels) valid_mask = mask > 0.5 # If no valid pixels, return 0 if not torch.any(valid_mask): return 0.0 # Get mask dimensions H, W = mask.shape total_area = H * W # Find bounding box coordinates # Get row and column indices of valid pixels valid_coords = torch.nonzero(valid_mask, as_tuple=False) # Returns [N, 2] array of [row, col] if valid_coords.size(0) == 0: return 0.0 # Find min/max coordinates to form bounding box min_row = torch.min(valid_coords[:, 0]).item() max_row = torch.max(valid_coords[:, 0]).item() min_col = torch.min(valid_coords[:, 1]).item() max_col = torch.max(valid_coords[:, 1]).item() # Calculate bounding box dimensions bbox_height = max_row - min_row + 1 bbox_width = max_col - min_col + 1 bbox_area = bbox_height * bbox_width # Return ratio of bounding box area to total image area return bbox_area / total_area