File size: 4,134 Bytes
c6d5483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch.nn as nn
import torch
import albumentations as A


# CNN block will be used repeatly later
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=2):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, bias=False, padding_mode='reflect'),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.conv(x)


class PatchGan(torch.nn.Module):
    """ Patch GAN Architecture """

    @staticmethod
    def create_contracting_block(in_channels: int, out_channels: int):
        """
        Create encoding layer
        :param in_channels:
        :param out_channels:
        :return:
        """
        conv_layer = torch.nn.Sequential(
            torch.nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3,
                padding=1,
            ),
            torch.nn.ReLU(),
            torch.nn.Conv2d(
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=3,
                padding=1,
            ),
            torch.nn.ReLU(),
        )
        max_pool = torch.nn.Sequential(
            torch.nn.MaxPool2d(
                stride=2,
                kernel_size=2,
            ),
        )
        layer = torch.nn.Sequential(
            conv_layer,
            max_pool,
        )
        return layer

    def __init__(self, input_channels: int, hidden_channels: int) -> None:
        super().__init__()
        self.resize_channels = torch.nn.Conv2d(
            in_channels=input_channels,
            out_channels=hidden_channels,
            kernel_size=1,
        )

        self.enc1 = self.create_contracting_block(
            in_channels=hidden_channels,
            out_channels=hidden_channels * 2
        )

        self.enc2 = self.create_contracting_block(
            in_channels=hidden_channels * 2,
            out_channels=hidden_channels * 4
        )

        self.enc3 = self.create_contracting_block(
            in_channels=hidden_channels * 4,
            out_channels=hidden_channels * 8
        )
        self.enc4 = self.create_contracting_block(
            in_channels=hidden_channels * 8,
            out_channels=hidden_channels * 16
        )

        self.final_layer = torch.nn.Conv2d(
            in_channels=hidden_channels * 16,
            out_channels=1,
            kernel_size=1,
        )

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """ Forward patch gan layer """
        inpt = torch.cat([x, y], axis=1)
        resize_img = self.resize_channels(inpt)
        enc1 = self.enc1(resize_img)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        final_layer = self.final_layer(enc4)
        return final_layer


# x, y <- concatenate the gen image and the input image to determin the gen image is real or not
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels * 2, features[0], kernel_size=4, stride=2, padding=1, padding_mode='reflect'),
            nn.LeakyReLU(.2)
        )

        # save layers into a list
        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                CNNBlock(
                    in_channels,
                    feature,
                    stride=1 if feature == features[-1] else 2
                ),
            )
            in_channels = feature

        # append last conv layer
        layers.append(
            nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode='reflect')
        )

        # create a model using the list of layers
        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        x = torch.cat([x, y], dim=1)
        x = self.initial(x)
        return self.model(x)