MohidAbdullah's picture
Add model architecture
3155589 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, dropout=0.1):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout2d(dropout),
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=True),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.conv(x)
class AttentionGate(nn.Module):
def __init__(self, F_g, F_l, F_int):
super().__init__()
self.W_g = nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True)
self.W_x = nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
class AttentionUNet(nn.Module):
def __init__(self, img_ch=1, output_ch=4):
super().__init__()
self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.downs = nn.ModuleList([
DoubleConv(img_ch, 64),
DoubleConv(64, 128),
DoubleConv(128, 256),
DoubleConv(256, 512)
])
self.bottleneck = DoubleConv(512, 1024)
self.ups = nn.ModuleList([
nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
])
self.attention_gates = nn.ModuleList([
AttentionGate(F_g=512, F_l=512, F_int=256),
AttentionGate(F_g=256, F_l=256, F_int=128),
AttentionGate(F_g=128, F_l=128, F_int=64),
AttentionGate(F_g=64, F_l=64, F_int=32)
])
self.up_convs = nn.ModuleList([
DoubleConv(1024, 512),
DoubleConv(512, 256),
DoubleConv(256, 128),
DoubleConv(128, 64)
])
self.final_conv = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)
def forward(self, x):
e1 = self.downs[0](x)
e2 = self.downs[1](self.Maxpool(e1))
e3 = self.downs[2](self.Maxpool(e2))
e4 = self.downs[3](self.Maxpool(e3))
b = self.bottleneck(self.Maxpool(e4))
d4 = self.ups[0](b)
x4 = self.attention_gates[0](g=d4, x=e4)
d4 = self.up_convs[0](torch.cat((x4, d4), dim=1))
d3 = self.ups[1](d4)
x3 = self.attention_gates[1](g=d3, x=e3)
d3 = self.up_convs[1](torch.cat((x3, d3), dim=1))
d2 = self.ups[2](d3)
x2 = self.attention_gates[2](g=d2, x=e2)
d2 = self.up_convs[2](torch.cat((x2, d2), dim=1))
d1 = self.ups[3](d2)
x1 = self.attention_gates[3](g=d1, x=e1)
d1 = self.up_convs[3](torch.cat((x1, d1), dim=1))
return self.final_conv(d1)