| | from __future__ import annotations
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from monai.networks.blocks.convolutions import Convolution
|
| | from monai.networks.blocks.segresnet_block import get_conv_layer, get_upsample_layer
|
| | from monai.networks.layers.factories import Dropout
|
| | from monai.networks.layers.utils import get_act_layer, get_norm_layer
|
| | from monai.utils import UpsampleMode
|
| | from einops import rearrange
|
| | from models.mamba_customer import ConvMamba, M3, PatchEmbed, PatchUnEmbed
|
| | from models.Blocks import CAB, SAB, VSSBlock, ShallowFusionAttnBlock
|
| | import warnings
|
| | warnings.filterwarnings("ignore")
|
| |
|
| | def get_dwconv_layer(
|
| | spatial_dims: int, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1,
|
| | bias: bool = False
|
| | ):
|
| | depth_conv = Convolution(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels,
|
| | strides=stride, kernel_size=kernel_size, bias=bias, conv_only=True, groups=in_channels)
|
| | point_conv = Convolution(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels,
|
| | strides=stride, kernel_size=1, bias=bias, conv_only=True, groups=1)
|
| | return torch.nn.Sequential(depth_conv, point_conv)
|
| |
|
| |
|
| | class SRCMLayer(nn.Module):
|
| | def __init__(self, input_dim, output_dim, d_state=16, d_conv=4, expand=2, conv_mode='deepwise'):
|
| | super().__init__()
|
| | self.input_dim = input_dim
|
| | self.output_dim = output_dim
|
| | self.norm = nn.LayerNorm(input_dim)
|
| | self.convmamba = ConvMamba(
|
| | d_model=input_dim,
|
| | d_state=d_state,
|
| | d_conv=d_conv,
|
| | expand=expand,
|
| | bimamba_type="v2",
|
| | conv_mode=conv_mode
|
| | )
|
| | self.proj = nn.Linear(input_dim, output_dim)
|
| | self.skip_scale = nn.Parameter(torch.ones(1))
|
| |
|
| | def forward(self, x):
|
| | if x.dtype == torch.float16:
|
| | x = x.type(torch.float32)
|
| | B, C = x.shape[:2]
|
| | assert C == self.input_dim
|
| | n_tokens = x.shape[2:].numel()
|
| | img_dims = x.shape[2:]
|
| | x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
|
| | x_norm = self.norm(x_flat)
|
| | x_mamba = self.convmamba(x_norm) + self.skip_scale * x_flat
|
| | x_mamba = self.norm(x_mamba)
|
| | x_mamba = self.proj(x_mamba)
|
| | out = x_mamba.transpose(-1, -2).reshape(B, self.output_dim, *img_dims)
|
| | return out
|
| |
|
| |
|
| | def get_srcm_layer(
|
| | spatial_dims: int, in_channels: int, out_channels: int, stride: int = 1, conv_mode: str = "deepwise"
|
| | ):
|
| | srcm_layer = SRCMLayer(input_dim=in_channels, output_dim=out_channels, conv_mode=conv_mode)
|
| | if stride != 1:
|
| | if spatial_dims == 2:
|
| | return nn.Sequential(srcm_layer, nn.MaxPool2d(kernel_size=stride, stride=stride))
|
| | return srcm_layer
|
| |
|
| | class SRCMBlock(nn.Module):
|
| |
|
| | def __init__(
|
| | self,
|
| | spatial_dims: int,
|
| | in_channels: int,
|
| | norm: tuple | str,
|
| | kernel_size: int = 3,
|
| | conv_mode: str = "deepwise",
|
| | act: tuple | str = ("RELU", {"inplace": True}),
|
| | ) -> None:
|
| | """
|
| | Args:
|
| | spatial_dims: number of spatial dimensions, could be 1, 2 or 3.
|
| | in_channels: number of input channels.
|
| | norm: feature normalization type and arguments.
|
| | kernel_size: convolution kernel size, the value should be an odd number. Defaults to 3.
|
| | act: activation type and arguments. Defaults to ``RELU``.
|
| | """
|
| |
|
| | super().__init__()
|
| |
|
| | if kernel_size % 2 != 1:
|
| | raise AssertionError("kernel_size should be an odd number.")
|
| |
|
| | self.norm1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)
|
| | self.norm2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)
|
| | self.act = get_act_layer(act)
|
| | self.conv1 = get_srcm_layer(
|
| | spatial_dims, in_channels=in_channels, out_channels=in_channels, conv_mode=conv_mode
|
| | )
|
| | self.conv2 = get_srcm_layer(
|
| | spatial_dims, in_channels=in_channels, out_channels=in_channels, conv_mode=conv_mode
|
| | )
|
| |
|
| | def forward(self, x):
|
| | identity = x
|
| |
|
| | x = self.norm1(x)
|
| | x = self.act(x)
|
| | x = self.conv1(x)
|
| |
|
| | x = self.norm2(x)
|
| | x = self.act(x)
|
| | x = self.conv2(x)
|
| |
|
| | x += identity
|
| |
|
| | return x
|
| |
|
| |
|
| | class CSI(nn.Module):
|
| | def __init__(self, dim):
|
| | super(CSI, self).__init__()
|
| | self.shallow_fusion_attn = ShallowFusionAttnBlock(dim)
|
| | self.m3 = M3(dim)
|
| | self.vss = VSSBlock(hidden_dim=dim)
|
| | self.patch_embed = PatchEmbed(in_chans=dim, embed_dim=dim)
|
| | self.patch_unembed = PatchUnEmbed(in_chans=dim, embed_dim=dim)
|
| | def forward(self, I1, I2, h, w):
|
| | I1_fuse, I2_fuse = self.shallow_fusion_attn(I1, I2, h, w)
|
| | fusion = torch.abs(I1_fuse - I2_fuse)
|
| | I1_token = self.patch_embed(I1_fuse)
|
| | I2_token = self.patch_embed(I2_fuse)
|
| | fusion_token = self.patch_embed(fusion)
|
| | test_h, test_w = fusion.shape[2], fusion.shape[3]
|
| | fusion_token, _ = self.m3(I1_token, I2_token, fusion_token, test_h, test_w)
|
| | fusion_out = self.patch_unembed(fusion_token, (h, w))
|
| | return fusion_out
|
| |
|
| | class STNR(nn.Module):
|
| | def __init__(
|
| | self,
|
| | spatial_dims: int = 2,
|
| | init_filters: int = 16,
|
| | in_channels: int = 1,
|
| | out_channels: int = 2,
|
| | conv_mode: str = "deepwise",
|
| | local_query_model = "orignal_dinner",
|
| | dropout_prob: float | None = None,
|
| | act: tuple | str = ("RELU", {"inplace": True}),
|
| | norm: tuple | str = ("GROUP", {"num_groups": 8}),
|
| | norm_name: str = "",
|
| | num_groups: int = 8,
|
| | use_conv_final: bool = True,
|
| | blocks_down: tuple = (1, 2, 2, 4),
|
| | blocks_up: tuple = (1, 1, 1),
|
| | mode: str = "",
|
| | up_mode="ResMamba",
|
| | up_conv_mode="deepwise",
|
| | resdiual=False,
|
| | stage = 4,
|
| | diff_abs="later",
|
| | mamba_act = "silu",
|
| | upsample_mode: UpsampleMode | str = UpsampleMode.NONTRAINABLE,
|
| | ):
|
| | super().__init__()
|
| |
|
| | if spatial_dims not in (2, 3):
|
| | raise ValueError("`spatial_dims` can only be 2 or 3.")
|
| | self.mode = mode
|
| | self.stage = stage
|
| | self.up_conv_mode = up_conv_mode
|
| | self.mamba_act = mamba_act
|
| | self.resdiual = resdiual
|
| | self.up_mode = up_mode
|
| | self.diff_abs = diff_abs
|
| | self.conv_mode = conv_mode
|
| | self.local_query_model = local_query_model
|
| | self.spatial_dims = spatial_dims
|
| | self.init_filters = init_filters
|
| | self.channels_list = [self.init_filters, self.init_filters*2, self.init_filters*4, self.init_filters*8]
|
| | self.in_channels = in_channels
|
| | self.blocks_down = blocks_down
|
| | self.blocks_up = blocks_up
|
| | print(self.blocks_up)
|
| | self.dropout_prob = dropout_prob
|
| | self.act = act
|
| | self.act_mod = get_act_layer(act)
|
| | if norm_name:
|
| | if norm_name.lower() != "group":
|
| | raise ValueError(f"Deprecating option 'norm_name={norm_name}', please use 'norm' instead.")
|
| | norm = ("group", {"num_groups": num_groups})
|
| | self.norm = norm
|
| | print(self.norm)
|
| | self.upsample_mode = UpsampleMode(upsample_mode)
|
| | self.use_conv_final = use_conv_final
|
| | self.convInit = get_conv_layer(spatial_dims, in_channels, init_filters)
|
| | self.srcm_encoder_layers = self._make_srcm_encoder_layers()
|
| | self.srcm_decoder_layers, self.up_samples = self._make_srcm_decoder_layers(up_mode=self.up_mode)
|
| | self.conv_final = self._make_final_conv(out_channels)
|
| | self.fusion_blocks = nn.ModuleList(
|
| | [CSI(self.channels_list[i]) for i in range(self.stage)]
|
| | )
|
| | self.cab_layers = nn.ModuleList([
|
| | CAB(ch) for ch in self.channels_list[::-1][1:]
|
| | ])
|
| | self.sab_layers = nn.ModuleList([
|
| | SAB(kernel_size=7) for _ in range(len(self.blocks_up))
|
| | ])
|
| | self.conv_down_layers = nn.ModuleList([
|
| | nn.Conv2d(ch * 2, ch, kernel_size=1, stride=1, padding=0) for ch in self.channels_list[::-1][1:]
|
| | ])
|
| | if dropout_prob is not None:
|
| | self.dropout = Dropout[Dropout.DROPOUT, spatial_dims](dropout_prob)
|
| |
|
| | def _make_srcm_encoder_layers(self):
|
| | srcm_encoder_layers = nn.ModuleList()
|
| | blocks_down, spatial_dims, filters, norm, conv_mode = (self.blocks_down, self.spatial_dims, self.init_filters, self.norm, self.conv_mode)
|
| | for i, item in enumerate(blocks_down):
|
| | layer_in_channels = filters * 2 ** i
|
| | downsample_mamba = (
|
| | get_srcm_layer(spatial_dims, layer_in_channels // 2, layer_in_channels, stride=2, conv_mode=conv_mode)
|
| | if i > 0
|
| | else nn.Identity()
|
| | )
|
| | down_layer = nn.Sequential(
|
| | downsample_mamba,
|
| | *[SRCMBlock(spatial_dims, layer_in_channels, norm=norm, act=self.act, conv_mode=conv_mode) for _ in range(item)]
|
| | )
|
| | srcm_encoder_layers.append(down_layer)
|
| | return srcm_encoder_layers
|
| |
|
| | def _make_srcm_decoder_layers(self, up_mode):
|
| | srcm_decoder_layers, up_samples = nn.ModuleList(), nn.ModuleList()
|
| | upsample_mode, blocks_up, spatial_dims, filters, norm = (
|
| | self.upsample_mode,
|
| | self.blocks_up,
|
| | self.spatial_dims,
|
| | self.init_filters,
|
| | self.norm,
|
| | )
|
| | if up_mode == 'SRCM':
|
| | Block_up = SRCMBlock
|
| | n_up = len(blocks_up)
|
| | for i in range(n_up):
|
| | sample_in_channels = filters * 2 ** (n_up - i)
|
| | srcm_decoder_layers.append(
|
| | nn.Sequential(
|
| | *[
|
| | Block_up(spatial_dims, sample_in_channels // 2, norm=norm, act=self.act, conv_mode=self.up_conv_mode)
|
| | for _ in range(blocks_up[i])
|
| | ]
|
| | )
|
| | )
|
| | up_samples.append(
|
| | nn.Sequential(
|
| | *[
|
| | get_conv_layer(spatial_dims, sample_in_channels, sample_in_channels // 2, kernel_size=1),
|
| | get_upsample_layer(spatial_dims, sample_in_channels // 2, upsample_mode=upsample_mode),
|
| | ]
|
| | )
|
| | )
|
| | return srcm_decoder_layers, up_samples
|
| |
|
| | def _make_final_conv(self, out_channels: int):
|
| | return nn.Sequential(
|
| | get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=self.init_filters),
|
| | self.act_mod,
|
| | get_conv_layer(self.spatial_dims, self.init_filters, out_channels, kernel_size=1, bias=True),
|
| | )
|
| |
|
| | def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
| | x = self.convInit(x)
|
| | if self.dropout_prob is not None:
|
| | x = self.dropout(x)
|
| | down_x = []
|
| |
|
| | for down in self.srcm_encoder_layers:
|
| | x = down(x)
|
| | down_x.append(x)
|
| |
|
| | return x, down_x
|
| |
|
| | def decode(self, x: torch.Tensor, down_x: list[torch.Tensor]) -> torch.Tensor:
|
| | for i, (up, upl) in enumerate(zip(self.up_samples, self.srcm_decoder_layers)):
|
| | skip = down_x[i + 1]
|
| | x_up = up(x) + skip
|
| | x_cab = self.cab_layers[i](x_up) * x_up
|
| | x_sab = self.sab_layers[i](x_cab) * x_cab
|
| | x_srcm = upl(x_up)
|
| | combined_out = torch.cat([x_sab, x_srcm], dim=1)
|
| | final_out = self.conv_down_layers[i](combined_out)
|
| | x = final_out
|
| | if self.use_conv_final:
|
| | x = self.conv_final(x)
|
| | return x
|
| |
|
| | def forward(self, x1: torch.Tensor, x2:torch.Tensor) -> torch.Tensor:
|
| | b, c, h, w = x1.shape
|
| | x1, down_x1 = self.encode(x1)
|
| | x2, down_x2 = self.encode(x2)
|
| | down_x = []
|
| | for i in range(len(down_x1)):
|
| | x1_level, x2_level = down_x1[i], down_x2[i]
|
| | H_i, W_i = x1_level.shape[2], x1_level.shape[3]
|
| | if self.diff_abs == "later":
|
| | if self.mode == "FUSION":
|
| | if i < self.stage:
|
| | zero_res = torch.zeros_like(x1_level)
|
| | fusion = self.fusion_blocks[i](x1_level, x2_level, H_i, W_i)
|
| | else:
|
| | fusion = torch.abs(x1_level - x2_level)
|
| | else:
|
| | fusion = torch.abs(x1_level - x2_level)
|
| | down_x.append(fusion)
|
| | down_x.reverse()
|
| | x = self.decode(down_x[0], down_x)
|
| | return x
|
| |
|
| | if __name__ == "__main__":
|
| | device = "cuda:0"
|
| | CDMamba = STNR(spatial_dims=2, in_channels=3, out_channels=2, init_filters=16, norm=("GROUP", {"num_groups": 8}),
|
| | mode="FUSION", conv_mode='orignal', local_query_model="orignal_dinner",
|
| | stage=4, mamba_act="silu", up_mode="SRCM", up_conv_mode='deepwise', blocks_down=(1, 2, 2, 4), blocks_up=(1, 1, 1),
|
| | resdiual=False, diff_abs="later").to(device)
|
| | x = torch.randn(1, 3, 256, 256).to(device)
|
| | y = CDMamba(x, x)
|
| | print(y.shape) |