Spaces:
Sleeping
Sleeping
Auto-deploy from GitHub: af943986a2919fba83018f48c4261db4a72f4cee
Browse files
src/musiclime/factorization.py
CHANGED
|
@@ -128,16 +128,22 @@ class OpenUnmixFactorization:
|
|
| 128 |
# Specify targets
|
| 129 |
targets = ["vocals", "bass", "drums", "other"]
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
# Then load openunmix files to openunmix' method
|
| 132 |
prediction = predict.separate(
|
| 133 |
torch.as_tensor(waveform).float(),
|
| 134 |
rate=44100,
|
| 135 |
model_str_or_path=model_path,
|
| 136 |
targets=targets,
|
|
|
|
| 137 |
)
|
| 138 |
|
| 139 |
components = [prediction[key][0].mean(dim=0).numpy() for key in prediction]
|
| 140 |
names = list(prediction.keys())
|
|
|
|
| 141 |
return components, names
|
| 142 |
|
| 143 |
def _prepare_temporal_components(self):
|
|
|
|
| 128 |
# Specify targets
|
| 129 |
targets = ["vocals", "bass", "drums", "other"]
|
| 130 |
|
| 131 |
+
# Specify device based on availability
|
| 132 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 133 |
+
print(f"[MusicLIME] Using device for source separation: {device}")
|
| 134 |
+
|
| 135 |
# Then load openunmix files to openunmix' method
|
| 136 |
prediction = predict.separate(
|
| 137 |
torch.as_tensor(waveform).float(),
|
| 138 |
rate=44100,
|
| 139 |
model_str_or_path=model_path,
|
| 140 |
targets=targets,
|
| 141 |
+
device=device,
|
| 142 |
)
|
| 143 |
|
| 144 |
components = [prediction[key][0].mean(dim=0).numpy() for key in prediction]
|
| 145 |
names = list(prediction.keys())
|
| 146 |
+
|
| 147 |
return components, names
|
| 148 |
|
| 149 |
def _prepare_temporal_components(self):
|
src/preprocessing/audio_preprocessor.py
CHANGED
|
@@ -259,14 +259,14 @@ class AudioPreprocessor:
|
|
| 259 |
"""
|
| 260 |
waveform, sample_rate = self.load_audio(file)
|
| 261 |
|
| 262 |
-
# Resample the audio to 16kHz
|
| 263 |
-
waveform = self.resample_audio(original_sr=sample_rate, waveform=waveform)
|
| 264 |
-
|
| 265 |
# Convert the audio into mono
|
| 266 |
if waveform.shape[0] > 1:
|
| 267 |
# print("Current audio is stereo. Converting to mono.")
|
| 268 |
waveform = waveform.mean(dim=0, keepdim=True)
|
| 269 |
|
|
|
|
|
|
|
|
|
|
| 270 |
# If there is a skip value provided, trim it
|
| 271 |
if skip_time is not None and skip_time > 0:
|
| 272 |
# print(f"Skipping first {skip_time:.2f} seconds.")
|
|
|
|
| 259 |
"""
|
| 260 |
waveform, sample_rate = self.load_audio(file)
|
| 261 |
|
|
|
|
|
|
|
|
|
|
| 262 |
# Convert the audio into mono
|
| 263 |
if waveform.shape[0] > 1:
|
| 264 |
# print("Current audio is stereo. Converting to mono.")
|
| 265 |
waveform = waveform.mean(dim=0, keepdim=True)
|
| 266 |
|
| 267 |
+
# Resample the audio to 16kHz
|
| 268 |
+
waveform = self.resample_audio(original_sr=sample_rate, waveform=waveform)
|
| 269 |
+
|
| 270 |
# If there is a skip value provided, trim it
|
| 271 |
if skip_time is not None and skip_time > 0:
|
| 272 |
# print(f"Skipping first {skip_time:.2f} seconds.")
|
src/spectttra/spectttra_trainer.py
CHANGED
|
@@ -4,7 +4,11 @@ import numpy as np
|
|
| 4 |
from types import SimpleNamespace
|
| 5 |
|
| 6 |
from src.spectttra.feature import FeatureExtractor
|
| 7 |
-
from src.spectttra.spectttra import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
# Shared variables for the model and setup, loaded only once and reused (cache)
|
| 10 |
_PREDICTOR_LOCK = threading.Lock()
|
|
@@ -19,7 +23,9 @@ def build_spectttra(cfg, device):
|
|
| 19 |
Wrapper that builds SpecTTTra + FeatureExtractor and loads frozen checkpoint.
|
| 20 |
"""
|
| 21 |
feat_ext, model = build_spectttra_from_cfg(cfg, device)
|
| 22 |
-
model = load_frozen_spectttra(
|
|
|
|
|
|
|
| 23 |
return feat_ext, model
|
| 24 |
|
| 25 |
|
|
@@ -107,7 +113,7 @@ def spectttra_predict(audio_tensor):
|
|
| 107 |
cfg = _CFG
|
| 108 |
|
| 109 |
# Move waveform to device but keep float for mel extraction
|
| 110 |
-
waveform = audio_tensor.to(device
|
| 111 |
|
| 112 |
with torch.no_grad():
|
| 113 |
# Extract mel-spectrogram
|
|
@@ -162,17 +168,21 @@ def spectttra_train(audio_tensors):
|
|
| 162 |
|
| 163 |
# Refactors the loop to be a much faster single-batch operation
|
| 164 |
try:
|
| 165 |
-
waveforms_batch = torch.cat(audio_tensors, dim=0).to(
|
|
|
|
|
|
|
| 166 |
except Exception as e:
|
| 167 |
-
print(
|
|
|
|
|
|
|
| 168 |
batch_list = [spectttra_predict(w) for w in audio_tensors]
|
| 169 |
return np.array(batch_list)
|
| 170 |
|
| 171 |
with torch.no_grad():
|
| 172 |
melspec = feat_ext(waveforms_batch)
|
| 173 |
|
| 174 |
-
# Ensure melspec shape matches model's expectation
|
| 175 |
-
expected_frames = model.input_temp_dim
|
| 176 |
if melspec.shape[2] > expected_frames:
|
| 177 |
melspec = melspec[:, :, :expected_frames]
|
| 178 |
elif melspec.shape[2] < expected_frames:
|
|
@@ -187,4 +197,4 @@ def spectttra_train(audio_tensors):
|
|
| 187 |
tokens = model(melspec)
|
| 188 |
pooled = tokens.mean(dim=1)
|
| 189 |
|
| 190 |
-
return pooled.cpu().numpy()
|
|
|
|
| 4 |
from types import SimpleNamespace
|
| 5 |
|
| 6 |
from src.spectttra.feature import FeatureExtractor
|
| 7 |
+
from src.spectttra.spectttra import (
|
| 8 |
+
SpecTTTra,
|
| 9 |
+
build_spectttra_from_cfg,
|
| 10 |
+
load_frozen_spectttra,
|
| 11 |
+
)
|
| 12 |
|
| 13 |
# Shared variables for the model and setup, loaded only once and reused (cache)
|
| 14 |
_PREDICTOR_LOCK = threading.Lock()
|
|
|
|
| 23 |
Wrapper that builds SpecTTTra + FeatureExtractor and loads frozen checkpoint.
|
| 24 |
"""
|
| 25 |
feat_ext, model = build_spectttra_from_cfg(cfg, device)
|
| 26 |
+
model = load_frozen_spectttra(
|
| 27 |
+
model, "models/spectttra/spectttra_frozen.pth", device
|
| 28 |
+
)
|
| 29 |
return feat_ext, model
|
| 30 |
|
| 31 |
|
|
|
|
| 113 |
cfg = _CFG
|
| 114 |
|
| 115 |
# Move waveform to device but keep float for mel extraction
|
| 116 |
+
waveform = audio_tensor.to(device, dtype=torch.float32)
|
| 117 |
|
| 118 |
with torch.no_grad():
|
| 119 |
# Extract mel-spectrogram
|
|
|
|
| 168 |
|
| 169 |
# Refactors the loop to be a much faster single-batch operation
|
| 170 |
try:
|
| 171 |
+
waveforms_batch = torch.cat(audio_tensors, dim=0).to(
|
| 172 |
+
device, dtype=torch.float32
|
| 173 |
+
)
|
| 174 |
except Exception as e:
|
| 175 |
+
print(
|
| 176 |
+
f"[INFO] Error during tensor concatenation, falling back to loop. Fix preprocessing for speed. Error: {e}"
|
| 177 |
+
)
|
| 178 |
batch_list = [spectttra_predict(w) for w in audio_tensors]
|
| 179 |
return np.array(batch_list)
|
| 180 |
|
| 181 |
with torch.no_grad():
|
| 182 |
melspec = feat_ext(waveforms_batch)
|
| 183 |
|
| 184 |
+
# Ensure melspec shape matches model's expectation
|
| 185 |
+
expected_frames = model.input_temp_dim # expected_frames is 3744
|
| 186 |
if melspec.shape[2] > expected_frames:
|
| 187 |
melspec = melspec[:, :, :expected_frames]
|
| 188 |
elif melspec.shape[2] < expected_frames:
|
|
|
|
| 197 |
tokens = model(melspec)
|
| 198 |
pooled = tokens.mean(dim=1)
|
| 199 |
|
| 200 |
+
return pooled.cpu().numpy()
|