"""Attention U-Net (Oktay et al., 2018) — additive attention gates on skips. Reimplemented cleanly inside the framework. The original sfczekalski/attention_unet repo is unmaintained (2020), unlicensed, has DRIVE-specific hardcoded padding and a commented-out training loop, so only the architecture idea is reused here. This is the faithful additive gating attention (NOT SMP's scSE), configurable for arbitrary in_channels / num_classes. """ from __future__ import annotations import torch import torch.nn as nn class ConvBlock(nn.Module): def __init__(self, cin, cout): super().__init__() self.block = nn.Sequential( nn.Conv2d(cin, cout, 3, padding=1, bias=False), nn.BatchNorm2d(cout), nn.ReLU(inplace=True), nn.Conv2d(cout, cout, 3, padding=1, bias=False), nn.BatchNorm2d(cout), nn.ReLU(inplace=True), ) def forward(self, x): return self.block(x) class UpConv(nn.Module): def __init__(self, cin, cout): super().__init__() self.up = nn.Sequential( nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), nn.Conv2d(cin, cout, 3, padding=1, bias=False), nn.BatchNorm2d(cout), nn.ReLU(inplace=True), ) def forward(self, x): return self.up(x) class AttentionGate(nn.Module): """Additive attention gate: gating signal g (coarser) modulates skip x.""" def __init__(self, f_g, f_l, f_int): super().__init__() self.w_g = nn.Sequential(nn.Conv2d(f_g, f_int, 1, bias=True), nn.BatchNorm2d(f_int)) self.w_x = nn.Sequential(nn.Conv2d(f_l, f_int, 1, bias=True), nn.BatchNorm2d(f_int)) self.psi = nn.Sequential(nn.Conv2d(f_int, 1, 1, bias=True), nn.BatchNorm2d(1), nn.Sigmoid()) self.relu = nn.ReLU(inplace=True) def forward(self, g, x): a = self.relu(self.w_g(g) + self.w_x(x)) a = self.psi(a) return x * a class AttentionUNet(nn.Module): def __init__(self, in_channels: int = 3, num_classes: int = 2, base: int = 64): super().__init__() c = [base, base * 2, base * 4, base * 8, base * 16] self.pool = nn.MaxPool2d(2, 2) self.e1 = ConvBlock(in_channels, c[0]) self.e2 = ConvBlock(c[0], c[1]) self.e3 = ConvBlock(c[1], c[2]) self.e4 = ConvBlock(c[2], c[3]) self.e5 = ConvBlock(c[3], c[4]) self.u5 = UpConv(c[4], c[3]); self.a5 = AttentionGate(c[3], c[3], c[2]); self.d5 = ConvBlock(c[4], c[3]) self.u4 = UpConv(c[3], c[2]); self.a4 = AttentionGate(c[2], c[2], c[1]); self.d4 = ConvBlock(c[3], c[2]) self.u3 = UpConv(c[2], c[1]); self.a3 = AttentionGate(c[1], c[1], c[0]); self.d3 = ConvBlock(c[2], c[1]) self.u2 = UpConv(c[1], c[0]); self.a2 = AttentionGate(c[0], c[0], c[0] // 2); self.d2 = ConvBlock(c[1], c[0]) self.head = nn.Conv2d(c[0], num_classes, 1) def forward(self, x): x1 = self.e1(x) x2 = self.e2(self.pool(x1)) x3 = self.e3(self.pool(x2)) x4 = self.e4(self.pool(x3)) x5 = self.e5(self.pool(x4)) d5 = self.u5(x5); d5 = self.d5(torch.cat([self.a5(d5, x4), d5], dim=1)) d4 = self.u4(d5); d4 = self.d4(torch.cat([self.a4(d4, x3), d4], dim=1)) d3 = self.u3(d4); d3 = self.d3(torch.cat([self.a3(d3, x2), d3], dim=1)) d2 = self.u2(d3); d2 = self.d2(torch.cat([self.a2(d2, x1), d2], dim=1)) return self.head(d2) def build_attention_unet(in_channels: int, num_classes: int, **_): return AttentionUNet(in_channels=in_channels, num_classes=num_classes)