szuweifu commited on
Commit
dfc9065
·
verified ·
1 Parent(s): 33dc378

Upload 2 files

Browse files
Files changed (2) hide show
  1. inference.py +7 -0
  2. inference_chunk.py +6 -0
inference.py CHANGED
@@ -10,11 +10,14 @@ import os
10
  import argparse
11
  import torch
12
  import torchaudio
 
13
  import librosa
14
  from models.stfts import mag_phase_stft, mag_phase_istft
15
  from models.generator_SEMamba_time_d4 import SEMamba
16
  from utils.util import load_config, pad_or_trim_to_match
17
 
 
 
18
  def get_filepaths(directory, file_type=None):
19
  file_paths = [] # List which will store all of the full filepaths.
20
  # Walk the tree.
@@ -75,6 +78,10 @@ def inference(args, device):
75
  addeps=False
76
  )
77
  amp_g, pha_g, _ = SE_model(noisy_mag, noisy_pha)
 
 
 
 
78
 
79
  audio_g = mag_phase_istft(amp_g, pha_g, n_fft_scaled, hop_size_scaled, win_size_scaled, compress_factor)
80
  audio_g = pad_or_trim_to_match(noisy_wav.detach(), audio_g, pad_value=1e-8) # Align lengths using epsilon padding
 
10
  import argparse
11
  import torch
12
  import torchaudio
13
+ import torch.nn as nn
14
  import librosa
15
  from models.stfts import mag_phase_stft, mag_phase_istft
16
  from models.generator_SEMamba_time_d4 import SEMamba
17
  from utils.util import load_config, pad_or_trim_to_match
18
 
19
+ RELU = nn.ReLU()
20
+
21
  def get_filepaths(directory, file_type=None):
22
  file_paths = [] # List which will store all of the full filepaths.
23
  # Walk the tree.
 
78
  addeps=False
79
  )
80
  amp_g, pha_g, _ = SE_model(noisy_mag, noisy_pha)
81
+ # To remove "strange sweep artifact"
82
+ mag = torch.expm1(RELU(amp_g)) # [1, F, T]
83
+ zero_portion = torch.sum(mag==0, 1)/mag.shape[1]
84
+ amp_g[:,:,(zero_portion>0.5)[0]] = 0
85
 
86
  audio_g = mag_phase_istft(amp_g, pha_g, n_fft_scaled, hop_size_scaled, win_size_scaled, compress_factor)
87
  audio_g = pad_or_trim_to_match(noisy_wav.detach(), audio_g, pad_value=1e-8) # Align lengths using epsilon padding
inference_chunk.py CHANGED
@@ -10,12 +10,14 @@ import os
10
  import argparse
11
  import torch
12
  import torchaudio
 
13
  import librosa
14
  import math
15
  from models.stfts import mag_phase_stft, mag_phase_istft
16
  from models.generator_SEMamba_time_d4 import SEMamba
17
  from utils.util import load_config, pad_or_trim_to_match
18
 
 
19
 
20
  def get_filepaths(directory, file_type=None):
21
  file_paths = [] # List which will store all of the full filepaths.
@@ -89,6 +91,10 @@ def inference(args, device):
89
  addeps=False
90
  )
91
  amp_g, pha_g, _ = SE_model(noisy_mag, noisy_pha)
 
 
 
 
92
 
93
  audio_g = mag_phase_istft(amp_g, pha_g, n_fft_scaled, hop_size_scaled, win_size_scaled, compress_factor)
94
  audio_g = pad_or_trim_to_match(noisy_wav_chunk.detach(), audio_g, pad_value=1e-8) # Align lengths using epsilon padding
 
10
  import argparse
11
  import torch
12
  import torchaudio
13
+ import torch.nn as nn
14
  import librosa
15
  import math
16
  from models.stfts import mag_phase_stft, mag_phase_istft
17
  from models.generator_SEMamba_time_d4 import SEMamba
18
  from utils.util import load_config, pad_or_trim_to_match
19
 
20
+ RELU = nn.ReLU()
21
 
22
  def get_filepaths(directory, file_type=None):
23
  file_paths = [] # List which will store all of the full filepaths.
 
91
  addeps=False
92
  )
93
  amp_g, pha_g, _ = SE_model(noisy_mag, noisy_pha)
94
+ # To remove "strange sweep artifact"
95
+ mag = torch.expm1(RELU(amp_g)) # [1, F, T]
96
+ zero_portion = torch.sum(mag==0, 1)/mag.shape[1]
97
+ amp_g[:,:,(zero_portion>0.5)[0]] = 0
98
 
99
  audio_g = mag_phase_istft(amp_g, pha_g, n_fft_scaled, hop_size_scaled, win_size_scaled, compress_factor)
100
  audio_g = pad_or_trim_to_match(noisy_wav_chunk.detach(), audio_g, pad_value=1e-8) # Align lengths using epsilon padding