File size: 4,731 Bytes
d78f08c
 
 
 
 
 
 
 
 
 
 
 
dfc9065
d78f08c
 
 
 
 
dfc9065
 
d78f08c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfc9065
 
 
 
d78f08c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import os
import argparse
import torch
import torchaudio
import torch.nn as nn
import librosa
from models.stfts import mag_phase_stft, mag_phase_istft
from models.generator_SEMamba_time_d4 import SEMamba
from utils.util import load_config, pad_or_trim_to_match

RELU = nn.ReLU()

def get_filepaths(directory, file_type=None):
    file_paths = []  # List which will store all of the full filepaths.
    # Walk the tree.
    for root, directories, files in os.walk(directory):
        for filename in files:
            # Join the two strings in order to form the full filepath.
            filepath = os.path.join(root, filename)
            if file_type is not None:
                if filepath.split('.')[-1] == file_type:
                    file_paths.append(filepath)  # Add it to the list.
            else:
                file_paths.append(filepath)  # Add it to the list.
    return file_paths  # Self-explanatory.

def make_even(value):
    value = int(round(value))
    return value if value % 2 == 0 else value + 1

def inference(args, device):
    cfg = load_config(args.config)
    n_fft, hop_size, win_size = cfg['stft_cfg']['n_fft'], cfg['stft_cfg']['hop_size'], cfg['stft_cfg']['win_size']
    compress_factor = cfg['model_cfg']['compress_factor']
    sampling_rate = cfg['stft_cfg']['sampling_rate']

    SE_model = SEMamba(cfg).to(device)
    state_dict = torch.load(args.checkpoint_file, map_location=device)
    SE_model.load_state_dict(state_dict['generator'])
    SE_model.eval()

    os.makedirs(args.output_folder, exist_ok=True)
    with torch.no_grad():
        for i, fname in enumerate(get_filepaths(args.input_folder)):
            print(fname)
            try:
                os.makedirs(args.output_folder + fname[0:fname.rfind('/')].replace(args.input_folder,''), exist_ok=True)
                noisy_wav, noisy_sr = torchaudio.load(fname)
            except Exception as e:
                print(f"Warning: cannot read {fname}, skipping. ({e})")
                continue
            
            if args.BWE is not None:
                opts = {"res_type": "kaiser_best"}
                noisy_wav = librosa.resample(noisy_wav.cpu().numpy(), orig_sr=noisy_sr, target_sr=int(args.BWE), **opts)
                noisy_sr = int(args.BWE)

            noisy_wav = torch.FloatTensor(noisy_wav).to(device)
            n_fft_scaled = make_even(n_fft * noisy_sr // sampling_rate)
            hop_size_scaled = make_even(hop_size * noisy_sr // sampling_rate)
            win_size_scaled = make_even(win_size * noisy_sr // sampling_rate)

            noisy_mag, noisy_pha, noisy_com = mag_phase_stft(
                noisy_wav,
                n_fft=n_fft_scaled,
                hop_size=hop_size_scaled,
                win_size=win_size_scaled,
                compress_factor=compress_factor,
                center=True,
                addeps=False
            )
            amp_g, pha_g, _ = SE_model(noisy_mag, noisy_pha)
            # To remove "strange sweep artifact"
            mag = torch.expm1(RELU(amp_g)) # [1, F, T]
            zero_portion = torch.sum(mag==0, 1)/mag.shape[1]
            amp_g[:,:,(zero_portion>0.5)[0]] = 0

            audio_g = mag_phase_istft(amp_g, pha_g, n_fft_scaled, hop_size_scaled, win_size_scaled, compress_factor)
            audio_g = pad_or_trim_to_match(noisy_wav.detach(), audio_g, pad_value=1e-8)  # Align lengths using epsilon padding
            assert audio_g.shape == noisy_wav.shape, audio_g.shape

            output_file = os.path.join(args.output_folder + fname.replace(args.input_folder,'').split('.')[0]+'.flac') # save to .flac format
            torchaudio.save(output_file, audio_g.cpu(), noisy_sr)
            
def main():
    print('Initializing Inference Process...')
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_folder')
    parser.add_argument('--output_folder')
    parser.add_argument('--config')
    parser.add_argument('--checkpoint_file', required=True)
    parser.add_argument('--BWE', default=None)
    args = parser.parse_args()

    global device
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        raise RuntimeError("Currently, CPU mode is not supported.")
        
    inference(args, device)


if __name__ == '__main__':
    main()