GenSeg-Baselines / code /framework /models /attention_unet.py
MaybeRichard's picture
Upload folder using huggingface_hub
b8fae22 verified
Raw
History Blame Contribute Delete
3.61 kB
"""Attention U-Net (Oktay et al., 2018) — additive attention gates on skips.
Reimplemented cleanly inside the framework. The original sfczekalski/attention_unet
repo is unmaintained (2020), unlicensed, has DRIVE-specific hardcoded padding and a
commented-out training loop, so only the architecture idea is reused here. This is
the faithful additive gating attention (NOT SMP's scSE), configurable for arbitrary
in_channels / num_classes.
"""
from __future__ import annotations
import torch
import torch.nn as nn
class ConvBlock(nn.Module):
def __init__(self, cin, cout):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(cin, cout, 3, padding=1, bias=False),
nn.BatchNorm2d(cout), nn.ReLU(inplace=True),
nn.Conv2d(cout, cout, 3, padding=1, bias=False),
nn.BatchNorm2d(cout), nn.ReLU(inplace=True),
)
def forward(self, x):
return self.block(x)
class UpConv(nn.Module):
def __init__(self, cin, cout):
super().__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(cin, cout, 3, padding=1, bias=False),
nn.BatchNorm2d(cout), nn.ReLU(inplace=True),
)
def forward(self, x):
return self.up(x)
class AttentionGate(nn.Module):
"""Additive attention gate: gating signal g (coarser) modulates skip x."""
def __init__(self, f_g, f_l, f_int):
super().__init__()
self.w_g = nn.Sequential(nn.Conv2d(f_g, f_int, 1, bias=True), nn.BatchNorm2d(f_int))
self.w_x = nn.Sequential(nn.Conv2d(f_l, f_int, 1, bias=True), nn.BatchNorm2d(f_int))
self.psi = nn.Sequential(nn.Conv2d(f_int, 1, 1, bias=True), nn.BatchNorm2d(1), nn.Sigmoid())
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
a = self.relu(self.w_g(g) + self.w_x(x))
a = self.psi(a)
return x * a
class AttentionUNet(nn.Module):
def __init__(self, in_channels: int = 3, num_classes: int = 2, base: int = 64):
super().__init__()
c = [base, base * 2, base * 4, base * 8, base * 16]
self.pool = nn.MaxPool2d(2, 2)
self.e1 = ConvBlock(in_channels, c[0])
self.e2 = ConvBlock(c[0], c[1])
self.e3 = ConvBlock(c[1], c[2])
self.e4 = ConvBlock(c[2], c[3])
self.e5 = ConvBlock(c[3], c[4])
self.u5 = UpConv(c[4], c[3]); self.a5 = AttentionGate(c[3], c[3], c[2]); self.d5 = ConvBlock(c[4], c[3])
self.u4 = UpConv(c[3], c[2]); self.a4 = AttentionGate(c[2], c[2], c[1]); self.d4 = ConvBlock(c[3], c[2])
self.u3 = UpConv(c[2], c[1]); self.a3 = AttentionGate(c[1], c[1], c[0]); self.d3 = ConvBlock(c[2], c[1])
self.u2 = UpConv(c[1], c[0]); self.a2 = AttentionGate(c[0], c[0], c[0] // 2); self.d2 = ConvBlock(c[1], c[0])
self.head = nn.Conv2d(c[0], num_classes, 1)
def forward(self, x):
x1 = self.e1(x)
x2 = self.e2(self.pool(x1))
x3 = self.e3(self.pool(x2))
x4 = self.e4(self.pool(x3))
x5 = self.e5(self.pool(x4))
d5 = self.u5(x5); d5 = self.d5(torch.cat([self.a5(d5, x4), d5], dim=1))
d4 = self.u4(d5); d4 = self.d4(torch.cat([self.a4(d4, x3), d4], dim=1))
d3 = self.u3(d4); d3 = self.d3(torch.cat([self.a3(d3, x2), d3], dim=1))
d2 = self.u2(d3); d2 = self.d2(torch.cat([self.a2(d2, x1), d2], dim=1))
return self.head(d2)
def build_attention_unet(in_channels: int, num_classes: int, **_):
return AttentionUNet(in_channels=in_channels, num_classes=num_classes)