Spaces:
Running
Running
Update ensemble.py
Browse files- ensemble.py +65 -74
ensemble.py
CHANGED
|
@@ -6,23 +6,26 @@ import librosa
|
|
| 6 |
import soundfile as sf
|
| 7 |
import numpy as np
|
| 8 |
import argparse
|
| 9 |
-
import
|
| 10 |
import gc
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
def stft(wave, nfft, hl):
|
| 13 |
-
wave_left = np.
|
| 14 |
-
wave_right = np.
|
| 15 |
spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
|
| 16 |
spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
|
| 17 |
-
spec = np.
|
| 18 |
return spec
|
| 19 |
|
| 20 |
def istft(spec, hl, length):
|
| 21 |
-
spec_left = np.
|
| 22 |
-
spec_right = np.
|
| 23 |
wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
|
| 24 |
wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
|
| 25 |
-
wave = np.
|
| 26 |
return wave
|
| 27 |
|
| 28 |
def absmax(a, *, axis):
|
|
@@ -72,7 +75,7 @@ def average_waveforms(pred_track, weights, algorithm):
|
|
| 72 |
:param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
|
| 73 |
:return: averaged waveform in shape (channels, length)
|
| 74 |
"""
|
| 75 |
-
pred_track = np.
|
| 76 |
final_length = pred_track.shape[-1]
|
| 77 |
|
| 78 |
mod_track = []
|
|
@@ -83,103 +86,91 @@ def average_waveforms(pred_track, weights, algorithm):
|
|
| 83 |
mod_track.append(pred_track[i])
|
| 84 |
elif algorithm in ['avg_fft', 'min_fft', 'max_fft', 'median_fft']:
|
| 85 |
spec = stft(pred_track[i], nfft=2048, hl=1024)
|
| 86 |
-
if algorithm
|
| 87 |
mod_track.append(spec * weights[i])
|
| 88 |
else:
|
| 89 |
mod_track.append(spec)
|
| 90 |
del spec
|
| 91 |
gc.collect()
|
| 92 |
-
|
| 93 |
|
| 94 |
-
if algorithm
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
pred_track = np.median(pred_track, axis=0)
|
| 115 |
-
pred_track = istft(pred_track, 1024, final_length)
|
| 116 |
|
| 117 |
gc.collect()
|
| 118 |
-
return
|
| 119 |
|
| 120 |
def ensemble_files(args):
|
| 121 |
-
parser = argparse.ArgumentParser()
|
| 122 |
-
parser.add_argument(
|
| 123 |
-
parser.add_argument(
|
| 124 |
-
parser.add_argument(
|
| 125 |
-
parser.add_argument(
|
| 126 |
|
| 127 |
-
|
| 128 |
-
args = parser.parse_args(args) if isinstance(args, list) else parser.parse_args()
|
| 129 |
-
except SystemExit:
|
| 130 |
-
print("Error: Invalid command-line arguments. Check --files, --type, --weights, and --output.")
|
| 131 |
-
return None
|
| 132 |
-
|
| 133 |
-
print('Ensemble type: {}'.format(args.type))
|
| 134 |
-
print('Number of input files: {}'.format(len(args.files)))
|
| 135 |
-
if args.weights is not None:
|
| 136 |
-
weights = args.weights
|
| 137 |
-
if len(weights) != len(args.files):
|
| 138 |
-
print('Error: Number of weights must match number of audio files.')
|
| 139 |
-
return None
|
| 140 |
-
else:
|
| 141 |
-
weights = np.ones(len(args.files))
|
| 142 |
-
print('Weights: {}'.format(weights))
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
data = []
|
| 151 |
sr = None
|
| 152 |
for f in args.files:
|
| 153 |
if not os.path.isfile(f):
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
try:
|
| 158 |
wav, curr_sr = librosa.load(f, sr=None, mono=False)
|
| 159 |
if sr is None:
|
| 160 |
sr = curr_sr
|
| 161 |
elif sr != curr_sr:
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
data.append(wav)
|
| 166 |
del wav
|
| 167 |
gc.collect()
|
| 168 |
except Exception as e:
|
| 169 |
-
|
| 170 |
-
|
| 171 |
|
| 172 |
try:
|
| 173 |
-
data = np.
|
| 174 |
res = average_waveforms(data, weights, args.type)
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
| 178 |
except Exception as e:
|
| 179 |
-
|
| 180 |
-
|
| 181 |
finally:
|
| 182 |
gc.collect()
|
| 183 |
|
| 184 |
if __name__ == "__main__":
|
| 185 |
-
ensemble_files(
|
|
|
|
| 6 |
import soundfile as sf
|
| 7 |
import numpy as np
|
| 8 |
import argparse
|
| 9 |
+
import logging
|
| 10 |
import gc
|
| 11 |
|
| 12 |
+
logging.basicConfig(level=logging.INFO)
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
def stft(wave, nfft, hl):
|
| 16 |
+
wave_left = np.ascontiguousarray(wave[0])
|
| 17 |
+
wave_right = np.ascontiguousarray(wave[1])
|
| 18 |
spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
|
| 19 |
spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
|
| 20 |
+
spec = np.stack([spec_left, spec_right])
|
| 21 |
return spec
|
| 22 |
|
| 23 |
def istft(spec, hl, length):
|
| 24 |
+
spec_left = np.ascontiguousarray(spec[0])
|
| 25 |
+
spec_right = np.ascontiguousarray(spec[1])
|
| 26 |
wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
|
| 27 |
wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
|
| 28 |
+
wave = np.stack([wave_left, wave_right])
|
| 29 |
return wave
|
| 30 |
|
| 31 |
def absmax(a, *, axis):
|
|
|
|
| 75 |
:param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
|
| 76 |
:return: averaged waveform in shape (channels, length)
|
| 77 |
"""
|
| 78 |
+
pred_track = np.asarray(pred_track) # NumPy 2.0+ compatibility
|
| 79 |
final_length = pred_track.shape[-1]
|
| 80 |
|
| 81 |
mod_track = []
|
|
|
|
| 86 |
mod_track.append(pred_track[i])
|
| 87 |
elif algorithm in ['avg_fft', 'min_fft', 'max_fft', 'median_fft']:
|
| 88 |
spec = stft(pred_track[i], nfft=2048, hl=1024)
|
| 89 |
+
if algorithm == 'avg_fft':
|
| 90 |
mod_track.append(spec * weights[i])
|
| 91 |
else:
|
| 92 |
mod_track.append(spec)
|
| 93 |
del spec
|
| 94 |
gc.collect()
|
| 95 |
+
mod_track = np.asarray(mod_track) # NumPy 2.0+ compatibility
|
| 96 |
|
| 97 |
+
if algorithm == 'avg_wave':
|
| 98 |
+
result = mod_track.sum(axis=0) / np.sum(weights)
|
| 99 |
+
elif algorithm == 'median_wave':
|
| 100 |
+
result = np.median(mod_track, axis=0)
|
| 101 |
+
elif algorithm == 'min_wave':
|
| 102 |
+
result = lambda_min(mod_track, axis=0, key=np.abs)
|
| 103 |
+
elif algorithm == 'max_wave':
|
| 104 |
+
result = lambda_max(mod_track, axis=0, key=np.abs)
|
| 105 |
+
elif algorithm == 'avg_fft':
|
| 106 |
+
result = mod_track.sum(axis=0) / np.sum(weights)
|
| 107 |
+
result = istft(result, 1024, final_length)
|
| 108 |
+
elif algorithm == 'min_fft':
|
| 109 |
+
result = lambda_min(mod_track, axis=0, key=np.abs)
|
| 110 |
+
result = istft(result, 1024, final_length)
|
| 111 |
+
elif algorithm == 'max_fft':
|
| 112 |
+
result = absmax(mod_track, axis=0)
|
| 113 |
+
result = istft(result, 1024, final_length)
|
| 114 |
+
elif algorithm == 'median_fft':
|
| 115 |
+
result = np.median(mod_track, axis=0)
|
| 116 |
+
result = istft(result, 1024, final_length)
|
|
|
|
|
|
|
| 117 |
|
| 118 |
gc.collect()
|
| 119 |
+
return result
|
| 120 |
|
| 121 |
def ensemble_files(args):
|
| 122 |
+
parser = argparse.ArgumentParser(description="Ensemble audio files")
|
| 123 |
+
parser.add_argument('--files', nargs='+', required=True, help="Input audio files")
|
| 124 |
+
parser.add_argument('--type', required=True, choices=['avg_wave', 'median_wave', 'max_wave', 'min_wave', 'avg_fft', 'median_fft', 'max_fft', 'min_fft'], help="Ensemble type")
|
| 125 |
+
parser.add_argument('--weights', nargs='+', type=float, default=None, help="Weights for each file")
|
| 126 |
+
parser.add_argument('--output', required=True, help="Output file path")
|
| 127 |
|
| 128 |
+
args = parser.parse_args(args) if isinstance(args, list) else args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
+
logger.info(f"Ensemble type: {args.type}")
|
| 131 |
+
logger.info(f"Number of input files: {len(args.files)}")
|
| 132 |
+
weights = args.weights if args.weights else [1.0] * len(args.files)
|
| 133 |
+
if len(weights) != len(args.files):
|
| 134 |
+
logger.error("Number of weights must match number of audio files")
|
| 135 |
+
raise ValueError("Number of weights must match number of audio files")
|
| 136 |
+
logger.info(f"Weights: {weights}")
|
| 137 |
+
logger.info(f"Output file: {args.output}")
|
| 138 |
|
| 139 |
data = []
|
| 140 |
sr = None
|
| 141 |
for f in args.files:
|
| 142 |
if not os.path.isfile(f):
|
| 143 |
+
logger.error(f"Cannot find file: {f}")
|
| 144 |
+
raise FileNotFoundError(f"Cannot find file: {f}")
|
| 145 |
+
logger.info(f"Reading file: {f}")
|
| 146 |
try:
|
| 147 |
wav, curr_sr = librosa.load(f, sr=None, mono=False)
|
| 148 |
if sr is None:
|
| 149 |
sr = curr_sr
|
| 150 |
elif sr != curr_sr:
|
| 151 |
+
logger.error("All audio files must have the same sample rate")
|
| 152 |
+
raise ValueError("All audio files must have the same sample rate")
|
| 153 |
+
logger.info(f"Waveform shape: {wav.shape} sample rate: {sr}")
|
| 154 |
data.append(wav)
|
| 155 |
del wav
|
| 156 |
gc.collect()
|
| 157 |
except Exception as e:
|
| 158 |
+
logger.error(f"Error reading audio file {f}: {str(e)}")
|
| 159 |
+
raise RuntimeError(f"Error reading audio file {f}: {str(e)}")
|
| 160 |
|
| 161 |
try:
|
| 162 |
+
data = np.asarray(data) # NumPy 2.0+ compatibility
|
| 163 |
res = average_waveforms(data, weights, args.type)
|
| 164 |
+
logger.info(f"Result shape: {res.shape}")
|
| 165 |
+
os.makedirs(os.path.dirname(args.output), exist_ok=True)
|
| 166 |
+
sf.write(args.output, res.T, sr, 'FLOAT')
|
| 167 |
+
logger.info(f"Output written to: {args.output}")
|
| 168 |
+
return args.output
|
| 169 |
except Exception as e:
|
| 170 |
+
logger.error(f"Error during ensemble processing: {str(e)}")
|
| 171 |
+
raise RuntimeError(f"Error during ensemble processing: {str(e)}")
|
| 172 |
finally:
|
| 173 |
gc.collect()
|
| 174 |
|
| 175 |
if __name__ == "__main__":
|
| 176 |
+
ensemble_files(sys.argv[1:])
|