from typing import * import math import torch import torch.nn as nn import torch.nn.functional as F import librosa from torch import Tensor from torch.nn import Parameter, init from torch.nn.common_types import _size_1_t from model.mamba.mamba import Mamba from model.mamba.utils.generation import InferenceParams class LinearGroup(nn.Module): def __init__(self, in_features: int, out_features: int, num_groups: int, bias: bool = True) -> None: super(LinearGroup, self).__init__() self.in_features = in_features self.out_features = out_features self.num_groups = num_groups self.weight = Parameter(torch.empty((num_groups, out_features, in_features))) if bias: self.bias = Parameter(torch.empty(num_groups, out_features)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self) -> None: # same as linear init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 init.uniform_(self.bias, -bound, bound) def forward(self, x: Tensor) -> Tensor: """shape [..., group, feature]""" x = torch.einsum("...gh,gkh->...gk", x, self.weight) if self.bias is not None: x = x + self.bias return x def extra_repr(self) -> str: return f"{self.in_features}, {self.out_features}, num_groups={self.num_groups}, bias={True if self.bias is not None else False}" class LayerNorm(nn.LayerNorm): def __init__(self, seq_last: bool, **kwargs) -> None: """ Arg s: seq_last (bool): whether the sequence dim is the last dim """ super().__init__(**kwargs) self.seq_last = seq_last def forward(self, input: Tensor) -> Tensor: if self.seq_last: input = input.transpose(-1, 1) # [B, H, Seq] -> [B, Seq, H], or [B,H,w,h] -> [B,h,w,H] o = super().forward(input) if self.seq_last: o = o.transpose(-1, 1) return o class CausalConv1d(nn.Conv1d): def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, padding: _size_1_t | str = 0, dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', device=None, dtype=None, look_ahead: int = 0, ) -> None: super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype) self.look_ahead = look_ahead assert look_ahead <= self.kernel_size[0] - 1, (look_ahead, self.kernel_size) def forward(self, x: Tensor, state: Dict[int, Any] = None) -> Tensor: # x [B,H,T] B, H, T = x.shape if state is None or id(self) not in state: x = F.pad(x, pad=(self.kernel_size[0] - 1 - self.look_ahead, self.look_ahead)) else: x = torch.concat([state[id(self)], x], dim=-1) if state is not None: state[id(self)] = x[..., -self.kernel_size + 1:] x = super().forward(x) return x class CleanMelLayer(nn.Module): def __init__( self, dim_hidden: int, dim_squeeze: int, n_freqs: int, dropout: Tuple[float, float, float] = (0, 0, 0), f_kernel_size: int = 5, f_conv_groups: int = 8, padding: str = 'zeros', full: nn.Module = None, mamba_state: int = None, mamba_conv_kernel: int = None, online: bool = False, ) -> None: super().__init__() self.online = online # cross-band block # frequency-convolutional module self.fconv1 = nn.ModuleList([ LayerNorm(seq_last=True, normalized_shape=dim_hidden), nn.Conv1d(in_channels=dim_hidden, out_channels=dim_hidden, kernel_size=f_kernel_size, groups=f_conv_groups, padding='same', padding_mode=padding), nn.PReLU(dim_hidden), ]) # full-band linear module self.norm_full = LayerNorm(seq_last=False, normalized_shape=dim_hidden) self.full_share = False if full == None else True self.squeeze = nn.Sequential(nn.Conv1d(in_channels=dim_hidden, out_channels=dim_squeeze, kernel_size=1), nn.SiLU()) self.dropout_full = nn.Dropout2d(dropout[2]) if dropout[2] > 0 else None self.full = LinearGroup(n_freqs, n_freqs, num_groups=dim_squeeze) if full == None else full self.unsqueeze = nn.Sequential(nn.Conv1d(in_channels=dim_squeeze, out_channels=dim_hidden, kernel_size=1), nn.SiLU()) # frequency-convolutional module self.fconv2 = nn.ModuleList([ LayerNorm(seq_last=True, normalized_shape=dim_hidden), nn.Conv1d(in_channels=dim_hidden, out_channels=dim_hidden, kernel_size=f_kernel_size, groups=f_conv_groups, padding='same', padding_mode=padding), nn.PReLU(dim_hidden), ]) # narrow-band block self.norm_mamba = LayerNorm(seq_last=False, normalized_shape=dim_hidden) if online: self.mamba = Mamba(d_model=dim_hidden, d_state=mamba_state, d_conv=mamba_conv_kernel, layer_idx=0) else: self.mamba = nn.ModuleList([ Mamba(d_model=dim_hidden, d_state=mamba_state, d_conv=mamba_conv_kernel, layer_idx=0), Mamba(d_model=dim_hidden, d_state=mamba_state, d_conv=mamba_conv_kernel, layer_idx=1), ]) self.dropout_mamba = nn.Dropout(dropout[0]) def forward(self, x: Tensor, inference: bool = False) -> Tensor: x = x + self._fconv(self.fconv1, x) x = x + self._full(x) x = x + self._fconv(self.fconv2, x) if self.online: x = x + self._mamba(x, self.mamba, self.norm_mamba, self.dropout_mamba, inference) else: x_fw = x + self._mamba(x, self.mamba[0], self.norm_mamba, self.dropout_mamba, inference) x_bw = x.flip(dims=[2]) + self._mamba(x.flip(dims=[2]), self.mamba[1], self.norm_mamba, self.dropout_mamba, inference) x = (x_fw + x_bw.flip(dims=[2])) / 2 return x def _mamba(self, x: Tensor, mamba: Mamba, norm: nn.Module, dropout: nn.Module, inference: bool = False): B, F, T, H = x.shape x = norm(x) x = x.reshape(B * F, T, H) if inference: inference_params = InferenceParams(T, B * F) xs = [] for i in range(T): inference_params.seqlen_offset = i xi = mamba.forward(x[:, [i], :], inference_params) xs.append(xi) x = torch.concat(xs, dim=1) else: x = mamba.forward(x) x = x.reshape(B, F, T, H) return dropout(x) def _fconv(self, ml: nn.ModuleList, x: Tensor) -> Tensor: B, F, T, H = x.shape x = x.permute(0, 2, 3, 1) # [B,T,H,F] x = x.reshape(B * T, H, F) for m in ml: x = m(x) x = x.reshape(B, T, H, F) x = x.permute(0, 3, 1, 2) # [B,F,T,H] return x def _full(self, x: Tensor) -> Tensor: B, F, T, H = x.shape x = self.norm_full(x) x = x.permute(0, 2, 3, 1) # [B,T,H,F] x = x.reshape(B * T, H, F) x = self.squeeze(x) # [B*T,H',F] if self.dropout_full: x = x.reshape(B, T, -1, F) x = x.transpose(1, 3) # [B,F,H',T] x = self.dropout_full(x) # dropout some frequencies in one utterance x = x.transpose(1, 3) # [B,T,H',F] x = x.reshape(B * T, -1, F) x = self.full(x) # [B*T,H',F] x = self.unsqueeze(x) # [B*T,H,F] x = x.reshape(B, T, H, F) x = x.permute(0, 3, 1, 2) # [B,F,T,H] return x def extra_repr(self) -> str: return f"full_share={self.full_share}" class CleanMel(nn.Module): def __init__( self, dim_input: int, # the input dim for each time-frequency point dim_output: int, # the output dim for each time-frequency point n_layers: int, n_freqs: int, n_mels: int = 80, layer_linear_freq: int = 1, encoder_kernel_size: int = 5, dim_hidden: int = 192, dropout: Tuple[float, float, float] = (0, 0, 0), f_kernel_size: int = 5, f_conv_groups: int = 8, padding: str = 'zeros', mamba_state: int = 16, mamba_conv_kernel: int = 4, online: bool = True, sr: int = 16000, n_fft: int = 512, ): super().__init__() self.layer_linear_freq = layer_linear_freq self.online = online # encoder self.encoder = CausalConv1d(in_channels=dim_input, out_channels=dim_hidden, kernel_size=encoder_kernel_size, look_ahead=0) # cleanmel layers full = None layers = [] for l in range(n_layers): layer = CleanMelLayer( dim_hidden=dim_hidden, dim_squeeze=8 if l < layer_linear_freq else dim_hidden, n_freqs=n_freqs if l < layer_linear_freq else n_mels, dropout=dropout, f_kernel_size=f_kernel_size, f_conv_groups=f_conv_groups, padding=padding, full=full if l > layer_linear_freq else None, online=online, mamba_conv_kernel=mamba_conv_kernel, mamba_state=mamba_state, ) if hasattr(layer, 'full'): full = layer.full layers.append(layer) self.layers = nn.ModuleList(layers) # Mel filterbank linear2mel = librosa.filters.mel(**{"sr": sr, "n_fft": n_fft, "n_mels": n_mels}) self.register_buffer("linear2mel", torch.nn.Parameter(torch.tensor(linear2mel.T, dtype=torch.float32))) # decoder self.decoder = nn.Linear(in_features=dim_hidden, out_features=dim_output) def forward(self, x: Tensor, inference: bool = False) -> Tensor: # x: [Batch, Freq, Time, Feature] B, F, T, H0 = x.shape x = self.encoder(x.reshape(B * F, T, H0).permute(0, 2, 1)).permute(0, 2, 1) H = x.shape[2] x = x.reshape(B, F, T, H) # First Cross-Narrow band block in Linear Frequency for i in range(self.layer_linear_freq): m = self.layers[i] x = m(x, inference).contiguous() # Mel-filterbank x = torch.einsum("bfth,fm->bmth", x, self.linear2mel) for i in range(self.layer_linear_freq, len(self.layers)): m = self.layers[i] x = m(x, inference).contiguous() y = self.decoder(x).squeeze(-1) return y.contiguous()