Spaces:
Sleeping
Sleeping
| """ | |
| paper: https://arxiv.org/abs/2004.08790 | |
| ref: https://github.com/ZJUGiveLab/UNet-Version/blob/master/models/UNet_3Plus.py | |
| """ | |
| import torch | |
| from torch import nn | |
| from torch.functional import F | |
| class UNetConv(nn.Module): | |
| def __init__( | |
| self, | |
| in_size, | |
| out_size, | |
| is_batchnorm=True, | |
| num_layers=2, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| ): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| for i in range(num_layers): | |
| seq = [nn.Conv1d(in_size, out_size, kernel_size, stride, padding)] | |
| if is_batchnorm: | |
| seq.append(nn.BatchNorm1d(out_size)) | |
| seq.append(nn.ReLU()) | |
| conv = nn.Sequential(*seq) | |
| setattr(self, "conv%d" % i, conv) | |
| in_size = out_size | |
| def forward(self, inputs): | |
| x = inputs | |
| for i in range(self.num_layers): | |
| conv = getattr(self, "conv%d" % i) | |
| x = conv(x) | |
| return x | |
| class UNet3PlusDeepSup(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| inplanes = int(config.inplanes) | |
| kernel_size = int(config.kernel_size) | |
| padding = (kernel_size - 1) // 2 | |
| num_encoder_layers = int(config.num_encoder_layers) | |
| encoder_batchnorm = bool(config.encoder_batchnorm) | |
| self.num_depths = int(config.num_depths) | |
| self.interpolate_mode = str(config.interpolate_mode) | |
| dropout = float(config.dropout) | |
| self.use_cgm = bool(config.use_cgm) | |
| # sum_of_sup == True: ๋ชจ๋ sup ์ elementwise sum ํ์ฌ ํ๋์ dense map ์ ๋ง๋ค์ด label ๊ณผ loss ๋ฅผ ๊ตฌํจ | |
| # sum_of_sup == False: ๊ฐ sup ๊ณผ label์ loss ๋ฅผ ๊ฐ๊ฐ ๊ตฌํ์ฌ ํ๋์ loss ์ ์ ์ฅ | |
| self.sum_of_sup = bool(config.sum_of_sup) | |
| # TrialSetup._init_network_params ์์ ์ค์ ๋จ | |
| self.output_size: int = config.output_size | |
| # Encoder | |
| self.encoders = torch.nn.ModuleList() | |
| for i in range(self.num_depths): | |
| """(MaxPool - UNetConv) ๋ฅผ ์ํํ๋ ๊ฒ์ด ํ๋์ depth ์ด๊ณ , ์์ธ์ ์ผ๋ก ์ฒซ๋ฒ์งธ depth ์ encode ๊ฒฐ๊ณผ๋ (UNetConv)๋ง ์ํํ ๊ฒ""" | |
| _encoders = [] | |
| if i != 0: | |
| _encoders.append(nn.MaxPool1d(2)) | |
| _encoders.append( | |
| UNetConv( | |
| 1 if i == 0 else (inplanes * (2 ** (i - 1))), | |
| inplanes * (2**i), | |
| is_batchnorm=encoder_batchnorm, | |
| num_layers=num_encoder_layers, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| padding=padding, | |
| ) | |
| ) | |
| self.encoders.append(nn.Sequential(*_encoders)) | |
| # CGM: Classification-Guided Module | |
| if self.use_cgm: | |
| self.cls = nn.Sequential( | |
| nn.Dropout(dropout), | |
| nn.Conv1d( | |
| inplanes * (2 ** (self.num_depths - 1)), 2 * self.output_size, 1 | |
| ), | |
| nn.AdaptiveMaxPool1d(1), | |
| nn.Sigmoid(), | |
| ) | |
| # Decoder | |
| self.up_channels = inplanes * self.num_depths | |
| self.decoders = torch.nn.ModuleList() | |
| for i in reversed(range(self.num_depths - 1)): | |
| """ | |
| ๊ฐ decoder ๋ ๊ฐ encode ๊ฒฐ๊ณผ๋ฅผ MaxPool ํ๊ฑฐ๋ ๊ทธ๋๋ก(Conv,BatchNorm,Relu ๋ง) ์ฌ์ฉํ๊ฑฐ๋ Upsample ๋ ๊ฒฐ๊ณผ๋ฅผ ์ํํ๊ณ concat ํ์ฌ (Conv,BatchNorm,Relu)๋ฅผ ์ํํ ์ ์๋๋ก ๊ตฌ์ฑ | |
| ๋ค๋ง, Upsample ์ encode ๊ฒฐ๊ณผ์ size ๋ฅผ ๋ง์ถ๊ธฐ ๊ฐํธํ๋๋ก forward ๋จ๊ณ์์ torch.functional.interpolate() ๋ก ์ํ | |
| """ | |
| # ๊ฐ ๋จ๊ณ๋ณ decoder ๋ ํญ์ num_depths ๋งํผ ๊ตฌ์ฑ๋๊ณ ๋ด๋ถ์ ์ผ๋ก MaxPool/๊ทธ๋๋ก/Upsample ์ํํ ์ง๊ฐ ๋ฌ๋ผ์ง | |
| _decoders = torch.nn.ModuleList() | |
| for j in range(self.num_depths): | |
| _each_decoders = [] | |
| if j < i: | |
| _each_decoders.append(nn.MaxPool1d(2 ** (i - j), ceil_mode=True)) | |
| if i < j < self.num_depths - 1: | |
| _each_decoders.append( | |
| nn.Conv1d( | |
| inplanes * self.num_depths, | |
| inplanes, | |
| kernel_size, | |
| padding=padding, | |
| ) | |
| ) | |
| else: | |
| _each_decoders.append( | |
| nn.Conv1d( | |
| inplanes * (2**j), inplanes, kernel_size, padding=padding | |
| ) | |
| ) | |
| _each_decoders.append(nn.BatchNorm1d(inplanes)) | |
| _each_decoders.append(nn.ReLU()) | |
| _decoders.append(nn.Sequential(*_each_decoders)) | |
| _decoders.append( | |
| nn.Sequential( | |
| nn.Conv1d( | |
| self.up_channels, self.up_channels, kernel_size, padding=padding | |
| ), | |
| nn.BatchNorm1d(self.up_channels), | |
| nn.ReLU(), | |
| ) | |
| ) | |
| self.decoders.append(_decoders) | |
| # ์ conv ๋ค์ in channel ์ด up_channels(inplanes*num_depths(์๋ณธ์์๋ 320)), ๋ง์ง๋ง conv ๋ ๋ง์ง๋ง encoder ๊ฒฐ๊ณผ์ output_channel ๊ณผ ๋ง์ถค | |
| self.sup_conv = torch.nn.ModuleList() | |
| for i in range(self.num_depths - 1): | |
| self.sup_conv.append( | |
| nn.Sequential( | |
| nn.Conv1d( | |
| self.up_channels, self.output_size, kernel_size, padding=padding | |
| ), | |
| nn.BatchNorm1d(self.output_size), | |
| nn.ReLU(), | |
| ) | |
| ) | |
| self.sup_conv.append( | |
| nn.Sequential( | |
| nn.Conv1d( | |
| inplanes * (2 ** (self.num_depths - 1)), | |
| self.output_size, | |
| kernel_size, | |
| padding=padding, | |
| ), | |
| nn.BatchNorm1d(self.output_size), | |
| nn.ReLU(), | |
| ) | |
| ) | |
| def forward(self, input: torch.Tensor, y=None): | |
| # Encoder | |
| output = input | |
| enc_features = [] # X1Ee, X2Ee, .. , X5Ee | |
| dec_features = [] # X5Ee, X4De, .. , X1De | |
| for encoder in self.encoders: | |
| output = encoder(output) | |
| enc_features.append(output) | |
| dec_features.append(output) | |
| # CGM | |
| cls_branch_max = None | |
| if self.use_cgm: | |
| # (B, 2*3(output_size), 1) | |
| cls_branch: torch.Tensor = self.cls(enc_features[-1]) | |
| # (B, 3(output_size)) | |
| cls_branch_max = cls_branch.view( | |
| input.shape[0], self.output_size, 2 | |
| ).argmax(2) | |
| # Decoder | |
| for i in reversed(range(self.num_depths - 1)): | |
| _each_dec_feature = [] | |
| for j in range(self.num_depths): | |
| if j <= i: | |
| _each_enc = enc_features[j] | |
| else: | |
| _each_enc = F.interpolate( | |
| dec_features[self.num_depths - j - 1], | |
| enc_features[i].shape[2], | |
| mode=self.interpolate_mode, | |
| ) | |
| _each_dec_feature.append( | |
| self.decoders[self.num_depths - i - 2][j](_each_enc) | |
| ) | |
| dec_features.append( | |
| self.decoders[self.num_depths - i - 2][-1]( | |
| torch.cat(_each_dec_feature, dim=1) | |
| ) | |
| ) | |
| sup = [] | |
| for i, (dec_feature, sup_conv) in enumerate( | |
| zip(dec_features, reversed(self.sup_conv)) | |
| ): | |
| if i < self.num_depths - 1: | |
| sup.append( | |
| F.interpolate( | |
| sup_conv(dec_feature), | |
| input.shape[2], | |
| mode=self.interpolate_mode, | |
| ) | |
| ) | |
| else: | |
| sup.append(sup_conv(dec_feature)) | |
| if self.use_cgm: | |
| if self.sum_of_sup: | |
| return torch.sigmoid( | |
| sum( | |
| [ | |
| torch.einsum("ijk,ij->ijk", [_sup, cls_branch_max]) | |
| for _sup in reversed(sup) | |
| ] | |
| ) | |
| ) | |
| else: | |
| return [ | |
| torch.sigmoid( | |
| torch.einsum("ijk,ij->ijk", [_sup, cls_branch_max]) | |
| for _sup in reversed(sup) | |
| ) | |
| ] | |
| else: | |
| if self.sum_of_sup: | |
| return torch.sigmoid(sum(sup)) | |
| else: | |
| return [torch.sigmoid(_sup) for _sup in reversed(sup)] | |