File size: 2,785 Bytes
e99a83c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models.segmentation import (
    deeplabv3_resnet50,
    deeplabv3_resnet101,
)
from torchvision.models.segmentation.deeplabv3 import DeepLabHead


class DeepLabV3Wrapper(nn.Module):
    """
    DeepLabV3 wrapper for retinal vessel segmentation.

    Output:
        Raw logits [B, num_classes, H, W]

    For binary vessel segmentation:
        num_classes = 1
    """

    def __init__(
        self,
        backbone="resnet50",
        num_classes=1,
        pretrained_backbone=True,
        aux_loss=False,
    ):
        super().__init__()

        if backbone == "resnet50":
            model = deeplabv3_resnet50(
                weights=None,
                weights_backbone="DEFAULT" if pretrained_backbone else None,
                aux_loss=aux_loss,
            )
            in_channels = 2048

        elif backbone == "resnet101":
            model = deeplabv3_resnet101(
                weights=None,
                weights_backbone="DEFAULT" if pretrained_backbone else None,
                aux_loss=aux_loss,
            )
            in_channels = 2048

        else:
            raise ValueError(
                f"Unsupported backbone: {backbone}. "
                "Choose from: 'resnet50', 'resnet101'."
            )

        model.classifier = DeepLabHead(
            in_channels=in_channels,
            num_classes=num_classes,
        )

        if aux_loss and model.aux_classifier is not None:
            model.aux_classifier[-1] = nn.Conv2d(
                model.aux_classifier[-1].in_channels,
                num_classes,
                kernel_size=1,
            )

        self.model = model

    def forward(self, x):
        output = self.model(x)

        # torchvision segmentation models return dict:
        # {"out": logits, "aux": optional aux logits}
        return output["out"]


def build_deeplabv3(
    backbone="resnet50",
    num_classes=1,
    pretrained_backbone=True,
    aux_loss=False,
):
    return DeepLabV3Wrapper(
        backbone=backbone,
        num_classes=num_classes,
        pretrained_backbone=pretrained_backbone,
        aux_loss=aux_loss,
    )


if __name__ == "__main__":
    # Smoke test:
    # python models/deeplabv3.py

    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = build_deeplabv3(
        backbone="resnet50",
        num_classes=1,
        pretrained_backbone=False,
    ).to(device)

    x = torch.randn(2, 3, 512, 512).to(device)

    with torch.no_grad():
        y = model(x)

    print("Input shape:", x.shape)
    print("Output shape:", y.shape)
    print("Output min/max:", y.min().item(), y.max().item())

    assert y.shape == (2, 1, 512, 512)

    print("Smoke test passed.")