File size: 3,945 Bytes
24e9ac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
from torch import nn


def generate_patches(input, fstride, tstride, fshape, tshape):
    r"""Function that extract patches from tensors and stacks them.

    See :class:`~kornia.contrib.ExtractTensorPatches` for details.

    Args:
        input: tensor image where to extract the patches with shape :math:`(B, C, H, W)`.

    Returns:
        the tensor with the extracted patches with shape :math:`(B, N, C, H_{out}, W_{out})`.

    Examples:
        >>> input = torch.arange(9.).view(1, 1, 3, 3)
        >>> patches = extract_tensor_patches(input, (2, 3))
        >>> input
        tensor([[[[0., 1., 2.],
                  [3., 4., 5.],
                  [6., 7., 8.]]]])
        >>> patches[:, -1]
        tensor([[[[3., 4., 5.],
                  [6., 7., 8.]]]])

    """
    batch_size, num_channels = input.size()[:2]
    dims = range(2, input.dim())
    for dim, patch_size, stride in zip(dims, (fshape, tshape), (fstride, tstride)):
        input = input.unfold(dim, patch_size, stride)
    input = input.permute(0, *dims, 1, *(dim + len(dims) for dim in dims)).contiguous()
    return input.view(batch_size, -1, num_channels, fshape, tshape)


def combine_patches(
    patches,
    original_size,
    fstride,
    tstride,
    fshape,
    tshape,
    eps: float = 1e-8,
):
    r"""Restore input from patches.

    See :class:`~kornia.contrib.CombineTensorPatches` for details.

    Args:
        patches: patched tensor with shape :math:`(B, N, C, H_{out}, W_{out})`.

    Return:
        The combined patches in an image tensor with shape :math:`(B, C, H, W)`.

    Example:
        >>> out = extract_tensor_patches(torch.arange(16).view(1, 1, 4, 4), window_size=(2, 2), stride=(2, 2))
        >>> combine_tensor_patches(out, original_size=(4, 4), window_size=(2, 2), stride=(2, 2))
        tensor([[[[ 0,  1,  2,  3],
                  [ 4,  5,  6,  7],
                  [ 8,  9, 10, 11],
                  [12, 13, 14, 15]]]])

    .. note::
        This function is supposed to be used in conjunction with :func:`extract_tensor_patches`.

    """
    if patches.ndim != 5:
        raise ValueError(
            f"Invalid input shape, we expect BxNxCxHxW. Got: {patches.shape}"
        )
    ones = torch.ones(
        patches.shape[0],
        patches.shape[2],
        original_size[0],
        original_size[1],
        device=patches.device,
        dtype=patches.dtype,
    )
    restored_size = ones.shape[2:]

    patches = patches.permute(0, 2, 3, 4, 1)
    patches = patches.reshape(patches.shape[0], -1, patches.shape[-1])
    int_flag = 0
    if not torch.is_floating_point(patches):
        int_flag = 1
        dtype = patches.dtype
        patches = patches.float()
        ones = ones.float()

    # Calculate normalization map
    unfold_ones = torch.nn.functional.unfold(
        ones, kernel_size=(fshape, tshape), stride=(fstride, tstride)
    )
    norm_map = torch.nn.functional.fold(
        input=unfold_ones,
        output_size=restored_size,
        kernel_size=(fshape, tshape),
        stride=(fstride, tstride),
    )
    # Restored tensor
    saturated_restored_tensor = torch.nn.functional.fold(
        input=patches,
        output_size=restored_size,
        kernel_size=(fshape, tshape),
        stride=(fstride, tstride),
    )
    # Remove satuation effect due to multiple summations
    restored_tensor = saturated_restored_tensor / (norm_map + eps)
    if int_flag:
        restored_tensor = restored_tensor.to(dtype)
    return restored_tensor


# get the shape of intermediate representation.
def get_shape(fstride, tstride, input_fdim, input_tdim, fshape, tshape):
    test_input = torch.randn(1, 2, input_fdim, input_tdim)
    test_proj = nn.Conv2d(
        2,
        2,
        kernel_size=(fshape, tshape),
        stride=(fstride, tstride),
    )
    test_out = test_proj(test_input)
    f_dim = test_out.shape[2]
    t_dim = test_out.shape[3]
    return f_dim, t_dim