snowsafed's picture
Create app.py
396b3c3 verified
import gradio as gr
import torch
import torchaudio
import numpy as np
from asteroid.models import ConvTasNet
from speechbrain.pretrained import SepformerSeparation
from scipy.io import wavfile
from scipy import signal
import librosa
import librosa.display
import matplotlib.pyplot as plt
import noisereduce as nr
import io
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Load models
print("Loading ConvTasNet...")
convtasnet_model = ConvTasNet.from_pretrained("JorisCos/ConvTasNet_Libri2Mix_sepclean_16k")
convtasnet_model = convtasnet_model.to(DEVICE)
convtasnet_model.eval()
print("Loading SepFormer...")
sepformer_model = SepformerSeparation.from_hparams(
source="speechbrain/sepformer-wsj02mix",
savedir="pretrained_models/sepformer-wsj02mix",
run_opts={"device": DEVICE}
)
def apply_highpass_filter(audio, sr, cutoff=80):
if len(audio) < 18:
return audio
try:
nyquist = sr / 2
normalized_cutoff = cutoff / nyquist
filter_order = min(4, max(1, len(audio) // 10))
b, a = signal.butter(filter_order, normalized_cutoff, btype='high', analog=False)
padlen = min(len(audio) // 3, 3 * max(len(a), len(b)))
filtered = signal.filtfilt(b, a, audio, padlen=padlen)
return filtered
except:
return audio
def normalize_audio(audio, target_level=-20):
rms = np.sqrt(np.mean(audio**2))
if rms > 0:
target_rms = 10**(target_level/20)
audio = audio * (target_rms / rms)
return np.clip(audio, -1.0, 1.0)
def apply_gate(audio, threshold=-40):
if len(audio) < 10:
return audio
try:
threshold_linear = 10**(threshold/20)
envelope = np.abs(signal.hilbert(audio))
gate_mask = envelope > threshold_linear
window_size = max(1, int(len(audio) * 0.001))
if window_size > 1 and window_size < len(gate_mask):
gate_mask = signal.convolve(gate_mask.astype(float),
np.ones(window_size)/window_size,
mode='same')
return audio * gate_mask
except:
return audio
def reduce_musical_noise(audio, sr):
if len(audio) < 100:
return audio
try:
reduced = nr.reduce_noise(y=audio, sr=sr, stationary=False, prop_decrease=0.6)
return reduced
except:
return audio
def enhance_separation(audio, sr, is_convtasnet=True):
if len(audio) < 100:
return audio
audio = apply_highpass_filter(audio, sr, cutoff=80)
if is_convtasnet:
audio = reduce_musical_noise(audio, sr)
threshold = -40 if is_convtasnet else -45
audio = apply_gate(audio, threshold=threshold)
audio = normalize_audio(audio, target_level=-20)
return audio
def separate_audio(audio_file, model_choice):
# Load audio
waveform, sample_rate = torchaudio.load(audio_file)
# Convert to mono
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample
target_sr = 16000 if model_choice == "ConvTasNet" else 8000
if sample_rate != target_sr:
resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
waveform = resampler(waveform)
sample_rate = target_sr
# Separate
if model_choice == "ConvTasNet":
with torch.no_grad():
waveform = waveform.to(DEVICE)
separated = convtasnet_model(waveform.unsqueeze(0))
separated = separated.squeeze(0).cpu()
source1 = separated[0].numpy()
source2 = separated[1].numpy()
else:
separated = sepformer_model.separate_file(path=audio_file)
separated = separated.squeeze()
if len(separated.shape) == 2:
if separated.shape[1] == 2 and separated.shape[0] > separated.shape[1]:
separated = separated.T
source1 = separated[0].cpu().numpy() if isinstance(separated[0], torch.Tensor) else separated[0]
source2 = separated[1].cpu().numpy() if isinstance(separated[1], torch.Tensor) else separated[1]
# Enhance
is_convtasnet = (model_choice == "ConvTasNet")
source1 = enhance_separation(source1, sample_rate, is_convtasnet)
source2 = enhance_separation(source2, sample_rate, is_convtasnet)
# Save as WAV
output1 = "speaker1.wav"
output2 = "speaker2.wav"
wavfile.write(output1, sample_rate, (source1 * 32767).astype(np.int16))
wavfile.write(output2, sample_rate, (source2 * 32767).astype(np.int16))
return output1, output2, f"Separated using {model_choice} with enhancement"
# Gradio Interface
demo = gr.Interface(
fn=separate_audio,
inputs=[
gr.Audio(type="filepath", label="Upload Mixed Audio"),
gr.Radio(["ConvTasNet", "SepFormer"], value="ConvTasNet", label="Model")
],
outputs=[
gr.Audio(label="Speaker 1"),
gr.Audio(label="Speaker 2"),
gr.Textbox(label="Status")
],
title="🎵 Audio Source Separator",
description="Upload mixed audio to separate it into individual speakers. Enhancement is automatically applied.",
theme="soft"
)
if __name__ == "__main__":
demo.launch()