voice-tools / src /web /tabs /speaker_separation.py
jcudit's picture
jcudit HF Staff
fix: make download components visible after speaker separation completes
1a6b8d0
"""
Gradio tab for speaker separation workflow.
Provides UI for separating speakers from multi-speaker audio files.
"""
import json
import logging
import tempfile
from pathlib import Path
from typing import List, Optional, Tuple
import gradio as gr
from src.services.speaker_separation import SpeakerSeparationService
logger = logging.getLogger(__name__)
def create_speaker_separation_tab() -> gr.Tab:
"""
Create the speaker separation tab for the Gradio interface.
Returns:
Configured Gradio Tab component
"""
with gr.Tab("Speaker Separation") as tab:
gr.Markdown(
"""
# πŸ‘₯ Speaker Separation
Analyze multi-speaker audio files to automatically detect and separate
individual speakers into separate audio streams.
Upload an audio file with multiple speakers, and this tool will:
- Detect all speakers automatically
- Separate each speaker's audio
- Export clean individual streams
"""
)
with gr.Row():
with gr.Column(scale=1):
# Input Section
gr.Markdown("### πŸ“€ Input File")
input_audio = gr.Audio(
label="Multi-Speaker Audio File",
type="filepath",
sources=["upload"],
)
# Configuration Section
gr.Markdown("### βš™οΈ Speaker Detection Settings")
with gr.Row():
min_speakers = gr.Slider(
minimum=1,
maximum=10,
value=2,
step=1,
label="Minimum Speakers",
info="Minimum number of speakers expected",
)
max_speakers = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
label="Maximum Speakers",
info="Maximum number of speakers expected",
)
num_speakers = gr.Slider(
minimum=0,
maximum=10,
value=0,
step=1,
label="Exact Speaker Count (0 = auto-detect)",
info="Set to non-zero to specify exact number",
)
with gr.Accordion("Output Settings", open=True):
output_format = gr.Radio(
choices=["m4a", "wav", "mp3"],
value="m4a",
label="Output Format",
)
with gr.Row():
sample_rate = gr.Slider(
minimum=8000,
maximum=48000,
value=44100,
step=100,
label="Sample Rate (Hz)",
)
bitrate = gr.Dropdown(
choices=["128k", "192k", "256k", "320k"],
value="192k",
label="Bitrate",
)
# Action Buttons
with gr.Row():
separate_btn = gr.Button("πŸš€ Separate Speakers", variant="primary", size="lg")
clear_btn = gr.ClearButton(components=[input_audio], value="πŸ—‘οΈ Clear")
with gr.Column(scale=1):
# Output Section
gr.Markdown("### πŸ“Š Results")
# Status
status_output = gr.Textbox(
label="Status",
placeholder="Ready to process...",
interactive=False,
lines=3,
)
# Progress indicator (will be updated during processing)
progress_bar = gr.Progress()
# Results summary
summary_output = gr.JSON(
label="Separation Summary",
visible=False,
)
# Speaker details
with gr.Accordion("Speaker Details", open=True, visible=False) as details_accordion:
speaker_table = gr.Dataframe(
headers=["Speaker", "Duration (s)", "Confidence"],
label="Detected Speakers",
interactive=False,
)
# Download Section
gr.Markdown("### πŸ’Ύ Downloads")
output_files = gr.File(
label="Separated Speaker Files",
file_count="multiple",
interactive=False,
visible=False,
)
report_file = gr.File(
label="Separation Report (JSON)",
interactive=False,
visible=False,
)
# Examples and Tips
gr.Markdown("### πŸ“š Usage Tips")
gr.Markdown(
"""
**How to Use:**
1. **Upload Audio**: Select an M4A, WAV, or MP3 file with multiple speakers
2. **Configure Detection**:
- Use min/max speakers for auto-detection (recommended)
- Or set exact speaker count if you know it
3. **Choose Output**: Select format, sample rate, and bitrate
4. **Separate**: Click the button and wait for processing
5. **Download**: Get individual speaker files and a detailed report
**Best Practices:**
- Clear audio with distinct speakers works best
- If you know the exact speaker count, specify it for better results
- Processing time scales with file duration (expect ~2x realtime)
- M4A format provides best quality-to-size ratio
- For long files (>1 hour), expect several minutes of processing
**Troubleshooting:**
- If fewer speakers detected than expected, try increasing max_speakers
- If too many speakers detected, try increasing min_speakers
- For overlapping speech, the tool will assign to the dominant speaker
"""
)
# Event Handler
separate_btn.click(
fn=_separate_speakers_handler,
inputs=[
input_audio,
min_speakers,
max_speakers,
num_speakers,
output_format,
sample_rate,
bitrate,
],
outputs=[
status_output,
summary_output,
speaker_table,
output_files,
report_file,
details_accordion,
],
)
return tab
def _separate_speakers_handler(
input_audio: Optional[str],
min_speakers: int,
max_speakers: int,
num_speakers: int,
output_format: str,
sample_rate: int,
bitrate: str,
progress=gr.Progress(),
) -> Tuple[str, dict, list, list, str, gr.Accordion]:
"""
Handler function for speaker separation.
Args:
input_audio: Path to input audio file
min_speakers: Minimum speakers to detect
max_speakers: Maximum speakers to detect
num_speakers: Exact speaker count (0 = auto)
output_format: Output format (m4a, wav, mp3)
sample_rate: Output sample rate
bitrate: Output bitrate
progress: Gradio progress tracker
Returns:
Tuple of (status, summary, speaker_data, output_files, report_file, accordion_visibility)
"""
try:
# Validate inputs
if not input_audio:
return (
"❌ Error: Please upload an audio file",
{},
[],
[],
None,
gr.update(visible=False),
)
input_path = Path(input_audio)
if not input_path.exists():
return (
f"❌ Error: File not found: {input_audio}",
{},
[],
[],
None,
gr.update(visible=False),
)
# Validate speaker counts
if min_speakers > max_speakers and num_speakers == 0:
return (
f"❌ Error: Minimum speakers ({min_speakers}) cannot exceed maximum ({max_speakers})",
{},
[],
[],
None,
gr.update(visible=False),
)
# Validate sample rate for M4A
if output_format == "m4a" and sample_rate > 48000:
return (
f"❌ Error: Sample rate {sample_rate} exceeds M4A limit of 48000 Hz",
{},
[],
[],
None,
gr.update(visible=False),
)
# Use exact speaker count if specified
if num_speakers > 0:
min_speakers = num_speakers
max_speakers = num_speakers
# Create temporary output directory
output_dir = Path(tempfile.mkdtemp(prefix="speaker_separation_"))
# Initialize service
progress(0.1, desc="Initializing speaker separation models...")
service = SpeakerSeparationService()
# Progress callback
def progress_callback(stage: str, current: float, total: float):
# Interpret float-based (0.0-1.0) vs integer-based formats
if total == 1.0:
pct = 0.1 + (current * 0.8) # Scale float 0.0-1.0 to 10-90%
else:
pct = 0.1 + (current / total) * 0.8 # Scale integer to 10-90%
progress(pct, desc=stage)
progress(0.1, desc="Starting speaker separation...")
# Run separation
report = service.separate_and_export(
input_file=str(input_path),
output_dir=str(output_dir),
min_speakers=min_speakers,
max_speakers=max_speakers,
output_format=output_format,
sample_rate=sample_rate,
bitrate=bitrate,
progress_callback=progress_callback,
)
# Check if result is an error report
if report.get("status") == "failed":
error_message = f"❌ **Error ({report['error_type']}):** {report['error']}"
# Save error report
error_report_path = output_dir / "error_report.json"
with open(error_report_path, "w") as f:
json.dump(report, f, indent=2)
return (
error_message,
{},
[],
[],
str(error_report_path),
gr.update(visible=False),
)
progress(0.9, desc="Preparing results...")
# Build speaker table data
speaker_data = []
for output_info in report["output_files"]:
speaker_data.append(
[
output_info["speaker_id"],
f"{output_info['duration']:.1f}",
f"{output_info.get('confidence', 1.0):.2f}",
]
)
# Collect output files
output_file_paths = [
str(output_dir / output_info["file"]) for output_info in report["output_files"]
]
# Save report to file
report_path = output_dir / "separation_report.json"
with open(report_path, "w") as f:
json.dump(report, f, indent=2)
# Build status message
status = f"""βœ… Separation Complete!
🎀 Detected {report["speakers_detected"]} speaker(s)
⏱️ Processed in {report["processing_time_seconds"]:.1f} seconds
πŸ“ Output saved to temporary directory
You can download the separated audio files and detailed report below.
"""
# Build summary
summary = {
"speakers_detected": report["speakers_detected"],
"processing_time": f"{report['processing_time_seconds']:.1f}s",
"input_duration": f"{report['input_duration_seconds']:.1f}s",
"output_format": output_format,
"sample_rate": f"{sample_rate} Hz",
}
if "overlapping_segments" in report:
summary["overlapping_segments"] = report["overlapping_segments"]
progress(1.0, desc="Done!")
return (
status,
gr.JSON(value=summary, visible=True),
speaker_data,
gr.File(value=output_file_paths, visible=True),
gr.File(value=str(report_path), visible=True),
gr.update(visible=True),
)
except Exception as e:
# Catch any unexpected errors not handled by the service
logger.exception("Unexpected error in speaker separation")
error_report = {
"status": "failed",
"error": f"Unexpected error: {str(e)}",
"error_type": "processing",
}
# Save error report
error_report_path = output_dir / "error_report.json"
with open(error_report_path, "w") as f:
json.dump(error_report, f, indent=2)
return (
f"❌ **Error:** {error_report['error']}",
{},
[],
[],
str(error_report_path),
gr.update(visible=False),
)