SVS_oe3 / model /model.py
shivamkunkolikar
May5 8:51PM
4f75b6d
import torch
class fourierConv(torch.nn.Module):
def __init__(self, in_channels, out_channels, use_batch_norm=True):
super(fourierConv, self).__init__()
conv_in_channels = in_channels * 2
layers = torch.nn.Sequential(
torch.nn.Conv2d(conv_in_channels, out_channels, kernel_size=1, padding=0, bias=not use_batch_norm),
torch.nn.ReLU(inplace=True),
)
self.linear = layers
def forward(self, x):
orig_h, orig_w = x.shape[-2:]
fft = torch.fft.rfft2(x, norm='ortho')
fft_real = fft.real
fft_imag = fft.imag
fft_features = torch.cat([fft_real, fft_imag], dim=1)
out_features = self.linear(fft_features)
return x
class UNetInpaint(torch.nn.Module):
def __init__(self, input_channels=4, output_channels=3):
super().__init__()
# Encoder layers
self.enc1 = self.conv_block(input_channels, 64)
self.enc2 = self.conv_block(64, 128)
self.enc3 = self.conv_block(128, 256)
self.enc4 = self.conv_block(256, 512)
self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
# self.bottleneck_ff1 = fourierConv(512, 512)
# self.bottleneck_ff2 = fourierConv(512, 512)
self.bottleneck = self.conv_block(512, 1024)
# Decoder
self.upconv4 = self.up_conv_block(1024, 512)
self.dec4 = self.conv_block(1024, 512)
self.upconv3 = self.up_conv_block(512, 256)
self.dec3 = self.conv_block(512, 256)
self.upconv2 = self.up_conv_block(256, 128)
self.dec2 = self.conv_block(256, 128)
self.upconv1 = self.up_conv_block(128, 64)
self.dec1 = self.conv_block(128, 64)
# Final layer
self.out_conv = torch.nn.Conv2d(64, output_channels, kernel_size=1)
self.final_activation = torch.nn.Sigmoid()
def conv_block(self, in_channels, out_channels):
return torch.nn.Sequential(
torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
torch.nn.BatchNorm2d(out_channels),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
torch.nn.BatchNorm2d(out_channels),
torch.nn.ReLU(inplace=True)
)
def up_conv_block(self, in_channels, out_channels):
return torch.nn.Sequential(
torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
torch.nn.BatchNorm2d(out_channels),
torch.nn.ReLU(inplace=True)
)
def forward(self, x):
# Encoder
e1 = self.enc1(x)
e2 = self.enc2(self.pool(e1))
e3 = self.enc3(self.pool(e2))
e4 = self.enc4(self.pool(e3))
# Bottleneck
b = self.bottleneck(self.pool(e4))
# b = self.bottleneck_ff1(self.pool(e4))
# b = self.bottleneck_ff2(b)
#Decoder
d4 = self.upconv4(b)
d4 = torch.cat((d4, e4), dim=1)
d4 = self.dec4(d4)
d3 = self.upconv3(d4)
d3 = torch.cat((d3, e3), dim=1)
d3 = self.dec3(d3)
d2 = self.upconv2(d3)
d2 = torch.cat((d2, e2), dim=1)
d2 = self.dec2(d2)
d1 = self.upconv1(d2)
d1 = torch.cat((d1, e1), dim=1)
d1 = self.dec1(d1)
# Output
out = self.out_conv(d1)
out = self.final_activation(out)
return out