U-Past / modules /blocks /complexmodule.py
lycaoduong's picture
Initial space
e8160b2 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class ComplexConv2d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
dilation=1,
groups=1,
causal=False,
complex_axis=1,
):
'''
in_channels: real+imag
out_channels: real+imag
kernel_size : input [B,C,D,T] kernel size in [D,T]
padding : input [B,C,D,T] padding in [D,T]
causal: if causal, will padding time dimension's left side,
otherwise both
'''
super(ComplexConv2d, self).__init__()
self.in_channels = in_channels // 2
self.out_channels = out_channels // 2
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.causal = causal
self.groups = groups
self.dilation = dilation
self.complex_axis = complex_axis
self.real_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride,
padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups)
self.imag_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride,
padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups)
nn.init.normal_(self.real_conv.weight.data, std=0.05)
nn.init.normal_(self.imag_conv.weight.data, std=0.05)
nn.init.constant_(self.real_conv.bias, 0.)
nn.init.constant_(self.imag_conv.bias, 0.)
def forward(self, inputs):
if self.padding[1] != 0 and self.causal:
inputs = F.pad(inputs, [self.padding[1], 0, 0, 0])
else:
inputs = F.pad(inputs, [self.padding[1], self.padding[1], 0, 0])
if self.complex_axis == 0:
real = self.real_conv(inputs)
imag = self.imag_conv(inputs)
real2real, imag2real = torch.chunk(real, 2, self.complex_axis)
real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis)
else:
if isinstance(inputs, torch.Tensor):
real, imag = torch.chunk(inputs, 2, self.complex_axis)
real2real = self.real_conv(real, )
imag2imag = self.imag_conv(imag, )
real2imag = self.imag_conv(real)
imag2real = self.real_conv(imag)
real = real2real - imag2imag
imag = real2imag + imag2real
out = torch.cat([real, imag], self.complex_axis)
return out
class ComplexGroupNorm(nn.Module):
def __init__(self, num_channels, num_groups, eps=1e-6, complex_axis=1):
super(ComplexGroupNorm, self).__init__()
# self.num_channels = num_channels // 2
self.num_groups = num_groups
self.eps = eps
self.complex_axis = complex_axis
self.real_norm = nn.GroupNorm(num_groups, num_channels // 2, eps=eps)
self.imag_norm = nn.GroupNorm(num_groups, num_channels // 2, eps=eps)
def forward(self, x):
real, imag = torch.chunk(x, 2, self.complex_axis)
real_normed = self.real_norm(real)
imag_normed = self.imag_norm(imag)
out = torch.cat([real_normed, imag_normed], self.complex_axis)
return out
class ComplexBatchNorm(torch.nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True, complex_axis=1):
super(ComplexBatchNorm, self).__init__()
self.num_features = num_features // 2
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
self.complex_axis = complex_axis
if self.affine:
self.Wrr = torch.nn.Parameter(torch.Tensor(self.num_features))
self.Wri = torch.nn.Parameter(torch.Tensor(self.num_features))
self.Wii = torch.nn.Parameter(torch.Tensor(self.num_features))
self.Br = torch.nn.Parameter(torch.Tensor(self.num_features))
self.Bi = torch.nn.Parameter(torch.Tensor(self.num_features))
else:
self.register_parameter('Wrr', None)
self.register_parameter('Wri', None)
self.register_parameter('Wii', None)
self.register_parameter('Br', None)
self.register_parameter('Bi', None)
if self.track_running_stats:
self.register_buffer('RMr', torch.zeros(self.num_features))
self.register_buffer('RMi', torch.zeros(self.num_features))
self.register_buffer('RVrr', torch.ones(self.num_features))
self.register_buffer('RVri', torch.zeros(self.num_features))
self.register_buffer('RVii', torch.ones(self.num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
else:
self.register_parameter('RMr', None)
self.register_parameter('RMi', None)
self.register_parameter('RVrr', None)
self.register_parameter('RVri', None)
self.register_parameter('RVii', None)
self.register_parameter('num_batches_tracked', None)
self.reset_parameters()
def reset_running_stats(self):
if self.track_running_stats:
self.RMr.zero_()
self.RMi.zero_()
self.RVrr.fill_(1)
self.RVri.zero_()
self.RVii.fill_(1)
self.num_batches_tracked.zero_()
def reset_parameters(self):
self.reset_running_stats()
if self.affine:
self.Br.data.zero_()
self.Bi.data.zero_()
self.Wrr.data.fill_(1)
self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite
self.Wii.data.fill_(1)
def _check_input_dim(self, xr, xi):
assert (xr.shape == xi.shape)
assert (xr.size(1) == self.num_features)
def forward(self, inputs):
# self._check_input_dim(xr, xi)
xr, xi = torch.chunk(inputs, 2, axis=self.complex_axis)
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
else: # use exponential moving average
exponential_average_factor = self.momentum
#
# NOTE: The precise meaning of the "training flag" is:
# True: Normalize using batch statistics, update running statistics
# if they are being collected.
# False: Normalize using running statistics, ignore batch statistics.
#
training = self.training or not self.track_running_stats
redux = [i for i in reversed(range(xr.dim())) if i != 1]
vdim = [1] * xr.dim()
vdim[1] = xr.size(1)
#
# Mean M Computation and Centering
#
# Includes running mean update if training and running.
#
if training:
Mr, Mi = xr, xi
for d in redux:
Mr = Mr.mean(d, keepdim=True)
Mi = Mi.mean(d, keepdim=True)
if self.track_running_stats:
self.RMr.lerp_(Mr.squeeze(), exponential_average_factor)
self.RMi.lerp_(Mi.squeeze(), exponential_average_factor)
else:
Mr = self.RMr.view(vdim)
Mi = self.RMi.view(vdim)
xr, xi = xr - Mr, xi - Mi
#
# Variance Matrix V Computation
#
# Includes epsilon numerical stabilizer/Tikhonov regularizer.
# Includes running variance update if training and running.
#
if training:
Vrr = xr * xr
Vri = xr * xi
Vii = xi * xi
for d in redux:
Vrr = Vrr.mean(d, keepdim=True)
Vri = Vri.mean(d, keepdim=True)
Vii = Vii.mean(d, keepdim=True)
if self.track_running_stats:
self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor)
self.RVri.lerp_(Vri.squeeze(), exponential_average_factor)
self.RVii.lerp_(Vii.squeeze(), exponential_average_factor)
else:
Vrr = self.RVrr.view(vdim)
Vri = self.RVri.view(vdim)
Vii = self.RVii.view(vdim)
Vrr = Vrr + self.eps
Vri = Vri
Vii = Vii + self.eps
#
# Matrix Inverse Square Root U = V^-0.5
#
# sqrt of a 2x2 matrix,
# - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
tau = Vrr + Vii
# delta = torch.addcmul(Vrr * Vii, -1, Vri, Vri)
delta = torch.addcmul(Vrr * Vii, Vri, Vri, value=-1)
s = delta.sqrt()
t = (tau + 2 * s).sqrt()
# matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html
rst = (s * t).reciprocal()
Urr = (s + Vii) * rst
Uii = (s + Vrr) * rst
Uri = (- Vri) * rst
#
# Optionally left-multiply U by affine weights W to produce combined
# weights Z, left-multiply the inputs by Z, then optionally bias them.
#
# y = Zx + B
# y = WUx + B
# y = [Wrr Wri][Urr Uri] [xr] + [Br]
# [Wir Wii][Uir Uii] [xi] [Bi]
#
if self.affine:
Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim)
Zrr = (Wrr * Urr) + (Wri * Uri)
Zri = (Wrr * Uri) + (Wri * Uii)
Zir = (Wri * Urr) + (Wii * Uri)
Zii = (Wri * Uri) + (Wii * Uii)
else:
Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
yr = (Zrr * xr) + (Zri * xi)
yi = (Zir * xr) + (Zii * xi)
if self.affine:
yr = yr + self.Br.view(vdim)
yi = yi + self.Bi.view(vdim)
outputs = torch.cat([yr, yi], self.complex_axis)
return outputs
def extra_repr(self):
return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
'track_running_stats={track_running_stats}'.format(**self.__dict__)
class cPReLU(nn.Module):
def __init__(self, complex_axis=1):
super(cPReLU, self).__init__()
self.r_prelu = nn.PReLU()
self.i_prelu = nn.PReLU()
self.complex_axis = complex_axis
def forward(self, inputs):
real, imag = torch.chunk(inputs, 2, self.complex_axis)
real = self.r_prelu(real)
imag = self.i_prelu(imag)
return torch.cat([real, imag], self.complex_axis)
class ComplexConvTranspose2d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
output_padding=(0, 0),
causal=False,
complex_axis=1,
groups=1,
dilation=1
):
'''
in_channels: real+imag
out_channels: real+imag
'''
super(ComplexConvTranspose2d, self).__init__()
self.in_channels = in_channels // 2
self.out_channels = out_channels // 2
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.output_padding = output_padding
self.groups = groups
self.dilation = dilation
self.real_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride,
padding=self.padding, output_padding=output_padding, groups=self.groups,
dilation=self.dilation)
self.imag_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride,
padding=self.padding, output_padding=output_padding, groups=self.groups,
dilation=self.dilation)
self.complex_axis = complex_axis
nn.init.normal_(self.real_conv.weight, std=0.05)
nn.init.normal_(self.imag_conv.weight, std=0.05)
nn.init.constant_(self.real_conv.bias, 0.)
nn.init.constant_(self.imag_conv.bias, 0.)
def forward(self, inputs):
if isinstance(inputs, torch.Tensor):
real, imag = torch.chunk(inputs, 2, self.complex_axis)
elif isinstance(inputs, tuple) or isinstance(inputs, list):
real = inputs[0]
imag = inputs[1]
if self.complex_axis == 0:
real = self.real_conv(inputs)
imag = self.imag_conv(inputs)
real2real, imag2real = torch.chunk(real, 2, self.complex_axis)
real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis)
else:
if isinstance(inputs, torch.Tensor):
real, imag = torch.chunk(inputs, 2, self.complex_axis)
real2real = self.real_conv(real)
imag2imag = self.imag_conv(imag)
real2imag = self.imag_conv(real)
imag2real = self.real_conv(imag)
real = real2real - imag2imag
imag = real2imag + imag2real
out = torch.cat([real, imag], self.complex_axis)
return out
class ComplexConv1D(nn.Module):
def __init__(
self,
in_channels: int = 256,
out_channels: int = 512,
kernel_size: int = 3,
stride: int =1,
padding: int =1,
complex_axis: int =1
):
super(ComplexConv1D, self).__init__()
self.real_conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
self.imag_conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
self.complex_axis = complex_axis
def forward(self, inputs):
# inputs: [B, 2*C, T]
real, imag = torch.chunk(inputs, 2, self.complex_axis)
real2real = self.real_conv(real)
imag2imag = self.imag_conv(imag)
real2imag = self.imag_conv(real)
imag2real = self.real_conv(imag)
real = real2real - imag2imag
imag = real2imag + imag2real
# 2 * [B, C, T] -> [B, 2*C, T]
out = torch.cat([real, imag], self.complex_axis)
return out
class ComplexTranspose1D(nn.Module):
def __init__(
self,
in_channels: int = 256,
out_channels: int = 512,
kernel_size: int = 3,
stride: int =1,
padding: int =1,
complex_axis: int =1
):
super(ComplexTranspose1D, self).__init__()
self.real_conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding)
self.imag_conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding)
self.complex_axis = complex_axis
def forward(self, inputs):
# inputs: [B, 2*C, T]
real, imag = torch.chunk(inputs, 2, self.complex_axis)
real2real = self.real_conv(real)
imag2imag = self.imag_conv(imag)
real2imag = self.imag_conv(real)
imag2real = self.real_conv(imag)
real = real2real - imag2imag
imag = real2imag + imag2real
# 2 * [B, C, T] -> [B, 2*C, T]
out = torch.cat([real, imag], self.complex_axis)
return out
if __name__ == "__main__":
# Test ComplexConv2d
batch_size = 4
in_channels = 256
out_channels = 512
height = 256
width = 256
x = torch.randn(batch_size, 2, height, width)
tokenizer = ComplexConv1D(in_channels, out_channels)
tokens = tokenizer(x)
print("Input shape:", x.shape)
print("Output shape:", tokens.shape)