| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import argparse |
| import torch |
| import torchaudio |
| import torch.nn as nn |
| import librosa |
| from models.stfts import mag_phase_stft, mag_phase_istft |
| from models.generator_SEMamba_time_d4 import SEMamba |
| from utils.util import load_config, pad_or_trim_to_match |
|
|
| RELU = nn.ReLU() |
|
|
| def get_filepaths(directory, file_type=None): |
| file_paths = [] |
| |
| for root, directories, files in os.walk(directory): |
| for filename in files: |
| |
| filepath = os.path.join(root, filename) |
| if file_type is not None: |
| if filepath.split('.')[-1] == file_type: |
| file_paths.append(filepath) |
| else: |
| file_paths.append(filepath) |
| return file_paths |
|
|
| def make_even(value): |
| value = int(round(value)) |
| return value if value % 2 == 0 else value + 1 |
|
|
| def inference(args, device): |
| cfg = load_config(args.config) |
| n_fft, hop_size, win_size = cfg['stft_cfg']['n_fft'], cfg['stft_cfg']['hop_size'], cfg['stft_cfg']['win_size'] |
| compress_factor = cfg['model_cfg']['compress_factor'] |
| sampling_rate = cfg['stft_cfg']['sampling_rate'] |
|
|
| SE_model = SEMamba(cfg).to(device) |
| state_dict = torch.load(args.checkpoint_file, map_location=device) |
| SE_model.load_state_dict(state_dict['generator']) |
| SE_model.eval() |
|
|
| os.makedirs(args.output_folder, exist_ok=True) |
| with torch.no_grad(): |
| for i, fname in enumerate(get_filepaths(args.input_folder)): |
| print(fname) |
| try: |
| os.makedirs(args.output_folder + fname[0:fname.rfind('/')].replace(args.input_folder,''), exist_ok=True) |
| noisy_wav, noisy_sr = torchaudio.load(fname) |
| except Exception as e: |
| print(f"Warning: cannot read {fname}, skipping. ({e})") |
| continue |
| |
| if args.BWE is not None: |
| opts = {"res_type": "kaiser_best"} |
| noisy_wav = librosa.resample(noisy_wav.cpu().numpy(), orig_sr=noisy_sr, target_sr=int(args.BWE), **opts) |
| noisy_sr = int(args.BWE) |
|
|
| noisy_wav = torch.FloatTensor(noisy_wav).to(device) |
| n_fft_scaled = make_even(n_fft * noisy_sr // sampling_rate) |
| hop_size_scaled = make_even(hop_size * noisy_sr // sampling_rate) |
| win_size_scaled = make_even(win_size * noisy_sr // sampling_rate) |
|
|
| noisy_mag, noisy_pha, noisy_com = mag_phase_stft( |
| noisy_wav, |
| n_fft=n_fft_scaled, |
| hop_size=hop_size_scaled, |
| win_size=win_size_scaled, |
| compress_factor=compress_factor, |
| center=True, |
| addeps=False |
| ) |
| amp_g, pha_g, _ = SE_model(noisy_mag, noisy_pha) |
| |
| mag = torch.expm1(RELU(amp_g)) |
| zero_portion = torch.sum(mag==0, 1)/mag.shape[1] |
| amp_g[:,:,(zero_portion>0.5)[0]] = 0 |
|
|
| audio_g = mag_phase_istft(amp_g, pha_g, n_fft_scaled, hop_size_scaled, win_size_scaled, compress_factor) |
| audio_g = pad_or_trim_to_match(noisy_wav.detach(), audio_g, pad_value=1e-8) |
| assert audio_g.shape == noisy_wav.shape, audio_g.shape |
|
|
| output_file = os.path.join(args.output_folder + fname.replace(args.input_folder,'').split('.')[0]+'.flac') |
| torchaudio.save(output_file, audio_g.cpu(), noisy_sr) |
| |
| def main(): |
| print('Initializing Inference Process...') |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--input_folder') |
| parser.add_argument('--output_folder') |
| parser.add_argument('--config') |
| parser.add_argument('--checkpoint_file', required=True) |
| parser.add_argument('--BWE', default=None) |
| args = parser.parse_args() |
|
|
| global device |
| if torch.cuda.is_available(): |
| device = torch.device('cuda') |
| else: |
| raise RuntimeError("Currently, CPU mode is not supported.") |
| |
| inference(args, device) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|