import torch import torch.nn as nn import torch.nn.functional as F from torchvision.models.segmentation import ( deeplabv3_resnet50, deeplabv3_resnet101, ) from torchvision.models.segmentation.deeplabv3 import DeepLabHead class DeepLabV3Wrapper(nn.Module): """ DeepLabV3 wrapper for retinal vessel segmentation. Output: Raw logits [B, num_classes, H, W] For binary vessel segmentation: num_classes = 1 """ def __init__( self, backbone="resnet50", num_classes=1, pretrained_backbone=True, aux_loss=False, ): super().__init__() if backbone == "resnet50": model = deeplabv3_resnet50( weights=None, weights_backbone="DEFAULT" if pretrained_backbone else None, aux_loss=aux_loss, ) in_channels = 2048 elif backbone == "resnet101": model = deeplabv3_resnet101( weights=None, weights_backbone="DEFAULT" if pretrained_backbone else None, aux_loss=aux_loss, ) in_channels = 2048 else: raise ValueError( f"Unsupported backbone: {backbone}. " "Choose from: 'resnet50', 'resnet101'." ) model.classifier = DeepLabHead( in_channels=in_channels, num_classes=num_classes, ) if aux_loss and model.aux_classifier is not None: model.aux_classifier[-1] = nn.Conv2d( model.aux_classifier[-1].in_channels, num_classes, kernel_size=1, ) self.model = model def forward(self, x): output = self.model(x) # torchvision segmentation models return dict: # {"out": logits, "aux": optional aux logits} return output["out"] def build_deeplabv3( backbone="resnet50", num_classes=1, pretrained_backbone=True, aux_loss=False, ): return DeepLabV3Wrapper( backbone=backbone, num_classes=num_classes, pretrained_backbone=pretrained_backbone, aux_loss=aux_loss, ) if __name__ == "__main__": # Smoke test: # python models/deeplabv3.py device = "cuda" if torch.cuda.is_available() else "cpu" model = build_deeplabv3( backbone="resnet50", num_classes=1, pretrained_backbone=False, ).to(device) x = torch.randn(2, 3, 512, 512).to(device) with torch.no_grad(): y = model(x) print("Input shape:", x.shape) print("Output shape:", y.shape) print("Output min/max:", y.min().item(), y.max().item()) assert y.shape == (2, 1, 512, 512) print("Smoke test passed.")