|
|
import torch |
|
|
import torchaudio |
|
|
|
|
|
|
|
|
try: |
|
|
torchaudio.set_audio_backend("soundfile") |
|
|
except: |
|
|
pass |
|
|
|
|
|
from demucs.apply import apply_model |
|
|
from demucs.pretrained import get_model |
|
|
import os |
|
|
import pathlib |
|
|
|
|
|
|
|
|
|
|
|
MODELS = { |
|
|
"2stem": "htdemucs", |
|
|
"4stem": "htdemucs", |
|
|
"6stem": "htdemucs_6s", |
|
|
} |
|
|
|
|
|
class AudioSeparator: |
|
|
def __init__(self): |
|
|
|
|
|
self.models = {} |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
unique_models = set(MODELS.values()) |
|
|
for model_name in unique_models: |
|
|
print(f"Loading Demucs Model: {model_name}...") |
|
|
model = get_model(model_name) |
|
|
model.to(self.device) |
|
|
self.models[model_name] = model |
|
|
print(f"All models loaded on {self.device}") |
|
|
|
|
|
def separate(self, audio_path: str, output_dir: str, callback=None, mode="4stem"): |
|
|
""" |
|
|
Memisahkan file audio menjadi stems. |
|
|
Modes: 2stem, 4stem, 6stem |
|
|
""" |
|
|
|
|
|
model_name = MODELS.get(mode, "htdemucs") |
|
|
model = self.models[model_name] |
|
|
|
|
|
|
|
|
import soundfile as sf |
|
|
wav_np, sr = sf.read(audio_path) |
|
|
|
|
|
|
|
|
wav = torch.from_numpy(wav_np).float() |
|
|
|
|
|
if wav.ndim == 1: |
|
|
wav = wav.unsqueeze(0) |
|
|
else: |
|
|
wav = wav.t() |
|
|
|
|
|
|
|
|
if sr != 44100: |
|
|
if callback: callback("Resampling audio...", 15) |
|
|
resampler = torchaudio.transforms.Resample(sr, 44100) |
|
|
wav = resampler(wav) |
|
|
sr = 44100 |
|
|
|
|
|
wav = wav.unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
ref = wav.mean(0) |
|
|
wav = (wav - ref.mean()) / ref.std() |
|
|
|
|
|
if callback: callback("Running Demucs Inference...", 20) |
|
|
print(f"Starting separation with {model_name} (mode: {mode})...") |
|
|
|
|
|
sources = apply_model(model, wav, shifts=1, split=True, overlap=0.25, progress=True)[0] |
|
|
source_names = model.sources |
|
|
|
|
|
results = {} |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
total_sources = len(source_names) |
|
|
source_tensors = {name: source for name, source in zip(source_names, sources)} |
|
|
|
|
|
|
|
|
if mode == "2stem": |
|
|
|
|
|
if callback: callback("Merging to 2 stems...", 45) |
|
|
|
|
|
vocals = source_tensors.get('vocals') |
|
|
instruments = None |
|
|
for name, src in source_tensors.items(): |
|
|
if name != 'vocals': |
|
|
if instruments is None: |
|
|
instruments = src.clone() |
|
|
else: |
|
|
instruments += src |
|
|
|
|
|
if vocals is not None: |
|
|
self._save_audio(vocals, sr, os.path.join(output_dir, "vocals.mp3")) |
|
|
results['vocals'] = os.path.join(output_dir, "vocals.mp3") |
|
|
if instruments is not None: |
|
|
self._save_audio(instruments, sr, os.path.join(output_dir, "instruments.mp3")) |
|
|
results['instruments'] = os.path.join(output_dir, "instruments.mp3") |
|
|
|
|
|
elif mode == "6stem": |
|
|
|
|
|
for i, (name, source) in enumerate(source_tensors.items()): |
|
|
progress = 30 + int((i / total_sources) * 20) |
|
|
if callback: callback(f"Saving stem: {name}", progress) |
|
|
|
|
|
if name == 'guitar': |
|
|
results.update(self._process_guitar(source, sr, output_dir)) |
|
|
else: |
|
|
stem_path = os.path.join(output_dir, f"{name}.mp3") |
|
|
self._save_audio(source, sr, stem_path) |
|
|
results[name] = stem_path |
|
|
else: |
|
|
|
|
|
for i, (name, source) in enumerate(source_tensors.items()): |
|
|
progress = 30 + int((i / total_sources) * 20) |
|
|
if callback: callback(f"Saving stem: {name}", progress) |
|
|
|
|
|
stem_path = os.path.join(output_dir, f"{name}.mp3") |
|
|
self._save_audio(source, sr, stem_path) |
|
|
results[name] = stem_path |
|
|
|
|
|
input_duration = len(wav_np) / sr |
|
|
return results, input_duration |
|
|
|
|
|
def _process_guitar(self, source, sr, output_dir): |
|
|
""" |
|
|
Memisahkan stem gitar menjadi Lead dan Rhythm menggunakan Mid-Side processing. |
|
|
- Mid (center) = Rhythm (biasanya power chords, strumming di center) |
|
|
- Side (stereo difference) = Lead (biasanya di-pan atau dengan stereo effects) |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
if source.shape[0] < 2: |
|
|
print("Warning: Guitar stem is Mono. Cannot split Rhythm/Lead.") |
|
|
path = os.path.join(output_dir, "guitar.mp3") |
|
|
self._save_audio(source, sr, path) |
|
|
return {"guitar_rhythm": path, "guitar_lead": path} |
|
|
|
|
|
|
|
|
left = source[0:1, :] |
|
|
right = source[1:2, :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mean_l = left.mean() |
|
|
mean_r = right.mean() |
|
|
var_l = ((left - mean_l)**2).mean() |
|
|
var_r = ((right - mean_r)**2).mean() |
|
|
cov = ((left - mean_l) * (right - mean_r)).mean() |
|
|
|
|
|
correlation = 0.0 |
|
|
if var_l > 0 and var_r > 0: |
|
|
correlation = cov / torch.sqrt(var_l * var_r) |
|
|
|
|
|
print(f"Guitar Stereo Correlation: {correlation:.4f}") |
|
|
|
|
|
|
|
|
if abs(correlation) < 0.6: |
|
|
print("Detected Wide Stereo Guitar (Math Rock Style). Using Spatial Split (L=Rhythm, R=Lead).") |
|
|
|
|
|
|
|
|
|
|
|
rhythm_stereo = torch.cat([left, left], dim=0) |
|
|
lead_stereo = torch.cat([right, right], dim=0) |
|
|
else: |
|
|
print("Detected Narrow/Mono Guitar. Using Mid-Side Frequency Split.") |
|
|
|
|
|
|
|
|
mid = (left + right) / 2.0 |
|
|
side = (left - right) / 2.0 |
|
|
|
|
|
try: |
|
|
import scipy.signal as signal |
|
|
nyquist = sr / 2 |
|
|
|
|
|
|
|
|
|
|
|
rhythm_low = 80 / nyquist |
|
|
rhythm_high = 1200 / nyquist |
|
|
b_r, a_r = signal.butter(4, [rhythm_low, rhythm_high], btype='band') |
|
|
|
|
|
|
|
|
lead_low = 1000 / nyquist |
|
|
lead_high = 8000 / nyquist |
|
|
b_l, a_l = signal.butter(4, [lead_low, lead_high], btype='band') |
|
|
|
|
|
|
|
|
|
|
|
rhythm_from_mid = signal.filtfilt(b_r, a_r, mid.numpy()) |
|
|
lead_from_mid = signal.filtfilt(b_l, a_l, mid.numpy()) |
|
|
|
|
|
|
|
|
|
|
|
side_np = side.numpy() |
|
|
|
|
|
rhythm_final = rhythm_from_mid |
|
|
lead_final = lead_from_mid + (side_np * 1.5) |
|
|
|
|
|
rhythm_stereo = torch.from_numpy(rhythm_final).float() |
|
|
rhythm_stereo = torch.cat([rhythm_stereo, rhythm_stereo], dim=0) |
|
|
|
|
|
lead_stereo = torch.from_numpy(lead_final).float() |
|
|
lead_stereo = torch.cat([lead_stereo, lead_stereo], dim=0) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Filter failed: {e}. Fallback to raw.") |
|
|
rhythm_stereo = torch.cat([left, left], dim=0) |
|
|
lead_stereo = torch.cat([right, right], dim=0) |
|
|
|
|
|
|
|
|
def normalize(tensor): |
|
|
peak = tensor.abs().max() |
|
|
if peak > 0: |
|
|
target_peak = 0.89 |
|
|
return tensor * (target_peak / peak) |
|
|
return tensor |
|
|
|
|
|
rhythm_stereo = normalize(rhythm_stereo) |
|
|
lead_stereo = normalize(lead_stereo) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rhythm_mono = rhythm_stereo.mean(dim=0, keepdim=True) |
|
|
lead_mono = lead_stereo.mean(dim=0, keepdim=True) |
|
|
|
|
|
|
|
|
guitar_split = torch.cat([rhythm_mono, lead_mono], dim=0) |
|
|
|
|
|
guitar_split = normalize(guitar_split) |
|
|
|
|
|
|
|
|
path = os.path.join(output_dir, "guitar.mp3") |
|
|
self._save_audio(guitar_split, sr, path) |
|
|
|
|
|
return { |
|
|
"guitar": path |
|
|
} |
|
|
|
|
|
def _save_audio(self, source, sr, path): |
|
|
|
|
|
|
|
|
source = source.cpu() |
|
|
|
|
|
|
|
|
peak = source.abs().max() |
|
|
if peak > 0.89: |
|
|
source = source / peak * 0.89 |
|
|
|
|
|
|
|
|
|
|
|
import soundfile as sf |
|
|
sf.write(path, source.t().numpy(), sr) |
|
|
|