File size: 10,769 Bytes
6319e2f
 
 
 
 
 
 
dc58474
 
 
 
 
 
6319e2f
 
dc58474
6319e2f
dc58474
 
 
6319e2f
 
 
 
dc58474
6319e2f
 
 
dc58474
6319e2f
 
 
 
 
 
 
 
 
dc58474
 
 
 
 
6319e2f
 
 
dc58474
 
6319e2f
dc58474
 
6319e2f
 
 
 
 
 
 
dc58474
6319e2f
 
 
 
 
 
 
dc58474
 
6319e2f
 
 
 
 
 
 
 
 
 
 
 
dc58474
6319e2f
 
dc58474
6319e2f
dc58474
6319e2f
dc58474
6319e2f
 
 
 
 
 
 
 
dc58474
6319e2f
 
 
 
 
 
 
 
dc58474
6319e2f
 
 
 
 
 
 
 
 
 
 
dc58474
6319e2f
 
 
 
 
 
 
 
207cecb
 
6319e2f
 
 
dc58474
 
 
6319e2f
 
 
 
 
 
 
 
 
 
d4e97e7
 
 
 
61f32ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6319e2f
61f32ab
 
 
 
 
 
 
 
 
 
 
 
 
6319e2f
61f32ab
 
6319e2f
dc58474
61f32ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc58474
61f32ab
 
 
 
 
 
 
 
dc58474
61f32ab
dc58474
 
6319e2f
61f32ab
dc58474
 
6319e2f
dc58474
 
 
ef24873
 
 
 
6319e2f
ef24873
 
 
 
 
 
 
 
 
 
 
6319e2f
 
ef24873
6319e2f
 
 
dc58474
 
 
207cecb
 
 
 
 
 
dc58474
 
 
6319e2f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
import torch
import torchaudio

# Force backend to soundfile to avoid TorchCodec requirement on Windows
try:
    torchaudio.set_audio_backend("soundfile")
except:
    pass # In newer versions it might be automatic or different, but this is a safe attempt

from demucs.apply import apply_model
from demucs.pretrained import get_model
import os
import pathlib

# Konfigurasi Model
# Support multiple models for different separation modes
MODELS = {
    "2stem": "htdemucs",    # Will merge to 2 stems after
    "4stem": "htdemucs",    # Default 4 stem (drums, bass, vocals, other)
    "6stem": "htdemucs_6s", # Full 6 stem with guitar and piano
}

class AudioSeparator:
    def __init__(self):
        # Load all models on startup
        self.models = {}
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Load unique models only
        unique_models = set(MODELS.values())
        for model_name in unique_models:
            print(f"Loading Demucs Model: {model_name}...")
            model = get_model(model_name)
            model.to(self.device)
            self.models[model_name] = model
        print(f"All models loaded on {self.device}")

    def separate(self, audio_path: str, output_dir: str, callback=None, mode="4stem"):
        """
        Memisahkan file audio menjadi stems.
        Modes: 2stem, 4stem, 6stem
        """
        # Select model based on mode
        model_name = MODELS.get(mode, "htdemucs")
        model = self.models[model_name]
        
        # Load audio using soundfile directly (bypass torchaudio backend issues)
        import soundfile as sf
        wav_np, sr = sf.read(audio_path)
        
        # Convert to tensor
        wav = torch.from_numpy(wav_np).float()
        
        if wav.ndim == 1:
            wav = wav.unsqueeze(0)
        else:
            wav = wav.t()
        
        # Resample if needed
        if sr != 44100:
            if callback: callback("Resampling audio...", 15)
            resampler = torchaudio.transforms.Resample(sr, 44100)
            wav = resampler(wav)
            sr = 44100
        
        wav = wav.unsqueeze(0).to(self.device)

        # Apply model
        ref = wav.mean(0)
        wav = (wav - ref.mean()) / ref.std()
        
        if callback: callback("Running Demucs Inference...", 20)
        print(f"Starting separation with {model_name} (mode: {mode})...")
        
        sources = apply_model(model, wav, shifts=1, split=True, overlap=0.25, progress=True)[0]
        source_names = model.sources
        
        results = {}
        os.makedirs(output_dir, exist_ok=True)
        
        total_sources = len(source_names)
        source_tensors = {name: source for name, source in zip(source_names, sources)}
        
        # Handle different modes
        if mode == "2stem":
            # Merge to Vocals + Instruments
            if callback: callback("Merging to 2 stems...", 45)
            
            vocals = source_tensors.get('vocals')
            instruments = None
            for name, src in source_tensors.items():
                if name != 'vocals':
                    if instruments is None:
                        instruments = src.clone()
                    else:
                        instruments += src
            
            if vocals is not None:
                self._save_audio(vocals, sr, os.path.join(output_dir, "vocals.mp3"))
                results['vocals'] = os.path.join(output_dir, "vocals.mp3")
            if instruments is not None:
                self._save_audio(instruments, sr, os.path.join(output_dir, "instruments.mp3"))
                results['instruments'] = os.path.join(output_dir, "instruments.mp3")
                
        elif mode == "6stem":
            # Full separation with guitar processing
            for i, (name, source) in enumerate(source_tensors.items()):
                progress = 30 + int((i / total_sources) * 20)
                if callback: callback(f"Saving stem: {name}", progress)
                
                if name == 'guitar':
                    results.update(self._process_guitar(source, sr, output_dir))
                else:
                    stem_path = os.path.join(output_dir, f"{name}.mp3")
                    self._save_audio(source, sr, stem_path)
                    results[name] = stem_path
        else:
            # 4stem - standard separation
            for i, (name, source) in enumerate(source_tensors.items()):
                progress = 30 + int((i / total_sources) * 20)
                if callback: callback(f"Saving stem: {name}", progress)
                
                stem_path = os.path.join(output_dir, f"{name}.mp3")
                self._save_audio(source, sr, stem_path)
                results[name] = stem_path
                
        input_duration = len(wav_np) / sr
        return results, input_duration

    def _process_guitar(self, source, sr, output_dir):
        """
        Memisahkan stem gitar menjadi Lead dan Rhythm menggunakan Mid-Side processing.
        - Mid (center) = Rhythm (biasanya power chords, strumming di center)
        - Side (stereo difference) = Lead (biasanya di-pan atau dengan stereo effects)
        """
        # source shape: (2, samples)
        
        # Check integrity
        if source.shape[0] < 2:
             print("Warning: Guitar stem is Mono. Cannot split Rhythm/Lead.")
             path = os.path.join(output_dir, "guitar.mp3")
             self._save_audio(source, sr, path)
             return {"guitar_rhythm": path, "guitar_lead": path}

        # Define Left and Right channels for processing
        left = source[0:1, :]
        right = source[1:2, :]

        # 1. Smart Spatial Split Check
        # Calculate correlation between L and R to detect Hard Panning (Math Rock Style)
        # If correlation is low, it means L and R are playing different things.
        # We assume Left = Rhythm (often dropped D/lower), Right = Lead (often ornate/higher) - OR provide both as is.
        
        # Calculate cross-correlation at lag 0
        mean_l = left.mean()
        mean_r = right.mean()
        var_l = ((left - mean_l)**2).mean()
        var_r = ((right - mean_r)**2).mean()
        cov = ((left - mean_l) * (right - mean_r)).mean()
        
        correlation = 0.0
        if var_l > 0 and var_r > 0:
            correlation = cov / torch.sqrt(var_l * var_r)
            
        print(f"Guitar Stereo Correlation: {correlation:.4f}")
        
        # Threshold for "Wide Stereo"
        if abs(correlation) < 0.6:
            print("Detected Wide Stereo Guitar (Math Rock Style). Using Spatial Split (L=Rhythm, R=Lead).")
            # Force Hard Split
            # Rhythm = Left Channel (Duplicated to Stereo)
            # Lead = Right Channel (Duplicated to Stereo)
            rhythm_stereo = torch.cat([left, left], dim=0)
            lead_stereo = torch.cat([right, right], dim=0)
        else:
            print("Detected Narrow/Mono Guitar. Using Mid-Side Frequency Split.")
            # Standard Mid-Side with Tighter Filters
            
            mid = (left + right) / 2.0
            side = (left - right) / 2.0
            
            try:
                import scipy.signal as signal
                nyquist = sr / 2
                
                # Rhythm: Low-Mid focus (80-1500Hz) - tighter top end
                # To distinguish from lead which often occupies 800+
                rhythm_low = 80 / nyquist
                rhythm_high = 1200 / nyquist
                b_r, a_r = signal.butter(4, [rhythm_low, rhythm_high], btype='band')
                
                # Lead: High-Mid focus (1000-8000Hz)
                lead_low = 1000 / nyquist
                lead_high = 8000 / nyquist
                b_l, a_l = signal.butter(4, [lead_low, lead_high], btype='band')
                
                # Apply to Mid (Center info usually has both, but we try to separate by freq)
                # We interpret 'Mid' as the main source.
                rhythm_from_mid = signal.filtfilt(b_r, a_r, mid.numpy())
                lead_from_mid = signal.filtfilt(b_l, a_l, mid.numpy())
                
                # Reconstruct
                # If Side exists (some stereo), add it to Lead (often spatial effects are on lead)
                side_np = side.numpy()
                
                rhythm_final = rhythm_from_mid 
                lead_final = lead_from_mid + (side_np * 1.5) # Boost side for lead
                
                rhythm_stereo = torch.from_numpy(rhythm_final).float()
                rhythm_stereo = torch.cat([rhythm_stereo, rhythm_stereo], dim=0)
                
                lead_stereo = torch.from_numpy(lead_final).float()
                lead_stereo = torch.cat([lead_stereo, lead_stereo], dim=0)
                
            except Exception as e:
                print(f"Filter failed: {e}. Fallback to raw.")
                rhythm_stereo = torch.cat([left, left], dim=0)
                lead_stereo = torch.cat([right, right], dim=0)
        
        # Normalize
        def normalize(tensor):
            peak = tensor.abs().max()
            if peak > 0:
                target_peak = 0.89  # -1dB
                return tensor * (target_peak / peak)
            return tensor

        rhythm_stereo = normalize(rhythm_stereo)
        lead_stereo = normalize(lead_stereo)
        
        # MERGE TO SINGLE STEREO FILE (L=Rhythm, R=Lead)
        # We take the Left channel of the Rhythm stereo mix (which is mono-ish)
        # And the Right channel of the Lead stereo mix (which is mono-ish)
        # Or better: Just use the Mono mix of each.
        
        rhythm_mono = rhythm_stereo.mean(dim=0, keepdim=True)
        lead_mono = lead_stereo.mean(dim=0, keepdim=True)
        
        # Combine: Channel 0 = Rhythm, Channel 1 = Lead
        guitar_split = torch.cat([rhythm_mono, lead_mono], dim=0)
        
        guitar_split = normalize(guitar_split)
        
        # Save as single file named "guitar.mp3" (Special split)
        path = os.path.join(output_dir, "guitar.mp3")
        self._save_audio(guitar_split, sr, path)
        
        return {
            "guitar": path 
        }

    def _save_audio(self, source, sr, path):
        # source is tensor (channels, samples) on device
        # Move to cpu
        source = source.cpu()

        # Normalize to prevent clipping (limit to -1dB peak)
        peak = source.abs().max()
        if peak > 0.89: # approx -1dB
            source = source / peak * 0.89
            
        # Save using soundfile
        # source is (channels, samples) -> need (samples, channels)
        import soundfile as sf
        sf.write(path, source.t().numpy(), sr)