import gradio as gr import torch import numpy as np import librosa import pathlib import scipy.io.wavfile import os from imagebind import data from imagebind.models import imagebind_model from imagebind.models.imagebind_model import ModalityType import torch.nn.functional as F import omnisep import utils device = "cuda" if torch.cuda.is_available() else "cpu" # ========== Configuration & Model Loading ========== def setup_models(checkpoint_path, train_args_path): train_args = utils.load_json(train_args_path) model = omnisep.OmniSep( train_args['n_mix'], train_args['layers'], train_args['channels'], use_log_freq=train_args['log_freq'], use_weighted_loss=train_args['weighted_loss'], use_binary_mask=train_args['binary_mask'], emb_dim=train_args.get('emb_dim', 512) ) model = torch.nn.DataParallel(model) model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) model.to(device) model.eval() imagebind_net = imagebind_model.imagebind_huge(pretrained=True) imagebind_net = torch.nn.DataParallel(imagebind_net) imagebind_net.to(device) imagebind_net.eval() return model, imagebind_net, train_args # ========== Audio Loading & Preprocessing ========== def load_audio_and_spec(audio_file, audio_len, sample_rate, n_fft, hop_len, win_len): y, sr = librosa.load(audio_file, sr=sample_rate, mono=True) if len(y) < audio_len: y = np.tile(y, (audio_len // len(y) + 1))[:audio_len] else: y = y[:audio_len] y = np.clip(y, -1, 1) spec_mix = librosa.stft(y, n_fft=n_fft, hop_length=hop_len, win_length=win_len) mag_mix = torch.tensor(np.abs(spec_mix)).unsqueeze(0).unsqueeze(0) phase_mix = torch.tensor(np.angle(spec_mix)).unsqueeze(0).unsqueeze(0) return mag_mix, phase_mix, y.shape[0] # ========== Embedding Construction ========== def get_combined_embedding(imagebind_net, text=None, image=None, audio=None, text_w=1.0, image_w=1.0, audio_w=1.0): inputs = {} if text: inputs[ModalityType.TEXT] = data.load_and_transform_text([text], device) if image: inputs[ModalityType.VISION] = data.load_and_transform_vision_data([image], device) if audio: inputs[ModalityType.AUDIO] = data.load_and_transform_audio_data([audio], device) emb = imagebind_net(inputs) result = None denom = 0 if text: result = text_w * emb[ModalityType.TEXT] denom += text_w if image: result = emb[ModalityType.VISION] * image_w if result is None else result + image_w * emb[ModalityType.VISION] denom += image_w if audio: result = emb[ModalityType.AUDIO] * audio_w if result is None else result + audio_w * emb[ModalityType.AUDIO] denom += audio_w if denom > 0: result = F.normalize(result / denom) return result # ========== Waveform Recovery ========== def recover_waveform(mag_mix, phase_mix, pred_mask, args): B = mag_mix.size(0) if args['log_freq']: grid_unwarp = torch.from_numpy(utils.warpgrid(B, args['n_fft'] // 2 + 1, pred_mask.size(3), warp=False)).to(pred_mask.device) pred_mask_linear = F.grid_sample(pred_mask, grid_unwarp, align_corners=True) else: pred_mask_linear = pred_mask[0] # pred_mag = mag_mix[0, 0].numpy() * pred_mask_linear[0, 0].numpy() # pred_wav = utils.istft_reconstruction(pred_mag, phase_mix[0, 0].numpy(), # hop_len=args['hop_len'], win_len=args['win_len']) # Convert into numpy arrays mag_mix = mag_mix.detach().cpu().numpy() phase_mix = phase_mix.detach().cpu().numpy() pred_mask = pred_mask.detach().cpu().numpy() pred_mask_linear = pred_mask_linear.detach().cpu().numpy() # Apply the threshold pred_mask = (pred_mask > 0.5).astype(np.float32) pred_mask_linear = (pred_mask_linear > 0.5).astype(np.float32) # Recover predicted audio pred_mag = mag_mix[0, 0] * pred_mask_linear[0, 0] pred_wav = utils.istft_reconstruction( pred_mag, phase_mix[0, 0], hop_len=args['hop_len'], win_len=args['win_len'], ) return pred_wav # ========== Gradio Interface ========== def run_inference(input_audio, text_pos, audio_pos, image_pos, text_neg, audio_neg, image_neg, text_w, image_w, audio_w, neg_w): model, imagebind_net, args = setup_models("./exp/checkpoints/best_model.pt", "./exp/train-args.json") audio_len = 65535 mag_mix, phase_mix, out_len = load_audio_and_spec(input_audio, audio_len, args['audio_rate'], args['n_fft'], args['hop_len'], args['win_len']) img_emb = get_combined_embedding(imagebind_net, text_pos, image_pos, audio_pos, text_w, image_w, audio_w) if any([text_neg, audio_neg, image_neg]): neg_emb = get_combined_embedding(imagebind_net, text_neg, image_neg, audio_neg, 1.0, 1.0, 1.0) img_emb = (1 + neg_w) * img_emb - neg_w * neg_emb mag_mix = mag_mix.to(device) phase_mix = phase_mix.to(device) pred_mask = model.module.infer(mag_mix, [img_emb])[0] pred_wav = recover_waveform(mag_mix, phase_mix, pred_mask, args) out_path = "/tmp/output.wav" scipy.io.wavfile.write(out_path, args['audio_rate'], pred_wav[:out_len]) return out_path with gr.Blocks(title="OmniSep UI") as iface: gr.Markdown("## 🎧 Upload Your Mixed Audio") mixed_audio = gr.Audio(type="filepath", label="Mixed Input Audio") gr.Markdown("### ✅ Positive Query") with gr.Row(): pos_text = gr.Textbox(label="Text Query", placeholder="e.g. dog barking") pos_audio = gr.Audio(type="filepath", label="Audio Query") pos_image = gr.Image(type="filepath", label="Image Query") gr.Markdown("### ❌ Negative Query (Optional)") with gr.Row(): neg_text = gr.Textbox(label="Negative Text Query") neg_audio = gr.Audio(type="filepath", label="Negative Audio Query") neg_image = gr.Image(type="filepath", label="Negative Image Query") gr.Markdown("### 🎚️ Modality Weights") with gr.Row(): text_weight = gr.Slider(0, 5, value=1.0, step=0.1, label="Text Weight") image_weight = gr.Slider(0, 5, value=1.0, step=0.1, label="Image Weight") audio_weight = gr.Slider(0, 5, value=1.0, step=0.1, label="Audio Weight") neg_weight = gr.Slider(0, 2, value=0.5, step=0.1, label="Negative Embedding Weight") output_audio = gr.Audio(type="filepath", label="Separated Output Audio") btn = gr.Button("Run OmniSep Inference") btn.click(fn=run_inference, inputs=[mixed_audio, pos_text, pos_audio, pos_image, neg_text, neg_audio, neg_image, text_weight, image_weight, audio_weight, neg_weight], outputs=output_audio) iface.launch(share=True)