File size: 3,277 Bytes
aefe97d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""Model factory for deterministic segmentation baselines."""


def get_model(model_name, in_channels=1, num_classes=1):
    """
    Get a segmentation model by name using official libraries.
    
    Args:
        model_name: One of 'unet', 'attention_unet', 'unetpp', 'transunet'
        in_channels: Number of input channels (1 for grayscale)
        num_classes: Number of output classes (1 for binary)
    
    Returns:
        PyTorch model
    """
    if model_name == "unet":
        # From segmentation_models_pytorch
        import segmentation_models_pytorch as smp
        model = smp.Unet(
            encoder_name="resnet34",
            encoder_weights="imagenet",
            in_channels=in_channels,
            classes=num_classes,
            activation=None,  # We apply sigmoid in loss/prediction
        )
        return model
    
    elif model_name == "unetpp":
        # From segmentation_models_pytorch
        import segmentation_models_pytorch as smp
        model = smp.UnetPlusPlus(
            encoder_name="resnet34",
            encoder_weights="imagenet",
            in_channels=in_channels,
            classes=num_classes,
            activation=None,
        )
        return model
    
    elif model_name == "attention_unet":
        # From MONAI
        from monai.networks.nets import AttentionUnet
        model = AttentionUnet(
            spatial_dims=2,
            in_channels=in_channels,
            out_channels=num_classes,
            channels=(64, 128, 256, 512),
            strides=(2, 2, 2),
            dropout=0.1,
        )
        return model
    
    elif model_name == "transunet":
        # Custom TransUNet using smp with a ViT-style encoder
        # We use smp's MAnet with a MiT (Mix Transformer) encoder 
        # which is the closest official arch to TransUNet in smp
        import segmentation_models_pytorch as smp
        model = smp.MAnet(
            encoder_name="mit_b2",
            encoder_weights="imagenet",
            in_channels=in_channels,
            classes=num_classes,
            activation=None,
        )
        return model
    
    elif model_name == "nnunet":
        # nnU-Net architecture (PlainConvUNet) from dynamic_network_architectures
        # Using the same config that nnU-Net auto-planned for our 128x128 dataset
        from dynamic_network_architectures.architectures.unet import PlainConvUNet
        import torch.nn as nn
        model = PlainConvUNet(
            input_channels=in_channels,
            n_stages=6,
            features_per_stage=(32, 64, 128, 256, 512, 512),
            conv_op=nn.Conv2d,
            kernel_sizes=[[3, 3]] * 6,
            strides=[[1, 1]] + [[2, 2]] * 5,
            n_conv_per_stage=[2, 2, 2, 2, 2, 2],
            num_classes=num_classes,
            n_conv_per_stage_decoder=[2, 2, 2, 2, 2],
            conv_bias=True,
            norm_op=nn.InstanceNorm2d,
            norm_op_kwargs={"eps": 1e-5, "affine": True},
            dropout_op=None,
            nonlin=nn.LeakyReLU,
            nonlin_kwargs={"inplace": True},
            deep_supervision=False,
        )
        return model

    else:
        raise ValueError(f"Unknown model: {model_name}. Choose from: unet, attention_unet, unetpp, transunet, nnunet")