File size: 6,218 Bytes
d78f08c
 
 
 
 
 
 
 
 
 
 
 
dfc9065
d78f08c
 
 
 
 
51a5645
d78f08c
dfc9065
d78f08c
efab889
51a5645
d78f08c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cad78e
d78f08c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfc9065
 
 
 
d78f08c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import os
import argparse
import torch
import torchaudio
import torch.nn as nn
import librosa
import math
from models.stfts import mag_phase_stft, mag_phase_istft
from models.generator_SEMamba_time_d4 import SEMamba
from utils.util import load_config, pad_or_trim_to_match
from huggingface_hub import hf_hub_download

RELU = nn.ReLU()

config_path = hf_hub_download(repo_id="nvidia/RE-USE", filename="config.json")

def get_filepaths(directory, file_type=None):
    file_paths = []  # List which will store all of the full filepaths.
    # Walk the tree.
    for root, directories, files in os.walk(directory):
        for filename in files:
            # Join the two strings in order to form the full filepath.
            filepath = os.path.join(root, filename)
            if file_type is not None:
                if filepath.split('.')[-1] == file_type:
                    file_paths.append(filepath)  # Add it to the list.
            else:
                file_paths.append(filepath)  # Add it to the list.
    return file_paths  # Self-explanatory.

def make_even(value):
    value = int(round(value))
    return value if value % 2 == 0 else value + 1

def inference(args, device):
    cfg = load_config(args.config)
    n_fft, hop_size, win_size = cfg['stft_cfg']['n_fft'], cfg['stft_cfg']['hop_size'], cfg['stft_cfg']['win_size']
    compress_factor = cfg['model_cfg']['compress_factor']
    sampling_rate = cfg['stft_cfg']['sampling_rate']

    SE_model = SEMamba.from_pretrained("nvidia/RE-USE", cfg=cfg).to(device)
    SE_model.eval()

    os.makedirs(args.output_folder, exist_ok=True)
    with torch.no_grad():
        for fname in get_filepaths(args.input_folder):
            print(fname)
            try:
                os.makedirs(args.output_folder + fname[0:fname.rfind('/')].replace(args.input_folder,''), exist_ok=True)
                Noisy_wav, noisy_sr = torchaudio.load(fname)
            except Exception as e:
                print(f"Warning: cannot read {fname}, skipping. ({e})")
                continue

            if args.BWE is not None:
                opts = {"res_type": "kaiser_best"}
                Noisy_wav = librosa.resample(Noisy_wav.cpu().numpy(), orig_sr=noisy_sr, target_sr=int(args.BWE), **opts)
                noisy_sr = int(args.BWE)
            
            chunk_size = int(args.chunk_size_in_seconds*noisy_sr) # (in samples)
            hop_length = int(args.hop_length_portion*chunk_size) # (in samples)
            window = torch.hann_window(chunk_size).to(device)

            n_fft_scaled = make_even(n_fft * noisy_sr // sampling_rate)
            hop_size_scaled = make_even(hop_size * noisy_sr // sampling_rate)
            win_size_scaled = make_even(win_size * noisy_sr // sampling_rate)

            Noisy_wav = torch.FloatTensor(Noisy_wav).to(device)
            audio_enhanced = torch.zeros_like(Noisy_wav).to(device)
            #norm = torch.zeros_like(Noisy_wav).to(device)
            window_sum = torch.zeros_like(Noisy_wav).to(device)
            for c in range(Noisy_wav.shape[0]): # for multi-channel speech
                noisy_wav = Noisy_wav[c:c+1,:]
                for i in range(max(1, math.ceil((noisy_wav.shape[1]-chunk_size)/hop_length)+1)):
                    noisy_wav_chunk = noisy_wav[:, i*hop_length : i*hop_length+chunk_size]

                    noisy_mag, noisy_pha, noisy_com = mag_phase_stft(
                        noisy_wav_chunk,
                        n_fft=n_fft_scaled,
                        hop_size=hop_size_scaled,
                        win_size=win_size_scaled,
                        compress_factor=compress_factor,
                        center=True,
                        addeps=False
                    )
                    amp_g, pha_g, _ = SE_model(noisy_mag, noisy_pha)
                    # To remove "strange sweep artifact"
                    mag = torch.expm1(RELU(amp_g)) # [1, F, T]
                    zero_portion = torch.sum(mag==0, 1)/mag.shape[1]
                    amp_g[:,:,(zero_portion>0.5)[0]] = 0

                    audio_g = mag_phase_istft(amp_g, pha_g, n_fft_scaled, hop_size_scaled, win_size_scaled, compress_factor)
                    audio_g = pad_or_trim_to_match(noisy_wav_chunk.detach(), audio_g, pad_value=1e-8)  # Align lengths using epsilon padding

                    audio_enhanced[c:c+1,i*hop_length:i*hop_length+chunk_size] += audio_g*window[0:audio_g.shape[1]] 
                    window_sum[c:c+1,i*hop_length:i*hop_length+chunk_size] += window[0:audio_g.shape[1]]
                    #norm[c:c+1,i*hop_length:i*hop_length+chunk_size] += 1.0
            nonzero_indices = (window_sum > 1e-8)
            audio_enhanced[:,nonzero_indices[0]] = audio_enhanced[:,nonzero_indices[0]]/window_sum[:,nonzero_indices[0]]
            assert audio_enhanced.shape == Noisy_wav.shape, audio_enhanced.shape
            output_file = os.path.join(args.output_folder + fname.replace(args.input_folder,'').split('.')[0]+'.flac') # save to .flac format
            torchaudio.save(output_file, audio_enhanced.cpu(), noisy_sr)

def main():
    print('Initializing Inference Process..')
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_folder')
    parser.add_argument('--output_folder')
    parser.add_argument('--config')
    parser.add_argument('--checkpoint_file')
    parser.add_argument('--chunk_size_in_seconds', type=float)
    parser.add_argument('--hop_length_portion', type=float)
    parser.add_argument('--BWE', default=None)
    args = parser.parse_args()

    global device
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        raise RuntimeError("Currently, CPU mode is not supported.")
        
    inference(args, device)

if __name__ == '__main__':
    main()