Sep-TFAnet-VAD / app.py
Mordehay's picture
change color
73b58d1
from logging import config
import gradio as gr
import torch
import numpy as np
import soundfile as sf
from scipy.signal import resample
import matplotlib.pyplot as plt
from pathlib import Path
import librosa
import librosa.display
from matplotlib.colors import LinearSegmentedColormap
import io
from PIL import Image
import traceback
from scipy.signal import get_window
from model import SeparationModel, InputSpec
import json
import torchaudio
import torch
# Global model storage
models = {}
def load_model(model_type="without_vad"):
"""Load the appropriate model based on user selection"""
if model_type not in models:
if model_type == "with_vad":
# Load model with VAD capabilities
model_path = "model_with_vad.pth"
print(f"Loading model with VAD: {model_path}")
with open('config_with_vad.json', 'r') as f:
config = json.load(f)
params = config.get("arch", {}).get("args", {})
else:
# Load standard model without VAD
model_path = "model_without_vad.pth"
print(f"Loading standard model: {model_path}")
with open('config_without_vad.json', 'r') as f:
config = json.load(f)
params = config.get("arch", {}).get("args", {})
try:
model = SeparationModel(**params)
checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
state_dict = checkpoint['state_dict']
model.load_state_dict(state_dict, strict=True)
model.eval()
models[model_type] = model
print(f"βœ… Model {model_type} loaded successfully")
except FileNotFoundError:
print(f"⚠️ Model file {model_path} not found. Using dummy model.")
# Create a dummy model for demonstration
models[model_type] = create_dummy_model(with_vad=(model_type == "with_vad"))
return models[model_type]
def create_dummy_model(with_vad=False):
"""Create a dummy model for demonstration purposes"""
class DummyModel(torch.nn.Module):
def __init__(self, with_vad=False):
super().__init__()
self.with_vad = with_vad
def forward(self, x):
# Simple dummy separation: add some noise and phase shifts
spk1 = x + 0.1 * torch.randn_like(x)
spk2 = x + 0.1 * torch.randn_like(x) * -1
if self.with_vad:
# Dummy VAD: random activation patterns
batch_size, seq_len = x.shape
vad1 = torch.sigmoid(torch.randn(batch_size, seq_len // 1000)) # Downsampled VAD
vad2 = torch.sigmoid(torch.randn(batch_size, seq_len // 1000))
return torch.stack([spk1, spk2]), torch.stack([vad1, vad2])
else:
return torch.stack([spk1, spk2])
return DummyModel(with_vad=with_vad)
def separate_speakers(mixed_audio, model_type="without_vad"):
"""Separate speakers using the selected model"""
model = load_model(model_type)
print(f"The mixed audio shape is: {mixed_audio.shape}")
audio_tensor = torch.from_numpy(mixed_audio).float().unsqueeze(0)
with torch.no_grad():
if model_type == "with_vad":
separated, vad = model(audio_tensor)
separated = separated.squeeze(0) # Remove batch dimension
vad = vad.squeeze(0) # Remove batch dimension
spk1 = separated[0].cpu().numpy()
spk2 = separated[1].cpu().numpy()
vad1 = vad[0].cpu().numpy()
vad2 = vad[1].cpu().numpy()
return spk1, spk2, vad1, vad2
else:
separated, vad = model(audio_tensor)
separated = separated.squeeze(0) # Remove batch dimension
spk1 = separated[0].cpu().numpy()
spk2 = separated[1].cpu().numpy()
print(f"Separated speakers: {spk1.shape}, {spk2.shape}")
return spk1, spk2, None, None
def create_spectrogram(audio, sr=16000, title="Spectrogram", vad_data=None, vad_threshold=0.7):
"""Create a beautiful spectrogram plot with optional VAD overlay"""
# Compute Short-Time Fourier Transform
D = librosa.stft(audio, hop_length=512, n_fft=2048)
S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
# Create figure with custom styling
plt.style.use('dark_background')
fig, ax = plt.subplots(figsize=(12, 6))
# Custom colormap for better visual appeal
colors = ['#0d1117', '#1f2937', '#3730a3', '#7c3aed', '#ec4899', '#f59e0b', '#eab308']
n_bins = 256
cmap = LinearSegmentedColormap.from_list('custom', colors, N=n_bins)
# Parameters for InputSpec (can be adjusted)
n_fft = 512
hop_length = 256
win_length = 512
# Convert audio to torch tensor
audio_tensor = torch.tensor(audio, dtype=torch.float32)
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0) # (1, N)
# Compute spectrogram using InputSpec
spec_layer = InputSpec(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
stft = spec_layer(audio_tensor)
# Compute magnitude and convert to dB
S = stft.abs().squeeze(0).numpy()
S_db = librosa.amplitude_to_db(S, ref=np.max)
# Plot using the same colormap
img = librosa.display.specshow(S_db, sr=sr, hop_length=hop_length, x_axis='time',
y_axis='hz', ax=ax, cmap=cmap, vmin=-80, vmax=0)
# Overlay VAD if provided
if vad_data is not None:
# threshold = 0.7
threshold = vad_threshold
# Time axis aligned with spectrogram
vad_time_axis = librosa.frames_to_time(np.arange(len(vad_data)), sr=sr, hop_length=hop_length)
# Convert VAD scores to binary (1=voice, 0=no voice)
vad_mask = vad_data > threshold
# Frequency range for highlighting (in Hz)
freq_max = sr // 2
vad_height_min = 0.15 * freq_max
vad_height_max = freq_max
# Set VAD values to high or low frequency based on activity
vad_y = np.where(vad_mask, vad_height_max, vad_height_min)
# Create a twin y-axis
ax2 = ax.twinx()
# Option 1: shaded region
# ax2.fill_between(vad_time_axis, vad_height_min, vad_y, color='#10b981', alpha=0.4)
# Option 2 (alternative): line plot
ax2.plot(vad_time_axis, vad_y, color="#F1F1F1", linewidth=3, alpha=0.9)
# Adjust secondary axis
ax2.set_ylim(0, freq_max)
ax2.set_xlim(ax.get_xlim())
ax2.set_ylabel('')
ax2.set_yticks([])
ax2.spines['right'].set_visible(False)
ax2.spines['top'].set_visible(False)
ax2.spines['left'].set_visible(False)
# Add legend text
ax2.text(0.02, 0.95, 'Voice Activity', transform=ax2.transAxes,
color="#F1F1F1", fontsize=10, fontweight='bold',
bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7))
# Styling
ax.set_title(f'{title}', fontsize=16, fontweight='bold', color='white', pad=20)
ax.set_xlabel('Time (seconds)', fontsize=12, color='white')
ax.set_ylabel('Frequency (Hz)', fontsize=12, color='white')
ax.grid(True, alpha=0.3)
# Add colorbar
cbar = fig.colorbar(img, ax=ax, format='%+2.0f dB')
cbar.set_label('Amplitude (dB)', fontsize=12, color='white')
cbar.ax.yaxis.set_tick_params(color='white')
plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='white')
# Set background
fig.patch.set_facecolor('#0d1117')
ax.set_facecolor('#0d1117')
plt.tight_layout()
# Convert to image for Gradio
buf = io.BytesIO()
plt.savefig(buf, format='png', facecolor='#0d1117', edgecolor='none', dpi=150)
buf.seek(0)
img = Image.open(buf)
plt.close()
return img
def inference_gradio(audio, model_choice, vad_threshold=0.7):
"""Main inference function for Gradio interface"""
if audio is None:
return None, None, None, None, None, None, None, "❌ Please upload or record an audio file."
try:
print(audio)
samplerate, audio_data = audio
audio_array = np.array(audio_data, dtype=np.float32)
# Handle multi-channel audio
if audio_array.ndim > 1:
status_msg = "πŸ”„ Multi-channel audio detected, using first channel."
if audio_array.shape[1] > audio_array.shape[0]:
audio_array = audio_array[0]
else:
audio_array = audio_array[:, 0]
else:
status_msg = "βœ… Processing mono audio."
# Resample to 16kHz if necessary
if samplerate != 16000:
len_audio = len(audio_array)
new_len = int(len_audio * 16000 / samplerate)
audio_array = resample(audio_array, new_len)
status_msg += f" Resampled from {samplerate}Hz to 16kHz."
# Normalize audio
if audio_array.max() != audio_array.min():
normalized_audio = 1.8 * (audio_array - audio_array.min()) / (audio_array.max() - audio_array.min()) - 0.9
else:
normalized_audio = audio_array
# Determine model type
print(f"Selected model: {model_choice}")
model_type = "with_vad" if "VAD" in model_choice else "without_vad"
# Separate speakers
spk1, spk2, vad1, vad2 = separate_speakers(normalized_audio, model_type)
# Create spectrograms with VAD overlay if available
mixed_spec = create_spectrogram(normalized_audio, title="Mixed Audio Spectrogram")
if model_type == "with_vad" and vad1 is not None and vad2 is not None:
spk1_spec = create_spectrogram(spk1, title="Speaker 1 Spectrogram + VAD", vad_data=vad1, vad_threshold=vad_threshold)
spk2_spec = create_spectrogram(spk2, title="Speaker 2 Spectrogram + VAD", vad_data=vad2, vad_threshold=vad_threshold)
else:
spk1_spec = create_spectrogram(spk1, title="Speaker 1 Spectrogram")
spk2_spec = create_spectrogram(spk2, title="Speaker 2 Spectrogram")
# For backwards compatibility, set VAD plots to None since they're now overlaid
vad1_plot = None
vad2_plot = None
status_msg += f" βœ… Successfully separated using {model_choice}!"
# Return audio and visualizations
return (
(16000, spk1),
(16000, spk2),
mixed_spec,
spk1_spec,
spk2_spec,
status_msg
)
except Exception as e:
error_msg = f"❌ Error during processing: {str(e)}"
traceback.print_exc()
return None, None, None, None, None, error_msg
def list_example_audios():
"""Return a dict of example wav files in the current directory."""
example_files = sorted(Path(".").glob("Mixed*.wav"))
return {f.name: str(f) for f in example_files}
def load_example_audio_by_path(path):
"""Load a wav file by path for Gradio."""
if Path(path).exists():
audio, sr = sf.read(path)
if audio.ndim > 1:
audio = audio[:, 0]
return (sr, audio)
return None
def load_example_audio():
"""Load example audio file as (sample_rate, np.ndarray) tuple for Gradio."""
example_path = "example_mixed.wav"
if Path(example_path).exists():
audio, sr = sf.read(example_path)
# Ensure mono for demo
if audio.ndim > 1:
audio = audio[:, 0]
return (sr, audio)
return None
# Create the Gradio interface
def create_interface():
example_files = list_example_audios()
default_example = next(iter(example_files.values()), None)
default_audio = load_example_audio_by_path(default_example) if default_example else None
with gr.Blocks(css="""
.centered {
align-items: center;
text-align: center;
}
.center-radio .gr-form {
align-items: center;
}
.center-radio .gr-radio {
display: flex;
flex-direction: column;
align-items: center;
}
""",
theme = gr.themes.Soft(
primary_hue="blue",
secondary_hue="stone",
neutral_hue="zinc"
).set(
# Background colors - softer, warmer grays
background_fill_primary="#f8f9fa",
background_fill_secondary="#f1f3f4",
block_background_fill="#ffffff",
# Borders - very subtle
border_color_primary="#e5e7eb",
border_color_accent="#d1d5db",
# Text colors - muted
body_text_color="#374151",
body_text_color_subdued="#6b7280",
# Button colors - understated
button_primary_background_fill="#f3f4f6",
button_primary_background_fill_hover="#e5e7eb",
button_primary_text_color="#374151",
# Input fields - clean and minimal
input_background_fill="#fafafa",
input_background_fill_focus="#ffffff",
input_border_color="#e5e7eb",
input_border_color_focus="#d1d5db",
# Accent colors - very muted
color_accent="#8b5cf6",
color_accent_soft="#f3f0ff",
# Shadows - barely visible
shadow_drop="0 1px 3px 0 rgba(0, 0, 0, 0.05)",
shadow_drop_lg="0 4px 6px -1px rgba(0, 0, 0, 0.05)"
)
) as demo:
with gr.Column(elem_classes=["centered"]):
gr.Markdown("""
# 🎀 Speaker Separation with Voice Activity Detection
**Separate mixed audio into individual speakers**
Choose between standard separation or separation with Voice Activity Detection (VAD).
---
### πŸ“ Example Audio
Select an example audio file below, or upload your own!
""")
with gr.Row():
with gr.Column(scale=1, elem_classes=["centered", "center-radio"]):
gr.Markdown("### 🎡 Input")
example_selector = gr.Dropdown(
choices=list(example_files.keys()),
value=next(iter(example_files.keys()), None),
label="Select Example Audio",
interactive=True
)
audio_input = gr.Audio(
sources=["upload", "microphone"],
type="numpy",
label="πŸ“ Upload or πŸŽ™οΈ Record Mixed Audio",
show_download_button=True,
value=default_audio
)
def update_audio_from_example(selected):
path = example_files.get(selected)
return load_example_audio_by_path(path) if path else None
example_selector.change(
fn=update_audio_from_example,
inputs=example_selector,
outputs=audio_input
)
model_choice = gr.Radio(
choices=[
"πŸ”§ Separation Only",
"πŸš€ Separation With VAD"
],
value="πŸ”§ Separation Only",
label="πŸ€– Model Selection",
info="VAD models provide voice activity detection for each speaker"
)
vad_threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.01,
label="VAD Threshold",
visible=False
)
def toggle_vad_threshold(model_choice_val):
return gr.update(visible=(model_choice_val == "πŸš€ Separation With VAD"))
model_choice.change(
fn=toggle_vad_threshold,
inputs=model_choice,
outputs=vad_threshold
)
process_btn = gr.Button("✨ Separate Speakers", variant="primary", size="lg")
status_output = gr.Textbox(
label="πŸ“Š Processing Status",
interactive=False,
lines=2
)
gr.Markdown("### 🎧 Separated Audio Outputs")
with gr.Row():
spk1_audio = gr.Audio(label="πŸ‘€ Speaker 1", show_download_button=True)
spk2_audio = gr.Audio(label="πŸ‘€ Speaker 2", show_download_button=True)
gr.Markdown("### πŸ“Š Audio Spectrograms")
with gr.Row():
mixed_spec = gr.Image(label="🎡 Mixed Audio Spectrogram", height=300)
with gr.Row():
spk1_spec = gr.Image(label="πŸ‘€ Speaker 1 Spectrogram (with VAD overlay)", height=300)
spk2_spec = gr.Image(label="πŸ‘€ Speaker 2 Spectrogram (with VAD overlay)", height=300)
# Process button click
process_btn.click(
fn=inference_gradio,
inputs=[audio_input, model_choice, vad_threshold],
outputs=[
spk1_audio, spk2_audio,
mixed_spec, spk1_spec, spk2_spec,
status_output
]
)
gr.Markdown("""
---
### πŸ“‹ Instructions:
1. **Upload** an audio file or **record** directly using the microphone
2. **Select** your preferred model (with or without VAD)
3. **If using VAD**, adjust the threshold as needed
4. **Click** "Separate Speakers" to process
5. **Download** the separated audio files and view the spectrograms
### πŸ”§ Technical Notes:
- Audio is automatically resampled to 16kHz
- Multi-channel audio uses the first channel
- **Spectrograms**: Show frequency content over time with VAD activity highlighted
- **VAD Overlay**: A white line at the top indicates when the speaker is active
""")
gr.Markdown("""
---
<div style="
background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%);
border-left: 4px solid #3b82f6;
border-radius: 8px;
padding: 20px;
margin: 20px 0;
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
">
<div style="margin-bottom: 15px;">
<h3 style="color: #1e40af; margin: 0 0 10px 0; font-size: 1.1em; display: flex; align-items: center;">
πŸ“– <span style="margin-left: 8px;">Reference</span>
</h3>
<div style="
background: white;
padding: 15px;
border-radius: 6px;
border: 1px solid #e2e8f0;
line-height: 1.6;
font-size: 0.95em;
">
<strong>Opochinsky, R., Moradi, M., & Gannot, S.</strong> (2025).<br>
<em style="color: #374151; font-size: 1.02em;">Single-microphone speaker separation and voice activity detection in noisy and reverberant environments</em>.<br>
<span style="color: #6b7280;">EURASIP Journal on Audio, Speech, and Music Processing</span>, <strong>2025</strong>(1), 18. Springer.
</div>
</div>
<details style="margin-top: 15px;">
<summary style="
cursor: pointer;
color: #4f46e5;
font-weight: 600;
padding: 8px 0;
border-bottom: 1px solid #e5e7eb;
margin-bottom: 10px;
user-select: none;
">πŸ“‹ BibTeX Citation</summary>
<div style="
background: #1f2937;
color: #f9fafb;
padding: 15px;
border-radius: 6px;
font-family: 'Courier New', monospace;
font-size: 0.85em;
line-height: 1.4;
overflow-x: auto;
margin-top: 10px;
">
<pre style="margin: 0; white-space: pre-wrap;">@article{opochinsky2025single,
title={Single-microphone speaker separation and voice activity detection in noisy and reverberant environments},
author={Opochinsky, Renana and Moradi, Mordehay and Gannot, Sharon},
journal={EURASIP Journal on Audio, Speech, and Music Processing},
volume={2025},
number={1},
pages={18},
year={2025},
publisher={Springer}
}</pre>
</div>
</details>
</div>
""")
return demo
if __name__ == "__main__":
# Pre-load models to check availability
print("πŸš€ Initializing Speaker Separation...")
load_model("without_vad")
# load_model("with_vad")
# Launch interface
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True
)