Upload Model.py with huggingface_hub
Browse files
Model.py
CHANGED
|
@@ -208,6 +208,21 @@ class Decoder(nn.Module):
|
|
| 208 |
x = self.double_conv(x)
|
| 209 |
return x
|
| 210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
class VesselSegmentModel(PreTrainedModel):
|
| 212 |
config_class = VesselSegmentConfig
|
| 213 |
def __init__(self, config: VesselSegmentConfig=VesselSegmentConfig()):
|
|
@@ -233,12 +248,14 @@ class VesselSegmentModel(PreTrainedModel):
|
|
| 233 |
self.decoder_layer_3 = Decoder(tensor_dim_decoder=config.features[-1], tensor_dim_encoder=config.features[-2], tensor_dim_mid=config.features[2], up_conv_in_ch=config.features[-1], up_conv_out_ch=config.features[-2], up_conv_scale=2, dconv_in_feature=config.features[-1], dconv_out_feature=config.features[-2], is_concat=True)
|
| 234 |
|
| 235 |
# Segmentation Head
|
| 236 |
-
self.segmenation_head =
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
)
|
| 240 |
-
|
|
|
|
| 241 |
def forward(self, x):
|
|
|
|
| 242 |
IMG_1 = self.img_patch(x)
|
| 243 |
IMG_2 = self.img_down_sampling_1(IMG_1)
|
| 244 |
IMG_3 = self.img_down_sampling_2(IMG_2)
|
|
@@ -257,6 +274,6 @@ class VesselSegmentModel(PreTrainedModel):
|
|
| 257 |
d3 = self.decoder_layer_3(sk1, d2)
|
| 258 |
|
| 259 |
# head
|
| 260 |
-
head = self.segmenation_head(d3)
|
| 261 |
|
| 262 |
return head
|
|
|
|
| 208 |
x = self.double_conv(x)
|
| 209 |
return x
|
| 210 |
|
| 211 |
+
######################################################################
|
| 212 |
+
# SEGMENTATION HEAD
|
| 213 |
+
######################################################################
|
| 214 |
+
class SegmentationHead(nn.Module):
|
| 215 |
+
def __init__(self, feature_dim, num_classes, config:VesselSegmentConfig = VesselSegmentConfig()):
|
| 216 |
+
super().__init__()
|
| 217 |
+
self.config = config
|
| 218 |
+
self.conv = nn.Conv2d(in_channels=feature_dim, out_channels=num_classes, kernel_size=1, stride=1, padding=0)
|
| 219 |
+
|
| 220 |
+
def forward(self, x, batch_size):
|
| 221 |
+
x1 = self.conv(x)
|
| 222 |
+
x1 = ImageFolding(image_size=self.config.image_size[0], patch_size=self.config.patch_size, batch_size=batch_size)(x1)
|
| 223 |
+
return x1
|
| 224 |
+
|
| 225 |
+
|
| 226 |
class VesselSegmentModel(PreTrainedModel):
|
| 227 |
config_class = VesselSegmentConfig
|
| 228 |
def __init__(self, config: VesselSegmentConfig=VesselSegmentConfig()):
|
|
|
|
| 248 |
self.decoder_layer_3 = Decoder(tensor_dim_decoder=config.features[-1], tensor_dim_encoder=config.features[-2], tensor_dim_mid=config.features[2], up_conv_in_ch=config.features[-1], up_conv_out_ch=config.features[-2], up_conv_scale=2, dconv_in_feature=config.features[-1], dconv_out_feature=config.features[-2], is_concat=True)
|
| 249 |
|
| 250 |
# Segmentation Head
|
| 251 |
+
self.segmenation_head = SegmentationHead(feature_dim=config.features[-3], num_classes=config.num_classes)
|
| 252 |
+
# self.segmenation_head = nn.Sequential(
|
| 253 |
+
# nn.Conv2d(in_channels=config.features[-3], out_channels=config.num_classes, kernel_size=1, padding=0, stride=1),
|
| 254 |
+
# ImageFolding(image_size=config.image_size[0], patch_size=config.patch_size, batch_size=config.batch_size)
|
| 255 |
+
# )
|
| 256 |
+
|
| 257 |
def forward(self, x):
|
| 258 |
+
B,C,H,W = x.shape
|
| 259 |
IMG_1 = self.img_patch(x)
|
| 260 |
IMG_2 = self.img_down_sampling_1(IMG_1)
|
| 261 |
IMG_3 = self.img_down_sampling_2(IMG_2)
|
|
|
|
| 274 |
d3 = self.decoder_layer_3(sk1, d2)
|
| 275 |
|
| 276 |
# head
|
| 277 |
+
head = self.segmenation_head(d3, B)
|
| 278 |
|
| 279 |
return head
|