import torch import torch.nn as nn import torch.nn.functional as F class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels, dropout=0.1): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Dropout2d(dropout), nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=True), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) def forward(self, x): return self.conv(x) class AttentionGate(nn.Module): def __init__(self, F_g, F_l, F_int): super().__init__() self.W_g = nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True) self.W_x = nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True) self.psi = nn.Sequential( nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), nn.Sigmoid() ) self.relu = nn.ReLU(inplace=True) def forward(self, g, x): g1 = self.W_g(g) x1 = self.W_x(x) psi = self.relu(g1 + x1) psi = self.psi(psi) return x * psi class AttentionUNet(nn.Module): def __init__(self, img_ch=1, output_ch=4): super().__init__() self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) self.downs = nn.ModuleList([ DoubleConv(img_ch, 64), DoubleConv(64, 128), DoubleConv(128, 256), DoubleConv(256, 512) ]) self.bottleneck = DoubleConv(512, 1024) self.ups = nn.ModuleList([ nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2), nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2), nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2), nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) ]) self.attention_gates = nn.ModuleList([ AttentionGate(F_g=512, F_l=512, F_int=256), AttentionGate(F_g=256, F_l=256, F_int=128), AttentionGate(F_g=128, F_l=128, F_int=64), AttentionGate(F_g=64, F_l=64, F_int=32) ]) self.up_convs = nn.ModuleList([ DoubleConv(1024, 512), DoubleConv(512, 256), DoubleConv(256, 128), DoubleConv(128, 64) ]) self.final_conv = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0) def forward(self, x): e1 = self.downs[0](x) e2 = self.downs[1](self.Maxpool(e1)) e3 = self.downs[2](self.Maxpool(e2)) e4 = self.downs[3](self.Maxpool(e3)) b = self.bottleneck(self.Maxpool(e4)) d4 = self.ups[0](b) x4 = self.attention_gates[0](g=d4, x=e4) d4 = self.up_convs[0](torch.cat((x4, d4), dim=1)) d3 = self.ups[1](d4) x3 = self.attention_gates[1](g=d3, x=e3) d3 = self.up_convs[1](torch.cat((x3, d3), dim=1)) d2 = self.ups[2](d3) x2 = self.attention_gates[2](g=d2, x=e2) d2 = self.up_convs[2](torch.cat((x2, d2), dim=1)) d1 = self.ups[3](d2) x1 = self.attention_gates[3](g=d1, x=e1) d1 = self.up_convs[3](torch.cat((x1, d1), dim=1)) return self.final_conv(d1)