Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Gradio Web Interface for Real-Time VAD + Speaker Diarization | |
| Interactive demo with visualizations | |
| """ | |
| import gradio as gr | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| from pathlib import Path | |
| import json | |
| import os | |
| import tempfile | |
| import soundfile as sf | |
| from typing import Optional, Tuple, List, Dict | |
| from datetime import datetime | |
| from src.pipeline import VADDiarizationPipeline | |
| from src.utils import visualize_timeline, segment_to_rttm | |
| # Initialize pipeline | |
| print("Initializing pipeline...") | |
| HF_TOKEN = os.environ.get('HF_TOKEN', None) | |
| if not HF_TOKEN: | |
| print("⚠️ No HF_TOKEN found. Set it with: export HF_TOKEN='your_token_here'") | |
| print("Pipeline will work with VAD only until token is provided.") | |
| try: | |
| pipeline = VADDiarizationPipeline( | |
| use_auth_token=HF_TOKEN, | |
| vad_threshold=0.5 | |
| ) | |
| PIPELINE_READY = True | |
| except Exception as e: | |
| print(f"⚠️ Could not initialize full pipeline: {e}") | |
| print("Will use VAD-only mode") | |
| PIPELINE_READY = False | |
| def apply_speaker_names(segments: List[Dict], speaker_mapping: Dict[str, str]) -> List[Dict]: | |
| """Apply custom speaker names to segments.""" | |
| if not speaker_mapping: | |
| return segments | |
| renamed_segments = [] | |
| for seg in segments: | |
| new_seg = seg.copy() | |
| if seg['speaker'] in speaker_mapping and speaker_mapping[seg['speaker']]: | |
| new_seg['speaker'] = speaker_mapping[seg['speaker']] | |
| renamed_segments.append(new_seg) | |
| return renamed_segments | |
| def create_timeline_plot(segments: List[Dict], duration: float) -> plt.Figure: | |
| """Create a visual timeline plot of speaker segments.""" | |
| fig, ax = plt.subplots(figsize=(12, 4)) | |
| # Get unique speakers and assign colors | |
| speakers = sorted(set(seg['speaker'] for seg in segments)) | |
| colors = plt.cm.Set3(np.linspace(0, 1, len(speakers))) | |
| speaker_colors = {speaker: colors[i] for i, speaker in enumerate(speakers)} | |
| # Plot segments | |
| for seg in segments: | |
| color = speaker_colors[seg['speaker']] | |
| ax.barh( | |
| 0, | |
| seg['duration'], | |
| left=seg['start'], | |
| height=0.8, | |
| color=color, | |
| edgecolor='black', | |
| linewidth=0.5 | |
| ) | |
| # Add speaker label in the middle of long segments | |
| if seg['duration'] > 1.0: | |
| mid = seg['start'] + seg['duration'] / 2 | |
| ax.text( | |
| mid, 0, seg['speaker'], | |
| ha='center', va='center', | |
| fontsize=8, fontweight='bold' | |
| ) | |
| # Formatting | |
| ax.set_xlim(0, duration) | |
| ax.set_ylim(-0.5, 0.5) | |
| ax.set_xlabel('Time (seconds)', fontsize=12) | |
| ax.set_yticks([]) | |
| ax.set_title('Speaker Timeline', fontsize=14, fontweight='bold') | |
| ax.grid(True, axis='x', alpha=0.3) | |
| # Legend | |
| legend_patches = [ | |
| mpatches.Patch(color=speaker_colors[speaker], label=speaker) | |
| for speaker in speakers | |
| ] | |
| ax.legend(handles=legend_patches, loc='upper right') | |
| plt.tight_layout() | |
| return fig | |
| def process_audio( | |
| audio_file, | |
| audio_record, | |
| num_speakers: Optional[int] = None, | |
| vad_threshold: float = 0.5, | |
| speaker_names: str = "", | |
| progress=gr.Progress() | |
| ) -> Tuple[str, str, str, plt.Figure, str]: | |
| """ | |
| Process audio file through the pipeline. | |
| Handles both uploaded files and recorded audio. | |
| Returns: | |
| Tuple of (summary_text, timeline_text, json_output, plot, download_path) | |
| """ | |
| # Use recorded audio if available, otherwise use uploaded file | |
| audio_source = audio_record if audio_record is not None else audio_file | |
| if audio_source is None: | |
| return "Please upload an audio file or record using the microphone", "", "", None, None | |
| if not PIPELINE_READY: | |
| return "Pipeline not ready. Please set HF_TOKEN environment variable.", "", "", None, None | |
| try: | |
| progress(0.1, desc="Loading audio...") | |
| # Update VAD threshold if changed | |
| pipeline.vad.threshold = vad_threshold | |
| progress(0.3, desc="Running VAD...") | |
| # Process file | |
| num_speakers_param = int(num_speakers) if num_speakers and num_speakers > 0 else None | |
| progress(0.5, desc="Running speaker diarization...") | |
| result = pipeline.process_file( | |
| audio_source, | |
| num_speakers=num_speakers_param, | |
| return_vad=True, | |
| return_stats=True | |
| ) | |
| progress(0.8, desc="Generating visualizations...") | |
| # Parse speaker names | |
| speaker_mapping = {} | |
| if speaker_names.strip(): | |
| lines = [line.strip() for line in speaker_names.strip().split('\n') if line.strip()] | |
| for line in lines: | |
| if ':' in line: | |
| parts = line.split(':', 1) | |
| speaker_id = parts[0].strip() | |
| custom_name = parts[1].strip() | |
| if custom_name: | |
| speaker_mapping[speaker_id] = custom_name | |
| # Apply custom speaker names | |
| if speaker_mapping: | |
| result['speaker_segments'] = apply_speaker_names(result['speaker_segments'], speaker_mapping) | |
| # Update speaker statistics with new names | |
| if 'speaker_statistics' in result: | |
| new_stats = {} | |
| for speaker, stats in result['speaker_statistics'].items(): | |
| new_name = speaker_mapping.get(speaker, speaker) | |
| new_stats[new_name] = stats | |
| result['speaker_statistics'] = new_stats | |
| # Create summary | |
| summary_lines = [] | |
| summary_lines.append("# Processing Results\n") | |
| # Determine source type for display | |
| source_type = "Recorded Audio" if audio_record is not None else "Uploaded File" | |
| file_name = Path(audio_source).name if audio_source else "Unknown" | |
| summary_lines.append(f"**Source:** {source_type}\n") | |
| summary_lines.append(f"**File:** {file_name}\n") | |
| summary_lines.append(f"**Speakers Detected:** {result['metadata']['num_speakers']}") | |
| summary_lines.append(f"**Speaker Segments:** {result['metadata']['num_segments']}") | |
| summary_lines.append(f"**Total Speech Time:** {result['metadata']['total_speech_time']:.2f}s\n") | |
| summary_lines.append("## Processing Time") | |
| summary_lines.append(f"- VAD: {result['processing_time']['vad_ms']:.2f}ms") | |
| summary_lines.append(f"- Diarization: {result['processing_time']['diarization_ms']:.2f}ms") | |
| summary_lines.append(f"- **Total: {result['processing_time']['total_ms']:.2f}ms**\n") | |
| # Speaker statistics | |
| if 'speaker_statistics' in result: | |
| summary_lines.append("## Speaker Statistics\n") | |
| for speaker, stats in result['speaker_statistics'].items(): | |
| summary_lines.append(f"### {speaker}") | |
| summary_lines.append(f"- Total speaking time: {stats['total_time']:.2f}s") | |
| summary_lines.append(f"- Number of segments: {stats['num_segments']}") | |
| summary_lines.append(f"- Average segment duration: {stats['avg_segment_duration']:.2f}s\n") | |
| summary_text = "\n".join(summary_lines) | |
| # Create timeline text | |
| timeline_lines = ["# Speaker Timeline\n"] | |
| timeline_lines.append("```") | |
| for seg in result['speaker_segments']: | |
| timeline_lines.append( | |
| f"{seg['start']:7.2f}s - {seg['end']:7.2f}s: {seg['speaker']} ({seg['duration']:.2f}s)" | |
| ) | |
| timeline_lines.append("```") | |
| timeline_text = "\n".join(timeline_lines) | |
| # JSON output | |
| json_output = json.dumps(result, indent=2, default=str) | |
| # Create plot | |
| duration = max(seg['end'] for seg in result['speaker_segments']) | |
| plot = create_timeline_plot(result['speaker_segments'], duration) | |
| # Save processed audio info for download | |
| download_path = audio_source | |
| progress(1.0, desc="Complete!") | |
| return summary_text, timeline_text, json_output, plot, download_path | |
| except Exception as e: | |
| error_msg = f"Error processing audio: {str(e)}\n\n" | |
| error_msg += "Make sure you have:\n" | |
| error_msg += "1. Valid HF_TOKEN environment variable\n" | |
| error_msg += "2. Accepted model conditions at https://huggingface.co/pyannote/speaker-diarization-3.1" | |
| return error_msg, "", "", None, None | |
| def create_demo(): | |
| """Create Gradio interface.""" | |
| with gr.Blocks(title="VAD + Speaker Diarization", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🎙️ Real-Time Voice Activity Detection + Speaker Diarization | |
| Upload an audio file to detect speech segments and identify different speakers. | |
| **Features:** | |
| - Voice Activity Detection (VAD) with <100ms latency | |
| - Speaker Diarization with state-of-the-art accuracy | |
| - Visual timeline of speaker segments | |
| - Detailed statistics and JSON export | |
| **Supported formats:** WAV, MP3, FLAC, OGG, M4A | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Input") | |
| with gr.Tabs() as input_tabs: | |
| with gr.Tab("📁 Upload File"): | |
| audio_input = gr.Audio( | |
| label="Upload Audio File", | |
| type="filepath", | |
| sources=["upload"] | |
| ) | |
| with gr.Tab("🎤 Record Live"): | |
| audio_record = gr.Audio( | |
| label="Record Audio", | |
| type="filepath", | |
| sources=["microphone"] | |
| ) | |
| gr.Markdown(""" | |
| **Tips for recording:** | |
| - Click the microphone icon to start recording | |
| - Speak clearly and avoid background noise | |
| - Click stop when finished | |
| - Click "🚀 Process Audio" button below to analyze | |
| **Recording Info:** | |
| - Max duration: Unlimited (browser dependent) | |
| - Format: WAV (automatically converted) | |
| - Storage: Temporary (deleted after session) | |
| - Download: Available after processing | |
| """) | |
| with gr.Accordion("⚙️ Advanced Settings", open=False): | |
| num_speakers = gr.Number( | |
| label="Number of Speakers (0 for auto-detection)", | |
| value=0, | |
| precision=0, | |
| minimum=0, | |
| maximum=10, | |
| info="Set to 0 for automatic speaker detection" | |
| ) | |
| vad_threshold = gr.Slider( | |
| label="VAD Sensitivity Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.5, | |
| step=0.05, | |
| info="Lower = more sensitive to speech" | |
| ) | |
| gr.Markdown("### 👥 Custom Speaker Names") | |
| gr.Markdown(""" | |
| Enter custom names for speakers (one per line): | |
| Format: `SPEAKER_00: John Doe` | |
| Example: | |
| ``` | |
| SPEAKER_00: Alice | |
| SPEAKER_01: Bob | |
| SPEAKER_02: Charlie | |
| ``` | |
| """) | |
| speaker_names = gr.Textbox( | |
| label="Speaker Name Mapping", | |
| placeholder="SPEAKER_00: Alice\nSPEAKER_01: Bob", | |
| lines=5, | |
| info="Leave empty to use default speaker labels" | |
| ) | |
| process_btn = gr.Button("🚀 Process Audio", variant="primary", size="lg") | |
| gr.Markdown(""" | |
| ### Tips: | |
| - For best results, use clear audio with minimal background noise | |
| - Specify number of speakers if known for better accuracy | |
| - Adjust VAD threshold if speech is not detected properly | |
| """) | |
| with gr.Column(scale=2): | |
| gr.Markdown("## Results") | |
| with gr.Tab("Summary"): | |
| summary_output = gr.Markdown(label="Summary") | |
| with gr.Tab("Timeline"): | |
| timeline_plot = gr.Plot(label="Visual Timeline") | |
| timeline_output = gr.Markdown(label="Timeline Details") | |
| with gr.Tab("JSON Export"): | |
| json_output = gr.Code( | |
| label="Full Results (JSON)", | |
| language="json", | |
| lines=20 | |
| ) | |
| with gr.Tab("📥 Download"): | |
| gr.Markdown("### Download Processed Audio") | |
| download_audio = gr.File( | |
| label="Download Audio File", | |
| interactive=False | |
| ) | |
| gr.Markdown(""" | |
| The original audio file is available for download here. | |
| You can use it with the JSON results for further processing. | |
| """) | |
| # Examples | |
| gr.Markdown("## 📝 Examples") | |
| gr.Markdown(""" | |
| Try the demo with your own audio files or use sample data from the FEARLESS STEPS dataset. | |
| **Expected Performance:** | |
| - VAD Latency: <100ms per second of audio | |
| - Diarization Error Rate (DER): ~19-20% on benchmark datasets | |
| - Processing Time: Depends on audio length and hardware | |
| """) | |
| # Event handler for process button (works with both upload and recording) | |
| process_btn.click( | |
| fn=process_audio, | |
| inputs=[audio_input, audio_record, num_speakers, vad_threshold, speaker_names], | |
| outputs=[summary_output, timeline_output, json_output, timeline_plot, download_audio] | |
| ) | |
| # Optional: Auto-process when recording stops | |
| # Uncomment the following lines if you want automatic processing after recording | |
| # audio_record.stop_recording( | |
| # fn=process_audio, | |
| # inputs=[audio_input, audio_record, num_speakers, vad_threshold, speaker_names], | |
| # outputs=[summary_output, timeline_output, json_output, timeline_plot, download_audio] | |
| # ) | |
| # Footer | |
| gr.Markdown(""" | |
| --- | |
| **Tech Stack:** Silero VAD + Pyannote.audio 3.1 | **GPU:** CUDA 12.5+ supported | |
| **Note:** First run may take longer due to model downloads (~1GB) | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| # Launch settings | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) | |