Demusics / separator.py
Kremon96's picture
Create separator.py
ebe7355 verified
# separator.py
import torch
from demucs.pretrained import get_model, ModelLoadingError
from demucs.apply import apply_model
import time
from config import DEVICE, MODEL_NAME, SEGMENT_LENGTH
import gc
class MusicSeparator:
def __init__(self):
self.device = DEVICE
self.model_name = MODEL_NAME
self.model = None
self.sources = []
self.load_time = None
def load_model(self):
"""Load the separation model with error handling"""
if self.model is not None:
return True
print(f"Loading {self.model_name} model on {self.device}...")
start_time = time.time()
try:
self.model = get_model(self.model_name)
self.model.to(self.device)
self.model.eval()
self.sources = self.model.sources
self.load_time = time.time() - start_time
print(f"Model loaded successfully in {self.load_time:.2f}s")
print(f"Available sources: {', '.join(self.sources)}")
return True
except ModelLoadingError as e:
raise Exception(f"Failed to load model {self.model_name}: {str(e)}")
except Exception as e:
raise Exception(f"Unexpected error loading model: {str(e)}")
def separate_audio(self, audio_tensor, progress_callback=None):
"""Separate audio into sources with progress tracking"""
if self.model is None:
raise Exception("Model not loaded. Call load_model() first.")
print("Starting audio separation...")
start_time = time.time()
try:
# Move audio to device
audio_tensor = audio_tensor.to(self.device)
# Apply model with progress tracking
with torch.no_grad():
sources = apply_model(
self.model,
audio_tensor.unsqueeze(0),
device=self.device,
progress=True if progress_callback else False
)[0]
# Convert to dictionary
sources_dict = {source: sources[i] for i, source in enumerate(self.sources)}
processing_time = time.time() - start_time
print(f"Audio separation completed in {processing_time:.2f}s")
return sources_dict
except RuntimeError as e:
if "out of memory" in str(e):
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
raise Exception("GPU out of memory. Try using smaller audio files or CPU processing.")
raise Exception(f"Separation error: {str(e)}")
except Exception as e:
raise Exception(f"Unexpected error during separation: {str(e)}")
def get_model_info(self):
"""Get model information"""
if self.model is None:
return "Model not loaded"
return {
'name': self.model_name,
'sources': self.sources,
'device': str(self.device),
'load_time': self.load_time
}