| import torch | |
| import torch.nn as nn | |
| class Discriminator(nn.Module): | |
| r""" | |
| PatchGAN Discriminator. | |
| Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to | |
| 1 scalar value , we instead predict grid of values. | |
| Where each grid is prediction of how likely | |
| the discriminator thinks that the image patch corresponding | |
| to the grid cell is real | |
| """ | |
| def __init__( | |
| self, | |
| im_channels=3, | |
| conv_channels=[64, 128, 256], | |
| kernels=[4, 4, 4, 4], | |
| strides=[2, 2, 2, 1], | |
| paddings=[1, 1, 1, 1], | |
| ): | |
| super().__init__() | |
| self.im_channels = im_channels | |
| activation = nn.LeakyReLU(0.2) | |
| layers_dim = [self.im_channels] + conv_channels + [1] | |
| self.layers = nn.ModuleList( | |
| [ | |
| nn.Sequential( | |
| nn.Conv2d( | |
| layers_dim[i], | |
| layers_dim[i + 1], | |
| kernel_size=kernels[i], | |
| stride=strides[i], | |
| padding=paddings[i], | |
| bias=False if i != 0 else True, | |
| ), | |
| nn.BatchNorm2d(layers_dim[i + 1]) | |
| if i != len(layers_dim) - 2 and i != 0 | |
| else nn.Identity(), | |
| activation if i != len(layers_dim) - 2 else nn.Identity(), | |
| ) | |
| for i in range(len(layers_dim) - 1) | |
| ] | |
| ) | |
| def forward(self, x): | |
| out = x | |
| for layer in self.layers: | |
| out = layer(out) | |
| return out | |
| # if __name__ == "__main__": | |
| # x = torch.randn((2, 3, 256, 256)) | |
| # prob = Discriminator(im_channels=3)(x) | |
| # print(prob.shape) | |