Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from typing import Tuple, Union | |
| from monai.networks.blocks.dynunet_block import UnetOutBlock | |
| from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock | |
| def build_sam_decoder_vit_h(): | |
| return _build_sam_decoder( | |
| encoder_embed_dim=1280, | |
| encoder_num_heads=16, | |
| ) | |
| def build_sam_decoder_vit_l(): | |
| return _build_sam_decoder( | |
| encoder_embed_dim=1024, | |
| encoder_num_heads=16, | |
| ) | |
| def build_sam_decoder_vit_b(): | |
| return _build_sam_decoder( | |
| encoder_embed_dim=768, | |
| encoder_num_heads=12, | |
| ) | |
| sam_decoder_reg = { | |
| "default": build_sam_decoder_vit_h, | |
| "vit_h": build_sam_decoder_vit_h, | |
| "vit_l": build_sam_decoder_vit_l, | |
| "vit_b": build_sam_decoder_vit_b, | |
| } | |
| def _build_sam_decoder( | |
| encoder_embed_dim, | |
| encoder_num_heads, | |
| ): | |
| image_size = 1024 | |
| vit_patch_size = 16 | |
| return ImageDecoderViT( | |
| hidden_size=encoder_embed_dim, | |
| img_size=image_size, | |
| num_heads=encoder_num_heads, | |
| patch_size=vit_patch_size, | |
| ) | |
| class ImageDecoderViT(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int = 3, | |
| feature_size: int = 64, | |
| hidden_size: int = 1280, | |
| conv_block: bool = True, | |
| res_block: bool = True, | |
| norm_name: Union[Tuple, str] = "instance", | |
| dropout_rate: float = 0.0, | |
| spatial_dims: int = 2, | |
| img_size: int = 1024, | |
| patch_size: int = 16, | |
| out_channels: int = 1, | |
| num_heads: int = 12, | |
| ) -> None: | |
| super().__init__() | |
| if not (0 <= dropout_rate <= 1): | |
| raise AssertionError("dropout_rate should be between 0 and 1.") | |
| if hidden_size % num_heads != 0: | |
| raise AssertionError("hidden size should be divisible by num_heads.") | |
| self.patch_size = patch_size | |
| self.feat_size = ( | |
| img_size // self.patch_size, | |
| img_size // self.patch_size | |
| ) | |
| self.hidden_size = hidden_size | |
| self.classification = False | |
| self.encoder_low_res_mask = nn.Sequential( | |
| UnetrBasicBlock( | |
| spatial_dims=spatial_dims, | |
| in_channels=out_channels, | |
| out_channels=feature_size, | |
| kernel_size=3, | |
| stride=1, | |
| norm_name=norm_name, | |
| res_block=res_block, | |
| ), | |
| UnetrBasicBlock( | |
| spatial_dims=spatial_dims, | |
| in_channels=feature_size, | |
| out_channels=feature_size * 4, | |
| kernel_size=3, | |
| stride=1, | |
| norm_name=norm_name, | |
| res_block=res_block, | |
| ), | |
| ) | |
| self.decoder_fuse = UnetrBasicBlock( | |
| spatial_dims=spatial_dims, | |
| in_channels=feature_size * 8, | |
| out_channels=feature_size * 4, | |
| kernel_size=3, | |
| stride=1, | |
| norm_name=norm_name, | |
| res_block=res_block, | |
| ) | |
| self.encoder1 = UnetrBasicBlock( | |
| spatial_dims=spatial_dims, | |
| in_channels=in_channels, | |
| out_channels=feature_size, | |
| kernel_size=3, | |
| stride=1, | |
| norm_name=norm_name, | |
| res_block=res_block, | |
| ) | |
| self.encoder2 = UnetrPrUpBlock( | |
| spatial_dims=spatial_dims, | |
| in_channels=hidden_size, | |
| out_channels=feature_size * 2, | |
| num_layer=2, | |
| kernel_size=3, | |
| stride=1, | |
| upsample_kernel_size=2, | |
| norm_name=norm_name, | |
| conv_block=conv_block, | |
| res_block=res_block, | |
| ) | |
| self.encoder3 = UnetrPrUpBlock( | |
| spatial_dims=spatial_dims, | |
| in_channels=hidden_size, | |
| out_channels=feature_size * 4, | |
| num_layer=1, | |
| kernel_size=3, | |
| stride=1, | |
| upsample_kernel_size=2, | |
| norm_name=norm_name, | |
| conv_block=conv_block, | |
| res_block=res_block, | |
| ) | |
| self.encoder4 = UnetrPrUpBlock( | |
| spatial_dims=spatial_dims, | |
| in_channels=hidden_size, | |
| out_channels=feature_size * 8, | |
| num_layer=0, | |
| kernel_size=3, | |
| stride=1, | |
| upsample_kernel_size=2, | |
| norm_name=norm_name, | |
| conv_block=conv_block, | |
| res_block=res_block, | |
| ) | |
| self.decoder5 = UnetrUpBlock( | |
| spatial_dims=spatial_dims, | |
| in_channels=hidden_size, | |
| out_channels=feature_size * 8, | |
| kernel_size=3, | |
| upsample_kernel_size=2, | |
| norm_name=norm_name, | |
| res_block=res_block, | |
| ) | |
| self.decoder4 = UnetrUpBlock( | |
| spatial_dims=spatial_dims, | |
| in_channels=feature_size * 8, | |
| out_channels=feature_size * 4, | |
| kernel_size=3, | |
| upsample_kernel_size=2, | |
| norm_name=norm_name, | |
| res_block=res_block, | |
| ) | |
| self.decoder3 = UnetrUpBlock( | |
| spatial_dims=spatial_dims, | |
| in_channels=feature_size * 4, | |
| out_channels=feature_size * 2, | |
| kernel_size=3, | |
| upsample_kernel_size=2, | |
| norm_name=norm_name, | |
| res_block=res_block, | |
| ) | |
| self.decoder2 = UnetrUpBlock( | |
| spatial_dims=spatial_dims, | |
| in_channels=feature_size * 2, | |
| out_channels=feature_size, | |
| kernel_size=3, | |
| upsample_kernel_size=2, | |
| norm_name=norm_name, | |
| res_block=res_block, | |
| ) | |
| self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels) | |
| self.proj_axes = (0, spatial_dims + 1) + tuple(d + 1 for d in range(spatial_dims)) | |
| self.proj_view_shape = list(self.feat_size) + [self.hidden_size] | |
| def proj_feat(self, x): | |
| new_view = [x.size(0)] + self.proj_view_shape | |
| x = x.view(new_view) | |
| x = x.permute(self.proj_axes).contiguous() | |
| return x | |
| def forward(self, x_img,hidden_states_out, low_res_mask): | |
| enc1 = self.encoder1(x_img) | |
| x2 = hidden_states_out[0] | |
| enc2 = self.encoder2(self.proj_feat(x2)) | |
| x3 = hidden_states_out[1] | |
| enc3 = self.encoder3(self.proj_feat(x3)) | |
| x4 = hidden_states_out[2] | |
| enc4 = self.encoder4(self.proj_feat(x4)) | |
| dec4 = self.proj_feat(hidden_states_out[3]) | |
| dec3 = self.decoder5(dec4, enc4) | |
| dec2 = self.decoder4(dec3, enc3) | |
| if low_res_mask != None: | |
| enc_mask = self.encoder_low_res_mask(low_res_mask) | |
| fused_dec2 = torch.cat([dec2, enc_mask], dim=1) | |
| fused_dec2 = self.decoder_fuse(fused_dec2) | |
| dec1 = self.decoder3(fused_dec2, enc2) | |
| else: | |
| dec1 = self.decoder3(dec2, enc2) | |
| out = self.decoder2(dec1, enc1) | |
| return self.out(out) | |