File size: 13,395 Bytes
c8b42eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
"""
Base class of the UniFlowMatch training system.
"""

from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union

import torch


@dataclass
class UFMFlowFieldOutput:
    """
    Output interface of the flow field prediction network.
    """

    flow_output: torch.Tensor
    flow_covariance: Optional[torch.Tensor] = None
    flow_covariance_inv: Optional[torch.Tensor] = None
    flow_covariance_log_det: Optional[torch.Tensor] = None


@dataclass
class UFMMaskFieldOutput:
    """
    Output interface of the mask prediction network.
    """

    mask: torch.Tensor
    logits: torch.Tensor


@dataclass
class UFMClassificationRefinementOutput:
    """
    Output interface of the classification refinement network.
    """

    # the flow output of the regression step, with shape [B, 2, H, W].
    # it is the initial flow output, which is used to get the first local feature maps for the residual.
    regression_flow_output: torch.Tensor

    # residual is the output of the refinement step, with shape [B, 2, H, W].
    # it is added to the initial flow output to get the final flow output.
    residual: torch.Tensor

    # log_softmax is
    #   the logarithm of
    #   the softmax of
    #   similarity of the pixel's feature
    #       to that of its neighborhood of the flow prediction
    #       in the other image.
    # it have shape [B, H, W, P, P], the similarity of pixel at [b, h, w] to its neighborhood [P, P] centered at regression_flow_output[b, h, w]
    log_softmax: torch.Tensor

    feature_map_0: torch.Tensor
    feature_map_1: torch.Tensor


@dataclass
class UFMOutputInterface:
    """
    Output interface of the UniFlowMatch training system.
    """

    flow: Optional[UFMFlowFieldOutput] = None

    # Refinement output (for training and visualization)
    classification_refinement: Optional[UFMClassificationRefinementOutput] = None

    # auxiliary ouputs
    covisibility: Optional[UFMMaskFieldOutput] = None


from uniception.models.encoders.image_normalizations import IMAGE_NORMALIZATION_DICT

from uniflowmatch.utils.flow_resizing import (
    AutomaticShapeSelection,
    ResizeToFixedManipulation,
    unmap_predicted_channels,
    unmap_predicted_flow,
)


class UniFlowMatchModelsBase(torch.nn.Module):
    def __init__(self, inference_resolution: Optional[Union[List[Tuple[int, int]], Tuple[int, int]]] = None):
        super().__init__()

        if inference_resolution is None:
            inference_resolution = [(560, 420)]

        if isinstance(inference_resolution[0], int):  # Handle the case for single resolution
            inference_resolution = [inference_resolution]

        self.inference_resolution = inference_resolution

        self.image_scaler = AutomaticShapeSelection(
            *[ResizeToFixedManipulation((resolution[1], resolution[0])) for resolution in inference_resolution],
            strategy="closest_aspect",  # will inference on the trained aspect ratio that is closest to the input image 1
        )

    def forward(self, view1, view2) -> UFMOutputInterface:
        """
        Forward interface of correspondence prediction networks.

        Args:
        - view1 (Dict[str, Any]): Input view 1
          - img (torch.Tensor): BCHW image tensor normalized according to encoder's data_norm_type
          - instance (List[int]): List of instance indices, or id of the input image
          - data_norm_type (str): Data normalization type, see uniception.models.encoders.IMAGE_NORMALIZATION_DICT
        - view2 (Dict[str, Any]): Input view 2
          - (same structure as view1)
        Returns:
        - Dict[str, Any]: Output results
          - flow [Required] (Dict[str, torch.Tensor]): Flow output
            - [Required] flow_output (torch.Tensor): Flow output tensor, BCHW
            - [Optional] flow_covariance
            - [Optional] flow_covariance_inv
            - [Optional] flow_covariance_log_det
          - occlusion [Optional] (Dict[str, torch.Tensor]): Occlusion output
            - [Optional] mask
            - [Optional] logits
        """
        raise NotImplementedError("Implement this method in derived classes")

    def get_parameter_groups(self) -> Dict[str, torch.nn.ParameterList]:
        """
        Get parameter groups for optimizer. This methods guides the optimizer
        to apply correct learning rate to different parts of the model.

        Returns:
        - Dict[str, torch.nn.ParameterList]: Parameter groups for optimizer
        """

        raise NotImplementedError("Implement this method in derived classes")

    def predict_correspondences_batched(
        self,
        source_image: torch.Tensor,
        target_image: torch.Tensor,
        data_norm_type: Optional[str] = None,
    ) -> UFMOutputInterface:
        """
        Predict correspondences between source and target images.

        This method generates random correspondences for demonstration purposes.

        Args:
            source_image (torch.Tensor): Tensor of shape BCHW/BHWC/CHW/HWC, dtype of uint8 or float32 The source image tensor.
            target_image (torch.Tensor): Tensor of shape BCHW/BHWC/CHW/HWC, dtype of uint8 or float32 The target image tensor.

        Returns:
            UFMOutputInterface:
                - flow
                    - flow_output       (torch.Tensor): Tensor of shape (B, 2, H, W) representing the flow output in the original image space.
                - occlusion
                    - mask              (torch.Tensor): Tensor of shape (B, H, W) representing the covisibility in range [0, 1]. 1 = fully covisible, 0 = fully occluded or out of range.
        """

        assert isinstance(source_image, torch.Tensor) and isinstance(
            target_image, torch.Tensor
        ), "source_image and target_image must be torch.Tensors"
        assert source_image.dim() in [3, 4], "source_image must have dimensions 3 or 4"
        assert target_image.dim() in [3, 4], "target_image must have dimensions 3 or 4"

        batched = source_image.dim() == 4

        if not batched:
            # add batch dimension
            source_image = source_image.unsqueeze(0)
            target_image = target_image.unsqueeze(0)

        # check the channel
        if source_image.shape[1] == 3 and target_image.shape[1] == 3:
            pass  # do nothing because the image is in BCHW format
        elif source_image.shape[-1] == 3 and target_image.shape[-1] == 3:
            # convert to BCHW
            source_image = source_image.permute(0, 3, 1, 2)
            target_image = target_image.permute(0, 3, 1, 2)
        else:
            raise ValueError("source_image and target_image must have 3 channels in either BCHW or BHWC format")

        required_data_norm_type = self.encoder.data_norm_type

        image_device = source_image.device

        if source_image.dtype == torch.float32:
            assert data_norm_type is not None, "data_norm_type must be provided for float32 images"
            assert (
                data_norm_type in IMAGE_NORMALIZATION_DICT
            ), f"data_norm_type must be one of {list(IMAGE_NORMALIZATION_DICT.keys())}"

            if data_norm_type != required_data_norm_type:
                # apply transformation to the correct from the old normalization
                prev_mean = (
                    IMAGE_NORMALIZATION_DICT[data_norm_type].mean.view(1, 3, 1, 1).to(image_device, non_blocking=True)
                )
                prev_std = (
                    IMAGE_NORMALIZATION_DICT[data_norm_type].std.view(1, 3, 1, 1).to(image_device, non_blocking=True)
                )
                mean = (
                    IMAGE_NORMALIZATION_DICT[required_data_norm_type]
                    .mean.view(1, 3, 1, 1)
                    .to(image_device, non_blocking=True)
                )
                std = (
                    IMAGE_NORMALIZATION_DICT[required_data_norm_type]
                    .std.view(1, 3, 1, 1)
                    .to(image_device, non_blocking=True)
                )

                source_image = source_image * (prev_std / std) + (prev_mean - mean) / std
                target_image = target_image * (prev_std / std) + (prev_mean - mean) / std

        elif source_image.dtype == torch.uint8:
            # convert into float32 and apply normalization
            mean = (
                IMAGE_NORMALIZATION_DICT[required_data_norm_type]
                .mean.view(1, 3, 1, 1)
                .to(image_device, non_blocking=True)
            )
            std = (
                IMAGE_NORMALIZATION_DICT[required_data_norm_type]
                .std.view(1, 3, 1, 1)
                .to(image_device, non_blocking=True)
            )

            source_image = (source_image.float() / 255.0 - mean) / std
            target_image = (target_image.float() / 255.0 - mean) / std
        else:
            raise ValueError("source_image and target_image must be of type torch.float32 or torch.uint8")

        # Now all the inputs are normalized according to the model's encoder and organized in BCHW format
        return self._predict_correspondences_batched(source_image, target_image)

    def _predict_correspondences_batched(
        self,
        source_image: torch.Tensor,
        target_image: torch.Tensor,
    ) -> UFMOutputInterface:
        assert isinstance(source_image, torch.Tensor), "source_image must be a torch.Tensor"
        assert isinstance(target_image, torch.Tensor), "target_image must be a torch.Tensor"

        assert source_image.dim() == 4, "source_image must be of shape (B, 3, H, W)"
        assert target_image.dim() == 4, "target_image must be of shape (B, 3, H, W)"
        assert source_image.shape[1] == 3, "source_image must be of shape (B, 3, H, W)"
        assert target_image.shape[1] == 3, "target_image must be of shape (B, 3, H, W)"

        assert source_image.dtype == torch.float32, "source_image must be of dtype torch.float32"
        assert target_image.dtype == torch.float32, "target_image must be of dtype torch.float32"

        source_shape_hw = source_image.shape[2:]
        target_shape_hw = target_image.shape[2:]

        # Scale images to one of the model's trained resolution.
        (
            scaled_img0,  # The scaled source image
            scaled_img1,  # The scaled target image
            img0_region_source,  # Where in the source image is captured in the scaled image
            img1_region_source,  # Where in the target image is captured in the scaled image
            img0_region_representation,  # Region in the source image is captured in this region in the scaled image
            img1_region_representation,  # same as above, but for the target image
        ) = self.image_scaler(source_image.permute(0, 2, 3, 1), target_image.permute(0, 2, 3, 1))

        scaled_img0 = scaled_img0.permute(0, 3, 1, 2)
        scaled_img1 = scaled_img1.permute(0, 3, 1, 2)

        # Run a forward pass
        view1 = {"img": scaled_img0, "symmetrized": False, "data_norm_type": self.encoder.data_norm_type}
        view2 = {"img": scaled_img1, "symmetrized": False, "data_norm_type": self.encoder.data_norm_type}

        with torch.no_grad():
            with torch.autocast("cuda", torch.bfloat16):
                result = self(view1, view2)

        rescaled_ufm_result = UFMOutputInterface()

        # rescale flow
        flow_output = result.flow.flow_output
        flow_unmapped, flow_unmap_validity = unmap_predicted_flow(
            flow_output,
            img0_region_representation,
            img1_region_representation,
            img0_region_source,
            img1_region_source,
            source_shape_hw,
            target_shape_hw,
        )

        rescaled_ufm_result.flow = UFMFlowFieldOutput(
            flow_output=flow_unmapped,
        )

        # rescale covariance if it exists
        if result.flow.flow_covariance is not None:
            flow_covariance = result.flow.flow_covariance
            flow_covariance_unmapped, _ = unmap_predicted_channels(
                flow_covariance,
                img0_region_representation,
                img0_region_source,
                source_shape_hw,
            )

            # scale covariance in the correct way
            w_pred = scaled_img0.shape[3]
            h_pred = scaled_img0.shape[2]

            w_final = source_shape_hw[1]
            h_final = source_shape_hw[0]

            w_ratio, h_ratio = w_final / w_pred, h_final / h_pred

            flow_covariance_unmapped *= (
                torch.tensor([w_ratio**2, h_ratio**2, w_ratio * h_ratio])
                .view(1, 3, 1, 1)
                .to(flow_covariance_unmapped.device)
            )

            rescaled_ufm_result.flow.flow_covariance = flow_covariance_unmapped

        # rescale occlusion if it exists
        if result.covisibility is not None:
            occlusion_mask = result.covisibility.mask
            covisibility_unmapped, _ = unmap_predicted_channels(
                occlusion_mask,
                img0_region_representation,
                img0_region_source,
                source_shape_hw,
            )

            covisibility_unmapped = covisibility_unmapped.squeeze(1)
            rescaled_ufm_result.covisibility = UFMMaskFieldOutput(mask=covisibility_unmapped, logits=None)

        return rescaled_ufm_result