import torch import torch.nn as nn import torch.nn.functional as F from .complexmodule import ComplexConv1D, ComplexConv2d, ComplexBatchNorm, ComplexTranspose1D, cPReLU, ComplexConvTranspose2d class ComplexLinearLayer(nn.Module): """ A 1x1 Convolution Layer, which can be used to efficiently increase/decrease Feature Maps or, in the context of the U-Net architecture, generate a linear projection of the feature maps learned in earlier layers ("channel-wise pooling"). """ def __init__(self, in_chan, out_chan): super(ComplexLinearLayer, self).__init__() self.conv = ComplexConv2d(in_channels=in_chan, out_channels=out_chan, kernel_size=(1, 1), stride=(1, 1)) self.bn = ComplexBatchNorm(out_chan) self.act = cPReLU() def forward(self, x): return self.act(self.bn(self.conv(x))) class ComplexDConvBlock(nn.Module): def __init__(self, in_chan, out_chan, kernel_size=3, stride=1, dilation=2): super(ComplexDConvBlock, self).__init__() dconv_pad = (dilation * (kernel_size - 1)) // 2 self.conv1 = ComplexConv2d(in_channels=in_chan, out_channels=out_chan, kernel_size=(kernel_size, kernel_size), stride=(stride, stride), padding=(dconv_pad, dconv_pad), dilation=dilation, complex_axis=1) self.act = cPReLU() self.bn1 = ComplexBatchNorm(out_chan) self.drop1 = nn.Dropout(0.2) def forward(self, x, dropout=False): x = self.act(self.bn1(self.conv1(x))) if dropout: x = self.drop1(x) return x class ComplexConv1x1Block(nn.Module): """ Inspired by TasNet Temporal Block - not a 1x1 block, TODO rename across whole project """ def __init__(self, in_chan, out_chan, kernel_size, dilation): super(ComplexConv1x1Block, self).__init__() # Start with linear projection self.conv1x1 = ComplexLinearLayer(in_chan, out_chan) dconv_pad = (dilation * (kernel_size - 1)) // 2 # dont divide by 2 for a causal system # Follow up by depthwise, dilated conv self.dconv = ComplexConv2d(in_channels=out_chan, # before it was out everywhere out_channels=out_chan, kernel_size=(kernel_size, kernel_size), groups=in_chan, padding=(dconv_pad, dconv_pad), dilation=dilation) self.prelu = cPReLU() self.bn = ComplexBatchNorm(out_chan) # 1x1 across channel conv (pointwise conv) in=out, out=in self.pconv = ComplexConv2d(out_chan, in_chan, (1, 1)) def forward(self, x): # Generate new features by using separable, dilated conv y = self.conv1x1(x) y = self.dconv(y) y = self.bn(self.prelu(y)) # Map the new features to the same count of feature maps as the input was y = self.pconv(y) # Next part is done in tasnet paper but it might not be useful if one doesnt use a very deep module consisting # of these Blocks # # Add new features to the previous features, increasing the influence of important features and vice versa x = x + y return x class ComplexConvBlock(nn.Module): """ """ def __init__(self, in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)): super(ComplexConvBlock, self).__init__() self.conv1 = ComplexConv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, complex_axis=1, padding=padding) self.conv2 = ComplexConv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, complex_axis=1, padding=padding) self.bn1 = ComplexBatchNorm(out_channels) self.bn2 = ComplexBatchNorm(out_channels) self.act1 = cPReLU() self.act2 = cPReLU() self.drop1 = nn.Dropout(0.2) def forward(self, x, pool_size=(2, 2), pool_type=None, dropout=False): 'Conv -> BN -> Relu * 2 for Unet' x = self.act1(self.bn1(self.conv1(x))) x = self.act2(self.bn2(self.conv2(x))) if pool_type == 'max': x = F.max_pool2d(x, kernel_size=pool_size) elif pool_type == 'avg': x = F.avg_pool2d(x, kernel_size=pool_size) if dropout: x = self.drop1(x) return x class ComplexUpsampleUnet(nn.Module): def __init__(self, in_size, out_size, kz=3, mode='conv', dilation=1, padding=(0, 0), output_padding=(0, 0), complex_axis=1): super().__init__() self.complex_axis = complex_axis if mode == 'conv': self.up = ComplexConvTranspose2d(in_size, in_size//2, kernel_size=(kz, kz), stride=(2, 2), dilation=dilation, padding=padding, output_padding=output_padding, complex_axis=self.complex_axis) elif mode == 'bilinear': self.up = nn.Sequential( nn.Upsample(mode='bilinear', scale_factor=2), nn.Conv2d(in_size, out_size, kernel_size=1) ) self.conv1 = ComplexConvBlock(in_size, out_size, padding=(1, 1)) self.conv_noskip = ComplexConvBlock(in_size // 2, out_size, padding=(1, 1)) def _center_crop(self, x1, x2): diffX = x2.size()[2] - x1.size()[2] diffY = x2.size()[3] - x1.size()[3] # print('sizes', x1.size(), x2.size(), diffX // 2, diffX - diffX//2, diffY // 2, diffY - diffY//2) x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) return x1, x2 def forward(self, x, residual): # x: complex tensor (B, 2*C, H, W) or equivalent up = self.up(x) if residual is None: out = self.conv_noskip(up, pool_type=None) else: up_real, up_imag = torch.chunk(up, 2, dim=self.complex_axis) res_real, res_imag = torch.chunk(residual, 2, dim=self.complex_axis) # Padding # up_real, res_real = self._center_crop(up_real, res_real) # up_imag, res_imag = self._center_crop(up_imag, res_imag) real_total = torch.cat([up_real, res_real], self.complex_axis) imag_total = torch.cat([up_imag, res_imag], self.complex_axis) out = torch.cat([real_total, imag_total], self.complex_axis) # out = torch.cat([out, residual], self.complex_axis) out = self.conv1(out, pool_type=None) return out class ComplexUpConstantUNet(nn.Module): """ Upsampling without residual/lateral connections. Doubles dimensions of T and F but keeps the number of channels constant. """ def __init__(self, nb_chan, kz=3, dilation=(1, 1), padding=(0, 0), output_padding=(0, 0), complex_axis=1): """ :param nb_chan: Int, number of channels of the input and output. """ super(ComplexUpConstantUNet, self).__init__() self.nb_chan = nb_chan self.complex_axis = complex_axis # Transposed convolution for upsampling self.up = ComplexConvTranspose2d( nb_chan, nb_chan, kernel_size=(kz, kz), stride=(2, 2), dilation=dilation, padding=padding, output_padding=output_padding, complex_axis=self.complex_axis ) # Optional convolution to refine channels (keeps nb_chan) self.out_conv = ComplexConvBlock(nb_chan * 2, nb_chan, padding=padding) def _center_crop(self, x1, x2): diffX = x2.size()[2] - x1.size()[2] diffY = x2.size()[3] - x1.size()[3] # print('sizes', x1.size(), x2.size(), diffX // 2, diffX - diffX//2, diffY // 2, diffY - diffY//2) x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) return x1, x2 def forward(self, x, residual): up = self.up(x) # doubles time/freq dimensions if residual is None: return up else: up_real, up_imag = torch.chunk(up, 2, dim=self.complex_axis) res_real, res_imag = torch.chunk(residual, 2, dim=self.complex_axis) # Padding # up_real, res_real = self._center_crop(up_real, res_real) # up_imag, res_imag = self._center_crop(up_imag, res_imag) real_total = torch.cat([up_real, res_real], self.complex_axis) imag_total = torch.cat([up_imag, res_imag], self.complex_axis) out = torch.cat([real_total, imag_total], self.complex_axis) # Pass through conv block to refine # out = torch.cat([up, residual], self.complex_axis) out = self.out_conv(out, pool_type=None) return out class ComplexUConvBlock(nn.Module): def __init__(self, in_chan, out_chan, kernel_size=3, stride=1): super(ComplexUConvBlock, self).__init__() self.conv1 = ComplexConv2d(in_channels=in_chan, out_channels=out_chan, kernel_size=(kernel_size, kernel_size), stride=(stride, stride), padding=(1, 1), complex_axis=1) self.act = cPReLU() self.bn1 = ComplexBatchNorm(out_chan) self.drop1 = nn.Dropout(0.2) def forward(self, x, dropout=False): x = self.act(self.bn1(self.conv1(x))) if dropout: x = self.drop1(x) return x class CVConvNeXtBlock(nn.Module): """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. Args: dim (int): Number of input channels. intermediate_dim (int): Dimensionality of the intermediate layer. layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. Defaults to None. adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. None means non-conditional LayerNorm. Defaults to None. """ def __init__( self, dim: int = 512, intermediate_dim: int = 1536, layer_scale_init_value: float = 0.125, complex_axis: int = 1 ): super().__init__() self.dwconv = ComplexConv1D(in_channels=dim, out_channels=dim, kernel_size=3, padding=1, complex_axis=1) self.norm = nn.LayerNorm(dim, eps=1e-6) self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.pwconv2 = nn.Linear(intermediate_dim, dim) self.gamma = ( nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) if layer_scale_init_value > 0 else None ) self.complex_axis = complex_axis def forward(self, x: torch.Tensor, laterial: torch.Tensor = None) -> torch.Tensor: residual = x x = self.dwconv(x) real, imag = torch.chunk(x, 2, dim=self.complex_axis) # Split real and imaginary parts real = real.transpose(1, 2) # (B, C, T) -> (B, T, C) imag = imag.transpose(1, 2) # (B, C, T) -> (B, T, C) real = self.norm(real) # Apply LayerNorm to real part imag = self.norm(imag) # Apply LayerNorm to imaginary part real = self.pwconv2(self.act(self.pwconv1(real))) # MLP on real part imag = self.pwconv2(self.act(self.pwconv1(imag))) # MLP on imaginary part if self.gamma is not None: real = self.gamma * real imag = self.gamma * imag real = real.transpose(1, 2) # (B, T, C) -> (B, C, T) imag = imag.transpose(1, 2) # (B, T, C) -> (B, C, T) x = torch.cat([real, imag], dim=self.complex_axis) # Concatenate real and imaginary parts back together x = residual + x return x class CVConvNeXtDBlock(nn.Module): """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. Args: dim (int): Number of input channels. intermediate_dim (int): Dimensionality of the intermediate layer. layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. Defaults to None. adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. None means non-conditional LayerNorm. Defaults to None. """ def __init__( self, dim: int = 512, intermediate_dim: int = 1536, layer_scale_init_value: float = 0.125, complex_axis: int = 1 ): super().__init__() self.dwconv = ComplexTranspose1D(in_channels=dim, out_channels=dim, kernel_size=3, padding=1, complex_axis=1) self.norm = nn.LayerNorm(dim, eps=1e-6) self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.pwconv2 = nn.Linear(intermediate_dim, dim) self.gamma = ( nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) if layer_scale_init_value > 0 else None ) self.complex_axis = complex_axis self.out_conv = ComplexConv1D(in_channels=dim*2, out_channels=dim, kernel_size=3, padding=1, complex_axis=1) def forward(self, x: torch.Tensor, laterial: torch.Tensor = None) -> torch.Tensor: # residual = x x = self.dwconv(x) real, imag = torch.chunk(x, 2, dim=self.complex_axis) # Split real and imaginary parts real = real.transpose(1, 2) # (B, C, T) -> (B, T, C) imag = imag.transpose(1, 2) # (B, C, T) -> (B, T, C) real = self.norm(real) # Apply LayerNorm to real part imag = self.norm(imag) # Apply LayerNorm to imaginary part real = self.pwconv2(self.act(self.pwconv1(real))) # MLP on real part imag = self.pwconv2(self.act(self.pwconv1(imag))) # MLP on imaginary part if self.gamma is not None: real = self.gamma * real imag = self.gamma * imag real = real.transpose(1, 2) # (B, T, C) -> (B, C, T) imag = imag.transpose(1, 2) # (B, T, C) -> (B, C, T) x = torch.cat([real, imag], dim=self.complex_axis) # Concatenate real and imaginary parts back together # x = residual + x if laterial is not None: up_real, up_imag = torch.chunk(x, 2, dim=self.complex_axis) res_real, res_imag = torch.chunk(laterial, 2, dim=self.complex_axis) real_total = torch.cat([up_real, res_real], self.complex_axis) imag_total = torch.cat([up_imag, res_imag], self.complex_axis) x = torch.cat([real_total, imag_total], self.complex_axis) x = self.out_conv(x) return x