snowsafed commited on
Commit
5602abd
·
verified ·
1 Parent(s): f488e69

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -0
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ import numpy as np
5
+ from asteroid.models import ConvTasNet
6
+ from speechbrain.pretrained import SepformerSeparation
7
+ from scipy.io import wavfile
8
+ from scipy import signal
9
+ import noisereduce as nr
10
+ import warnings
11
+ warnings.filterwarnings('ignore')
12
+
13
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+ print(f"Using device: {DEVICE}")
15
+
16
+ # Global model variables
17
+ convtasnet_model = None
18
+ sepformer_model = None
19
+
20
+ def load_convtasnet():
21
+ global convtasnet_model
22
+ if convtasnet_model is None:
23
+ print("Loading ConvTasNet model...")
24
+ convtasnet_model = ConvTasNet.from_pretrained("JorisCos/ConvTasNet_Libri2Mix_sepclean_16k")
25
+ convtasnet_model = convtasnet_model.to(DEVICE)
26
+ convtasnet_model.eval()
27
+ print("ConvTasNet loaded!")
28
+ return convtasnet_model
29
+
30
+ def load_sepformer():
31
+ global sepformer_model
32
+ if sepformer_model is None:
33
+ print("Loading SepFormer model...")
34
+ sepformer_model = SepformerSeparation.from_hparams(
35
+ source="speechbrain/sepformer-wsj02mix",
36
+ savedir="pretrained_models/sepformer-wsj02mix",
37
+ run_opts={"device": DEVICE}
38
+ )
39
+ print("SepFormer loaded!")
40
+ return sepformer_model
41
+
42
+ def apply_highpass_filter(audio, sr, cutoff=80):
43
+ if len(audio) < 18:
44
+ return audio
45
+ try:
46
+ nyquist = sr / 2
47
+ normalized_cutoff = cutoff / nyquist
48
+ filter_order = min(4, max(1, len(audio) // 10))
49
+ b, a = signal.butter(filter_order, normalized_cutoff, btype='high', analog=False)
50
+ padlen = min(len(audio) // 3, 3 * max(len(a), len(b)))
51
+ filtered = signal.filtfilt(b, a, audio, padlen=padlen)
52
+ return filtered
53
+ except:
54
+ return audio
55
+
56
+ def normalize_audio(audio, target_level=-20):
57
+ rms = np.sqrt(np.mean(audio**2))
58
+ if rms > 0:
59
+ target_rms = 10**(target_level/20)
60
+ audio = audio * (target_rms / rms)
61
+ return np.clip(audio, -1.0, 1.0)
62
+
63
+ def apply_gate(audio, threshold=-40):
64
+ if len(audio) < 10:
65
+ return audio
66
+ try:
67
+ threshold_linear = 10**(threshold/20)
68
+ envelope = np.abs(signal.hilbert(audio))
69
+ gate_mask = envelope > threshold_linear
70
+ window_size = max(1, int(len(audio) * 0.001))
71
+ if window_size > 1 and window_size < len(gate_mask):
72
+ gate_mask = signal.convolve(gate_mask.astype(float),
73
+ np.ones(window_size)/window_size,
74
+ mode='same')
75
+ return audio * gate_mask
76
+ except:
77
+ return audio
78
+
79
+ def reduce_musical_noise(audio, sr):
80
+ if len(audio) < 100:
81
+ return audio
82
+ try:
83
+ reduced = nr.reduce_noise(y=audio, sr=sr, stationary=False, prop_decrease=0.6)
84
+ return reduced
85
+ except:
86
+ return audio
87
+
88
+ def enhance_separation(audio, sr, is_convtasnet=True):
89
+ if len(audio) < 100:
90
+ return audio
91
+ audio = apply_highpass_filter(audio, sr, cutoff=80)
92
+ if is_convtasnet:
93
+ audio = reduce_musical_noise(audio, sr)
94
+ threshold = -40 if is_convtasnet else -45
95
+ audio = apply_gate(audio, threshold=threshold)
96
+ audio = normalize_audio(audio, target_level=-20)
97
+ return audio
98
+
99
+ def separate_audio(audio_file, model_choice):
100
+ try:
101
+ # Load audio
102
+ waveform, sample_rate = torchaudio.load(audio_file)
103
+
104
+ # Convert to mono
105
+ if waveform.shape[0] > 1:
106
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
107
+
108
+ # Resample
109
+ target_sr = 16000 if model_choice == "ConvTasNet" else 8000
110
+ if sample_rate != target_sr:
111
+ resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
112
+ waveform = resampler(waveform)
113
+ sample_rate = target_sr
114
+
115
+ # Separate based on model choice
116
+ if model_choice == "ConvTasNet":
117
+ model = load_convtasnet()
118
+ with torch.no_grad():
119
+ waveform_input = waveform.to(DEVICE)
120
+ separated = model(waveform_input.unsqueeze(0))
121
+ separated = separated.squeeze(0).cpu()
122
+ source1 = separated[0].numpy()
123
+ source2 = separated[1].numpy()
124
+ else: # SepFormer
125
+ model = load_sepformer()
126
+ separated = model.separate_file(path=audio_file)
127
+ separated = separated.squeeze()
128
+
129
+ # Handle shape
130
+ if len(separated.shape) == 2:
131
+ if separated.shape[1] == 2 and separated.shape[0] > separated.shape[1]:
132
+ separated = separated.T
133
+ source1 = separated[0].cpu().numpy() if isinstance(separated[0], torch.Tensor) else separated[0]
134
+ source2 = separated[1].cpu().numpy() if isinstance(separated[1], torch.Tensor) else separated[1]
135
+ else:
136
+ raise ValueError(f"Unexpected shape: {separated.shape}")
137
+
138
+ # Enhance audio (always on)
139
+ is_convtasnet = (model_choice == "ConvTasNet")
140
+ source1 = enhance_separation(source1, sample_rate, is_convtasnet)
141
+ source2 = enhance_separation(source2, sample_rate, is_convtasnet)
142
+
143
+ # Save as WAV files
144
+ output1 = "speaker1.wav"
145
+ output2 = "speaker2.wav"
146
+ wavfile.write(output1, sample_rate, (source1 * 32767).astype(np.int16))
147
+ wavfile.write(output2, sample_rate, (source2 * 32767).astype(np.int16))
148
+
149
+ status = f"✅ Separation complete using {model_choice} with audio enhancement"
150
+ return output1, output2, status
151
+
152
+ except Exception as e:
153
+ error_msg = f"❌ Error: {str(e)}"
154
+ print(error_msg)
155
+ import traceback
156
+ traceback.print_exc()
157
+ return None, None, error_msg
158
+
159
+ # Create Gradio Interface
160
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
161
+ gr.Markdown(
162
+ """
163
+ # 🎵 Audio Source Separator
164
+ Upload mixed audio to separate it into individual speakers using AI.
165
+ Enhancement is automatically applied for best quality.
166
+ """
167
+ )
168
+
169
+ with gr.Row():
170
+ with gr.Column():
171
+ audio_input = gr.Audio(
172
+ label="Upload Mixed Audio",
173
+ type="filepath"
174
+ )
175
+ model_choice = gr.Radio(
176
+ ["ConvTasNet", "SepFormer"],
177
+ label="Select Model",
178
+ value="ConvTasNet",
179
+ info="ConvTasNet: Faster | SepFormer: Higher Quality"
180
+ )
181
+ separate_btn = gr.Button("🚀 Separate Audio", variant="primary")
182
+
183
+ with gr.Column():
184
+ status_output = gr.Textbox(label="Status", interactive=False)
185
+
186
+ with gr.Row():
187
+ audio_output1 = gr.Audio(label="🎤 Speaker 1")
188
+ audio_output2 = gr.Audio(label="🎤 Speaker 2")
189
+
190
+ gr.Markdown(
191
+ """
192
+ ### 📝 How to Use:
193
+ 1. Upload your mixed audio file (MP3, WAV, etc.)
194
+ 2. Choose a model (ConvTasNet is faster, SepFormer is more accurate)
195
+ 3. Click "Separate Audio" and wait
196
+ 4. Download the separated audio files
197
+
198
+ **Note:** First separation takes longer as models load. Subsequent separations are faster!
199
+ """
200
+ )
201
+
202
+ separate_btn.click(
203
+ fn=separate_audio,
204
+ inputs=[audio_input, model_choice],
205
+ outputs=[audio_output1, audio_output2, status_output]
206
+ )
207
+
208
+ # Preload models on startup
209
+ print("Preloading ConvTasNet model...")
210
+ load_convtasnet()
211
+
212
+ if __name__ == "__main__":
213
+ demo.launch()