|
|
|
|
|
import torch |
|
|
from demucs.pretrained import get_model, ModelLoadingError |
|
|
from demucs.apply import apply_model |
|
|
import time |
|
|
from config import DEVICE, MODEL_NAME, SEGMENT_LENGTH |
|
|
import gc |
|
|
|
|
|
class MusicSeparator: |
|
|
def __init__(self): |
|
|
self.device = DEVICE |
|
|
self.model_name = MODEL_NAME |
|
|
self.model = None |
|
|
self.sources = [] |
|
|
self.load_time = None |
|
|
|
|
|
def load_model(self): |
|
|
"""Load the separation model with error handling""" |
|
|
if self.model is not None: |
|
|
return True |
|
|
|
|
|
print(f"Loading {self.model_name} model on {self.device}...") |
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
self.model = get_model(self.model_name) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
self.sources = self.model.sources |
|
|
|
|
|
self.load_time = time.time() - start_time |
|
|
print(f"Model loaded successfully in {self.load_time:.2f}s") |
|
|
print(f"Available sources: {', '.join(self.sources)}") |
|
|
return True |
|
|
|
|
|
except ModelLoadingError as e: |
|
|
raise Exception(f"Failed to load model {self.model_name}: {str(e)}") |
|
|
except Exception as e: |
|
|
raise Exception(f"Unexpected error loading model: {str(e)}") |
|
|
|
|
|
def separate_audio(self, audio_tensor, progress_callback=None): |
|
|
"""Separate audio into sources with progress tracking""" |
|
|
if self.model is None: |
|
|
raise Exception("Model not loaded. Call load_model() first.") |
|
|
|
|
|
print("Starting audio separation...") |
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
|
|
|
audio_tensor = audio_tensor.to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
sources = apply_model( |
|
|
self.model, |
|
|
audio_tensor.unsqueeze(0), |
|
|
device=self.device, |
|
|
progress=True if progress_callback else False |
|
|
)[0] |
|
|
|
|
|
|
|
|
sources_dict = {source: sources[i] for i, source in enumerate(self.sources)} |
|
|
|
|
|
processing_time = time.time() - start_time |
|
|
print(f"Audio separation completed in {processing_time:.2f}s") |
|
|
|
|
|
return sources_dict |
|
|
|
|
|
except RuntimeError as e: |
|
|
if "out of memory" in str(e): |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
raise Exception("GPU out of memory. Try using smaller audio files or CPU processing.") |
|
|
raise Exception(f"Separation error: {str(e)}") |
|
|
except Exception as e: |
|
|
raise Exception(f"Unexpected error during separation: {str(e)}") |
|
|
|
|
|
def get_model_info(self): |
|
|
"""Get model information""" |
|
|
if self.model is None: |
|
|
return "Model not loaded" |
|
|
|
|
|
return { |
|
|
'name': self.model_name, |
|
|
'sources': self.sources, |
|
|
'device': str(self.device), |
|
|
'load_time': self.load_time |
|
|
} |