| """Attention U-Net (Oktay et al., 2018) — additive attention gates on skips. |
| |
| Reimplemented cleanly inside the framework. The original sfczekalski/attention_unet |
| repo is unmaintained (2020), unlicensed, has DRIVE-specific hardcoded padding and a |
| commented-out training loop, so only the architecture idea is reused here. This is |
| the faithful additive gating attention (NOT SMP's scSE), configurable for arbitrary |
| in_channels / num_classes. |
| """ |
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class ConvBlock(nn.Module): |
| def __init__(self, cin, cout): |
| super().__init__() |
| self.block = nn.Sequential( |
| nn.Conv2d(cin, cout, 3, padding=1, bias=False), |
| nn.BatchNorm2d(cout), nn.ReLU(inplace=True), |
| nn.Conv2d(cout, cout, 3, padding=1, bias=False), |
| nn.BatchNorm2d(cout), nn.ReLU(inplace=True), |
| ) |
|
|
| def forward(self, x): |
| return self.block(x) |
|
|
|
|
| class UpConv(nn.Module): |
| def __init__(self, cin, cout): |
| super().__init__() |
| self.up = nn.Sequential( |
| nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), |
| nn.Conv2d(cin, cout, 3, padding=1, bias=False), |
| nn.BatchNorm2d(cout), nn.ReLU(inplace=True), |
| ) |
|
|
| def forward(self, x): |
| return self.up(x) |
|
|
|
|
| class AttentionGate(nn.Module): |
| """Additive attention gate: gating signal g (coarser) modulates skip x.""" |
| def __init__(self, f_g, f_l, f_int): |
| super().__init__() |
| self.w_g = nn.Sequential(nn.Conv2d(f_g, f_int, 1, bias=True), nn.BatchNorm2d(f_int)) |
| self.w_x = nn.Sequential(nn.Conv2d(f_l, f_int, 1, bias=True), nn.BatchNorm2d(f_int)) |
| self.psi = nn.Sequential(nn.Conv2d(f_int, 1, 1, bias=True), nn.BatchNorm2d(1), nn.Sigmoid()) |
| self.relu = nn.ReLU(inplace=True) |
|
|
| def forward(self, g, x): |
| a = self.relu(self.w_g(g) + self.w_x(x)) |
| a = self.psi(a) |
| return x * a |
|
|
|
|
| class AttentionUNet(nn.Module): |
| def __init__(self, in_channels: int = 3, num_classes: int = 2, base: int = 64): |
| super().__init__() |
| c = [base, base * 2, base * 4, base * 8, base * 16] |
| self.pool = nn.MaxPool2d(2, 2) |
| self.e1 = ConvBlock(in_channels, c[0]) |
| self.e2 = ConvBlock(c[0], c[1]) |
| self.e3 = ConvBlock(c[1], c[2]) |
| self.e4 = ConvBlock(c[2], c[3]) |
| self.e5 = ConvBlock(c[3], c[4]) |
|
|
| self.u5 = UpConv(c[4], c[3]); self.a5 = AttentionGate(c[3], c[3], c[2]); self.d5 = ConvBlock(c[4], c[3]) |
| self.u4 = UpConv(c[3], c[2]); self.a4 = AttentionGate(c[2], c[2], c[1]); self.d4 = ConvBlock(c[3], c[2]) |
| self.u3 = UpConv(c[2], c[1]); self.a3 = AttentionGate(c[1], c[1], c[0]); self.d3 = ConvBlock(c[2], c[1]) |
| self.u2 = UpConv(c[1], c[0]); self.a2 = AttentionGate(c[0], c[0], c[0] // 2); self.d2 = ConvBlock(c[1], c[0]) |
| self.head = nn.Conv2d(c[0], num_classes, 1) |
|
|
| def forward(self, x): |
| x1 = self.e1(x) |
| x2 = self.e2(self.pool(x1)) |
| x3 = self.e3(self.pool(x2)) |
| x4 = self.e4(self.pool(x3)) |
| x5 = self.e5(self.pool(x4)) |
|
|
| d5 = self.u5(x5); d5 = self.d5(torch.cat([self.a5(d5, x4), d5], dim=1)) |
| d4 = self.u4(d5); d4 = self.d4(torch.cat([self.a4(d4, x3), d4], dim=1)) |
| d3 = self.u3(d4); d3 = self.d3(torch.cat([self.a3(d3, x2), d3], dim=1)) |
| d2 = self.u2(d3); d2 = self.d2(torch.cat([self.a2(d2, x1), d2], dim=1)) |
| return self.head(d2) |
|
|
|
|
| def build_attention_unet(in_channels: int, num_classes: int, **_): |
| return AttentionUNet(in_channels=in_channels, num_classes=num_classes) |
|
|