krislette commited on
Commit
e26dafd
·
1 Parent(s): e1ee8d1

Auto-deploy from GitHub: af943986a2919fba83018f48c4261db4a72f4cee

Browse files
src/musiclime/factorization.py CHANGED
@@ -128,16 +128,22 @@ class OpenUnmixFactorization:
128
  # Specify targets
129
  targets = ["vocals", "bass", "drums", "other"]
130
 
 
 
 
 
131
  # Then load openunmix files to openunmix' method
132
  prediction = predict.separate(
133
  torch.as_tensor(waveform).float(),
134
  rate=44100,
135
  model_str_or_path=model_path,
136
  targets=targets,
 
137
  )
138
 
139
  components = [prediction[key][0].mean(dim=0).numpy() for key in prediction]
140
  names = list(prediction.keys())
 
141
  return components, names
142
 
143
  def _prepare_temporal_components(self):
 
128
  # Specify targets
129
  targets = ["vocals", "bass", "drums", "other"]
130
 
131
+ # Specify device based on availability
132
+ device = "cuda" if torch.cuda.is_available() else "cpu"
133
+ print(f"[MusicLIME] Using device for source separation: {device}")
134
+
135
  # Then load openunmix files to openunmix' method
136
  prediction = predict.separate(
137
  torch.as_tensor(waveform).float(),
138
  rate=44100,
139
  model_str_or_path=model_path,
140
  targets=targets,
141
+ device=device,
142
  )
143
 
144
  components = [prediction[key][0].mean(dim=0).numpy() for key in prediction]
145
  names = list(prediction.keys())
146
+
147
  return components, names
148
 
149
  def _prepare_temporal_components(self):
src/preprocessing/audio_preprocessor.py CHANGED
@@ -259,14 +259,14 @@ class AudioPreprocessor:
259
  """
260
  waveform, sample_rate = self.load_audio(file)
261
 
262
- # Resample the audio to 16kHz
263
- waveform = self.resample_audio(original_sr=sample_rate, waveform=waveform)
264
-
265
  # Convert the audio into mono
266
  if waveform.shape[0] > 1:
267
  # print("Current audio is stereo. Converting to mono.")
268
  waveform = waveform.mean(dim=0, keepdim=True)
269
 
 
 
 
270
  # If there is a skip value provided, trim it
271
  if skip_time is not None and skip_time > 0:
272
  # print(f"Skipping first {skip_time:.2f} seconds.")
 
259
  """
260
  waveform, sample_rate = self.load_audio(file)
261
 
 
 
 
262
  # Convert the audio into mono
263
  if waveform.shape[0] > 1:
264
  # print("Current audio is stereo. Converting to mono.")
265
  waveform = waveform.mean(dim=0, keepdim=True)
266
 
267
+ # Resample the audio to 16kHz
268
+ waveform = self.resample_audio(original_sr=sample_rate, waveform=waveform)
269
+
270
  # If there is a skip value provided, trim it
271
  if skip_time is not None and skip_time > 0:
272
  # print(f"Skipping first {skip_time:.2f} seconds.")
src/spectttra/spectttra_trainer.py CHANGED
@@ -4,7 +4,11 @@ import numpy as np
4
  from types import SimpleNamespace
5
 
6
  from src.spectttra.feature import FeatureExtractor
7
- from src.spectttra.spectttra import SpecTTTra, build_spectttra_from_cfg, load_frozen_spectttra
 
 
 
 
8
 
9
  # Shared variables for the model and setup, loaded only once and reused (cache)
10
  _PREDICTOR_LOCK = threading.Lock()
@@ -19,7 +23,9 @@ def build_spectttra(cfg, device):
19
  Wrapper that builds SpecTTTra + FeatureExtractor and loads frozen checkpoint.
20
  """
21
  feat_ext, model = build_spectttra_from_cfg(cfg, device)
22
- model = load_frozen_spectttra(model, "models/spectttra/spectttra_frozen.pth", device)
 
 
23
  return feat_ext, model
24
 
25
 
@@ -107,7 +113,7 @@ def spectttra_predict(audio_tensor):
107
  cfg = _CFG
108
 
109
  # Move waveform to device but keep float for mel extraction
110
- waveform = audio_tensor.to(device).float()
111
 
112
  with torch.no_grad():
113
  # Extract mel-spectrogram
@@ -162,17 +168,21 @@ def spectttra_train(audio_tensors):
162
 
163
  # Refactors the loop to be a much faster single-batch operation
164
  try:
165
- waveforms_batch = torch.cat(audio_tensors, dim=0).to(device).float()
 
 
166
  except Exception as e:
167
- print(f"[INFO] Error during tensor concatenation, falling back to loop. Fix preprocessing for speed. Error: {e}")
 
 
168
  batch_list = [spectttra_predict(w) for w in audio_tensors]
169
  return np.array(batch_list)
170
 
171
  with torch.no_grad():
172
  melspec = feat_ext(waveforms_batch)
173
 
174
- # Ensure melspec shape matches model's expectation
175
- expected_frames = model.input_temp_dim # expected_frames is 3744
176
  if melspec.shape[2] > expected_frames:
177
  melspec = melspec[:, :, :expected_frames]
178
  elif melspec.shape[2] < expected_frames:
@@ -187,4 +197,4 @@ def spectttra_train(audio_tensors):
187
  tokens = model(melspec)
188
  pooled = tokens.mean(dim=1)
189
 
190
- return pooled.cpu().numpy()
 
4
  from types import SimpleNamespace
5
 
6
  from src.spectttra.feature import FeatureExtractor
7
+ from src.spectttra.spectttra import (
8
+ SpecTTTra,
9
+ build_spectttra_from_cfg,
10
+ load_frozen_spectttra,
11
+ )
12
 
13
  # Shared variables for the model and setup, loaded only once and reused (cache)
14
  _PREDICTOR_LOCK = threading.Lock()
 
23
  Wrapper that builds SpecTTTra + FeatureExtractor and loads frozen checkpoint.
24
  """
25
  feat_ext, model = build_spectttra_from_cfg(cfg, device)
26
+ model = load_frozen_spectttra(
27
+ model, "models/spectttra/spectttra_frozen.pth", device
28
+ )
29
  return feat_ext, model
30
 
31
 
 
113
  cfg = _CFG
114
 
115
  # Move waveform to device but keep float for mel extraction
116
+ waveform = audio_tensor.to(device, dtype=torch.float32)
117
 
118
  with torch.no_grad():
119
  # Extract mel-spectrogram
 
168
 
169
  # Refactors the loop to be a much faster single-batch operation
170
  try:
171
+ waveforms_batch = torch.cat(audio_tensors, dim=0).to(
172
+ device, dtype=torch.float32
173
+ )
174
  except Exception as e:
175
+ print(
176
+ f"[INFO] Error during tensor concatenation, falling back to loop. Fix preprocessing for speed. Error: {e}"
177
+ )
178
  batch_list = [spectttra_predict(w) for w in audio_tensors]
179
  return np.array(batch_list)
180
 
181
  with torch.no_grad():
182
  melspec = feat_ext(waveforms_batch)
183
 
184
+ # Ensure melspec shape matches model's expectation
185
+ expected_frames = model.input_temp_dim # expected_frames is 3744
186
  if melspec.shape[2] > expected_frames:
187
  melspec = melspec[:, :, :expected_frames]
188
  elif melspec.shape[2] < expected_frames:
 
197
  tokens = model(melspec)
198
  pooled = tokens.mean(dim=1)
199
 
200
+ return pooled.cpu().numpy()