|
|
"""Contains modules for different types of alignment. |
|
|
|
|
|
For licensing see accompanying LICENSE file. |
|
|
Copyright (C) 2025 Apple Inc. All Rights Reserved. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import math |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import nn |
|
|
|
|
|
from sharp.models.decoders import UNetDecoder |
|
|
from sharp.models.encoders import UNetEncoder |
|
|
from sharp.utils import math as math_utils |
|
|
|
|
|
from .params import AlignmentParams |
|
|
|
|
|
|
|
|
def create_alignment( |
|
|
params: AlignmentParams, depth_decoder_dim: int | None = None |
|
|
) -> nn.Module | None: |
|
|
"""Create depth alignment.""" |
|
|
if depth_decoder_dim is None: |
|
|
raise ValueError("Requires depth_decoder_dim for LearnedAlignment.") |
|
|
alignment = LearnedAlignment( |
|
|
depth_decoder_features=params.depth_decoder_features, |
|
|
depth_decoder_dim=depth_decoder_dim, |
|
|
steps=params.steps, |
|
|
stride=params.stride, |
|
|
base_width=params.base_width, |
|
|
activation_type=params.activation_type, |
|
|
) |
|
|
|
|
|
if params.frozen: |
|
|
alignment.requires_grad_(False) |
|
|
|
|
|
return alignment |
|
|
|
|
|
|
|
|
class LearnedAlignment(nn.Module): |
|
|
"""Aligns tensors using a UNet.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
steps: int = 4, |
|
|
stride: int = 8, |
|
|
base_width: int = 16, |
|
|
depth_decoder_features: bool = False, |
|
|
depth_decoder_dim: int = 256, |
|
|
activation_type: math_utils.ActivationType = "exp", |
|
|
) -> None: |
|
|
"""Initialize LearnedAlignment. |
|
|
|
|
|
Args: |
|
|
steps: Number of steps in the UNet. |
|
|
stride: Effective downsampling of the alignment module. |
|
|
base_width: Base width of the UNet. |
|
|
depth_decoder_features: Whether to use depth decoder features. |
|
|
depth_decoder_dim: Dimension of the depth decoder features. |
|
|
activation_type: Activation type for the alignment output. |
|
|
""" |
|
|
super().__init__() |
|
|
self.activation = math_utils.create_activation_pair(activation_type) |
|
|
bias_value = self.activation.inverse(torch.tensor(1.0)) |
|
|
|
|
|
self.depth_decoder_features = depth_decoder_features |
|
|
if depth_decoder_features: |
|
|
dim_in = 2 + depth_decoder_dim |
|
|
else: |
|
|
dim_in = 2 |
|
|
|
|
|
def is_power_of_two(n: int) -> bool: |
|
|
"""Check if a number is a power of two.""" |
|
|
if n <= 0: |
|
|
return False |
|
|
return (n & (n - 1)) == 0 |
|
|
|
|
|
if not is_power_of_two(stride): |
|
|
raise ValueError(f"Stride {stride} is not a power of two.") |
|
|
|
|
|
steps_decoder = steps - int(math.log2(stride)) |
|
|
if steps_decoder < 1: |
|
|
raise ValueError(f"{steps_decoder} must be greater or equal to 1.") |
|
|
widths = [min(base_width << i, 1024) for i in range(steps + 1)] |
|
|
self.encoder = UNetEncoder(dim_in=dim_in, width=widths, steps=steps, norm_num_groups=4) |
|
|
self.decoder = UNetDecoder( |
|
|
dim_out=widths[0], width=widths, steps=steps_decoder, norm_num_groups=4 |
|
|
) |
|
|
self.conv_out = nn.Conv2d(widths[0], 1, 1, bias=True) |
|
|
nn.init.zeros_(self.conv_out.weight) |
|
|
nn.init.constant_(self.conv_out.bias, bias_value) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
tensor_src: torch.Tensor, |
|
|
tensor_tgt: torch.Tensor, |
|
|
depth_decoder_features: torch.Tensor | None = None, |
|
|
) -> torch.Tensor: |
|
|
"""Compute alignment map.""" |
|
|
|
|
|
|
|
|
tensor_src = 1.0 / tensor_src.clamp(min=1e-4) |
|
|
tensor_tgt = 1.0 / tensor_tgt.clamp(min=1e-4) |
|
|
tensor_input = torch.cat([tensor_src, tensor_tgt], dim=1) |
|
|
if self.depth_decoder_features: |
|
|
height, width = tensor_src.shape[-2:] |
|
|
upsampled_encodings = F.interpolate( |
|
|
depth_decoder_features, |
|
|
size=(height, width), |
|
|
mode="bilinear", |
|
|
) |
|
|
tensor_input = torch.cat([tensor_input, upsampled_encodings], dim=1) |
|
|
features = self.encoder(tensor_input) |
|
|
output = self.conv_out(self.decoder(features)) |
|
|
alignment_map_lowres = self.activation.forward(output) |
|
|
if alignment_map_lowres.shape[-2:] != tensor_src.shape[-2]: |
|
|
alignment_map = F.interpolate( |
|
|
alignment_map_lowres, |
|
|
size=tensor_src.shape[-2:], |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
) |
|
|
return alignment_map |
|
|
|