|
|
"""
|
|
|
AudioSep: Separate Anything You Describe
|
|
|
CPU inference for HuggingFace Spaces free tier
|
|
|
"""
|
|
|
import gc
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
|
|
|
import gradio as gr
|
|
|
import librosa
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
|
|
import numpy.core.multiarray
|
|
|
torch.serialization.add_safe_globals([numpy.core.multiarray.scalar])
|
|
|
|
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
|
DEVICE = torch.device("cpu")
|
|
|
|
|
|
|
|
|
HF_REPO = "audo/AudioSep"
|
|
|
CHECKPOINT_DIR = Path("checkpoint")
|
|
|
CHECKPOINT_DIR.mkdir(exist_ok=True)
|
|
|
|
|
|
|
|
|
MODEL = None
|
|
|
|
|
|
|
|
|
def download_checkpoints():
|
|
|
"""Download checkpoints from HuggingFace Hub"""
|
|
|
files = [
|
|
|
"audiosep_base_4M_steps.ckpt",
|
|
|
"music_speech_audioset_epoch_15_esc_89.98.pt",
|
|
|
]
|
|
|
for f in files:
|
|
|
local_path = CHECKPOINT_DIR / f
|
|
|
if not local_path.exists():
|
|
|
print(f"Downloading {f}...")
|
|
|
hf_hub_download(
|
|
|
repo_id=HF_REPO,
|
|
|
filename=f,
|
|
|
local_dir=CHECKPOINT_DIR,
|
|
|
local_dir_use_symlinks=False,
|
|
|
)
|
|
|
print("Checkpoints ready!")
|
|
|
|
|
|
|
|
|
def load_model():
|
|
|
"""Load model (lazy loading to save memory)"""
|
|
|
global MODEL
|
|
|
if MODEL is not None:
|
|
|
return MODEL
|
|
|
|
|
|
print("Loading AudioSep model...")
|
|
|
from pipeline import build_audiosep
|
|
|
|
|
|
MODEL = build_audiosep(
|
|
|
config_yaml="config/audiosep_base.yaml",
|
|
|
checkpoint_path=str(CHECKPOINT_DIR / "audiosep_base_4M_steps.ckpt"),
|
|
|
device=DEVICE,
|
|
|
)
|
|
|
MODEL.eval()
|
|
|
gc.collect()
|
|
|
print("Model loaded!")
|
|
|
return MODEL
|
|
|
|
|
|
|
|
|
def inference(audio_file_path: str, text: str, progress=gr.Progress()):
|
|
|
"""Separate audio based on text query"""
|
|
|
if audio_file_path is None:
|
|
|
return None, "Please upload an audio file"
|
|
|
if not text or text.strip() == "":
|
|
|
return None, "Please enter a text query"
|
|
|
|
|
|
progress(0.1, "Loading model...")
|
|
|
model = load_model()
|
|
|
|
|
|
progress(0.3, "Loading audio...")
|
|
|
print(f"Separate audio from [{audio_file_path}] with query [{text}]")
|
|
|
mixture, _ = librosa.load(audio_file_path, sr=32000, mono=True)
|
|
|
|
|
|
progress(0.5, "Processing...")
|
|
|
with torch.no_grad():
|
|
|
conditions = model.query_encoder.get_query_embed(
|
|
|
modality="text",
|
|
|
text=[text],
|
|
|
device=DEVICE
|
|
|
)
|
|
|
|
|
|
input_dict = {
|
|
|
"mixture": torch.Tensor(mixture)[None, None, :].to(DEVICE),
|
|
|
"condition": conditions,
|
|
|
}
|
|
|
|
|
|
progress(0.7, "Separating audio...")
|
|
|
sep_segment = model.ss_model(input_dict)["waveform"]
|
|
|
sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
|
|
|
|
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
progress(1.0, "Done!")
|
|
|
return (32000, np.round(sep_segment * 32767).astype(np.int16)), "Separation complete!"
|
|
|
|
|
|
|
|
|
|
|
|
print("Initializing AudioSep...")
|
|
|
download_checkpoints()
|
|
|
|
|
|
|
|
|
EXAMPLES = [
|
|
|
["examples/acoustic_guitar.wav", "acoustic guitar"],
|
|
|
["examples/laughing.wav", "laughing"],
|
|
|
["examples/ticktok_piano.wav", "A ticktock sound playing at the same rhythm with piano"],
|
|
|
["examples/water_drops.wav", "water drops"],
|
|
|
["examples/noisy_speech.wav", "speech"],
|
|
|
]
|
|
|
|
|
|
|
|
|
description = """
|
|
|
# AudioSep: Separate Anything You Describe
|
|
|
|
|
|
AudioSep is a foundation model for open-domain sound separation with natural language queries.
|
|
|
|
|
|
**How to use:**
|
|
|
1. Upload an audio file (mix of sounds)
|
|
|
2. Describe what you want to separate (e.g., "piano", "speech", "dog barking")
|
|
|
3. Click Separate
|
|
|
|
|
|
[[Project Page]](https://audio-agi.github.io/Separate-Anything-You-Describe) | [[Paper]](https://arxiv.org/abs/2308.05037) | [[Code]](https://github.com/Audio-AGI/AudioSep)
|
|
|
"""
|
|
|
|
|
|
with gr.Blocks(title="AudioSep") as demo:
|
|
|
gr.Markdown(description)
|
|
|
|
|
|
with gr.Row():
|
|
|
with gr.Column():
|
|
|
input_audio = gr.Audio(label="Input Audio (Mixture)", type="filepath")
|
|
|
text_query = gr.Textbox(
|
|
|
label="Text Query",
|
|
|
placeholder="Describe the sound to separate (e.g., 'piano', 'speech', 'dog barking')"
|
|
|
)
|
|
|
separate_btn = gr.Button("Separate", variant="primary", size="lg")
|
|
|
|
|
|
with gr.Column():
|
|
|
output_audio = gr.Audio(label="Separated Audio")
|
|
|
status = gr.Textbox(label="Status", interactive=False)
|
|
|
|
|
|
separate_btn.click(
|
|
|
fn=inference,
|
|
|
inputs=[input_audio, text_query],
|
|
|
outputs=[output_audio, status],
|
|
|
api_name="separate",
|
|
|
)
|
|
|
|
|
|
gr.Markdown("## Examples")
|
|
|
gr.Examples(
|
|
|
examples=EXAMPLES,
|
|
|
inputs=[input_audio, text_query],
|
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
demo.queue().launch()
|
|
|
|