Upload 2 files
Browse files- inference.py +7 -0
- inference_chunk.py +6 -0
inference.py
CHANGED
|
@@ -10,11 +10,14 @@ import os
|
|
| 10 |
import argparse
|
| 11 |
import torch
|
| 12 |
import torchaudio
|
|
|
|
| 13 |
import librosa
|
| 14 |
from models.stfts import mag_phase_stft, mag_phase_istft
|
| 15 |
from models.generator_SEMamba_time_d4 import SEMamba
|
| 16 |
from utils.util import load_config, pad_or_trim_to_match
|
| 17 |
|
|
|
|
|
|
|
| 18 |
def get_filepaths(directory, file_type=None):
|
| 19 |
file_paths = [] # List which will store all of the full filepaths.
|
| 20 |
# Walk the tree.
|
|
@@ -75,6 +78,10 @@ def inference(args, device):
|
|
| 75 |
addeps=False
|
| 76 |
)
|
| 77 |
amp_g, pha_g, _ = SE_model(noisy_mag, noisy_pha)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
audio_g = mag_phase_istft(amp_g, pha_g, n_fft_scaled, hop_size_scaled, win_size_scaled, compress_factor)
|
| 80 |
audio_g = pad_or_trim_to_match(noisy_wav.detach(), audio_g, pad_value=1e-8) # Align lengths using epsilon padding
|
|
|
|
| 10 |
import argparse
|
| 11 |
import torch
|
| 12 |
import torchaudio
|
| 13 |
+
import torch.nn as nn
|
| 14 |
import librosa
|
| 15 |
from models.stfts import mag_phase_stft, mag_phase_istft
|
| 16 |
from models.generator_SEMamba_time_d4 import SEMamba
|
| 17 |
from utils.util import load_config, pad_or_trim_to_match
|
| 18 |
|
| 19 |
+
RELU = nn.ReLU()
|
| 20 |
+
|
| 21 |
def get_filepaths(directory, file_type=None):
|
| 22 |
file_paths = [] # List which will store all of the full filepaths.
|
| 23 |
# Walk the tree.
|
|
|
|
| 78 |
addeps=False
|
| 79 |
)
|
| 80 |
amp_g, pha_g, _ = SE_model(noisy_mag, noisy_pha)
|
| 81 |
+
# To remove "strange sweep artifact"
|
| 82 |
+
mag = torch.expm1(RELU(amp_g)) # [1, F, T]
|
| 83 |
+
zero_portion = torch.sum(mag==0, 1)/mag.shape[1]
|
| 84 |
+
amp_g[:,:,(zero_portion>0.5)[0]] = 0
|
| 85 |
|
| 86 |
audio_g = mag_phase_istft(amp_g, pha_g, n_fft_scaled, hop_size_scaled, win_size_scaled, compress_factor)
|
| 87 |
audio_g = pad_or_trim_to_match(noisy_wav.detach(), audio_g, pad_value=1e-8) # Align lengths using epsilon padding
|
inference_chunk.py
CHANGED
|
@@ -10,12 +10,14 @@ import os
|
|
| 10 |
import argparse
|
| 11 |
import torch
|
| 12 |
import torchaudio
|
|
|
|
| 13 |
import librosa
|
| 14 |
import math
|
| 15 |
from models.stfts import mag_phase_stft, mag_phase_istft
|
| 16 |
from models.generator_SEMamba_time_d4 import SEMamba
|
| 17 |
from utils.util import load_config, pad_or_trim_to_match
|
| 18 |
|
|
|
|
| 19 |
|
| 20 |
def get_filepaths(directory, file_type=None):
|
| 21 |
file_paths = [] # List which will store all of the full filepaths.
|
|
@@ -89,6 +91,10 @@ def inference(args, device):
|
|
| 89 |
addeps=False
|
| 90 |
)
|
| 91 |
amp_g, pha_g, _ = SE_model(noisy_mag, noisy_pha)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
audio_g = mag_phase_istft(amp_g, pha_g, n_fft_scaled, hop_size_scaled, win_size_scaled, compress_factor)
|
| 94 |
audio_g = pad_or_trim_to_match(noisy_wav_chunk.detach(), audio_g, pad_value=1e-8) # Align lengths using epsilon padding
|
|
|
|
| 10 |
import argparse
|
| 11 |
import torch
|
| 12 |
import torchaudio
|
| 13 |
+
import torch.nn as nn
|
| 14 |
import librosa
|
| 15 |
import math
|
| 16 |
from models.stfts import mag_phase_stft, mag_phase_istft
|
| 17 |
from models.generator_SEMamba_time_d4 import SEMamba
|
| 18 |
from utils.util import load_config, pad_or_trim_to_match
|
| 19 |
|
| 20 |
+
RELU = nn.ReLU()
|
| 21 |
|
| 22 |
def get_filepaths(directory, file_type=None):
|
| 23 |
file_paths = [] # List which will store all of the full filepaths.
|
|
|
|
| 91 |
addeps=False
|
| 92 |
)
|
| 93 |
amp_g, pha_g, _ = SE_model(noisy_mag, noisy_pha)
|
| 94 |
+
# To remove "strange sweep artifact"
|
| 95 |
+
mag = torch.expm1(RELU(amp_g)) # [1, F, T]
|
| 96 |
+
zero_portion = torch.sum(mag==0, 1)/mag.shape[1]
|
| 97 |
+
amp_g[:,:,(zero_portion>0.5)[0]] = 0
|
| 98 |
|
| 99 |
audio_g = mag_phase_istft(amp_g, pha_g, n_fft_scaled, hop_size_scaled, win_size_scaled, compress_factor)
|
| 100 |
audio_g = pad_or_trim_to_match(noisy_wav_chunk.detach(), audio_g, pad_value=1e-8) # Align lengths using epsilon padding
|