| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "models", "audiosep")) |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "models", "flowsep")) |
|
|
| import gradio as gr |
| import torch |
| import numpy as np |
| import torchaudio |
| import librosa |
| import yaml |
| from huggingface_hub import hf_hub_download |
| from pytorch_lightning import seed_everything |
|
|
| try: |
| import spaces |
| except ImportError: |
| spaces = None |
|
|
| _audiosep_model = None |
| _flowsep_model = None |
| _flowsep_preprocessor = None |
|
|
|
|
| def get_runtime_device(): |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| class FlowSepPreprocessor: |
| def __init__(self, config): |
| import utilities.audio as Audio |
|
|
| self.sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] |
| self.duration = config["preprocessing"]["audio"]["duration"] |
| self.hopsize = config["preprocessing"]["stft"]["hop_length"] |
| self.target_length = int(self.duration * self.sampling_rate / self.hopsize) |
|
|
| self.STFT = Audio.stft.TacotronSTFT( |
| config["preprocessing"]["stft"]["filter_length"], |
| config["preprocessing"]["stft"]["hop_length"], |
| config["preprocessing"]["stft"]["win_length"], |
| config["preprocessing"]["mel"]["n_mel_channels"], |
| config["preprocessing"]["audio"]["sampling_rate"], |
| config["preprocessing"]["mel"]["mel_fmin"], |
| config["preprocessing"]["mel"]["mel_fmax"], |
| ) |
|
|
| def read_wav_file(self, filename): |
| waveform, sr = torchaudio.load(filename) |
| target_length = int(sr * self.duration) |
| if waveform.shape[-1] > target_length: |
| waveform = waveform[:, :target_length] |
| if sr != self.sampling_rate: |
| waveform = torchaudio.functional.resample(waveform, sr, self.sampling_rate) |
| waveform = waveform.numpy()[0, ...] |
| waveform = waveform - np.mean(waveform) |
| waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) |
| waveform = waveform * 0.5 |
| waveform = waveform[None, ...] |
| target_samples = int(self.sampling_rate * self.duration) |
| if waveform.shape[-1] < target_samples: |
| temp_wav = np.zeros((1, target_samples), dtype=np.float32) |
| temp_wav[:, :waveform.shape[-1]] = waveform |
| waveform = temp_wav |
| return waveform |
|
|
| def wav_feature_extraction(self, waveform): |
| import utilities.audio as Audio |
|
|
| waveform = waveform[0, ...] |
| waveform = torch.FloatTensor(waveform) |
| log_mel_spec, stft, energy = Audio.tools.get_mel_from_wav(waveform, self.STFT) |
| log_mel_spec = torch.FloatTensor(log_mel_spec.T) |
| stft = torch.FloatTensor(stft.T) |
| log_mel_spec = self._pad_spec(log_mel_spec) |
| stft = self._pad_spec(stft) |
| return log_mel_spec, stft |
|
|
| def _pad_spec(self, log_mel_spec): |
| n_frames = log_mel_spec.shape[0] |
| p = self.target_length - n_frames |
| if p > 0: |
| m = torch.nn.ZeroPad2d((0, 0, 0, p)) |
| log_mel_spec = m(log_mel_spec) |
| elif p < 0: |
| log_mel_spec = log_mel_spec[:self.target_length, :] |
| if log_mel_spec.size(-1) % 2 != 0: |
| log_mel_spec = log_mel_spec[..., :-1] |
| return log_mel_spec |
|
|
| def load_full_audio(self, filename): |
| waveform, sr = torchaudio.load(filename) |
| if sr != self.sampling_rate: |
| waveform = torchaudio.functional.resample(waveform, sr, self.sampling_rate) |
| waveform = waveform.numpy()[0, ...] |
| return waveform |
|
|
| def preprocess_chunk(self, chunk): |
| chunk = chunk - np.mean(chunk) |
| chunk = chunk / (np.max(np.abs(chunk)) + 1e-8) |
| chunk = chunk * 0.5 |
| return chunk |
|
|
|
|
| def load_audiosep(): |
| global _audiosep_model |
| device = get_runtime_device() |
| if _audiosep_model is not None: |
| _audiosep_model = _audiosep_model.to(device).eval() |
| return _audiosep_model |
|
|
| from models.clap_encoder import CLAP_Encoder |
| from utils import parse_yaml, load_ss_model |
|
|
| clap_ckpt = hf_hub_download(repo_id="ShandaAI/AudioSep-hive", filename="music_speech_audioset_epoch_15_esc_89.98.pt") |
| query_encoder = CLAP_Encoder(pretrained_path=clap_ckpt).eval() |
|
|
| config_file = hf_hub_download(repo_id="ShandaAI/AudioSep-hive", filename="config.yaml") |
| checkpoint_file = hf_hub_download(repo_id="ShandaAI/AudioSep-hive", filename="audiosep_hive.ckpt") |
| configs = parse_yaml(config_file) |
| model = load_ss_model(configs=configs, checkpoint_path=checkpoint_file, query_encoder=query_encoder) |
| model = model.to(device).eval() |
| _audiosep_model = model |
| return model |
|
|
|
|
| def load_flowsep(): |
| global _flowsep_model, _flowsep_preprocessor |
| device = get_runtime_device() |
| if _flowsep_model is not None: |
| _flowsep_model = _flowsep_model.to(device).eval() |
| return _flowsep_model, _flowsep_preprocessor |
|
|
| seed_everything(0) |
| from latent_diffusion.util import instantiate_from_config |
|
|
| config_file = hf_hub_download(repo_id="ShandaAI/FlowSep-hive", filename="config.yaml") |
| model_file = hf_hub_download(repo_id="ShandaAI/FlowSep-hive", filename="flowsep_hive.ckpt") |
|
|
| configs = yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader) |
| configs["model"]["params"]["first_stage_config"]["params"]["reload_from_ckpt"] = None |
|
|
| preprocessor = FlowSepPreprocessor(configs) |
|
|
| model = instantiate_from_config(configs["model"]).to(device) |
| try: |
| ckpt = torch.load(model_file, map_location=device, weights_only=False)["state_dict"] |
| except TypeError: |
| ckpt = torch.load(model_file, map_location=device)["state_dict"] |
| model.load_state_dict(ckpt, strict=True) |
| model.eval() |
|
|
| _flowsep_model = model |
| _flowsep_preprocessor = preprocessor |
| return model, preprocessor |
|
|
|
|
| AUDIOSEP_SR = 32000 |
| FLOWSEP_CHUNK_IN = 163840 |
| FLOWSEP_CHUNK_OUT = 160000 |
| FLOWSEP_SR = 16000 |
|
|
|
|
| def separate_audiosep(audio_path, text): |
| device = get_runtime_device() |
| model = load_audiosep() |
| mixture, _ = librosa.load(audio_path, sr=AUDIOSEP_SR, mono=True) |
| input_len = mixture.shape[0] |
|
|
| with torch.no_grad(): |
| conditions = model.query_encoder.get_query_embed( |
| modality='text', text=[text], device=device |
| ) |
| input_dict = { |
| "mixture": torch.Tensor(mixture)[None, None, :].to(device), |
| "condition": conditions, |
| } |
| if input_len > AUDIOSEP_SR * 10: |
| sep_audio = model.ss_model.chunk_inference(input_dict) |
| sep_audio = sep_audio.squeeze() |
| else: |
| sep_segment = model.ss_model(input_dict)["waveform"] |
| sep_audio = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy() |
| sep_audio = sep_audio[:input_len] |
|
|
| return (AUDIOSEP_SR, sep_audio) |
|
|
|
|
| def _flowsep_process_chunk(model, preprocessor, chunk_wav, text): |
| device = get_runtime_device() |
| chunk_wav = preprocessor.preprocess_chunk(chunk_wav) |
| if len(chunk_wav) < FLOWSEP_CHUNK_IN: |
| pad = np.zeros(FLOWSEP_CHUNK_IN - len(chunk_wav), dtype=np.float32) |
| chunk_wav = np.concatenate([chunk_wav, pad]) |
| chunk_wav = chunk_wav[:FLOWSEP_CHUNK_IN] |
| mixed_mel, stft = preprocessor.wav_feature_extraction(chunk_wav.reshape(1, -1)) |
| batch = { |
| "fname": ["temp"], |
| "text": [text], |
| "caption": [text], |
| "waveform": torch.rand(1, 1, FLOWSEP_CHUNK_IN).to(device), |
| "log_mel_spec": torch.rand(1, 1024, 64).to(device), |
| "sampling_rate": torch.tensor([FLOWSEP_SR]).to(device), |
| "label_vector": torch.rand(1, 527).to(device), |
| "stft": torch.rand(1, 1024, 512).to(device), |
| "mixed_waveform": torch.from_numpy(chunk_wav.reshape(1, 1, FLOWSEP_CHUNK_IN)).to(device), |
| "mixed_mel": mixed_mel.reshape(1, mixed_mel.shape[0], mixed_mel.shape[1]).to(device), |
| } |
| result = model.generate_sample( |
| [batch], |
| name="temp_result", |
| unconditional_guidance_scale=1.0, |
| ddim_steps=20, |
| n_gen=1, |
| save=False, |
| save_mixed=False, |
| ) |
| if isinstance(result, np.ndarray): |
| out = result.squeeze() |
| else: |
| out = result.squeeze().cpu().numpy() |
| return out[:FLOWSEP_CHUNK_OUT] |
|
|
|
|
| def separate_flowsep(audio_path, text): |
| device = get_runtime_device() |
| model, preprocessor = load_flowsep() |
| full_wav = preprocessor.load_full_audio(audio_path) |
| input_len = full_wav.shape[0] |
|
|
| with torch.no_grad(): |
| if input_len <= FLOWSEP_CHUNK_IN: |
| sep_audio = _flowsep_process_chunk(model, preprocessor, full_wav.copy(), text) |
| else: |
| out_list = [] |
| start = 0 |
| while start < input_len: |
| end = min(start + FLOWSEP_CHUNK_IN, input_len) |
| chunk = full_wav[start:end] |
| out_chunk = _flowsep_process_chunk(model, preprocessor, chunk.copy(), text) |
| need = min(FLOWSEP_CHUNK_OUT, input_len - start) |
| out_list.append(out_chunk[:need]) |
| start += FLOWSEP_CHUNK_OUT |
| sep_audio = np.concatenate(out_list) |
|
|
| if len(sep_audio) > input_len: |
| sep_audio = sep_audio[:input_len] |
| elif len(sep_audio) < input_len: |
| sep_audio = np.pad(sep_audio, (0, input_len - len(sep_audio)), mode="constant", constant_values=0) |
|
|
| return (FLOWSEP_SR, sep_audio) |
|
|
|
|
| def inference(audio, text, model_choice): |
| if audio is None: |
| raise gr.Error("Please upload an audio file / 请上传音频文件") |
| if not text or not text.strip(): |
| raise gr.Error("Please enter a text query / 请输入文本描述") |
|
|
| if model_choice == "AudioSep-hive": |
| return separate_audiosep(audio, text) |
| else: |
| return separate_flowsep(audio, text) |
|
|
|
|
| if spaces is not None: |
| @spaces.GPU(duration=120) |
| def inference_entry(audio, text, model_choice): |
| return inference(audio, text, model_choice) |
| else: |
| def inference_entry(audio, text, model_choice): |
| return inference(audio, text, model_choice) |
|
|
|
|
| DESCRIPTION = """ |
| # Universal Sound Separation on HIVE |
| |
| **Hive** is a high-quality synthetic dataset (2k hours) built via an automated pipeline that mines high-purity single-event segments and synthesizes semantically consistent mixtures. Despite using only ~0.2% of the data scale of million-hour baselines, models trained on Hive achieve competitive separation accuracy and strong zero-shot generalization. |
| |
| This space provides two separation models trained on Hive: |
| - **AudioSep**: A foundation model for open-domain sound separation with natural language queries, based on [AudioSep](https://github.com/Audio-AGI/AudioSep). |
| - **FlowSep**: A flow-matching based separation model with text conditioning, based on [FlowSep](https://github.com/Audio-AGI/FlowSep). |
| |
| **How to use:** |
| 1. Upload an audio file (mix of sounds) |
| 2. Describe what you want to separate (e.g., "piano", "speech", "dog barking") |
| 3. Select a model and click Separate |
| |
| [[Paper]](https://arxiv.org/abs/2601.22599) | [[Code]](https://github.com/ShandaAI/Hive) | [[Hive Dataset]](https://huggingface.co/datasets/ShandaAI/Hive) | [[Demo Page]](https://shandaai.github.io/Hive/) |
| """ |
|
|
| EXAMPLES = [ |
| ["examples/acoustic_guitar.wav", "acoustic guitar"], |
| ["examples/laughing.wav", "laughing"], |
| ["examples/ticktok_piano.wav", "A ticktock sound playing at the same rhythm with piano"], |
| ["examples/water_drops.wav", "water drops"], |
| ["examples/noisy_speech.wav", "speech"], |
| ] |
|
|
| with gr.Blocks( |
| theme=gr.themes.Soft(), |
| title="Universal Sound Separation on HIVE", |
| ) as demo: |
| gr.Markdown(DESCRIPTION) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| audio_input = gr.Audio(label="Input Mixture Audio", type="filepath") |
| text_input = gr.Textbox( |
| label="Text Query", |
| placeholder='e.g. "dog barking", "piano playing"', |
| ) |
| model_choice = gr.Dropdown( |
| choices=["AudioSep-hive", "FlowSep-hive"], |
| value="AudioSep-hive", |
| label="Select Model", |
| ) |
| submit_btn = gr.Button("Separate", variant="primary") |
|
|
| with gr.Column(): |
| audio_output = gr.Audio(label="Separated Audio") |
|
|
| submit_btn.click( |
| fn=inference_entry, |
| inputs=[audio_input, text_input, model_choice], |
| outputs=audio_output, |
| ) |
|
|
| gr.Markdown("## Examples") |
| gr.Examples(examples=EXAMPLES, inputs=[audio_input, text_input]) |
|
|
| DEBUG = False |
|
|
| def run_debug(): |
| examples_dir = os.path.join(os.path.dirname(__file__), "examples") |
| test_path = os.path.join(examples_dir, "acoustic_guitar.wav") |
| test_text = "acoustic guitar" |
| print("\n" + "=" * 50) |
| print("[DEBUG] Starting inference test for both models") |
| print("=" * 50) |
|
|
| if not os.path.exists(test_path): |
| print(f"[DEBUG] Skip: {test_path} not found") |
| return |
|
|
| print(f"\n[DEBUG] Using test audio: {test_path}") |
|
|
| print("\n" + "-" * 40) |
| print("[DEBUG] AudioSep inference") |
| print("-" * 40) |
| print("[DEBUG] Loading AudioSep model...") |
| out_audiosep = separate_audiosep(test_path, test_text) |
| print(f"[DEBUG] AudioSep done. Output sr={out_audiosep[0]}, shape={np.array(out_audiosep[1]).shape}") |
|
|
| print("\n" + "-" * 40) |
| print("[DEBUG] FlowSep inference") |
| print("-" * 40) |
| print("[DEBUG] Loading FlowSep model...") |
| out_flowsep = separate_flowsep(test_path, test_text) |
| print(f"[DEBUG] FlowSep done. Output sr={out_flowsep[0]}, shape={np.array(out_flowsep[1]).shape}") |
|
|
| print("\n" + "=" * 50) |
| print("[DEBUG] Both models passed inference test") |
| print("=" * 50 + "\n") |
|
|
|
|
| if __name__ == "__main__": |
| if DEBUG: |
| run_debug() |
| demo.queue() |
| demo.launch() |
|
|