Spaces:
Running
Running
| from logging import config | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import soundfile as sf | |
| from scipy.signal import resample | |
| import matplotlib.pyplot as plt | |
| from pathlib import Path | |
| import librosa | |
| import librosa.display | |
| from matplotlib.colors import LinearSegmentedColormap | |
| import io | |
| from PIL import Image | |
| import traceback | |
| from scipy.signal import get_window | |
| from model import SeparationModel, InputSpec | |
| import json | |
| import torchaudio | |
| import torch | |
| # Global model storage | |
| models = {} | |
| def load_model(model_type="without_vad"): | |
| """Load the appropriate model based on user selection""" | |
| if model_type not in models: | |
| if model_type == "with_vad": | |
| # Load model with VAD capabilities | |
| model_path = "model_with_vad.pth" | |
| print(f"Loading model with VAD: {model_path}") | |
| with open('config_with_vad.json', 'r') as f: | |
| config = json.load(f) | |
| params = config.get("arch", {}).get("args", {}) | |
| else: | |
| # Load standard model without VAD | |
| model_path = "model_without_vad.pth" | |
| print(f"Loading standard model: {model_path}") | |
| with open('config_without_vad.json', 'r') as f: | |
| config = json.load(f) | |
| params = config.get("arch", {}).get("args", {}) | |
| try: | |
| model = SeparationModel(**params) | |
| checkpoint = torch.load(model_path, map_location="cpu", weights_only=False) | |
| state_dict = checkpoint['state_dict'] | |
| model.load_state_dict(state_dict, strict=True) | |
| model.eval() | |
| models[model_type] = model | |
| print(f"β Model {model_type} loaded successfully") | |
| except FileNotFoundError: | |
| print(f"β οΈ Model file {model_path} not found. Using dummy model.") | |
| # Create a dummy model for demonstration | |
| models[model_type] = create_dummy_model(with_vad=(model_type == "with_vad")) | |
| return models[model_type] | |
| def create_dummy_model(with_vad=False): | |
| """Create a dummy model for demonstration purposes""" | |
| class DummyModel(torch.nn.Module): | |
| def __init__(self, with_vad=False): | |
| super().__init__() | |
| self.with_vad = with_vad | |
| def forward(self, x): | |
| # Simple dummy separation: add some noise and phase shifts | |
| spk1 = x + 0.1 * torch.randn_like(x) | |
| spk2 = x + 0.1 * torch.randn_like(x) * -1 | |
| if self.with_vad: | |
| # Dummy VAD: random activation patterns | |
| batch_size, seq_len = x.shape | |
| vad1 = torch.sigmoid(torch.randn(batch_size, seq_len // 1000)) # Downsampled VAD | |
| vad2 = torch.sigmoid(torch.randn(batch_size, seq_len // 1000)) | |
| return torch.stack([spk1, spk2]), torch.stack([vad1, vad2]) | |
| else: | |
| return torch.stack([spk1, spk2]) | |
| return DummyModel(with_vad=with_vad) | |
| def separate_speakers(mixed_audio, model_type="without_vad"): | |
| """Separate speakers using the selected model""" | |
| model = load_model(model_type) | |
| print(f"The mixed audio shape is: {mixed_audio.shape}") | |
| audio_tensor = torch.from_numpy(mixed_audio).float().unsqueeze(0) | |
| with torch.no_grad(): | |
| if model_type == "with_vad": | |
| separated, vad = model(audio_tensor) | |
| separated = separated.squeeze(0) # Remove batch dimension | |
| vad = vad.squeeze(0) # Remove batch dimension | |
| spk1 = separated[0].cpu().numpy() | |
| spk2 = separated[1].cpu().numpy() | |
| vad1 = vad[0].cpu().numpy() | |
| vad2 = vad[1].cpu().numpy() | |
| return spk1, spk2, vad1, vad2 | |
| else: | |
| separated, vad = model(audio_tensor) | |
| separated = separated.squeeze(0) # Remove batch dimension | |
| spk1 = separated[0].cpu().numpy() | |
| spk2 = separated[1].cpu().numpy() | |
| print(f"Separated speakers: {spk1.shape}, {spk2.shape}") | |
| return spk1, spk2, None, None | |
| def create_spectrogram(audio, sr=16000, title="Spectrogram", vad_data=None, vad_threshold=0.7): | |
| """Create a beautiful spectrogram plot with optional VAD overlay""" | |
| # Compute Short-Time Fourier Transform | |
| D = librosa.stft(audio, hop_length=512, n_fft=2048) | |
| S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max) | |
| # Create figure with custom styling | |
| plt.style.use('dark_background') | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| # Custom colormap for better visual appeal | |
| colors = ['#0d1117', '#1f2937', '#3730a3', '#7c3aed', '#ec4899', '#f59e0b', '#eab308'] | |
| n_bins = 256 | |
| cmap = LinearSegmentedColormap.from_list('custom', colors, N=n_bins) | |
| # Parameters for InputSpec (can be adjusted) | |
| n_fft = 512 | |
| hop_length = 256 | |
| win_length = 512 | |
| # Convert audio to torch tensor | |
| audio_tensor = torch.tensor(audio, dtype=torch.float32) | |
| if audio_tensor.dim() == 1: | |
| audio_tensor = audio_tensor.unsqueeze(0) # (1, N) | |
| # Compute spectrogram using InputSpec | |
| spec_layer = InputSpec(n_fft=n_fft, hop_length=hop_length, win_length=win_length) | |
| stft = spec_layer(audio_tensor) | |
| # Compute magnitude and convert to dB | |
| S = stft.abs().squeeze(0).numpy() | |
| S_db = librosa.amplitude_to_db(S, ref=np.max) | |
| # Plot using the same colormap | |
| img = librosa.display.specshow(S_db, sr=sr, hop_length=hop_length, x_axis='time', | |
| y_axis='hz', ax=ax, cmap=cmap, vmin=-80, vmax=0) | |
| # Overlay VAD if provided | |
| if vad_data is not None: | |
| # threshold = 0.7 | |
| threshold = vad_threshold | |
| # Time axis aligned with spectrogram | |
| vad_time_axis = librosa.frames_to_time(np.arange(len(vad_data)), sr=sr, hop_length=hop_length) | |
| # Convert VAD scores to binary (1=voice, 0=no voice) | |
| vad_mask = vad_data > threshold | |
| # Frequency range for highlighting (in Hz) | |
| freq_max = sr // 2 | |
| vad_height_min = 0.15 * freq_max | |
| vad_height_max = freq_max | |
| # Set VAD values to high or low frequency based on activity | |
| vad_y = np.where(vad_mask, vad_height_max, vad_height_min) | |
| # Create a twin y-axis | |
| ax2 = ax.twinx() | |
| # Option 1: shaded region | |
| # ax2.fill_between(vad_time_axis, vad_height_min, vad_y, color='#10b981', alpha=0.4) | |
| # Option 2 (alternative): line plot | |
| ax2.plot(vad_time_axis, vad_y, color="#F1F1F1", linewidth=3, alpha=0.9) | |
| # Adjust secondary axis | |
| ax2.set_ylim(0, freq_max) | |
| ax2.set_xlim(ax.get_xlim()) | |
| ax2.set_ylabel('') | |
| ax2.set_yticks([]) | |
| ax2.spines['right'].set_visible(False) | |
| ax2.spines['top'].set_visible(False) | |
| ax2.spines['left'].set_visible(False) | |
| # Add legend text | |
| ax2.text(0.02, 0.95, 'Voice Activity', transform=ax2.transAxes, | |
| color="#F1F1F1", fontsize=10, fontweight='bold', | |
| bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7)) | |
| # Styling | |
| ax.set_title(f'{title}', fontsize=16, fontweight='bold', color='white', pad=20) | |
| ax.set_xlabel('Time (seconds)', fontsize=12, color='white') | |
| ax.set_ylabel('Frequency (Hz)', fontsize=12, color='white') | |
| ax.grid(True, alpha=0.3) | |
| # Add colorbar | |
| cbar = fig.colorbar(img, ax=ax, format='%+2.0f dB') | |
| cbar.set_label('Amplitude (dB)', fontsize=12, color='white') | |
| cbar.ax.yaxis.set_tick_params(color='white') | |
| plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='white') | |
| # Set background | |
| fig.patch.set_facecolor('#0d1117') | |
| ax.set_facecolor('#0d1117') | |
| plt.tight_layout() | |
| # Convert to image for Gradio | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', facecolor='#0d1117', edgecolor='none', dpi=150) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| plt.close() | |
| return img | |
| def inference_gradio(audio, model_choice, vad_threshold=0.7): | |
| """Main inference function for Gradio interface""" | |
| if audio is None: | |
| return None, None, None, None, None, None, None, "β Please upload or record an audio file." | |
| try: | |
| print(audio) | |
| samplerate, audio_data = audio | |
| audio_array = np.array(audio_data, dtype=np.float32) | |
| # Handle multi-channel audio | |
| if audio_array.ndim > 1: | |
| status_msg = "π Multi-channel audio detected, using first channel." | |
| if audio_array.shape[1] > audio_array.shape[0]: | |
| audio_array = audio_array[0] | |
| else: | |
| audio_array = audio_array[:, 0] | |
| else: | |
| status_msg = "β Processing mono audio." | |
| # Resample to 16kHz if necessary | |
| if samplerate != 16000: | |
| len_audio = len(audio_array) | |
| new_len = int(len_audio * 16000 / samplerate) | |
| audio_array = resample(audio_array, new_len) | |
| status_msg += f" Resampled from {samplerate}Hz to 16kHz." | |
| # Normalize audio | |
| if audio_array.max() != audio_array.min(): | |
| normalized_audio = 1.8 * (audio_array - audio_array.min()) / (audio_array.max() - audio_array.min()) - 0.9 | |
| else: | |
| normalized_audio = audio_array | |
| # Determine model type | |
| print(f"Selected model: {model_choice}") | |
| model_type = "with_vad" if "VAD" in model_choice else "without_vad" | |
| # Separate speakers | |
| spk1, spk2, vad1, vad2 = separate_speakers(normalized_audio, model_type) | |
| # Create spectrograms with VAD overlay if available | |
| mixed_spec = create_spectrogram(normalized_audio, title="Mixed Audio Spectrogram") | |
| if model_type == "with_vad" and vad1 is not None and vad2 is not None: | |
| spk1_spec = create_spectrogram(spk1, title="Speaker 1 Spectrogram + VAD", vad_data=vad1, vad_threshold=vad_threshold) | |
| spk2_spec = create_spectrogram(spk2, title="Speaker 2 Spectrogram + VAD", vad_data=vad2, vad_threshold=vad_threshold) | |
| else: | |
| spk1_spec = create_spectrogram(spk1, title="Speaker 1 Spectrogram") | |
| spk2_spec = create_spectrogram(spk2, title="Speaker 2 Spectrogram") | |
| # For backwards compatibility, set VAD plots to None since they're now overlaid | |
| vad1_plot = None | |
| vad2_plot = None | |
| status_msg += f" β Successfully separated using {model_choice}!" | |
| # Return audio and visualizations | |
| return ( | |
| (16000, spk1), | |
| (16000, spk2), | |
| mixed_spec, | |
| spk1_spec, | |
| spk2_spec, | |
| status_msg | |
| ) | |
| except Exception as e: | |
| error_msg = f"β Error during processing: {str(e)}" | |
| traceback.print_exc() | |
| return None, None, None, None, None, error_msg | |
| def list_example_audios(): | |
| """Return a dict of example wav files in the current directory.""" | |
| example_files = sorted(Path(".").glob("Mixed*.wav")) | |
| return {f.name: str(f) for f in example_files} | |
| def load_example_audio_by_path(path): | |
| """Load a wav file by path for Gradio.""" | |
| if Path(path).exists(): | |
| audio, sr = sf.read(path) | |
| if audio.ndim > 1: | |
| audio = audio[:, 0] | |
| return (sr, audio) | |
| return None | |
| def load_example_audio(): | |
| """Load example audio file as (sample_rate, np.ndarray) tuple for Gradio.""" | |
| example_path = "example_mixed.wav" | |
| if Path(example_path).exists(): | |
| audio, sr = sf.read(example_path) | |
| # Ensure mono for demo | |
| if audio.ndim > 1: | |
| audio = audio[:, 0] | |
| return (sr, audio) | |
| return None | |
| # Create the Gradio interface | |
| def create_interface(): | |
| example_files = list_example_audios() | |
| default_example = next(iter(example_files.values()), None) | |
| default_audio = load_example_audio_by_path(default_example) if default_example else None | |
| with gr.Blocks(css=""" | |
| .centered { | |
| align-items: center; | |
| text-align: center; | |
| } | |
| .center-radio .gr-form { | |
| align-items: center; | |
| } | |
| .center-radio .gr-radio { | |
| display: flex; | |
| flex-direction: column; | |
| align-items: center; | |
| } | |
| """, | |
| theme = gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="stone", | |
| neutral_hue="zinc" | |
| ).set( | |
| # Background colors - softer, warmer grays | |
| background_fill_primary="#f8f9fa", | |
| background_fill_secondary="#f1f3f4", | |
| block_background_fill="#ffffff", | |
| # Borders - very subtle | |
| border_color_primary="#e5e7eb", | |
| border_color_accent="#d1d5db", | |
| # Text colors - muted | |
| body_text_color="#374151", | |
| body_text_color_subdued="#6b7280", | |
| # Button colors - understated | |
| button_primary_background_fill="#f3f4f6", | |
| button_primary_background_fill_hover="#e5e7eb", | |
| button_primary_text_color="#374151", | |
| # Input fields - clean and minimal | |
| input_background_fill="#fafafa", | |
| input_background_fill_focus="#ffffff", | |
| input_border_color="#e5e7eb", | |
| input_border_color_focus="#d1d5db", | |
| # Accent colors - very muted | |
| color_accent="#8b5cf6", | |
| color_accent_soft="#f3f0ff", | |
| # Shadows - barely visible | |
| shadow_drop="0 1px 3px 0 rgba(0, 0, 0, 0.05)", | |
| shadow_drop_lg="0 4px 6px -1px rgba(0, 0, 0, 0.05)" | |
| ) | |
| ) as demo: | |
| with gr.Column(elem_classes=["centered"]): | |
| gr.Markdown(""" | |
| # π€ Speaker Separation with Voice Activity Detection | |
| **Separate mixed audio into individual speakers** | |
| Choose between standard separation or separation with Voice Activity Detection (VAD). | |
| --- | |
| ### π Example Audio | |
| Select an example audio file below, or upload your own! | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1, elem_classes=["centered", "center-radio"]): | |
| gr.Markdown("### π΅ Input") | |
| example_selector = gr.Dropdown( | |
| choices=list(example_files.keys()), | |
| value=next(iter(example_files.keys()), None), | |
| label="Select Example Audio", | |
| interactive=True | |
| ) | |
| audio_input = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="numpy", | |
| label="π Upload or ποΈ Record Mixed Audio", | |
| show_download_button=True, | |
| value=default_audio | |
| ) | |
| def update_audio_from_example(selected): | |
| path = example_files.get(selected) | |
| return load_example_audio_by_path(path) if path else None | |
| example_selector.change( | |
| fn=update_audio_from_example, | |
| inputs=example_selector, | |
| outputs=audio_input | |
| ) | |
| model_choice = gr.Radio( | |
| choices=[ | |
| "π§ Separation Only", | |
| "π Separation With VAD" | |
| ], | |
| value="π§ Separation Only", | |
| label="π€ Model Selection", | |
| info="VAD models provide voice activity detection for each speaker" | |
| ) | |
| vad_threshold = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.01, | |
| label="VAD Threshold", | |
| visible=False | |
| ) | |
| def toggle_vad_threshold(model_choice_val): | |
| return gr.update(visible=(model_choice_val == "π Separation With VAD")) | |
| model_choice.change( | |
| fn=toggle_vad_threshold, | |
| inputs=model_choice, | |
| outputs=vad_threshold | |
| ) | |
| process_btn = gr.Button("β¨ Separate Speakers", variant="primary", size="lg") | |
| status_output = gr.Textbox( | |
| label="π Processing Status", | |
| interactive=False, | |
| lines=2 | |
| ) | |
| gr.Markdown("### π§ Separated Audio Outputs") | |
| with gr.Row(): | |
| spk1_audio = gr.Audio(label="π€ Speaker 1", show_download_button=True) | |
| spk2_audio = gr.Audio(label="π€ Speaker 2", show_download_button=True) | |
| gr.Markdown("### π Audio Spectrograms") | |
| with gr.Row(): | |
| mixed_spec = gr.Image(label="π΅ Mixed Audio Spectrogram", height=300) | |
| with gr.Row(): | |
| spk1_spec = gr.Image(label="π€ Speaker 1 Spectrogram (with VAD overlay)", height=300) | |
| spk2_spec = gr.Image(label="π€ Speaker 2 Spectrogram (with VAD overlay)", height=300) | |
| # Process button click | |
| process_btn.click( | |
| fn=inference_gradio, | |
| inputs=[audio_input, model_choice, vad_threshold], | |
| outputs=[ | |
| spk1_audio, spk2_audio, | |
| mixed_spec, spk1_spec, spk2_spec, | |
| status_output | |
| ] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### π Instructions: | |
| 1. **Upload** an audio file or **record** directly using the microphone | |
| 2. **Select** your preferred model (with or without VAD) | |
| 3. **If using VAD**, adjust the threshold as needed | |
| 4. **Click** "Separate Speakers" to process | |
| 5. **Download** the separated audio files and view the spectrograms | |
| ### π§ Technical Notes: | |
| - Audio is automatically resampled to 16kHz | |
| - Multi-channel audio uses the first channel | |
| - **Spectrograms**: Show frequency content over time with VAD activity highlighted | |
| - **VAD Overlay**: A white line at the top indicates when the speaker is active | |
| """) | |
| gr.Markdown(""" | |
| --- | |
| <div style=" | |
| background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%); | |
| border-left: 4px solid #3b82f6; | |
| border-radius: 8px; | |
| padding: 20px; | |
| margin: 20px 0; | |
| box-shadow: 0 2px 8px rgba(0,0,0,0.1); | |
| "> | |
| <div style="margin-bottom: 15px;"> | |
| <h3 style="color: #1e40af; margin: 0 0 10px 0; font-size: 1.1em; display: flex; align-items: center;"> | |
| π <span style="margin-left: 8px;">Reference</span> | |
| </h3> | |
| <div style=" | |
| background: white; | |
| padding: 15px; | |
| border-radius: 6px; | |
| border: 1px solid #e2e8f0; | |
| line-height: 1.6; | |
| font-size: 0.95em; | |
| "> | |
| <strong>Opochinsky, R., Moradi, M., & Gannot, S.</strong> (2025).<br> | |
| <em style="color: #374151; font-size: 1.02em;">Single-microphone speaker separation and voice activity detection in noisy and reverberant environments</em>.<br> | |
| <span style="color: #6b7280;">EURASIP Journal on Audio, Speech, and Music Processing</span>, <strong>2025</strong>(1), 18. Springer. | |
| </div> | |
| </div> | |
| <details style="margin-top: 15px;"> | |
| <summary style=" | |
| cursor: pointer; | |
| color: #4f46e5; | |
| font-weight: 600; | |
| padding: 8px 0; | |
| border-bottom: 1px solid #e5e7eb; | |
| margin-bottom: 10px; | |
| user-select: none; | |
| ">π BibTeX Citation</summary> | |
| <div style=" | |
| background: #1f2937; | |
| color: #f9fafb; | |
| padding: 15px; | |
| border-radius: 6px; | |
| font-family: 'Courier New', monospace; | |
| font-size: 0.85em; | |
| line-height: 1.4; | |
| overflow-x: auto; | |
| margin-top: 10px; | |
| "> | |
| <pre style="margin: 0; white-space: pre-wrap;">@article{opochinsky2025single, | |
| title={Single-microphone speaker separation and voice activity detection in noisy and reverberant environments}, | |
| author={Opochinsky, Renana and Moradi, Mordehay and Gannot, Sharon}, | |
| journal={EURASIP Journal on Audio, Speech, and Music Processing}, | |
| volume={2025}, | |
| number={1}, | |
| pages={18}, | |
| year={2025}, | |
| publisher={Springer} | |
| }</pre> | |
| </div> | |
| </details> | |
| </div> | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| # Pre-load models to check availability | |
| print("π Initializing Speaker Separation...") | |
| load_model("without_vad") | |
| # load_model("with_vad") | |
| # Launch interface | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| show_error=True | |
| ) |