import gradio as gr import torch import torchaudio import numpy as np from asteroid.models import ConvTasNet from speechbrain.pretrained import SepformerSeparation from scipy.io import wavfile from scipy import signal import librosa import librosa.display import matplotlib.pyplot as plt import noisereduce as nr import io DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Load models print("Loading ConvTasNet...") convtasnet_model = ConvTasNet.from_pretrained("JorisCos/ConvTasNet_Libri2Mix_sepclean_16k") convtasnet_model = convtasnet_model.to(DEVICE) convtasnet_model.eval() print("Loading SepFormer...") sepformer_model = SepformerSeparation.from_hparams( source="speechbrain/sepformer-wsj02mix", savedir="pretrained_models/sepformer-wsj02mix", run_opts={"device": DEVICE} ) def apply_highpass_filter(audio, sr, cutoff=80): if len(audio) < 18: return audio try: nyquist = sr / 2 normalized_cutoff = cutoff / nyquist filter_order = min(4, max(1, len(audio) // 10)) b, a = signal.butter(filter_order, normalized_cutoff, btype='high', analog=False) padlen = min(len(audio) // 3, 3 * max(len(a), len(b))) filtered = signal.filtfilt(b, a, audio, padlen=padlen) return filtered except: return audio def normalize_audio(audio, target_level=-20): rms = np.sqrt(np.mean(audio**2)) if rms > 0: target_rms = 10**(target_level/20) audio = audio * (target_rms / rms) return np.clip(audio, -1.0, 1.0) def apply_gate(audio, threshold=-40): if len(audio) < 10: return audio try: threshold_linear = 10**(threshold/20) envelope = np.abs(signal.hilbert(audio)) gate_mask = envelope > threshold_linear window_size = max(1, int(len(audio) * 0.001)) if window_size > 1 and window_size < len(gate_mask): gate_mask = signal.convolve(gate_mask.astype(float), np.ones(window_size)/window_size, mode='same') return audio * gate_mask except: return audio def reduce_musical_noise(audio, sr): if len(audio) < 100: return audio try: reduced = nr.reduce_noise(y=audio, sr=sr, stationary=False, prop_decrease=0.6) return reduced except: return audio def enhance_separation(audio, sr, is_convtasnet=True): if len(audio) < 100: return audio audio = apply_highpass_filter(audio, sr, cutoff=80) if is_convtasnet: audio = reduce_musical_noise(audio, sr) threshold = -40 if is_convtasnet else -45 audio = apply_gate(audio, threshold=threshold) audio = normalize_audio(audio, target_level=-20) return audio def separate_audio(audio_file, model_choice): # Load audio waveform, sample_rate = torchaudio.load(audio_file) # Convert to mono if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # Resample target_sr = 16000 if model_choice == "ConvTasNet" else 8000 if sample_rate != target_sr: resampler = torchaudio.transforms.Resample(sample_rate, target_sr) waveform = resampler(waveform) sample_rate = target_sr # Separate if model_choice == "ConvTasNet": with torch.no_grad(): waveform = waveform.to(DEVICE) separated = convtasnet_model(waveform.unsqueeze(0)) separated = separated.squeeze(0).cpu() source1 = separated[0].numpy() source2 = separated[1].numpy() else: separated = sepformer_model.separate_file(path=audio_file) separated = separated.squeeze() if len(separated.shape) == 2: if separated.shape[1] == 2 and separated.shape[0] > separated.shape[1]: separated = separated.T source1 = separated[0].cpu().numpy() if isinstance(separated[0], torch.Tensor) else separated[0] source2 = separated[1].cpu().numpy() if isinstance(separated[1], torch.Tensor) else separated[1] # Enhance is_convtasnet = (model_choice == "ConvTasNet") source1 = enhance_separation(source1, sample_rate, is_convtasnet) source2 = enhance_separation(source2, sample_rate, is_convtasnet) # Save as WAV output1 = "speaker1.wav" output2 = "speaker2.wav" wavfile.write(output1, sample_rate, (source1 * 32767).astype(np.int16)) wavfile.write(output2, sample_rate, (source2 * 32767).astype(np.int16)) return output1, output2, f"Separated using {model_choice} with enhancement" # Gradio Interface demo = gr.Interface( fn=separate_audio, inputs=[ gr.Audio(type="filepath", label="Upload Mixed Audio"), gr.Radio(["ConvTasNet", "SepFormer"], value="ConvTasNet", label="Model") ], outputs=[ gr.Audio(label="Speaker 1"), gr.Audio(label="Speaker 2"), gr.Textbox(label="Status") ], title="🎵 Audio Source Separator", description="Upload mixed audio to separate it into individual speakers. Enhancement is automatically applied.", theme="soft" ) if __name__ == "__main__": demo.launch()