audiosep / app.py
Nekochu's picture
Add api_name to separate endpoint
fe21b4a verified
"""
AudioSep: Separate Anything You Describe
CPU inference for HuggingFace Spaces free tier
"""
import gc
import os
from pathlib import Path
import gradio as gr
import librosa
import numpy as np
import torch
from huggingface_hub import hf_hub_download
# Fix PyTorch 2.6+ weights_only default change
import numpy.core.multiarray
torch.serialization.add_safe_globals([numpy.core.multiarray.scalar])
# Force CPU
os.environ["CUDA_VISIBLE_DEVICES"] = ""
DEVICE = torch.device("cpu")
# Checkpoint repo on HuggingFace
HF_REPO = "audo/AudioSep"
CHECKPOINT_DIR = Path("checkpoint")
CHECKPOINT_DIR.mkdir(exist_ok=True)
# Global model (lazy loaded)
MODEL = None
def download_checkpoints():
"""Download checkpoints from HuggingFace Hub"""
files = [
"audiosep_base_4M_steps.ckpt",
"music_speech_audioset_epoch_15_esc_89.98.pt",
]
for f in files:
local_path = CHECKPOINT_DIR / f
if not local_path.exists():
print(f"Downloading {f}...")
hf_hub_download(
repo_id=HF_REPO,
filename=f,
local_dir=CHECKPOINT_DIR,
local_dir_use_symlinks=False,
)
print("Checkpoints ready!")
def load_model():
"""Load model (lazy loading to save memory)"""
global MODEL
if MODEL is not None:
return MODEL
print("Loading AudioSep model...")
from pipeline import build_audiosep
MODEL = build_audiosep(
config_yaml="config/audiosep_base.yaml",
checkpoint_path=str(CHECKPOINT_DIR / "audiosep_base_4M_steps.ckpt"),
device=DEVICE,
)
MODEL.eval()
gc.collect()
print("Model loaded!")
return MODEL
def inference(audio_file_path: str, text: str, progress=gr.Progress()):
"""Separate audio based on text query"""
if audio_file_path is None:
return None, "Please upload an audio file"
if not text or text.strip() == "":
return None, "Please enter a text query"
progress(0.1, "Loading model...")
model = load_model()
progress(0.3, "Loading audio...")
print(f"Separate audio from [{audio_file_path}] with query [{text}]")
mixture, _ = librosa.load(audio_file_path, sr=32000, mono=True)
progress(0.5, "Processing...")
with torch.no_grad():
conditions = model.query_encoder.get_query_embed(
modality="text",
text=[text],
device=DEVICE
)
input_dict = {
"mixture": torch.Tensor(mixture)[None, None, :].to(DEVICE),
"condition": conditions,
}
progress(0.7, "Separating audio...")
sep_segment = model.ss_model(input_dict)["waveform"]
sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
# Clean up
gc.collect()
progress(1.0, "Done!")
return (32000, np.round(sep_segment * 32767).astype(np.int16)), "Separation complete!"
# Download checkpoints on startup
print("Initializing AudioSep...")
download_checkpoints()
# Examples
EXAMPLES = [
["examples/acoustic_guitar.wav", "acoustic guitar"],
["examples/laughing.wav", "laughing"],
["examples/ticktok_piano.wav", "A ticktock sound playing at the same rhythm with piano"],
["examples/water_drops.wav", "water drops"],
["examples/noisy_speech.wav", "speech"],
]
# Gradio UI
description = """
# AudioSep: Separate Anything You Describe
AudioSep is a foundation model for open-domain sound separation with natural language queries.
**How to use:**
1. Upload an audio file (mix of sounds)
2. Describe what you want to separate (e.g., "piano", "speech", "dog barking")
3. Click Separate
[[Project Page]](https://audio-agi.github.io/Separate-Anything-You-Describe) | [[Paper]](https://arxiv.org/abs/2308.05037) | [[Code]](https://github.com/Audio-AGI/AudioSep)
"""
with gr.Blocks(title="AudioSep") as demo:
gr.Markdown(description)
with gr.Row():
with gr.Column():
input_audio = gr.Audio(label="Input Audio (Mixture)", type="filepath")
text_query = gr.Textbox(
label="Text Query",
placeholder="Describe the sound to separate (e.g., 'piano', 'speech', 'dog barking')"
)
separate_btn = gr.Button("Separate", variant="primary", size="lg")
with gr.Column():
output_audio = gr.Audio(label="Separated Audio")
status = gr.Textbox(label="Status", interactive=False)
separate_btn.click(
fn=inference,
inputs=[input_audio, text_query],
outputs=[output_audio, status],
api_name="separate",
)
gr.Markdown("## Examples")
gr.Examples(
examples=EXAMPLES,
inputs=[input_audio, text_query],
)
if __name__ == "__main__":
demo.queue().launch()