MHasanUnical commited on
Commit
5f54b5b
·
verified ·
1 Parent(s): 3a0de6b

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +27 -14
model.py CHANGED
@@ -5,6 +5,19 @@ import torchvision
5
  import torch.nn.functional as F
6
  from torchvision.transforms import functional as FF
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  ######################################################################
9
  # IMAGE DOWN SAMPLING
10
  ######################################################################
@@ -208,33 +221,33 @@ class Decoder(nn.Module):
208
  return x
209
 
210
  class VesselSegmentModel(nn.Module):
211
- def __init__(self, input_channel: int, feature: list, attention_dim: list, output_channel: int=1):
212
  super().__init__()
213
-
214
  # image patch
215
- self.img_patch = ImagePatching(patch_size=PATCH_SIZE)
216
 
217
  # image downsampling
218
- self.img_down_sampling_1 = ImageDownSampling(height=PATCH_SIZE, width=PATCH_SIZE, scale=2)
219
- self.img_down_sampling_2 = ImageDownSampling(height=PATCH_SIZE, width=PATCH_SIZE, scale=4)
220
 
221
  # encoder layers
222
- self.encoder_layer_1 = Encoder(input_channel, feature[0], enc_fet_ch=feature[0], max_pool_size=2, is_concate=False)
223
- self.encoder_layer_2 = Encoder(input_channel, feature[1], enc_fet_ch=feature[0]*2, max_pool_size=2, is_concate=True)
224
- self.encoder_layer_3 = Encoder(input_channel, feature[2], enc_fet_ch=feature[0]*4, max_pool_size=2, is_concate=True)
225
 
226
  # bottle-neck layer
227
- self.bottleneck = BottleNeck(in_ch=feature[2]*2, out_ch=feature[2]*4)
228
 
229
  # decoder layers
230
- self.decoder_layer_1 = Decoder(tensor_dim_decoder=feature[-1]*4, tensor_dim_encoder=feature[-1]*2, tensor_dim_mid=attention_dim[0], up_conv_in_ch=feature[-1]*4, up_conv_out_ch=feature[-1]*2, up_conv_scale=2, dconv_in_feature=feature[-1]*4, dconv_out_feature=feature[-1]*2, is_concat=True)
231
- self.decoder_layer_2 = Decoder(tensor_dim_decoder=feature[-1]*2, tensor_dim_encoder=feature[-1], tensor_dim_mid=attention_dim[1], up_conv_in_ch=feature[-1]*2, up_conv_out_ch=feature[-1], up_conv_scale=2, dconv_in_feature=feature[-1]*2, dconv_out_feature=feature[-1], is_concat=True)
232
- self.decoder_layer_3 = Decoder(tensor_dim_decoder=feature[-1], tensor_dim_encoder=feature[-2], tensor_dim_mid=attention_dim[2], up_conv_in_ch=feature[-1], up_conv_out_ch=feature[-2], up_conv_scale=2, dconv_in_feature=feature[-1], dconv_out_feature=feature[-2], is_concat=True)
233
 
234
  # Segmentation Head
235
  self.segmenation_head = nn.Sequential(
236
- nn.Conv2d(in_channels=feature[-3], out_channels=output_channel, kernel_size=1, padding=0, stride=1),
237
- ImageFolding(image_size=IMAGE_SIZE[0], patch_size=PATCH_SIZE, batch_size=BATCH_SIZE)
238
  )
239
 
240
  def forward(self, x):
 
5
  import torch.nn.functional as F
6
  from torchvision.transforms import functional as FF
7
 
8
+ ######################################################################
9
+ # Configuration File
10
+ ######################################################################
11
+ class VesselSegmentConfig:
12
+ def __init__(self, num_classes, input_channels, image_size, features, attention_dims, patch_size, batch_size):
13
+ self.num_classes = num_classes
14
+ self.input_channels = input_channels
15
+ self.image_size = image_size
16
+ self.features = features
17
+ self.attention_dims = attention_dims
18
+ self.patch_size = patch_size
19
+ self.batch_size = batch_size
20
+
21
  ######################################################################
22
  # IMAGE DOWN SAMPLING
23
  ######################################################################
 
221
  return x
222
 
223
  class VesselSegmentModel(nn.Module):
224
+ def __init__(self, config: VesselSegmentConfig):
225
  super().__init__()
226
+ self.config = config
227
  # image patch
228
+ self.img_patch = ImagePatching(patch_size=config.patch_size)
229
 
230
  # image downsampling
231
+ self.img_down_sampling_1 = ImageDownSampling(height=config.patch_size, width=config.patch_size, scale=2)
232
+ self.img_down_sampling_2 = ImageDownSampling(height=config.patch_size, width=config.patch_size, scale=4)
233
 
234
  # encoder layers
235
+ self.encoder_layer_1 = Encoder(config.input_channels, config.features[0], enc_fet_ch=config.features[0], max_pool_size=2, is_concate=False)
236
+ self.encoder_layer_2 = Encoder(config.input_channels, config.features[1], enc_fet_ch=config.features[0]*2, max_pool_size=2, is_concate=True)
237
+ self.encoder_layer_3 = Encoder(config.input_channels, config.features[2], enc_fet_ch=config.features[0]*4, max_pool_size=2, is_concate=True)
238
 
239
  # bottle-neck layer
240
+ self.bottleneck = BottleNeck(in_ch=config.features[2]*2, out_ch=config.features[2]*4)
241
 
242
  # decoder layers
243
+ self.decoder_layer_1 = Decoder(tensor_dim_decoder=config.features[-1]*4, tensor_dim_encoder=config.features[-1]*2, tensor_dim_mid=config.features[0], up_conv_in_ch=config.features[-1]*4, up_conv_out_ch=config.features[-1]*2, up_conv_scale=2, dconv_in_feature=config.features[-1]*4, dconv_out_feature=config.features[-1]*2, is_concat=True)
244
+ self.decoder_layer_2 = Decoder(tensor_dim_decoder=config.features[-1]*2, tensor_dim_encoder=config.features[-1], tensor_dim_mid=config.features[1], up_conv_in_ch=config.features[-1]*2, up_conv_out_ch=config.features[-1], up_conv_scale=2, dconv_in_feature=config.features[-1]*2, dconv_out_feature=config.features[-1], is_concat=True)
245
+ self.decoder_layer_3 = Decoder(tensor_dim_decoder=config.features[-1], tensor_dim_encoder=config.features[-2], tensor_dim_mid=config.features[2], up_conv_in_ch=config.features[-1], up_conv_out_ch=config.features[-2], up_conv_scale=2, dconv_in_feature=config.features[-1], dconv_out_feature=config.features[-2], is_concat=True)
246
 
247
  # Segmentation Head
248
  self.segmenation_head = nn.Sequential(
249
+ nn.Conv2d(in_channels=config.features[-3], out_channels=config.num_classes, kernel_size=1, padding=0, stride=1),
250
+ ImageFolding(image_size=IMAGE_SIZE[0], patch_size=config.patch_size, batch_size=config.batch_size)
251
  )
252
 
253
  def forward(self, x):