Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .complexmodule import ComplexConv1D, ComplexConv2d, ComplexBatchNorm, ComplexTranspose1D, cPReLU, ComplexConvTranspose2d | |
| class ComplexLinearLayer(nn.Module): | |
| """ | |
| A 1x1 Convolution Layer, which can be used to efficiently increase/decrease Feature Maps or, in the context of | |
| the U-Net architecture, generate a linear projection of the feature maps learned in earlier layers ("channel-wise | |
| pooling"). | |
| """ | |
| def __init__(self, in_chan, out_chan): | |
| super(ComplexLinearLayer, self).__init__() | |
| self.conv = ComplexConv2d(in_channels=in_chan, | |
| out_channels=out_chan, | |
| kernel_size=(1, 1), | |
| stride=(1, 1)) | |
| self.bn = ComplexBatchNorm(out_chan) | |
| self.act = cPReLU() | |
| def forward(self, x): | |
| return self.act(self.bn(self.conv(x))) | |
| class ComplexDConvBlock(nn.Module): | |
| def __init__(self, in_chan, out_chan, kernel_size=3, stride=1, dilation=2): | |
| super(ComplexDConvBlock, self).__init__() | |
| dconv_pad = (dilation * (kernel_size - 1)) // 2 | |
| self.conv1 = ComplexConv2d(in_channels=in_chan, | |
| out_channels=out_chan, | |
| kernel_size=(kernel_size, kernel_size), | |
| stride=(stride, stride), | |
| padding=(dconv_pad, dconv_pad), | |
| dilation=dilation, complex_axis=1) | |
| self.act = cPReLU() | |
| self.bn1 = ComplexBatchNorm(out_chan) | |
| self.drop1 = nn.Dropout(0.2) | |
| def forward(self, x, dropout=False): | |
| x = self.act(self.bn1(self.conv1(x))) | |
| if dropout: | |
| x = self.drop1(x) | |
| return x | |
| class ComplexConv1x1Block(nn.Module): | |
| """ Inspired by TasNet Temporal Block - not a 1x1 block, TODO rename across whole project """ | |
| def __init__(self, in_chan, out_chan, kernel_size, dilation): | |
| super(ComplexConv1x1Block, self).__init__() | |
| # Start with linear projection | |
| self.conv1x1 = ComplexLinearLayer(in_chan, out_chan) | |
| dconv_pad = (dilation * (kernel_size - 1)) // 2 # dont divide by 2 for a causal system | |
| # Follow up by depthwise, dilated conv | |
| self.dconv = ComplexConv2d(in_channels=out_chan, # before it was out everywhere | |
| out_channels=out_chan, | |
| kernel_size=(kernel_size, kernel_size), | |
| groups=in_chan, | |
| padding=(dconv_pad, dconv_pad), | |
| dilation=dilation) | |
| self.prelu = cPReLU() | |
| self.bn = ComplexBatchNorm(out_chan) | |
| # 1x1 across channel conv (pointwise conv) in=out, out=in | |
| self.pconv = ComplexConv2d(out_chan, in_chan, (1, 1)) | |
| def forward(self, x): | |
| # Generate new features by using separable, dilated conv | |
| y = self.conv1x1(x) | |
| y = self.dconv(y) | |
| y = self.bn(self.prelu(y)) | |
| # Map the new features to the same count of feature maps as the input was | |
| y = self.pconv(y) | |
| # Next part is done in tasnet paper but it might not be useful if one doesnt use a very deep module consisting | |
| # of these Blocks | |
| # # Add new features to the previous features, increasing the influence of important features and vice versa | |
| x = x + y | |
| return x | |
| class ComplexConvBlock(nn.Module): | |
| """ | |
| """ | |
| def __init__(self, in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)): | |
| super(ComplexConvBlock, self).__init__() | |
| self.conv1 = ComplexConv2d(in_channels=in_channels, out_channels=out_channels, | |
| kernel_size=kernel_size, stride=stride, complex_axis=1, padding=padding) | |
| self.conv2 = ComplexConv2d(in_channels=out_channels, out_channels=out_channels, | |
| kernel_size=kernel_size, stride=stride, complex_axis=1, padding=padding) | |
| self.bn1 = ComplexBatchNorm(out_channels) | |
| self.bn2 = ComplexBatchNorm(out_channels) | |
| self.act1 = cPReLU() | |
| self.act2 = cPReLU() | |
| self.drop1 = nn.Dropout(0.2) | |
| def forward(self, x, pool_size=(2, 2), pool_type=None, dropout=False): | |
| 'Conv -> BN -> Relu * 2 for Unet' | |
| x = self.act1(self.bn1(self.conv1(x))) | |
| x = self.act2(self.bn2(self.conv2(x))) | |
| if pool_type == 'max': | |
| x = F.max_pool2d(x, kernel_size=pool_size) | |
| elif pool_type == 'avg': | |
| x = F.avg_pool2d(x, kernel_size=pool_size) | |
| if dropout: | |
| x = self.drop1(x) | |
| return x | |
| class ComplexUpsampleUnet(nn.Module): | |
| def __init__(self, in_size, out_size, kz=3, | |
| mode='conv', dilation=1, padding=(0, 0), output_padding=(0, 0), complex_axis=1): | |
| super().__init__() | |
| self.complex_axis = complex_axis | |
| if mode == 'conv': | |
| self.up = ComplexConvTranspose2d(in_size, in_size//2, kernel_size=(kz, kz), stride=(2, 2), | |
| dilation=dilation, padding=padding, output_padding=output_padding, complex_axis=self.complex_axis) | |
| elif mode == 'bilinear': | |
| self.up = nn.Sequential( | |
| nn.Upsample(mode='bilinear', scale_factor=2), | |
| nn.Conv2d(in_size, out_size, kernel_size=1) | |
| ) | |
| self.conv1 = ComplexConvBlock(in_size, out_size, padding=(1, 1)) | |
| self.conv_noskip = ComplexConvBlock(in_size // 2, out_size, padding=(1, 1)) | |
| def _center_crop(self, x1, x2): | |
| diffX = x2.size()[2] - x1.size()[2] | |
| diffY = x2.size()[3] - x1.size()[3] | |
| # print('sizes', x1.size(), x2.size(), diffX // 2, diffX - diffX//2, diffY // 2, diffY - diffY//2) | |
| x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, | |
| diffY // 2, diffY - diffY // 2]) | |
| return x1, x2 | |
| def forward(self, x, residual): | |
| # x: complex tensor (B, 2*C, H, W) or equivalent | |
| up = self.up(x) | |
| if residual is None: | |
| out = self.conv_noskip(up, pool_type=None) | |
| else: | |
| up_real, up_imag = torch.chunk(up, 2, dim=self.complex_axis) | |
| res_real, res_imag = torch.chunk(residual, 2, dim=self.complex_axis) | |
| # Padding | |
| # up_real, res_real = self._center_crop(up_real, res_real) | |
| # up_imag, res_imag = self._center_crop(up_imag, res_imag) | |
| real_total = torch.cat([up_real, res_real], self.complex_axis) | |
| imag_total = torch.cat([up_imag, res_imag], self.complex_axis) | |
| out = torch.cat([real_total, imag_total], self.complex_axis) | |
| # out = torch.cat([out, residual], self.complex_axis) | |
| out = self.conv1(out, pool_type=None) | |
| return out | |
| class ComplexUpConstantUNet(nn.Module): | |
| """ | |
| Upsampling without residual/lateral connections. | |
| Doubles dimensions of T and F but keeps the number of channels constant. | |
| """ | |
| def __init__(self, nb_chan, kz=3, dilation=(1, 1), padding=(0, 0), output_padding=(0, 0), complex_axis=1): | |
| """ | |
| :param nb_chan: Int, number of channels of the input and output. | |
| """ | |
| super(ComplexUpConstantUNet, self).__init__() | |
| self.nb_chan = nb_chan | |
| self.complex_axis = complex_axis | |
| # Transposed convolution for upsampling | |
| self.up = ComplexConvTranspose2d( | |
| nb_chan, nb_chan, kernel_size=(kz, kz), stride=(2, 2), | |
| dilation=dilation, padding=padding, output_padding=output_padding, complex_axis=self.complex_axis | |
| ) | |
| # Optional convolution to refine channels (keeps nb_chan) | |
| self.out_conv = ComplexConvBlock(nb_chan * 2, nb_chan, padding=padding) | |
| def _center_crop(self, x1, x2): | |
| diffX = x2.size()[2] - x1.size()[2] | |
| diffY = x2.size()[3] - x1.size()[3] | |
| # print('sizes', x1.size(), x2.size(), diffX // 2, diffX - diffX//2, diffY // 2, diffY - diffY//2) | |
| x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, | |
| diffY // 2, diffY - diffY // 2]) | |
| return x1, x2 | |
| def forward(self, x, residual): | |
| up = self.up(x) # doubles time/freq dimensions | |
| if residual is None: | |
| return up | |
| else: | |
| up_real, up_imag = torch.chunk(up, 2, dim=self.complex_axis) | |
| res_real, res_imag = torch.chunk(residual, 2, dim=self.complex_axis) | |
| # Padding | |
| # up_real, res_real = self._center_crop(up_real, res_real) | |
| # up_imag, res_imag = self._center_crop(up_imag, res_imag) | |
| real_total = torch.cat([up_real, res_real], self.complex_axis) | |
| imag_total = torch.cat([up_imag, res_imag], self.complex_axis) | |
| out = torch.cat([real_total, imag_total], self.complex_axis) | |
| # Pass through conv block to refine | |
| # out = torch.cat([up, residual], self.complex_axis) | |
| out = self.out_conv(out, pool_type=None) | |
| return out | |
| class ComplexUConvBlock(nn.Module): | |
| def __init__(self, in_chan, out_chan, kernel_size=3, stride=1): | |
| super(ComplexUConvBlock, self).__init__() | |
| self.conv1 = ComplexConv2d(in_channels=in_chan, | |
| out_channels=out_chan, | |
| kernel_size=(kernel_size, kernel_size), | |
| stride=(stride, stride), | |
| padding=(1, 1), | |
| complex_axis=1) | |
| self.act = cPReLU() | |
| self.bn1 = ComplexBatchNorm(out_chan) | |
| self.drop1 = nn.Dropout(0.2) | |
| def forward(self, x, dropout=False): | |
| x = self.act(self.bn1(self.conv1(x))) | |
| if dropout: | |
| x = self.drop1(x) | |
| return x | |
| class CVConvNeXtBlock(nn.Module): | |
| """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. | |
| Args: | |
| dim (int): Number of input channels. | |
| intermediate_dim (int): Dimensionality of the intermediate layer. | |
| layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. | |
| Defaults to None. | |
| adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. | |
| None means non-conditional LayerNorm. Defaults to None. | |
| """ | |
| def __init__( | |
| self, | |
| dim: int = 512, | |
| intermediate_dim: int = 1536, | |
| layer_scale_init_value: float = 0.125, | |
| complex_axis: int = 1 | |
| ): | |
| super().__init__() | |
| self.dwconv = ComplexConv1D(in_channels=dim, out_channels=dim, kernel_size=3, padding=1, complex_axis=1) | |
| self.norm = nn.LayerNorm(dim, eps=1e-6) | |
| self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers | |
| self.act = nn.GELU() | |
| self.pwconv2 = nn.Linear(intermediate_dim, dim) | |
| self.gamma = ( | |
| nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) | |
| if layer_scale_init_value > 0 | |
| else None | |
| ) | |
| self.complex_axis = complex_axis | |
| def forward(self, x: torch.Tensor, laterial: torch.Tensor = None) -> torch.Tensor: | |
| residual = x | |
| x = self.dwconv(x) | |
| real, imag = torch.chunk(x, 2, dim=self.complex_axis) # Split real and imaginary parts | |
| real = real.transpose(1, 2) # (B, C, T) -> (B, T, C) | |
| imag = imag.transpose(1, 2) # (B, C, T) -> (B, T, C) | |
| real = self.norm(real) # Apply LayerNorm to real part | |
| imag = self.norm(imag) # Apply LayerNorm to imaginary part | |
| real = self.pwconv2(self.act(self.pwconv1(real))) # MLP on real part | |
| imag = self.pwconv2(self.act(self.pwconv1(imag))) # MLP on imaginary part | |
| if self.gamma is not None: | |
| real = self.gamma * real | |
| imag = self.gamma * imag | |
| real = real.transpose(1, 2) # (B, T, C) -> (B, C, T) | |
| imag = imag.transpose(1, 2) # (B, T, C) -> (B, C, T) | |
| x = torch.cat([real, imag], dim=self.complex_axis) # Concatenate real and imaginary parts back together | |
| x = residual + x | |
| return x | |
| class CVConvNeXtDBlock(nn.Module): | |
| """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. | |
| Args: | |
| dim (int): Number of input channels. | |
| intermediate_dim (int): Dimensionality of the intermediate layer. | |
| layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. | |
| Defaults to None. | |
| adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. | |
| None means non-conditional LayerNorm. Defaults to None. | |
| """ | |
| def __init__( | |
| self, | |
| dim: int = 512, | |
| intermediate_dim: int = 1536, | |
| layer_scale_init_value: float = 0.125, | |
| complex_axis: int = 1 | |
| ): | |
| super().__init__() | |
| self.dwconv = ComplexTranspose1D(in_channels=dim, out_channels=dim, kernel_size=3, padding=1, complex_axis=1) | |
| self.norm = nn.LayerNorm(dim, eps=1e-6) | |
| self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers | |
| self.act = nn.GELU() | |
| self.pwconv2 = nn.Linear(intermediate_dim, dim) | |
| self.gamma = ( | |
| nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) | |
| if layer_scale_init_value > 0 | |
| else None | |
| ) | |
| self.complex_axis = complex_axis | |
| self.out_conv = ComplexConv1D(in_channels=dim*2, out_channels=dim, kernel_size=3, padding=1, complex_axis=1) | |
| def forward(self, x: torch.Tensor, laterial: torch.Tensor = None) -> torch.Tensor: | |
| # residual = x | |
| x = self.dwconv(x) | |
| real, imag = torch.chunk(x, 2, dim=self.complex_axis) # Split real and imaginary parts | |
| real = real.transpose(1, 2) # (B, C, T) -> (B, T, C) | |
| imag = imag.transpose(1, 2) # (B, C, T) -> (B, T, C) | |
| real = self.norm(real) # Apply LayerNorm to real part | |
| imag = self.norm(imag) # Apply LayerNorm to imaginary part | |
| real = self.pwconv2(self.act(self.pwconv1(real))) # MLP on real part | |
| imag = self.pwconv2(self.act(self.pwconv1(imag))) # MLP on imaginary part | |
| if self.gamma is not None: | |
| real = self.gamma * real | |
| imag = self.gamma * imag | |
| real = real.transpose(1, 2) # (B, T, C) -> (B, C, T) | |
| imag = imag.transpose(1, 2) # (B, T, C) -> (B, C, T) | |
| x = torch.cat([real, imag], dim=self.complex_axis) # Concatenate real and imaginary parts back together | |
| # x = residual + x | |
| if laterial is not None: | |
| up_real, up_imag = torch.chunk(x, 2, dim=self.complex_axis) | |
| res_real, res_imag = torch.chunk(laterial, 2, dim=self.complex_axis) | |
| real_total = torch.cat([up_real, res_real], self.complex_axis) | |
| imag_total = torch.cat([up_imag, res_imag], self.complex_axis) | |
| x = torch.cat([real_total, imag_total], self.complex_axis) | |
| x = self.out_conv(x) | |
| return x | |