File size: 3,791 Bytes
1dc7b3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import torch.nn as nn
import torch.nn.init as init
import torchvision.models as models
from torchvision.models import ResNet34_Weights


class ResNetEncoder(nn.Module):
    def __init__(self, freeze=True):
        super().__init__()
        resnet = models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
        
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        with torch.no_grad():
            self.conv1.weight[:] = resnet.conv1.weight.mean(dim=1, keepdim=True)

        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4

        if freeze:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, x):
        # x = (x - 0.449) / 0.226
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x1 = self.maxpool(x)
        x2 = self.layer1(x1)
        del x1

        x3 = self.layer2(x2)
        x4 = self.layer3(x3)
        x5 = self.layer4(x4)

        return x, x2, x3, x4, x5


def icnr(tensor, scale=2, init_func=init.kaiming_normal_):
    ni, nf, h, w = tensor.shape
    ni2 = int(ni / (scale ** 2))
    k = init_func(torch.zeros([ni2, nf, h, w]))
    k = k.repeat_interleave(scale ** 2, 0)
    with torch.no_grad():
        tensor.copy_(k)


class PixelShuffleICNR(nn.Module):
    def __init__(self, in_channels, out_channels, scale=2):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels * (scale ** 2), kernel_size=3, padding=1)
        icnr(self.conv.weight, scale=scale)
        self.pixel_shuffle = nn.PixelShuffle(scale)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class DecoderBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels):
        super().__init__()
        
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels + skip_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x, skip):
        x = self.upsample(x)
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
        return self.conv(x)


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.dec4 = DecoderBlock(512, 256, 256)
        self.dec3 = DecoderBlock(256, 128, 128)
        self.dec2 = DecoderBlock(128, 64, 64)
        self.dec1 = DecoderBlock(64, 64, 64)
        self.pixel_shuffle = PixelShuffleICNR(64, 16, scale=2)
        self.final = nn.Conv2d(16, 2, kernel_size=3, padding=1)

    def forward(self, x5, x4, x3, x2, x1):
        d4 = self.dec4(x5, x4)
        d3 = self.dec3(d4, x3)
        del d4, x4, x3
        d2 = self.dec2(d3, x2)
        del d3, x2
        d1 = self.dec1(d2, x1)
        del d2, x1
        out = self.pixel_shuffle(d1)
        del d1
        out = self.final(out)
        return torch.tanh(out)


class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ResNetEncoder()
        self.decoder = Decoder()

    def forward(self, x):
        x, x2, x3, x4, x5 = self.encoder(x)
        return self.decoder(x5, x4, x3, x2, x)