Hive / app.py
JusperLee's picture
update all code
859ee84
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "models", "audiosep"))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "models", "flowsep"))
import gradio as gr
import torch
import numpy as np
import torchaudio
import librosa
import yaml
from huggingface_hub import hf_hub_download
from pytorch_lightning import seed_everything
try:
import spaces
except ImportError:
spaces = None
_audiosep_model = None
_flowsep_model = None
_flowsep_preprocessor = None
def get_runtime_device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
class FlowSepPreprocessor:
def __init__(self, config):
import utilities.audio as Audio
self.sampling_rate = config["preprocessing"]["audio"]["sampling_rate"]
self.duration = config["preprocessing"]["audio"]["duration"]
self.hopsize = config["preprocessing"]["stft"]["hop_length"]
self.target_length = int(self.duration * self.sampling_rate / self.hopsize)
self.STFT = Audio.stft.TacotronSTFT(
config["preprocessing"]["stft"]["filter_length"],
config["preprocessing"]["stft"]["hop_length"],
config["preprocessing"]["stft"]["win_length"],
config["preprocessing"]["mel"]["n_mel_channels"],
config["preprocessing"]["audio"]["sampling_rate"],
config["preprocessing"]["mel"]["mel_fmin"],
config["preprocessing"]["mel"]["mel_fmax"],
)
def read_wav_file(self, filename):
waveform, sr = torchaudio.load(filename)
target_length = int(sr * self.duration)
if waveform.shape[-1] > target_length:
waveform = waveform[:, :target_length]
if sr != self.sampling_rate:
waveform = torchaudio.functional.resample(waveform, sr, self.sampling_rate)
waveform = waveform.numpy()[0, ...]
waveform = waveform - np.mean(waveform)
waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
waveform = waveform * 0.5
waveform = waveform[None, ...]
target_samples = int(self.sampling_rate * self.duration)
if waveform.shape[-1] < target_samples:
temp_wav = np.zeros((1, target_samples), dtype=np.float32)
temp_wav[:, :waveform.shape[-1]] = waveform
waveform = temp_wav
return waveform
def wav_feature_extraction(self, waveform):
import utilities.audio as Audio
waveform = waveform[0, ...]
waveform = torch.FloatTensor(waveform)
log_mel_spec, stft, energy = Audio.tools.get_mel_from_wav(waveform, self.STFT)
log_mel_spec = torch.FloatTensor(log_mel_spec.T)
stft = torch.FloatTensor(stft.T)
log_mel_spec = self._pad_spec(log_mel_spec)
stft = self._pad_spec(stft)
return log_mel_spec, stft
def _pad_spec(self, log_mel_spec):
n_frames = log_mel_spec.shape[0]
p = self.target_length - n_frames
if p > 0:
m = torch.nn.ZeroPad2d((0, 0, 0, p))
log_mel_spec = m(log_mel_spec)
elif p < 0:
log_mel_spec = log_mel_spec[:self.target_length, :]
if log_mel_spec.size(-1) % 2 != 0:
log_mel_spec = log_mel_spec[..., :-1]
return log_mel_spec
def load_full_audio(self, filename):
waveform, sr = torchaudio.load(filename)
if sr != self.sampling_rate:
waveform = torchaudio.functional.resample(waveform, sr, self.sampling_rate)
waveform = waveform.numpy()[0, ...]
return waveform
def preprocess_chunk(self, chunk):
chunk = chunk - np.mean(chunk)
chunk = chunk / (np.max(np.abs(chunk)) + 1e-8)
chunk = chunk * 0.5
return chunk
def load_audiosep():
global _audiosep_model
device = get_runtime_device()
if _audiosep_model is not None:
_audiosep_model = _audiosep_model.to(device).eval()
return _audiosep_model
from models.clap_encoder import CLAP_Encoder
from utils import parse_yaml, load_ss_model
clap_ckpt = hf_hub_download(repo_id="ShandaAI/AudioSep-hive", filename="music_speech_audioset_epoch_15_esc_89.98.pt")
query_encoder = CLAP_Encoder(pretrained_path=clap_ckpt).eval()
config_file = hf_hub_download(repo_id="ShandaAI/AudioSep-hive", filename="config.yaml")
checkpoint_file = hf_hub_download(repo_id="ShandaAI/AudioSep-hive", filename="audiosep_hive.ckpt")
configs = parse_yaml(config_file)
model = load_ss_model(configs=configs, checkpoint_path=checkpoint_file, query_encoder=query_encoder)
model = model.to(device).eval()
_audiosep_model = model
return model
def load_flowsep():
global _flowsep_model, _flowsep_preprocessor
device = get_runtime_device()
if _flowsep_model is not None:
_flowsep_model = _flowsep_model.to(device).eval()
return _flowsep_model, _flowsep_preprocessor
seed_everything(0)
from latent_diffusion.util import instantiate_from_config
config_file = hf_hub_download(repo_id="ShandaAI/FlowSep-hive", filename="config.yaml")
model_file = hf_hub_download(repo_id="ShandaAI/FlowSep-hive", filename="flowsep_hive.ckpt")
configs = yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader)
configs["model"]["params"]["first_stage_config"]["params"]["reload_from_ckpt"] = None
preprocessor = FlowSepPreprocessor(configs)
model = instantiate_from_config(configs["model"]).to(device)
try:
ckpt = torch.load(model_file, map_location=device, weights_only=False)["state_dict"]
except TypeError:
ckpt = torch.load(model_file, map_location=device)["state_dict"]
model.load_state_dict(ckpt, strict=True)
model.eval()
_flowsep_model = model
_flowsep_preprocessor = preprocessor
return model, preprocessor
AUDIOSEP_SR = 32000
FLOWSEP_CHUNK_IN = 163840
FLOWSEP_CHUNK_OUT = 160000
FLOWSEP_SR = 16000
def separate_audiosep(audio_path, text):
device = get_runtime_device()
model = load_audiosep()
mixture, _ = librosa.load(audio_path, sr=AUDIOSEP_SR, mono=True)
input_len = mixture.shape[0]
with torch.no_grad():
conditions = model.query_encoder.get_query_embed(
modality='text', text=[text], device=device
)
input_dict = {
"mixture": torch.Tensor(mixture)[None, None, :].to(device),
"condition": conditions,
}
if input_len > AUDIOSEP_SR * 10:
sep_audio = model.ss_model.chunk_inference(input_dict)
sep_audio = sep_audio.squeeze()
else:
sep_segment = model.ss_model(input_dict)["waveform"]
sep_audio = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
sep_audio = sep_audio[:input_len]
return (AUDIOSEP_SR, sep_audio)
def _flowsep_process_chunk(model, preprocessor, chunk_wav, text):
device = get_runtime_device()
chunk_wav = preprocessor.preprocess_chunk(chunk_wav)
if len(chunk_wav) < FLOWSEP_CHUNK_IN:
pad = np.zeros(FLOWSEP_CHUNK_IN - len(chunk_wav), dtype=np.float32)
chunk_wav = np.concatenate([chunk_wav, pad])
chunk_wav = chunk_wav[:FLOWSEP_CHUNK_IN]
mixed_mel, stft = preprocessor.wav_feature_extraction(chunk_wav.reshape(1, -1))
batch = {
"fname": ["temp"],
"text": [text],
"caption": [text],
"waveform": torch.rand(1, 1, FLOWSEP_CHUNK_IN).to(device),
"log_mel_spec": torch.rand(1, 1024, 64).to(device),
"sampling_rate": torch.tensor([FLOWSEP_SR]).to(device),
"label_vector": torch.rand(1, 527).to(device),
"stft": torch.rand(1, 1024, 512).to(device),
"mixed_waveform": torch.from_numpy(chunk_wav.reshape(1, 1, FLOWSEP_CHUNK_IN)).to(device),
"mixed_mel": mixed_mel.reshape(1, mixed_mel.shape[0], mixed_mel.shape[1]).to(device),
}
result = model.generate_sample(
[batch],
name="temp_result",
unconditional_guidance_scale=1.0,
ddim_steps=20,
n_gen=1,
save=False,
save_mixed=False,
)
if isinstance(result, np.ndarray):
out = result.squeeze()
else:
out = result.squeeze().cpu().numpy()
return out[:FLOWSEP_CHUNK_OUT]
def separate_flowsep(audio_path, text):
device = get_runtime_device()
model, preprocessor = load_flowsep()
full_wav = preprocessor.load_full_audio(audio_path)
input_len = full_wav.shape[0]
with torch.no_grad():
if input_len <= FLOWSEP_CHUNK_IN:
sep_audio = _flowsep_process_chunk(model, preprocessor, full_wav.copy(), text)
else:
out_list = []
start = 0
while start < input_len:
end = min(start + FLOWSEP_CHUNK_IN, input_len)
chunk = full_wav[start:end]
out_chunk = _flowsep_process_chunk(model, preprocessor, chunk.copy(), text)
need = min(FLOWSEP_CHUNK_OUT, input_len - start)
out_list.append(out_chunk[:need])
start += FLOWSEP_CHUNK_OUT
sep_audio = np.concatenate(out_list)
if len(sep_audio) > input_len:
sep_audio = sep_audio[:input_len]
elif len(sep_audio) < input_len:
sep_audio = np.pad(sep_audio, (0, input_len - len(sep_audio)), mode="constant", constant_values=0)
return (FLOWSEP_SR, sep_audio)
def inference(audio, text, model_choice):
if audio is None:
raise gr.Error("Please upload an audio file / 请上传音频文件")
if not text or not text.strip():
raise gr.Error("Please enter a text query / 请输入文本描述")
if model_choice == "AudioSep-hive":
return separate_audiosep(audio, text)
else:
return separate_flowsep(audio, text)
if spaces is not None:
@spaces.GPU(duration=120)
def inference_entry(audio, text, model_choice):
return inference(audio, text, model_choice)
else:
def inference_entry(audio, text, model_choice):
return inference(audio, text, model_choice)
DESCRIPTION = """
# Universal Sound Separation on HIVE
**Hive** is a high-quality synthetic dataset (2k hours) built via an automated pipeline that mines high-purity single-event segments and synthesizes semantically consistent mixtures. Despite using only ~0.2% of the data scale of million-hour baselines, models trained on Hive achieve competitive separation accuracy and strong zero-shot generalization.
This space provides two separation models trained on Hive:
- **AudioSep**: A foundation model for open-domain sound separation with natural language queries, based on [AudioSep](https://github.com/Audio-AGI/AudioSep).
- **FlowSep**: A flow-matching based separation model with text conditioning, based on [FlowSep](https://github.com/Audio-AGI/FlowSep).
**How to use:**
1. Upload an audio file (mix of sounds)
2. Describe what you want to separate (e.g., "piano", "speech", "dog barking")
3. Select a model and click Separate
[[Paper]](https://arxiv.org/abs/2601.22599) | [[Code]](https://github.com/ShandaAI/Hive) | [[Hive Dataset]](https://huggingface.co/datasets/ShandaAI/Hive) | [[Demo Page]](https://shandaai.github.io/Hive/)
"""
EXAMPLES = [
["examples/acoustic_guitar.wav", "acoustic guitar"],
["examples/laughing.wav", "laughing"],
["examples/ticktok_piano.wav", "A ticktock sound playing at the same rhythm with piano"],
["examples/water_drops.wav", "water drops"],
["examples/noisy_speech.wav", "speech"],
]
with gr.Blocks(
theme=gr.themes.Soft(),
title="Universal Sound Separation on HIVE",
) as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
audio_input = gr.Audio(label="Input Mixture Audio", type="filepath")
text_input = gr.Textbox(
label="Text Query",
placeholder='e.g. "dog barking", "piano playing"',
)
model_choice = gr.Dropdown(
choices=["AudioSep-hive", "FlowSep-hive"],
value="AudioSep-hive",
label="Select Model",
)
submit_btn = gr.Button("Separate", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Separated Audio")
submit_btn.click(
fn=inference_entry,
inputs=[audio_input, text_input, model_choice],
outputs=audio_output,
)
gr.Markdown("## Examples")
gr.Examples(examples=EXAMPLES, inputs=[audio_input, text_input])
DEBUG = False
def run_debug():
examples_dir = os.path.join(os.path.dirname(__file__), "examples")
test_path = os.path.join(examples_dir, "acoustic_guitar.wav")
test_text = "acoustic guitar"
print("\n" + "=" * 50)
print("[DEBUG] Starting inference test for both models")
print("=" * 50)
if not os.path.exists(test_path):
print(f"[DEBUG] Skip: {test_path} not found")
return
print(f"\n[DEBUG] Using test audio: {test_path}")
print("\n" + "-" * 40)
print("[DEBUG] AudioSep inference")
print("-" * 40)
print("[DEBUG] Loading AudioSep model...")
out_audiosep = separate_audiosep(test_path, test_text)
print(f"[DEBUG] AudioSep done. Output sr={out_audiosep[0]}, shape={np.array(out_audiosep[1]).shape}")
print("\n" + "-" * 40)
print("[DEBUG] FlowSep inference")
print("-" * 40)
print("[DEBUG] Loading FlowSep model...")
out_flowsep = separate_flowsep(test_path, test_text)
print(f"[DEBUG] FlowSep done. Output sr={out_flowsep[0]}, shape={np.array(out_flowsep[1]).shape}")
print("\n" + "=" * 50)
print("[DEBUG] Both models passed inference test")
print("=" * 50 + "\n")
if __name__ == "__main__":
if DEBUG:
run_debug()
demo.queue()
demo.launch()