"""Model factory for deterministic segmentation baselines.""" def get_model(model_name, in_channels=1, num_classes=1): """ Get a segmentation model by name using official libraries. Args: model_name: One of 'unet', 'attention_unet', 'unetpp', 'transunet' in_channels: Number of input channels (1 for grayscale) num_classes: Number of output classes (1 for binary) Returns: PyTorch model """ if model_name == "unet": # From segmentation_models_pytorch import segmentation_models_pytorch as smp model = smp.Unet( encoder_name="resnet34", encoder_weights="imagenet", in_channels=in_channels, classes=num_classes, activation=None, # We apply sigmoid in loss/prediction ) return model elif model_name == "unetpp": # From segmentation_models_pytorch import segmentation_models_pytorch as smp model = smp.UnetPlusPlus( encoder_name="resnet34", encoder_weights="imagenet", in_channels=in_channels, classes=num_classes, activation=None, ) return model elif model_name == "attention_unet": # From MONAI from monai.networks.nets import AttentionUnet model = AttentionUnet( spatial_dims=2, in_channels=in_channels, out_channels=num_classes, channels=(64, 128, 256, 512), strides=(2, 2, 2), dropout=0.1, ) return model elif model_name == "transunet": # Custom TransUNet using smp with a ViT-style encoder # We use smp's MAnet with a MiT (Mix Transformer) encoder # which is the closest official arch to TransUNet in smp import segmentation_models_pytorch as smp model = smp.MAnet( encoder_name="mit_b2", encoder_weights="imagenet", in_channels=in_channels, classes=num_classes, activation=None, ) return model elif model_name == "nnunet": # nnU-Net architecture (PlainConvUNet) from dynamic_network_architectures # Using the same config that nnU-Net auto-planned for our 128x128 dataset from dynamic_network_architectures.architectures.unet import PlainConvUNet import torch.nn as nn model = PlainConvUNet( input_channels=in_channels, n_stages=6, features_per_stage=(32, 64, 128, 256, 512, 512), conv_op=nn.Conv2d, kernel_sizes=[[3, 3]] * 6, strides=[[1, 1]] + [[2, 2]] * 5, n_conv_per_stage=[2, 2, 2, 2, 2, 2], num_classes=num_classes, n_conv_per_stage_decoder=[2, 2, 2, 2, 2], conv_bias=True, norm_op=nn.InstanceNorm2d, norm_op_kwargs={"eps": 1e-5, "affine": True}, dropout_op=None, nonlin=nn.LeakyReLU, nonlin_kwargs={"inplace": True}, deep_supervision=False, ) return model else: raise ValueError(f"Unknown model: {model_name}. Choose from: unet, attention_unet, unetpp, transunet, nnunet")