ml-sharp / src /sharp /models /alignment.py
amael-apple's picture
Initial commit
c20d7cc
raw
history blame
4.5 kB
"""Contains modules for different types of alignment.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
import math
import torch
import torch.nn.functional as F
from torch import nn
from sharp.models.decoders import UNetDecoder
from sharp.models.encoders import UNetEncoder
from sharp.utils import math as math_utils
from .params import AlignmentParams
def create_alignment(
params: AlignmentParams, depth_decoder_dim: int | None = None
) -> nn.Module | None:
"""Create depth alignment."""
if depth_decoder_dim is None:
raise ValueError("Requires depth_decoder_dim for LearnedAlignment.")
alignment = LearnedAlignment(
depth_decoder_features=params.depth_decoder_features,
depth_decoder_dim=depth_decoder_dim,
steps=params.steps,
stride=params.stride,
base_width=params.base_width,
activation_type=params.activation_type,
)
if params.frozen:
alignment.requires_grad_(False)
return alignment
class LearnedAlignment(nn.Module):
"""Aligns tensors using a UNet."""
def __init__(
self,
steps: int = 4,
stride: int = 8,
base_width: int = 16,
depth_decoder_features: bool = False,
depth_decoder_dim: int = 256,
activation_type: math_utils.ActivationType = "exp",
) -> None:
"""Initialize LearnedAlignment.
Args:
steps: Number of steps in the UNet.
stride: Effective downsampling of the alignment module.
base_width: Base width of the UNet.
depth_decoder_features: Whether to use depth decoder features.
depth_decoder_dim: Dimension of the depth decoder features.
activation_type: Activation type for the alignment output.
"""
super().__init__()
self.activation = math_utils.create_activation_pair(activation_type)
bias_value = self.activation.inverse(torch.tensor(1.0))
self.depth_decoder_features = depth_decoder_features
if depth_decoder_features:
dim_in = 2 + depth_decoder_dim
else:
dim_in = 2
def is_power_of_two(n: int) -> bool:
"""Check if a number is a power of two."""
if n <= 0:
return False
return (n & (n - 1)) == 0
if not is_power_of_two(stride):
raise ValueError(f"Stride {stride} is not a power of two.")
steps_decoder = steps - int(math.log2(stride))
if steps_decoder < 1:
raise ValueError(f"{steps_decoder} must be greater or equal to 1.")
widths = [min(base_width << i, 1024) for i in range(steps + 1)]
self.encoder = UNetEncoder(dim_in=dim_in, width=widths, steps=steps, norm_num_groups=4)
self.decoder = UNetDecoder(
dim_out=widths[0], width=widths, steps=steps_decoder, norm_num_groups=4
)
self.conv_out = nn.Conv2d(widths[0], 1, 1, bias=True)
nn.init.zeros_(self.conv_out.weight)
nn.init.constant_(self.conv_out.bias, bias_value)
def forward(
self,
tensor_src: torch.Tensor,
tensor_tgt: torch.Tensor,
depth_decoder_features: torch.Tensor | None = None,
) -> torch.Tensor:
"""Compute alignment map."""
# Since the tensors are usually given by depth which is >= 1.0, we invert
# the tensors to have them in a reasonable range.
tensor_src = 1.0 / tensor_src.clamp(min=1e-4)
tensor_tgt = 1.0 / tensor_tgt.clamp(min=1e-4)
tensor_input = torch.cat([tensor_src, tensor_tgt], dim=1)
if self.depth_decoder_features:
height, width = tensor_src.shape[-2:]
upsampled_encodings = F.interpolate(
depth_decoder_features,
size=(height, width),
mode="bilinear",
)
tensor_input = torch.cat([tensor_input, upsampled_encodings], dim=1)
features = self.encoder(tensor_input)
output = self.conv_out(self.decoder(features))
alignment_map_lowres = self.activation.forward(output)
if alignment_map_lowres.shape[-2:] != tensor_src.shape[-2]:
alignment_map = F.interpolate(
alignment_map_lowres,
size=tensor_src.shape[-2:],
mode="bilinear",
align_corners=False,
)
return alignment_map