| """Model factory for deterministic segmentation baselines.""" |
|
|
|
|
| def get_model(model_name, in_channels=1, num_classes=1): |
| """ |
| Get a segmentation model by name using official libraries. |
| |
| Args: |
| model_name: One of 'unet', 'attention_unet', 'unetpp', 'transunet' |
| in_channels: Number of input channels (1 for grayscale) |
| num_classes: Number of output classes (1 for binary) |
| |
| Returns: |
| PyTorch model |
| """ |
| if model_name == "unet": |
| |
| import segmentation_models_pytorch as smp |
| model = smp.Unet( |
| encoder_name="resnet34", |
| encoder_weights="imagenet", |
| in_channels=in_channels, |
| classes=num_classes, |
| activation=None, |
| ) |
| return model |
| |
| elif model_name == "unetpp": |
| |
| import segmentation_models_pytorch as smp |
| model = smp.UnetPlusPlus( |
| encoder_name="resnet34", |
| encoder_weights="imagenet", |
| in_channels=in_channels, |
| classes=num_classes, |
| activation=None, |
| ) |
| return model |
| |
| elif model_name == "attention_unet": |
| |
| from monai.networks.nets import AttentionUnet |
| model = AttentionUnet( |
| spatial_dims=2, |
| in_channels=in_channels, |
| out_channels=num_classes, |
| channels=(64, 128, 256, 512), |
| strides=(2, 2, 2), |
| dropout=0.1, |
| ) |
| return model |
| |
| elif model_name == "transunet": |
| |
| |
| |
| import segmentation_models_pytorch as smp |
| model = smp.MAnet( |
| encoder_name="mit_b2", |
| encoder_weights="imagenet", |
| in_channels=in_channels, |
| classes=num_classes, |
| activation=None, |
| ) |
| return model |
| |
| elif model_name == "nnunet": |
| |
| |
| from dynamic_network_architectures.architectures.unet import PlainConvUNet |
| import torch.nn as nn |
| model = PlainConvUNet( |
| input_channels=in_channels, |
| n_stages=6, |
| features_per_stage=(32, 64, 128, 256, 512, 512), |
| conv_op=nn.Conv2d, |
| kernel_sizes=[[3, 3]] * 6, |
| strides=[[1, 1]] + [[2, 2]] * 5, |
| n_conv_per_stage=[2, 2, 2, 2, 2, 2], |
| num_classes=num_classes, |
| n_conv_per_stage_decoder=[2, 2, 2, 2, 2], |
| conv_bias=True, |
| norm_op=nn.InstanceNorm2d, |
| norm_op_kwargs={"eps": 1e-5, "affine": True}, |
| dropout_op=None, |
| nonlin=nn.LeakyReLU, |
| nonlin_kwargs={"inplace": True}, |
| deep_supervision=False, |
| ) |
| return model |
|
|
| else: |
| raise ValueError(f"Unknown model: {model_name}. Choose from: unet, attention_unet, unetpp, transunet, nnunet") |
|
|