tunebase / services /audio_separator.py
github-actions[bot]
Sync from GitHub: c41325319a184f2e41b4253f10027acd4a889425
d4e97e7
import torch
import torchaudio
# Force backend to soundfile to avoid TorchCodec requirement on Windows
try:
torchaudio.set_audio_backend("soundfile")
except:
pass # In newer versions it might be automatic or different, but this is a safe attempt
from demucs.apply import apply_model
from demucs.pretrained import get_model
import os
import pathlib
# Konfigurasi Model
# Support multiple models for different separation modes
MODELS = {
"2stem": "htdemucs", # Will merge to 2 stems after
"4stem": "htdemucs", # Default 4 stem (drums, bass, vocals, other)
"6stem": "htdemucs_6s", # Full 6 stem with guitar and piano
}
class AudioSeparator:
def __init__(self):
# Load all models on startup
self.models = {}
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Load unique models only
unique_models = set(MODELS.values())
for model_name in unique_models:
print(f"Loading Demucs Model: {model_name}...")
model = get_model(model_name)
model.to(self.device)
self.models[model_name] = model
print(f"All models loaded on {self.device}")
def separate(self, audio_path: str, output_dir: str, callback=None, mode="4stem"):
"""
Memisahkan file audio menjadi stems.
Modes: 2stem, 4stem, 6stem
"""
# Select model based on mode
model_name = MODELS.get(mode, "htdemucs")
model = self.models[model_name]
# Load audio using soundfile directly (bypass torchaudio backend issues)
import soundfile as sf
wav_np, sr = sf.read(audio_path)
# Convert to tensor
wav = torch.from_numpy(wav_np).float()
if wav.ndim == 1:
wav = wav.unsqueeze(0)
else:
wav = wav.t()
# Resample if needed
if sr != 44100:
if callback: callback("Resampling audio...", 15)
resampler = torchaudio.transforms.Resample(sr, 44100)
wav = resampler(wav)
sr = 44100
wav = wav.unsqueeze(0).to(self.device)
# Apply model
ref = wav.mean(0)
wav = (wav - ref.mean()) / ref.std()
if callback: callback("Running Demucs Inference...", 20)
print(f"Starting separation with {model_name} (mode: {mode})...")
sources = apply_model(model, wav, shifts=1, split=True, overlap=0.25, progress=True)[0]
source_names = model.sources
results = {}
os.makedirs(output_dir, exist_ok=True)
total_sources = len(source_names)
source_tensors = {name: source for name, source in zip(source_names, sources)}
# Handle different modes
if mode == "2stem":
# Merge to Vocals + Instruments
if callback: callback("Merging to 2 stems...", 45)
vocals = source_tensors.get('vocals')
instruments = None
for name, src in source_tensors.items():
if name != 'vocals':
if instruments is None:
instruments = src.clone()
else:
instruments += src
if vocals is not None:
self._save_audio(vocals, sr, os.path.join(output_dir, "vocals.mp3"))
results['vocals'] = os.path.join(output_dir, "vocals.mp3")
if instruments is not None:
self._save_audio(instruments, sr, os.path.join(output_dir, "instruments.mp3"))
results['instruments'] = os.path.join(output_dir, "instruments.mp3")
elif mode == "6stem":
# Full separation with guitar processing
for i, (name, source) in enumerate(source_tensors.items()):
progress = 30 + int((i / total_sources) * 20)
if callback: callback(f"Saving stem: {name}", progress)
if name == 'guitar':
results.update(self._process_guitar(source, sr, output_dir))
else:
stem_path = os.path.join(output_dir, f"{name}.mp3")
self._save_audio(source, sr, stem_path)
results[name] = stem_path
else:
# 4stem - standard separation
for i, (name, source) in enumerate(source_tensors.items()):
progress = 30 + int((i / total_sources) * 20)
if callback: callback(f"Saving stem: {name}", progress)
stem_path = os.path.join(output_dir, f"{name}.mp3")
self._save_audio(source, sr, stem_path)
results[name] = stem_path
input_duration = len(wav_np) / sr
return results, input_duration
def _process_guitar(self, source, sr, output_dir):
"""
Memisahkan stem gitar menjadi Lead dan Rhythm menggunakan Mid-Side processing.
- Mid (center) = Rhythm (biasanya power chords, strumming di center)
- Side (stereo difference) = Lead (biasanya di-pan atau dengan stereo effects)
"""
# source shape: (2, samples)
# Check integrity
if source.shape[0] < 2:
print("Warning: Guitar stem is Mono. Cannot split Rhythm/Lead.")
path = os.path.join(output_dir, "guitar.mp3")
self._save_audio(source, sr, path)
return {"guitar_rhythm": path, "guitar_lead": path}
# Define Left and Right channels for processing
left = source[0:1, :]
right = source[1:2, :]
# 1. Smart Spatial Split Check
# Calculate correlation between L and R to detect Hard Panning (Math Rock Style)
# If correlation is low, it means L and R are playing different things.
# We assume Left = Rhythm (often dropped D/lower), Right = Lead (often ornate/higher) - OR provide both as is.
# Calculate cross-correlation at lag 0
mean_l = left.mean()
mean_r = right.mean()
var_l = ((left - mean_l)**2).mean()
var_r = ((right - mean_r)**2).mean()
cov = ((left - mean_l) * (right - mean_r)).mean()
correlation = 0.0
if var_l > 0 and var_r > 0:
correlation = cov / torch.sqrt(var_l * var_r)
print(f"Guitar Stereo Correlation: {correlation:.4f}")
# Threshold for "Wide Stereo"
if abs(correlation) < 0.6:
print("Detected Wide Stereo Guitar (Math Rock Style). Using Spatial Split (L=Rhythm, R=Lead).")
# Force Hard Split
# Rhythm = Left Channel (Duplicated to Stereo)
# Lead = Right Channel (Duplicated to Stereo)
rhythm_stereo = torch.cat([left, left], dim=0)
lead_stereo = torch.cat([right, right], dim=0)
else:
print("Detected Narrow/Mono Guitar. Using Mid-Side Frequency Split.")
# Standard Mid-Side with Tighter Filters
mid = (left + right) / 2.0
side = (left - right) / 2.0
try:
import scipy.signal as signal
nyquist = sr / 2
# Rhythm: Low-Mid focus (80-1500Hz) - tighter top end
# To distinguish from lead which often occupies 800+
rhythm_low = 80 / nyquist
rhythm_high = 1200 / nyquist
b_r, a_r = signal.butter(4, [rhythm_low, rhythm_high], btype='band')
# Lead: High-Mid focus (1000-8000Hz)
lead_low = 1000 / nyquist
lead_high = 8000 / nyquist
b_l, a_l = signal.butter(4, [lead_low, lead_high], btype='band')
# Apply to Mid (Center info usually has both, but we try to separate by freq)
# We interpret 'Mid' as the main source.
rhythm_from_mid = signal.filtfilt(b_r, a_r, mid.numpy())
lead_from_mid = signal.filtfilt(b_l, a_l, mid.numpy())
# Reconstruct
# If Side exists (some stereo), add it to Lead (often spatial effects are on lead)
side_np = side.numpy()
rhythm_final = rhythm_from_mid
lead_final = lead_from_mid + (side_np * 1.5) # Boost side for lead
rhythm_stereo = torch.from_numpy(rhythm_final).float()
rhythm_stereo = torch.cat([rhythm_stereo, rhythm_stereo], dim=0)
lead_stereo = torch.from_numpy(lead_final).float()
lead_stereo = torch.cat([lead_stereo, lead_stereo], dim=0)
except Exception as e:
print(f"Filter failed: {e}. Fallback to raw.")
rhythm_stereo = torch.cat([left, left], dim=0)
lead_stereo = torch.cat([right, right], dim=0)
# Normalize
def normalize(tensor):
peak = tensor.abs().max()
if peak > 0:
target_peak = 0.89 # -1dB
return tensor * (target_peak / peak)
return tensor
rhythm_stereo = normalize(rhythm_stereo)
lead_stereo = normalize(lead_stereo)
# MERGE TO SINGLE STEREO FILE (L=Rhythm, R=Lead)
# We take the Left channel of the Rhythm stereo mix (which is mono-ish)
# And the Right channel of the Lead stereo mix (which is mono-ish)
# Or better: Just use the Mono mix of each.
rhythm_mono = rhythm_stereo.mean(dim=0, keepdim=True)
lead_mono = lead_stereo.mean(dim=0, keepdim=True)
# Combine: Channel 0 = Rhythm, Channel 1 = Lead
guitar_split = torch.cat([rhythm_mono, lead_mono], dim=0)
guitar_split = normalize(guitar_split)
# Save as single file named "guitar.mp3" (Special split)
path = os.path.join(output_dir, "guitar.mp3")
self._save_audio(guitar_split, sr, path)
return {
"guitar": path
}
def _save_audio(self, source, sr, path):
# source is tensor (channels, samples) on device
# Move to cpu
source = source.cpu()
# Normalize to prevent clipping (limit to -1dB peak)
peak = source.abs().max()
if peak > 0.89: # approx -1dB
source = source / peak * 0.89
# Save using soundfile
# source is (channels, samples) -> need (samples, channels)
import soundfile as sf
sf.write(path, source.t().numpy(), sr)