| __author__ = "Mouad"
|
| '''my first torch module hope not the last one :p'''
|
|
|
| import torch
|
| import torch.nn as nn
|
| import numpy as np
|
|
|
|
|
| def conv(in_chan, out_chan):
|
|
|
|
|
|
|
|
|
|
|
| return nn.Sequential(
|
| nn.Conv2d(in_channels=in_chan, out_channels=out_chan, kernel_size=5, stride=2, padding=2),
|
| nn.BatchNorm2d(num_features=out_chan),
|
| nn.LeakyReLU(0.2,inplace=True)
|
|
|
| )
|
|
|
| def deconv(in_chan, out_chan,dropout=False):
|
|
|
| if dropout==False:
|
| return nn.Sequential(
|
| nn.ConvTranspose2d(in_channels=in_chan, out_channels=out_chan, kernel_size=5, stride=2, padding=2,output_padding=1),
|
| nn.BatchNorm2d(num_features=out_chan),
|
| nn.ReLU(inplace=True)
|
| )
|
| else:
|
| return nn.Sequential(
|
| nn.ConvTranspose2d(in_channels=in_chan, out_channels=out_chan, kernel_size=5, stride=2, padding=2,
|
| output_padding=1),
|
| nn.BatchNorm2d(num_features=out_chan),
|
| nn.ReLU(inplace=True),
|
| nn.Dropout2d(p=0.5)
|
| )
|
|
|
|
|
|
|
| def final_conv(in_chan, out_chan):
|
| return nn.Sequential(nn.ConvTranspose2d(in_channels=in_chan, out_channels=out_chan, kernel_size=5, stride=2, padding=2,
|
| output_padding=1),
|
| nn.Sigmoid())
|
|
|
|
|
| class UNet(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
|
|
|
|
|
|
|
|
| self.down_conv1=conv(1,16)
|
| self.down_conv2=conv(16,32)
|
| self.down_conv3=conv(32,64)
|
| self.down_conv4=conv(64,128)
|
| self.down_conv5=conv(128,256)
|
| self.down_conv6=conv(256,512)
|
| self.deconv1=deconv(512,256,dropout=True)
|
| self.deconv2=deconv(512,128,dropout=True)
|
| self.deconv3=deconv(256,64,dropout=True)
|
| self.deconv4=deconv(128,32)
|
| self.deconv5=deconv(64,16)
|
|
|
| self.final_conv=final_conv(32,1)
|
|
|
| def forward(self, image):
|
|
|
|
|
| x1 = self.down_conv1(image)
|
|
|
|
|
| x2 = self.down_conv2(x1)
|
|
|
|
|
| x3 = self.down_conv3(x2)
|
|
|
|
|
| x4 = self.down_conv4(x3)
|
|
|
|
|
| x5 = self.down_conv5(x4)
|
|
|
|
|
| x6 = self.down_conv6(x5)
|
|
|
|
|
|
|
| x7 = self.deconv1(x6)
|
|
|
|
|
| x8 = self.deconv2(torch.cat((x7, x5), 1))
|
|
|
|
|
| x9 = self.deconv3(torch.cat((x8, x4), 1))
|
|
|
|
|
| x10 = self.deconv4(torch.cat((x9, x3), 1))
|
|
|
|
|
| x11 = self.deconv5(torch.cat((x10, x2), 1))
|
|
|
|
|
| final_layer = self.final_conv(torch.cat((x11, x1), 1))
|
|
|
|
|
| return final_layer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| |