Upload baselines/models/__init__.py with huggingface_hub
Browse files- baselines/models/__init__.py +93 -0
baselines/models/__init__.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model factory for deterministic segmentation baselines."""
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_model(model_name, in_channels=1, num_classes=1):
|
| 5 |
+
"""
|
| 6 |
+
Get a segmentation model by name using official libraries.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
model_name: One of 'unet', 'attention_unet', 'unetpp', 'transunet'
|
| 10 |
+
in_channels: Number of input channels (1 for grayscale)
|
| 11 |
+
num_classes: Number of output classes (1 for binary)
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
PyTorch model
|
| 15 |
+
"""
|
| 16 |
+
if model_name == "unet":
|
| 17 |
+
# From segmentation_models_pytorch
|
| 18 |
+
import segmentation_models_pytorch as smp
|
| 19 |
+
model = smp.Unet(
|
| 20 |
+
encoder_name="resnet34",
|
| 21 |
+
encoder_weights="imagenet",
|
| 22 |
+
in_channels=in_channels,
|
| 23 |
+
classes=num_classes,
|
| 24 |
+
activation=None, # We apply sigmoid in loss/prediction
|
| 25 |
+
)
|
| 26 |
+
return model
|
| 27 |
+
|
| 28 |
+
elif model_name == "unetpp":
|
| 29 |
+
# From segmentation_models_pytorch
|
| 30 |
+
import segmentation_models_pytorch as smp
|
| 31 |
+
model = smp.UnetPlusPlus(
|
| 32 |
+
encoder_name="resnet34",
|
| 33 |
+
encoder_weights="imagenet",
|
| 34 |
+
in_channels=in_channels,
|
| 35 |
+
classes=num_classes,
|
| 36 |
+
activation=None,
|
| 37 |
+
)
|
| 38 |
+
return model
|
| 39 |
+
|
| 40 |
+
elif model_name == "attention_unet":
|
| 41 |
+
# From MONAI
|
| 42 |
+
from monai.networks.nets import AttentionUnet
|
| 43 |
+
model = AttentionUnet(
|
| 44 |
+
spatial_dims=2,
|
| 45 |
+
in_channels=in_channels,
|
| 46 |
+
out_channels=num_classes,
|
| 47 |
+
channels=(64, 128, 256, 512),
|
| 48 |
+
strides=(2, 2, 2),
|
| 49 |
+
dropout=0.1,
|
| 50 |
+
)
|
| 51 |
+
return model
|
| 52 |
+
|
| 53 |
+
elif model_name == "transunet":
|
| 54 |
+
# Custom TransUNet using smp with a ViT-style encoder
|
| 55 |
+
# We use smp's MAnet with a MiT (Mix Transformer) encoder
|
| 56 |
+
# which is the closest official arch to TransUNet in smp
|
| 57 |
+
import segmentation_models_pytorch as smp
|
| 58 |
+
model = smp.MAnet(
|
| 59 |
+
encoder_name="mit_b2",
|
| 60 |
+
encoder_weights="imagenet",
|
| 61 |
+
in_channels=in_channels,
|
| 62 |
+
classes=num_classes,
|
| 63 |
+
activation=None,
|
| 64 |
+
)
|
| 65 |
+
return model
|
| 66 |
+
|
| 67 |
+
elif model_name == "nnunet":
|
| 68 |
+
# nnU-Net architecture (PlainConvUNet) from dynamic_network_architectures
|
| 69 |
+
# Using the same config that nnU-Net auto-planned for our 128x128 dataset
|
| 70 |
+
from dynamic_network_architectures.architectures.unet import PlainConvUNet
|
| 71 |
+
import torch.nn as nn
|
| 72 |
+
model = PlainConvUNet(
|
| 73 |
+
input_channels=in_channels,
|
| 74 |
+
n_stages=6,
|
| 75 |
+
features_per_stage=(32, 64, 128, 256, 512, 512),
|
| 76 |
+
conv_op=nn.Conv2d,
|
| 77 |
+
kernel_sizes=[[3, 3]] * 6,
|
| 78 |
+
strides=[[1, 1]] + [[2, 2]] * 5,
|
| 79 |
+
n_conv_per_stage=[2, 2, 2, 2, 2, 2],
|
| 80 |
+
num_classes=num_classes,
|
| 81 |
+
n_conv_per_stage_decoder=[2, 2, 2, 2, 2],
|
| 82 |
+
conv_bias=True,
|
| 83 |
+
norm_op=nn.InstanceNorm2d,
|
| 84 |
+
norm_op_kwargs={"eps": 1e-5, "affine": True},
|
| 85 |
+
dropout_op=None,
|
| 86 |
+
nonlin=nn.LeakyReLU,
|
| 87 |
+
nonlin_kwargs={"inplace": True},
|
| 88 |
+
deep_supervision=False,
|
| 89 |
+
)
|
| 90 |
+
return model
|
| 91 |
+
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(f"Unknown model: {model_name}. Choose from: unet, attention_unet, unetpp, transunet, nnunet")
|