OmniSep / app.py
Exgc's picture
Update app.py
b8dba48 verified
import gradio as gr
import torch
import numpy as np
import librosa
import pathlib
import scipy.io.wavfile
import os
from imagebind import data
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
import torch.nn.functional as F
import omnisep
import utils
device = "cuda" if torch.cuda.is_available() else "cpu"
# ========== Configuration & Model Loading ==========
def setup_models(checkpoint_path, train_args_path):
train_args = utils.load_json(train_args_path)
model = omnisep.OmniSep(
train_args['n_mix'], train_args['layers'], train_args['channels'],
use_log_freq=train_args['log_freq'],
use_weighted_loss=train_args['weighted_loss'],
use_binary_mask=train_args['binary_mask'],
emb_dim=train_args.get('emb_dim', 512)
)
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
model.to(device)
model.eval()
imagebind_net = imagebind_model.imagebind_huge(pretrained=True)
imagebind_net = torch.nn.DataParallel(imagebind_net)
imagebind_net.to(device)
imagebind_net.eval()
return model, imagebind_net, train_args
# ========== Audio Loading & Preprocessing ==========
def load_audio_and_spec(audio_file, audio_len, sample_rate, n_fft, hop_len, win_len):
y, sr = librosa.load(audio_file, sr=sample_rate, mono=True)
if len(y) < audio_len:
y = np.tile(y, (audio_len // len(y) + 1))[:audio_len]
else:
y = y[:audio_len]
y = np.clip(y, -1, 1)
spec_mix = librosa.stft(y, n_fft=n_fft, hop_length=hop_len, win_length=win_len)
mag_mix = torch.tensor(np.abs(spec_mix)).unsqueeze(0).unsqueeze(0)
phase_mix = torch.tensor(np.angle(spec_mix)).unsqueeze(0).unsqueeze(0)
return mag_mix, phase_mix, y.shape[0]
# ========== Embedding Construction ==========
def get_combined_embedding(imagebind_net, text=None, image=None, audio=None,
text_w=1.0, image_w=1.0, audio_w=1.0):
inputs = {}
if text: inputs[ModalityType.TEXT] = data.load_and_transform_text([text], device)
if image: inputs[ModalityType.VISION] = data.load_and_transform_vision_data([image], device)
if audio: inputs[ModalityType.AUDIO] = data.load_and_transform_audio_data([audio], device)
emb = imagebind_net(inputs)
result = None
denom = 0
if text:
result = text_w * emb[ModalityType.TEXT]
denom += text_w
if image:
result = emb[ModalityType.VISION] * image_w if result is None else result + image_w * emb[ModalityType.VISION]
denom += image_w
if audio:
result = emb[ModalityType.AUDIO] * audio_w if result is None else result + audio_w * emb[ModalityType.AUDIO]
denom += audio_w
if denom > 0:
result = F.normalize(result / denom)
return result
# ========== Waveform Recovery ==========
def recover_waveform(mag_mix, phase_mix, pred_mask, args):
B = mag_mix.size(0)
if args['log_freq']:
grid_unwarp = torch.from_numpy(utils.warpgrid(B, args['n_fft'] // 2 + 1, pred_mask.size(3), warp=False)).to(pred_mask.device)
pred_mask_linear = F.grid_sample(pred_mask, grid_unwarp, align_corners=True)
else:
pred_mask_linear = pred_mask[0]
# pred_mag = mag_mix[0, 0].numpy() * pred_mask_linear[0, 0].numpy()
# pred_wav = utils.istft_reconstruction(pred_mag, phase_mix[0, 0].numpy(),
# hop_len=args['hop_len'], win_len=args['win_len'])
# Convert into numpy arrays
mag_mix = mag_mix.detach().cpu().numpy()
phase_mix = phase_mix.detach().cpu().numpy()
pred_mask = pred_mask.detach().cpu().numpy()
pred_mask_linear = pred_mask_linear.detach().cpu().numpy()
# Apply the threshold
pred_mask = (pred_mask > 0.5).astype(np.float32)
pred_mask_linear = (pred_mask_linear > 0.5).astype(np.float32)
# Recover predicted audio
pred_mag = mag_mix[0, 0] * pred_mask_linear[0, 0]
pred_wav = utils.istft_reconstruction(
pred_mag,
phase_mix[0, 0],
hop_len=args['hop_len'],
win_len=args['win_len'],
)
return pred_wav
# ========== Gradio Interface ==========
def run_inference(input_audio, text_pos, audio_pos, image_pos, text_neg, audio_neg, image_neg,
text_w, image_w, audio_w, neg_w):
model, imagebind_net, args = setup_models("./exp/checkpoints/best_model.pt", "./exp/train-args.json")
audio_len = 65535
mag_mix, phase_mix, out_len = load_audio_and_spec(input_audio, audio_len,
args['audio_rate'], args['n_fft'], args['hop_len'], args['win_len'])
img_emb = get_combined_embedding(imagebind_net, text_pos, image_pos, audio_pos,
text_w, image_w, audio_w)
if any([text_neg, audio_neg, image_neg]):
neg_emb = get_combined_embedding(imagebind_net, text_neg, image_neg, audio_neg,
1.0, 1.0, 1.0)
img_emb = (1 + neg_w) * img_emb - neg_w * neg_emb
mag_mix = mag_mix.to(device)
phase_mix = phase_mix.to(device)
pred_mask = model.module.infer(mag_mix, [img_emb])[0]
pred_wav = recover_waveform(mag_mix, phase_mix, pred_mask, args)
out_path = "/tmp/output.wav"
scipy.io.wavfile.write(out_path, args['audio_rate'], pred_wav[:out_len])
return out_path
with gr.Blocks(title="OmniSep UI") as iface:
gr.Markdown("## ๐ŸŽง Upload Your Mixed Audio")
mixed_audio = gr.Audio(type="filepath", label="Mixed Input Audio")
gr.Markdown("### โœ… Positive Query")
with gr.Row():
pos_text = gr.Textbox(label="Text Query", placeholder="e.g. dog barking")
pos_audio = gr.Audio(type="filepath", label="Audio Query")
pos_image = gr.Image(type="filepath", label="Image Query")
gr.Markdown("### โŒ Negative Query (Optional)")
with gr.Row():
neg_text = gr.Textbox(label="Negative Text Query")
neg_audio = gr.Audio(type="filepath", label="Negative Audio Query")
neg_image = gr.Image(type="filepath", label="Negative Image Query")
gr.Markdown("### ๐ŸŽš๏ธ Modality Weights")
with gr.Row():
text_weight = gr.Slider(0, 5, value=1.0, step=0.1, label="Text Weight")
image_weight = gr.Slider(0, 5, value=1.0, step=0.1, label="Image Weight")
audio_weight = gr.Slider(0, 5, value=1.0, step=0.1, label="Audio Weight")
neg_weight = gr.Slider(0, 2, value=0.5, step=0.1, label="Negative Embedding Weight")
output_audio = gr.Audio(type="filepath", label="Separated Output Audio")
btn = gr.Button("Run OmniSep Inference")
btn.click(fn=run_inference,
inputs=[mixed_audio, pos_text, pos_audio, pos_image, neg_text, neg_audio, neg_image,
text_weight, image_weight, audio_weight, neg_weight],
outputs=output_audio)
iface.launch(share=True)