Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision.models.segmentation import ( | |
| deeplabv3_resnet50, | |
| deeplabv3_resnet101, | |
| ) | |
| from torchvision.models.segmentation.deeplabv3 import DeepLabHead | |
| class DeepLabV3Wrapper(nn.Module): | |
| """ | |
| DeepLabV3 wrapper for retinal vessel segmentation. | |
| Output: | |
| Raw logits [B, num_classes, H, W] | |
| For binary vessel segmentation: | |
| num_classes = 1 | |
| """ | |
| def __init__( | |
| self, | |
| backbone="resnet50", | |
| num_classes=1, | |
| pretrained_backbone=True, | |
| aux_loss=False, | |
| ): | |
| super().__init__() | |
| if backbone == "resnet50": | |
| model = deeplabv3_resnet50( | |
| weights=None, | |
| weights_backbone="DEFAULT" if pretrained_backbone else None, | |
| aux_loss=aux_loss, | |
| ) | |
| in_channels = 2048 | |
| elif backbone == "resnet101": | |
| model = deeplabv3_resnet101( | |
| weights=None, | |
| weights_backbone="DEFAULT" if pretrained_backbone else None, | |
| aux_loss=aux_loss, | |
| ) | |
| in_channels = 2048 | |
| else: | |
| raise ValueError( | |
| f"Unsupported backbone: {backbone}. " | |
| "Choose from: 'resnet50', 'resnet101'." | |
| ) | |
| model.classifier = DeepLabHead( | |
| in_channels=in_channels, | |
| num_classes=num_classes, | |
| ) | |
| if aux_loss and model.aux_classifier is not None: | |
| model.aux_classifier[-1] = nn.Conv2d( | |
| model.aux_classifier[-1].in_channels, | |
| num_classes, | |
| kernel_size=1, | |
| ) | |
| self.model = model | |
| def forward(self, x): | |
| output = self.model(x) | |
| # torchvision segmentation models return dict: | |
| # {"out": logits, "aux": optional aux logits} | |
| return output["out"] | |
| def build_deeplabv3( | |
| backbone="resnet50", | |
| num_classes=1, | |
| pretrained_backbone=True, | |
| aux_loss=False, | |
| ): | |
| return DeepLabV3Wrapper( | |
| backbone=backbone, | |
| num_classes=num_classes, | |
| pretrained_backbone=pretrained_backbone, | |
| aux_loss=aux_loss, | |
| ) | |
| if __name__ == "__main__": | |
| # Smoke test: | |
| # python models/deeplabv3.py | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = build_deeplabv3( | |
| backbone="resnet50", | |
| num_classes=1, | |
| pretrained_backbone=False, | |
| ).to(device) | |
| x = torch.randn(2, 3, 512, 512).to(device) | |
| with torch.no_grad(): | |
| y = model(x) | |
| print("Input shape:", x.shape) | |
| print("Output shape:", y.shape) | |
| print("Output min/max:", y.min().item(), y.max().item()) | |
| assert y.shape == (2, 1, 512, 512) | |
| print("Smoke test passed.") |