|
|
""" |
|
|
Inference utilities. |
|
|
""" |
|
|
|
|
|
import warnings |
|
|
from typing import Any, Dict, List |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from mapanything.utils.geometry import ( |
|
|
depth_edge, |
|
|
get_rays_in_camera_frame, |
|
|
normals_edge, |
|
|
points_to_normals, |
|
|
quaternion_to_rotation_matrix, |
|
|
recover_pinhole_intrinsics_from_ray_directions, |
|
|
rotation_matrix_to_quaternion, |
|
|
) |
|
|
from mapanything.utils.image import rgb |
|
|
|
|
|
|
|
|
ALLOWED_VIEW_KEYS = { |
|
|
"img", |
|
|
"data_norm_type", |
|
|
"depth_z", |
|
|
"ray_directions", |
|
|
"intrinsics", |
|
|
"camera_poses", |
|
|
"is_metric_scale", |
|
|
"true_shape", |
|
|
"idx", |
|
|
"instance", |
|
|
} |
|
|
|
|
|
REQUIRED_KEYS = {"img", "data_norm_type"} |
|
|
|
|
|
|
|
|
CONFLICTING_KEYS = [ |
|
|
("intrinsics", "ray_directions") |
|
|
] |
|
|
|
|
|
|
|
|
def loss_of_one_batch_multi_view( |
|
|
batch, |
|
|
model, |
|
|
criterion, |
|
|
device, |
|
|
use_amp=False, |
|
|
amp_dtype="bf16", |
|
|
ret=None, |
|
|
ignore_keys=None, |
|
|
): |
|
|
""" |
|
|
Calculate loss for a batch with multiple views. |
|
|
|
|
|
Args: |
|
|
batch (list): List of view dictionaries containing input data. |
|
|
model (torch.nn.Module): Model to run inference with. |
|
|
criterion (callable, optional): Loss function to compute the loss. |
|
|
device (torch.device): Device to run the computation on. |
|
|
use_amp (bool, optional): Whether to use automatic mixed precision. Defaults to False. |
|
|
amp_dtype (str, optional): Floating point type to use for automatic mixed precision. Options: ["fp32", "fp16", "bf16"]. Defaults to "bf16". |
|
|
ret (str, optional): If provided, return only the specified key from the result dictionary. |
|
|
ignore_keys (set, optional): Set of keys to ignore when moving tensors to device. |
|
|
Defaults to {"dataset", "label", "instance", |
|
|
"idx", "true_shape", "rng", "data_norm_type"}. |
|
|
|
|
|
Returns: |
|
|
dict or Any: If ret is None, returns a dictionary containing views, predictions, and loss. |
|
|
Otherwise, returns the value associated with the ret key. |
|
|
""" |
|
|
|
|
|
if ignore_keys is None: |
|
|
ignore_keys = set( |
|
|
[ |
|
|
"depthmap", |
|
|
"dataset", |
|
|
"label", |
|
|
"instance", |
|
|
"idx", |
|
|
"true_shape", |
|
|
"rng", |
|
|
"data_norm_type", |
|
|
"scene_flow_compute_type", |
|
|
] |
|
|
) |
|
|
for view in batch: |
|
|
for name in view.keys(): |
|
|
if name in ignore_keys: |
|
|
continue |
|
|
view[name] = view[name].to(device, non_blocking=True) |
|
|
|
|
|
|
|
|
if use_amp: |
|
|
if amp_dtype == "fp16": |
|
|
amp_dtype = torch.float16 |
|
|
elif amp_dtype == "bf16": |
|
|
if torch.cuda.is_bf16_supported(): |
|
|
amp_dtype = torch.bfloat16 |
|
|
else: |
|
|
warnings.warn( |
|
|
"bf16 is not supported on this device. Using fp16 instead." |
|
|
) |
|
|
amp_dtype = torch.float16 |
|
|
elif amp_dtype == "fp32": |
|
|
amp_dtype = torch.float32 |
|
|
else: |
|
|
amp_dtype = torch.float32 |
|
|
|
|
|
|
|
|
with torch.autocast("cuda", enabled=bool(use_amp), dtype=amp_dtype): |
|
|
preds = model(batch) |
|
|
with torch.autocast("cuda", enabled=False): |
|
|
loss = criterion(batch, preds) if criterion is not None else None |
|
|
|
|
|
result = {f"view{i + 1}": view for i, view in enumerate(batch)} |
|
|
result.update({f"pred{i + 1}": pred for i, pred in enumerate(preds)}) |
|
|
result["loss"] = loss |
|
|
|
|
|
return result[ret] if ret else result |
|
|
|
|
|
|
|
|
def validate_input_views_for_inference( |
|
|
views: List[Dict[str, Any]], |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Strict validation and preprocessing of input views. |
|
|
|
|
|
Args: |
|
|
views: List of view dictionaries |
|
|
|
|
|
Returns: |
|
|
Validated and preprocessed views |
|
|
|
|
|
Raises: |
|
|
ValueError: For invalid keys, missing required keys, conflicting inputs, or invalid camera pose constraints |
|
|
""" |
|
|
|
|
|
if not views: |
|
|
raise ValueError("At least one view must be provided") |
|
|
|
|
|
|
|
|
views_with_poses = [] |
|
|
|
|
|
|
|
|
for view_idx, view in enumerate(views): |
|
|
|
|
|
provided_keys = set(view.keys()) |
|
|
invalid_keys = provided_keys - ALLOWED_VIEW_KEYS |
|
|
if invalid_keys: |
|
|
raise ValueError( |
|
|
f"View {view_idx} contains invalid keys: {invalid_keys}. " |
|
|
f"Allowed keys are: {sorted(ALLOWED_VIEW_KEYS)}" |
|
|
) |
|
|
|
|
|
|
|
|
missing_keys = REQUIRED_KEYS - provided_keys |
|
|
if missing_keys: |
|
|
raise ValueError(f"View {view_idx} missing required keys: {missing_keys}") |
|
|
|
|
|
|
|
|
for conflict_set in CONFLICTING_KEYS: |
|
|
present_conflicts = [key for key in conflict_set if key in provided_keys] |
|
|
if len(present_conflicts) > 1: |
|
|
raise ValueError( |
|
|
f"View {view_idx} contains conflicting keys: {present_conflicts}. " |
|
|
f"Only one of {conflict_set} can be provided at a time." |
|
|
) |
|
|
|
|
|
|
|
|
if "depth_z" in provided_keys: |
|
|
if ( |
|
|
"intrinsics" not in provided_keys |
|
|
and "ray_directions" not in provided_keys |
|
|
): |
|
|
raise ValueError( |
|
|
f"View {view_idx} depth constraint violation: If 'depth_z' is provided, " |
|
|
f"then 'intrinsics' or 'ray_directions' must also be provided. " |
|
|
f"Z Depth values require camera calibration information to be meaningful for an image." |
|
|
) |
|
|
|
|
|
|
|
|
if "camera_poses" in provided_keys: |
|
|
views_with_poses.append(view_idx) |
|
|
|
|
|
|
|
|
if views_with_poses and 0 not in views_with_poses: |
|
|
raise ValueError( |
|
|
f"Camera pose constraint violation: Views {views_with_poses} have camera_poses, " |
|
|
f"but view 0 (reference view) does not. When using camera_poses, the first view " |
|
|
f"must also provide camera_poses to serve as the reference frame." |
|
|
) |
|
|
|
|
|
return views |
|
|
|
|
|
|
|
|
def preprocess_input_views_for_inference( |
|
|
views: List[Dict[str, Any]], |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Pre-process input views to match the expected internal input format. |
|
|
|
|
|
The following steps are performed: |
|
|
1. Convert intrinsics to ray directions when required. If ray directions are already provided, unit normalize them. |
|
|
2. Convert depth_z to depth_along_ray |
|
|
3. Convert camera_poses to the expected input keys (camera_pose_quats and camera_pose_trans) |
|
|
4. Default is_metric_scale to True when not provided |
|
|
|
|
|
Args: |
|
|
views: List of view dictionaries |
|
|
|
|
|
Returns: |
|
|
Preprocessed views with consistent internal format |
|
|
""" |
|
|
processed_views = [] |
|
|
|
|
|
for view_idx, view in enumerate(views): |
|
|
|
|
|
processed_view = dict(view) |
|
|
|
|
|
|
|
|
if "intrinsics" in view: |
|
|
images = view["img"] |
|
|
height, width = images.shape[-2:] |
|
|
intrinsics = view["intrinsics"] |
|
|
_, ray_directions = get_rays_in_camera_frame( |
|
|
intrinsics=intrinsics, |
|
|
height=height, |
|
|
width=width, |
|
|
normalize_to_unit_sphere=True, |
|
|
) |
|
|
processed_view["ray_directions"] = ray_directions |
|
|
del processed_view["intrinsics"] |
|
|
elif "ray_directions" in view: |
|
|
ray_directions = view["ray_directions"] |
|
|
ray_norm = torch.norm(ray_directions, dim=-1, keepdim=True) |
|
|
processed_view["ray_directions"] = ray_directions / (ray_norm + 1e-8) |
|
|
|
|
|
|
|
|
if "depth_z" in view: |
|
|
depth_z = view["depth_z"] |
|
|
ray_directions = processed_view["ray_directions"] |
|
|
ray_directions_unit_plane = ray_directions / ray_directions[..., 2:3] |
|
|
pts3d_cam = depth_z * ray_directions_unit_plane |
|
|
depth_along_ray = torch.norm(pts3d_cam, dim=-1, keepdim=True) |
|
|
processed_view["depth_along_ray"] = depth_along_ray |
|
|
del processed_view["depth_z"] |
|
|
|
|
|
|
|
|
if "camera_poses" in view: |
|
|
camera_poses = view["camera_poses"] |
|
|
if isinstance(camera_poses, tuple) and len(camera_poses) == 2: |
|
|
quats, trans = camera_poses |
|
|
processed_view["camera_pose_quats"] = quats |
|
|
processed_view["camera_pose_trans"] = trans |
|
|
elif torch.is_tensor(camera_poses) and camera_poses.shape[-2:] == (4, 4): |
|
|
rotation_matrices = camera_poses[:, :3, :3] |
|
|
translation_vectors = camera_poses[:, :3, 3] |
|
|
quats = rotation_matrix_to_quaternion(rotation_matrices) |
|
|
processed_view["camera_pose_quats"] = quats |
|
|
processed_view["camera_pose_trans"] = translation_vectors |
|
|
else: |
|
|
raise ValueError( |
|
|
f"View {view_idx}: camera_poses must be either a tuple of (quats, trans) " |
|
|
f"or a tensor of (B, 4, 4) transformation matrices." |
|
|
) |
|
|
del processed_view["camera_poses"] |
|
|
|
|
|
|
|
|
if "is_metric_scale" not in processed_view: |
|
|
|
|
|
batch_size = view["img"].shape[0] |
|
|
|
|
|
processed_view["is_metric_scale"] = torch.ones( |
|
|
batch_size, dtype=torch.bool, device=view["img"].device |
|
|
) |
|
|
|
|
|
|
|
|
if "ray_directions" in processed_view: |
|
|
processed_view["ray_directions_cam"] = processed_view["ray_directions"] |
|
|
del processed_view["ray_directions"] |
|
|
|
|
|
|
|
|
processed_views.append(processed_view) |
|
|
|
|
|
return processed_views |
|
|
|
|
|
|
|
|
def postprocess_model_outputs_for_inference( |
|
|
raw_outputs: List[Dict[str, torch.Tensor]], |
|
|
input_views: List[Dict[str, Any]], |
|
|
apply_mask: bool = True, |
|
|
mask_edges: bool = True, |
|
|
edge_normal_threshold: float = 5.0, |
|
|
edge_depth_threshold: float = 0.03, |
|
|
apply_confidence_mask: bool = False, |
|
|
confidence_percentile: float = 10, |
|
|
) -> List[Dict[str, torch.Tensor]]: |
|
|
""" |
|
|
Post-process raw model outputs by copying raw outputs and adding essential derived fields. |
|
|
|
|
|
This function simplifies the raw model outputs by: |
|
|
1. Copying all raw outputs as-is |
|
|
2. Adding denormalized images (img_no_norm) |
|
|
3. Adding Z depth (depth_z) from camera frame points |
|
|
4. Recovering pinhole camera intrinsics from ray directions |
|
|
5. Adding camera pose matrices (camera_poses) if pose data is available |
|
|
6. Applying mask to dense geometry outputs if requested (supports edge masking and confidence masking) |
|
|
|
|
|
Args: |
|
|
raw_outputs: List of raw model output dictionaries, one per view |
|
|
input_views: List of original input view dictionaries, one per view |
|
|
apply_mask: Whether to apply non-ambiguous mask to dense outputs. Defaults to True. |
|
|
mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True. |
|
|
apply_confidence_mask: Whether to apply the confidence mask to the output. Defaults to False. |
|
|
confidence_percentile: The percentile to use for the confidence threshold. Defaults to 10. |
|
|
|
|
|
Returns: |
|
|
List of processed output dictionaries containing: |
|
|
- All original raw outputs (after masking dense geometry outputs if requested) |
|
|
- 'img_no_norm': Denormalized RGB images (B, H, W, 3) |
|
|
- 'depth_z': Z depth from camera frame (B, H, W, 1) if points in camera frame available |
|
|
- 'intrinsics': Recovered pinhole camera intrinsics (B, 3, 3) if ray directions available |
|
|
- 'camera_poses': 4x4 pose matrices (B, 4, 4) if pose data available |
|
|
- 'mask': comprehensive mask for dense geometry outputs (B, H, W, 1) if requested |
|
|
|
|
|
""" |
|
|
processed_outputs = [] |
|
|
|
|
|
for view_idx, (raw_output, original_view) in enumerate( |
|
|
zip(raw_outputs, input_views) |
|
|
): |
|
|
|
|
|
processed_output = dict(raw_output) |
|
|
|
|
|
|
|
|
img = original_view["img"] |
|
|
data_norm_type = original_view["data_norm_type"][0] |
|
|
img_hwc = rgb(img, data_norm_type) |
|
|
|
|
|
|
|
|
if isinstance(img_hwc, np.ndarray): |
|
|
img_hwc = torch.from_numpy(img_hwc).to(img.device) |
|
|
|
|
|
processed_output["img_no_norm"] = img_hwc |
|
|
|
|
|
|
|
|
if "pts3d_cam" in processed_output: |
|
|
processed_output["depth_z"] = processed_output["pts3d_cam"][..., 2:3] |
|
|
|
|
|
|
|
|
if "ray_directions" in processed_output: |
|
|
intrinsics = recover_pinhole_intrinsics_from_ray_directions( |
|
|
processed_output["ray_directions"] |
|
|
) |
|
|
processed_output["intrinsics"] = intrinsics |
|
|
|
|
|
|
|
|
if "cam_trans" in processed_output and "cam_quats" in processed_output: |
|
|
cam_trans = processed_output["cam_trans"] |
|
|
cam_quats = processed_output["cam_quats"] |
|
|
batch_size = cam_trans.shape[0] |
|
|
|
|
|
|
|
|
rotation_matrices = quaternion_to_rotation_matrix(cam_quats) |
|
|
|
|
|
|
|
|
pose_matrices = ( |
|
|
torch.eye(4, device=img.device).unsqueeze(0).repeat(batch_size, 1, 1) |
|
|
) |
|
|
pose_matrices[:, :3, :3] = rotation_matrices |
|
|
pose_matrices[:, :3, 3] = cam_trans |
|
|
|
|
|
processed_output["camera_poses"] = pose_matrices |
|
|
|
|
|
|
|
|
if apply_mask: |
|
|
final_mask = None |
|
|
|
|
|
|
|
|
if "non_ambiguous_mask" in processed_output: |
|
|
non_ambiguous_mask = ( |
|
|
processed_output["non_ambiguous_mask"].cpu().numpy() |
|
|
) |
|
|
final_mask = non_ambiguous_mask |
|
|
|
|
|
|
|
|
if apply_confidence_mask and "conf" in processed_output: |
|
|
confidences = processed_output["conf"].cpu() |
|
|
|
|
|
batch_size = confidences.shape[0] |
|
|
conf_mask = torch.zeros_like(confidences, dtype=torch.bool) |
|
|
percentile_threshold = ( |
|
|
torch.quantile( |
|
|
confidences.reshape(batch_size, -1), |
|
|
confidence_percentile / 100.0, |
|
|
dim=1, |
|
|
) |
|
|
.unsqueeze(-1) |
|
|
.unsqueeze(-1) |
|
|
) |
|
|
|
|
|
|
|
|
conf_mask = confidences > percentile_threshold |
|
|
conf_mask = conf_mask.numpy() |
|
|
|
|
|
if final_mask is not None: |
|
|
final_mask = final_mask & conf_mask |
|
|
else: |
|
|
final_mask = conf_mask |
|
|
|
|
|
|
|
|
if mask_edges and final_mask is not None and "pts3d" in processed_output: |
|
|
|
|
|
pred_pts3d = processed_output["pts3d"].cpu().numpy() |
|
|
batch_size, height, width = final_mask.shape |
|
|
|
|
|
edge_masks = [] |
|
|
for b in range(batch_size): |
|
|
batch_final_mask = final_mask[b] |
|
|
batch_pts3d = pred_pts3d[b] |
|
|
|
|
|
if batch_final_mask.any(): |
|
|
|
|
|
normals, normals_mask = points_to_normals( |
|
|
batch_pts3d, mask=batch_final_mask |
|
|
) |
|
|
normal_edges = normals_edge( |
|
|
normals, tol=edge_normal_threshold, mask=normals_mask |
|
|
) |
|
|
|
|
|
|
|
|
depth_z = ( |
|
|
processed_output["depth_z"][b].squeeze(-1).cpu().numpy() |
|
|
) |
|
|
depth_edges = depth_edge( |
|
|
depth_z, rtol=edge_depth_threshold, mask=batch_final_mask |
|
|
) |
|
|
|
|
|
|
|
|
edge_mask = ~(depth_edges & normal_edges) |
|
|
edge_masks.append(edge_mask) |
|
|
else: |
|
|
|
|
|
edge_masks.append(np.zeros_like(batch_final_mask, dtype=bool)) |
|
|
|
|
|
|
|
|
edge_mask = np.stack(edge_masks, axis=0) |
|
|
final_mask = final_mask & edge_mask |
|
|
|
|
|
|
|
|
if final_mask is not None: |
|
|
|
|
|
final_mask_torch = torch.from_numpy(final_mask).to( |
|
|
processed_output["pts3d"].device |
|
|
) |
|
|
final_mask_torch = final_mask_torch.unsqueeze(-1) |
|
|
|
|
|
|
|
|
dense_geometry_keys = [ |
|
|
"pts3d", |
|
|
"pts3d_cam", |
|
|
"depth_along_ray", |
|
|
"depth_z", |
|
|
] |
|
|
for key in dense_geometry_keys: |
|
|
if key in processed_output: |
|
|
processed_output[key] = processed_output[key] * final_mask_torch |
|
|
|
|
|
|
|
|
processed_output["mask"] = final_mask_torch |
|
|
|
|
|
processed_outputs.append(processed_output) |
|
|
|
|
|
return processed_outputs |
|
|
|