Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class ComplexConv2d(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=(1, 1), | |
| stride=(1, 1), | |
| padding=(0, 0), | |
| dilation=1, | |
| groups=1, | |
| causal=False, | |
| complex_axis=1, | |
| ): | |
| ''' | |
| in_channels: real+imag | |
| out_channels: real+imag | |
| kernel_size : input [B,C,D,T] kernel size in [D,T] | |
| padding : input [B,C,D,T] padding in [D,T] | |
| causal: if causal, will padding time dimension's left side, | |
| otherwise both | |
| ''' | |
| super(ComplexConv2d, self).__init__() | |
| self.in_channels = in_channels // 2 | |
| self.out_channels = out_channels // 2 | |
| self.kernel_size = kernel_size | |
| self.stride = stride | |
| self.padding = padding | |
| self.causal = causal | |
| self.groups = groups | |
| self.dilation = dilation | |
| self.complex_axis = complex_axis | |
| self.real_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride, | |
| padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups) | |
| self.imag_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride, | |
| padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups) | |
| nn.init.normal_(self.real_conv.weight.data, std=0.05) | |
| nn.init.normal_(self.imag_conv.weight.data, std=0.05) | |
| nn.init.constant_(self.real_conv.bias, 0.) | |
| nn.init.constant_(self.imag_conv.bias, 0.) | |
| def forward(self, inputs): | |
| if self.padding[1] != 0 and self.causal: | |
| inputs = F.pad(inputs, [self.padding[1], 0, 0, 0]) | |
| else: | |
| inputs = F.pad(inputs, [self.padding[1], self.padding[1], 0, 0]) | |
| if self.complex_axis == 0: | |
| real = self.real_conv(inputs) | |
| imag = self.imag_conv(inputs) | |
| real2real, imag2real = torch.chunk(real, 2, self.complex_axis) | |
| real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis) | |
| else: | |
| if isinstance(inputs, torch.Tensor): | |
| real, imag = torch.chunk(inputs, 2, self.complex_axis) | |
| real2real = self.real_conv(real, ) | |
| imag2imag = self.imag_conv(imag, ) | |
| real2imag = self.imag_conv(real) | |
| imag2real = self.real_conv(imag) | |
| real = real2real - imag2imag | |
| imag = real2imag + imag2real | |
| out = torch.cat([real, imag], self.complex_axis) | |
| return out | |
| class ComplexGroupNorm(nn.Module): | |
| def __init__(self, num_channels, num_groups, eps=1e-6, complex_axis=1): | |
| super(ComplexGroupNorm, self).__init__() | |
| # self.num_channels = num_channels // 2 | |
| self.num_groups = num_groups | |
| self.eps = eps | |
| self.complex_axis = complex_axis | |
| self.real_norm = nn.GroupNorm(num_groups, num_channels // 2, eps=eps) | |
| self.imag_norm = nn.GroupNorm(num_groups, num_channels // 2, eps=eps) | |
| def forward(self, x): | |
| real, imag = torch.chunk(x, 2, self.complex_axis) | |
| real_normed = self.real_norm(real) | |
| imag_normed = self.imag_norm(imag) | |
| out = torch.cat([real_normed, imag_normed], self.complex_axis) | |
| return out | |
| class ComplexBatchNorm(torch.nn.Module): | |
| def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, | |
| track_running_stats=True, complex_axis=1): | |
| super(ComplexBatchNorm, self).__init__() | |
| self.num_features = num_features // 2 | |
| self.eps = eps | |
| self.momentum = momentum | |
| self.affine = affine | |
| self.track_running_stats = track_running_stats | |
| self.complex_axis = complex_axis | |
| if self.affine: | |
| self.Wrr = torch.nn.Parameter(torch.Tensor(self.num_features)) | |
| self.Wri = torch.nn.Parameter(torch.Tensor(self.num_features)) | |
| self.Wii = torch.nn.Parameter(torch.Tensor(self.num_features)) | |
| self.Br = torch.nn.Parameter(torch.Tensor(self.num_features)) | |
| self.Bi = torch.nn.Parameter(torch.Tensor(self.num_features)) | |
| else: | |
| self.register_parameter('Wrr', None) | |
| self.register_parameter('Wri', None) | |
| self.register_parameter('Wii', None) | |
| self.register_parameter('Br', None) | |
| self.register_parameter('Bi', None) | |
| if self.track_running_stats: | |
| self.register_buffer('RMr', torch.zeros(self.num_features)) | |
| self.register_buffer('RMi', torch.zeros(self.num_features)) | |
| self.register_buffer('RVrr', torch.ones(self.num_features)) | |
| self.register_buffer('RVri', torch.zeros(self.num_features)) | |
| self.register_buffer('RVii', torch.ones(self.num_features)) | |
| self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) | |
| else: | |
| self.register_parameter('RMr', None) | |
| self.register_parameter('RMi', None) | |
| self.register_parameter('RVrr', None) | |
| self.register_parameter('RVri', None) | |
| self.register_parameter('RVii', None) | |
| self.register_parameter('num_batches_tracked', None) | |
| self.reset_parameters() | |
| def reset_running_stats(self): | |
| if self.track_running_stats: | |
| self.RMr.zero_() | |
| self.RMi.zero_() | |
| self.RVrr.fill_(1) | |
| self.RVri.zero_() | |
| self.RVii.fill_(1) | |
| self.num_batches_tracked.zero_() | |
| def reset_parameters(self): | |
| self.reset_running_stats() | |
| if self.affine: | |
| self.Br.data.zero_() | |
| self.Bi.data.zero_() | |
| self.Wrr.data.fill_(1) | |
| self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite | |
| self.Wii.data.fill_(1) | |
| def _check_input_dim(self, xr, xi): | |
| assert (xr.shape == xi.shape) | |
| assert (xr.size(1) == self.num_features) | |
| def forward(self, inputs): | |
| # self._check_input_dim(xr, xi) | |
| xr, xi = torch.chunk(inputs, 2, axis=self.complex_axis) | |
| exponential_average_factor = 0.0 | |
| if self.training and self.track_running_stats: | |
| self.num_batches_tracked += 1 | |
| if self.momentum is None: # use cumulative moving average | |
| exponential_average_factor = 1.0 / self.num_batches_tracked.item() | |
| else: # use exponential moving average | |
| exponential_average_factor = self.momentum | |
| # | |
| # NOTE: The precise meaning of the "training flag" is: | |
| # True: Normalize using batch statistics, update running statistics | |
| # if they are being collected. | |
| # False: Normalize using running statistics, ignore batch statistics. | |
| # | |
| training = self.training or not self.track_running_stats | |
| redux = [i for i in reversed(range(xr.dim())) if i != 1] | |
| vdim = [1] * xr.dim() | |
| vdim[1] = xr.size(1) | |
| # | |
| # Mean M Computation and Centering | |
| # | |
| # Includes running mean update if training and running. | |
| # | |
| if training: | |
| Mr, Mi = xr, xi | |
| for d in redux: | |
| Mr = Mr.mean(d, keepdim=True) | |
| Mi = Mi.mean(d, keepdim=True) | |
| if self.track_running_stats: | |
| self.RMr.lerp_(Mr.squeeze(), exponential_average_factor) | |
| self.RMi.lerp_(Mi.squeeze(), exponential_average_factor) | |
| else: | |
| Mr = self.RMr.view(vdim) | |
| Mi = self.RMi.view(vdim) | |
| xr, xi = xr - Mr, xi - Mi | |
| # | |
| # Variance Matrix V Computation | |
| # | |
| # Includes epsilon numerical stabilizer/Tikhonov regularizer. | |
| # Includes running variance update if training and running. | |
| # | |
| if training: | |
| Vrr = xr * xr | |
| Vri = xr * xi | |
| Vii = xi * xi | |
| for d in redux: | |
| Vrr = Vrr.mean(d, keepdim=True) | |
| Vri = Vri.mean(d, keepdim=True) | |
| Vii = Vii.mean(d, keepdim=True) | |
| if self.track_running_stats: | |
| self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor) | |
| self.RVri.lerp_(Vri.squeeze(), exponential_average_factor) | |
| self.RVii.lerp_(Vii.squeeze(), exponential_average_factor) | |
| else: | |
| Vrr = self.RVrr.view(vdim) | |
| Vri = self.RVri.view(vdim) | |
| Vii = self.RVii.view(vdim) | |
| Vrr = Vrr + self.eps | |
| Vri = Vri | |
| Vii = Vii + self.eps | |
| # | |
| # Matrix Inverse Square Root U = V^-0.5 | |
| # | |
| # sqrt of a 2x2 matrix, | |
| # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix | |
| tau = Vrr + Vii | |
| # delta = torch.addcmul(Vrr * Vii, -1, Vri, Vri) | |
| delta = torch.addcmul(Vrr * Vii, Vri, Vri, value=-1) | |
| s = delta.sqrt() | |
| t = (tau + 2 * s).sqrt() | |
| # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html | |
| rst = (s * t).reciprocal() | |
| Urr = (s + Vii) * rst | |
| Uii = (s + Vrr) * rst | |
| Uri = (- Vri) * rst | |
| # | |
| # Optionally left-multiply U by affine weights W to produce combined | |
| # weights Z, left-multiply the inputs by Z, then optionally bias them. | |
| # | |
| # y = Zx + B | |
| # y = WUx + B | |
| # y = [Wrr Wri][Urr Uri] [xr] + [Br] | |
| # [Wir Wii][Uir Uii] [xi] [Bi] | |
| # | |
| if self.affine: | |
| Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim) | |
| Zrr = (Wrr * Urr) + (Wri * Uri) | |
| Zri = (Wrr * Uri) + (Wri * Uii) | |
| Zir = (Wri * Urr) + (Wii * Uri) | |
| Zii = (Wri * Uri) + (Wii * Uii) | |
| else: | |
| Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii | |
| yr = (Zrr * xr) + (Zri * xi) | |
| yi = (Zir * xr) + (Zii * xi) | |
| if self.affine: | |
| yr = yr + self.Br.view(vdim) | |
| yi = yi + self.Bi.view(vdim) | |
| outputs = torch.cat([yr, yi], self.complex_axis) | |
| return outputs | |
| def extra_repr(self): | |
| return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ | |
| 'track_running_stats={track_running_stats}'.format(**self.__dict__) | |
| class cPReLU(nn.Module): | |
| def __init__(self, complex_axis=1): | |
| super(cPReLU, self).__init__() | |
| self.r_prelu = nn.PReLU() | |
| self.i_prelu = nn.PReLU() | |
| self.complex_axis = complex_axis | |
| def forward(self, inputs): | |
| real, imag = torch.chunk(inputs, 2, self.complex_axis) | |
| real = self.r_prelu(real) | |
| imag = self.i_prelu(imag) | |
| return torch.cat([real, imag], self.complex_axis) | |
| class ComplexConvTranspose2d(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=(1, 1), | |
| stride=(1, 1), | |
| padding=(0, 0), | |
| output_padding=(0, 0), | |
| causal=False, | |
| complex_axis=1, | |
| groups=1, | |
| dilation=1 | |
| ): | |
| ''' | |
| in_channels: real+imag | |
| out_channels: real+imag | |
| ''' | |
| super(ComplexConvTranspose2d, self).__init__() | |
| self.in_channels = in_channels // 2 | |
| self.out_channels = out_channels // 2 | |
| self.kernel_size = kernel_size | |
| self.stride = stride | |
| self.padding = padding | |
| self.output_padding = output_padding | |
| self.groups = groups | |
| self.dilation = dilation | |
| self.real_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride, | |
| padding=self.padding, output_padding=output_padding, groups=self.groups, | |
| dilation=self.dilation) | |
| self.imag_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride, | |
| padding=self.padding, output_padding=output_padding, groups=self.groups, | |
| dilation=self.dilation) | |
| self.complex_axis = complex_axis | |
| nn.init.normal_(self.real_conv.weight, std=0.05) | |
| nn.init.normal_(self.imag_conv.weight, std=0.05) | |
| nn.init.constant_(self.real_conv.bias, 0.) | |
| nn.init.constant_(self.imag_conv.bias, 0.) | |
| def forward(self, inputs): | |
| if isinstance(inputs, torch.Tensor): | |
| real, imag = torch.chunk(inputs, 2, self.complex_axis) | |
| elif isinstance(inputs, tuple) or isinstance(inputs, list): | |
| real = inputs[0] | |
| imag = inputs[1] | |
| if self.complex_axis == 0: | |
| real = self.real_conv(inputs) | |
| imag = self.imag_conv(inputs) | |
| real2real, imag2real = torch.chunk(real, 2, self.complex_axis) | |
| real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis) | |
| else: | |
| if isinstance(inputs, torch.Tensor): | |
| real, imag = torch.chunk(inputs, 2, self.complex_axis) | |
| real2real = self.real_conv(real) | |
| imag2imag = self.imag_conv(imag) | |
| real2imag = self.imag_conv(real) | |
| imag2real = self.real_conv(imag) | |
| real = real2real - imag2imag | |
| imag = real2imag + imag2real | |
| out = torch.cat([real, imag], self.complex_axis) | |
| return out | |
| class ComplexConv1D(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int = 256, | |
| out_channels: int = 512, | |
| kernel_size: int = 3, | |
| stride: int =1, | |
| padding: int =1, | |
| complex_axis: int =1 | |
| ): | |
| super(ComplexConv1D, self).__init__() | |
| self.real_conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding) | |
| self.imag_conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding) | |
| self.complex_axis = complex_axis | |
| def forward(self, inputs): | |
| # inputs: [B, 2*C, T] | |
| real, imag = torch.chunk(inputs, 2, self.complex_axis) | |
| real2real = self.real_conv(real) | |
| imag2imag = self.imag_conv(imag) | |
| real2imag = self.imag_conv(real) | |
| imag2real = self.real_conv(imag) | |
| real = real2real - imag2imag | |
| imag = real2imag + imag2real | |
| # 2 * [B, C, T] -> [B, 2*C, T] | |
| out = torch.cat([real, imag], self.complex_axis) | |
| return out | |
| class ComplexTranspose1D(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int = 256, | |
| out_channels: int = 512, | |
| kernel_size: int = 3, | |
| stride: int =1, | |
| padding: int =1, | |
| complex_axis: int =1 | |
| ): | |
| super(ComplexTranspose1D, self).__init__() | |
| self.real_conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding) | |
| self.imag_conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding) | |
| self.complex_axis = complex_axis | |
| def forward(self, inputs): | |
| # inputs: [B, 2*C, T] | |
| real, imag = torch.chunk(inputs, 2, self.complex_axis) | |
| real2real = self.real_conv(real) | |
| imag2imag = self.imag_conv(imag) | |
| real2imag = self.imag_conv(real) | |
| imag2real = self.real_conv(imag) | |
| real = real2real - imag2imag | |
| imag = real2imag + imag2real | |
| # 2 * [B, C, T] -> [B, 2*C, T] | |
| out = torch.cat([real, imag], self.complex_axis) | |
| return out | |
| if __name__ == "__main__": | |
| # Test ComplexConv2d | |
| batch_size = 4 | |
| in_channels = 256 | |
| out_channels = 512 | |
| height = 256 | |
| width = 256 | |
| x = torch.randn(batch_size, 2, height, width) | |
| tokenizer = ComplexConv1D(in_channels, out_channels) | |
| tokens = tokenizer(x) | |
| print("Input shape:", x.shape) | |
| print("Output shape:", tokens.shape) |