Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class ConvBNReLU(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=3, padding=1): | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| return self.block(x) | |
| class ResidualBlock(nn.Module): | |
| """ | |
| Basic residual block for ResUNet. | |
| If in_channels != out_channels, the shortcut uses a 1x1 conv. | |
| """ | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.conv1 = ConvBNReLU(in_channels, out_channels) | |
| self.conv2 = nn.Sequential( | |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| ) | |
| if in_channels != out_channels: | |
| self.shortcut = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| ) | |
| else: | |
| self.shortcut = nn.Identity() | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| residual = self.shortcut(x) | |
| x = self.conv1(x) | |
| x = self.conv2(x) | |
| x = x + residual | |
| x = self.relu(x) | |
| return x | |
| class EncoderBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.res_block = ResidualBlock(in_channels, out_channels) | |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
| def forward(self, x): | |
| skip = self.res_block(x) | |
| pooled = self.pool(skip) | |
| return skip, pooled | |
| class DecoderBlock(nn.Module): | |
| def __init__(self, in_channels, skip_channels, out_channels): | |
| super().__init__() | |
| self.up = nn.ConvTranspose2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=2, | |
| stride=2, | |
| ) | |
| self.res_block = ResidualBlock( | |
| out_channels + skip_channels, | |
| out_channels, | |
| ) | |
| def forward(self, x, skip): | |
| x = self.up(x) | |
| # Handles odd image sizes, though 512/1024 should already match. | |
| if x.shape[-2:] != skip.shape[-2:]: | |
| x = F.interpolate( | |
| x, | |
| size=skip.shape[-2:], | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| x = torch.cat([x, skip], dim=1) | |
| x = self.res_block(x) | |
| return x | |
| class ResUNet(nn.Module): | |
| """ | |
| ResUNet for binary or multi-class retinal segmentation. | |
| Output: | |
| Raw logits of shape [B, num_classes, H, W] | |
| For vessel segmentation: | |
| num_classes=1 | |
| loss=BCEWithLogits/Dice/Tversky/etc. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels=3, | |
| num_classes=1, | |
| base_channels=32, | |
| dropout=0.0, | |
| ): | |
| super().__init__() | |
| c1 = base_channels | |
| c2 = base_channels * 2 | |
| c3 = base_channels * 4 | |
| c4 = base_channels * 8 | |
| c5 = base_channels * 16 | |
| self.enc1 = EncoderBlock(in_channels, c1) | |
| self.enc2 = EncoderBlock(c1, c2) | |
| self.enc3 = EncoderBlock(c2, c3) | |
| self.enc4 = EncoderBlock(c3, c4) | |
| self.bottleneck = nn.Sequential( | |
| ResidualBlock(c4, c5), | |
| nn.Dropout2d(dropout), | |
| ) | |
| self.dec4 = DecoderBlock(c5, c4, c4) | |
| self.dec3 = DecoderBlock(c4, c3, c3) | |
| self.dec2 = DecoderBlock(c3, c2, c2) | |
| self.dec1 = DecoderBlock(c2, c1, c1) | |
| self.out_conv = nn.Conv2d(c1, num_classes, kernel_size=1) | |
| def forward(self, x): | |
| s1, x = self.enc1(x) | |
| s2, x = self.enc2(x) | |
| s3, x = self.enc3(x) | |
| s4, x = self.enc4(x) | |
| x = self.bottleneck(x) | |
| x = self.dec4(x, s4) | |
| x = self.dec3(x, s3) | |
| x = self.dec2(x, s2) | |
| x = self.dec1(x, s1) | |
| logits = self.out_conv(x) | |
| return logits | |
| def build_resunet( | |
| in_channels=3, | |
| num_classes=1, | |
| base_channels=32, | |
| dropout=0.0, | |
| ): | |
| return ResUNet( | |
| in_channels=in_channels, | |
| num_classes=num_classes, | |
| base_channels=base_channels, | |
| dropout=dropout, | |
| ) | |
| if __name__ == "__main__": | |
| # Smoke test: | |
| # python models/unet.py | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = build_resunet( | |
| in_channels=3, | |
| num_classes=1, | |
| base_channels=32, | |
| dropout=0.0, | |
| ).to(device) | |
| x = torch.randn(2, 3, 512, 512).to(device) | |
| with torch.no_grad(): | |
| y = model(x) | |
| print("Input shape:", x.shape) | |
| print("Output shape:", y.shape) | |
| print("Output min/max:", y.min().item(), y.max().item()) | |
| assert y.shape == (2, 1, 512, 512) | |
| print("Smoke test passed.") |