aimusic-attribution / ml /stem_separation.py
emresar's picture
Upload folder using huggingface_hub
6678fa1 verified
"""
Stem Separation using Demucs
Separates audio into individual stems (vocals, drums, bass, other)
using Facebook/Meta's Demucs model.
Requires: demucs package (pip install demucs)
"""
import os
import sys
import subprocess
import json
from pathlib import Path
from typing import Optional, List
import tempfile
# Available Demucs models
DEMUCS_MODELS = {
"htdemucs": {
"stems": ["vocals", "drums", "bass", "other"],
"description": "Hybrid Transformer Demucs (recommended)"
},
"htdemucs_ft": {
"stems": ["vocals", "drums", "bass", "other"],
"description": "Fine-tuned Hybrid Transformer Demucs"
},
"htdemucs_6s": {
"stems": ["vocals", "drums", "bass", "guitar", "piano", "other"],
"description": "6-stem Hybrid Transformer Demucs"
},
"mdx_extra": {
"stems": ["vocals", "drums", "bass", "other"],
"description": "MDX-Net architecture"
}
}
def get_best_device() -> str:
"""Auto-detect the best available device for ML processing."""
try:
import torch
if torch.backends.mps.is_available():
return "mps" # Apple Silicon GPU
elif torch.cuda.is_available():
return "cuda" # NVIDIA GPU
except ImportError:
pass
return "cpu"
def separate_stems(
input_path: str,
output_dir: str,
model: str = "htdemucs",
device: Optional[str] = None,
shifts: int = 1,
overlap: float = 0.25
) -> dict:
"""
Separate audio into stems using Demucs.
Args:
input_path: Path to input audio file
output_dir: Directory to save separated stems
model: Demucs model to use (default: htdemucs)
device: Processing device (cuda, cpu, mps). Auto-detected if None.
shifts: Number of random shifts for better quality (more = slower)
overlap: Overlap between prediction windows
Returns:
dict with:
- success: bool
- stems: list of {type, path, duration}
- model: str (model used)
- error: str (if failed)
"""
input_path = Path(input_path)
output_dir = Path(output_dir)
if not input_path.exists():
return {
"success": False,
"error": f"Input file not found: {input_path}"
}
if model not in DEMUCS_MODELS:
return {
"success": False,
"error": f"Unknown model: {model}. Available: {list(DEMUCS_MODELS.keys())}"
}
# Create output directory
output_dir.mkdir(parents=True, exist_ok=True)
# Auto-detect device if not specified
if device is None:
device = get_best_device()
try:
# Build demucs command using current Python interpreter
cmd = [
sys.executable, "-m", "demucs",
"--name", model,
"--out", str(output_dir),
"--shifts", str(shifts),
"--overlap", str(overlap),
"--mp3", # Use mp3 output to avoid torchcodec dependency issues
"--device", device, # Use detected or specified device
]
# Add input file
cmd.append(str(input_path))
# Run demucs
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=600 # 10 minute timeout for long files
)
if result.returncode != 0:
return {
"success": False,
"error": f"Demucs failed: {result.stderr}"
}
# Demucs outputs to: output_dir/model_name/track_name/stem.wav
track_name = input_path.stem
stems_dir = output_dir / model / track_name
if not stems_dir.exists():
return {
"success": False,
"error": f"Stems directory not found: {stems_dir}"
}
# Collect stem info
stems = []
expected_stems = DEMUCS_MODELS[model]["stems"]
for stem_type in expected_stems:
# Check mp3 first (default with --mp3 flag), then wav
stem_path = stems_dir / f"{stem_type}.mp3"
if not stem_path.exists():
stem_path = stems_dir / f"{stem_type}.wav"
if stem_path.exists():
# Get duration using librosa or soundfile
duration = get_audio_duration(str(stem_path))
stems.append({
"type": stem_type,
"path": str(stem_path),
"duration": duration
})
if not stems:
return {
"success": False,
"error": f"No stems found in {stems_dir}"
}
return {
"success": True,
"stems": stems,
"model": model,
"output_dir": str(stems_dir)
}
except subprocess.TimeoutExpired:
return {
"success": False,
"error": "Stem separation timed out (>10 minutes)"
}
except Exception as e:
return {
"success": False,
"error": f"Stem separation failed: {str(e)}"
}
def get_audio_duration(audio_path: str) -> Optional[float]:
"""Get audio duration in seconds."""
try:
import soundfile as sf
info = sf.info(audio_path)
return info.duration
except ImportError:
try:
import librosa
duration = librosa.get_duration(path=audio_path)
return duration
except ImportError:
# Fallback: use ffprobe if available
try:
result = subprocess.run(
["ffprobe", "-v", "quiet", "-show_entries",
"format=duration", "-of", "json", audio_path],
capture_output=True,
text=True
)
if result.returncode == 0:
data = json.loads(result.stdout)
return float(data["format"]["duration"])
except:
pass
except Exception:
pass
return None
def list_available_models() -> dict:
"""List available Demucs models."""
return {
"success": True,
"models": DEMUCS_MODELS
}
if __name__ == "__main__":
# Test stem separation
import sys
if len(sys.argv) > 2:
result = separate_stems(sys.argv[1], sys.argv[2])
print(json.dumps(result, indent=2))
else:
print("Usage: python stem_separation.py <input_audio> <output_dir>")
print("\nAvailable models:")
for name, info in DEMUCS_MODELS.items():
print(f" {name}: {info['description']}")
print(f" Stems: {', '.join(info['stems'])}")