Spaces:
Running
Running
| import logging | |
| import os | |
| from typing import Callable | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| import torchaudio | |
| import yaml | |
| from pydub import AudioSegment | |
| from huggingface_hub import hf_hub_download | |
| from audio_separator.separator.uvr_lib_v5.roformer.bs_roformer import BSRoformer | |
| logger = logging.getLogger(__name__) | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {DEVICE}") | |
| if DEVICE.type == "cpu": | |
| cpu_count = os.cpu_count() or 1 | |
| torch.set_num_threads(cpu_count) | |
| torch.set_num_interop_threads(max(1, cpu_count // 2)) | |
| logger.info(f"CPU mode: set torch threads={cpu_count}, interop={max(1, cpu_count // 2)}") | |
| MODEL_REPO = "jarredou/BS-ROFO-SW-Fixed" | |
| MODEL_FILENAME = "BS-Rofo-SW-Fixed.ckpt" | |
| MODEL_CONFIG = "BS-Rofo-SW-Fixed.yaml" | |
| MODEL_DIR = "/tmp/models" | |
| # Stem order matches the model's training config | |
| STEM_ORDER = ["bass", "drums", "other", "vocals", "guitar", "piano"] | |
| STEM_NAME_MAP = { | |
| "bass": "Bass", | |
| "drums": "Drums", | |
| "other": "Other", | |
| "vocals": "Vocals", | |
| "guitar": "Guitar", | |
| "piano": "Piano", | |
| } | |
| class StemSeparatorService: | |
| _instance = None | |
| _model_loaded = False | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| return cls._instance | |
| def load_model(self): | |
| if self._model_loaded: | |
| return | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| logger.info(f"Downloading model from HF Hub: {MODEL_REPO}") | |
| hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME, local_dir=MODEL_DIR) | |
| hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_CONFIG, local_dir=MODEL_DIR) | |
| # Parse config | |
| with open(os.path.join(MODEL_DIR, MODEL_CONFIG)) as f: | |
| config = yaml.load(f, Loader=yaml.FullLoader) | |
| model_cfg = config["model"] | |
| audio_cfg = config.get("audio", {}) | |
| inference_cfg = config.get("inference", {}) | |
| self.sample_rate = audio_cfg.get("sample_rate", 44100) | |
| self.chunk_size = audio_cfg.get("chunk_size", 588800) | |
| self.num_overlap = inference_cfg.get("num_overlap", 2) | |
| # Use flash_attn only on CUDA where it's supported | |
| use_flash = DEVICE.type == "cuda" | |
| # Create model directly — bypass audio-separator's Separator wrapper entirely | |
| self.model = BSRoformer( | |
| dim=model_cfg["dim"], | |
| depth=model_cfg["depth"], | |
| stereo=model_cfg.get("stereo", True), | |
| num_stems=model_cfg.get("num_stems", 6), | |
| time_transformer_depth=model_cfg.get("time_transformer_depth", 1), | |
| freq_transformer_depth=model_cfg.get("freq_transformer_depth", 1), | |
| linear_transformer_depth=model_cfg.get("linear_transformer_depth", 0), | |
| freqs_per_bands=tuple(model_cfg["freqs_per_bands"]), | |
| dim_head=model_cfg.get("dim_head", 64), | |
| heads=model_cfg.get("heads", 8), | |
| attn_dropout=model_cfg.get("attn_dropout", 0.1), | |
| ff_dropout=model_cfg.get("ff_dropout", 0.1), | |
| flash_attn=use_flash, | |
| dim_freqs_in=model_cfg.get("dim_freqs_in", 1025), | |
| stft_n_fft=model_cfg.get("stft_n_fft", 2048), | |
| stft_hop_length=model_cfg.get("stft_hop_length", 512), | |
| stft_win_length=model_cfg.get("stft_win_length", 2048), | |
| stft_normalized=model_cfg.get("stft_normalized", False), | |
| mask_estimator_depth=model_cfg.get("mask_estimator_depth", 2), | |
| multi_stft_resolution_loss_weight=model_cfg.get("multi_stft_resolution_loss_weight", 1.0), | |
| multi_stft_resolutions_window_sizes=tuple( | |
| model_cfg.get("multi_stft_resolutions_window_sizes", (4096, 2048, 1024, 512, 256)) | |
| ), | |
| multi_stft_hop_size=model_cfg.get("multi_stft_hop_size", 147), | |
| multi_stft_normalized=model_cfg.get("multi_stft_normalized", False), | |
| mlp_expansion_factor=model_cfg.get("mlp_expansion_factor", 4), | |
| use_torch_checkpoint=model_cfg.get("use_torch_checkpoint", False), | |
| skip_connection=model_cfg.get("skip_connection", False), | |
| ) | |
| # Load checkpoint weights | |
| ckpt_path = os.path.join(MODEL_DIR, MODEL_FILENAME) | |
| try: | |
| state_dict = torch.load(ckpt_path, map_location=DEVICE, weights_only=True) | |
| except TypeError: | |
| state_dict = torch.load(ckpt_path, map_location=DEVICE) | |
| if isinstance(state_dict, dict) and "state_dict" in state_dict: | |
| state_dict = state_dict["state_dict"] | |
| elif isinstance(state_dict, dict) and "model" in state_dict: | |
| state_dict = state_dict["model"] | |
| self.model.load_state_dict(state_dict) | |
| self.model.to(DEVICE) | |
| self.model.eval() | |
| logger.info(f"BS-RoFormer model loaded successfully on {DEVICE}") | |
| self._model_loaded = True | |
| def _process_audio(self, audio_tensor: torch.Tensor, progress_callback) -> torch.Tensor: | |
| """Run inference with chunking and overlap-add.""" | |
| chunk_size = self.chunk_size | |
| step = chunk_size // self.num_overlap | |
| channels, total_samples = audio_tensor.shape | |
| num_stems = 6 | |
| # Pad so we cover the full audio | |
| pad_needed = max(0, chunk_size - total_samples) | |
| if total_samples > chunk_size: | |
| remainder = (total_samples - chunk_size) % step | |
| if remainder != 0: | |
| pad_needed = step - remainder | |
| if pad_needed > 0: | |
| audio_tensor = torch.nn.functional.pad(audio_tensor, (0, pad_needed)) | |
| padded_len = audio_tensor.shape[1] | |
| # Move input to device | |
| audio_tensor = audio_tensor.to(DEVICE) | |
| # Output accumulators (keep on CPU to save GPU memory) | |
| result = torch.zeros(num_stems, channels, padded_len) | |
| weight = torch.zeros(padded_len) | |
| # Hann window for smooth crossfading | |
| window = torch.hann_window(chunk_size, device=DEVICE) | |
| # Build chunk positions | |
| starts = list(range(0, padded_len - chunk_size + 1, step)) | |
| total_chunks = len(starts) | |
| for i, start in enumerate(starts): | |
| chunk = audio_tensor[:, start : start + chunk_size] | |
| with torch.no_grad(): | |
| # BSRoformer: (batch, channels, time) -> (batch, stems, channels, time) | |
| output = self.model(chunk.unsqueeze(0)) | |
| output = output.squeeze(0) # (stems, channels, time) | |
| # Move output to CPU for accumulation | |
| output_cpu = output.cpu() | |
| window_cpu = window.cpu() | |
| result[:, :, start : start + chunk_size] += output_cpu * window_cpu | |
| weight[start : start + chunk_size] += window_cpu | |
| frac = (i + 1) / total_chunks | |
| progress_callback("separating", 0.2 + frac * 0.7) | |
| # Normalize by overlap weight | |
| result = result / weight.clamp(min=1e-8).unsqueeze(0).unsqueeze(0) | |
| # Remove padding | |
| return result[:, :, :total_samples] | |
| def separate( | |
| self, | |
| input_path: str, | |
| output_dir: str, | |
| stems: list[str], | |
| output_format: str, | |
| progress_callback: Callable[[str, float], None], | |
| ) -> dict[str, str]: | |
| progress_callback("loading_model", 0.05) | |
| self.load_model() | |
| progress_callback("separating", 0.15) | |
| # Load audio | |
| audio, sr = sf.read(input_path) | |
| if audio.ndim == 1: | |
| audio = np.stack([audio, audio], axis=1) # Mono to stereo | |
| audio_tensor = torch.tensor(audio.T, dtype=torch.float32) # (channels, samples) | |
| # Resample if needed | |
| if sr != self.sample_rate: | |
| resampler = torchaudio.transforms.Resample(sr, self.sample_rate) | |
| audio_tensor = resampler(audio_tensor) | |
| # Run inference | |
| separated = self._process_audio(audio_tensor, progress_callback) | |
| progress_callback("finalizing", 0.92) | |
| # Save requested stems | |
| result: dict[str, str] = {} | |
| for i, stem_key in enumerate(STEM_ORDER): | |
| canonical = STEM_NAME_MAP[stem_key] | |
| if canonical in stems: | |
| stem_audio = separated[i].numpy().T # (samples, channels) | |
| # Clip to prevent clipping artifacts | |
| stem_audio = np.clip(stem_audio, -1.0, 1.0) | |
| clean_name = f"{canonical}.{output_format}" | |
| out_path = os.path.join(output_dir, clean_name) | |
| self._write_output(out_path, stem_audio, output_format) | |
| result[canonical] = clean_name | |
| progress_callback("done", 1.0) | |
| return result | |
| def _write_output(self, output_path: str, stem_audio: np.ndarray, output_format: str): | |
| if output_format == "wav": | |
| sf.write(output_path, stem_audio, self.sample_rate, subtype="FLOAT") | |
| return | |
| pcm = (stem_audio * 32767.0).astype(np.int16) | |
| segment = AudioSegment( | |
| data=pcm.tobytes(), | |
| sample_width=2, | |
| frame_rate=self.sample_rate, | |
| channels=pcm.shape[1] if pcm.ndim > 1 else 1, | |
| ) | |
| export_format = "mp3" if output_format == "mp3" else "adts" | |
| export_kwargs = {"format": export_format} | |
| if output_format in {"mp3", "aac"}: | |
| export_kwargs["bitrate"] = "192k" | |
| segment.export(output_path, **export_kwargs) | |