# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import os import argparse import torch import torchaudio import torch.nn as nn import librosa from models.stfts import mag_phase_stft, mag_phase_istft from models.generator_SEMamba_time_d4 import SEMamba from utils.util import load_config, pad_or_trim_to_match RELU = nn.ReLU() def get_filepaths(directory, file_type=None): file_paths = [] # List which will store all of the full filepaths. # Walk the tree. for root, directories, files in os.walk(directory): for filename in files: # Join the two strings in order to form the full filepath. filepath = os.path.join(root, filename) if file_type is not None: if filepath.split('.')[-1] == file_type: file_paths.append(filepath) # Add it to the list. else: file_paths.append(filepath) # Add it to the list. return file_paths # Self-explanatory. def make_even(value): value = int(round(value)) return value if value % 2 == 0 else value + 1 def inference(args, device): cfg = load_config(args.config) n_fft, hop_size, win_size = cfg['stft_cfg']['n_fft'], cfg['stft_cfg']['hop_size'], cfg['stft_cfg']['win_size'] compress_factor = cfg['model_cfg']['compress_factor'] sampling_rate = cfg['stft_cfg']['sampling_rate'] SE_model = SEMamba(cfg).to(device) state_dict = torch.load(args.checkpoint_file, map_location=device) SE_model.load_state_dict(state_dict['generator']) SE_model.eval() os.makedirs(args.output_folder, exist_ok=True) with torch.no_grad(): for i, fname in enumerate(get_filepaths(args.input_folder)): print(fname) try: os.makedirs(args.output_folder + fname[0:fname.rfind('/')].replace(args.input_folder,''), exist_ok=True) noisy_wav, noisy_sr = torchaudio.load(fname) except Exception as e: print(f"Warning: cannot read {fname}, skipping. ({e})") continue if args.BWE is not None: opts = {"res_type": "kaiser_best"} noisy_wav = librosa.resample(noisy_wav.cpu().numpy(), orig_sr=noisy_sr, target_sr=int(args.BWE), **opts) noisy_sr = int(args.BWE) noisy_wav = torch.FloatTensor(noisy_wav).to(device) n_fft_scaled = make_even(n_fft * noisy_sr // sampling_rate) hop_size_scaled = make_even(hop_size * noisy_sr // sampling_rate) win_size_scaled = make_even(win_size * noisy_sr // sampling_rate) noisy_mag, noisy_pha, noisy_com = mag_phase_stft( noisy_wav, n_fft=n_fft_scaled, hop_size=hop_size_scaled, win_size=win_size_scaled, compress_factor=compress_factor, center=True, addeps=False ) amp_g, pha_g, _ = SE_model(noisy_mag, noisy_pha) # To remove "strange sweep artifact" mag = torch.expm1(RELU(amp_g)) # [1, F, T] zero_portion = torch.sum(mag==0, 1)/mag.shape[1] amp_g[:,:,(zero_portion>0.5)[0]] = 0 audio_g = mag_phase_istft(amp_g, pha_g, n_fft_scaled, hop_size_scaled, win_size_scaled, compress_factor) audio_g = pad_or_trim_to_match(noisy_wav.detach(), audio_g, pad_value=1e-8) # Align lengths using epsilon padding assert audio_g.shape == noisy_wav.shape, audio_g.shape output_file = os.path.join(args.output_folder + fname.replace(args.input_folder,'').split('.')[0]+'.flac') # save to .flac format torchaudio.save(output_file, audio_g.cpu(), noisy_sr) def main(): print('Initializing Inference Process...') parser = argparse.ArgumentParser() parser.add_argument('--input_folder') parser.add_argument('--output_folder') parser.add_argument('--config') parser.add_argument('--checkpoint_file', required=True) parser.add_argument('--BWE', default=None) args = parser.parse_args() global device if torch.cuda.is_available(): device = torch.device('cuda') else: raise RuntimeError("Currently, CPU mode is not supported.") inference(args, device) if __name__ == '__main__': main()