import torch import torch.nn as nn import torch.nn.functional as F class ConvBNReLU(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, padding=1): super().__init__() self.block = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) def forward(self, x): return self.block(x) class ResidualBlock(nn.Module): """ Basic residual block for ResUNet. If in_channels != out_channels, the shortcut uses a 1x1 conv. """ def __init__(self, in_channels, out_channels): super().__init__() self.conv1 = ConvBNReLU(in_channels, out_channels) self.conv2 = nn.Sequential( nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), ) if in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels), ) else: self.shortcut = nn.Identity() self.relu = nn.ReLU(inplace=True) def forward(self, x): residual = self.shortcut(x) x = self.conv1(x) x = self.conv2(x) x = x + residual x = self.relu(x) return x class EncoderBlock(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.res_block = ResidualBlock(in_channels, out_channels) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) def forward(self, x): skip = self.res_block(x) pooled = self.pool(skip) return skip, pooled class DecoderBlock(nn.Module): def __init__(self, in_channels, skip_channels, out_channels): super().__init__() self.up = nn.ConvTranspose2d( in_channels, out_channels, kernel_size=2, stride=2, ) self.res_block = ResidualBlock( out_channels + skip_channels, out_channels, ) def forward(self, x, skip): x = self.up(x) # Handles odd image sizes, though 512/1024 should already match. if x.shape[-2:] != skip.shape[-2:]: x = F.interpolate( x, size=skip.shape[-2:], mode="bilinear", align_corners=False, ) x = torch.cat([x, skip], dim=1) x = self.res_block(x) return x class ResUNet(nn.Module): """ ResUNet for binary or multi-class retinal segmentation. Output: Raw logits of shape [B, num_classes, H, W] For vessel segmentation: num_classes=1 loss=BCEWithLogits/Dice/Tversky/etc. """ def __init__( self, in_channels=3, num_classes=1, base_channels=32, dropout=0.0, ): super().__init__() c1 = base_channels c2 = base_channels * 2 c3 = base_channels * 4 c4 = base_channels * 8 c5 = base_channels * 16 self.enc1 = EncoderBlock(in_channels, c1) self.enc2 = EncoderBlock(c1, c2) self.enc3 = EncoderBlock(c2, c3) self.enc4 = EncoderBlock(c3, c4) self.bottleneck = nn.Sequential( ResidualBlock(c4, c5), nn.Dropout2d(dropout), ) self.dec4 = DecoderBlock(c5, c4, c4) self.dec3 = DecoderBlock(c4, c3, c3) self.dec2 = DecoderBlock(c3, c2, c2) self.dec1 = DecoderBlock(c2, c1, c1) self.out_conv = nn.Conv2d(c1, num_classes, kernel_size=1) def forward(self, x): s1, x = self.enc1(x) s2, x = self.enc2(x) s3, x = self.enc3(x) s4, x = self.enc4(x) x = self.bottleneck(x) x = self.dec4(x, s4) x = self.dec3(x, s3) x = self.dec2(x, s2) x = self.dec1(x, s1) logits = self.out_conv(x) return logits def build_resunet( in_channels=3, num_classes=1, base_channels=32, dropout=0.0, ): return ResUNet( in_channels=in_channels, num_classes=num_classes, base_channels=base_channels, dropout=dropout, ) if __name__ == "__main__": # Smoke test: # python models/unet.py device = "cuda" if torch.cuda.is_available() else "cpu" model = build_resunet( in_channels=3, num_classes=1, base_channels=32, dropout=0.0, ).to(device) x = torch.randn(2, 3, 512, 512).to(device) with torch.no_grad(): y = model(x) print("Input shape:", x.shape) print("Output shape:", y.shape) print("Output min/max:", y.min().item(), y.max().item()) assert y.shape == (2, 1, 512, 512) print("Smoke test passed.")