File size: 4,128 Bytes
fadb92b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import einops
import torch
import torch.nn.functional as F
from torch import nn


class LayerNormProxy(nn.Module):
    def __init__(self, dim):

        super().__init__()
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):

        x = einops.rearrange(x, "b c h w -> b h w c")
        x = self.norm(x)
        return einops.rearrange(x, "b h w c -> b c h w")


class MSDeformablePoints(nn.Module):
    def __init__(
        self,
        embed_dim,
        n_levels,
        n_heads,
        offset_range_factor=-1,
    ):

        super().__init__()
        self.n_head_channels = embed_dim // n_heads
        self.scale = self.n_head_channels**-0.5
        self.n_heads = n_heads
        self.nc = self.n_head_channels * n_heads
        self.offset_range_factor = offset_range_factor

        self.kernel_sizes = [(n_levels - 1 - i) * 2 + 1 for i in range(n_levels)]  # [7, 5, 3, 1]
        self.strides = [2 ** (n_levels - i) for i in range(n_levels)]  # [16, 8, 4, 2]

        self.conv_offset = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Conv2d(
                        self.n_head_channels,
                        self.n_head_channels,
                        self.kernel_sizes[i],
                        self.strides[i],
                        self.kernel_sizes[i] // 2,
                        groups=self.n_heads,
                    ),
                    LayerNormProxy(self.n_head_channels),
                    nn.GELU(),
                    nn.Conv2d(self.n_head_channels, 2, 1, 1, 0, bias=False),
                )
                for i in range(n_levels)
            ]
        )

        self.proj_q = nn.ModuleList(
            [nn.Conv2d(self.nc, self.nc, kernel_size=1, stride=1, padding=0) for _ in range(n_levels)]
        )

    @torch.no_grad()
    def _get_ref_points(self, H_key, W_key, B, dtype, device):

        ref_y, ref_x = torch.meshgrid(
            torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),
            torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device),
            indexing="ij",
        )
        ref = torch.stack((ref_y, ref_x), -1)
        ref[..., 1].div_(W_key).mul_(2.0).sub_(1.0)
        ref[..., 0].div_(H_key).mul_(2.0).sub_(1.0)
        ref = ref[None, ...].expand(B * self.n_heads, -1, -1, -1)  # [B*g, H, W, 2]

        return ref

    def forward(self, x, spatial_shapes, level_start_index):
        B = x.size(0)
        dtype, device = x.dtype, x.device

        x_list = x.split([H_ * W_ for H_, W_ in spatial_shapes], dim=1)
        out = []
        for i in range(len(x_list)):
            cur_x = x_list[i]
            q = self.proj_q[i](
                einops.rearrange(cur_x, "b (h w) c -> b c h w", h=spatial_shapes[i][0], w=spatial_shapes[i][1])
            )
            q_off = einops.rearrange(q, "b (g c) h w -> (b g) c h w", g=self.n_heads, c=self.n_head_channels)
            offset = self.conv_offset[i](q_off).contiguous()  # [B*g, 2, Hg Wg]
            Hk, Wk = offset.size(2), offset.size(3)

            if self.offset_range_factor >= 0:
                offset_range = torch.tensor([1.0 / Hk, 1.0 / Wk], device=device).reshape(1, 2, 1, 1)
                offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)

            offset = einops.rearrange(offset, "b two h w -> b h w two")
            reference = self._get_ref_points(Hk, Wk, B, dtype, device)

            if self.offset_range_factor >= 0:
                pos = offset + reference
            else:
                pos = (offset + reference).clamp(-1.0, +1.0)

            H, W = spatial_shapes[i]
            x_sampled = F.grid_sample(
                input=cur_x.reshape(B * self.n_heads, self.n_head_channels, H, W),  # [B*g, Cg, H, W]
                grid=pos[..., (1, 0)],  # y, x -> x, y: [B*g, Hg, Wg, 2]
                mode="bilinear",
                align_corners=True,
            )  # [B*g, Cg, Hg, Wg]

            x_sampled = einops.rearrange(x_sampled, "(B g) C H W -> B (H W) (g C)", B=B)
            out.append(x_sampled)
        return torch.cat(out, dim=1)