Upload model.py with huggingface_hub
Browse files
model.py
CHANGED
|
@@ -5,6 +5,19 @@ import torchvision
|
|
| 5 |
import torch.nn.functional as F
|
| 6 |
from torchvision.transforms import functional as FF
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
######################################################################
|
| 9 |
# IMAGE DOWN SAMPLING
|
| 10 |
######################################################################
|
|
@@ -208,33 +221,33 @@ class Decoder(nn.Module):
|
|
| 208 |
return x
|
| 209 |
|
| 210 |
class VesselSegmentModel(nn.Module):
|
| 211 |
-
def __init__(self,
|
| 212 |
super().__init__()
|
| 213 |
-
|
| 214 |
# image patch
|
| 215 |
-
self.img_patch = ImagePatching(patch_size=
|
| 216 |
|
| 217 |
# image downsampling
|
| 218 |
-
self.img_down_sampling_1 = ImageDownSampling(height=
|
| 219 |
-
self.img_down_sampling_2 = ImageDownSampling(height=
|
| 220 |
|
| 221 |
# encoder layers
|
| 222 |
-
self.encoder_layer_1 = Encoder(
|
| 223 |
-
self.encoder_layer_2 = Encoder(
|
| 224 |
-
self.encoder_layer_3 = Encoder(
|
| 225 |
|
| 226 |
# bottle-neck layer
|
| 227 |
-
self.bottleneck = BottleNeck(in_ch=
|
| 228 |
|
| 229 |
# decoder layers
|
| 230 |
-
self.decoder_layer_1 = Decoder(tensor_dim_decoder=
|
| 231 |
-
self.decoder_layer_2 = Decoder(tensor_dim_decoder=
|
| 232 |
-
self.decoder_layer_3 = Decoder(tensor_dim_decoder=
|
| 233 |
|
| 234 |
# Segmentation Head
|
| 235 |
self.segmenation_head = nn.Sequential(
|
| 236 |
-
nn.Conv2d(in_channels=
|
| 237 |
-
ImageFolding(image_size=IMAGE_SIZE[0], patch_size=
|
| 238 |
)
|
| 239 |
|
| 240 |
def forward(self, x):
|
|
|
|
| 5 |
import torch.nn.functional as F
|
| 6 |
from torchvision.transforms import functional as FF
|
| 7 |
|
| 8 |
+
######################################################################
|
| 9 |
+
# Configuration File
|
| 10 |
+
######################################################################
|
| 11 |
+
class VesselSegmentConfig:
|
| 12 |
+
def __init__(self, num_classes, input_channels, image_size, features, attention_dims, patch_size, batch_size):
|
| 13 |
+
self.num_classes = num_classes
|
| 14 |
+
self.input_channels = input_channels
|
| 15 |
+
self.image_size = image_size
|
| 16 |
+
self.features = features
|
| 17 |
+
self.attention_dims = attention_dims
|
| 18 |
+
self.patch_size = patch_size
|
| 19 |
+
self.batch_size = batch_size
|
| 20 |
+
|
| 21 |
######################################################################
|
| 22 |
# IMAGE DOWN SAMPLING
|
| 23 |
######################################################################
|
|
|
|
| 221 |
return x
|
| 222 |
|
| 223 |
class VesselSegmentModel(nn.Module):
|
| 224 |
+
def __init__(self, config: VesselSegmentConfig):
|
| 225 |
super().__init__()
|
| 226 |
+
self.config = config
|
| 227 |
# image patch
|
| 228 |
+
self.img_patch = ImagePatching(patch_size=config.patch_size)
|
| 229 |
|
| 230 |
# image downsampling
|
| 231 |
+
self.img_down_sampling_1 = ImageDownSampling(height=config.patch_size, width=config.patch_size, scale=2)
|
| 232 |
+
self.img_down_sampling_2 = ImageDownSampling(height=config.patch_size, width=config.patch_size, scale=4)
|
| 233 |
|
| 234 |
# encoder layers
|
| 235 |
+
self.encoder_layer_1 = Encoder(config.input_channels, config.features[0], enc_fet_ch=config.features[0], max_pool_size=2, is_concate=False)
|
| 236 |
+
self.encoder_layer_2 = Encoder(config.input_channels, config.features[1], enc_fet_ch=config.features[0]*2, max_pool_size=2, is_concate=True)
|
| 237 |
+
self.encoder_layer_3 = Encoder(config.input_channels, config.features[2], enc_fet_ch=config.features[0]*4, max_pool_size=2, is_concate=True)
|
| 238 |
|
| 239 |
# bottle-neck layer
|
| 240 |
+
self.bottleneck = BottleNeck(in_ch=config.features[2]*2, out_ch=config.features[2]*4)
|
| 241 |
|
| 242 |
# decoder layers
|
| 243 |
+
self.decoder_layer_1 = Decoder(tensor_dim_decoder=config.features[-1]*4, tensor_dim_encoder=config.features[-1]*2, tensor_dim_mid=config.features[0], up_conv_in_ch=config.features[-1]*4, up_conv_out_ch=config.features[-1]*2, up_conv_scale=2, dconv_in_feature=config.features[-1]*4, dconv_out_feature=config.features[-1]*2, is_concat=True)
|
| 244 |
+
self.decoder_layer_2 = Decoder(tensor_dim_decoder=config.features[-1]*2, tensor_dim_encoder=config.features[-1], tensor_dim_mid=config.features[1], up_conv_in_ch=config.features[-1]*2, up_conv_out_ch=config.features[-1], up_conv_scale=2, dconv_in_feature=config.features[-1]*2, dconv_out_feature=config.features[-1], is_concat=True)
|
| 245 |
+
self.decoder_layer_3 = Decoder(tensor_dim_decoder=config.features[-1], tensor_dim_encoder=config.features[-2], tensor_dim_mid=config.features[2], up_conv_in_ch=config.features[-1], up_conv_out_ch=config.features[-2], up_conv_scale=2, dconv_in_feature=config.features[-1], dconv_out_feature=config.features[-2], is_concat=True)
|
| 246 |
|
| 247 |
# Segmentation Head
|
| 248 |
self.segmenation_head = nn.Sequential(
|
| 249 |
+
nn.Conv2d(in_channels=config.features[-3], out_channels=config.num_classes, kernel_size=1, padding=0, stride=1),
|
| 250 |
+
ImageFolding(image_size=IMAGE_SIZE[0], patch_size=config.patch_size, batch_size=config.batch_size)
|
| 251 |
)
|
| 252 |
|
| 253 |
def forward(self, x):
|