import logging import os from typing import Callable import numpy as np import soundfile as sf import torch import torchaudio import yaml from pydub import AudioSegment from huggingface_hub import hf_hub_download from audio_separator.separator.uvr_lib_v5.roformer.bs_roformer import BSRoformer logger = logging.getLogger(__name__) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {DEVICE}") if DEVICE.type == "cpu": cpu_count = os.cpu_count() or 1 torch.set_num_threads(cpu_count) torch.set_num_interop_threads(max(1, cpu_count // 2)) logger.info(f"CPU mode: set torch threads={cpu_count}, interop={max(1, cpu_count // 2)}") MODEL_REPO = "jarredou/BS-ROFO-SW-Fixed" MODEL_FILENAME = "BS-Rofo-SW-Fixed.ckpt" MODEL_CONFIG = "BS-Rofo-SW-Fixed.yaml" MODEL_DIR = "/tmp/models" # Stem order matches the model's training config STEM_ORDER = ["bass", "drums", "other", "vocals", "guitar", "piano"] STEM_NAME_MAP = { "bass": "Bass", "drums": "Drums", "other": "Other", "vocals": "Vocals", "guitar": "Guitar", "piano": "Piano", } class StemSeparatorService: _instance = None _model_loaded = False def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def load_model(self): if self._model_loaded: return os.makedirs(MODEL_DIR, exist_ok=True) logger.info(f"Downloading model from HF Hub: {MODEL_REPO}") hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME, local_dir=MODEL_DIR) hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_CONFIG, local_dir=MODEL_DIR) # Parse config with open(os.path.join(MODEL_DIR, MODEL_CONFIG)) as f: config = yaml.load(f, Loader=yaml.FullLoader) model_cfg = config["model"] audio_cfg = config.get("audio", {}) inference_cfg = config.get("inference", {}) self.sample_rate = audio_cfg.get("sample_rate", 44100) self.chunk_size = audio_cfg.get("chunk_size", 588800) self.num_overlap = inference_cfg.get("num_overlap", 2) # Use flash_attn only on CUDA where it's supported use_flash = DEVICE.type == "cuda" # Create model directly — bypass audio-separator's Separator wrapper entirely self.model = BSRoformer( dim=model_cfg["dim"], depth=model_cfg["depth"], stereo=model_cfg.get("stereo", True), num_stems=model_cfg.get("num_stems", 6), time_transformer_depth=model_cfg.get("time_transformer_depth", 1), freq_transformer_depth=model_cfg.get("freq_transformer_depth", 1), linear_transformer_depth=model_cfg.get("linear_transformer_depth", 0), freqs_per_bands=tuple(model_cfg["freqs_per_bands"]), dim_head=model_cfg.get("dim_head", 64), heads=model_cfg.get("heads", 8), attn_dropout=model_cfg.get("attn_dropout", 0.1), ff_dropout=model_cfg.get("ff_dropout", 0.1), flash_attn=use_flash, dim_freqs_in=model_cfg.get("dim_freqs_in", 1025), stft_n_fft=model_cfg.get("stft_n_fft", 2048), stft_hop_length=model_cfg.get("stft_hop_length", 512), stft_win_length=model_cfg.get("stft_win_length", 2048), stft_normalized=model_cfg.get("stft_normalized", False), mask_estimator_depth=model_cfg.get("mask_estimator_depth", 2), multi_stft_resolution_loss_weight=model_cfg.get("multi_stft_resolution_loss_weight", 1.0), multi_stft_resolutions_window_sizes=tuple( model_cfg.get("multi_stft_resolutions_window_sizes", (4096, 2048, 1024, 512, 256)) ), multi_stft_hop_size=model_cfg.get("multi_stft_hop_size", 147), multi_stft_normalized=model_cfg.get("multi_stft_normalized", False), mlp_expansion_factor=model_cfg.get("mlp_expansion_factor", 4), use_torch_checkpoint=model_cfg.get("use_torch_checkpoint", False), skip_connection=model_cfg.get("skip_connection", False), ) # Load checkpoint weights ckpt_path = os.path.join(MODEL_DIR, MODEL_FILENAME) try: state_dict = torch.load(ckpt_path, map_location=DEVICE, weights_only=True) except TypeError: state_dict = torch.load(ckpt_path, map_location=DEVICE) if isinstance(state_dict, dict) and "state_dict" in state_dict: state_dict = state_dict["state_dict"] elif isinstance(state_dict, dict) and "model" in state_dict: state_dict = state_dict["model"] self.model.load_state_dict(state_dict) self.model.to(DEVICE) self.model.eval() logger.info(f"BS-RoFormer model loaded successfully on {DEVICE}") self._model_loaded = True def _process_audio(self, audio_tensor: torch.Tensor, progress_callback) -> torch.Tensor: """Run inference with chunking and overlap-add.""" chunk_size = self.chunk_size step = chunk_size // self.num_overlap channels, total_samples = audio_tensor.shape num_stems = 6 # Pad so we cover the full audio pad_needed = max(0, chunk_size - total_samples) if total_samples > chunk_size: remainder = (total_samples - chunk_size) % step if remainder != 0: pad_needed = step - remainder if pad_needed > 0: audio_tensor = torch.nn.functional.pad(audio_tensor, (0, pad_needed)) padded_len = audio_tensor.shape[1] # Move input to device audio_tensor = audio_tensor.to(DEVICE) # Output accumulators (keep on CPU to save GPU memory) result = torch.zeros(num_stems, channels, padded_len) weight = torch.zeros(padded_len) # Hann window for smooth crossfading window = torch.hann_window(chunk_size, device=DEVICE) # Build chunk positions starts = list(range(0, padded_len - chunk_size + 1, step)) total_chunks = len(starts) for i, start in enumerate(starts): chunk = audio_tensor[:, start : start + chunk_size] with torch.no_grad(): # BSRoformer: (batch, channels, time) -> (batch, stems, channels, time) output = self.model(chunk.unsqueeze(0)) output = output.squeeze(0) # (stems, channels, time) # Move output to CPU for accumulation output_cpu = output.cpu() window_cpu = window.cpu() result[:, :, start : start + chunk_size] += output_cpu * window_cpu weight[start : start + chunk_size] += window_cpu frac = (i + 1) / total_chunks progress_callback("separating", 0.2 + frac * 0.7) # Normalize by overlap weight result = result / weight.clamp(min=1e-8).unsqueeze(0).unsqueeze(0) # Remove padding return result[:, :, :total_samples] def separate( self, input_path: str, output_dir: str, stems: list[str], output_format: str, progress_callback: Callable[[str, float], None], ) -> dict[str, str]: progress_callback("loading_model", 0.05) self.load_model() progress_callback("separating", 0.15) # Load audio audio, sr = sf.read(input_path) if audio.ndim == 1: audio = np.stack([audio, audio], axis=1) # Mono to stereo audio_tensor = torch.tensor(audio.T, dtype=torch.float32) # (channels, samples) # Resample if needed if sr != self.sample_rate: resampler = torchaudio.transforms.Resample(sr, self.sample_rate) audio_tensor = resampler(audio_tensor) # Run inference separated = self._process_audio(audio_tensor, progress_callback) progress_callback("finalizing", 0.92) # Save requested stems result: dict[str, str] = {} for i, stem_key in enumerate(STEM_ORDER): canonical = STEM_NAME_MAP[stem_key] if canonical in stems: stem_audio = separated[i].numpy().T # (samples, channels) # Clip to prevent clipping artifacts stem_audio = np.clip(stem_audio, -1.0, 1.0) clean_name = f"{canonical}.{output_format}" out_path = os.path.join(output_dir, clean_name) self._write_output(out_path, stem_audio, output_format) result[canonical] = clean_name progress_callback("done", 1.0) return result def _write_output(self, output_path: str, stem_audio: np.ndarray, output_format: str): if output_format == "wav": sf.write(output_path, stem_audio, self.sample_rate, subtype="FLOAT") return pcm = (stem_audio * 32767.0).astype(np.int16) segment = AudioSegment( data=pcm.tobytes(), sample_width=2, frame_rate=self.sample_rate, channels=pcm.shape[1] if pcm.ndim > 1 else 1, ) export_format = "mp3" if output_format == "mp3" else "adts" export_kwargs = {"format": export_format} if output_format in {"mp3", "aac"}: export_kwargs["bitrate"] = "192k" segment.export(output_path, **export_kwargs)