|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
import librosa |
|
|
import pathlib |
|
|
import scipy.io.wavfile |
|
|
import os |
|
|
|
|
|
from imagebind import data |
|
|
from imagebind.models import imagebind_model |
|
|
from imagebind.models.imagebind_model import ModalityType |
|
|
import torch.nn.functional as F |
|
|
|
|
|
import omnisep |
|
|
import utils |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
|
|
|
def setup_models(checkpoint_path, train_args_path): |
|
|
train_args = utils.load_json(train_args_path) |
|
|
|
|
|
model = omnisep.OmniSep( |
|
|
train_args['n_mix'], train_args['layers'], train_args['channels'], |
|
|
use_log_freq=train_args['log_freq'], |
|
|
use_weighted_loss=train_args['weighted_loss'], |
|
|
use_binary_mask=train_args['binary_mask'], |
|
|
emb_dim=train_args.get('emb_dim', 512) |
|
|
) |
|
|
model = torch.nn.DataParallel(model) |
|
|
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
imagebind_net = imagebind_model.imagebind_huge(pretrained=True) |
|
|
imagebind_net = torch.nn.DataParallel(imagebind_net) |
|
|
imagebind_net.to(device) |
|
|
imagebind_net.eval() |
|
|
|
|
|
return model, imagebind_net, train_args |
|
|
|
|
|
|
|
|
|
|
|
def load_audio_and_spec(audio_file, audio_len, sample_rate, n_fft, hop_len, win_len): |
|
|
y, sr = librosa.load(audio_file, sr=sample_rate, mono=True) |
|
|
if len(y) < audio_len: |
|
|
y = np.tile(y, (audio_len // len(y) + 1))[:audio_len] |
|
|
else: |
|
|
y = y[:audio_len] |
|
|
y = np.clip(y, -1, 1) |
|
|
|
|
|
spec_mix = librosa.stft(y, n_fft=n_fft, hop_length=hop_len, win_length=win_len) |
|
|
mag_mix = torch.tensor(np.abs(spec_mix)).unsqueeze(0).unsqueeze(0) |
|
|
phase_mix = torch.tensor(np.angle(spec_mix)).unsqueeze(0).unsqueeze(0) |
|
|
return mag_mix, phase_mix, y.shape[0] |
|
|
|
|
|
|
|
|
|
|
|
def get_combined_embedding(imagebind_net, text=None, image=None, audio=None, |
|
|
text_w=1.0, image_w=1.0, audio_w=1.0): |
|
|
inputs = {} |
|
|
if text: inputs[ModalityType.TEXT] = data.load_and_transform_text([text], device) |
|
|
if image: inputs[ModalityType.VISION] = data.load_and_transform_vision_data([image], device) |
|
|
if audio: inputs[ModalityType.AUDIO] = data.load_and_transform_audio_data([audio], device) |
|
|
emb = imagebind_net(inputs) |
|
|
|
|
|
result = None |
|
|
denom = 0 |
|
|
if text: |
|
|
result = text_w * emb[ModalityType.TEXT] |
|
|
denom += text_w |
|
|
if image: |
|
|
result = emb[ModalityType.VISION] * image_w if result is None else result + image_w * emb[ModalityType.VISION] |
|
|
denom += image_w |
|
|
if audio: |
|
|
result = emb[ModalityType.AUDIO] * audio_w if result is None else result + audio_w * emb[ModalityType.AUDIO] |
|
|
denom += audio_w |
|
|
if denom > 0: |
|
|
result = F.normalize(result / denom) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def recover_waveform(mag_mix, phase_mix, pred_mask, args): |
|
|
B = mag_mix.size(0) |
|
|
if args['log_freq']: |
|
|
grid_unwarp = torch.from_numpy(utils.warpgrid(B, args['n_fft'] // 2 + 1, pred_mask.size(3), warp=False)).to(pred_mask.device) |
|
|
pred_mask_linear = F.grid_sample(pred_mask, grid_unwarp, align_corners=True) |
|
|
else: |
|
|
pred_mask_linear = pred_mask[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mag_mix = mag_mix.detach().cpu().numpy() |
|
|
phase_mix = phase_mix.detach().cpu().numpy() |
|
|
pred_mask = pred_mask.detach().cpu().numpy() |
|
|
pred_mask_linear = pred_mask_linear.detach().cpu().numpy() |
|
|
|
|
|
|
|
|
pred_mask = (pred_mask > 0.5).astype(np.float32) |
|
|
pred_mask_linear = (pred_mask_linear > 0.5).astype(np.float32) |
|
|
|
|
|
|
|
|
pred_mag = mag_mix[0, 0] * pred_mask_linear[0, 0] |
|
|
pred_wav = utils.istft_reconstruction( |
|
|
pred_mag, |
|
|
phase_mix[0, 0], |
|
|
hop_len=args['hop_len'], |
|
|
win_len=args['win_len'], |
|
|
) |
|
|
|
|
|
return pred_wav |
|
|
|
|
|
|
|
|
|
|
|
def run_inference(input_audio, text_pos, audio_pos, image_pos, text_neg, audio_neg, image_neg, |
|
|
text_w, image_w, audio_w, neg_w): |
|
|
model, imagebind_net, args = setup_models("./exp/checkpoints/best_model.pt", "./exp/train-args.json") |
|
|
audio_len = 65535 |
|
|
mag_mix, phase_mix, out_len = load_audio_and_spec(input_audio, audio_len, |
|
|
args['audio_rate'], args['n_fft'], args['hop_len'], args['win_len']) |
|
|
img_emb = get_combined_embedding(imagebind_net, text_pos, image_pos, audio_pos, |
|
|
text_w, image_w, audio_w) |
|
|
if any([text_neg, audio_neg, image_neg]): |
|
|
neg_emb = get_combined_embedding(imagebind_net, text_neg, image_neg, audio_neg, |
|
|
1.0, 1.0, 1.0) |
|
|
img_emb = (1 + neg_w) * img_emb - neg_w * neg_emb |
|
|
mag_mix = mag_mix.to(device) |
|
|
phase_mix = phase_mix.to(device) |
|
|
|
|
|
pred_mask = model.module.infer(mag_mix, [img_emb])[0] |
|
|
pred_wav = recover_waveform(mag_mix, phase_mix, pred_mask, args) |
|
|
out_path = "/tmp/output.wav" |
|
|
scipy.io.wavfile.write(out_path, args['audio_rate'], pred_wav[:out_len]) |
|
|
return out_path |
|
|
|
|
|
with gr.Blocks(title="OmniSep UI") as iface: |
|
|
gr.Markdown("## ๐ง Upload Your Mixed Audio") |
|
|
mixed_audio = gr.Audio(type="filepath", label="Mixed Input Audio") |
|
|
|
|
|
gr.Markdown("### โ
Positive Query") |
|
|
with gr.Row(): |
|
|
pos_text = gr.Textbox(label="Text Query", placeholder="e.g. dog barking") |
|
|
pos_audio = gr.Audio(type="filepath", label="Audio Query") |
|
|
pos_image = gr.Image(type="filepath", label="Image Query") |
|
|
|
|
|
gr.Markdown("### โ Negative Query (Optional)") |
|
|
with gr.Row(): |
|
|
neg_text = gr.Textbox(label="Negative Text Query") |
|
|
neg_audio = gr.Audio(type="filepath", label="Negative Audio Query") |
|
|
neg_image = gr.Image(type="filepath", label="Negative Image Query") |
|
|
|
|
|
gr.Markdown("### ๐๏ธ Modality Weights") |
|
|
with gr.Row(): |
|
|
text_weight = gr.Slider(0, 5, value=1.0, step=0.1, label="Text Weight") |
|
|
image_weight = gr.Slider(0, 5, value=1.0, step=0.1, label="Image Weight") |
|
|
audio_weight = gr.Slider(0, 5, value=1.0, step=0.1, label="Audio Weight") |
|
|
neg_weight = gr.Slider(0, 2, value=0.5, step=0.1, label="Negative Embedding Weight") |
|
|
|
|
|
output_audio = gr.Audio(type="filepath", label="Separated Output Audio") |
|
|
|
|
|
btn = gr.Button("Run OmniSep Inference") |
|
|
btn.click(fn=run_inference, |
|
|
inputs=[mixed_audio, pos_text, pos_audio, pos_image, neg_text, neg_audio, neg_image, |
|
|
text_weight, image_weight, audio_weight, neg_weight], |
|
|
outputs=output_audio) |
|
|
|
|
|
iface.launch(share=True) |