Spaces:
Paused
Paused
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| from typing import Union, Tuple | |
| import torch | |
| import torch.nn as nn | |
| from toolbox.torchaudio.models.frcrn.uni_deep_fsmn import UniDeepFsmn | |
| class ComplexUniDeepFsmn(nn.Module): | |
| def __init__(self, input_dim: int, hidden_size: int, lorder: int = 20): | |
| super(ComplexUniDeepFsmn, self).__init__() | |
| self.fsmn_re_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) | |
| self.fsmn_im_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) | |
| self.fsmn_re_l2 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) | |
| self.fsmn_im_l2 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) | |
| def forward(self, x: torch.Tensor): | |
| """ | |
| :param x: torch.Tensor, shape: [b, c, h, t, 2] | |
| :return: torch.Tensor, shape: [b, h, t, 2] | |
| """ | |
| b, c, h, t, d = x.size() | |
| x = torch.reshape(x, shape=(b, c * h, t, d)) | |
| # x shape: [b, h', t, 2] | |
| x = torch.transpose(x, dim0=1, dim1=2) | |
| # x shape: [b, t, h', 2] | |
| real_l1 = self.fsmn_re_l1(x[..., 0]) - self.fsmn_im_l1(x[..., 1]) | |
| imaginary_l1 = self.fsmn_re_l1(x[..., 1]) + self.fsmn_im_l1(x[..., 0]) | |
| # real, image shape: [b, t, h'] | |
| real = self.fsmn_re_l2(real_l1) - self.fsmn_im_l2(imaginary_l1) | |
| imaginary = self.fsmn_re_l2(imaginary_l1) + self.fsmn_im_l2(real_l1) | |
| # real, image shape: [b, t, h'] | |
| output = torch.stack(tensors=(real, imaginary), dim=-1) | |
| # output shape: [b, t, h', 2] | |
| output = torch.transpose(output, dim0=1, dim1=2) | |
| # output shape: [b, h', t, 2] | |
| output = torch.reshape(output, shape=(b, c, h, t, d)) | |
| # output shape: [b, c, h, t, 2] | |
| return output | |
| class ComplexUniDeepFsmnL1(nn.Module): | |
| def __init__(self, input_dim: int, hidden_size: int, lorder: int = 20): | |
| super(ComplexUniDeepFsmnL1, self).__init__() | |
| self.fsmn_re_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) | |
| self.fsmn_im_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) | |
| def forward(self, x: torch.Tensor): | |
| b, c, h, t, d = x.size() | |
| x = torch.transpose(x, dim0=1, dim1=3) | |
| # x shape: [b, t, h, c, 2] | |
| x = torch.reshape(x, shape=(b * t, h, c, d)) | |
| # x shape: [b*t, h, c, 2] | |
| real = self.fsmn_re_l1(x[..., 0]) - self.fsmn_im_l1(x[..., 1]) | |
| imaginary = self.fsmn_re_l1(x[..., 1]) + self.fsmn_im_l1(x[..., 0]) | |
| # real, image shape: [b*t, h, c] | |
| output = torch.stack(tensors=(real, imaginary), dim=-1) | |
| # output shape: [b*t, h, c, 2] | |
| output = torch.reshape(output, shape=(b, t, h, c, d)) | |
| # output shape: [b, t, h, c, 2] | |
| output = torch.transpose(output, dim0=1, dim1=3) | |
| # output shape: [b, c, h, t, 2] | |
| return output | |
| class ComplexConv2d(nn.Module): | |
| # https://github.com/litcoderr/ComplexCNN/blob/master/complexcnn/modules.py | |
| def __init__(self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: Union[int, Tuple[int, int]], | |
| stride: Union[int, Tuple[int, int]] = 1, | |
| padding: Union[int, Tuple[int, int]] = 0, | |
| dilation: Union[int, Tuple[int, int]] = 1, | |
| groups: int = 1, | |
| bias: bool = True, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| # Model components | |
| self.conv_re = nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=bias, | |
| **kwargs | |
| ) | |
| self.conv_im = nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=bias, | |
| **kwargs | |
| ) | |
| def forward(self, x: torch.Tensor): | |
| """ | |
| :param x: torch.Tensor, shape: [b, c, h, w, 2] | |
| :return: | |
| """ | |
| real = self.conv_re(x[..., 0]) - self.conv_im(x[..., 1]) | |
| imaginary = self.conv_re(x[..., 1]) + self.conv_im(x[..., 0]) | |
| output = torch.stack((real, imaginary), dim=-1) | |
| return output | |
| class ComplexConvTranspose2d(nn.Module): | |
| def __init__(self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: Union[int, Tuple[int, int]], | |
| stride: Union[int, Tuple[int, int]] = 1, | |
| padding: Union[int, Tuple[int, int]] = 0, | |
| output_padding: Union[int, Tuple[int, int]] = 0, | |
| dilation: Union[int, Tuple[int, int]] = 1, | |
| groups: int = 1, | |
| bias=True, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| # Model components | |
| self.tconv_re = nn.ConvTranspose2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| output_padding=output_padding, | |
| groups=groups, | |
| bias=bias, | |
| dilation=dilation, | |
| **kwargs | |
| ) | |
| self.tconv_im = nn.ConvTranspose2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| output_padding=output_padding, | |
| groups=groups, | |
| bias=bias, | |
| dilation=dilation, | |
| **kwargs | |
| ) | |
| def forward(self, x: torch.Tensor): | |
| """ | |
| :param x: torch.Tensor, shape: [b, c, h, w, 2] | |
| :return: | |
| """ | |
| real = self.tconv_re(x[..., 0]) - self.tconv_im(x[..., 1]) | |
| imaginary = self.tconv_re(x[..., 1]) + self.tconv_im(x[..., 0]) | |
| output = torch.stack((real, imaginary), dim=-1) | |
| return output | |
| class ComplexBatchNorm2d(nn.Module): | |
| def __init__(self, | |
| num_features: int, | |
| eps: float = 1e-5, | |
| momentum: float = 0.1, | |
| affine: bool = True, | |
| track_running_stats: bool = True, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| self.bn_re = nn.BatchNorm2d( | |
| num_features=num_features, | |
| momentum=momentum, | |
| affine=affine, | |
| eps=eps, | |
| track_running_stats=track_running_stats, | |
| **kwargs | |
| ) | |
| self.bn_im = nn.BatchNorm2d( | |
| num_features=num_features, | |
| momentum=momentum, | |
| affine=affine, | |
| eps=eps, | |
| track_running_stats=track_running_stats, | |
| **kwargs | |
| ) | |
| def forward(self, x: torch.Tensor): | |
| real = self.bn_re(x[..., 0]) | |
| imag = self.bn_im(x[..., 1]) | |
| output = torch.stack((real, imag), dim=-1) | |
| return output | |
| def main(): | |
| # x = torch.rand(size=(1, 1, 32, 200, 2)) | |
| # fsmn = ComplexUniDeepFsmn( | |
| # input_dim=32, | |
| # hidden_size=64, | |
| # ) | |
| # result = fsmn.forward(x) | |
| # print(result.shape) | |
| # x = torch.rand(size=(1, 32, 32, 200, 2)) | |
| # fsmn = ComplexUniDeepFsmnL1( | |
| # input_dim=32, | |
| # hidden_size=64, | |
| # ) | |
| # result = fsmn.forward(x) | |
| # print(result.shape) | |
| # x = torch.rand(size=(1, 32, 200, 200, 2)) | |
| x = torch.rand(size=(1, 1, 320, 200, 2)) | |
| conv2d = ComplexConv2d( | |
| in_channels=1, | |
| out_channels=128, | |
| kernel_size=(5, 2), | |
| stride=(2, 1), | |
| padding=(0, 1), | |
| ) | |
| result = conv2d.forward(x) | |
| print(result.shape) | |
| # x = torch.rand(size=(1, 32, 200, 200, 2)) | |
| # x = torch.rand(size=(1, 64, 15, 2000, 2)) | |
| # tconv = ComplexConvTranspose2d( | |
| # in_channels=64, | |
| # out_channels=32, | |
| # kernel_size=(3, 3), | |
| # stride=(2, 1), | |
| # padding=(0, 1), | |
| # ) | |
| # result = tconv.forward(x) | |
| # print(result.shape) | |
| return | |
| if __name__ == "__main__": | |
| main() | |