File size: 6,889 Bytes
6678fa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
"""
Stem Separation using Demucs

Separates audio into individual stems (vocals, drums, bass, other)
using Facebook/Meta's Demucs model.

Requires: demucs package (pip install demucs)
"""

import os
import sys
import subprocess
import json
from pathlib import Path
from typing import Optional, List
import tempfile


# Available Demucs models
DEMUCS_MODELS = {
    "htdemucs": {
        "stems": ["vocals", "drums", "bass", "other"],
        "description": "Hybrid Transformer Demucs (recommended)"
    },
    "htdemucs_ft": {
        "stems": ["vocals", "drums", "bass", "other"],
        "description": "Fine-tuned Hybrid Transformer Demucs"
    },
    "htdemucs_6s": {
        "stems": ["vocals", "drums", "bass", "guitar", "piano", "other"],
        "description": "6-stem Hybrid Transformer Demucs"
    },
    "mdx_extra": {
        "stems": ["vocals", "drums", "bass", "other"],
        "description": "MDX-Net architecture"
    }
}


def get_best_device() -> str:
    """Auto-detect the best available device for ML processing."""
    try:
        import torch
        if torch.backends.mps.is_available():
            return "mps"  # Apple Silicon GPU
        elif torch.cuda.is_available():
            return "cuda"  # NVIDIA GPU
    except ImportError:
        pass
    return "cpu"


def separate_stems(
    input_path: str,
    output_dir: str,
    model: str = "htdemucs",
    device: Optional[str] = None,
    shifts: int = 1,
    overlap: float = 0.25
) -> dict:
    """
    Separate audio into stems using Demucs.
    
    Args:
        input_path: Path to input audio file
        output_dir: Directory to save separated stems
        model: Demucs model to use (default: htdemucs)
        device: Processing device (cuda, cpu, mps). Auto-detected if None.
        shifts: Number of random shifts for better quality (more = slower)
        overlap: Overlap between prediction windows
    
    Returns:
        dict with:
            - success: bool
            - stems: list of {type, path, duration}
            - model: str (model used)
            - error: str (if failed)
    """
    input_path = Path(input_path)
    output_dir = Path(output_dir)
    
    if not input_path.exists():
        return {
            "success": False,
            "error": f"Input file not found: {input_path}"
        }
    
    if model not in DEMUCS_MODELS:
        return {
            "success": False,
            "error": f"Unknown model: {model}. Available: {list(DEMUCS_MODELS.keys())}"
        }
    
    # Create output directory
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Auto-detect device if not specified
    if device is None:
        device = get_best_device()
    
    try:
        # Build demucs command using current Python interpreter
        cmd = [
            sys.executable, "-m", "demucs",
            "--name", model,
            "--out", str(output_dir),
            "--shifts", str(shifts),
            "--overlap", str(overlap),
            "--mp3",  # Use mp3 output to avoid torchcodec dependency issues
            "--device", device,  # Use detected or specified device
        ]
        
        # Add input file
        cmd.append(str(input_path))
        
        # Run demucs
        result = subprocess.run(
            cmd,
            capture_output=True,
            text=True,
            timeout=600  # 10 minute timeout for long files
        )
        
        if result.returncode != 0:
            return {
                "success": False,
                "error": f"Demucs failed: {result.stderr}"
            }
        
        # Demucs outputs to: output_dir/model_name/track_name/stem.wav
        track_name = input_path.stem
        stems_dir = output_dir / model / track_name
        
        if not stems_dir.exists():
            return {
                "success": False,
                "error": f"Stems directory not found: {stems_dir}"
            }
        
        # Collect stem info
        stems = []
        expected_stems = DEMUCS_MODELS[model]["stems"]
        
        for stem_type in expected_stems:
            # Check mp3 first (default with --mp3 flag), then wav
            stem_path = stems_dir / f"{stem_type}.mp3"
            if not stem_path.exists():
                stem_path = stems_dir / f"{stem_type}.wav"
            
            if stem_path.exists():
                # Get duration using librosa or soundfile
                duration = get_audio_duration(str(stem_path))
                stems.append({
                    "type": stem_type,
                    "path": str(stem_path),
                    "duration": duration
                })
        
        if not stems:
            return {
                "success": False,
                "error": f"No stems found in {stems_dir}"
            }
        
        return {
            "success": True,
            "stems": stems,
            "model": model,
            "output_dir": str(stems_dir)
        }
        
    except subprocess.TimeoutExpired:
        return {
            "success": False,
            "error": "Stem separation timed out (>10 minutes)"
        }
    except Exception as e:
        return {
            "success": False,
            "error": f"Stem separation failed: {str(e)}"
        }


def get_audio_duration(audio_path: str) -> Optional[float]:
    """Get audio duration in seconds."""
    try:
        import soundfile as sf
        info = sf.info(audio_path)
        return info.duration
    except ImportError:
        try:
            import librosa
            duration = librosa.get_duration(path=audio_path)
            return duration
        except ImportError:
            # Fallback: use ffprobe if available
            try:
                result = subprocess.run(
                    ["ffprobe", "-v", "quiet", "-show_entries", 
                     "format=duration", "-of", "json", audio_path],
                    capture_output=True,
                    text=True
                )
                if result.returncode == 0:
                    data = json.loads(result.stdout)
                    return float(data["format"]["duration"])
            except:
                pass
    except Exception:
        pass
    return None


def list_available_models() -> dict:
    """List available Demucs models."""
    return {
        "success": True,
        "models": DEMUCS_MODELS
    }


if __name__ == "__main__":
    # Test stem separation
    import sys
    if len(sys.argv) > 2:
        result = separate_stems(sys.argv[1], sys.argv[2])
        print(json.dumps(result, indent=2))
    else:
        print("Usage: python stem_separation.py <input_audio> <output_dir>")
        print("\nAvailable models:")
        for name, info in DEMUCS_MODELS.items():
            print(f"  {name}: {info['description']}")
            print(f"    Stems: {', '.join(info['stems'])}")