xinjie.wang
update
7734c01
raw
history blame
31.7 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
import torch
import numpy as np
import open3d as o3d
import trimesh
from pytorch3d.structures import Meshes
from pytorch3d.transforms import quaternion_to_matrix, Transform3d, matrix_to_quaternion
from sam3d_objects.data.dataset.tdfy.transforms_3d import compose_transform, decompose_transform
from sam3d_objects.data.dataset.tdfy.pose_target import PoseTargetConverter
from loguru import logger
from sam3d_objects.pipeline.layout_post_optimization_utils import (
run_ICP,
compute_iou,
set_seed,
apply_transform,
get_mesh,
get_mask_renderer,
run_alignment,
run_render_compare,
check_occlusion,
)
SLAT_STD = torch.tensor(
[
2.377650737762451,
2.386378288269043,
2.124418020248413,
2.1748552322387695,
2.663944721221924,
2.371192216873169,
2.6217446327209473,
2.684523105621338,
]
)
SLAT_MEAN = torch.tensor(
[
-2.1687545776367188,
-0.004347046371549368,
-0.13352349400520325,
-0.08418072760105133,
-0.5271206498146057,
0.7238689064979553,
-1.1414450407028198,
1.2039363384246826,
]
)
ROTATION_6D_MEAN = torch.tensor(
[
-0.06366084883674913,
0.008438224692279752,
0.00017084786438302483,
0.0007126610473540038,
-0.0030916726538816417,
0.5166093753457688,
]
)
ROTATION_6D_STD = torch.tensor(
[
0.6656971967514863,
0.6787012271867754,
0.30345010594844524,
0.4394504420678794,
0.39817973931717104,
0.6176286868761914,
]
)
def layout_post_optimization(
Mesh,
Quaternion,
Translation,
Scale,
Mask,
Point_Map,
Intrinsics,
Enable_shape_ICP=True,
Enable_rendering_optimization=True,
min_size=512,
device=None,
):
set_seed(100)
if device is None:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# init transform and process mesh
Rotation = quaternion_to_matrix(Quaternion.squeeze(1))
center = Translation[0].clone()
tfm_ori = compose_transform(scale=Scale, rotation=Rotation, translation=Translation)
mesh, faces_idx, textures = get_mesh(Mesh, tfm_ori, device)
# get mask and renderer
mask, renderer = get_mask_renderer(Mask, min_size, Intrinsics, device)
# check occlusion
if check_occlusion(mask[0, 0].cpu().numpy(), Point_Map.cpu().numpy()):
return (
Quaternion,
Translation,
Scale,
-1.0,
False,
False,
)
# Step 1: Manual Alignment
source_points, target_points, center, tfm1, mesh, ori_iou, final_iou, flag_notgt = (
run_alignment(
Point_Map, mask, mesh, center, faces_idx, textures, renderer, device
)
)
# return original layout if no target points.
if flag_notgt:
return (
Quaternion,
Translation,
Scale,
-1.0,
False,
False,
)
# Step 2: Shape ICP
if Enable_shape_ICP:
Flag_ICP = True
points_aligned_icp, transformation = run_ICP(
mesh, source_points, target_points, threshold=0.05
)
mesh_ICP = Meshes(
verts=[points_aligned_icp], faces=[faces_idx], textures=textures
)
rendered = renderer(mesh_ICP)
ori_iou_shapeICP = compute_iou(
rendered[..., 3][0][None, None], mask, threshold=0.5
)
# determine whether accept ICP
if ori_iou_shapeICP > ori_iou:
mesh = mesh_ICP
final_iou = ori_iou_shapeICP.cpu().item()
T_o3d = torch.tensor(transformation, dtype=torch.float32, device=device)
T_o3d = T_o3d.T
A = T_o3d[:3, :3]
t = T_o3d[3, :3]
scale = A.norm(dim=1)
R = A / scale[:, None]
center = ((center[None] * scale) @ R + t)[0] # transform center
tfm2 = (
Transform3d(device=device)
.scale(scale[None])
.rotate(R[None])
.translate(t[None])
)
else:
Flag_ICP = False
scale_2, translation_2 = torch.tensor(1).to(device), torch.zeros([3]).to(
device
)
tfm2 = (
Transform3d(device=device)
.scale(scale_2.expand(3)[None])
.translate(translation_2[None])
)
else:
Flag_ICP = False
scale_2, translation_2 = torch.tensor(1).to(device), torch.zeros([3]).to(device)
tfm2 = (
Transform3d(device=device)
.scale(scale_2.expand(3)[None])
.translate(translation_2[None])
)
# Step 3: Render-and-Compare
if not Enable_rendering_optimization:
Flag_optim = False
tfm = tfm_ori.compose(tfm1).compose(tfm2)
else:
quat, translation, scale, R = run_render_compare(
mesh, center, renderer, mask, device
)
with torch.no_grad():
transformed = apply_transform(mesh, center, quat, translation, scale)
rendered = renderer(transformed)
optimized_iou = compute_iou(
rendered[..., 3][0][None, None], mask, threshold=0.5
)
# Criterior to use layout optimization
if optimized_iou < 0.5 or optimized_iou <= ori_iou:
Flag_optim = False
tfm = tfm_ori # reject manual alignment and ICP as well.
# tfm = tfm_ori.compose(tfm1).compose(tfm2) # only reject render-compare but keep manual alignment and ICP.
else:
Flag_optim = True
final_iou = optimized_iou.detach().cpu().item()
tfm3 = (
Transform3d(device=device)
.translate(-center[None]) # move to center
.scale(scale.expand(3)[None])
.rotate(R.T[None])
.translate(center[None]) # move back
.translate(translation[None])
)
tfm = tfm_ori.compose(tfm1).compose(tfm2).compose(tfm3)
M = tfm.get_matrix()[0]
T_final = M[3, :3][None]
A = M[:3, :3]
scale_final = A.norm(dim=1)[None]
R_final = A / scale_final[:, None]
quat_final = matrix_to_quaternion(R_final)
return (
quat_final,
T_final,
scale_final,
round(float(final_iou), 4),
Flag_ICP,
Flag_optim,
)
def pose_decoder(
pose_target_convention,
):
def decode(model_output_dict, scene_scale=None, scene_shift=None):
x = model_output_dict
# BEGIN: copied from generative.py
key_mapping = {
"shape": "x_shape_latent",
"quaternion": "x_instance_rotation",
"6drotation": "x_instance_rotation_6d",
"6drotation_normalized": "x_instance_rotation_6d_normalized",
"translation": "x_instance_translation",
"scale": "x_instance_scale",
"translation_scale": "x_translation_scale",
}
# Decodes for metrics
pose_target_dict = {}
for k, v in x.items():
pose_target_dict[key_mapping.get(k, k)] = v
# TODO: Hao & Bowen please do clean this up!
# Convert 6D rotation to quaternion if needed
if (
"x_instance_rotation_6d" in pose_target_dict
or "x_instance_rotation_6d_normalized" in pose_target_dict
):
# Extract the two 3D vectors
if "x_instance_rotation_6d_normalized" in pose_target_dict:
rot_6d = pose_target_dict[
"x_instance_rotation_6d_normalized"
] * ROTATION_6D_STD.to(
pose_target_dict["x_instance_rotation_6d_normalized"].device
) + ROTATION_6D_MEAN.to(
pose_target_dict["x_instance_rotation_6d_normalized"].device
)
else:
rot_6d = pose_target_dict["x_instance_rotation_6d"]
a1 = rot_6d[..., 0:3]
a2 = rot_6d[..., 3:6]
# Normalize first vector
b1 = torch.nn.functional.normalize(a1, dim=-1)
# Make second vector orthogonal to first
b2 = a2 - torch.sum(b1 * a2, dim=-1, keepdim=True) * b1
b2 = torch.nn.functional.normalize(b2, dim=-1)
# Compute third vector as cross product
b3 = torch.cross(b1, b2, dim=-1)
# Stack to create rotation matrix
rotation_matrix = torch.stack([b1, b2, b3], dim=-1)
# Convert to quaternion
quaternion = matrix_to_quaternion(rotation_matrix)
pose_target_dict["x_instance_rotation"] = quaternion
if "x_instance_scale" in pose_target_dict:
pose_target_dict["x_instance_scale"] = torch.exp(
pose_target_dict["x_instance_scale"]
)
if "x_translation_scale" in pose_target_dict:
pose_target_dict["x_translation_scale"] = torch.exp(
pose_target_dict["x_translation_scale"]
)
pose_target_dict["pose_target_convention"] = [pose_target_convention] * x[
"shape"
].shape[0]
# END: copied from generative.py
# Fake pointmap moments
device = x["shape"].device
_scene_scale = (
scene_scale if scene_scale is not None else torch.tensor(1.0, device=device)
)
_scene_shift = (
scene_shift
if scene_shift is not None
else torch.tensor([[0, 0, 0]], device=device)
)
pose_target_dict["x_scene_scale"] = _scene_scale
pose_target_dict["x_scene_center"] = _scene_shift
# Convert to instance pose
pose_instance_dict = PoseTargetConverter.dicts_pose_target_to_instance_pose(
pose_target_convention=pose_target_convention,
x_instance_scale=pose_target_dict["x_instance_scale"],
x_instance_translation=pose_target_dict["x_instance_translation"],
x_instance_rotation=pose_target_dict["x_instance_rotation"],
x_translation_scale=pose_target_dict["x_translation_scale"],
x_scene_scale=pose_target_dict["x_scene_scale"],
x_scene_center=pose_target_dict["x_scene_center"],
)
return {
"translation": pose_instance_dict["instance_position_l2c"].squeeze(0),
"rotation": pose_instance_dict["instance_quaternion_l2c"].squeeze(0),
"scale": pose_instance_dict["instance_scale_l2c"].squeeze(0).mean(-1, keepdim=True).expand(1,3),
}
return decode
def zero_prediction_decoder():
def decode(model_output_dict, scene_scale=None, scene_shift=None):
import copy
from loguru import logger
_pose_decoder = pose_decoder("ScaleShiftInvariant")
model_output_dict = copy.deepcopy(model_output_dict)
logger.warning("Overwriting predictions to zero prediction")
model_output_dict["translation"] = torch.zeros_like(model_output_dict["translation"])
model_output_dict["translation_scale"] = torch.zeros_like(model_output_dict["translation_scale"])
model_output_dict["scale"] = torch.zeros_like(model_output_dict["scale"]) + 1.337 # Empirical average on R3
return _pose_decoder(model_output_dict, scene_scale, scene_shift)
return decode
def get_default_pose_decoder():
def decode(model_output_dict, **kwargs):
return {}
return decode
POSE_DECODERS = {
"default": get_default_pose_decoder(),
"ApparentSize": pose_decoder("ApparentSize"),
"DisparitySpace": pose_decoder("DisparitySpace"),
"ScaleShiftInvariant": pose_decoder("ScaleShiftInvariant"),
"ZeroPredictionScaleShiftInvariant": zero_prediction_decoder(),
}
def get_pose_decoder(name):
if name not in POSE_DECODERS:
raise NotImplementedError
return POSE_DECODERS[name]
def prune_sparse_structure(
coord_batch,
max_neighbor_axes_dist=1,
):
coords, batch = coord_batch[:, 1:], coord_batch[:, 0].unsqueeze(-1)
device = coords.device
# 1) shift coords so minimum is zero
min_xyz = coords.min(0)[0]
coords0 = coords - min_xyz
# 2) build occupancy grid
max_xyz = coords0.max(0)[0] + 1 # size in each dim
D, H, W = max_xyz.tolist()
# shape (1,1,D,H,W)
occ = torch.zeros((1, 1, D, H, W), dtype=torch.uint8, device=device)
x, y, z = coords0.unbind(1)
occ[0, 0, x, y, z] = 1
# 3) 3×3×3 convolution to count each voxel + neighbors
kernel = torch.ones(
(
1,
1,
2 * max_neighbor_axes_dist + 1,
2 * max_neighbor_axes_dist + 1,
2 * max_neighbor_axes_dist + 1,
),
dtype=torch.uint8,
device=device,
)
# pad so output is same size
pad = max_neighbor_axes_dist
counts = torch.nn.functional.conv3d(occ.float(), kernel.float(), padding=pad)
# interior voxels have count == (2*max_neighbor_axes_dist+1)**3
full_count = (2 * max_neighbor_axes_dist + 1) ** 3
# 4) lookup counts at each original coord
counts_at_pts = counts[0, 0, x, y, z] # (N,)
is_surface = counts_at_pts < full_count
# 5) return filtered batch+coords (shift back if you want original coords)
kept = is_surface.nonzero(as_tuple=False).squeeze(1)
out_batch = batch[kept]
out_coords = coords[kept]
coords = torch.cat([out_batch, out_coords], dim=1)
return torch.cat([out_batch, out_coords], dim=1)
def downsample_sparse_structure(
coord_batch,
max_coords=42000,
downsample_factor=2,
):
"""
Downsample sparse structure coordinates when there are more than max_coords.
Downsamples by rescaling coordinates, effectively shrinking the grid while preserving
the structure. The downsampled grid is centered in the original space.
Args:
coord_batch: tensor of shape (N, 4) where [:, 0] is batch index and [:, 1:] are coords
max_coords: maximum number of coordinates to keep
42000 should be safe number. Calculation: max(int32) / (64*768) ~= 43691
Only needed for mesh decoding.
downsample_factor: factor by which to downsample (e.g., 2 means half resolution)
Returns:
Downsampled coord_batch with coordinates rescaled if downsampling is needed
"""
if coord_batch.shape[0] <= max_coords:
return coord_batch, 1
# Extract coordinates and batch indices
coords = coord_batch[:, 1:].float() # Shape: (N, 3), convert to float for scaling
batch_indices = coord_batch[:, 0:1] # Shape: (N, 1)
# Find the actual coordinate bounds
coords_min = coords.min(dim=0)[0] # Shape: (3,)
coords_max = coords.max(dim=0)[0] # Shape: (3,)
original_size = coords_max - coords_min + 1 # Add 1 since coordinates are discrete
# Calculate target size after downsampling
target_size = original_size / downsample_factor
# Calculate the offset to center the downsampled grid
offset = (original_size - target_size) / 2
target_min = coords_min + offset
target_max = coords_min + offset + target_size - 1
# Normalize coordinates to [0, 1] within their actual range
coords_normalized = (coords - coords_min) / (coords_max - coords_min)
# Scale to the target range
coords_rescaled = coords_normalized * (target_size - 1) + target_min
# Round to integers to get discrete grid coordinates
coords_rescaled = torch.round(coords_rescaled).int()
# Clamp to ensure we stay within bounds
coords_rescaled = torch.clamp(coords_rescaled, target_min.int(), target_max.int())
# Remove duplicates that may have been created by the downsampling
# Concatenate batch and coords for duplicate removal
combined = torch.cat([batch_indices, coords_rescaled], dim=1)
unique_combined = torch.unique(combined, dim=0)
# If still too many after deduplication, randomly subsample
if unique_combined.shape[0] > max_coords:
indices = torch.randperm(unique_combined.shape[0], device=coord_batch.device)[
:max_coords
]
unique_combined = unique_combined[indices]
return unique_combined.int(), downsample_factor
def normalize_mesh_verts(verts):
vmin = verts.min(axis=0)
vmax = verts.max(axis=0)
center = (vmax + vmin) / 2.0
extent = vmax - vmin # largest side length
max_extent = np.max(extent)
if max_extent == 0:
vertices = verts - center
scale = 1
else:
scale = 1.0 / max_extent
vertices = (verts - center) * scale
return vertices, scale, center
def voxelize_mesh(mesh, resolution=64):
verts = np.asarray(mesh.vertices)
# rotate mesh (from z-up to y-up)
verts = verts @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]).T
# normalize vertices
# skip vertices to avoid losing points, likely already normalized
if np.abs(verts.min() + 0.5) < 1e-3 and np.abs(verts.max() - 0.5) < 1e-3:
vertices, scale, center = verts, None, None
else:
vertices, scale, center = normalize_mesh_verts(verts)
vertices = np.clip(vertices, -0.5 + 1e-6, 0.5 - 1e-6)
mesh.vertices = o3d.utility.Vector3dVector(vertices)
voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(
mesh,
voxel_size=1 / 64,
min_bound=(-0.5, -0.5, -0.5),
max_bound=(0.5, 0.5, 0.5),
)
vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()])
vertices = (vertices + 0.5) / 64 - 0.5
coords = ((torch.tensor(vertices) + 0.5) * resolution).int().contiguous()
ss = torch.zeros(1, resolution, resolution, resolution, dtype=torch.long)
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
return ss, scale, center
def preprocess_mesh(mesh: trimesh.Trimesh):
verts = mesh.vertices
if np.abs(verts.min() + 0.5) < 1e-3 and np.abs(verts.max() - 0.5) < 1e-3:
return mesh
vertices, _, _ = normalize_mesh_verts(verts)
mesh.vertices = vertices
return mesh
def trimesh2o3d_mesh(trimesh_mesh):
verts = np.asarray(trimesh_mesh.vertices)
faces = np.asarray(trimesh_mesh.faces)
return o3d.geometry.TriangleMesh(
o3d.utility.Vector3dVector(verts), o3d.utility.Vector3iVector(faces)
)
def update_layout(pred_t, pred_s, pred_quat, center, scale, to_halo=True):
if center is None and not to_halo:
return pred_t, pred_s, pred_quat
pred_transform = compose_transform(
pred_s, quaternion_to_matrix(pred_quat[0]), pred_t
)
if center is None:
comb_transform = pred_transform
else:
norm_transform = compose_transform(
scale * torch.ones_like(pred_t),
torch.eye(3, dtype=pred_t.dtype).to(pred_t.device)[None],
scale * -torch.tensor(center, dtype=pred_t.dtype).to(pred_t.device)[None],
)
comb_transform = norm_transform.compose(pred_transform)
comb_transform = convert_to_halo(comb_transform, pred_t.device, pred_t.dtype)
decomposed = decompose_transform(comb_transform)
quat = matrix_to_quaternion(decomposed.rotation)
return decomposed.translation, decomposed.scale, quat
def convert_to_halo(pred_transform, device, dtype):
on_mesh_transform = Transform3d(dtype=dtype, device=device).rotate(
torch.tensor(
[
[1, 0, 0],
[0, 0, 1],
[0, -1, 0],
],
dtype=dtype,
)
)
on_pm_transform = Transform3d(dtype=dtype, device=device).rotate(
torch.tensor(
[
[-1, 0, 0],
[0, -1, 0],
[0, 0, 1],
],
dtype=dtype,
)
)
return on_mesh_transform.compose(pred_transform).compose(on_pm_transform)
def quat_wxyz_to_euler_XYZ(q: torch.Tensor) -> torch.Tensor:
"""
Convert PyTorch3D quaternions (w,x,y,z) to SciPy-style Euler angles
with sequence 'XYZ' (extrinsic, radians). Works with batch dims.
Args:
q: (..., 4) tensor in w,x,y,z order. Doesn't need to be normalized.
Returns:
angles: (..., 3) tensor [alpha_X, beta_Y, gamma_Z] in radians.
"""
q = q / q.norm(dim=-1, keepdim=True) # normalize
R = quaternion_to_matrix(q) # (..., 3, 3)
R = R.transpose(-1, -2)
r00 = R[..., 0, 0]
r10 = R[..., 1, 0]
r20 = R[..., 2, 0]
r21 = R[..., 2, 1]
r22 = R[..., 2, 2]
# For extrinsic XYZ (R = Rz(gamma) @ Ry(beta) @ Rx(alpha)):
# beta = atan2(-r20, sqrt(r00^2 + r10^2))
# alpha = atan2(r21, r22)
# gamma = atan2(r10, r00)
eps = torch.finfo(R.dtype).eps
beta = torch.atan2(-r20, torch.clamp((r00 * r00 + r10 * r10).sqrt(), min=eps))
alpha = torch.atan2(r21, r22)
gamma = torch.atan2(r10, r00)
return -torch.stack((alpha, beta, gamma), dim=-1)
def format_to_halo(layout_output):
json_out = {}
quaternion = layout_output["quaternion"][0, 0]
translation = layout_output["translation"][0]
scale = list(layout_output["scale"][0])
euler = quat_wxyz_to_euler_XYZ(quaternion)
json_out["roll"] = float(euler[0])
json_out["pitch"] = float(euler[1])
json_out["yaw"] = float(euler[2])
json_out["pred_scale"] = [float(s) for s in scale]
rot_matrix = quaternion_to_matrix(quaternion)
pred_transform = torch.eye(4, dtype=quaternion.dtype).to(quaternion.device)
pred_transform[:3, :3] = rot_matrix
pred_transform[:3, 3] = translation
pred_transform_list = [
[float(t) for t in trans_row] for trans_row in pred_transform
]
json_out["pred_transform"] = pred_transform_list
return json_out
def json_to_halo_payloads(target_data):
pred_transform = target_data["pred_transform"]
pred_scale = target_data["pred_scale"]
roll = target_data.get("roll", 0)
pitch = target_data.get("pitch", 0)
yaw = target_data.get("yaw", 0)
# Update positions, rotation, and scale in the payload
item_attachments = {}
item_attachments["positions"] = {
"x": pred_transform[0][3],
"y": pred_transform[1][3],
"z": pred_transform[2][3] - 1, # Adjust for Halo design
}
item_attachments["rotation"] = {"x": roll, "y": pitch, "z": yaw}
item_attachments["scale"] = {
"x": pred_scale[0],
"y": pred_scale[1],
"z": pred_scale[2],
}
return item_attachments
def o3d_plane_estimation(points):
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
plane_model, inliers = pcd.segment_plane(0.02, 3, 1000)
[a, b, c, d] = plane_model
logger.info(f"Plane equation: {a:.2f}x + {b:.2f}y + {c:.2f}z + {d:.2f} = 0")
# Get the inlier points from RANSAC
inlier_points = np.asarray(pcd.points)[inliers]
# Adaptive flying point removal based on Z-range
z_range = np.max(inlier_points[:, 2]) - np.min(inlier_points[:, 2])
if z_range > 6.0: # Large range - likely flying points
thresh = 0.90 # Remove 10%
elif z_range > 2.0: # Moderate range
thresh = 0.93 # Remove 7%
else: # Small range - clean
thresh = 0.95 # Remove 5%
depth_quantile = np.quantile(inlier_points[:, 2], thresh)
clean_points = inlier_points[inlier_points[:, 2] <= depth_quantile]
logger.info(f"Flying point removal: {len(inlier_points)} -> {len(clean_points)} points (z_range: {z_range:.2f}m, thresh: {thresh})")
logger.info(f"Clean points Z range: [{clean_points[:, 2].min():.3f}, {clean_points[:, 2].max():.3f}]")
# Get the normal vector of the plane
normal = np.array([a, b, c])
normal = normal / np.linalg.norm(normal)
# Create two orthogonal vectors in the plane using camera-aware approach
# Use Z-axis as primary tangent (depth direction in camera coords)
# This helps align one plane axis with the camera's depth direction
if abs(normal[2]) < 0.9: # Use Z-axis if normal isn't too close to Z
tangent = np.array([0, 0, 1])
else:
tangent = np.array([1, 0, 0]) # Use X-axis otherwise
v1 = np.cross(normal, tangent)
v1 = v1 / np.linalg.norm(v1)
v2 = np.cross(normal, v1)
v2 = v2 / np.linalg.norm(v2) # Explicit normalization for numerical stability
# Ensure consistent right-handed coordinate system
if np.dot(np.cross(v1, v2), normal) < 0:
v2 = -v2
logger.info(f"Plane basis vectors - v1: [{v1[0]:.3f}, {v1[1]:.3f}, {v1[2]:.3f}], v2: [{v2[0]:.3f}, {v2[1]:.3f}, {v2[2]:.3f}]")
# Calculate centroid using bounding box center (more robust to density bias)
min_vals = np.min(clean_points, axis=0)
max_vals = np.max(clean_points, axis=0)
centroid = (min_vals + max_vals) / 2
logger.info(f"Bbox centroid: [{centroid[0]:.3f}, {centroid[1]:.3f}, {centroid[2]:.3f}]")
# Project clean points onto the plane's coordinate system
relative_points = clean_points - centroid
u_coords = np.dot(relative_points, v1) # coordinates along v1 direction
v_coords = np.dot(relative_points, v2) # coordinates along v2 direction
# Since flying points are already removed, use minimal percentile filtering [0, 99]
u_min, u_max = np.percentile(u_coords, [0, 100])
v_min, v_max = np.percentile(v_coords, [0, 100])
# Calculate extents
u_extent = u_max - u_min
v_extent = v_max - v_min
# Ensure minimum size
u_extent = max(u_extent, 0.1) # minimum 10cm
v_extent = max(v_extent, 0.1)
logger.info(f"Plane size: {u_extent:.3f}m x {v_extent:.3f}m")
# Calculate direction away from camera center (at origin [0,0,0])
camera_pos = np.array([0, 0, 0]) # Camera at origin
camera_to_centroid = centroid - camera_pos # Direction from camera to plane center
camera_distance = np.linalg.norm(camera_to_centroid)
away_direction = camera_to_centroid / camera_distance
# Project away direction onto the plane (remove component normal to plane)
away_in_plane = away_direction - np.dot(away_direction, normal) * normal
away_in_plane_norm = np.linalg.norm(away_in_plane)
# Create plane coordinate system based on camera direction
if away_in_plane_norm > 1e-6: # Only if there's a meaningful in-plane component
# Define plane axes directly based on camera direction
away_axis = away_in_plane / away_in_plane_norm # Away from camera direction (in plane)
perp_axis = np.cross(normal, away_axis) # Perpendicular to away direction (in plane)
perp_axis = perp_axis / np.linalg.norm(perp_axis)
logger.info(f"Camera-based plane axes:")
logger.info(f" Away axis: [{away_axis[0]:.3f}, {away_axis[1]:.3f}, {away_axis[2]:.3f}]")
logger.info(f" Perp axis: [{perp_axis[0]:.3f}, {perp_axis[1]:.3f}, {perp_axis[2]:.3f}]")
# Project all points onto this camera-aligned coordinate system
relative_points = clean_points - centroid
away_coords = np.dot(relative_points, away_axis) # coordinates along away direction
perp_coords = np.dot(relative_points, perp_axis) # coordinates perpendicular to away
# Calculate extents in camera-aligned system
away_min, away_max = np.percentile(away_coords, [0, 100])
perp_min, perp_max = np.percentile(perp_coords, [0, 100])
away_extent = max(away_max - away_min, 0.1)
perp_extent = max(perp_max - perp_min, 0.1)
# Asymmetric extension: 10% towards camera, 50% away from camera, 20% perpendicular both sides
away_extent_extended = away_extent * 1.6 # 60% larger in away direction (10% + 50%)
perp_extent_extended = perp_extent * 1.4 # 40% larger in perpendicular direction (20% each side)
logger.info(f"Original extents: away={away_extent:.3f}m, perp={perp_extent:.3f}m")
logger.info(f"Extended extents: away={away_extent_extended:.3f}m, perp={perp_extent_extended:.3f}m")
# Extension amounts for each direction
away_extension_near = away_extent * 0.1 # 10% extension towards camera (near side)
away_extension_far = away_extent * 0.5 # 50% extension away from camera (far side)
perp_extension = perp_extent * 0.2 # 20% extension on each perpendicular side
logger.info(f"Extensions: near={away_extension_near:.3f}m, far={away_extension_far:.3f}m, perp={perp_extension:.3f}m per side")
logger.info(f"Extending plane asymmetrically: 10% towards camera, 50% away from camera, 20% perpendicular both sides")
corners = []
for da in [-1, 1]:
for dp in [-1, 1]:
# Asymmetric extension in away direction
if da == 1: # Away from camera side - extend by 50%
away_distance = away_extent/2 + away_extension_far
else: # Near camera side - extend by 10%
away_distance = da * (away_extent/2 + away_extension_near)
# Extend perpendicular direction by 20% on both sides
perp_distance = dp * (perp_extent/2 + perp_extension)
corner = (centroid +
away_distance * away_axis +
perp_distance * perp_axis)
corners.append(corner)
else:
# If plane is parallel to camera direction, use original v1/v2 system
logger.info("Plane parallel to camera direction, using original coordinate system")
corners = []
for dx in [-1, 1]:
for dy in [-1, 1]:
corner = centroid + dx * (u_extent/2) * v1 + dy * (v_extent/2) * v2
corners.append(corner)
corners = np.array(corners)
# Create a quad mesh using trimesh
# Define vertices (4 corners)
vertices = corners
# Define a single quad face (indices of the 4 vertices)
# Make sure the order is correct for proper orientation
faces = np.array([[0, 1, 3, 2]]) # quad face
# Create trimesh with quad faces
# rotate mesh (from z-up to y-up)
vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
mesh = trimesh.Trimesh(
vertices=vertices,
faces=faces,
process=False # Important: prevents automatic triangulation
)
# Optional: set face colors
mesh.visual.face_colors = [128, 128, 128, 255] # gray color (RGBA)
return mesh
def estimate_plane_area(mask):
"""
Calculate the area covered by the mask's 2D bounding box as a fraction of total image area.
"""
if mask.numel() == 0:
return 0.0
# Find coordinates where mask > 0.5 (valid mask pixels)
valid_mask = mask > 0.5
# If no valid pixels, return 0
if not torch.any(valid_mask):
return 0.0
# Get mask dimensions
H, W = mask.shape
total_area = H * W
# Find bounding box coordinates
# Get row and column indices of valid pixels
valid_coords = torch.nonzero(valid_mask, as_tuple=False) # Returns [N, 2] array of [row, col]
if valid_coords.size(0) == 0:
return 0.0
# Find min/max coordinates to form bounding box
min_row = torch.min(valid_coords[:, 0]).item()
max_row = torch.max(valid_coords[:, 0]).item()
min_col = torch.min(valid_coords[:, 1]).item()
max_col = torch.max(valid_coords[:, 1]).item()
# Calculate bounding box dimensions
bbox_height = max_row - min_row + 1
bbox_width = max_col - min_col + 1
bbox_area = bbox_height * bbox_width
# Return ratio of bounding box area to total image area
return bbox_area / total_area