File size: 2,152 Bytes
646f45c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from torch import nn
from typing import Optional, Union, Tuple


class ConvLayer2d(nn.Module):
    def __init__(

        self,

        in_channels: int,

        out_channels: int,

        kernel_size: Union[int, Tuple[int, int]],

        stride: int = 1,

        padding: int = 0,

        dilation: int = 1,

        groups: int = 1,

        bias: bool = False,

        use_norm: bool = True,

        use_act: bool = True,

        norm_layer: Optional[nn.Module] = None,

        act_layer: Optional[nn.Module] = None,

    ):
        super().__init__()
        layers = []
        layers.append(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation,
                groups=groups,
                bias=bias
            )
        )
        if use_norm:
            if norm_layer is None:
                norm_layer = nn.BatchNorm2d(out_channels)
            layers.append(norm_layer)
        if use_act:
            if act_layer is None:
                act_layer = nn.ReLU(inplace=True)
            layers.append(act_layer)

        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

# PEG  from https://arxiv.org/abs/2102.10882


class PosCNN(nn.Module):
    def __init__(self, in_chans, embed_dim=None, s=1):
        super(PosCNN, self).__init__()
        self.proj = nn.Sequential(
            nn.Conv2d(in_chans, embed_dim, 3, s, 1,
                      bias=True, groups=embed_dim),
        )
        self.s = s

    def forward(self, x, H, W):
        B, N, C = x.shape

        feat_token = x
        cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
        if self.s == 1:
            x = self.proj(cnn_feat) + cnn_feat
        else:
            x = self.proj(cnn_feat)
        x = x.flatten(2).transpose(1, 2)
        return x

    def no_weight_decay(self):
        return ["proj.%d.weight" % i for i in range(4)]