siddharthdhara17 commited on
Commit
aefe97d
·
verified ·
1 Parent(s): fa5bb00

Upload baselines/models/__init__.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. baselines/models/__init__.py +93 -0
baselines/models/__init__.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model factory for deterministic segmentation baselines."""
2
+
3
+
4
+ def get_model(model_name, in_channels=1, num_classes=1):
5
+ """
6
+ Get a segmentation model by name using official libraries.
7
+
8
+ Args:
9
+ model_name: One of 'unet', 'attention_unet', 'unetpp', 'transunet'
10
+ in_channels: Number of input channels (1 for grayscale)
11
+ num_classes: Number of output classes (1 for binary)
12
+
13
+ Returns:
14
+ PyTorch model
15
+ """
16
+ if model_name == "unet":
17
+ # From segmentation_models_pytorch
18
+ import segmentation_models_pytorch as smp
19
+ model = smp.Unet(
20
+ encoder_name="resnet34",
21
+ encoder_weights="imagenet",
22
+ in_channels=in_channels,
23
+ classes=num_classes,
24
+ activation=None, # We apply sigmoid in loss/prediction
25
+ )
26
+ return model
27
+
28
+ elif model_name == "unetpp":
29
+ # From segmentation_models_pytorch
30
+ import segmentation_models_pytorch as smp
31
+ model = smp.UnetPlusPlus(
32
+ encoder_name="resnet34",
33
+ encoder_weights="imagenet",
34
+ in_channels=in_channels,
35
+ classes=num_classes,
36
+ activation=None,
37
+ )
38
+ return model
39
+
40
+ elif model_name == "attention_unet":
41
+ # From MONAI
42
+ from monai.networks.nets import AttentionUnet
43
+ model = AttentionUnet(
44
+ spatial_dims=2,
45
+ in_channels=in_channels,
46
+ out_channels=num_classes,
47
+ channels=(64, 128, 256, 512),
48
+ strides=(2, 2, 2),
49
+ dropout=0.1,
50
+ )
51
+ return model
52
+
53
+ elif model_name == "transunet":
54
+ # Custom TransUNet using smp with a ViT-style encoder
55
+ # We use smp's MAnet with a MiT (Mix Transformer) encoder
56
+ # which is the closest official arch to TransUNet in smp
57
+ import segmentation_models_pytorch as smp
58
+ model = smp.MAnet(
59
+ encoder_name="mit_b2",
60
+ encoder_weights="imagenet",
61
+ in_channels=in_channels,
62
+ classes=num_classes,
63
+ activation=None,
64
+ )
65
+ return model
66
+
67
+ elif model_name == "nnunet":
68
+ # nnU-Net architecture (PlainConvUNet) from dynamic_network_architectures
69
+ # Using the same config that nnU-Net auto-planned for our 128x128 dataset
70
+ from dynamic_network_architectures.architectures.unet import PlainConvUNet
71
+ import torch.nn as nn
72
+ model = PlainConvUNet(
73
+ input_channels=in_channels,
74
+ n_stages=6,
75
+ features_per_stage=(32, 64, 128, 256, 512, 512),
76
+ conv_op=nn.Conv2d,
77
+ kernel_sizes=[[3, 3]] * 6,
78
+ strides=[[1, 1]] + [[2, 2]] * 5,
79
+ n_conv_per_stage=[2, 2, 2, 2, 2, 2],
80
+ num_classes=num_classes,
81
+ n_conv_per_stage_decoder=[2, 2, 2, 2, 2],
82
+ conv_bias=True,
83
+ norm_op=nn.InstanceNorm2d,
84
+ norm_op_kwargs={"eps": 1e-5, "affine": True},
85
+ dropout_op=None,
86
+ nonlin=nn.LeakyReLU,
87
+ nonlin_kwargs={"inplace": True},
88
+ deep_supervision=False,
89
+ )
90
+ return model
91
+
92
+ else:
93
+ raise ValueError(f"Unknown model: {model_name}. Choose from: unet, attention_unet, unetpp, transunet, nnunet")