Spaces:
Running
Running
Update ensemble.py
Browse files- ensemble.py +50 -13
ensemble.py
CHANGED
|
@@ -6,6 +6,8 @@ import librosa
|
|
| 6 |
import soundfile as sf
|
| 7 |
import numpy as np
|
| 8 |
import argparse
|
|
|
|
|
|
|
| 9 |
|
| 10 |
def stft(wave, nfft, hl):
|
| 11 |
wave_left = np.asfortranarray(wave[0])
|
|
@@ -70,7 +72,7 @@ def average_waveforms(pred_track, weights, algorithm):
|
|
| 70 |
:param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
|
| 71 |
:return: averaged waveform in shape (channels, length)
|
| 72 |
"""
|
| 73 |
-
pred_track = np.array(pred_track)
|
| 74 |
final_length = pred_track.shape[-1]
|
| 75 |
|
| 76 |
mod_track = []
|
|
@@ -85,11 +87,13 @@ def average_waveforms(pred_track, weights, algorithm):
|
|
| 85 |
mod_track.append(spec * weights[i])
|
| 86 |
else:
|
| 87 |
mod_track.append(spec)
|
| 88 |
-
|
|
|
|
|
|
|
| 89 |
|
| 90 |
if algorithm in ['avg_wave']:
|
| 91 |
pred_track = pred_track.sum(axis=0)
|
| 92 |
-
pred_track /= np.array(weights).sum()
|
| 93 |
elif algorithm in ['median_wave']:
|
| 94 |
pred_track = np.median(pred_track, axis=0)
|
| 95 |
elif algorithm in ['min_wave']:
|
|
@@ -109,6 +113,8 @@ def average_waveforms(pred_track, weights, algorithm):
|
|
| 109 |
elif algorithm in ['median_fft']:
|
| 110 |
pred_track = np.median(pred_track, axis=0)
|
| 111 |
pred_track = istft(pred_track, 1024, final_length)
|
|
|
|
|
|
|
| 112 |
return pred_track
|
| 113 |
|
| 114 |
def ensemble_files(args):
|
|
@@ -117,32 +123,63 @@ def ensemble_files(args):
|
|
| 117 |
parser.add_argument("--type", type=str, default='avg_wave', help="One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft")
|
| 118 |
parser.add_argument("--weights", type=float, nargs='+', help="Weights to create ensemble. Number of weights must be equal to number of files")
|
| 119 |
parser.add_argument("--output", default="res.wav", type=str, help="Path to wav file where ensemble result will be stored")
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
print('Ensemble type: {}'.format(args.type))
|
| 123 |
print('Number of input files: {}'.format(len(args.files)))
|
| 124 |
if args.weights is not None:
|
| 125 |
weights = args.weights
|
|
|
|
|
|
|
|
|
|
| 126 |
else:
|
| 127 |
weights = np.ones(len(args.files))
|
| 128 |
print('Weights: {}'.format(weights))
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
data = []
|
|
|
|
| 132 |
for f in args.files:
|
| 133 |
if not os.path.isfile(f):
|
| 134 |
print('Error. Can\'t find file: {}. Check paths.'.format(f))
|
| 135 |
return None
|
| 136 |
print('Reading file: {}'.format(f))
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
if __name__ == "__main__":
|
| 148 |
ensemble_files(None)
|
|
|
|
| 6 |
import soundfile as sf
|
| 7 |
import numpy as np
|
| 8 |
import argparse
|
| 9 |
+
import uuid
|
| 10 |
+
import gc
|
| 11 |
|
| 12 |
def stft(wave, nfft, hl):
|
| 13 |
wave_left = np.asfortranarray(wave[0])
|
|
|
|
| 72 |
:param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
|
| 73 |
:return: averaged waveform in shape (channels, length)
|
| 74 |
"""
|
| 75 |
+
pred_track = np.array(pred_track, copy=False)
|
| 76 |
final_length = pred_track.shape[-1]
|
| 77 |
|
| 78 |
mod_track = []
|
|
|
|
| 87 |
mod_track.append(spec * weights[i])
|
| 88 |
else:
|
| 89 |
mod_track.append(spec)
|
| 90 |
+
del spec
|
| 91 |
+
gc.collect()
|
| 92 |
+
pred_track = np.array(mod_track, copy=False)
|
| 93 |
|
| 94 |
if algorithm in ['avg_wave']:
|
| 95 |
pred_track = pred_track.sum(axis=0)
|
| 96 |
+
pred_track /= np.array(weights).sum()
|
| 97 |
elif algorithm in ['median_wave']:
|
| 98 |
pred_track = np.median(pred_track, axis=0)
|
| 99 |
elif algorithm in ['min_wave']:
|
|
|
|
| 113 |
elif algorithm in ['median_fft']:
|
| 114 |
pred_track = np.median(pred_track, axis=0)
|
| 115 |
pred_track = istft(pred_track, 1024, final_length)
|
| 116 |
+
|
| 117 |
+
gc.collect()
|
| 118 |
return pred_track
|
| 119 |
|
| 120 |
def ensemble_files(args):
|
|
|
|
| 123 |
parser.add_argument("--type", type=str, default='avg_wave', help="One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft")
|
| 124 |
parser.add_argument("--weights", type=float, nargs='+', help="Weights to create ensemble. Number of weights must be equal to number of files")
|
| 125 |
parser.add_argument("--output", default="res.wav", type=str, help="Path to wav file where ensemble result will be stored")
|
| 126 |
+
|
| 127 |
+
try:
|
| 128 |
+
args = parser.parse_args(args) if isinstance(args, list) else parser.parse_args()
|
| 129 |
+
except SystemExit:
|
| 130 |
+
print("Error: Invalid command-line arguments. Check --files, --type, --weights, and --output.")
|
| 131 |
+
return None
|
| 132 |
|
| 133 |
print('Ensemble type: {}'.format(args.type))
|
| 134 |
print('Number of input files: {}'.format(len(args.files)))
|
| 135 |
if args.weights is not None:
|
| 136 |
weights = args.weights
|
| 137 |
+
if len(weights) != len(args.files):
|
| 138 |
+
print('Error: Number of weights must match number of audio files.')
|
| 139 |
+
return None
|
| 140 |
else:
|
| 141 |
weights = np.ones(len(args.files))
|
| 142 |
print('Weights: {}'.format(weights))
|
| 143 |
+
|
| 144 |
+
# Validate output name
|
| 145 |
+
if not args.output.endswith('.wav'):
|
| 146 |
+
args.output += '.wav'
|
| 147 |
+
output_path = os.path.join('/tmp', str(uuid.uuid4()) + '_' + args.output)
|
| 148 |
+
print('Output file: {}'.format(output_path))
|
| 149 |
|
| 150 |
data = []
|
| 151 |
+
sr = None
|
| 152 |
for f in args.files:
|
| 153 |
if not os.path.isfile(f):
|
| 154 |
print('Error. Can\'t find file: {}. Check paths.'.format(f))
|
| 155 |
return None
|
| 156 |
print('Reading file: {}'.format(f))
|
| 157 |
+
try:
|
| 158 |
+
wav, curr_sr = librosa.load(f, sr=None, mono=False)
|
| 159 |
+
if sr is None:
|
| 160 |
+
sr = curr_sr
|
| 161 |
+
elif sr != curr_sr:
|
| 162 |
+
print('Error: All audio files must have the same sample rate.')
|
| 163 |
+
return None
|
| 164 |
+
print("Waveform shape: {} sample rate: {}".format(wav.shape, sr))
|
| 165 |
+
data.append(wav)
|
| 166 |
+
del wav
|
| 167 |
+
gc.collect()
|
| 168 |
+
except Exception as e:
|
| 169 |
+
print(f'Error reading audio file {f}: {str(e)}')
|
| 170 |
+
return None
|
| 171 |
|
| 172 |
+
try:
|
| 173 |
+
data = np.array(data, copy=False)
|
| 174 |
+
res = average_waveforms(data, weights, args.type)
|
| 175 |
+
print('Result shape: {}'.format(res.shape))
|
| 176 |
+
sf.write(output_path, res.T, sr, 'FLOAT')
|
| 177 |
+
return output_path
|
| 178 |
+
except Exception as e:
|
| 179 |
+
print(f'Error during ensemble processing: {str(e)}')
|
| 180 |
+
return None
|
| 181 |
+
finally:
|
| 182 |
+
gc.collect()
|
| 183 |
|
| 184 |
if __name__ == "__main__":
|
| 185 |
ensemble_files(None)
|