szuweifu commited on
Commit
3c7312f
·
verified ·
1 Parent(s): 9b418f5

Upload 9 files

Browse files
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import shlex
10
+ import subprocess
11
+ import spaces
12
+ import gradio as gr
13
+
14
+ def install_mamba():
15
+ subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
16
+ install_mamba()
17
+
18
+ ABOUT = """
19
+ # RE-USE: A universal speech enhancement model for diverse degradations, sampling rates, and languages.
20
+ Upload or record a noisy clip, then click **Enhance** to listen to the result and view its spectrogram.
21
+ (ref: https://huggingface.co/spaces/rc19477/Speech_Enhancement_Mamba)
22
+ """
23
+
24
+ import torch
25
+ import torchaudio
26
+ import librosa
27
+ import matplotlib.pyplot as plt
28
+ import numpy as np
29
+ from models.stfts import mag_phase_stft, mag_phase_istft
30
+ from models.generator_SEMamba_MPSEnet_time_d4 import SEMamba
31
+ from utils.util import load_config, pad_or_trim_to_match
32
+
33
+
34
+ def make_even(value):
35
+ value = int(round(value))
36
+ return value if value % 2 == 0 else value + 1
37
+
38
+ device = "cuda"
39
+ cfg1 = load_config('recipes/USEMamba_30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_001.yaml')
40
+ n_fft, hop_size, win_size = cfg1['stft_cfg']['n_fft'], cfg1['stft_cfg']['hop_size'], cfg1['stft_cfg']['win_size']
41
+ compress_factor = cfg1['model_cfg']['compress_factor']
42
+ sampling_rate = cfg1['stft_cfg']['sampling_rate']
43
+
44
+ USE_model = SEMamba(cfg1).to(device)
45
+ checkpoint_file = "exp/30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_peak_GAN_tel_mic/g_01134000.pth"
46
+ state_dict = torch.load(checkpoint_file, map_location=device)
47
+ USE_model.load_state_dict(state_dict['generator'])
48
+ USE_model.eval()
49
+
50
+ @spaces.GPU
51
+ def enhance(filepath, low_pass_sampling_rate, target_sampling_rate):
52
+
53
+ with torch.no_grad():
54
+ noisy_wav, noisy_sr = torchaudio.load(filepath)
55
+ torchaudio.save("original.wav", noisy_wav.cpu(), noisy_sr)
56
+ original_noisy_wav = noisy_wav
57
+ original_sr = noisy_sr
58
+
59
+ if target_sampling_rate != '':
60
+ if low_pass_sampling_rate != '':
61
+ opts = {"res_type": "kaiser_best"}
62
+ noisy_wav = torch.tensor(librosa.resample(noisy_wav.cpu().numpy(), orig_sr=noisy_sr, target_sr=int(low_pass_sampling_rate), **opts))
63
+ noisy_sr = int(low_pass_sampling_rate)
64
+ opts = {"res_type": "kaiser_best"}
65
+ noisy_wav = librosa.resample(noisy_wav.cpu().numpy(), orig_sr=noisy_sr, target_sr=int(target_sampling_rate), **opts)
66
+ noisy_sr = int(target_sampling_rate)
67
+
68
+ noisy_wav = torch.FloatTensor(noisy_wav).to(device)
69
+ n_fft_scaled = make_even(n_fft * noisy_sr // sampling_rate)
70
+ hop_size_scaled = make_even(hop_size * noisy_sr // sampling_rate)
71
+ win_size_scaled = make_even(win_size * noisy_sr // sampling_rate)
72
+
73
+ noisy_mag, noisy_pha, noisy_com = mag_phase_stft(
74
+ noisy_wav,
75
+ n_fft=n_fft_scaled,
76
+ hop_size=hop_size_scaled,
77
+ win_size=win_size_scaled,
78
+ compress_factor=compress_factor,
79
+ center=True,
80
+ addeps=False
81
+ )
82
+ amp_g, pha_g, _ = USE_model(noisy_mag, noisy_pha)
83
+
84
+ audio_g = mag_phase_istft(amp_g, pha_g, n_fft_scaled, hop_size_scaled, win_size_scaled, compress_factor)
85
+ audio_g = pad_or_trim_to_match(noisy_wav.detach(), audio_g, pad_value=1e-8) # Align lengths using epsilon padding
86
+ assert audio_g.shape == noisy_wav.shape, audio_g.shape
87
+
88
+ # write file
89
+ torchaudio.save("enhanced.wav", audio_g.cpu(), noisy_sr)
90
+
91
+ # spectrograms
92
+ fig, axs = plt.subplots(1, 2, figsize=(16, 4))
93
+
94
+ # noisy
95
+ D_noisy = librosa.stft(original_noisy_wav[0].cpu().numpy(), n_fft=512, hop_length=256)
96
+ S_noisy = librosa.amplitude_to_db(np.abs(D_noisy), ref=np.max)
97
+ librosa.display.specshow(S_noisy, sr=original_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[0], vmax=0)
98
+ axs[0].set_title("Noisy Spectrogram")
99
+
100
+ # enhanced
101
+ D_clean = librosa.stft(audio_g.cpu().numpy(), n_fft=512, hop_length=256)
102
+ S_clean = librosa.amplitude_to_db(np.abs(D_clean), ref=np.max)
103
+ librosa.display.specshow(S_clean[0], sr=noisy_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[1], vmax=0)
104
+ axs[1].set_title("Enhanced Spectrogram")
105
+
106
+ plt.tight_layout()
107
+ return "original.wav", "enhanced.wav", fig
108
+
109
+ with gr.Blocks() as demo:
110
+ gr.Markdown(ABOUT)
111
+ gr.Markdown("**Note 1**: For bandwidth extension, the performance may be affected by the characteristics of the input data, particularly the cutoff pattern. A simple solution is to apply low-pass filtering beforehand.")
112
+ gr.Markdown("**Note 2**: When processing long input audio, out-of-memory (OOM) errors may occur. To address this, use the chunk-wise inference implementation provided on the Hugging Face.")
113
+
114
+ with gr.Row():
115
+ with gr.Column():
116
+ # Create Tabs to separate Audio and Video sessions
117
+ with gr.Tabs():
118
+ with gr.TabItem("Audio Upload"):
119
+ # gr.Audio works great for standard audio files
120
+ input_audio = gr.Audio(label="Input Audio", type="filepath")
121
+
122
+ with gr.TabItem("Video Upload (.mp4, .mov)"):
123
+ # gr.File handles .mp4 and .mov without errors
124
+ input_video = gr.File(label="Input Video", file_types=[".mp4", ".mov"])
125
+
126
+ target_sampling_rate = gr.Textbox(label="(Optional) Enter target sampling rate for bandwidth extension:")
127
+ low_pass_sampling_rate = gr.Textbox(label="(Optional) Enter target sampling rate for pre-low-pass filtering before bandwidth extension:")
128
+
129
+ # Helper to unify the input: we use a hidden state to store which one was used
130
+ active_input = gr.State()
131
+ enhance_btn = gr.Button("Enhance")
132
+ with gr.Row():
133
+ input_audio_player = gr.Audio(label="Original Input Audio", type="filepath")
134
+ output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
135
+ plot_output = gr.Plot(label="Spectrograms")
136
+
137
+ # This function determines which input (audio tab or video tab) to send to your model
138
+ def unified_enhance(audio_path, video_path, lp_sr, target_sr):
139
+ # Determine which path is valid (the one from the active tab)
140
+ # Note: input_video returns a file object, so we get its .name
141
+ final_path = audio_path if audio_path else video_path
142
+ return enhance(final_path, lp_sr, target_sr)
143
+
144
+ enhance_btn.click(
145
+ fn=unified_enhance,
146
+ inputs=[input_audio, input_video, low_pass_sampling_rate, target_sampling_rate],
147
+ outputs=[input_audio_player, output_audio, plot_output]
148
+ )
149
+
150
+ demo.queue().launch(share=True)
exp/30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_peak_GAN_tel_mic/g_01134000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e27db9e1de904eb59fc627dea72c69da7ca25650a3e704b4096f89812b395fe5
3
+ size 38982886
models/codec_module_time_d4.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ from einops import rearrange
13
+
14
+ def get_padding_2d(kernel_size, dilation=(1, 1)):
15
+ """
16
+ Calculate the padding size for a 2D convolutional layer.
17
+
18
+ Args:
19
+ - kernel_size (tuple): Size of the convolutional kernel (height, width).
20
+ - dilation (tuple, optional): Dilation rate of the convolution (height, width). Defaults to (1, 1).
21
+
22
+ Returns:
23
+ - tuple: Calculated padding size (height, width).
24
+ """
25
+ return (int((kernel_size[0] * dilation[0] - dilation[0]) / 2),
26
+ int((kernel_size[1] * dilation[1] - dilation[1]) / 2))
27
+
28
+ class SPConvTranspose2d(nn.Module):
29
+ def __init__(self, in_channels, out_channels, kernel_size, r=1):
30
+ super(SPConvTranspose2d, self).__init__()
31
+ self.pad1 = nn.ConstantPad2d((1, 1, 0, 0), value=0.)
32
+ self.out_channels = out_channels
33
+ self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1))
34
+ self.r = r
35
+
36
+ def forward(self, x):
37
+ x = self.pad1(x)
38
+ out = self.conv(x)
39
+ batch_size, nchannels, H, W = out.shape
40
+ out = out.view((batch_size, self.r, nchannels // self.r, H, W))
41
+ out = out.permute(0, 2, 3, 4, 1)
42
+ out = out.contiguous().view((batch_size, nchannels // self.r, H, -1))
43
+ return out
44
+
45
+ class DenseBlock(nn.Module):
46
+ """
47
+ DenseBlock module consisting of multiple convolutional layers with dilation.
48
+ """
49
+ def __init__(self, cfg, kernel_size=(3, 3), depth=4):
50
+ super(DenseBlock, self).__init__()
51
+ self.cfg = cfg
52
+ self.depth = depth
53
+ self.dense_block = nn.ModuleList()
54
+ self.hid_feature = cfg['model_cfg']['hid_feature']
55
+
56
+ for i in range(depth):
57
+ dil = 2 ** i
58
+ dense_conv = nn.Sequential(
59
+ nn.Conv2d(self.hid_feature * (i + 1), self.hid_feature, kernel_size,
60
+ dilation=(dil, 1), padding=get_padding_2d(kernel_size, (dil, 1))),
61
+ nn.InstanceNorm2d(self.hid_feature, affine=True),
62
+ nn.PReLU(self.hid_feature)
63
+ )
64
+ self.dense_block.append(dense_conv)
65
+
66
+ def forward(self, x):
67
+ skip = x
68
+ for i in range(self.depth):
69
+ x = self.dense_block[i](skip)
70
+ skip = torch.cat([x, skip], dim=1)
71
+ return x
72
+
73
+ class DenseEncoder(nn.Module):
74
+ """
75
+ DenseEncoder module consisting of initial convolution, dense block, and a final convolution.
76
+ """
77
+ def __init__(self, cfg):
78
+ super(DenseEncoder, self).__init__()
79
+ self.cfg = cfg
80
+ self.input_channel = cfg['model_cfg']['input_channel']
81
+ self.hid_feature = cfg['model_cfg']['hid_feature']
82
+
83
+ self.dense_conv_1 = nn.Sequential(
84
+ nn.Conv2d(self.input_channel, self.hid_feature, (1, 1)),
85
+ nn.InstanceNorm2d(self.hid_feature, affine=True),
86
+ nn.PReLU(self.hid_feature)
87
+ )
88
+
89
+ self.dense_block = DenseBlock(cfg, depth=4)
90
+
91
+ self.dense_conv_2 = nn.Sequential(
92
+ nn.Conv2d(self.hid_feature, self.hid_feature, (1, 3), stride=(4, 2)),
93
+ nn.InstanceNorm2d(self.hid_feature, affine=True),
94
+ nn.PReLU(self.hid_feature)
95
+ )
96
+
97
+ def forward(self, x):
98
+ x = self.dense_conv_1(x) # [batch, hid_feature, time, freq]
99
+ x = self.dense_block(x) # [batch, hid_feature, time, freq]
100
+ x = self.dense_conv_2(x) # [batch, hid_feature, time, freq//2]
101
+ return x
102
+
103
+ class MagDecoder(nn.Module):
104
+ """
105
+ MagDecoder module for decoding magnitude information.
106
+ """
107
+ def __init__(self, cfg):
108
+ super(MagDecoder, self).__init__()
109
+ self.dense_block = DenseBlock(cfg, depth=4)
110
+ self.hid_feature = cfg['model_cfg']['hid_feature']
111
+ self.output_channel = cfg['model_cfg']['output_channel']
112
+ self.n_fft = cfg['stft_cfg']['n_fft']
113
+ self.beta = cfg['model_cfg']['beta']
114
+
115
+ self.up_conv1 = nn.Sequential(
116
+ SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 2),
117
+ nn.InstanceNorm2d(self.hid_feature, affine=True),
118
+ nn.PReLU(self.hid_feature)
119
+ )
120
+
121
+ self.up_conv2 = nn.Sequential(
122
+ SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 4),
123
+ nn.InstanceNorm2d(self.hid_feature, affine=True),
124
+ nn.PReLU(self.hid_feature)
125
+ )
126
+
127
+ self.final_conv = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
128
+
129
+ def forward(self, x):
130
+ x = self.dense_block(x)
131
+ x = self.up_conv1(x)
132
+ x = self.up_conv2(x.permute(0,1,3,2)).permute(0,1,3,2)
133
+ x = self.final_conv(x)
134
+ return x
135
+
136
+ class PhaseDecoder(nn.Module):
137
+ """
138
+ PhaseDecoder module for decoding phase information.
139
+ """
140
+ def __init__(self, cfg):
141
+ super(PhaseDecoder, self).__init__()
142
+ self.dense_block = DenseBlock(cfg, depth=4)
143
+ self.hid_feature = cfg['model_cfg']['hid_feature']
144
+ self.output_channel = cfg['model_cfg']['output_channel']
145
+
146
+ self.up_conv1 = nn.Sequential(
147
+ SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 2),
148
+ nn.InstanceNorm2d(self.hid_feature, affine=True),
149
+ nn.PReLU(self.hid_feature)
150
+ )
151
+
152
+ self.up_conv2 = nn.Sequential(
153
+ SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 4),
154
+ nn.InstanceNorm2d(self.hid_feature, affine=True),
155
+ nn.PReLU(self.hid_feature)
156
+ )
157
+
158
+ self.phase_conv_r = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
159
+ self.phase_conv_i = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
160
+
161
+ def forward(self, x):
162
+ x = self.dense_block(x)
163
+ x = self.up_conv1(x)
164
+ x = self.up_conv2(x.permute(0,1,3,2)).permute(0,1,3,2)
165
+ x_r = self.phase_conv_r(x)
166
+ x_i = self.phase_conv_i(x)
167
+ x = torch.atan2(x_i, x_r)
168
+ return x
models/generator_SEMamba_time_d4.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+ from .mamba_block2_SEMamba import TFMambaBlock
13
+ from .codec_module_time_d4 import DenseEncoder, MagDecoder, PhaseDecoder
14
+
15
+ class SEMamba(nn.Module):
16
+ """
17
+ SEMamba model for speech enhancement using Mamba blocks.
18
+
19
+ This model uses a dense encoder, multiple Mamba blocks, and separate magnitude
20
+ and phase decoders to process noisy magnitude and phase inputs.
21
+ """
22
+ def __init__(self, cfg):
23
+ """
24
+ Initialize the SEMamba model.
25
+
26
+ Args:
27
+ - cfg: Configuration object containing model parameters.
28
+ """
29
+ super(SEMamba, self).__init__()
30
+ self.cfg = cfg
31
+ self.num_tscblocks = cfg['model_cfg']['num_tfmamba'] if cfg['model_cfg']['num_tfmamba'] is not None else 4 # default tfmamba: 4
32
+
33
+ # Initialize dense encoder
34
+ self.dense_encoder = DenseEncoder(cfg)
35
+
36
+ # Initialize Mamba blocks
37
+ self.TSMamba = nn.ModuleList([TFMambaBlock(cfg) for _ in range(self.num_tscblocks)])
38
+
39
+ # Initialize decoders
40
+ self.mask_decoder = MagDecoder(cfg)
41
+ self.phase_decoder = PhaseDecoder(cfg)
42
+
43
+ def forward(self, noisy_mag, noisy_pha):
44
+ """
45
+ Forward pass for the SEMamba model.
46
+
47
+ Args:
48
+ - noisy_mag (torch.Tensor): Noisy magnitude input tensor [B, F, T].
49
+ - noisy_pha (torch.Tensor): Noisy phase input tensor [B, F, T].
50
+
51
+ Returns:
52
+ - denoised_mag (torch.Tensor): Denoised magnitude tensor [B, F, T].
53
+ - denoised_pha (torch.Tensor): Denoised phase tensor [B, F, T].
54
+ - denoised_com (torch.Tensor): Denoised complex tensor [B, F, T, 2].
55
+ """
56
+ # Reshape inputs
57
+ noisy_mag = rearrange(noisy_mag, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F]
58
+ noisy_pha = rearrange(noisy_pha, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F]
59
+
60
+ # Concatenate magnitude and phase inputs
61
+ x = torch.cat((noisy_mag, noisy_pha), dim=1) # [B, 2, T, F]
62
+
63
+ # Prevent unpredictable errors
64
+ B, C, T, F = x.shape
65
+ zeros = torch.zeros(B, C, T, 2, device=x.device)
66
+ x = torch.cat((x, zeros), dim=-1)
67
+ zeros = torch.zeros(B, C, 2, F+2, device=x.device)
68
+ x = torch.cat((x, zeros), dim=-2)
69
+
70
+ # Encode input
71
+ x = self.dense_encoder(x)
72
+
73
+ # Apply Mamba blocks
74
+ for block in self.TSMamba:
75
+ x = block(x)
76
+
77
+ # Decode output
78
+ denoised_mag = rearrange(self.mask_decoder(x), 'b c t f -> b f t c').squeeze(-1)
79
+ denoised_pha = rearrange(self.phase_decoder(x), 'b c t f -> b f t c').squeeze(-1)
80
+
81
+ # Prevent unpredictable errors
82
+ denoised_mag = denoised_mag[:, :F, :T]
83
+ denoised_pha = denoised_pha[:, :F, :T]
84
+
85
+ # Combine denoised magnitude and phase into a complex representation
86
+ denoised_com = torch.stack(
87
+ (denoised_mag * torch.cos(denoised_pha), denoised_mag * torch.sin(denoised_pha)),
88
+ dim=-1
89
+ )
90
+
91
+ return denoised_mag, denoised_pha, denoised_com
models/mamba_block2_SEMamba.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.nn import init
13
+ from torch.nn.parameter import Parameter
14
+ from functools import partial
15
+ from einops import rearrange
16
+ from mamba_ssm import Mamba
17
+
18
+ class MambaBlock(nn.Module):
19
+ def __init__(self, d_model, cfg):
20
+ super(MambaBlock, self).__init__()
21
+
22
+ d_state = cfg['model_cfg']['d_state'] # 16
23
+ d_conv = cfg['model_cfg']['d_conv'] # 4
24
+ expand = cfg['model_cfg']['expand'] # 4
25
+
26
+ self.forward_blocks = Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
27
+ self.backward_blocks = Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
28
+ self.output_proj = nn.Linear(2 * d_model, d_model)
29
+ self.norm = nn.LayerNorm(d_model)
30
+
31
+ def forward(self, x):
32
+ # x: [B, T, D]
33
+ out_fw = self.forward_blocks(x) + x
34
+
35
+ out_bw = self.backward_blocks(torch.flip(x, dims=[1])) + torch.flip(x, dims=[1])
36
+ out_bw = torch.flip(out_bw, dims=[1])
37
+
38
+ out = torch.cat([out_fw, out_bw], dim=-1)
39
+ out = self.output_proj(out)
40
+
41
+ # LayerNorm
42
+ return self.norm(out)
43
+
44
+
45
+ class TFMambaBlock(nn.Module):
46
+ """
47
+ Temporal-Frequency Mamba block for sequence modeling.
48
+
49
+ Attributes:
50
+ cfg (Config): Configuration for the block.
51
+ time_mamba (MambaBlock): Mamba block for temporal dimension.
52
+ freq_mamba (MambaBlock): Mamba block for frequency dimension.
53
+ tlinear (ConvTranspose1d): ConvTranspose1d layer for temporal dimension.
54
+ flinear (ConvTranspose1d): ConvTranspose1d layer for frequency dimension.
55
+ """
56
+ def __init__(self, cfg):
57
+ super(TFMambaBlock, self).__init__()
58
+ self.cfg = cfg
59
+ self.hid_feature = cfg['model_cfg']['hid_feature']
60
+
61
+ # Initialize Mamba blocks
62
+ self.time_mamba = MambaBlock(d_model=self.hid_feature, cfg=cfg)
63
+ self.freq_mamba = MambaBlock(d_model=self.hid_feature, cfg=cfg)
64
+
65
+ def forward(self, x):
66
+ """
67
+ Forward pass of the TFMamba block.
68
+
69
+ Parameters:
70
+ x (Tensor): Input tensor with shape (batch, channels, time, freq).
71
+
72
+ Returns:
73
+ Tensor: Output tensor after applying temporal and frequency Mamba blocks.
74
+ """
75
+ b, c, t, f = x.size()
76
+ x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c)
77
+ x = self.time_mamba(x) + x
78
+ x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c)
79
+ x = self.freq_mamba(x) + x
80
+ x = x.view(b, t, f, c).permute(0, 3, 1, 2)
81
+ return x
models/stfts.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ def decompress_signed_log1p(y):
13
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
14
+
15
+ RELU = nn.ReLU()
16
+
17
+ def mag_phase_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True, addeps=False):
18
+ """
19
+ Compute magnitude and phase using STFT.
20
+
21
+ Args:
22
+ y (torch.Tensor): Input audio signal.
23
+ n_fft (int): FFT size.
24
+ hop_size (int): Hop size.
25
+ win_size (int): Window size.
26
+ compress_factor (float, optional): Magnitude compression factor. Defaults to 1.0.
27
+ center (bool, optional): Whether to center the signal before padding. Defaults to True.
28
+ eps (bool, optional): Whether adding epsilon to magnitude and phase or not. Defaults to False.
29
+
30
+ Returns:
31
+ tuple: Magnitude, phase, and complex representation of the STFT.
32
+ """
33
+ eps = 1e-10
34
+ hann_window = torch.hann_window(win_size).to(y.device)
35
+ stft_spec = torch.stft(
36
+ y, n_fft,
37
+ hop_length=hop_size,
38
+ win_length=win_size,
39
+ window=hann_window,
40
+ center=center,
41
+ pad_mode='reflect',
42
+ normalized=False,
43
+ return_complex=True)
44
+
45
+ if addeps==False:
46
+ mag = torch.abs(stft_spec)
47
+ pha = torch.angle(stft_spec)
48
+ else:
49
+ real_part = stft_spec.real
50
+ imag_part = stft_spec.imag
51
+ mag = torch.sqrt(real_part.pow(2) + imag_part.pow(2) + eps)
52
+ pha = torch.atan2(imag_part + eps, real_part + eps)
53
+ # Compress the magnitude
54
+ if compress_factor in ['log1p','relu_log1p', 'signed_log1p']:
55
+ mag = torch.log1p(mag)
56
+ else:
57
+ mag = torch.pow(mag, compress_factor)
58
+ com = torch.stack((mag * torch.cos(pha), mag * torch.sin(pha)), dim=-1)
59
+ return mag, pha, com
60
+
61
+
62
+ def mag_phase_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
63
+ """
64
+ Inverse STFT to reconstruct the audio signal from magnitude and phase.
65
+
66
+ Args:
67
+ mag (torch.Tensor): Magnitude of the STFT.
68
+ pha (torch.Tensor): Phase of the STFT.
69
+ n_fft (int): FFT size.
70
+ hop_size (int): Hop size.
71
+ win_size (int): Window size.
72
+ compress_factor (float, optional): Magnitude compression factor. Defaults to 1.0.
73
+ center (bool, optional): Whether to center the signal before padding. Defaults to True.
74
+
75
+ Returns:
76
+ torch.Tensor: Reconstructed audio signal.
77
+ """
78
+ if compress_factor == 'log1p':
79
+ mag = torch.expm1(mag)
80
+ elif compress_factor == 'signed_log1p':
81
+ mag = decompress_signed_log1p(mag)
82
+ elif compress_factor == 'relu_log1p':
83
+ mag = torch.expm1(RELU(mag))
84
+ else:
85
+ mag = torch.pow(RELU(mag), 1.0 / compress_factor)
86
+ com = torch.complex(mag * torch.cos(pha), mag * torch.sin(pha))
87
+ hann_window = torch.hann_window(win_size).to(com.device)
88
+ wav = torch.istft(
89
+ com,
90
+ n_fft,
91
+ hop_length=hop_size,
92
+ win_length=win_size,
93
+ window=hann_window,
94
+ center=center)
95
+ return wav
recipes/USEMamba_30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_001.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment Settings
2
+ # These settings specify the hardware and distributed setup for the model training.
3
+ # Adjust `num_gpus` and `dist_config` according to your distributed training environment.
4
+ env_setting:
5
+ num_gpus: 8 # Number of GPUs. Now we don't support CPU mode.
6
+ num_workers: 20 # 0 Number of worker threads for data loading.
7
+ persistent_workers: True # False If you have large RAM, turn this to be True
8
+ prefetch_factor: 8 # null
9
+ seed: 1234 # Seed for random number generators to ensure reproducibility.
10
+ stdout_interval: 5000
11
+ checkpoint_interval: 5000 # save model to ckpt every N steps
12
+ validation_interval: 5000
13
+ dist_cfg:
14
+ dist_backend: nccl # Distributed training backend, 'nccl' for NVIDIA GPUs.
15
+ dist_url: tcp://localhost:19478 # URL for initializing distributed training.
16
+ world_size: 1 # Total number of processes in the distributed training.
17
+ pin_memory: True # If you have large RAM, turn this to be True
18
+
19
+
20
+ # STFT Configuration
21
+ # Configuration for Short-Time Fourier Transform (STFT), crucial for audio processing models.
22
+ stft_cfg:
23
+ sampling_rate: 8000 # Audio sampling rate in Hz.
24
+ n_fft: 320 # FFT components for transforming audio signals.
25
+ hop_size: 40 # Samples between successive frames.
26
+ win_size: 320 # Window size used in FFT.
27
+ sfi: True # Sampline Frequency Independent
28
+
29
+ # Model Configuration
30
+ # Defines the architecture specifics of the model, including layer configurations and feature compression.
31
+ model_cfg:
32
+ hid_feature: 64 # Channels in dense layers.
33
+ compress_factor: relu_log1p # Compression factor applied to extracted features.
34
+ num_tfmamba: 30 # Number of Time-Frequency Mamba (TFMamba) blocks in the model.
35
+ d_state: 16 # Dimensionality of the state vector in Mamba blocks.
36
+ d_conv: 4 # Convolutional layer dimensionality within Mamba blocks.
37
+ expand: 4 # Expansion factor for the layers within the Mamba blocks.
38
+ norm_epsilon: 0.00001 # Numerical stability in normalization layers within the Mamba blocks.
39
+ beta: 2.0 # Hyperparameter for the Learnable Sigmoid function.
40
+ input_channel: 2 # Magnitude and Phase
41
+ output_channel: 1 # Single Channel Speech Enhancement
42
+ inner_mamba_nlayer: 1 # Number of layer of Mamba in Bidirectional Mamba
43
+ nonlinear: None # last activation function for the mag encoder. 'softplus' or 'relu'
44
+ mapping: True # Otherwise, this should be masking model
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ packaging
2
+ librosa
3
+ soundfile
4
+ pyyaml
5
+ argparse
6
+ tensorboard
7
+ pesq
8
+ einops
9
+ matplotlib
10
+ torch==2.6.0
11
+ torchaudio==2.6.0
12
+ numpy==1.26.4
13
+ resampy
14
+ transformers==4.33.3
utils/util.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import yaml
10
+ import torch
11
+ import os
12
+ import shutil
13
+ import torch.nn.functional as F
14
+
15
+ def load_config(config_path):
16
+ """Load configuration from a YAML file."""
17
+ with open(config_path, 'r') as file:
18
+ return yaml.safe_load(file)
19
+
20
+ def pad_or_trim_to_match(reference: torch.Tensor, target: torch.Tensor, pad_value: float = 1e-6) -> torch.Tensor:
21
+ """
22
+ Extends the target tensor to match the reference tensor along dim=1
23
+ without breaking autograd, by creating a new tensor and copying data in.
24
+ """
25
+ B, ref_len = reference.shape
26
+ _, tgt_len = target.shape
27
+
28
+ if tgt_len == ref_len:
29
+ return target
30
+ elif tgt_len > ref_len:
31
+ return target[:, :ref_len]
32
+
33
+ # Allocate padded tensor with grad support
34
+ padded = torch.full((B, ref_len), pad_value, dtype=target.dtype, device=target.device)
35
+ padded[:, :tgt_len] = target # This preserves gradient tracking
36
+
37
+ return padded