U-Past / modules /blocks /complexblock.py
lycaoduong's picture
Initial space
e8160b2 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from .complexmodule import ComplexConv1D, ComplexConv2d, ComplexBatchNorm, ComplexTranspose1D, cPReLU, ComplexConvTranspose2d
class ComplexLinearLayer(nn.Module):
"""
A 1x1 Convolution Layer, which can be used to efficiently increase/decrease Feature Maps or, in the context of
the U-Net architecture, generate a linear projection of the feature maps learned in earlier layers ("channel-wise
pooling").
"""
def __init__(self, in_chan, out_chan):
super(ComplexLinearLayer, self).__init__()
self.conv = ComplexConv2d(in_channels=in_chan,
out_channels=out_chan,
kernel_size=(1, 1),
stride=(1, 1))
self.bn = ComplexBatchNorm(out_chan)
self.act = cPReLU()
def forward(self, x):
return self.act(self.bn(self.conv(x)))
class ComplexDConvBlock(nn.Module):
def __init__(self, in_chan, out_chan, kernel_size=3, stride=1, dilation=2):
super(ComplexDConvBlock, self).__init__()
dconv_pad = (dilation * (kernel_size - 1)) // 2
self.conv1 = ComplexConv2d(in_channels=in_chan,
out_channels=out_chan,
kernel_size=(kernel_size, kernel_size),
stride=(stride, stride),
padding=(dconv_pad, dconv_pad),
dilation=dilation, complex_axis=1)
self.act = cPReLU()
self.bn1 = ComplexBatchNorm(out_chan)
self.drop1 = nn.Dropout(0.2)
def forward(self, x, dropout=False):
x = self.act(self.bn1(self.conv1(x)))
if dropout:
x = self.drop1(x)
return x
class ComplexConv1x1Block(nn.Module):
""" Inspired by TasNet Temporal Block - not a 1x1 block, TODO rename across whole project """
def __init__(self, in_chan, out_chan, kernel_size, dilation):
super(ComplexConv1x1Block, self).__init__()
# Start with linear projection
self.conv1x1 = ComplexLinearLayer(in_chan, out_chan)
dconv_pad = (dilation * (kernel_size - 1)) // 2 # dont divide by 2 for a causal system
# Follow up by depthwise, dilated conv
self.dconv = ComplexConv2d(in_channels=out_chan, # before it was out everywhere
out_channels=out_chan,
kernel_size=(kernel_size, kernel_size),
groups=in_chan,
padding=(dconv_pad, dconv_pad),
dilation=dilation)
self.prelu = cPReLU()
self.bn = ComplexBatchNorm(out_chan)
# 1x1 across channel conv (pointwise conv) in=out, out=in
self.pconv = ComplexConv2d(out_chan, in_chan, (1, 1))
def forward(self, x):
# Generate new features by using separable, dilated conv
y = self.conv1x1(x)
y = self.dconv(y)
y = self.bn(self.prelu(y))
# Map the new features to the same count of feature maps as the input was
y = self.pconv(y)
# Next part is done in tasnet paper but it might not be useful if one doesnt use a very deep module consisting
# of these Blocks
# # Add new features to the previous features, increasing the influence of important features and vice versa
x = x + y
return x
class ComplexConvBlock(nn.Module):
"""
"""
def __init__(self, in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)):
super(ComplexConvBlock, self).__init__()
self.conv1 = ComplexConv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, complex_axis=1, padding=padding)
self.conv2 = ComplexConv2d(in_channels=out_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, complex_axis=1, padding=padding)
self.bn1 = ComplexBatchNorm(out_channels)
self.bn2 = ComplexBatchNorm(out_channels)
self.act1 = cPReLU()
self.act2 = cPReLU()
self.drop1 = nn.Dropout(0.2)
def forward(self, x, pool_size=(2, 2), pool_type=None, dropout=False):
'Conv -> BN -> Relu * 2 for Unet'
x = self.act1(self.bn1(self.conv1(x)))
x = self.act2(self.bn2(self.conv2(x)))
if pool_type == 'max':
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg':
x = F.avg_pool2d(x, kernel_size=pool_size)
if dropout:
x = self.drop1(x)
return x
class ComplexUpsampleUnet(nn.Module):
def __init__(self, in_size, out_size, kz=3,
mode='conv', dilation=1, padding=(0, 0), output_padding=(0, 0), complex_axis=1):
super().__init__()
self.complex_axis = complex_axis
if mode == 'conv':
self.up = ComplexConvTranspose2d(in_size, in_size//2, kernel_size=(kz, kz), stride=(2, 2),
dilation=dilation, padding=padding, output_padding=output_padding, complex_axis=self.complex_axis)
elif mode == 'bilinear':
self.up = nn.Sequential(
nn.Upsample(mode='bilinear', scale_factor=2),
nn.Conv2d(in_size, out_size, kernel_size=1)
)
self.conv1 = ComplexConvBlock(in_size, out_size, padding=(1, 1))
self.conv_noskip = ComplexConvBlock(in_size // 2, out_size, padding=(1, 1))
def _center_crop(self, x1, x2):
diffX = x2.size()[2] - x1.size()[2]
diffY = x2.size()[3] - x1.size()[3]
# print('sizes', x1.size(), x2.size(), diffX // 2, diffX - diffX//2, diffY // 2, diffY - diffY//2)
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
return x1, x2
def forward(self, x, residual):
# x: complex tensor (B, 2*C, H, W) or equivalent
up = self.up(x)
if residual is None:
out = self.conv_noskip(up, pool_type=None)
else:
up_real, up_imag = torch.chunk(up, 2, dim=self.complex_axis)
res_real, res_imag = torch.chunk(residual, 2, dim=self.complex_axis)
# Padding
# up_real, res_real = self._center_crop(up_real, res_real)
# up_imag, res_imag = self._center_crop(up_imag, res_imag)
real_total = torch.cat([up_real, res_real], self.complex_axis)
imag_total = torch.cat([up_imag, res_imag], self.complex_axis)
out = torch.cat([real_total, imag_total], self.complex_axis)
# out = torch.cat([out, residual], self.complex_axis)
out = self.conv1(out, pool_type=None)
return out
class ComplexUpConstantUNet(nn.Module):
"""
Upsampling without residual/lateral connections.
Doubles dimensions of T and F but keeps the number of channels constant.
"""
def __init__(self, nb_chan, kz=3, dilation=(1, 1), padding=(0, 0), output_padding=(0, 0), complex_axis=1):
"""
:param nb_chan: Int, number of channels of the input and output.
"""
super(ComplexUpConstantUNet, self).__init__()
self.nb_chan = nb_chan
self.complex_axis = complex_axis
# Transposed convolution for upsampling
self.up = ComplexConvTranspose2d(
nb_chan, nb_chan, kernel_size=(kz, kz), stride=(2, 2),
dilation=dilation, padding=padding, output_padding=output_padding, complex_axis=self.complex_axis
)
# Optional convolution to refine channels (keeps nb_chan)
self.out_conv = ComplexConvBlock(nb_chan * 2, nb_chan, padding=padding)
def _center_crop(self, x1, x2):
diffX = x2.size()[2] - x1.size()[2]
diffY = x2.size()[3] - x1.size()[3]
# print('sizes', x1.size(), x2.size(), diffX // 2, diffX - diffX//2, diffY // 2, diffY - diffY//2)
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
return x1, x2
def forward(self, x, residual):
up = self.up(x) # doubles time/freq dimensions
if residual is None:
return up
else:
up_real, up_imag = torch.chunk(up, 2, dim=self.complex_axis)
res_real, res_imag = torch.chunk(residual, 2, dim=self.complex_axis)
# Padding
# up_real, res_real = self._center_crop(up_real, res_real)
# up_imag, res_imag = self._center_crop(up_imag, res_imag)
real_total = torch.cat([up_real, res_real], self.complex_axis)
imag_total = torch.cat([up_imag, res_imag], self.complex_axis)
out = torch.cat([real_total, imag_total], self.complex_axis)
# Pass through conv block to refine
# out = torch.cat([up, residual], self.complex_axis)
out = self.out_conv(out, pool_type=None)
return out
class ComplexUConvBlock(nn.Module):
def __init__(self, in_chan, out_chan, kernel_size=3, stride=1):
super(ComplexUConvBlock, self).__init__()
self.conv1 = ComplexConv2d(in_channels=in_chan,
out_channels=out_chan,
kernel_size=(kernel_size, kernel_size),
stride=(stride, stride),
padding=(1, 1),
complex_axis=1)
self.act = cPReLU()
self.bn1 = ComplexBatchNorm(out_chan)
self.drop1 = nn.Dropout(0.2)
def forward(self, x, dropout=False):
x = self.act(self.bn1(self.conv1(x)))
if dropout:
x = self.drop1(x)
return x
class CVConvNeXtBlock(nn.Module):
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
Args:
dim (int): Number of input channels.
intermediate_dim (int): Dimensionality of the intermediate layer.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional LayerNorm. Defaults to None.
"""
def __init__(
self,
dim: int = 512,
intermediate_dim: int = 1536,
layer_scale_init_value: float = 0.125,
complex_axis: int = 1
):
super().__init__()
self.dwconv = ComplexConv1D(in_channels=dim, out_channels=dim, kernel_size=3, padding=1, complex_axis=1)
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(intermediate_dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
if layer_scale_init_value > 0
else None
)
self.complex_axis = complex_axis
def forward(self, x: torch.Tensor, laterial: torch.Tensor = None) -> torch.Tensor:
residual = x
x = self.dwconv(x)
real, imag = torch.chunk(x, 2, dim=self.complex_axis) # Split real and imaginary parts
real = real.transpose(1, 2) # (B, C, T) -> (B, T, C)
imag = imag.transpose(1, 2) # (B, C, T) -> (B, T, C)
real = self.norm(real) # Apply LayerNorm to real part
imag = self.norm(imag) # Apply LayerNorm to imaginary part
real = self.pwconv2(self.act(self.pwconv1(real))) # MLP on real part
imag = self.pwconv2(self.act(self.pwconv1(imag))) # MLP on imaginary part
if self.gamma is not None:
real = self.gamma * real
imag = self.gamma * imag
real = real.transpose(1, 2) # (B, T, C) -> (B, C, T)
imag = imag.transpose(1, 2) # (B, T, C) -> (B, C, T)
x = torch.cat([real, imag], dim=self.complex_axis) # Concatenate real and imaginary parts back together
x = residual + x
return x
class CVConvNeXtDBlock(nn.Module):
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
Args:
dim (int): Number of input channels.
intermediate_dim (int): Dimensionality of the intermediate layer.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional LayerNorm. Defaults to None.
"""
def __init__(
self,
dim: int = 512,
intermediate_dim: int = 1536,
layer_scale_init_value: float = 0.125,
complex_axis: int = 1
):
super().__init__()
self.dwconv = ComplexTranspose1D(in_channels=dim, out_channels=dim, kernel_size=3, padding=1, complex_axis=1)
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(intermediate_dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
if layer_scale_init_value > 0
else None
)
self.complex_axis = complex_axis
self.out_conv = ComplexConv1D(in_channels=dim*2, out_channels=dim, kernel_size=3, padding=1, complex_axis=1)
def forward(self, x: torch.Tensor, laterial: torch.Tensor = None) -> torch.Tensor:
# residual = x
x = self.dwconv(x)
real, imag = torch.chunk(x, 2, dim=self.complex_axis) # Split real and imaginary parts
real = real.transpose(1, 2) # (B, C, T) -> (B, T, C)
imag = imag.transpose(1, 2) # (B, C, T) -> (B, T, C)
real = self.norm(real) # Apply LayerNorm to real part
imag = self.norm(imag) # Apply LayerNorm to imaginary part
real = self.pwconv2(self.act(self.pwconv1(real))) # MLP on real part
imag = self.pwconv2(self.act(self.pwconv1(imag))) # MLP on imaginary part
if self.gamma is not None:
real = self.gamma * real
imag = self.gamma * imag
real = real.transpose(1, 2) # (B, T, C) -> (B, C, T)
imag = imag.transpose(1, 2) # (B, T, C) -> (B, C, T)
x = torch.cat([real, imag], dim=self.complex_axis) # Concatenate real and imaginary parts back together
# x = residual + x
if laterial is not None:
up_real, up_imag = torch.chunk(x, 2, dim=self.complex_axis)
res_real, res_imag = torch.chunk(laterial, 2, dim=self.complex_axis)
real_total = torch.cat([up_real, res_real], self.complex_axis)
imag_total = torch.cat([up_imag, res_imag], self.complex_axis)
x = torch.cat([real_total, imag_total], self.complex_axis)
x = self.out_conv(x)
return x