CFPVesselSeg / models /deeplabv3.py
farrell236's picture
add src
e99a83c
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.segmentation import (
deeplabv3_resnet50,
deeplabv3_resnet101,
)
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
class DeepLabV3Wrapper(nn.Module):
"""
DeepLabV3 wrapper for retinal vessel segmentation.
Output:
Raw logits [B, num_classes, H, W]
For binary vessel segmentation:
num_classes = 1
"""
def __init__(
self,
backbone="resnet50",
num_classes=1,
pretrained_backbone=True,
aux_loss=False,
):
super().__init__()
if backbone == "resnet50":
model = deeplabv3_resnet50(
weights=None,
weights_backbone="DEFAULT" if pretrained_backbone else None,
aux_loss=aux_loss,
)
in_channels = 2048
elif backbone == "resnet101":
model = deeplabv3_resnet101(
weights=None,
weights_backbone="DEFAULT" if pretrained_backbone else None,
aux_loss=aux_loss,
)
in_channels = 2048
else:
raise ValueError(
f"Unsupported backbone: {backbone}. "
"Choose from: 'resnet50', 'resnet101'."
)
model.classifier = DeepLabHead(
in_channels=in_channels,
num_classes=num_classes,
)
if aux_loss and model.aux_classifier is not None:
model.aux_classifier[-1] = nn.Conv2d(
model.aux_classifier[-1].in_channels,
num_classes,
kernel_size=1,
)
self.model = model
def forward(self, x):
output = self.model(x)
# torchvision segmentation models return dict:
# {"out": logits, "aux": optional aux logits}
return output["out"]
def build_deeplabv3(
backbone="resnet50",
num_classes=1,
pretrained_backbone=True,
aux_loss=False,
):
return DeepLabV3Wrapper(
backbone=backbone,
num_classes=num_classes,
pretrained_backbone=pretrained_backbone,
aux_loss=aux_loss,
)
if __name__ == "__main__":
# Smoke test:
# python models/deeplabv3.py
device = "cuda" if torch.cuda.is_available() else "cpu"
model = build_deeplabv3(
backbone="resnet50",
num_classes=1,
pretrained_backbone=False,
).to(device)
x = torch.randn(2, 3, 512, 512).to(device)
with torch.no_grad():
y = model(x)
print("Input shape:", x.shape)
print("Output shape:", y.shape)
print("Output min/max:", y.min().item(), y.max().item())
assert y.shape == (2, 1, 512, 512)
print("Smoke test passed.")