Spaces:
Running
Running
File size: 2,785 Bytes
e99a83c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.segmentation import (
deeplabv3_resnet50,
deeplabv3_resnet101,
)
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
class DeepLabV3Wrapper(nn.Module):
"""
DeepLabV3 wrapper for retinal vessel segmentation.
Output:
Raw logits [B, num_classes, H, W]
For binary vessel segmentation:
num_classes = 1
"""
def __init__(
self,
backbone="resnet50",
num_classes=1,
pretrained_backbone=True,
aux_loss=False,
):
super().__init__()
if backbone == "resnet50":
model = deeplabv3_resnet50(
weights=None,
weights_backbone="DEFAULT" if pretrained_backbone else None,
aux_loss=aux_loss,
)
in_channels = 2048
elif backbone == "resnet101":
model = deeplabv3_resnet101(
weights=None,
weights_backbone="DEFAULT" if pretrained_backbone else None,
aux_loss=aux_loss,
)
in_channels = 2048
else:
raise ValueError(
f"Unsupported backbone: {backbone}. "
"Choose from: 'resnet50', 'resnet101'."
)
model.classifier = DeepLabHead(
in_channels=in_channels,
num_classes=num_classes,
)
if aux_loss and model.aux_classifier is not None:
model.aux_classifier[-1] = nn.Conv2d(
model.aux_classifier[-1].in_channels,
num_classes,
kernel_size=1,
)
self.model = model
def forward(self, x):
output = self.model(x)
# torchvision segmentation models return dict:
# {"out": logits, "aux": optional aux logits}
return output["out"]
def build_deeplabv3(
backbone="resnet50",
num_classes=1,
pretrained_backbone=True,
aux_loss=False,
):
return DeepLabV3Wrapper(
backbone=backbone,
num_classes=num_classes,
pretrained_backbone=pretrained_backbone,
aux_loss=aux_loss,
)
if __name__ == "__main__":
# Smoke test:
# python models/deeplabv3.py
device = "cuda" if torch.cuda.is_available() else "cpu"
model = build_deeplabv3(
backbone="resnet50",
num_classes=1,
pretrained_backbone=False,
).to(device)
x = torch.randn(2, 3, 512, 512).to(device)
with torch.no_grad():
y = model(x)
print("Input shape:", x.shape)
print("Output shape:", y.shape)
print("Output min/max:", y.min().item(), y.max().item())
assert y.shape == (2, 1, 512, 512)
print("Smoke test passed.") |