Argus / argus /utils /normalization.py
lixi042
Initial commit: Argus metric panoramic 3D reconstruction demo
510e990
Raw
History Blame Contribute Delete
2.99 kB
import torch
from typing import Optional, Tuple
from argus.utils.geometry import closed_form_inverse_se3
def cal_scale_by_points(points: torch.Tensor, point_masks: torch.Tensor) -> torch.Tensor:
# Calculate average distance of valid 3D points (batch-wise)
dist = points.norm(dim=-1)
dist_sum = (dist * point_masks).sum(dim=[1, 2, 3]) # Shape: [B,]
valid_count = point_masks.sum(dim=[1, 2, 3])
avg_scale = (dist_sum / (valid_count + 1e-3)).clamp(min=1e-6, max=1e6)
return avg_scale
def normalize_camera_extrinsics_and_points_batch(
extrinsics: torch.Tensor,
cam_points: torch.Tensor,
depths: torch.Tensor,
point_masks: torch.Tensor,
scale_mode: str = "none",
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# Basic input validation
assert extrinsics.ndim == 4 and extrinsics.shape[2:] == (4, 4), \
f"Extrinsics must be (B, S, 4, 4), got {extrinsics.shape}"
B, S = extrinsics.shape[:2]
device = extrinsics.device
# Step 1: Transform all extrinsics to reference frame (1st frame of each batch)
ref_extrinsics = extrinsics[:,0,:,:] # (B, 4, 4)
ref_extr_inv = closed_form_inverse_se3(ref_extrinsics)
new_extrinsics = torch.matmul(ref_extr_inv.unsqueeze(1), extrinsics) # (B, S, 4, 4) world coordinate
# Step 2: Clone tensors to avoid in-place modification
new_depths = depths.clone()
new_cam_points = cam_points.clone()
# Step 3: Compute rotated/world points from new extrinsics
R_new = new_extrinsics[:, :, :3, :3] # (B, S, 3, 3)
t_new = new_extrinsics[:, :, :3, 3] # (B, S, 3)
new_rotated_points = torch.matmul(R_new.unsqueeze(2).unsqueeze(3), new_cam_points.unsqueeze(-1)).squeeze(-1) # (B,S,1,1,3,3) × (B,S,H,W,3,1) -> (B,S,H,W,3)
new_world_points = new_rotated_points + t_new.unsqueeze(2).unsqueeze(3)
# Step 4: Apply scene scaling
if scale_mode == "avg_dist":
avg_scale = cal_scale_by_points(new_world_points, point_masks) # (B,)
# Reshape scale for broadcasting with different tensor shapes
scale_3d = avg_scale.view(-1, 1, 1) # For extrinsics (B, S, 4, 4)
scale_4d = avg_scale.view(-1, 1, 1, 1) # For depths (B, S, H, W)
scale_5d = avg_scale.view(-1, 1, 1, 1, 1) # For 3D points (B, S, H, W, 3)
new_extrinsics[:, :, :3, 3] /= scale_3d
new_depths /= scale_4d
new_cam_points /= scale_5d
new_rotated_points /= scale_5d
new_world_points /= scale_5d
elif scale_mode == "abs":
metric_scale = 10.0
new_extrinsics[:, :, :3, 3] /= metric_scale
new_depths /= metric_scale
new_cam_points /= metric_scale
new_rotated_points /= metric_scale
new_world_points /= metric_scale
elif scale_mode == "none":
pass
else:
raise ValueError(f"Unknown scale_mode: {scale_mode}")
return new_extrinsics, new_cam_points, new_rotated_points, new_world_points, new_depths