Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "models", "audiosep")) | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "models", "flowsep")) | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import torchaudio | |
| import librosa | |
| import yaml | |
| from huggingface_hub import hf_hub_download | |
| from pytorch_lightning import seed_everything | |
| try: | |
| import spaces | |
| except ImportError: | |
| spaces = None | |
| _audiosep_model = None | |
| _flowsep_model = None | |
| _flowsep_preprocessor = None | |
| def get_runtime_device(): | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| class FlowSepPreprocessor: | |
| def __init__(self, config): | |
| import utilities.audio as Audio | |
| self.sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] | |
| self.duration = config["preprocessing"]["audio"]["duration"] | |
| self.hopsize = config["preprocessing"]["stft"]["hop_length"] | |
| self.target_length = int(self.duration * self.sampling_rate / self.hopsize) | |
| self.STFT = Audio.stft.TacotronSTFT( | |
| config["preprocessing"]["stft"]["filter_length"], | |
| config["preprocessing"]["stft"]["hop_length"], | |
| config["preprocessing"]["stft"]["win_length"], | |
| config["preprocessing"]["mel"]["n_mel_channels"], | |
| config["preprocessing"]["audio"]["sampling_rate"], | |
| config["preprocessing"]["mel"]["mel_fmin"], | |
| config["preprocessing"]["mel"]["mel_fmax"], | |
| ) | |
| def read_wav_file(self, filename): | |
| waveform, sr = torchaudio.load(filename) | |
| target_length = int(sr * self.duration) | |
| if waveform.shape[-1] > target_length: | |
| waveform = waveform[:, :target_length] | |
| if sr != self.sampling_rate: | |
| waveform = torchaudio.functional.resample(waveform, sr, self.sampling_rate) | |
| waveform = waveform.numpy()[0, ...] | |
| waveform = waveform - np.mean(waveform) | |
| waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) | |
| waveform = waveform * 0.5 | |
| waveform = waveform[None, ...] | |
| target_samples = int(self.sampling_rate * self.duration) | |
| if waveform.shape[-1] < target_samples: | |
| temp_wav = np.zeros((1, target_samples), dtype=np.float32) | |
| temp_wav[:, :waveform.shape[-1]] = waveform | |
| waveform = temp_wav | |
| return waveform | |
| def wav_feature_extraction(self, waveform): | |
| import utilities.audio as Audio | |
| waveform = waveform[0, ...] | |
| waveform = torch.FloatTensor(waveform) | |
| log_mel_spec, stft, energy = Audio.tools.get_mel_from_wav(waveform, self.STFT) | |
| log_mel_spec = torch.FloatTensor(log_mel_spec.T) | |
| stft = torch.FloatTensor(stft.T) | |
| log_mel_spec = self._pad_spec(log_mel_spec) | |
| stft = self._pad_spec(stft) | |
| return log_mel_spec, stft | |
| def _pad_spec(self, log_mel_spec): | |
| n_frames = log_mel_spec.shape[0] | |
| p = self.target_length - n_frames | |
| if p > 0: | |
| m = torch.nn.ZeroPad2d((0, 0, 0, p)) | |
| log_mel_spec = m(log_mel_spec) | |
| elif p < 0: | |
| log_mel_spec = log_mel_spec[:self.target_length, :] | |
| if log_mel_spec.size(-1) % 2 != 0: | |
| log_mel_spec = log_mel_spec[..., :-1] | |
| return log_mel_spec | |
| def load_full_audio(self, filename): | |
| waveform, sr = torchaudio.load(filename) | |
| if sr != self.sampling_rate: | |
| waveform = torchaudio.functional.resample(waveform, sr, self.sampling_rate) | |
| waveform = waveform.numpy()[0, ...] | |
| return waveform | |
| def preprocess_chunk(self, chunk): | |
| chunk = chunk - np.mean(chunk) | |
| chunk = chunk / (np.max(np.abs(chunk)) + 1e-8) | |
| chunk = chunk * 0.5 | |
| return chunk | |
| def load_audiosep(): | |
| global _audiosep_model | |
| device = get_runtime_device() | |
| if _audiosep_model is not None: | |
| _audiosep_model = _audiosep_model.to(device).eval() | |
| return _audiosep_model | |
| from models.clap_encoder import CLAP_Encoder | |
| from utils import parse_yaml, load_ss_model | |
| clap_ckpt = hf_hub_download(repo_id="ShandaAI/AudioSep-hive", filename="music_speech_audioset_epoch_15_esc_89.98.pt") | |
| query_encoder = CLAP_Encoder(pretrained_path=clap_ckpt).eval() | |
| config_file = hf_hub_download(repo_id="ShandaAI/AudioSep-hive", filename="config.yaml") | |
| checkpoint_file = hf_hub_download(repo_id="ShandaAI/AudioSep-hive", filename="audiosep_hive.ckpt") | |
| configs = parse_yaml(config_file) | |
| model = load_ss_model(configs=configs, checkpoint_path=checkpoint_file, query_encoder=query_encoder) | |
| model = model.to(device).eval() | |
| _audiosep_model = model | |
| return model | |
| def load_flowsep(): | |
| global _flowsep_model, _flowsep_preprocessor | |
| device = get_runtime_device() | |
| if _flowsep_model is not None: | |
| _flowsep_model = _flowsep_model.to(device).eval() | |
| return _flowsep_model, _flowsep_preprocessor | |
| seed_everything(0) | |
| from latent_diffusion.util import instantiate_from_config | |
| config_file = hf_hub_download(repo_id="ShandaAI/FlowSep-hive", filename="config.yaml") | |
| model_file = hf_hub_download(repo_id="ShandaAI/FlowSep-hive", filename="flowsep_hive.ckpt") | |
| configs = yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader) | |
| configs["model"]["params"]["first_stage_config"]["params"]["reload_from_ckpt"] = None | |
| preprocessor = FlowSepPreprocessor(configs) | |
| model = instantiate_from_config(configs["model"]).to(device) | |
| try: | |
| ckpt = torch.load(model_file, map_location=device, weights_only=False)["state_dict"] | |
| except TypeError: | |
| ckpt = torch.load(model_file, map_location=device)["state_dict"] | |
| model.load_state_dict(ckpt, strict=True) | |
| model.eval() | |
| _flowsep_model = model | |
| _flowsep_preprocessor = preprocessor | |
| return model, preprocessor | |
| AUDIOSEP_SR = 32000 | |
| FLOWSEP_CHUNK_IN = 163840 | |
| FLOWSEP_CHUNK_OUT = 160000 | |
| FLOWSEP_SR = 16000 | |
| def separate_audiosep(audio_path, text): | |
| device = get_runtime_device() | |
| model = load_audiosep() | |
| mixture, _ = librosa.load(audio_path, sr=AUDIOSEP_SR, mono=True) | |
| input_len = mixture.shape[0] | |
| with torch.no_grad(): | |
| conditions = model.query_encoder.get_query_embed( | |
| modality='text', text=[text], device=device | |
| ) | |
| input_dict = { | |
| "mixture": torch.Tensor(mixture)[None, None, :].to(device), | |
| "condition": conditions, | |
| } | |
| if input_len > AUDIOSEP_SR * 10: | |
| sep_audio = model.ss_model.chunk_inference(input_dict) | |
| sep_audio = sep_audio.squeeze() | |
| else: | |
| sep_segment = model.ss_model(input_dict)["waveform"] | |
| sep_audio = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy() | |
| sep_audio = sep_audio[:input_len] | |
| return (AUDIOSEP_SR, sep_audio) | |
| def _flowsep_process_chunk(model, preprocessor, chunk_wav, text): | |
| device = get_runtime_device() | |
| chunk_wav = preprocessor.preprocess_chunk(chunk_wav) | |
| if len(chunk_wav) < FLOWSEP_CHUNK_IN: | |
| pad = np.zeros(FLOWSEP_CHUNK_IN - len(chunk_wav), dtype=np.float32) | |
| chunk_wav = np.concatenate([chunk_wav, pad]) | |
| chunk_wav = chunk_wav[:FLOWSEP_CHUNK_IN] | |
| mixed_mel, stft = preprocessor.wav_feature_extraction(chunk_wav.reshape(1, -1)) | |
| batch = { | |
| "fname": ["temp"], | |
| "text": [text], | |
| "caption": [text], | |
| "waveform": torch.rand(1, 1, FLOWSEP_CHUNK_IN).to(device), | |
| "log_mel_spec": torch.rand(1, 1024, 64).to(device), | |
| "sampling_rate": torch.tensor([FLOWSEP_SR]).to(device), | |
| "label_vector": torch.rand(1, 527).to(device), | |
| "stft": torch.rand(1, 1024, 512).to(device), | |
| "mixed_waveform": torch.from_numpy(chunk_wav.reshape(1, 1, FLOWSEP_CHUNK_IN)).to(device), | |
| "mixed_mel": mixed_mel.reshape(1, mixed_mel.shape[0], mixed_mel.shape[1]).to(device), | |
| } | |
| result = model.generate_sample( | |
| [batch], | |
| name="temp_result", | |
| unconditional_guidance_scale=1.0, | |
| ddim_steps=20, | |
| n_gen=1, | |
| save=False, | |
| save_mixed=False, | |
| ) | |
| if isinstance(result, np.ndarray): | |
| out = result.squeeze() | |
| else: | |
| out = result.squeeze().cpu().numpy() | |
| return out[:FLOWSEP_CHUNK_OUT] | |
| def separate_flowsep(audio_path, text): | |
| device = get_runtime_device() | |
| model, preprocessor = load_flowsep() | |
| full_wav = preprocessor.load_full_audio(audio_path) | |
| input_len = full_wav.shape[0] | |
| with torch.no_grad(): | |
| if input_len <= FLOWSEP_CHUNK_IN: | |
| sep_audio = _flowsep_process_chunk(model, preprocessor, full_wav.copy(), text) | |
| else: | |
| out_list = [] | |
| start = 0 | |
| while start < input_len: | |
| end = min(start + FLOWSEP_CHUNK_IN, input_len) | |
| chunk = full_wav[start:end] | |
| out_chunk = _flowsep_process_chunk(model, preprocessor, chunk.copy(), text) | |
| need = min(FLOWSEP_CHUNK_OUT, input_len - start) | |
| out_list.append(out_chunk[:need]) | |
| start += FLOWSEP_CHUNK_OUT | |
| sep_audio = np.concatenate(out_list) | |
| if len(sep_audio) > input_len: | |
| sep_audio = sep_audio[:input_len] | |
| elif len(sep_audio) < input_len: | |
| sep_audio = np.pad(sep_audio, (0, input_len - len(sep_audio)), mode="constant", constant_values=0) | |
| return (FLOWSEP_SR, sep_audio) | |
| def inference(audio, text, model_choice): | |
| if audio is None: | |
| raise gr.Error("Please upload an audio file / 请上传音频文件") | |
| if not text or not text.strip(): | |
| raise gr.Error("Please enter a text query / 请输入文本描述") | |
| if model_choice == "AudioSep-hive": | |
| return separate_audiosep(audio, text) | |
| else: | |
| return separate_flowsep(audio, text) | |
| if spaces is not None: | |
| def inference_entry(audio, text, model_choice): | |
| return inference(audio, text, model_choice) | |
| else: | |
| def inference_entry(audio, text, model_choice): | |
| return inference(audio, text, model_choice) | |
| DESCRIPTION = """ | |
| # Universal Sound Separation on HIVE | |
| **Hive** is a high-quality synthetic dataset (2k hours) built via an automated pipeline that mines high-purity single-event segments and synthesizes semantically consistent mixtures. Despite using only ~0.2% of the data scale of million-hour baselines, models trained on Hive achieve competitive separation accuracy and strong zero-shot generalization. | |
| This space provides two separation models trained on Hive: | |
| - **AudioSep**: A foundation model for open-domain sound separation with natural language queries, based on [AudioSep](https://github.com/Audio-AGI/AudioSep). | |
| - **FlowSep**: A flow-matching based separation model with text conditioning, based on [FlowSep](https://github.com/Audio-AGI/FlowSep). | |
| **How to use:** | |
| 1. Upload an audio file (mix of sounds) | |
| 2. Describe what you want to separate (e.g., "piano", "speech", "dog barking") | |
| 3. Select a model and click Separate | |
| [[Paper]](https://arxiv.org/abs/2601.22599) | [[Code]](https://github.com/ShandaAI/Hive) | [[Hive Dataset]](https://huggingface.co/datasets/ShandaAI/Hive) | [[Demo Page]](https://shandaai.github.io/Hive/) | |
| """ | |
| EXAMPLES = [ | |
| ["examples/acoustic_guitar.wav", "acoustic guitar"], | |
| ["examples/laughing.wav", "laughing"], | |
| ["examples/ticktok_piano.wav", "A ticktock sound playing at the same rhythm with piano"], | |
| ["examples/water_drops.wav", "water drops"], | |
| ["examples/noisy_speech.wav", "speech"], | |
| ] | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(), | |
| title="Universal Sound Separation on HIVE", | |
| ) as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio(label="Input Mixture Audio", type="filepath") | |
| text_input = gr.Textbox( | |
| label="Text Query", | |
| placeholder='e.g. "dog barking", "piano playing"', | |
| ) | |
| model_choice = gr.Dropdown( | |
| choices=["AudioSep-hive", "FlowSep-hive"], | |
| value="AudioSep-hive", | |
| label="Select Model", | |
| ) | |
| submit_btn = gr.Button("Separate", variant="primary") | |
| with gr.Column(): | |
| audio_output = gr.Audio(label="Separated Audio") | |
| submit_btn.click( | |
| fn=inference_entry, | |
| inputs=[audio_input, text_input, model_choice], | |
| outputs=audio_output, | |
| ) | |
| gr.Markdown("## Examples") | |
| gr.Examples(examples=EXAMPLES, inputs=[audio_input, text_input]) | |
| DEBUG = False | |
| def run_debug(): | |
| examples_dir = os.path.join(os.path.dirname(__file__), "examples") | |
| test_path = os.path.join(examples_dir, "acoustic_guitar.wav") | |
| test_text = "acoustic guitar" | |
| print("\n" + "=" * 50) | |
| print("[DEBUG] Starting inference test for both models") | |
| print("=" * 50) | |
| if not os.path.exists(test_path): | |
| print(f"[DEBUG] Skip: {test_path} not found") | |
| return | |
| print(f"\n[DEBUG] Using test audio: {test_path}") | |
| print("\n" + "-" * 40) | |
| print("[DEBUG] AudioSep inference") | |
| print("-" * 40) | |
| print("[DEBUG] Loading AudioSep model...") | |
| out_audiosep = separate_audiosep(test_path, test_text) | |
| print(f"[DEBUG] AudioSep done. Output sr={out_audiosep[0]}, shape={np.array(out_audiosep[1]).shape}") | |
| print("\n" + "-" * 40) | |
| print("[DEBUG] FlowSep inference") | |
| print("-" * 40) | |
| print("[DEBUG] Loading FlowSep model...") | |
| out_flowsep = separate_flowsep(test_path, test_text) | |
| print(f"[DEBUG] FlowSep done. Output sr={out_flowsep[0]}, shape={np.array(out_flowsep[1]).shape}") | |
| print("\n" + "=" * 50) | |
| print("[DEBUG] Both models passed inference test") | |
| print("=" * 50 + "\n") | |
| if __name__ == "__main__": | |
| if DEBUG: | |
| run_debug() | |
| demo.queue() | |
| demo.launch() | |