CFPVesselSeg / models /unet.py
farrell236's picture
add src
e99a83c
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBNReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.block(x)
class ResidualBlock(nn.Module):
"""
Basic residual block for ResUNet.
If in_channels != out_channels, the shortcut uses a 1x1 conv.
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = ConvBNReLU(in_channels, out_channels)
self.conv2 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
)
if in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
)
else:
self.shortcut = nn.Identity()
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
residual = self.shortcut(x)
x = self.conv1(x)
x = self.conv2(x)
x = x + residual
x = self.relu(x)
return x
class EncoderBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.res_block = ResidualBlock(in_channels, out_channels)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
skip = self.res_block(x)
pooled = self.pool(skip)
return skip, pooled
class DecoderBlock(nn.Module):
def __init__(self, in_channels, skip_channels, out_channels):
super().__init__()
self.up = nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=2,
stride=2,
)
self.res_block = ResidualBlock(
out_channels + skip_channels,
out_channels,
)
def forward(self, x, skip):
x = self.up(x)
# Handles odd image sizes, though 512/1024 should already match.
if x.shape[-2:] != skip.shape[-2:]:
x = F.interpolate(
x,
size=skip.shape[-2:],
mode="bilinear",
align_corners=False,
)
x = torch.cat([x, skip], dim=1)
x = self.res_block(x)
return x
class ResUNet(nn.Module):
"""
ResUNet for binary or multi-class retinal segmentation.
Output:
Raw logits of shape [B, num_classes, H, W]
For vessel segmentation:
num_classes=1
loss=BCEWithLogits/Dice/Tversky/etc.
"""
def __init__(
self,
in_channels=3,
num_classes=1,
base_channels=32,
dropout=0.0,
):
super().__init__()
c1 = base_channels
c2 = base_channels * 2
c3 = base_channels * 4
c4 = base_channels * 8
c5 = base_channels * 16
self.enc1 = EncoderBlock(in_channels, c1)
self.enc2 = EncoderBlock(c1, c2)
self.enc3 = EncoderBlock(c2, c3)
self.enc4 = EncoderBlock(c3, c4)
self.bottleneck = nn.Sequential(
ResidualBlock(c4, c5),
nn.Dropout2d(dropout),
)
self.dec4 = DecoderBlock(c5, c4, c4)
self.dec3 = DecoderBlock(c4, c3, c3)
self.dec2 = DecoderBlock(c3, c2, c2)
self.dec1 = DecoderBlock(c2, c1, c1)
self.out_conv = nn.Conv2d(c1, num_classes, kernel_size=1)
def forward(self, x):
s1, x = self.enc1(x)
s2, x = self.enc2(x)
s3, x = self.enc3(x)
s4, x = self.enc4(x)
x = self.bottleneck(x)
x = self.dec4(x, s4)
x = self.dec3(x, s3)
x = self.dec2(x, s2)
x = self.dec1(x, s1)
logits = self.out_conv(x)
return logits
def build_resunet(
in_channels=3,
num_classes=1,
base_channels=32,
dropout=0.0,
):
return ResUNet(
in_channels=in_channels,
num_classes=num_classes,
base_channels=base_channels,
dropout=dropout,
)
if __name__ == "__main__":
# Smoke test:
# python models/unet.py
device = "cuda" if torch.cuda.is_available() else "cpu"
model = build_resunet(
in_channels=3,
num_classes=1,
base_channels=32,
dropout=0.0,
).to(device)
x = torch.randn(2, 3, 512, 512).to(device)
with torch.no_grad():
y = model(x)
print("Input shape:", x.shape)
print("Output shape:", y.shape)
print("Output min/max:", y.min().item(), y.max().item())
assert y.shape == (2, 1, 512, 512)
print("Smoke test passed.")