File size: 3,277 Bytes
aefe97d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | """Model factory for deterministic segmentation baselines."""
def get_model(model_name, in_channels=1, num_classes=1):
"""
Get a segmentation model by name using official libraries.
Args:
model_name: One of 'unet', 'attention_unet', 'unetpp', 'transunet'
in_channels: Number of input channels (1 for grayscale)
num_classes: Number of output classes (1 for binary)
Returns:
PyTorch model
"""
if model_name == "unet":
# From segmentation_models_pytorch
import segmentation_models_pytorch as smp
model = smp.Unet(
encoder_name="resnet34",
encoder_weights="imagenet",
in_channels=in_channels,
classes=num_classes,
activation=None, # We apply sigmoid in loss/prediction
)
return model
elif model_name == "unetpp":
# From segmentation_models_pytorch
import segmentation_models_pytorch as smp
model = smp.UnetPlusPlus(
encoder_name="resnet34",
encoder_weights="imagenet",
in_channels=in_channels,
classes=num_classes,
activation=None,
)
return model
elif model_name == "attention_unet":
# From MONAI
from monai.networks.nets import AttentionUnet
model = AttentionUnet(
spatial_dims=2,
in_channels=in_channels,
out_channels=num_classes,
channels=(64, 128, 256, 512),
strides=(2, 2, 2),
dropout=0.1,
)
return model
elif model_name == "transunet":
# Custom TransUNet using smp with a ViT-style encoder
# We use smp's MAnet with a MiT (Mix Transformer) encoder
# which is the closest official arch to TransUNet in smp
import segmentation_models_pytorch as smp
model = smp.MAnet(
encoder_name="mit_b2",
encoder_weights="imagenet",
in_channels=in_channels,
classes=num_classes,
activation=None,
)
return model
elif model_name == "nnunet":
# nnU-Net architecture (PlainConvUNet) from dynamic_network_architectures
# Using the same config that nnU-Net auto-planned for our 128x128 dataset
from dynamic_network_architectures.architectures.unet import PlainConvUNet
import torch.nn as nn
model = PlainConvUNet(
input_channels=in_channels,
n_stages=6,
features_per_stage=(32, 64, 128, 256, 512, 512),
conv_op=nn.Conv2d,
kernel_sizes=[[3, 3]] * 6,
strides=[[1, 1]] + [[2, 2]] * 5,
n_conv_per_stage=[2, 2, 2, 2, 2, 2],
num_classes=num_classes,
n_conv_per_stage_decoder=[2, 2, 2, 2, 2],
conv_bias=True,
norm_op=nn.InstanceNorm2d,
norm_op_kwargs={"eps": 1e-5, "affine": True},
dropout_op=None,
nonlin=nn.LeakyReLU,
nonlin_kwargs={"inplace": True},
deep_supervision=False,
)
return model
else:
raise ValueError(f"Unknown model: {model_name}. Choose from: unet, attention_unet, unetpp, transunet, nnunet")
|