""" AudioSep: Separate Anything You Describe CPU inference for HuggingFace Spaces free tier """ import gc import os from pathlib import Path import gradio as gr import librosa import numpy as np import torch from huggingface_hub import hf_hub_download # Fix PyTorch 2.6+ weights_only default change import numpy.core.multiarray torch.serialization.add_safe_globals([numpy.core.multiarray.scalar]) # Force CPU os.environ["CUDA_VISIBLE_DEVICES"] = "" DEVICE = torch.device("cpu") # Checkpoint repo on HuggingFace HF_REPO = "audo/AudioSep" CHECKPOINT_DIR = Path("checkpoint") CHECKPOINT_DIR.mkdir(exist_ok=True) # Global model (lazy loaded) MODEL = None def download_checkpoints(): """Download checkpoints from HuggingFace Hub""" files = [ "audiosep_base_4M_steps.ckpt", "music_speech_audioset_epoch_15_esc_89.98.pt", ] for f in files: local_path = CHECKPOINT_DIR / f if not local_path.exists(): print(f"Downloading {f}...") hf_hub_download( repo_id=HF_REPO, filename=f, local_dir=CHECKPOINT_DIR, local_dir_use_symlinks=False, ) print("Checkpoints ready!") def load_model(): """Load model (lazy loading to save memory)""" global MODEL if MODEL is not None: return MODEL print("Loading AudioSep model...") from pipeline import build_audiosep MODEL = build_audiosep( config_yaml="config/audiosep_base.yaml", checkpoint_path=str(CHECKPOINT_DIR / "audiosep_base_4M_steps.ckpt"), device=DEVICE, ) MODEL.eval() gc.collect() print("Model loaded!") return MODEL def inference(audio_file_path: str, text: str, progress=gr.Progress()): """Separate audio based on text query""" if audio_file_path is None: return None, "Please upload an audio file" if not text or text.strip() == "": return None, "Please enter a text query" progress(0.1, "Loading model...") model = load_model() progress(0.3, "Loading audio...") print(f"Separate audio from [{audio_file_path}] with query [{text}]") mixture, _ = librosa.load(audio_file_path, sr=32000, mono=True) progress(0.5, "Processing...") with torch.no_grad(): conditions = model.query_encoder.get_query_embed( modality="text", text=[text], device=DEVICE ) input_dict = { "mixture": torch.Tensor(mixture)[None, None, :].to(DEVICE), "condition": conditions, } progress(0.7, "Separating audio...") sep_segment = model.ss_model(input_dict)["waveform"] sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy() # Clean up gc.collect() progress(1.0, "Done!") return (32000, np.round(sep_segment * 32767).astype(np.int16)), "Separation complete!" # Download checkpoints on startup print("Initializing AudioSep...") download_checkpoints() # Examples EXAMPLES = [ ["examples/acoustic_guitar.wav", "acoustic guitar"], ["examples/laughing.wav", "laughing"], ["examples/ticktok_piano.wav", "A ticktock sound playing at the same rhythm with piano"], ["examples/water_drops.wav", "water drops"], ["examples/noisy_speech.wav", "speech"], ] # Gradio UI description = """ # AudioSep: Separate Anything You Describe AudioSep is a foundation model for open-domain sound separation with natural language queries. **How to use:** 1. Upload an audio file (mix of sounds) 2. Describe what you want to separate (e.g., "piano", "speech", "dog barking") 3. Click Separate [[Project Page]](https://audio-agi.github.io/Separate-Anything-You-Describe) | [[Paper]](https://arxiv.org/abs/2308.05037) | [[Code]](https://github.com/Audio-AGI/AudioSep) """ with gr.Blocks(title="AudioSep") as demo: gr.Markdown(description) with gr.Row(): with gr.Column(): input_audio = gr.Audio(label="Input Audio (Mixture)", type="filepath") text_query = gr.Textbox( label="Text Query", placeholder="Describe the sound to separate (e.g., 'piano', 'speech', 'dog barking')" ) separate_btn = gr.Button("Separate", variant="primary", size="lg") with gr.Column(): output_audio = gr.Audio(label="Separated Audio") status = gr.Textbox(label="Status", interactive=False) separate_btn.click( fn=inference, inputs=[input_audio, text_query], outputs=[output_audio, status], api_name="separate", ) gr.Markdown("## Examples") gr.Examples( examples=EXAMPLES, inputs=[input_audio, text_query], ) if __name__ == "__main__": demo.queue().launch()