Spaces:
Running on Zero
Running on Zero
Upload 9 files
Browse files- app.py +150 -0
- exp/30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_peak_GAN_tel_mic/g_01134000.pth +3 -0
- models/codec_module_time_d4.py +168 -0
- models/generator_SEMamba_time_d4.py +91 -0
- models/mamba_block2_SEMamba.py +81 -0
- models/stfts.py +95 -0
- recipes/USEMamba_30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_001.yaml +44 -0
- requirements.txt +14 -0
- utils/util.py +37 -0
app.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import shlex
|
| 10 |
+
import subprocess
|
| 11 |
+
import spaces
|
| 12 |
+
import gradio as gr
|
| 13 |
+
|
| 14 |
+
def install_mamba():
|
| 15 |
+
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
|
| 16 |
+
install_mamba()
|
| 17 |
+
|
| 18 |
+
ABOUT = """
|
| 19 |
+
# RE-USE: A universal speech enhancement model for diverse degradations, sampling rates, and languages.
|
| 20 |
+
Upload or record a noisy clip, then click **Enhance** to listen to the result and view its spectrogram.
|
| 21 |
+
(ref: https://huggingface.co/spaces/rc19477/Speech_Enhancement_Mamba)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torchaudio
|
| 26 |
+
import librosa
|
| 27 |
+
import matplotlib.pyplot as plt
|
| 28 |
+
import numpy as np
|
| 29 |
+
from models.stfts import mag_phase_stft, mag_phase_istft
|
| 30 |
+
from models.generator_SEMamba_MPSEnet_time_d4 import SEMamba
|
| 31 |
+
from utils.util import load_config, pad_or_trim_to_match
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def make_even(value):
|
| 35 |
+
value = int(round(value))
|
| 36 |
+
return value if value % 2 == 0 else value + 1
|
| 37 |
+
|
| 38 |
+
device = "cuda"
|
| 39 |
+
cfg1 = load_config('recipes/USEMamba_30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_001.yaml')
|
| 40 |
+
n_fft, hop_size, win_size = cfg1['stft_cfg']['n_fft'], cfg1['stft_cfg']['hop_size'], cfg1['stft_cfg']['win_size']
|
| 41 |
+
compress_factor = cfg1['model_cfg']['compress_factor']
|
| 42 |
+
sampling_rate = cfg1['stft_cfg']['sampling_rate']
|
| 43 |
+
|
| 44 |
+
USE_model = SEMamba(cfg1).to(device)
|
| 45 |
+
checkpoint_file = "exp/30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_peak_GAN_tel_mic/g_01134000.pth"
|
| 46 |
+
state_dict = torch.load(checkpoint_file, map_location=device)
|
| 47 |
+
USE_model.load_state_dict(state_dict['generator'])
|
| 48 |
+
USE_model.eval()
|
| 49 |
+
|
| 50 |
+
@spaces.GPU
|
| 51 |
+
def enhance(filepath, low_pass_sampling_rate, target_sampling_rate):
|
| 52 |
+
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
noisy_wav, noisy_sr = torchaudio.load(filepath)
|
| 55 |
+
torchaudio.save("original.wav", noisy_wav.cpu(), noisy_sr)
|
| 56 |
+
original_noisy_wav = noisy_wav
|
| 57 |
+
original_sr = noisy_sr
|
| 58 |
+
|
| 59 |
+
if target_sampling_rate != '':
|
| 60 |
+
if low_pass_sampling_rate != '':
|
| 61 |
+
opts = {"res_type": "kaiser_best"}
|
| 62 |
+
noisy_wav = torch.tensor(librosa.resample(noisy_wav.cpu().numpy(), orig_sr=noisy_sr, target_sr=int(low_pass_sampling_rate), **opts))
|
| 63 |
+
noisy_sr = int(low_pass_sampling_rate)
|
| 64 |
+
opts = {"res_type": "kaiser_best"}
|
| 65 |
+
noisy_wav = librosa.resample(noisy_wav.cpu().numpy(), orig_sr=noisy_sr, target_sr=int(target_sampling_rate), **opts)
|
| 66 |
+
noisy_sr = int(target_sampling_rate)
|
| 67 |
+
|
| 68 |
+
noisy_wav = torch.FloatTensor(noisy_wav).to(device)
|
| 69 |
+
n_fft_scaled = make_even(n_fft * noisy_sr // sampling_rate)
|
| 70 |
+
hop_size_scaled = make_even(hop_size * noisy_sr // sampling_rate)
|
| 71 |
+
win_size_scaled = make_even(win_size * noisy_sr // sampling_rate)
|
| 72 |
+
|
| 73 |
+
noisy_mag, noisy_pha, noisy_com = mag_phase_stft(
|
| 74 |
+
noisy_wav,
|
| 75 |
+
n_fft=n_fft_scaled,
|
| 76 |
+
hop_size=hop_size_scaled,
|
| 77 |
+
win_size=win_size_scaled,
|
| 78 |
+
compress_factor=compress_factor,
|
| 79 |
+
center=True,
|
| 80 |
+
addeps=False
|
| 81 |
+
)
|
| 82 |
+
amp_g, pha_g, _ = USE_model(noisy_mag, noisy_pha)
|
| 83 |
+
|
| 84 |
+
audio_g = mag_phase_istft(amp_g, pha_g, n_fft_scaled, hop_size_scaled, win_size_scaled, compress_factor)
|
| 85 |
+
audio_g = pad_or_trim_to_match(noisy_wav.detach(), audio_g, pad_value=1e-8) # Align lengths using epsilon padding
|
| 86 |
+
assert audio_g.shape == noisy_wav.shape, audio_g.shape
|
| 87 |
+
|
| 88 |
+
# write file
|
| 89 |
+
torchaudio.save("enhanced.wav", audio_g.cpu(), noisy_sr)
|
| 90 |
+
|
| 91 |
+
# spectrograms
|
| 92 |
+
fig, axs = plt.subplots(1, 2, figsize=(16, 4))
|
| 93 |
+
|
| 94 |
+
# noisy
|
| 95 |
+
D_noisy = librosa.stft(original_noisy_wav[0].cpu().numpy(), n_fft=512, hop_length=256)
|
| 96 |
+
S_noisy = librosa.amplitude_to_db(np.abs(D_noisy), ref=np.max)
|
| 97 |
+
librosa.display.specshow(S_noisy, sr=original_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[0], vmax=0)
|
| 98 |
+
axs[0].set_title("Noisy Spectrogram")
|
| 99 |
+
|
| 100 |
+
# enhanced
|
| 101 |
+
D_clean = librosa.stft(audio_g.cpu().numpy(), n_fft=512, hop_length=256)
|
| 102 |
+
S_clean = librosa.amplitude_to_db(np.abs(D_clean), ref=np.max)
|
| 103 |
+
librosa.display.specshow(S_clean[0], sr=noisy_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[1], vmax=0)
|
| 104 |
+
axs[1].set_title("Enhanced Spectrogram")
|
| 105 |
+
|
| 106 |
+
plt.tight_layout()
|
| 107 |
+
return "original.wav", "enhanced.wav", fig
|
| 108 |
+
|
| 109 |
+
with gr.Blocks() as demo:
|
| 110 |
+
gr.Markdown(ABOUT)
|
| 111 |
+
gr.Markdown("**Note 1**: For bandwidth extension, the performance may be affected by the characteristics of the input data, particularly the cutoff pattern. A simple solution is to apply low-pass filtering beforehand.")
|
| 112 |
+
gr.Markdown("**Note 2**: When processing long input audio, out-of-memory (OOM) errors may occur. To address this, use the chunk-wise inference implementation provided on the Hugging Face.")
|
| 113 |
+
|
| 114 |
+
with gr.Row():
|
| 115 |
+
with gr.Column():
|
| 116 |
+
# Create Tabs to separate Audio and Video sessions
|
| 117 |
+
with gr.Tabs():
|
| 118 |
+
with gr.TabItem("Audio Upload"):
|
| 119 |
+
# gr.Audio works great for standard audio files
|
| 120 |
+
input_audio = gr.Audio(label="Input Audio", type="filepath")
|
| 121 |
+
|
| 122 |
+
with gr.TabItem("Video Upload (.mp4, .mov)"):
|
| 123 |
+
# gr.File handles .mp4 and .mov without errors
|
| 124 |
+
input_video = gr.File(label="Input Video", file_types=[".mp4", ".mov"])
|
| 125 |
+
|
| 126 |
+
target_sampling_rate = gr.Textbox(label="(Optional) Enter target sampling rate for bandwidth extension:")
|
| 127 |
+
low_pass_sampling_rate = gr.Textbox(label="(Optional) Enter target sampling rate for pre-low-pass filtering before bandwidth extension:")
|
| 128 |
+
|
| 129 |
+
# Helper to unify the input: we use a hidden state to store which one was used
|
| 130 |
+
active_input = gr.State()
|
| 131 |
+
enhance_btn = gr.Button("Enhance")
|
| 132 |
+
with gr.Row():
|
| 133 |
+
input_audio_player = gr.Audio(label="Original Input Audio", type="filepath")
|
| 134 |
+
output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
|
| 135 |
+
plot_output = gr.Plot(label="Spectrograms")
|
| 136 |
+
|
| 137 |
+
# This function determines which input (audio tab or video tab) to send to your model
|
| 138 |
+
def unified_enhance(audio_path, video_path, lp_sr, target_sr):
|
| 139 |
+
# Determine which path is valid (the one from the active tab)
|
| 140 |
+
# Note: input_video returns a file object, so we get its .name
|
| 141 |
+
final_path = audio_path if audio_path else video_path
|
| 142 |
+
return enhance(final_path, lp_sr, target_sr)
|
| 143 |
+
|
| 144 |
+
enhance_btn.click(
|
| 145 |
+
fn=unified_enhance,
|
| 146 |
+
inputs=[input_audio, input_video, low_pass_sampling_rate, target_sampling_rate],
|
| 147 |
+
outputs=[input_audio_player, output_audio, plot_output]
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
demo.queue().launch(share=True)
|
exp/30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_peak_GAN_tel_mic/g_01134000.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e27db9e1de904eb59fc627dea72c69da7ca25650a3e704b4096f89812b395fe5
|
| 3 |
+
size 38982886
|
models/codec_module_time_d4.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import numpy as np
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
|
| 14 |
+
def get_padding_2d(kernel_size, dilation=(1, 1)):
|
| 15 |
+
"""
|
| 16 |
+
Calculate the padding size for a 2D convolutional layer.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
- kernel_size (tuple): Size of the convolutional kernel (height, width).
|
| 20 |
+
- dilation (tuple, optional): Dilation rate of the convolution (height, width). Defaults to (1, 1).
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
- tuple: Calculated padding size (height, width).
|
| 24 |
+
"""
|
| 25 |
+
return (int((kernel_size[0] * dilation[0] - dilation[0]) / 2),
|
| 26 |
+
int((kernel_size[1] * dilation[1] - dilation[1]) / 2))
|
| 27 |
+
|
| 28 |
+
class SPConvTranspose2d(nn.Module):
|
| 29 |
+
def __init__(self, in_channels, out_channels, kernel_size, r=1):
|
| 30 |
+
super(SPConvTranspose2d, self).__init__()
|
| 31 |
+
self.pad1 = nn.ConstantPad2d((1, 1, 0, 0), value=0.)
|
| 32 |
+
self.out_channels = out_channels
|
| 33 |
+
self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1))
|
| 34 |
+
self.r = r
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
x = self.pad1(x)
|
| 38 |
+
out = self.conv(x)
|
| 39 |
+
batch_size, nchannels, H, W = out.shape
|
| 40 |
+
out = out.view((batch_size, self.r, nchannels // self.r, H, W))
|
| 41 |
+
out = out.permute(0, 2, 3, 4, 1)
|
| 42 |
+
out = out.contiguous().view((batch_size, nchannels // self.r, H, -1))
|
| 43 |
+
return out
|
| 44 |
+
|
| 45 |
+
class DenseBlock(nn.Module):
|
| 46 |
+
"""
|
| 47 |
+
DenseBlock module consisting of multiple convolutional layers with dilation.
|
| 48 |
+
"""
|
| 49 |
+
def __init__(self, cfg, kernel_size=(3, 3), depth=4):
|
| 50 |
+
super(DenseBlock, self).__init__()
|
| 51 |
+
self.cfg = cfg
|
| 52 |
+
self.depth = depth
|
| 53 |
+
self.dense_block = nn.ModuleList()
|
| 54 |
+
self.hid_feature = cfg['model_cfg']['hid_feature']
|
| 55 |
+
|
| 56 |
+
for i in range(depth):
|
| 57 |
+
dil = 2 ** i
|
| 58 |
+
dense_conv = nn.Sequential(
|
| 59 |
+
nn.Conv2d(self.hid_feature * (i + 1), self.hid_feature, kernel_size,
|
| 60 |
+
dilation=(dil, 1), padding=get_padding_2d(kernel_size, (dil, 1))),
|
| 61 |
+
nn.InstanceNorm2d(self.hid_feature, affine=True),
|
| 62 |
+
nn.PReLU(self.hid_feature)
|
| 63 |
+
)
|
| 64 |
+
self.dense_block.append(dense_conv)
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
skip = x
|
| 68 |
+
for i in range(self.depth):
|
| 69 |
+
x = self.dense_block[i](skip)
|
| 70 |
+
skip = torch.cat([x, skip], dim=1)
|
| 71 |
+
return x
|
| 72 |
+
|
| 73 |
+
class DenseEncoder(nn.Module):
|
| 74 |
+
"""
|
| 75 |
+
DenseEncoder module consisting of initial convolution, dense block, and a final convolution.
|
| 76 |
+
"""
|
| 77 |
+
def __init__(self, cfg):
|
| 78 |
+
super(DenseEncoder, self).__init__()
|
| 79 |
+
self.cfg = cfg
|
| 80 |
+
self.input_channel = cfg['model_cfg']['input_channel']
|
| 81 |
+
self.hid_feature = cfg['model_cfg']['hid_feature']
|
| 82 |
+
|
| 83 |
+
self.dense_conv_1 = nn.Sequential(
|
| 84 |
+
nn.Conv2d(self.input_channel, self.hid_feature, (1, 1)),
|
| 85 |
+
nn.InstanceNorm2d(self.hid_feature, affine=True),
|
| 86 |
+
nn.PReLU(self.hid_feature)
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
self.dense_block = DenseBlock(cfg, depth=4)
|
| 90 |
+
|
| 91 |
+
self.dense_conv_2 = nn.Sequential(
|
| 92 |
+
nn.Conv2d(self.hid_feature, self.hid_feature, (1, 3), stride=(4, 2)),
|
| 93 |
+
nn.InstanceNorm2d(self.hid_feature, affine=True),
|
| 94 |
+
nn.PReLU(self.hid_feature)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def forward(self, x):
|
| 98 |
+
x = self.dense_conv_1(x) # [batch, hid_feature, time, freq]
|
| 99 |
+
x = self.dense_block(x) # [batch, hid_feature, time, freq]
|
| 100 |
+
x = self.dense_conv_2(x) # [batch, hid_feature, time, freq//2]
|
| 101 |
+
return x
|
| 102 |
+
|
| 103 |
+
class MagDecoder(nn.Module):
|
| 104 |
+
"""
|
| 105 |
+
MagDecoder module for decoding magnitude information.
|
| 106 |
+
"""
|
| 107 |
+
def __init__(self, cfg):
|
| 108 |
+
super(MagDecoder, self).__init__()
|
| 109 |
+
self.dense_block = DenseBlock(cfg, depth=4)
|
| 110 |
+
self.hid_feature = cfg['model_cfg']['hid_feature']
|
| 111 |
+
self.output_channel = cfg['model_cfg']['output_channel']
|
| 112 |
+
self.n_fft = cfg['stft_cfg']['n_fft']
|
| 113 |
+
self.beta = cfg['model_cfg']['beta']
|
| 114 |
+
|
| 115 |
+
self.up_conv1 = nn.Sequential(
|
| 116 |
+
SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 2),
|
| 117 |
+
nn.InstanceNorm2d(self.hid_feature, affine=True),
|
| 118 |
+
nn.PReLU(self.hid_feature)
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
self.up_conv2 = nn.Sequential(
|
| 122 |
+
SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 4),
|
| 123 |
+
nn.InstanceNorm2d(self.hid_feature, affine=True),
|
| 124 |
+
nn.PReLU(self.hid_feature)
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
self.final_conv = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
|
| 128 |
+
|
| 129 |
+
def forward(self, x):
|
| 130 |
+
x = self.dense_block(x)
|
| 131 |
+
x = self.up_conv1(x)
|
| 132 |
+
x = self.up_conv2(x.permute(0,1,3,2)).permute(0,1,3,2)
|
| 133 |
+
x = self.final_conv(x)
|
| 134 |
+
return x
|
| 135 |
+
|
| 136 |
+
class PhaseDecoder(nn.Module):
|
| 137 |
+
"""
|
| 138 |
+
PhaseDecoder module for decoding phase information.
|
| 139 |
+
"""
|
| 140 |
+
def __init__(self, cfg):
|
| 141 |
+
super(PhaseDecoder, self).__init__()
|
| 142 |
+
self.dense_block = DenseBlock(cfg, depth=4)
|
| 143 |
+
self.hid_feature = cfg['model_cfg']['hid_feature']
|
| 144 |
+
self.output_channel = cfg['model_cfg']['output_channel']
|
| 145 |
+
|
| 146 |
+
self.up_conv1 = nn.Sequential(
|
| 147 |
+
SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 2),
|
| 148 |
+
nn.InstanceNorm2d(self.hid_feature, affine=True),
|
| 149 |
+
nn.PReLU(self.hid_feature)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
self.up_conv2 = nn.Sequential(
|
| 153 |
+
SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 4),
|
| 154 |
+
nn.InstanceNorm2d(self.hid_feature, affine=True),
|
| 155 |
+
nn.PReLU(self.hid_feature)
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
self.phase_conv_r = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
|
| 159 |
+
self.phase_conv_i = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
|
| 160 |
+
|
| 161 |
+
def forward(self, x):
|
| 162 |
+
x = self.dense_block(x)
|
| 163 |
+
x = self.up_conv1(x)
|
| 164 |
+
x = self.up_conv2(x.permute(0,1,3,2)).permute(0,1,3,2)
|
| 165 |
+
x_r = self.phase_conv_r(x)
|
| 166 |
+
x_i = self.phase_conv_i(x)
|
| 167 |
+
x = torch.atan2(x_i, x_r)
|
| 168 |
+
return x
|
models/generator_SEMamba_time_d4.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from .mamba_block2_SEMamba import TFMambaBlock
|
| 13 |
+
from .codec_module_time_d4 import DenseEncoder, MagDecoder, PhaseDecoder
|
| 14 |
+
|
| 15 |
+
class SEMamba(nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
SEMamba model for speech enhancement using Mamba blocks.
|
| 18 |
+
|
| 19 |
+
This model uses a dense encoder, multiple Mamba blocks, and separate magnitude
|
| 20 |
+
and phase decoders to process noisy magnitude and phase inputs.
|
| 21 |
+
"""
|
| 22 |
+
def __init__(self, cfg):
|
| 23 |
+
"""
|
| 24 |
+
Initialize the SEMamba model.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
- cfg: Configuration object containing model parameters.
|
| 28 |
+
"""
|
| 29 |
+
super(SEMamba, self).__init__()
|
| 30 |
+
self.cfg = cfg
|
| 31 |
+
self.num_tscblocks = cfg['model_cfg']['num_tfmamba'] if cfg['model_cfg']['num_tfmamba'] is not None else 4 # default tfmamba: 4
|
| 32 |
+
|
| 33 |
+
# Initialize dense encoder
|
| 34 |
+
self.dense_encoder = DenseEncoder(cfg)
|
| 35 |
+
|
| 36 |
+
# Initialize Mamba blocks
|
| 37 |
+
self.TSMamba = nn.ModuleList([TFMambaBlock(cfg) for _ in range(self.num_tscblocks)])
|
| 38 |
+
|
| 39 |
+
# Initialize decoders
|
| 40 |
+
self.mask_decoder = MagDecoder(cfg)
|
| 41 |
+
self.phase_decoder = PhaseDecoder(cfg)
|
| 42 |
+
|
| 43 |
+
def forward(self, noisy_mag, noisy_pha):
|
| 44 |
+
"""
|
| 45 |
+
Forward pass for the SEMamba model.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
- noisy_mag (torch.Tensor): Noisy magnitude input tensor [B, F, T].
|
| 49 |
+
- noisy_pha (torch.Tensor): Noisy phase input tensor [B, F, T].
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
- denoised_mag (torch.Tensor): Denoised magnitude tensor [B, F, T].
|
| 53 |
+
- denoised_pha (torch.Tensor): Denoised phase tensor [B, F, T].
|
| 54 |
+
- denoised_com (torch.Tensor): Denoised complex tensor [B, F, T, 2].
|
| 55 |
+
"""
|
| 56 |
+
# Reshape inputs
|
| 57 |
+
noisy_mag = rearrange(noisy_mag, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F]
|
| 58 |
+
noisy_pha = rearrange(noisy_pha, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F]
|
| 59 |
+
|
| 60 |
+
# Concatenate magnitude and phase inputs
|
| 61 |
+
x = torch.cat((noisy_mag, noisy_pha), dim=1) # [B, 2, T, F]
|
| 62 |
+
|
| 63 |
+
# Prevent unpredictable errors
|
| 64 |
+
B, C, T, F = x.shape
|
| 65 |
+
zeros = torch.zeros(B, C, T, 2, device=x.device)
|
| 66 |
+
x = torch.cat((x, zeros), dim=-1)
|
| 67 |
+
zeros = torch.zeros(B, C, 2, F+2, device=x.device)
|
| 68 |
+
x = torch.cat((x, zeros), dim=-2)
|
| 69 |
+
|
| 70 |
+
# Encode input
|
| 71 |
+
x = self.dense_encoder(x)
|
| 72 |
+
|
| 73 |
+
# Apply Mamba blocks
|
| 74 |
+
for block in self.TSMamba:
|
| 75 |
+
x = block(x)
|
| 76 |
+
|
| 77 |
+
# Decode output
|
| 78 |
+
denoised_mag = rearrange(self.mask_decoder(x), 'b c t f -> b f t c').squeeze(-1)
|
| 79 |
+
denoised_pha = rearrange(self.phase_decoder(x), 'b c t f -> b f t c').squeeze(-1)
|
| 80 |
+
|
| 81 |
+
# Prevent unpredictable errors
|
| 82 |
+
denoised_mag = denoised_mag[:, :F, :T]
|
| 83 |
+
denoised_pha = denoised_pha[:, :F, :T]
|
| 84 |
+
|
| 85 |
+
# Combine denoised magnitude and phase into a complex representation
|
| 86 |
+
denoised_com = torch.stack(
|
| 87 |
+
(denoised_mag * torch.cos(denoised_pha), denoised_mag * torch.sin(denoised_pha)),
|
| 88 |
+
dim=-1
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
return denoised_mag, denoised_pha, denoised_com
|
models/mamba_block2_SEMamba.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch.nn import init
|
| 13 |
+
from torch.nn.parameter import Parameter
|
| 14 |
+
from functools import partial
|
| 15 |
+
from einops import rearrange
|
| 16 |
+
from mamba_ssm import Mamba
|
| 17 |
+
|
| 18 |
+
class MambaBlock(nn.Module):
|
| 19 |
+
def __init__(self, d_model, cfg):
|
| 20 |
+
super(MambaBlock, self).__init__()
|
| 21 |
+
|
| 22 |
+
d_state = cfg['model_cfg']['d_state'] # 16
|
| 23 |
+
d_conv = cfg['model_cfg']['d_conv'] # 4
|
| 24 |
+
expand = cfg['model_cfg']['expand'] # 4
|
| 25 |
+
|
| 26 |
+
self.forward_blocks = Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
|
| 27 |
+
self.backward_blocks = Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
|
| 28 |
+
self.output_proj = nn.Linear(2 * d_model, d_model)
|
| 29 |
+
self.norm = nn.LayerNorm(d_model)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
# x: [B, T, D]
|
| 33 |
+
out_fw = self.forward_blocks(x) + x
|
| 34 |
+
|
| 35 |
+
out_bw = self.backward_blocks(torch.flip(x, dims=[1])) + torch.flip(x, dims=[1])
|
| 36 |
+
out_bw = torch.flip(out_bw, dims=[1])
|
| 37 |
+
|
| 38 |
+
out = torch.cat([out_fw, out_bw], dim=-1)
|
| 39 |
+
out = self.output_proj(out)
|
| 40 |
+
|
| 41 |
+
# LayerNorm
|
| 42 |
+
return self.norm(out)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class TFMambaBlock(nn.Module):
|
| 46 |
+
"""
|
| 47 |
+
Temporal-Frequency Mamba block for sequence modeling.
|
| 48 |
+
|
| 49 |
+
Attributes:
|
| 50 |
+
cfg (Config): Configuration for the block.
|
| 51 |
+
time_mamba (MambaBlock): Mamba block for temporal dimension.
|
| 52 |
+
freq_mamba (MambaBlock): Mamba block for frequency dimension.
|
| 53 |
+
tlinear (ConvTranspose1d): ConvTranspose1d layer for temporal dimension.
|
| 54 |
+
flinear (ConvTranspose1d): ConvTranspose1d layer for frequency dimension.
|
| 55 |
+
"""
|
| 56 |
+
def __init__(self, cfg):
|
| 57 |
+
super(TFMambaBlock, self).__init__()
|
| 58 |
+
self.cfg = cfg
|
| 59 |
+
self.hid_feature = cfg['model_cfg']['hid_feature']
|
| 60 |
+
|
| 61 |
+
# Initialize Mamba blocks
|
| 62 |
+
self.time_mamba = MambaBlock(d_model=self.hid_feature, cfg=cfg)
|
| 63 |
+
self.freq_mamba = MambaBlock(d_model=self.hid_feature, cfg=cfg)
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
"""
|
| 67 |
+
Forward pass of the TFMamba block.
|
| 68 |
+
|
| 69 |
+
Parameters:
|
| 70 |
+
x (Tensor): Input tensor with shape (batch, channels, time, freq).
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Tensor: Output tensor after applying temporal and frequency Mamba blocks.
|
| 74 |
+
"""
|
| 75 |
+
b, c, t, f = x.size()
|
| 76 |
+
x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c)
|
| 77 |
+
x = self.time_mamba(x) + x
|
| 78 |
+
x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c)
|
| 79 |
+
x = self.freq_mamba(x) + x
|
| 80 |
+
x = x.view(b, t, f, c).permute(0, 3, 1, 2)
|
| 81 |
+
return x
|
models/stfts.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
def decompress_signed_log1p(y):
|
| 13 |
+
return torch.sign(y) * (torch.expm1(torch.abs(y)))
|
| 14 |
+
|
| 15 |
+
RELU = nn.ReLU()
|
| 16 |
+
|
| 17 |
+
def mag_phase_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True, addeps=False):
|
| 18 |
+
"""
|
| 19 |
+
Compute magnitude and phase using STFT.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
y (torch.Tensor): Input audio signal.
|
| 23 |
+
n_fft (int): FFT size.
|
| 24 |
+
hop_size (int): Hop size.
|
| 25 |
+
win_size (int): Window size.
|
| 26 |
+
compress_factor (float, optional): Magnitude compression factor. Defaults to 1.0.
|
| 27 |
+
center (bool, optional): Whether to center the signal before padding. Defaults to True.
|
| 28 |
+
eps (bool, optional): Whether adding epsilon to magnitude and phase or not. Defaults to False.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
tuple: Magnitude, phase, and complex representation of the STFT.
|
| 32 |
+
"""
|
| 33 |
+
eps = 1e-10
|
| 34 |
+
hann_window = torch.hann_window(win_size).to(y.device)
|
| 35 |
+
stft_spec = torch.stft(
|
| 36 |
+
y, n_fft,
|
| 37 |
+
hop_length=hop_size,
|
| 38 |
+
win_length=win_size,
|
| 39 |
+
window=hann_window,
|
| 40 |
+
center=center,
|
| 41 |
+
pad_mode='reflect',
|
| 42 |
+
normalized=False,
|
| 43 |
+
return_complex=True)
|
| 44 |
+
|
| 45 |
+
if addeps==False:
|
| 46 |
+
mag = torch.abs(stft_spec)
|
| 47 |
+
pha = torch.angle(stft_spec)
|
| 48 |
+
else:
|
| 49 |
+
real_part = stft_spec.real
|
| 50 |
+
imag_part = stft_spec.imag
|
| 51 |
+
mag = torch.sqrt(real_part.pow(2) + imag_part.pow(2) + eps)
|
| 52 |
+
pha = torch.atan2(imag_part + eps, real_part + eps)
|
| 53 |
+
# Compress the magnitude
|
| 54 |
+
if compress_factor in ['log1p','relu_log1p', 'signed_log1p']:
|
| 55 |
+
mag = torch.log1p(mag)
|
| 56 |
+
else:
|
| 57 |
+
mag = torch.pow(mag, compress_factor)
|
| 58 |
+
com = torch.stack((mag * torch.cos(pha), mag * torch.sin(pha)), dim=-1)
|
| 59 |
+
return mag, pha, com
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def mag_phase_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
|
| 63 |
+
"""
|
| 64 |
+
Inverse STFT to reconstruct the audio signal from magnitude and phase.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
mag (torch.Tensor): Magnitude of the STFT.
|
| 68 |
+
pha (torch.Tensor): Phase of the STFT.
|
| 69 |
+
n_fft (int): FFT size.
|
| 70 |
+
hop_size (int): Hop size.
|
| 71 |
+
win_size (int): Window size.
|
| 72 |
+
compress_factor (float, optional): Magnitude compression factor. Defaults to 1.0.
|
| 73 |
+
center (bool, optional): Whether to center the signal before padding. Defaults to True.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
torch.Tensor: Reconstructed audio signal.
|
| 77 |
+
"""
|
| 78 |
+
if compress_factor == 'log1p':
|
| 79 |
+
mag = torch.expm1(mag)
|
| 80 |
+
elif compress_factor == 'signed_log1p':
|
| 81 |
+
mag = decompress_signed_log1p(mag)
|
| 82 |
+
elif compress_factor == 'relu_log1p':
|
| 83 |
+
mag = torch.expm1(RELU(mag))
|
| 84 |
+
else:
|
| 85 |
+
mag = torch.pow(RELU(mag), 1.0 / compress_factor)
|
| 86 |
+
com = torch.complex(mag * torch.cos(pha), mag * torch.sin(pha))
|
| 87 |
+
hann_window = torch.hann_window(win_size).to(com.device)
|
| 88 |
+
wav = torch.istft(
|
| 89 |
+
com,
|
| 90 |
+
n_fft,
|
| 91 |
+
hop_length=hop_size,
|
| 92 |
+
win_length=win_size,
|
| 93 |
+
window=hann_window,
|
| 94 |
+
center=center)
|
| 95 |
+
return wav
|
recipes/USEMamba_30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_001.yaml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Environment Settings
|
| 2 |
+
# These settings specify the hardware and distributed setup for the model training.
|
| 3 |
+
# Adjust `num_gpus` and `dist_config` according to your distributed training environment.
|
| 4 |
+
env_setting:
|
| 5 |
+
num_gpus: 8 # Number of GPUs. Now we don't support CPU mode.
|
| 6 |
+
num_workers: 20 # 0 Number of worker threads for data loading.
|
| 7 |
+
persistent_workers: True # False If you have large RAM, turn this to be True
|
| 8 |
+
prefetch_factor: 8 # null
|
| 9 |
+
seed: 1234 # Seed for random number generators to ensure reproducibility.
|
| 10 |
+
stdout_interval: 5000
|
| 11 |
+
checkpoint_interval: 5000 # save model to ckpt every N steps
|
| 12 |
+
validation_interval: 5000
|
| 13 |
+
dist_cfg:
|
| 14 |
+
dist_backend: nccl # Distributed training backend, 'nccl' for NVIDIA GPUs.
|
| 15 |
+
dist_url: tcp://localhost:19478 # URL for initializing distributed training.
|
| 16 |
+
world_size: 1 # Total number of processes in the distributed training.
|
| 17 |
+
pin_memory: True # If you have large RAM, turn this to be True
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# STFT Configuration
|
| 21 |
+
# Configuration for Short-Time Fourier Transform (STFT), crucial for audio processing models.
|
| 22 |
+
stft_cfg:
|
| 23 |
+
sampling_rate: 8000 # Audio sampling rate in Hz.
|
| 24 |
+
n_fft: 320 # FFT components for transforming audio signals.
|
| 25 |
+
hop_size: 40 # Samples between successive frames.
|
| 26 |
+
win_size: 320 # Window size used in FFT.
|
| 27 |
+
sfi: True # Sampline Frequency Independent
|
| 28 |
+
|
| 29 |
+
# Model Configuration
|
| 30 |
+
# Defines the architecture specifics of the model, including layer configurations and feature compression.
|
| 31 |
+
model_cfg:
|
| 32 |
+
hid_feature: 64 # Channels in dense layers.
|
| 33 |
+
compress_factor: relu_log1p # Compression factor applied to extracted features.
|
| 34 |
+
num_tfmamba: 30 # Number of Time-Frequency Mamba (TFMamba) blocks in the model.
|
| 35 |
+
d_state: 16 # Dimensionality of the state vector in Mamba blocks.
|
| 36 |
+
d_conv: 4 # Convolutional layer dimensionality within Mamba blocks.
|
| 37 |
+
expand: 4 # Expansion factor for the layers within the Mamba blocks.
|
| 38 |
+
norm_epsilon: 0.00001 # Numerical stability in normalization layers within the Mamba blocks.
|
| 39 |
+
beta: 2.0 # Hyperparameter for the Learnable Sigmoid function.
|
| 40 |
+
input_channel: 2 # Magnitude and Phase
|
| 41 |
+
output_channel: 1 # Single Channel Speech Enhancement
|
| 42 |
+
inner_mamba_nlayer: 1 # Number of layer of Mamba in Bidirectional Mamba
|
| 43 |
+
nonlinear: None # last activation function for the mag encoder. 'softplus' or 'relu'
|
| 44 |
+
mapping: True # Otherwise, this should be masking model
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
packaging
|
| 2 |
+
librosa
|
| 3 |
+
soundfile
|
| 4 |
+
pyyaml
|
| 5 |
+
argparse
|
| 6 |
+
tensorboard
|
| 7 |
+
pesq
|
| 8 |
+
einops
|
| 9 |
+
matplotlib
|
| 10 |
+
torch==2.6.0
|
| 11 |
+
torchaudio==2.6.0
|
| 12 |
+
numpy==1.26.4
|
| 13 |
+
resampy
|
| 14 |
+
transformers==4.33.3
|
utils/util.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import yaml
|
| 10 |
+
import torch
|
| 11 |
+
import os
|
| 12 |
+
import shutil
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
def load_config(config_path):
|
| 16 |
+
"""Load configuration from a YAML file."""
|
| 17 |
+
with open(config_path, 'r') as file:
|
| 18 |
+
return yaml.safe_load(file)
|
| 19 |
+
|
| 20 |
+
def pad_or_trim_to_match(reference: torch.Tensor, target: torch.Tensor, pad_value: float = 1e-6) -> torch.Tensor:
|
| 21 |
+
"""
|
| 22 |
+
Extends the target tensor to match the reference tensor along dim=1
|
| 23 |
+
without breaking autograd, by creating a new tensor and copying data in.
|
| 24 |
+
"""
|
| 25 |
+
B, ref_len = reference.shape
|
| 26 |
+
_, tgt_len = target.shape
|
| 27 |
+
|
| 28 |
+
if tgt_len == ref_len:
|
| 29 |
+
return target
|
| 30 |
+
elif tgt_len > ref_len:
|
| 31 |
+
return target[:, :ref_len]
|
| 32 |
+
|
| 33 |
+
# Allocate padded tensor with grad support
|
| 34 |
+
padded = torch.full((B, ref_len), pad_value, dtype=target.dtype, device=target.device)
|
| 35 |
+
padded[:, :tgt_len] = target # This preserves gradient tracking
|
| 36 |
+
|
| 37 |
+
return padded
|