Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| from asteroid.models import ConvTasNet | |
| from speechbrain.pretrained import SepformerSeparation | |
| from scipy.io import wavfile | |
| from scipy import signal | |
| import librosa | |
| import librosa.display | |
| import matplotlib.pyplot as plt | |
| import noisereduce as nr | |
| import io | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load models | |
| print("Loading ConvTasNet...") | |
| convtasnet_model = ConvTasNet.from_pretrained("JorisCos/ConvTasNet_Libri2Mix_sepclean_16k") | |
| convtasnet_model = convtasnet_model.to(DEVICE) | |
| convtasnet_model.eval() | |
| print("Loading SepFormer...") | |
| sepformer_model = SepformerSeparation.from_hparams( | |
| source="speechbrain/sepformer-wsj02mix", | |
| savedir="pretrained_models/sepformer-wsj02mix", | |
| run_opts={"device": DEVICE} | |
| ) | |
| def apply_highpass_filter(audio, sr, cutoff=80): | |
| if len(audio) < 18: | |
| return audio | |
| try: | |
| nyquist = sr / 2 | |
| normalized_cutoff = cutoff / nyquist | |
| filter_order = min(4, max(1, len(audio) // 10)) | |
| b, a = signal.butter(filter_order, normalized_cutoff, btype='high', analog=False) | |
| padlen = min(len(audio) // 3, 3 * max(len(a), len(b))) | |
| filtered = signal.filtfilt(b, a, audio, padlen=padlen) | |
| return filtered | |
| except: | |
| return audio | |
| def normalize_audio(audio, target_level=-20): | |
| rms = np.sqrt(np.mean(audio**2)) | |
| if rms > 0: | |
| target_rms = 10**(target_level/20) | |
| audio = audio * (target_rms / rms) | |
| return np.clip(audio, -1.0, 1.0) | |
| def apply_gate(audio, threshold=-40): | |
| if len(audio) < 10: | |
| return audio | |
| try: | |
| threshold_linear = 10**(threshold/20) | |
| envelope = np.abs(signal.hilbert(audio)) | |
| gate_mask = envelope > threshold_linear | |
| window_size = max(1, int(len(audio) * 0.001)) | |
| if window_size > 1 and window_size < len(gate_mask): | |
| gate_mask = signal.convolve(gate_mask.astype(float), | |
| np.ones(window_size)/window_size, | |
| mode='same') | |
| return audio * gate_mask | |
| except: | |
| return audio | |
| def reduce_musical_noise(audio, sr): | |
| if len(audio) < 100: | |
| return audio | |
| try: | |
| reduced = nr.reduce_noise(y=audio, sr=sr, stationary=False, prop_decrease=0.6) | |
| return reduced | |
| except: | |
| return audio | |
| def enhance_separation(audio, sr, is_convtasnet=True): | |
| if len(audio) < 100: | |
| return audio | |
| audio = apply_highpass_filter(audio, sr, cutoff=80) | |
| if is_convtasnet: | |
| audio = reduce_musical_noise(audio, sr) | |
| threshold = -40 if is_convtasnet else -45 | |
| audio = apply_gate(audio, threshold=threshold) | |
| audio = normalize_audio(audio, target_level=-20) | |
| return audio | |
| def separate_audio(audio_file, model_choice): | |
| # Load audio | |
| waveform, sample_rate = torchaudio.load(audio_file) | |
| # Convert to mono | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| # Resample | |
| target_sr = 16000 if model_choice == "ConvTasNet" else 8000 | |
| if sample_rate != target_sr: | |
| resampler = torchaudio.transforms.Resample(sample_rate, target_sr) | |
| waveform = resampler(waveform) | |
| sample_rate = target_sr | |
| # Separate | |
| if model_choice == "ConvTasNet": | |
| with torch.no_grad(): | |
| waveform = waveform.to(DEVICE) | |
| separated = convtasnet_model(waveform.unsqueeze(0)) | |
| separated = separated.squeeze(0).cpu() | |
| source1 = separated[0].numpy() | |
| source2 = separated[1].numpy() | |
| else: | |
| separated = sepformer_model.separate_file(path=audio_file) | |
| separated = separated.squeeze() | |
| if len(separated.shape) == 2: | |
| if separated.shape[1] == 2 and separated.shape[0] > separated.shape[1]: | |
| separated = separated.T | |
| source1 = separated[0].cpu().numpy() if isinstance(separated[0], torch.Tensor) else separated[0] | |
| source2 = separated[1].cpu().numpy() if isinstance(separated[1], torch.Tensor) else separated[1] | |
| # Enhance | |
| is_convtasnet = (model_choice == "ConvTasNet") | |
| source1 = enhance_separation(source1, sample_rate, is_convtasnet) | |
| source2 = enhance_separation(source2, sample_rate, is_convtasnet) | |
| # Save as WAV | |
| output1 = "speaker1.wav" | |
| output2 = "speaker2.wav" | |
| wavfile.write(output1, sample_rate, (source1 * 32767).astype(np.int16)) | |
| wavfile.write(output2, sample_rate, (source2 * 32767).astype(np.int16)) | |
| return output1, output2, f"Separated using {model_choice} with enhancement" | |
| # Gradio Interface | |
| demo = gr.Interface( | |
| fn=separate_audio, | |
| inputs=[ | |
| gr.Audio(type="filepath", label="Upload Mixed Audio"), | |
| gr.Radio(["ConvTasNet", "SepFormer"], value="ConvTasNet", label="Model") | |
| ], | |
| outputs=[ | |
| gr.Audio(label="Speaker 1"), | |
| gr.Audio(label="Speaker 2"), | |
| gr.Textbox(label="Status") | |
| ], | |
| title="🎵 Audio Source Separator", | |
| description="Upload mixed audio to separate it into individual speakers. Enhancement is automatically applied.", | |
| theme="soft" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |