|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from addict import Dict |
|
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
|
|
from depth_anything_3.cfg import create_object |
|
|
from depth_anything_3.model.utils.transform import pose_encoding_to_extri_intri |
|
|
from depth_anything_3.utils.alignment import ( |
|
|
apply_metric_scaling, |
|
|
compute_alignment_mask, |
|
|
compute_sky_mask, |
|
|
least_squares_scale_scalar, |
|
|
sample_tensor_for_quantile, |
|
|
set_sky_regions_to_max_depth, |
|
|
) |
|
|
from depth_anything_3.utils.geometry import affine_inverse, as_homogeneous, map_pdf_to_opacity |
|
|
from depth_anything_3.utils.ray_utils import get_extrinsic_from_camray |
|
|
|
|
|
|
|
|
def _wrap_cfg(cfg_obj): |
|
|
return OmegaConf.create(cfg_obj) |
|
|
|
|
|
|
|
|
class DepthAnything3Net(nn.Module): |
|
|
""" |
|
|
Depth Anything 3 network for depth estimation and camera pose estimation. |
|
|
|
|
|
This network consists of: |
|
|
- Backbone: DinoV2 feature extractor |
|
|
- Head: DPT or DualDPT for depth prediction |
|
|
- Optional camera decoders for pose estimation |
|
|
- Optional GSDPT for 3DGS prediction |
|
|
|
|
|
Args: |
|
|
preset: Configuration preset containing network dimensions and settings |
|
|
|
|
|
Returns: |
|
|
Dictionary containing: |
|
|
- depth: Predicted depth map (B, H, W) |
|
|
- depth_conf: Depth confidence map (B, H, W) |
|
|
- extrinsics: Camera extrinsics (B, N, 4, 4) |
|
|
- intrinsics: Camera intrinsics (B, N, 3, 3) |
|
|
- gaussians: 3D Gaussian Splats (world space), type: model.gs_adapter.Gaussians |
|
|
- aux: Auxiliary features for specified layers |
|
|
""" |
|
|
|
|
|
|
|
|
PATCH_SIZE = 14 |
|
|
|
|
|
def __init__(self, net, head, cam_dec=None, cam_enc=None, gs_head=None, gs_adapter=None): |
|
|
""" |
|
|
Initialize DepthAnything3Net with given yaml-initialized configuration. |
|
|
""" |
|
|
super().__init__() |
|
|
self.backbone = net if isinstance(net, nn.Module) else create_object(_wrap_cfg(net)) |
|
|
self.head = head if isinstance(head, nn.Module) else create_object(_wrap_cfg(head)) |
|
|
self.cam_dec, self.cam_enc = None, None |
|
|
if cam_dec is not None: |
|
|
self.cam_dec = ( |
|
|
cam_dec if isinstance(cam_dec, nn.Module) else create_object(_wrap_cfg(cam_dec)) |
|
|
) |
|
|
self.cam_enc = ( |
|
|
cam_enc if isinstance(cam_enc, nn.Module) else create_object(_wrap_cfg(cam_enc)) |
|
|
) |
|
|
self.gs_adapter, self.gs_head = None, None |
|
|
if gs_head is not None and gs_adapter is not None: |
|
|
self.gs_adapter = ( |
|
|
gs_adapter |
|
|
if isinstance(gs_adapter, nn.Module) |
|
|
else create_object(_wrap_cfg(gs_adapter)) |
|
|
) |
|
|
gs_out_dim = self.gs_adapter.d_in + 1 |
|
|
if isinstance(gs_head, nn.Module): |
|
|
assert ( |
|
|
gs_head.out_dim == gs_out_dim |
|
|
), f"gs_head.out_dim should be {gs_out_dim}, got {gs_head.out_dim}" |
|
|
self.gs_head = gs_head |
|
|
else: |
|
|
assert ( |
|
|
gs_head["output_dim"] == gs_out_dim |
|
|
), f"gs_head output_dim should set to {gs_out_dim}, got {gs_head['output_dim']}" |
|
|
self.gs_head = create_object(_wrap_cfg(gs_head)) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
extrinsics: torch.Tensor | None = None, |
|
|
intrinsics: torch.Tensor | None = None, |
|
|
export_feat_layers: list[int] | None = [], |
|
|
infer_gs: bool = False, |
|
|
use_ray_pose: bool = False, |
|
|
ref_view_strategy: str = "saddle_balanced", |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Forward pass through the network. |
|
|
|
|
|
Args: |
|
|
x: Input images (B, N, 3, H, W) |
|
|
extrinsics: Camera extrinsics (B, N, 4, 4) |
|
|
intrinsics: Camera intrinsics (B, N, 3, 3) |
|
|
feat_layers: List of layer indices to extract features from |
|
|
infer_gs: Enable Gaussian Splatting branch |
|
|
use_ray_pose: Use ray-based pose estimation |
|
|
ref_view_strategy: Strategy for selecting reference view |
|
|
|
|
|
Returns: |
|
|
Dictionary containing predictions and auxiliary features |
|
|
""" |
|
|
|
|
|
if extrinsics is not None: |
|
|
with torch.autocast(device_type=x.device.type, enabled=False): |
|
|
cam_token = self.cam_enc(extrinsics, intrinsics, x.shape[-2:]) |
|
|
else: |
|
|
cam_token = None |
|
|
|
|
|
feats, aux_feats = self.backbone( |
|
|
x, cam_token=cam_token, export_feat_layers=export_feat_layers, ref_view_strategy=ref_view_strategy |
|
|
) |
|
|
|
|
|
H, W = x.shape[-2], x.shape[-1] |
|
|
|
|
|
|
|
|
with torch.autocast(device_type=x.device.type, enabled=False): |
|
|
output = self._process_depth_head(feats, H, W) |
|
|
if use_ray_pose: |
|
|
output = self._process_ray_pose_estimation(output, H, W) |
|
|
else: |
|
|
output = self._process_camera_estimation(feats, H, W, output) |
|
|
if infer_gs: |
|
|
output = self._process_gs_head(feats, H, W, output, x, extrinsics, intrinsics) |
|
|
|
|
|
output = self._process_mono_sky_estimation(output) |
|
|
|
|
|
|
|
|
output.aux = self._extract_auxiliary_features(aux_feats, export_feat_layers, H, W) |
|
|
|
|
|
return output |
|
|
|
|
|
def _process_mono_sky_estimation( |
|
|
self, output: Dict[str, torch.Tensor] |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
"""Process mono sky estimation.""" |
|
|
if "sky" not in output: |
|
|
return output |
|
|
non_sky_mask = compute_sky_mask(output.sky, threshold=0.3) |
|
|
if non_sky_mask.sum() <= 10: |
|
|
return output |
|
|
if (~non_sky_mask).sum() <= 10: |
|
|
return output |
|
|
|
|
|
non_sky_depth = output.depth[non_sky_mask] |
|
|
if non_sky_depth.numel() > 100000: |
|
|
idx = torch.randint(0, non_sky_depth.numel(), (100000,), device=non_sky_depth.device) |
|
|
sampled_depth = non_sky_depth[idx] |
|
|
else: |
|
|
sampled_depth = non_sky_depth |
|
|
non_sky_max = torch.quantile(sampled_depth, 0.99) |
|
|
|
|
|
|
|
|
output.depth, _ = set_sky_regions_to_max_depth( |
|
|
output.depth, None, non_sky_mask, max_depth=non_sky_max |
|
|
) |
|
|
return output |
|
|
|
|
|
def _process_ray_pose_estimation( |
|
|
self, output: Dict[str, torch.Tensor], height: int, width: int |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
"""Process ray pose estimation if ray pose decoder is available.""" |
|
|
if "ray" in output and "ray_conf" in output: |
|
|
pred_extrinsic, pred_focal_lengths, pred_principal_points = get_extrinsic_from_camray( |
|
|
output.ray, |
|
|
output.ray_conf, |
|
|
output.ray.shape[-3], |
|
|
output.ray.shape[-2], |
|
|
) |
|
|
pred_extrinsic = affine_inverse(pred_extrinsic) |
|
|
pred_extrinsic = pred_extrinsic[:, :, :3, :] |
|
|
pred_intrinsic = torch.eye(3, 3)[None, None].repeat(pred_extrinsic.shape[0], pred_extrinsic.shape[1], 1, 1).clone().to(pred_extrinsic.device) |
|
|
pred_intrinsic[:, :, 0, 0] = pred_focal_lengths[:, :, 0] / 2 * width |
|
|
pred_intrinsic[:, :, 1, 1] = pred_focal_lengths[:, :, 1] / 2 * height |
|
|
pred_intrinsic[:, :, 0, 2] = pred_principal_points[:, :, 0] * width * 0.5 |
|
|
pred_intrinsic[:, :, 1, 2] = pred_principal_points[:, :, 1] * height * 0.5 |
|
|
del output.ray |
|
|
del output.ray_conf |
|
|
output.extrinsics = pred_extrinsic |
|
|
output.intrinsics = pred_intrinsic |
|
|
return output |
|
|
|
|
|
def _process_depth_head( |
|
|
self, feats: list[torch.Tensor], H: int, W: int |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
"""Process features through the depth prediction head.""" |
|
|
return self.head(feats, H, W, patch_start_idx=0) |
|
|
|
|
|
def _process_camera_estimation( |
|
|
self, feats: list[torch.Tensor], H: int, W: int, output: Dict[str, torch.Tensor] |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
"""Process camera pose estimation if camera decoder is available.""" |
|
|
if self.cam_dec is not None: |
|
|
pose_enc = self.cam_dec(feats[-1][1]) |
|
|
|
|
|
if "ray" in output: |
|
|
del output.ray |
|
|
if "ray_conf" in output: |
|
|
del output.ray_conf |
|
|
|
|
|
|
|
|
c2w, ixt = pose_encoding_to_extri_intri(pose_enc, (H, W)) |
|
|
output.extrinsics = affine_inverse(c2w) |
|
|
output.intrinsics = ixt |
|
|
|
|
|
return output |
|
|
|
|
|
def _process_gs_head( |
|
|
self, |
|
|
feats: list[torch.Tensor], |
|
|
H: int, |
|
|
W: int, |
|
|
output: Dict[str, torch.Tensor], |
|
|
in_images: torch.Tensor, |
|
|
extrinsics: torch.Tensor | None = None, |
|
|
intrinsics: torch.Tensor | None = None, |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
"""Process 3DGS parameters estimation if 3DGS head is available.""" |
|
|
if self.gs_head is None or self.gs_adapter is None: |
|
|
return output |
|
|
assert output.get("depth", None) is not None, "must provide MV depth for the GS head." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx_extr = output.get("extrinsics", None) |
|
|
ctx_intr = output.get("intrinsics", None) |
|
|
assert ( |
|
|
ctx_extr is not None and ctx_intr is not None |
|
|
), "must process camera info first if GT is not available" |
|
|
|
|
|
gt_extr = extrinsics |
|
|
|
|
|
ctx_extr = as_homogeneous(ctx_extr) |
|
|
if gt_extr is not None: |
|
|
gt_extr = as_homogeneous(gt_extr) |
|
|
|
|
|
|
|
|
gs_outs = self.gs_head( |
|
|
feats=feats, |
|
|
H=H, |
|
|
W=W, |
|
|
patch_start_idx=0, |
|
|
images=in_images, |
|
|
) |
|
|
raw_gaussians = gs_outs.raw_gs |
|
|
densities = gs_outs.raw_gs_conf |
|
|
|
|
|
|
|
|
|
|
|
gs_world = self.gs_adapter( |
|
|
extrinsics=ctx_extr, |
|
|
intrinsics=ctx_intr, |
|
|
depths=output.depth, |
|
|
opacities=map_pdf_to_opacity(densities), |
|
|
raw_gaussians=raw_gaussians, |
|
|
image_shape=(H, W), |
|
|
gt_extrinsics=gt_extr, |
|
|
) |
|
|
output.gaussians = gs_world |
|
|
|
|
|
return output |
|
|
|
|
|
def _extract_auxiliary_features( |
|
|
self, feats: list[torch.Tensor], feat_layers: list[int], H: int, W: int |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
"""Extract auxiliary features from specified layers.""" |
|
|
aux_features = Dict() |
|
|
assert len(feats) == len(feat_layers) |
|
|
for feat, feat_layer in zip(feats, feat_layers): |
|
|
|
|
|
feat_reshaped = feat.reshape( |
|
|
[ |
|
|
feat.shape[0], |
|
|
feat.shape[1], |
|
|
H // self.PATCH_SIZE, |
|
|
W // self.PATCH_SIZE, |
|
|
feat.shape[-1], |
|
|
] |
|
|
) |
|
|
aux_features[f"feat_layer_{feat_layer}"] = feat_reshaped |
|
|
|
|
|
return aux_features |
|
|
|
|
|
|
|
|
class NestedDepthAnything3Net(nn.Module): |
|
|
""" |
|
|
Nested Depth Anything 3 network with metric scaling capabilities. |
|
|
|
|
|
This network combines two DepthAnything3Net branches: |
|
|
- Main branch: Standard depth estimation |
|
|
- Metric branch: Metric depth estimation for scaling alignment |
|
|
|
|
|
The network performs depth alignment using least squares scaling |
|
|
and handles sky region masking for improved depth estimation. |
|
|
|
|
|
Args: |
|
|
preset: Configuration for the main depth estimation branch |
|
|
second_preset: Configuration for the metric depth branch |
|
|
""" |
|
|
|
|
|
def __init__(self, anyview: DictConfig, metric: DictConfig): |
|
|
""" |
|
|
Initialize NestedDepthAnything3Net with two branches. |
|
|
|
|
|
Args: |
|
|
preset: Configuration for main depth estimation branch |
|
|
second_preset: Configuration for metric depth branch |
|
|
""" |
|
|
super().__init__() |
|
|
self.da3 = create_object(anyview) |
|
|
self.da3_metric = create_object(metric) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
extrinsics: torch.Tensor | None = None, |
|
|
intrinsics: torch.Tensor | None = None, |
|
|
export_feat_layers: list[int] | None = [], |
|
|
infer_gs: bool = False, |
|
|
use_ray_pose: bool = False, |
|
|
ref_view_strategy: str = "saddle_balanced", |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Forward pass through both branches with metric scaling alignment. |
|
|
|
|
|
Args: |
|
|
x: Input images (B, N, 3, H, W) |
|
|
extrinsics: Camera extrinsics (B, N, 4, 4) - unused |
|
|
intrinsics: Camera intrinsics (B, N, 3, 3) - unused |
|
|
feat_layers: List of layer indices to extract features from |
|
|
infer_gs: Enable Gaussian Splatting branch |
|
|
use_ray_pose: Use ray-based pose estimation |
|
|
ref_view_strategy: Strategy for selecting reference view |
|
|
|
|
|
Returns: |
|
|
Dictionary containing aligned depth predictions and camera parameters |
|
|
""" |
|
|
|
|
|
output = self.da3( |
|
|
x, extrinsics, intrinsics, export_feat_layers=export_feat_layers, infer_gs=infer_gs, use_ray_pose=use_ray_pose, ref_view_strategy=ref_view_strategy |
|
|
) |
|
|
metric_output = self.da3_metric(x) |
|
|
|
|
|
|
|
|
output = self._apply_metric_scaling(output, metric_output) |
|
|
output = self._apply_depth_alignment(output, metric_output) |
|
|
output = self._handle_sky_regions(output, metric_output) |
|
|
|
|
|
return output |
|
|
|
|
|
def _apply_metric_scaling( |
|
|
self, output: Dict[str, torch.Tensor], metric_output: Dict[str, torch.Tensor] |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
"""Apply metric scaling to the metric depth output.""" |
|
|
|
|
|
metric_output.depth = apply_metric_scaling( |
|
|
metric_output.depth, |
|
|
output.intrinsics, |
|
|
) |
|
|
return output |
|
|
|
|
|
def _apply_depth_alignment( |
|
|
self, output: Dict[str, torch.Tensor], metric_output: Dict[str, torch.Tensor] |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
"""Apply depth alignment using least squares scaling.""" |
|
|
|
|
|
non_sky_mask = compute_sky_mask(metric_output.sky, threshold=0.3) |
|
|
|
|
|
|
|
|
assert non_sky_mask.sum() > 10, "Insufficient non-sky pixels for alignment" |
|
|
|
|
|
|
|
|
depth_conf_ns = output.depth_conf[non_sky_mask] |
|
|
depth_conf_sampled = sample_tensor_for_quantile(depth_conf_ns, max_samples=100000) |
|
|
median_conf = torch.quantile(depth_conf_sampled, 0.5) |
|
|
|
|
|
|
|
|
align_mask = compute_alignment_mask( |
|
|
output.depth_conf, non_sky_mask, output.depth, metric_output.depth, median_conf |
|
|
) |
|
|
|
|
|
|
|
|
valid_depth = output.depth[align_mask] |
|
|
valid_metric_depth = metric_output.depth[align_mask] |
|
|
scale_factor = least_squares_scale_scalar(valid_metric_depth, valid_depth) |
|
|
|
|
|
|
|
|
output.depth *= scale_factor |
|
|
output.extrinsics[:, :, :3, 3] *= scale_factor |
|
|
output.is_metric = 1 |
|
|
output.scale_factor = scale_factor.item() |
|
|
|
|
|
return output |
|
|
|
|
|
def _handle_sky_regions( |
|
|
self, |
|
|
output: Dict[str, torch.Tensor], |
|
|
metric_output: Dict[str, torch.Tensor], |
|
|
sky_depth_def: float = 200.0, |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
"""Handle sky regions by setting them to maximum depth.""" |
|
|
non_sky_mask = compute_sky_mask(metric_output.sky, threshold=0.3) |
|
|
|
|
|
|
|
|
|
|
|
non_sky_depth = output.depth[non_sky_mask] |
|
|
if non_sky_depth.numel() > 100000: |
|
|
idx = torch.randint(0, non_sky_depth.numel(), (100000,), device=non_sky_depth.device) |
|
|
sampled_depth = non_sky_depth[idx] |
|
|
else: |
|
|
sampled_depth = non_sky_depth |
|
|
non_sky_max = min(torch.quantile(sampled_depth, 0.99), sky_depth_def) |
|
|
|
|
|
|
|
|
output.depth, output.depth_conf = set_sky_regions_to_max_depth( |
|
|
output.depth, output.depth_conf, non_sky_mask, max_depth=non_sky_max |
|
|
) |
|
|
|
|
|
return output |
|
|
|