File size: 4,501 Bytes
c20d7cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""Contains modules for different types of alignment.

For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""

from __future__ import annotations

import math

import torch
import torch.nn.functional as F
from torch import nn

from sharp.models.decoders import UNetDecoder
from sharp.models.encoders import UNetEncoder
from sharp.utils import math as math_utils

from .params import AlignmentParams


def create_alignment(
    params: AlignmentParams, depth_decoder_dim: int | None = None
) -> nn.Module | None:
    """Create depth alignment."""
    if depth_decoder_dim is None:
        raise ValueError("Requires depth_decoder_dim for LearnedAlignment.")
    alignment = LearnedAlignment(
        depth_decoder_features=params.depth_decoder_features,
        depth_decoder_dim=depth_decoder_dim,
        steps=params.steps,
        stride=params.stride,
        base_width=params.base_width,
        activation_type=params.activation_type,
    )

    if params.frozen:
        alignment.requires_grad_(False)

    return alignment


class LearnedAlignment(nn.Module):
    """Aligns tensors using a UNet."""

    def __init__(
        self,
        steps: int = 4,
        stride: int = 8,
        base_width: int = 16,
        depth_decoder_features: bool = False,
        depth_decoder_dim: int = 256,
        activation_type: math_utils.ActivationType = "exp",
    ) -> None:
        """Initialize LearnedAlignment.

        Args:
            steps: Number of steps in the UNet.
            stride: Effective downsampling of the alignment module.
            base_width: Base width of the UNet.
            depth_decoder_features: Whether to use depth decoder features.
            depth_decoder_dim: Dimension of the depth decoder features.
            activation_type: Activation type for the alignment output.
        """
        super().__init__()
        self.activation = math_utils.create_activation_pair(activation_type)
        bias_value = self.activation.inverse(torch.tensor(1.0))

        self.depth_decoder_features = depth_decoder_features
        if depth_decoder_features:
            dim_in = 2 + depth_decoder_dim
        else:
            dim_in = 2

        def is_power_of_two(n: int) -> bool:
            """Check if a number is a power of two."""
            if n <= 0:
                return False
            return (n & (n - 1)) == 0

        if not is_power_of_two(stride):
            raise ValueError(f"Stride {stride} is not a power of two.")

        steps_decoder = steps - int(math.log2(stride))
        if steps_decoder < 1:
            raise ValueError(f"{steps_decoder} must be greater or equal to 1.")
        widths = [min(base_width << i, 1024) for i in range(steps + 1)]
        self.encoder = UNetEncoder(dim_in=dim_in, width=widths, steps=steps, norm_num_groups=4)
        self.decoder = UNetDecoder(
            dim_out=widths[0], width=widths, steps=steps_decoder, norm_num_groups=4
        )
        self.conv_out = nn.Conv2d(widths[0], 1, 1, bias=True)
        nn.init.zeros_(self.conv_out.weight)
        nn.init.constant_(self.conv_out.bias, bias_value)

    def forward(
        self,
        tensor_src: torch.Tensor,
        tensor_tgt: torch.Tensor,
        depth_decoder_features: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Compute alignment map."""
        # Since the tensors are usually given by depth which is >= 1.0, we invert
        # the tensors to have them in a reasonable range.
        tensor_src = 1.0 / tensor_src.clamp(min=1e-4)
        tensor_tgt = 1.0 / tensor_tgt.clamp(min=1e-4)
        tensor_input = torch.cat([tensor_src, tensor_tgt], dim=1)
        if self.depth_decoder_features:
            height, width = tensor_src.shape[-2:]
            upsampled_encodings = F.interpolate(
                depth_decoder_features,
                size=(height, width),
                mode="bilinear",
            )
            tensor_input = torch.cat([tensor_input, upsampled_encodings], dim=1)
        features = self.encoder(tensor_input)
        output = self.conv_out(self.decoder(features))
        alignment_map_lowres = self.activation.forward(output)
        if alignment_map_lowres.shape[-2:] != tensor_src.shape[-2]:
            alignment_map = F.interpolate(
                alignment_map_lowres,
                size=tensor_src.shape[-2:],
                mode="bilinear",
                align_corners=False,
            )
        return alignment_map