| import random |
| from typing import Tuple, Union, List |
|
|
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch |
| from einops import rearrange, repeat, reduce |
| from positional_encodings.torch_encodings import PositionalEncoding3D |
| from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet |
| from dynamic_network_architectures.building_blocks.simple_conv_blocks import StackedConvBlocks |
| from dynamic_network_architectures.initialization.weight_init import InitWeights_He |
|
|
| from .transformer_decoder import TransformerDecoder,TransformerDecoderLayer |
| from .SwinUNETR import SwinUNETR |
|
|
| class Maskformer(nn.Module): |
| def __init__(self, vision_backbone='UNET', input_channels=1, image_size=[288, 288, 96], patch_size=[32, 32, 32], deep_supervision=False): |
| """ |
| Args: |
| vision_backbone (str, optional): visual backbone. Defaults to UNET. |
| image_size (list, optional): image size. Defaults to [288, 288, 96]. |
| patch_size (list, optional): maxium downsample ratio of the bottleneck feature map. Defaults to [32, 32, 32]. |
| deep_supervision (bool, optional): seg results from mid layers of decoder. Defaults to False. |
| """ |
| super().__init__() |
| image_height, image_width, frames = image_size |
| self.hw_patch_size = patch_size[0] |
| self.frame_patch_size = patch_size[-1] |
| |
| self.deep_supervision = deep_supervision |
| |
| |
| |
| |
| self.backbone = { |
| 'SwinUNETR' : SwinUNETR( |
| img_size=[288, 288, 96], |
| in_channels=3, |
| feature_size=48, |
| drop_rate=0.0, |
| attn_drop_rate=0.0, |
| dropout_path_rate=0.0, |
| use_checkpoint=False, |
| ), |
| 'UNET' : PlainConvUNet(input_channels=input_channels, |
| n_stages=6, |
| features_per_stage=(64, 64, 128, 256, 512, 768), |
| conv_op=nn.Conv3d, |
| kernel_sizes=3, |
| strides=(1, 2, 2, 2, 2, 2), |
| n_conv_per_stage=(2, 2, 2, 2, 2, 2), |
| n_conv_per_stage_decoder=(2, 2, 2, 2, 2), |
| conv_bias=True, |
| norm_op=nn.InstanceNorm3d, |
| norm_op_kwargs={'eps': 1e-5, 'affine': True}, |
| dropout_op=None, |
| dropout_op_kwargs=None, |
| nonlin=nn.LeakyReLU, |
| nonlin_kwargs=None, |
| deep_supervision=deep_supervision, |
| nonlin_first=False |
| ), |
| 'UNET-L' : PlainConvUNet(input_channels=3, |
| n_stages=6, |
| features_per_stage=(128, 128, 256, 512, 1024, 1536), |
| conv_op=nn.Conv3d, |
| kernel_sizes=3, |
| strides=(1, 2, 2, 2, 2, 2), |
| n_conv_per_stage=(3, 3, 3, 3, 3, 3), |
| n_conv_per_stage_decoder=(3, 3, 3, 3, 3), |
| conv_bias=True, |
| norm_op=nn.InstanceNorm3d, |
| norm_op_kwargs={'eps': 1e-5, 'affine': True}, |
| dropout_op=None, |
| dropout_op_kwargs=None, |
| nonlin=nn.LeakyReLU, |
| nonlin_kwargs=None, |
| deep_supervision=deep_supervision, |
| nonlin_first=False |
| ), |
| 'UNET-H' : PlainConvUNet(input_channels=3, |
| n_stages=6, |
| features_per_stage=(256, 256, 512, 1024, 1536, 2048), |
| conv_op=nn.Conv3d, |
| kernel_sizes=3, |
| strides=(1, 2, 2, 2, 2, 2), |
| n_conv_per_stage=(3, 3, 3, 3, 3, 3), |
| n_conv_per_stage_decoder=(3, 3, 3, 3, 3), |
| conv_bias=True, |
| norm_op=nn.InstanceNorm3d, |
| norm_op_kwargs={'eps': 1e-5, 'affine': True}, |
| dropout_op=None, |
| dropout_op_kwargs=None, |
| nonlin=nn.LeakyReLU, |
| nonlin_kwargs=None, |
| deep_supervision=deep_supervision, |
| nonlin_first=False |
| ), |
| 'UNET-Res' : ResidualEncoderUNet( |
| input_channels=input_channels, |
| n_stages=6, |
| features_per_stage=[32, 64, 128, 256, 320, 320], |
| conv_op=nn.Conv3d, |
| kernel_sizes=[[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], |
| strides=[[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], |
| n_blocks_per_stage=[1, 3, 4, 6, 6, 6], |
| n_conv_per_stage_decoder=[1, 1, 1, 1, 1], |
| conv_bias=True, |
| norm_op=nn.InstanceNorm3d, |
| norm_op_kwargs={"eps": 1e-5, "affine": True}, |
| nonlin=nn.LeakyReLU, |
| nonlin_kwargs={"inplace": True}, |
| deep_supervision=deep_supervision, |
| ) |
| }[vision_backbone] |
| |
| self.backbone.apply(InitWeights_He(1e-2)) |
| |
| |
| if vision_backbone == 'UNET-H': |
| query_dim = 1536 |
| elif vision_backbone == 'UNET-Res': |
| query_dim = 320 |
| else: |
| query_dim = 768 |
|
|
| |
| self.avg_pool_ls = [ |
| nn.AvgPool3d(32, 32), |
| nn.AvgPool3d(16, 16), |
| nn.AvgPool3d(8, 8), |
| nn.AvgPool3d(4, 4), |
| nn.AvgPool3d(2, 2), |
| ] |
| |
| |
| self.projection_layer = { |
| 'SwinUNETR' : nn.Sequential( |
| nn.Linear(1536, 768), |
| nn.GELU(), |
| nn.Linear(768, query_dim), |
| nn.GELU() |
| ), |
| 'UNET' : nn.Sequential( |
| nn.Linear(1792, 768), |
| nn.GELU(), |
| nn.Linear(768, query_dim), |
| nn.GELU() |
| ), |
| 'UNET-L' : nn.Sequential( |
| nn.Linear(3584, 1536), |
| nn.GELU(), |
| nn.Linear(1536, query_dim), |
| nn.GELU() |
| ), |
| 'UNET-H' : nn.Sequential( |
| nn.Linear(5632, 3072), |
| nn.GELU(), |
| nn.Linear(3072, query_dim), |
| nn.GELU() |
| ), |
| 'UNET-Res' : nn.Sequential( |
| nn.Linear(1120, 320), |
| nn.GELU(), |
| nn.Linear(320, query_dim), |
| nn.GELU() |
| ) |
| }[vision_backbone] |
| |
| |
| pos_embedding = PositionalEncoding3D(query_dim)(torch.zeros(1, (image_height//self.hw_patch_size), (image_width//self.hw_patch_size), (frames//self.frame_patch_size), query_dim)) |
| self.pos_embedding = rearrange(pos_embedding, 'b h w d c -> (h w d) b c') |
| |
| |
| decoder_layer = TransformerDecoderLayer(d_model=query_dim, nhead=8, normalize_before=True) |
| decoder_norm = nn.LayerNorm(query_dim) |
| self.transformer_decoder = TransformerDecoder(decoder_layer=decoder_layer, num_layers=6, norm=decoder_norm) |
| |
| if query_dim != 768: |
| self.query_proj = nn.Sequential( |
| nn.Linear(768, query_dim), |
| nn.GELU(), |
| nn.Linear(query_dim, query_dim), |
| nn.GELU() |
| ) |
| else: |
| self.query_proj = nn.Identity() |
| |
| |
| |
| if self.deep_supervision: |
| feature_per_stage = { |
| 'SwinUNETR':[48, 96, 192], |
| 'UNET':[64, 128, 256], |
| 'UNET-L':[128, 256, 512], |
| 'UNET-H':[256, 512, 1024], |
| 'UNET-Res':[64, 128, 256] |
| }[vision_backbone] |
| mid_dim = { |
| 'SwinUNETR':[256, 384, 512], |
| 'UNET':[256, 384, 512], |
| 'UNET-L':[384, 512, 512], |
| 'UNET-H':[768, 1024, 1024], |
| 'UNET-Res':[256, 320, 320] |
| }[vision_backbone] |
| self.mid_mask_embed_proj = [] |
| for hidden_dim, per_pixel_dim in zip(mid_dim, feature_per_stage): |
| self.mid_mask_embed_proj.append( |
| nn.Sequential( |
| nn.Linear(query_dim, hidden_dim), |
| nn.GELU(), |
| nn.Linear(hidden_dim, per_pixel_dim), |
| nn.GELU(), |
| ), |
| ) |
| self.mid_mask_embed_proj = nn.ModuleList(self.mid_mask_embed_proj) |
| |
| |
| mid_dim, per_pixel_dim = { |
| 'SwinUNETR' : [256, 48], |
| 'UNET' : [256, 64], |
| 'UNET-L' : [384, 128], |
| 'UNET-H' : [768, 256], |
| 'UNET-Res' : [128, 32] |
| }[vision_backbone] |
| self.mask_embed_proj = nn.Sequential( |
| nn.Linear(query_dim, mid_dim), |
| nn.GELU(), |
| nn.Linear(mid_dim, per_pixel_dim), |
| nn.GELU(), |
| ) |
|
|
| self.fusion_conv = StackedConvBlocks( |
| 1, nn.Conv3d, 2 * per_pixel_dim, per_pixel_dim, |
| 3, 1, True, nn.InstanceNorm3d, {'eps': 1e-5, 'affine': True}, |
| None, None, nn.LeakyReLU, None, False) |
|
|
| def enhance_with_coarse_pred(self, pixel_emb, mask_emb, coarse_pred): |
| """ |
| Enhance pixel embeddings with coarse prediction information |
| |
| Args: |
| pixel_emb (torch.tensor): B,C,H,W,D per-pixel embeddings |
| mask_emb (torch.tensor): B,N,C mask embeddings |
| coarse_pred (torch.tensor): B,N,H,W,D coarse prediction probabilities |
| |
| Returns: |
| torch.tensor: enhanced pixel embeddings B,C,H,W,D |
| """ |
| |
| |
| mask_emb_transposed = mask_emb.permute(0, 2, 1) |
| |
| enhanced_emb = torch.einsum('bnhwd,bcn->bchwd', coarse_pred, mask_emb_transposed) |
| |
| combined = torch.cat([pixel_emb, enhanced_emb], dim=1) |
|
|
| enhanced_pixel_emb = self.fusion_conv(combined) |
| |
| return enhanced_pixel_emb |
| |
| def vision_backbone_forward(self, image_input): |
| """ |
| Visual backbone forward |
| |
| Args: |
| image_input (torch.tensor): C,H,W,D (C=1) |
| |
| Returns: |
| image_embedding (torch.tensor): multiscale image features from encoder layers. N,B,d |
| pos (torch.tensor): position encoding. N,B,d |
| per_pixel_embedding_ls (List of torch.tensor): perpixel embeddings from decoder layers. B,d,H,W,D |
| """ |
|
|
| |
| latent_embedding_ls, per_pixel_embedding_ls = self.backbone(image_input) |
| |
| |
| image_embedding = [] |
| for latent_embedding, avg_pool in zip(latent_embedding_ls, self.avg_pool_ls): |
| tmp = avg_pool(latent_embedding) |
| image_embedding.append(tmp) |
| image_embedding.append(latent_embedding_ls[-1]) |
|
|
| |
| image_embedding = torch.cat(image_embedding, dim=1) |
| image_embedding = rearrange(image_embedding, 'b d h w depth -> b h w depth d') |
| image_embedding = self.projection_layer(image_embedding) |
| image_embedding = rearrange(image_embedding, 'b h w d dim -> (h w d) b dim') |
| |
| |
| pos = self.pos_embedding.to(latent_embedding_ls[-1].device) |
| |
| return image_embedding, pos, per_pixel_embedding_ls |
|
|
| def infer_forward(self, q, image_embedding, pos, per_pixel_embedding_ls, simulated_lowres_mc_pred=None): |
| """ |
| infer batches of queries (a list) on a batch of patches |
| |
| Args: |
| q (List of torch.tensor): N,d |
| simulated_lowres_mc_pred (torch.tensor, optional): B,N,H,W,D low-res multi-channel prediction |
| |
| Returns: |
| logits (torch.tensor): concat seg output of all queries. B,N_all,H,W,D |
| """ |
| _, B, _ = image_embedding.shape |
| |
| |
| N,_ = q.shape |
| q = repeat(q, 'n dim -> n b dim', b=B) |
| q = self.query_proj(q) |
| mask_embedding,_ = self.transformer_decoder(q, image_embedding, pos = pos) |
| mask_embedding = rearrange(mask_embedding, 'n b dim -> (b n) dim') |
| |
| |
| mask_embedding = self.mask_embed_proj(mask_embedding) |
| mask_embedding = rearrange(mask_embedding, '(b n) dim -> b n dim', b=B, n=N) |
| per_pixel_embedding = per_pixel_embedding_ls[0] |
| |
| |
| if simulated_lowres_mc_pred is not None: |
| per_pixel_embedding = self.enhance_with_coarse_pred( |
| per_pixel_embedding, |
| mask_embedding, |
| simulated_lowres_mc_pred) |
| |
| logits = torch.einsum('bchwd,bnc->bnhwd', per_pixel_embedding, mask_embedding) |
| |
| return logits |
| |
| def train_forward(self, queries, image_embedding, pos, per_pixel_embedding_ls, simulated_lowres_mc_pred=None): |
| """ |
| Args: |
| queries (torch.tensor): B,N,d |
| simulated_lowres_mc_pred (torch.tensor, optional): B,N,H,W,D low-res multi-channel prediction |
| |
| Returns: |
| logits (List of torch.tensor): list of seg results. B,N,H,W,D |
| """ |
| _, B, _ = image_embedding.shape |
| |
| |
| _, N, _ = queries.shape |
| queries = rearrange(queries, 'b n dim -> n b dim') |
| queries = self.query_proj(queries) |
| mask_embedding,_ = self.transformer_decoder(queries, image_embedding, pos = pos) |
| mask_embedding = rearrange(mask_embedding, 'n b dim -> (b n) dim') |
| |
| |
| last_mask_embedding = self.mask_embed_proj(mask_embedding) |
| last_mask_embedding = rearrange(last_mask_embedding, '(b n) dim -> b n dim', b=B, n=N) |
| per_pixel_embedding = per_pixel_embedding_ls[0] |
| |
| |
| if simulated_lowres_mc_pred is not None: |
| per_pixel_embedding = self.enhance_with_coarse_pred( |
| per_pixel_embedding, |
| last_mask_embedding, |
| simulated_lowres_mc_pred) |
| |
| logits = [torch.einsum('bchwd,bnc->bnhwd', per_pixel_embedding, last_mask_embedding)] |
| |
| |
| if self.deep_supervision: |
| for mask_embed_proj, per_pixel_embedding in zip(self.mid_mask_embed_proj, per_pixel_embedding_ls[1:]): |
| mid_mask_embedding = mask_embed_proj(mask_embedding) |
| mid_mask_embedding = rearrange(mid_mask_embedding, '(b n) dim -> b n dim', b=B, n=N) |
| |
| logits.append(torch.einsum('bchwd,bnc->bnhwd', per_pixel_embedding, mid_mask_embedding)) |
| |
| return logits |
|
|
| def forward(self, queries, image_input, simulated_lowres_sc_pred=None, simulated_lowres_mc_pred=None, train_mode=True): |
| |
| if simulated_lowres_sc_pred is not None: |
| |
| image_input = torch.cat([image_input, simulated_lowres_sc_pred], dim=1) |
|
|
| |
| image_embedding, pos, per_pixel_embedding_ls = self.vision_backbone_forward(image_input) |
| |
| |
| if train_mode: |
| logits = self.train_forward(queries, image_embedding, pos, per_pixel_embedding_ls, simulated_lowres_mc_pred) |
| |
| |
| else: |
| del image_input |
| torch.cuda.empty_cache() |
| logits = self.infer_forward(queries, image_embedding, pos, per_pixel_embedding_ls, simulated_lowres_mc_pred) |
| |
| return logits |
|
|
| if __name__ == '__main__': |
| model = Maskformer().cuda() |
| image = torch.rand((1, 3, 288, 288, 96)).cuda() |
| query = torch.rand((2, 10, 768)).cuda() |
| segmentations = model(query, image) |
| print(segmentations.shape) |