File size: 3,240 Bytes
ebe7355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# separator.py
import torch
from demucs.pretrained import get_model, ModelLoadingError
from demucs.apply import apply_model
import time
from config import DEVICE, MODEL_NAME, SEGMENT_LENGTH
import gc

class MusicSeparator:
    def __init__(self):
        self.device = DEVICE
        self.model_name = MODEL_NAME
        self.model = None
        self.sources = []
        self.load_time = None
        
    def load_model(self):
        """Load the separation model with error handling"""
        if self.model is not None:
            return True
            
        print(f"Loading {self.model_name} model on {self.device}...")
        start_time = time.time()
        
        try:
            self.model = get_model(self.model_name)
            self.model.to(self.device)
            self.model.eval()
            self.sources = self.model.sources
            
            self.load_time = time.time() - start_time
            print(f"Model loaded successfully in {self.load_time:.2f}s")
            print(f"Available sources: {', '.join(self.sources)}")
            return True
            
        except ModelLoadingError as e:
            raise Exception(f"Failed to load model {self.model_name}: {str(e)}")
        except Exception as e:
            raise Exception(f"Unexpected error loading model: {str(e)}")
    
    def separate_audio(self, audio_tensor, progress_callback=None):
        """Separate audio into sources with progress tracking"""
        if self.model is None:
            raise Exception("Model not loaded. Call load_model() first.")
            
        print("Starting audio separation...")
        start_time = time.time()
        
        try:
            # Move audio to device
            audio_tensor = audio_tensor.to(self.device)
            
            # Apply model with progress tracking
            with torch.no_grad():
                sources = apply_model(
                    self.model, 
                    audio_tensor.unsqueeze(0), 
                    device=self.device,
                    progress=True if progress_callback else False
                )[0]
            
            # Convert to dictionary
            sources_dict = {source: sources[i] for i, source in enumerate(self.sources)}
            
            processing_time = time.time() - start_time
            print(f"Audio separation completed in {processing_time:.2f}s")
            
            return sources_dict
            
        except RuntimeError as e:
            if "out of memory" in str(e):
                gc.collect()
                torch.cuda.empty_cache() if torch.cuda.is_available() else None
                raise Exception("GPU out of memory. Try using smaller audio files or CPU processing.")
            raise Exception(f"Separation error: {str(e)}")
        except Exception as e:
            raise Exception(f"Unexpected error during separation: {str(e)}")
    
    def get_model_info(self):
        """Get model information"""
        if self.model is None:
            return "Model not loaded"
            
        return {
            'name': self.model_name,
            'sources': self.sources,
            'device': str(self.device),
            'load_time': self.load_time
        }