Spaces:
Running
Running
Upload ensemble.py with huggingface_hub
Browse files- ensemble.py +148 -0
ensemble.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import librosa
|
| 6 |
+
import soundfile as sf
|
| 7 |
+
import numpy as np
|
| 8 |
+
import argparse
|
| 9 |
+
|
| 10 |
+
def stft(wave, nfft, hl):
|
| 11 |
+
wave_left = np.asfortranarray(wave[0])
|
| 12 |
+
wave_right = np.asfortranarray(wave[1])
|
| 13 |
+
spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
|
| 14 |
+
spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
|
| 15 |
+
spec = np.asfortranarray([spec_left, spec_right])
|
| 16 |
+
return spec
|
| 17 |
+
|
| 18 |
+
def istft(spec, hl, length):
|
| 19 |
+
spec_left = np.asfortranarray(spec[0])
|
| 20 |
+
spec_right = np.asfortranarray(spec[1])
|
| 21 |
+
wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
|
| 22 |
+
wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
|
| 23 |
+
wave = np.asfortranarray([wave_left, wave_right])
|
| 24 |
+
return wave
|
| 25 |
+
|
| 26 |
+
def absmax(a, *, axis):
|
| 27 |
+
dims = list(a.shape)
|
| 28 |
+
dims.pop(axis)
|
| 29 |
+
indices = list(np.ogrid[tuple(slice(0, d) for d in dims)])
|
| 30 |
+
argmax = np.abs(a).argmax(axis=axis)
|
| 31 |
+
insert_pos = (len(a.shape) + axis) % len(a.shape)
|
| 32 |
+
indices.insert(insert_pos, argmax)
|
| 33 |
+
return a[tuple(indices)]
|
| 34 |
+
|
| 35 |
+
def absmin(a, *, axis):
|
| 36 |
+
dims = list(a.shape)
|
| 37 |
+
dims.pop(axis)
|
| 38 |
+
indices = list(np.ogrid[tuple(slice(0, d) for d in dims)])
|
| 39 |
+
argmax = np.abs(a).argmin(axis=axis)
|
| 40 |
+
insert_pos = (len(a.shape) + axis) % len(a.shape)
|
| 41 |
+
indices.insert(insert_pos, argmax)
|
| 42 |
+
return a[tuple(indices)]
|
| 43 |
+
|
| 44 |
+
def lambda_max(arr, axis=None, key=None, keepdims=False):
|
| 45 |
+
idxs = np.argmax(key(arr), axis)
|
| 46 |
+
if axis is not None:
|
| 47 |
+
idxs = np.expand_dims(idxs, axis)
|
| 48 |
+
result = np.take_along_axis(arr, idxs, axis)
|
| 49 |
+
if not keepdims:
|
| 50 |
+
result = np.squeeze(result, axis=axis)
|
| 51 |
+
return result
|
| 52 |
+
else:
|
| 53 |
+
return arr.flatten()[idxs]
|
| 54 |
+
|
| 55 |
+
def lambda_min(arr, axis=None, key=None, keepdims=False):
|
| 56 |
+
idxs = np.argmin(key(arr), axis)
|
| 57 |
+
if axis is not None:
|
| 58 |
+
idxs = np.expand_dims(idxs, axis)
|
| 59 |
+
result = np.take_along_axis(arr, idxs, axis)
|
| 60 |
+
if not keepdims:
|
| 61 |
+
result = np.squeeze(result, axis=axis)
|
| 62 |
+
return result
|
| 63 |
+
else:
|
| 64 |
+
return arr.flatten()[idxs]
|
| 65 |
+
|
| 66 |
+
def average_waveforms(pred_track, weights, algorithm):
|
| 67 |
+
"""
|
| 68 |
+
:param pred_track: shape = (num, channels, length)
|
| 69 |
+
:param weights: shape = (num, )
|
| 70 |
+
:param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
|
| 71 |
+
:return: averaged waveform in shape (channels, length)
|
| 72 |
+
"""
|
| 73 |
+
pred_track = np.array(pred_track)
|
| 74 |
+
final_length = pred_track.shape[-1]
|
| 75 |
+
|
| 76 |
+
mod_track = []
|
| 77 |
+
for i in range(pred_track.shape[0]):
|
| 78 |
+
if algorithm == 'avg_wave':
|
| 79 |
+
mod_track.append(pred_track[i] * weights[i])
|
| 80 |
+
elif algorithm in ['median_wave', 'min_wave', 'max_wave']:
|
| 81 |
+
mod_track.append(pred_track[i])
|
| 82 |
+
elif algorithm in ['avg_fft', 'min_fft', 'max_fft', 'median_fft']:
|
| 83 |
+
spec = stft(pred_track[i], nfft=2048, hl=1024)
|
| 84 |
+
if algorithm in ['avg_fft']:
|
| 85 |
+
mod_track.append(spec * weights[i])
|
| 86 |
+
else:
|
| 87 |
+
mod_track.append(spec)
|
| 88 |
+
pred_track = np.array(mod_track)
|
| 89 |
+
|
| 90 |
+
if algorithm in ['avg_wave']:
|
| 91 |
+
pred_track = pred_track.sum(axis=0)
|
| 92 |
+
pred_track /= np.array(weights).sum().T
|
| 93 |
+
elif algorithm in ['median_wave']:
|
| 94 |
+
pred_track = np.median(pred_track, axis=0)
|
| 95 |
+
elif algorithm in ['min_wave']:
|
| 96 |
+
pred_track = lambda_min(pred_track, axis=0, key=np.abs)
|
| 97 |
+
elif algorithm in ['max_wave']:
|
| 98 |
+
pred_track = lambda_max(pred_track, axis=0, key=np.abs)
|
| 99 |
+
elif algorithm in ['avg_fft']:
|
| 100 |
+
pred_track = pred_track.sum(axis=0)
|
| 101 |
+
pred_track /= np.array(weights).sum()
|
| 102 |
+
pred_track = istft(pred_track, 1024, final_length)
|
| 103 |
+
elif algorithm in ['min_fft']:
|
| 104 |
+
pred_track = lambda_min(pred_track, axis=0, key=np.abs)
|
| 105 |
+
pred_track = istft(pred_track, 1024, final_length)
|
| 106 |
+
elif algorithm in ['max_fft']:
|
| 107 |
+
pred_track = absmax(pred_track, axis=0)
|
| 108 |
+
pred_track = istft(pred_track, 1024, final_length)
|
| 109 |
+
elif algorithm in ['median_fft']:
|
| 110 |
+
pred_track = np.median(pred_track, axis=0)
|
| 111 |
+
pred_track = istft(pred_track, 1024, final_length)
|
| 112 |
+
return pred_track
|
| 113 |
+
|
| 114 |
+
def ensemble_files(args):
|
| 115 |
+
parser = argparse.ArgumentParser()
|
| 116 |
+
parser.add_argument("--files", type=str, required=True, nargs='+', help="Path to all audio-files to ensemble")
|
| 117 |
+
parser.add_argument("--type", type=str, default='avg_wave', help="One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft")
|
| 118 |
+
parser.add_argument("--weights", type=float, nargs='+', help="Weights to create ensemble. Number of weights must be equal to number of files")
|
| 119 |
+
parser.add_argument("--output", default="res.wav", type=str, help="Path to wav file where ensemble result will be stored")
|
| 120 |
+
args = parser.parse_args(args) if isinstance(args, list) else parser.parse_args()
|
| 121 |
+
|
| 122 |
+
print('Ensemble type: {}'.format(args.type))
|
| 123 |
+
print('Number of input files: {}'.format(len(args.files)))
|
| 124 |
+
if args.weights is not None:
|
| 125 |
+
weights = args.weights
|
| 126 |
+
else:
|
| 127 |
+
weights = np.ones(len(args.files))
|
| 128 |
+
print('Weights: {}'.format(weights))
|
| 129 |
+
print('Output file: {}'.format(args.output))
|
| 130 |
+
|
| 131 |
+
data = []
|
| 132 |
+
for f in args.files:
|
| 133 |
+
if not os.path.isfile(f):
|
| 134 |
+
print('Error. Can\'t find file: {}. Check paths.'.format(f))
|
| 135 |
+
return None
|
| 136 |
+
print('Reading file: {}'.format(f))
|
| 137 |
+
wav, sr = librosa.load(f, sr=None, mono=False)
|
| 138 |
+
print("Waveform shape: {} sample rate: {}".format(wav.shape, sr))
|
| 139 |
+
data.append(wav)
|
| 140 |
+
|
| 141 |
+
data = np.array(data)
|
| 142 |
+
res = average_waveforms(data, weights, args.type)
|
| 143 |
+
print('Result shape: {}'.format(res.shape))
|
| 144 |
+
sf.write(args.output, res.T, sr, 'FLOAT')
|
| 145 |
+
return args.output
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
ensemble_files(None)
|