| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as Func |
| |
|
| |
|
| | class RMSNorm(nn.Module): |
| | def __init__(self, dim): |
| | super().__init__() |
| | self.scale = dim**0.5 |
| | self.gamma = nn.Parameter(torch.ones(dim)) |
| |
|
| | def forward(self, x): |
| | return Func.normalize(x, dim=-1) * self.scale * self.gamma |
| |
|
| |
|
| | class MambaModule(nn.Module): |
| | def __init__(self, d_model, d_state, d_conv, d_expand): |
| | super().__init__() |
| | self.norm = RMSNorm(dim=d_model) |
| | self.mamba = Mamba( |
| | d_model=d_model, d_state=d_state, d_conv=d_conv, d_expand=d_expand |
| | ) |
| |
|
| | def forward(self, x): |
| | x = x + self.mamba(self.norm(x)) |
| | return x |
| |
|
| |
|
| | class RNNModule(nn.Module): |
| | """ |
| | RNNModule class implements a recurrent neural network module with LSTM cells. |
| | |
| | Args: |
| | - input_dim (int): Dimensionality of the input features. |
| | - hidden_dim (int): Dimensionality of the hidden state of the LSTM. |
| | - bidirectional (bool, optional): If True, uses bidirectional LSTM. Defaults to True. |
| | |
| | Shapes: |
| | - Input: (B, T, D) where |
| | B is batch size, |
| | T is sequence length, |
| | D is input dimensionality. |
| | - Output: (B, T, D) where |
| | B is batch size, |
| | T is sequence length, |
| | D is input dimensionality. |
| | """ |
| |
|
| | def __init__(self, input_dim: int, hidden_dim: int, bidirectional: bool = True): |
| | """ |
| | Initializes RNNModule with input dimension, hidden dimension, and bidirectional flag. |
| | """ |
| | super().__init__() |
| | self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=input_dim) |
| | self.rnn = nn.LSTM( |
| | input_dim, hidden_dim, batch_first=True, bidirectional=bidirectional |
| | ) |
| | self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, input_dim) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Performs forward pass through the RNNModule. |
| | |
| | Args: |
| | - x (torch.Tensor): Input tensor of shape (B, T, D). |
| | |
| | Returns: |
| | - torch.Tensor: Output tensor of shape (B, T, D). |
| | """ |
| | x = x.transpose(1, 2) |
| | x = self.groupnorm(x) |
| | x = x.transpose(1, 2) |
| |
|
| | x, (hidden, _) = self.rnn(x) |
| | x = self.fc(x) |
| | return x |
| |
|
| |
|
| | class RFFTModule(nn.Module): |
| | """ |
| | RFFTModule class implements a module for performing real-valued Fast Fourier Transform (FFT) |
| | or its inverse on input tensors. |
| | |
| | Args: |
| | - inverse (bool, optional): If False, performs forward FFT. If True, performs inverse FFT. Defaults to False. |
| | |
| | Shapes: |
| | - Input: (B, F, T, D) where |
| | B is batch size, |
| | F is the number of features, |
| | T is sequence length, |
| | D is input dimensionality. |
| | - Output: (B, F, T // 2 + 1, D * 2) if performing forward FFT. |
| | (B, F, T, D // 2, 2) if performing inverse FFT. |
| | """ |
| |
|
| | def __init__(self, inverse: bool = False): |
| | """ |
| | Initializes RFFTModule with inverse flag. |
| | """ |
| | super().__init__() |
| | self.inverse = inverse |
| |
|
| | def forward(self, x: torch.Tensor, time_dim: int) -> torch.Tensor: |
| | """ |
| | Performs forward or inverse FFT on the input tensor x. |
| | |
| | Args: |
| | - x (torch.Tensor): Input tensor of shape (B, F, T, D). |
| | - time_dim (int): Input size of time dimension. |
| | |
| | Returns: |
| | - torch.Tensor: Output tensor after FFT or its inverse operation. |
| | """ |
| | dtype = x.dtype |
| | B, F, T, D = x.shape |
| |
|
| | |
| | x = x.float() |
| |
|
| | if not self.inverse: |
| | x = torch.fft.rfft(x, dim=2) |
| | x = torch.view_as_real(x) |
| | x = x.reshape(B, F, T // 2 + 1, D * 2) |
| | else: |
| | x = x.reshape(B, F, T, D // 2, 2) |
| | x = torch.view_as_complex(x) |
| | x = torch.fft.irfft(x, n=time_dim, dim=2) |
| |
|
| | x = x.to(dtype) |
| | return x |
| |
|
| | def extra_repr(self) -> str: |
| | """ |
| | Returns extra representation string with module's configuration. |
| | """ |
| | return f"inverse={self.inverse}" |
| |
|
| |
|
| | class DualPathRNN(nn.Module): |
| | """ |
| | DualPathRNN class implements a neural network with alternating layers of RNNModule and RFFTModule. |
| | |
| | Args: |
| | - n_layers (int): Number of layers in the network. |
| | - input_dim (int): Dimensionality of the input features. |
| | - hidden_dim (int): Dimensionality of the hidden state of the RNNModule. |
| | |
| | Shapes: |
| | - Input: (B, F, T, D) where |
| | B is batch size, |
| | F is the number of features (frequency dimension), |
| | T is sequence length (time dimension), |
| | D is input dimensionality (channel dimension). |
| | - Output: (B, F, T, D) where |
| | B is batch size, |
| | F is the number of features (frequency dimension), |
| | T is sequence length (time dimension), |
| | D is input dimensionality (channel dimension). |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | n_layers: int, |
| | input_dim: int, |
| | hidden_dim: int, |
| | use_mamba: bool = False, |
| | d_state: int = 16, |
| | d_conv: int = 4, |
| | d_expand: int = 2, |
| | ): |
| | """ |
| | Initializes DualPathRNN with the specified number of layers, input dimension, and hidden dimension. |
| | """ |
| | super().__init__() |
| |
|
| | if use_mamba: |
| | from mamba_ssm.modules.mamba_simple import Mamba |
| |
|
| | net = MambaModule |
| | dkwargs = { |
| | "d_model": input_dim, |
| | "d_state": d_state, |
| | "d_conv": d_conv, |
| | "d_expand": d_expand, |
| | } |
| | ukwargs = { |
| | "d_model": input_dim * 2, |
| | "d_state": d_state, |
| | "d_conv": d_conv, |
| | "d_expand": d_expand * 2, |
| | } |
| | else: |
| | net = RNNModule |
| | dkwargs = {"input_dim": input_dim, "hidden_dim": hidden_dim} |
| | ukwargs = {"input_dim": input_dim * 2, "hidden_dim": hidden_dim * 2} |
| |
|
| | self.layers = nn.ModuleList() |
| | for i in range(1, n_layers + 1): |
| | kwargs = dkwargs if i % 2 == 1 else ukwargs |
| | layer = nn.ModuleList( |
| | [ |
| | net(**kwargs), |
| | net(**kwargs), |
| | RFFTModule(inverse=(i % 2 == 0)), |
| | ] |
| | ) |
| | self.layers.append(layer) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Performs forward pass through the DualPathRNN. |
| | |
| | Args: |
| | - x (torch.Tensor): Input tensor of shape (B, F, T, D). |
| | |
| | Returns: |
| | - torch.Tensor: Output tensor of shape (B, F, T, D). |
| | """ |
| |
|
| | time_dim = x.shape[2] |
| |
|
| | for time_layer, freq_layer, rfft_layer in self.layers: |
| | B, F, T, D = x.shape |
| |
|
| | x = x.reshape((B * F), T, D) |
| | x = time_layer(x) |
| | x = x.reshape(B, F, T, D) |
| | x = x.permute(0, 2, 1, 3) |
| |
|
| | x = x.reshape((B * T), F, D) |
| | x = freq_layer(x) |
| | x = x.reshape(B, T, F, D) |
| | x = x.permute(0, 2, 1, 3) |
| |
|
| | x = rfft_layer(x, time_dim) |
| |
|
| | return x |
| |
|