| import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class fourierConv(torch.nn.Module): |
|
|
| def __init__(self, in_channels, out_channels, use_batch_norm=True): |
| super(fourierConv, self).__init__() |
|
|
| conv_in_channels = in_channels * 2 |
| layers = torch.nn.Sequential( |
| torch.nn.Conv2d(conv_in_channels, out_channels, kernel_size=1, padding=0, bias=not use_batch_norm), |
| torch.nn.ReLU(inplace=True), |
| ) |
| self.linear = layers |
|
|
| def forward(self, x): |
|
|
| orig_h, orig_w = x.shape[-2:] |
|
|
| fft = torch.fft.rfft2(x, norm='ortho') |
| fft_real = fft.real |
| fft_imag = fft.imag |
|
|
|
|
| fft_features = torch.cat([fft_real, fft_imag], dim=1) |
|
|
| out_features = self.linear(fft_features) |
|
|
|
|
| return x |
| |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class UNetInpaint(torch.nn.Module): |
| def __init__(self, input_channels=4, output_channels=3): |
| super().__init__() |
|
|
| |
| self.enc1 = self.conv_block(input_channels, 64) |
| self.enc2 = self.conv_block(64, 128) |
| self.enc3 = self.conv_block(128, 256) |
| self.enc4 = self.conv_block(256, 512) |
|
|
| self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
| |
| |
| self.bottleneck = self.conv_block(512, 1024) |
|
|
| |
| self.upconv4 = self.up_conv_block(1024, 512) |
| self.dec4 = self.conv_block(1024, 512) |
|
|
| self.upconv3 = self.up_conv_block(512, 256) |
| self.dec3 = self.conv_block(512, 256) |
|
|
| self.upconv2 = self.up_conv_block(256, 128) |
| self.dec2 = self.conv_block(256, 128) |
|
|
| self.upconv1 = self.up_conv_block(128, 64) |
| self.dec1 = self.conv_block(128, 64) |
|
|
| |
| self.out_conv = torch.nn.Conv2d(64, output_channels, kernel_size=1) |
| self.final_activation = torch.nn.Sigmoid() |
|
|
| def conv_block(self, in_channels, out_channels): |
| return torch.nn.Sequential( |
| torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), |
| torch.nn.BatchNorm2d(out_channels), |
| torch.nn.ReLU(inplace=True), |
| torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), |
| torch.nn.BatchNorm2d(out_channels), |
| torch.nn.ReLU(inplace=True) |
| ) |
|
|
| def up_conv_block(self, in_channels, out_channels): |
|
|
| return torch.nn.Sequential( |
| torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), |
| torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), |
| torch.nn.BatchNorm2d(out_channels), |
| torch.nn.ReLU(inplace=True) |
| ) |
|
|
|
|
|
|
| def forward(self, x): |
| |
| e1 = self.enc1(x) |
| e2 = self.enc2(self.pool(e1)) |
| e3 = self.enc3(self.pool(e2)) |
| e4 = self.enc4(self.pool(e3)) |
|
|
| |
| b = self.bottleneck(self.pool(e4)) |
|
|
| |
| |
|
|
|
|
| |
| d4 = self.upconv4(b) |
| d4 = torch.cat((d4, e4), dim=1) |
| d4 = self.dec4(d4) |
|
|
| d3 = self.upconv3(d4) |
| d3 = torch.cat((d3, e3), dim=1) |
| d3 = self.dec3(d3) |
|
|
| d2 = self.upconv2(d3) |
| d2 = torch.cat((d2, e2), dim=1) |
| d2 = self.dec2(d2) |
|
|
| d1 = self.upconv1(d2) |
| d1 = torch.cat((d1, e1), dim=1) |
| d1 = self.dec1(d1) |
|
|
| |
| out = self.out_conv(d1) |
| out = self.final_activation(out) |
|
|
| return out |
|
|
|
|