MHasanUnical commited on
Commit
8b38ef0
·
verified ·
1 Parent(s): f177f11

Upload Model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. Model.py +23 -6
Model.py CHANGED
@@ -208,6 +208,21 @@ class Decoder(nn.Module):
208
  x = self.double_conv(x)
209
  return x
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  class VesselSegmentModel(PreTrainedModel):
212
  config_class = VesselSegmentConfig
213
  def __init__(self, config: VesselSegmentConfig=VesselSegmentConfig()):
@@ -233,12 +248,14 @@ class VesselSegmentModel(PreTrainedModel):
233
  self.decoder_layer_3 = Decoder(tensor_dim_decoder=config.features[-1], tensor_dim_encoder=config.features[-2], tensor_dim_mid=config.features[2], up_conv_in_ch=config.features[-1], up_conv_out_ch=config.features[-2], up_conv_scale=2, dconv_in_feature=config.features[-1], dconv_out_feature=config.features[-2], is_concat=True)
234
 
235
  # Segmentation Head
236
- self.segmenation_head = nn.Sequential(
237
- nn.Conv2d(in_channels=config.features[-3], out_channels=config.num_classes, kernel_size=1, padding=0, stride=1),
238
- ImageFolding(image_size=config.image_size[0], patch_size=config.patch_size, batch_size=config.batch_size)
239
- )
240
-
 
241
  def forward(self, x):
 
242
  IMG_1 = self.img_patch(x)
243
  IMG_2 = self.img_down_sampling_1(IMG_1)
244
  IMG_3 = self.img_down_sampling_2(IMG_2)
@@ -257,6 +274,6 @@ class VesselSegmentModel(PreTrainedModel):
257
  d3 = self.decoder_layer_3(sk1, d2)
258
 
259
  # head
260
- head = self.segmenation_head(d3)
261
 
262
  return head
 
208
  x = self.double_conv(x)
209
  return x
210
 
211
+ ######################################################################
212
+ # SEGMENTATION HEAD
213
+ ######################################################################
214
+ class SegmentationHead(nn.Module):
215
+ def __init__(self, feature_dim, num_classes, config:VesselSegmentConfig = VesselSegmentConfig()):
216
+ super().__init__()
217
+ self.config = config
218
+ self.conv = nn.Conv2d(in_channels=feature_dim, out_channels=num_classes, kernel_size=1, stride=1, padding=0)
219
+
220
+ def forward(self, x, batch_size):
221
+ x1 = self.conv(x)
222
+ x1 = ImageFolding(image_size=self.config.image_size[0], patch_size=self.config.patch_size, batch_size=batch_size)(x1)
223
+ return x1
224
+
225
+
226
  class VesselSegmentModel(PreTrainedModel):
227
  config_class = VesselSegmentConfig
228
  def __init__(self, config: VesselSegmentConfig=VesselSegmentConfig()):
 
248
  self.decoder_layer_3 = Decoder(tensor_dim_decoder=config.features[-1], tensor_dim_encoder=config.features[-2], tensor_dim_mid=config.features[2], up_conv_in_ch=config.features[-1], up_conv_out_ch=config.features[-2], up_conv_scale=2, dconv_in_feature=config.features[-1], dconv_out_feature=config.features[-2], is_concat=True)
249
 
250
  # Segmentation Head
251
+ self.segmenation_head = SegmentationHead(feature_dim=config.features[-3], num_classes=config.num_classes)
252
+ # self.segmenation_head = nn.Sequential(
253
+ # nn.Conv2d(in_channels=config.features[-3], out_channels=config.num_classes, kernel_size=1, padding=0, stride=1),
254
+ # ImageFolding(image_size=config.image_size[0], patch_size=config.patch_size, batch_size=config.batch_size)
255
+ # )
256
+
257
  def forward(self, x):
258
+ B,C,H,W = x.shape
259
  IMG_1 = self.img_patch(x)
260
  IMG_2 = self.img_down_sampling_1(IMG_1)
261
  IMG_3 = self.img_down_sampling_2(IMG_2)
 
274
  d3 = self.decoder_layer_3(sk1, d2)
275
 
276
  # head
277
+ head = self.segmenation_head(d3, B)
278
 
279
  return head