File size: 2,730 Bytes
30f8290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch

from torch import nn
from einops.layers.torch import Rearrange

def init_layer(layer):
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias") and layer.bias is not None: 
        layer.bias.data.fill_(0.0)

def init_bn(bn):
    bn.bias.data.fill_(0.0)
    bn.weight.data.fill_(1.0)
    bn.running_mean.data.fill_(0.0)
    bn.running_var.data.fill_(1.0)

class BiGRU(nn.Module):
    def __init__(
        self, 
        patch_size, 
        channels, 
        depth
    ):
        super(BiGRU, self).__init__()
        patch_width, patch_height = patch_size
        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange(
                'b c (w p1) (h p2) -> b (w h) (p1 p2 c)', 
                p1=patch_width, 
                p2=patch_height
            )
        )

        self.gru = nn.GRU(
            patch_dim, 
            patch_dim // 2, 
            num_layers=depth, 
            batch_first=True, 
            bidirectional=True
        )

    def forward(self, x):
        x = self.to_patch_embedding(x)

        try:
            return self.gru(x)[0]
        except:
            torch.backends.cudnn.enabled = False
            return self.gru(x)[0]

class ResConvBlock(nn.Module):
    def __init__(
        self, 
        in_planes, 
        out_planes
    ):
        super(ResConvBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(
            in_planes, 
            momentum=0.01
        )
        self.bn2 = nn.BatchNorm2d(
            out_planes, 
            momentum=0.01
        )
        self.act1 = nn.PReLU()
        self.act2 = nn.PReLU()
        self.conv1 = nn.Conv2d(
            in_planes, 
            out_planes, 
            (3, 3), 
            padding=(1, 1), 
            bias=False
        )
        self.conv2 = nn.Conv2d(
            out_planes, 
            out_planes, 
            (3, 3), 
            padding=(1, 1), 
            bias=False
        )
        self.is_shortcut = False

        if in_planes != out_planes:
            self.shortcut = nn.Conv2d(
                in_planes, 
                out_planes, 
                (1, 1)
            )
            self.is_shortcut = True

        self.init_weights()

    def init_weights(self):
        init_bn(self.bn1)
        init_bn(self.bn2)

        init_layer(self.conv1)
        init_layer(self.conv2)

        if self.is_shortcut: init_layer(self.shortcut)

    def forward(self, x):
        out = self.conv1(
            self.act1(self.bn1(x))
        )
        out = self.conv2(
            self.act2(self.bn2(out))
        )

        if self.is_shortcut: return self.shortcut(x) + out
        else: return out + x