ASesYusuf1 commited on
Commit
cff3f6e
·
verified ·
1 Parent(s): 6db3d10

Update ensemble.py

Browse files
Files changed (1) hide show
  1. ensemble.py +50 -13
ensemble.py CHANGED
@@ -6,6 +6,8 @@ import librosa
6
  import soundfile as sf
7
  import numpy as np
8
  import argparse
 
 
9
 
10
  def stft(wave, nfft, hl):
11
  wave_left = np.asfortranarray(wave[0])
@@ -70,7 +72,7 @@ def average_waveforms(pred_track, weights, algorithm):
70
  :param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
71
  :return: averaged waveform in shape (channels, length)
72
  """
73
- pred_track = np.array(pred_track)
74
  final_length = pred_track.shape[-1]
75
 
76
  mod_track = []
@@ -85,11 +87,13 @@ def average_waveforms(pred_track, weights, algorithm):
85
  mod_track.append(spec * weights[i])
86
  else:
87
  mod_track.append(spec)
88
- pred_track = np.array(mod_track)
 
 
89
 
90
  if algorithm in ['avg_wave']:
91
  pred_track = pred_track.sum(axis=0)
92
- pred_track /= np.array(weights).sum().T
93
  elif algorithm in ['median_wave']:
94
  pred_track = np.median(pred_track, axis=0)
95
  elif algorithm in ['min_wave']:
@@ -109,6 +113,8 @@ def average_waveforms(pred_track, weights, algorithm):
109
  elif algorithm in ['median_fft']:
110
  pred_track = np.median(pred_track, axis=0)
111
  pred_track = istft(pred_track, 1024, final_length)
 
 
112
  return pred_track
113
 
114
  def ensemble_files(args):
@@ -117,32 +123,63 @@ def ensemble_files(args):
117
  parser.add_argument("--type", type=str, default='avg_wave', help="One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft")
118
  parser.add_argument("--weights", type=float, nargs='+', help="Weights to create ensemble. Number of weights must be equal to number of files")
119
  parser.add_argument("--output", default="res.wav", type=str, help="Path to wav file where ensemble result will be stored")
120
- args = parser.parse_args(args) if isinstance(args, list) else parser.parse_args()
 
 
 
 
 
121
 
122
  print('Ensemble type: {}'.format(args.type))
123
  print('Number of input files: {}'.format(len(args.files)))
124
  if args.weights is not None:
125
  weights = args.weights
 
 
 
126
  else:
127
  weights = np.ones(len(args.files))
128
  print('Weights: {}'.format(weights))
129
- print('Output file: {}'.format(args.output))
 
 
 
 
 
130
 
131
  data = []
 
132
  for f in args.files:
133
  if not os.path.isfile(f):
134
  print('Error. Can\'t find file: {}. Check paths.'.format(f))
135
  return None
136
  print('Reading file: {}'.format(f))
137
- wav, sr = librosa.load(f, sr=None, mono=False)
138
- print("Waveform shape: {} sample rate: {}".format(wav.shape, sr))
139
- data.append(wav)
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- data = np.array(data)
142
- res = average_waveforms(data, weights, args.type)
143
- print('Result shape: {}'.format(res.shape))
144
- sf.write(args.output, res.T, sr, 'FLOAT')
145
- return args.output
 
 
 
 
 
 
146
 
147
  if __name__ == "__main__":
148
  ensemble_files(None)
 
6
  import soundfile as sf
7
  import numpy as np
8
  import argparse
9
+ import uuid
10
+ import gc
11
 
12
  def stft(wave, nfft, hl):
13
  wave_left = np.asfortranarray(wave[0])
 
72
  :param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
73
  :return: averaged waveform in shape (channels, length)
74
  """
75
+ pred_track = np.array(pred_track, copy=False)
76
  final_length = pred_track.shape[-1]
77
 
78
  mod_track = []
 
87
  mod_track.append(spec * weights[i])
88
  else:
89
  mod_track.append(spec)
90
+ del spec
91
+ gc.collect()
92
+ pred_track = np.array(mod_track, copy=False)
93
 
94
  if algorithm in ['avg_wave']:
95
  pred_track = pred_track.sum(axis=0)
96
+ pred_track /= np.array(weights).sum()
97
  elif algorithm in ['median_wave']:
98
  pred_track = np.median(pred_track, axis=0)
99
  elif algorithm in ['min_wave']:
 
113
  elif algorithm in ['median_fft']:
114
  pred_track = np.median(pred_track, axis=0)
115
  pred_track = istft(pred_track, 1024, final_length)
116
+
117
+ gc.collect()
118
  return pred_track
119
 
120
  def ensemble_files(args):
 
123
  parser.add_argument("--type", type=str, default='avg_wave', help="One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft")
124
  parser.add_argument("--weights", type=float, nargs='+', help="Weights to create ensemble. Number of weights must be equal to number of files")
125
  parser.add_argument("--output", default="res.wav", type=str, help="Path to wav file where ensemble result will be stored")
126
+
127
+ try:
128
+ args = parser.parse_args(args) if isinstance(args, list) else parser.parse_args()
129
+ except SystemExit:
130
+ print("Error: Invalid command-line arguments. Check --files, --type, --weights, and --output.")
131
+ return None
132
 
133
  print('Ensemble type: {}'.format(args.type))
134
  print('Number of input files: {}'.format(len(args.files)))
135
  if args.weights is not None:
136
  weights = args.weights
137
+ if len(weights) != len(args.files):
138
+ print('Error: Number of weights must match number of audio files.')
139
+ return None
140
  else:
141
  weights = np.ones(len(args.files))
142
  print('Weights: {}'.format(weights))
143
+
144
+ # Validate output name
145
+ if not args.output.endswith('.wav'):
146
+ args.output += '.wav'
147
+ output_path = os.path.join('/tmp', str(uuid.uuid4()) + '_' + args.output)
148
+ print('Output file: {}'.format(output_path))
149
 
150
  data = []
151
+ sr = None
152
  for f in args.files:
153
  if not os.path.isfile(f):
154
  print('Error. Can\'t find file: {}. Check paths.'.format(f))
155
  return None
156
  print('Reading file: {}'.format(f))
157
+ try:
158
+ wav, curr_sr = librosa.load(f, sr=None, mono=False)
159
+ if sr is None:
160
+ sr = curr_sr
161
+ elif sr != curr_sr:
162
+ print('Error: All audio files must have the same sample rate.')
163
+ return None
164
+ print("Waveform shape: {} sample rate: {}".format(wav.shape, sr))
165
+ data.append(wav)
166
+ del wav
167
+ gc.collect()
168
+ except Exception as e:
169
+ print(f'Error reading audio file {f}: {str(e)}')
170
+ return None
171
 
172
+ try:
173
+ data = np.array(data, copy=False)
174
+ res = average_waveforms(data, weights, args.type)
175
+ print('Result shape: {}'.format(res.shape))
176
+ sf.write(output_path, res.T, sr, 'FLOAT')
177
+ return output_path
178
+ except Exception as e:
179
+ print(f'Error during ensemble processing: {str(e)}')
180
+ return None
181
+ finally:
182
+ gc.collect()
183
 
184
  if __name__ == "__main__":
185
  ensemble_files(None)