Spaces:
Runtime error
Runtime error
| from functools import partial | |
| import torch | |
| from timm.models.efficientnet import tf_efficientnet_b3_ns, tf_efficientnet_b5_ns | |
| from torch import nn | |
| from torch.nn import Dropout2d, Conv2d | |
| from torch.nn.modules.dropout import Dropout | |
| from torch.nn.modules.linear import Linear | |
| from torch.nn.modules.pooling import AdaptiveAvgPool2d | |
| from torch.nn.modules.upsampling import UpsamplingBilinear2d | |
| encoder_params = { | |
| "tf_efficientnet_b3_ns": { | |
| "features": 1536, | |
| "filters": [40, 32, 48, 136, 1536], | |
| "decoder_filters": [64, 128, 256, 256], | |
| "init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2) | |
| }, | |
| "tf_efficientnet_b5_ns": { | |
| "features": 2048, | |
| "filters": [48, 40, 64, 176, 2048], | |
| "decoder_filters": [64, 128, 256, 256], | |
| "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2) | |
| }, | |
| } | |
| class DecoderBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.layer = nn.Sequential( | |
| nn.Upsample(scale_factor=2), | |
| nn.Conv2d(in_channels, out_channels, 3, padding=1), | |
| nn.ReLU(inplace=True) | |
| ) | |
| def forward(self, x): | |
| return self.layer(x) | |
| class ConcatBottleneck(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.seq = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, 3, padding=1), | |
| nn.ReLU(inplace=True) | |
| ) | |
| def forward(self, dec, enc): | |
| x = torch.cat([dec, enc], dim=1) | |
| return self.seq(x) | |
| class Decoder(nn.Module): | |
| def __init__(self, decoder_filters, filters, upsample_filters=None, | |
| decoder_block=DecoderBlock, bottleneck=ConcatBottleneck, dropout=0): | |
| super().__init__() | |
| self.decoder_filters = decoder_filters | |
| self.filters = filters | |
| self.decoder_block = decoder_block | |
| self.decoder_stages = nn.ModuleList([self._get_decoder(idx) for idx in range(0, len(decoder_filters))]) | |
| self.bottlenecks = nn.ModuleList([bottleneck(self.filters[-i - 2] + f, f) | |
| for i, f in enumerate(reversed(decoder_filters))]) | |
| self.dropout = Dropout2d(dropout) if dropout > 0 else None | |
| self.last_block = None | |
| if upsample_filters: | |
| self.last_block = decoder_block(decoder_filters[0], out_channels=upsample_filters) | |
| else: | |
| self.last_block = UpsamplingBilinear2d(scale_factor=2) | |
| def forward(self, encoder_results: list): | |
| x = encoder_results[0] | |
| bottlenecks = self.bottlenecks | |
| for idx, bottleneck in enumerate(bottlenecks): | |
| rev_idx = - (idx + 1) | |
| x = self.decoder_stages[rev_idx](x) | |
| x = bottleneck(x, encoder_results[-rev_idx]) | |
| if self.last_block: | |
| x = self.last_block(x) | |
| if self.dropout: | |
| x = self.dropout(x) | |
| return x | |
| def _get_decoder(self, layer): | |
| idx = layer + 1 | |
| if idx == len(self.decoder_filters): | |
| in_channels = self.filters[idx] | |
| else: | |
| in_channels = self.decoder_filters[idx] | |
| return self.decoder_block(in_channels, self.decoder_filters[max(layer, 0)]) | |
| def _initialize_weights(module): | |
| for m in module.modules(): | |
| if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): | |
| m.weight.data = nn.init.kaiming_normal_(m.weight.data) | |
| if m.bias is not None: | |
| m.bias.data.zero_() | |
| elif isinstance(m, nn.BatchNorm2d): | |
| m.weight.data.fill_(1) | |
| m.bias.data.zero_() | |
| class EfficientUnetClassifier(nn.Module): | |
| def __init__(self, encoder, dropout_rate=0.5) -> None: | |
| super().__init__() | |
| self.decoder = Decoder(decoder_filters=encoder_params[encoder]["decoder_filters"], | |
| filters=encoder_params[encoder]["filters"]) | |
| self.avg_pool = AdaptiveAvgPool2d((1, 1)) | |
| self.dropout = Dropout(dropout_rate) | |
| self.fc = Linear(encoder_params[encoder]["features"], 1) | |
| self.final = Conv2d(encoder_params[encoder]["decoder_filters"][0], out_channels=1, kernel_size=1, bias=False) | |
| _initialize_weights(self) | |
| self.encoder = encoder_params[encoder]["init_op"]() | |
| def get_encoder_features(self, x): | |
| encoder_results = [] | |
| x = self.encoder.conv_stem(x) | |
| x = self.encoder.bn1(x) | |
| x = self.encoder.act1(x) | |
| encoder_results.append(x) | |
| x = self.encoder.blocks[:2](x) | |
| encoder_results.append(x) | |
| x = self.encoder.blocks[2:3](x) | |
| encoder_results.append(x) | |
| x = self.encoder.blocks[3:5](x) | |
| encoder_results.append(x) | |
| x = self.encoder.blocks[5:](x) | |
| x = self.encoder.conv_head(x) | |
| x = self.encoder.bn2(x) | |
| x = self.encoder.act2(x) | |
| encoder_results.append(x) | |
| encoder_results = list(reversed(encoder_results)) | |
| return encoder_results | |
| def forward(self, x): | |
| encoder_results = self.get_encoder_features(x) | |
| seg = self.final(self.decoder(encoder_results)) | |
| x = encoder_results[0] | |
| x = self.avg_pool(x).flatten(1) | |
| x = self.dropout(x) | |
| x = self.fc(x) | |
| return x, seg | |
| if __name__ == '__main__': | |
| model = EfficientUnetClassifier("tf_efficientnet_b5_ns") | |
| model.eval() | |
| with torch.no_grad(): | |
| input = torch.rand(4, 3, 224, 224) | |
| print(model(input)) | |