# separator.py import torch from demucs.pretrained import get_model, ModelLoadingError from demucs.apply import apply_model import time from config import DEVICE, MODEL_NAME, SEGMENT_LENGTH import gc class MusicSeparator: def __init__(self): self.device = DEVICE self.model_name = MODEL_NAME self.model = None self.sources = [] self.load_time = None def load_model(self): """Load the separation model with error handling""" if self.model is not None: return True print(f"Loading {self.model_name} model on {self.device}...") start_time = time.time() try: self.model = get_model(self.model_name) self.model.to(self.device) self.model.eval() self.sources = self.model.sources self.load_time = time.time() - start_time print(f"Model loaded successfully in {self.load_time:.2f}s") print(f"Available sources: {', '.join(self.sources)}") return True except ModelLoadingError as e: raise Exception(f"Failed to load model {self.model_name}: {str(e)}") except Exception as e: raise Exception(f"Unexpected error loading model: {str(e)}") def separate_audio(self, audio_tensor, progress_callback=None): """Separate audio into sources with progress tracking""" if self.model is None: raise Exception("Model not loaded. Call load_model() first.") print("Starting audio separation...") start_time = time.time() try: # Move audio to device audio_tensor = audio_tensor.to(self.device) # Apply model with progress tracking with torch.no_grad(): sources = apply_model( self.model, audio_tensor.unsqueeze(0), device=self.device, progress=True if progress_callback else False )[0] # Convert to dictionary sources_dict = {source: sources[i] for i, source in enumerate(self.sources)} processing_time = time.time() - start_time print(f"Audio separation completed in {processing_time:.2f}s") return sources_dict except RuntimeError as e: if "out of memory" in str(e): gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None raise Exception("GPU out of memory. Try using smaller audio files or CPU processing.") raise Exception(f"Separation error: {str(e)}") except Exception as e: raise Exception(f"Unexpected error during separation: {str(e)}") def get_model_info(self): """Get model information""" if self.model is None: return "Model not loaded" return { 'name': self.model_name, 'sources': self.sources, 'device': str(self.device), 'load_time': self.load_time }