Audio-to-Audio
MambaSSM
Safetensors
speech-enhancement
universal speech enhancement
multiple input sampling rates
language-agnostic
Instructions to use nymphai/RE-USE with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MambaSSM
How to use nymphai/RE-USE with MambaSSM:
from mamba_ssm import MambaLMHeadModel model = MambaLMHeadModel.from_pretrained("nymphai/RE-USE") - Notebooks
- Google Colab
- Kaggle
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from einops import rearrange | |
| def get_padding_2d(kernel_size, dilation=(1, 1)): | |
| """ | |
| Calculate the padding size for a 2D convolutional layer. | |
| Args: | |
| - kernel_size (tuple): Size of the convolutional kernel (height, width). | |
| - dilation (tuple, optional): Dilation rate of the convolution (height, width). Defaults to (1, 1). | |
| Returns: | |
| - tuple: Calculated padding size (height, width). | |
| """ | |
| return (int((kernel_size[0] * dilation[0] - dilation[0]) / 2), | |
| int((kernel_size[1] * dilation[1] - dilation[1]) / 2)) | |
| class SPConvTranspose2d(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, r=1): | |
| super(SPConvTranspose2d, self).__init__() | |
| self.pad1 = nn.ConstantPad2d((1, 1, 0, 0), value=0.) | |
| self.out_channels = out_channels | |
| self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1)) | |
| self.r = r | |
| def forward(self, x): | |
| x = self.pad1(x) | |
| out = self.conv(x) | |
| batch_size, nchannels, H, W = out.shape | |
| out = out.view((batch_size, self.r, nchannels // self.r, H, W)) | |
| out = out.permute(0, 2, 3, 4, 1) | |
| out = out.contiguous().view((batch_size, nchannels // self.r, H, -1)) | |
| return out | |
| class DenseBlock(nn.Module): | |
| """ | |
| DenseBlock module consisting of multiple convolutional layers with dilation. | |
| """ | |
| def __init__(self, cfg, kernel_size=(3, 3), depth=4): | |
| super(DenseBlock, self).__init__() | |
| self.cfg = cfg | |
| self.depth = depth | |
| self.dense_block = nn.ModuleList() | |
| self.hid_feature = cfg['model_cfg']['hid_feature'] | |
| for i in range(depth): | |
| dil = 2 ** i | |
| dense_conv = nn.Sequential( | |
| nn.Conv2d(self.hid_feature * (i + 1), self.hid_feature, kernel_size, | |
| dilation=(dil, 1), padding=get_padding_2d(kernel_size, (dil, 1))), | |
| nn.InstanceNorm2d(self.hid_feature, affine=True), | |
| nn.PReLU(self.hid_feature) | |
| ) | |
| self.dense_block.append(dense_conv) | |
| def forward(self, x): | |
| skip = x | |
| for i in range(self.depth): | |
| x = self.dense_block[i](skip) | |
| skip = torch.cat([x, skip], dim=1) | |
| return x | |
| class DenseEncoder(nn.Module): | |
| """ | |
| DenseEncoder module consisting of initial convolution, dense block, and a final convolution. | |
| """ | |
| def __init__(self, cfg): | |
| super(DenseEncoder, self).__init__() | |
| self.cfg = cfg | |
| self.input_channel = cfg['model_cfg']['input_channel'] | |
| self.hid_feature = cfg['model_cfg']['hid_feature'] | |
| self.dense_conv_1 = nn.Sequential( | |
| nn.Conv2d(self.input_channel, self.hid_feature, (1, 1)), | |
| nn.InstanceNorm2d(self.hid_feature, affine=True), | |
| nn.PReLU(self.hid_feature) | |
| ) | |
| self.dense_block = DenseBlock(cfg, depth=4) | |
| self.dense_conv_2 = nn.Sequential( | |
| nn.Conv2d(self.hid_feature, self.hid_feature, (1, 3), stride=(4, 2)), | |
| nn.InstanceNorm2d(self.hid_feature, affine=True), | |
| nn.PReLU(self.hid_feature) | |
| ) | |
| def forward(self, x): | |
| x = self.dense_conv_1(x) # [batch, hid_feature, time, freq] | |
| x = self.dense_block(x) # [batch, hid_feature, time, freq] | |
| x = self.dense_conv_2(x) # [batch, hid_feature, time, freq//2] | |
| return x | |
| class MagDecoder(nn.Module): | |
| """ | |
| MagDecoder module for decoding magnitude information. | |
| """ | |
| def __init__(self, cfg): | |
| super(MagDecoder, self).__init__() | |
| self.dense_block = DenseBlock(cfg, depth=4) | |
| self.hid_feature = cfg['model_cfg']['hid_feature'] | |
| self.output_channel = cfg['model_cfg']['output_channel'] | |
| self.n_fft = cfg['stft_cfg']['n_fft'] | |
| self.beta = cfg['model_cfg']['beta'] | |
| self.up_conv1 = nn.Sequential( | |
| SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 2), | |
| nn.InstanceNorm2d(self.hid_feature, affine=True), | |
| nn.PReLU(self.hid_feature) | |
| ) | |
| self.up_conv2 = nn.Sequential( | |
| SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 4), | |
| nn.InstanceNorm2d(self.hid_feature, affine=True), | |
| nn.PReLU(self.hid_feature) | |
| ) | |
| self.final_conv = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1)) | |
| def forward(self, x): | |
| x = self.dense_block(x) | |
| x = self.up_conv1(x) | |
| x = self.up_conv2(x.permute(0,1,3,2)).permute(0,1,3,2) | |
| x = self.final_conv(x) | |
| return x | |
| class PhaseDecoder(nn.Module): | |
| """ | |
| PhaseDecoder module for decoding phase information. | |
| """ | |
| def __init__(self, cfg): | |
| super(PhaseDecoder, self).__init__() | |
| self.dense_block = DenseBlock(cfg, depth=4) | |
| self.hid_feature = cfg['model_cfg']['hid_feature'] | |
| self.output_channel = cfg['model_cfg']['output_channel'] | |
| self.up_conv1 = nn.Sequential( | |
| SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 2), | |
| nn.InstanceNorm2d(self.hid_feature, affine=True), | |
| nn.PReLU(self.hid_feature) | |
| ) | |
| self.up_conv2 = nn.Sequential( | |
| SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 4), | |
| nn.InstanceNorm2d(self.hid_feature, affine=True), | |
| nn.PReLU(self.hid_feature) | |
| ) | |
| self.phase_conv_r = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1)) | |
| self.phase_conv_i = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1)) | |
| def forward(self, x): | |
| x = self.dense_block(x) | |
| x = self.up_conv1(x) | |
| x = self.up_conv2(x.permute(0,1,3,2)).permute(0,1,3,2) | |
| x_r = self.phase_conv_r(x) | |
| x_i = self.phase_conv_i(x) | |
| x = torch.atan2(x_i, x_r) | |
| return x |