voice-tools / src /web /tabs /speaker_extraction.py
jcudit's picture
jcudit HF Staff
fix: resolve ZeroGPU pickling errors across all audio processing services
3fb465f
"""
Gradio tab for speaker extraction workflow
Provides web interface for extracting specific speaker from audio using reference clip.
"""
import json
from pathlib import Path
from typing import Optional, Tuple
import gradio as gr
from src.services.speaker_extraction import SpeakerExtractionService
def create_speaker_extraction_tab() -> gr.Tab:
"""
Create Gradio tab for speaker extraction.
Returns:
Gradio Tab component with speaker extraction interface
"""
# Initialize service (lazy loading)
service = None
def get_service():
"""Lazy-load the speaker extraction service"""
nonlocal service
if service is None:
service = SpeakerExtractionService()
return service
def validate_reference(reference_file) -> Tuple[str, str]:
"""
Validate reference clip.
Returns:
Tuple of (status_message, status_style)
"""
if reference_file is None:
return "Please upload a reference clip", "⚠️"
try:
svc = get_service()
is_valid, message = svc.validate_reference_clip(reference_file)
if is_valid:
if "warning" in message.lower():
return f"✓ Valid (with warning: {message})", "⚠️"
return "✓ Reference clip is valid", "✅"
else:
return f"❌ Invalid: {message}", "❌"
except Exception as e:
return f"❌ Error: {str(e)}", "❌"
def extract_speaker_handler(
reference_file,
target_file,
threshold,
min_confidence,
concatenate,
silence_duration,
crossfade_duration,
sample_rate,
bitrate,
progress=gr.Progress(),
):
"""
Handle speaker extraction request.
Returns:
Tuple of (output_audio, report_json, status_message, download_button)
"""
if reference_file is None:
return None, None, "❌ Please upload a reference clip", gr.update(visible=False)
if target_file is None:
return None, None, "❌ Please upload a target audio file", gr.update(visible=False)
try:
# Progress callback
def progress_callback(stage: str, current: float, total: float):
# Interpret float-based (0.0-1.0) vs integer-based formats
if total == 1.0:
progress_pct = current # Already normalized 0.0-1.0
else:
progress_pct = current / total if total > 0 else 0
progress(progress_pct, desc=stage)
# Create temporary output path
output_dir = Path("./temp_extraction_output")
output_dir.mkdir(parents=True, exist_ok=True)
if concatenate:
output_path = output_dir / "extracted_speaker.m4a"
else:
output_path = output_dir / "segments"
# Perform extraction
progress(0.1, desc="Initializing...")
svc = get_service()
# Note: progress_callback cannot be passed due to ZeroGPU pickling constraints
report = svc.extract_and_export(
reference_clip=reference_file,
target_file=target_file,
output_path=str(output_path),
threshold=threshold,
min_confidence=min_confidence,
concatenate=concatenate,
silence_duration_ms=silence_duration,
crossfade_duration_ms=crossfade_duration,
sample_rate=sample_rate,
bitrate=bitrate,
progress_callback=None, # Cannot pass callback to avoid pickling errors
)
# Check if result is an error report
if report.get("status") == "failed":
error_message = f"❌ **Error ({report['error_type']}):** {report['error']}"
error_report_text = f"""
# Error Report
**Status**: {report["status"]}
**Error Type**: {report["error_type"]}
**Error**: {report["error"]}
{json.dumps(report, indent=2)}
"""
return None, error_report_text, error_message, gr.update(visible=False)
progress(1.0, desc="Complete!")
# Format report for display
report_text = f"""
# Extraction Report
## Summary
- **Reference**: {Path(reference_file).name}
- **Target**: {Path(target_file).name}
- **Threshold**: {threshold:.2f}
- **Segments Found**: {report["segments_found"]}
- **Segments Included**: {report["segments_included"]}
- **Total Duration**: {report["total_duration_seconds"]:.1f}s
- **Average Confidence**: {report["average_confidence"]:.3f}
- **Low Confidence Segments**: {report.get("low_confidence_segments", 0)}
- **Processing Time**: {report["processing_time_seconds"]:.1f}s
## Output
- **File**: {report.get("output_file", "N/A")}
"""
# Return audio file for playback
output_audio = str(output_path) if concatenate and output_path.exists() else None
# Prepare report JSON
report_json = json.dumps(report, indent=2)
# Status message
if report["segments_included"] == 0:
status = f"⚠️ No matching segments found. Try lowering the threshold."
elif report.get("low_confidence_segments", 0) > 0:
status = f"✅ Extracted {report['segments_included']} segment(s) (⚠️ {report['low_confidence_segments']} low confidence)"
else:
status = f"✅ Successfully extracted {report['segments_included']} segment(s)"
# Make download button visible
download_btn = gr.update(visible=True)
return output_audio, report_text, status, download_btn
except Exception as e:
# Catch any unexpected errors not handled by the service
logger.exception("Unexpected error in speaker extraction")
error_report = {
"status": "failed",
"error": f"Unexpected error: {str(e)}",
"error_type": "processing",
}
error_report_text = f"""
# Error Report
**Status**: {error_report["status"]}
**Error Type**: {error_report["error_type"]}
**Error**: {error_report["error"]}
{json.dumps(error_report, indent=2)}
"""
return (
None,
error_report_text,
f"❌ **Error:** {error_report['error']}",
gr.update(visible=False),
)
# Create Gradio interface
with gr.Tab("Speaker Extraction") as tab:
gr.Markdown(
"Extract a specific speaker from audio using a reference clip. "
"Upload a short clip (3+ seconds) of the target speaker's voice, "
"then upload the audio file to extract from."
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Step 1: Upload Reference Clip")
reference_audio = gr.Audio(
label="Reference Clip (3+ seconds of target speaker)",
type="filepath",
sources=["upload"],
)
reference_status = gr.Textbox(
label="Reference Validation", value="", interactive=False
)
validate_btn = gr.Button("Validate Reference", size="sm")
gr.Markdown("### Step 2: Upload Target Audio")
target_audio = gr.Audio(
label="Target Audio File", type="filepath", sources=["upload"]
)
gr.Markdown("### Step 3: Configure Parameters")
threshold_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.40,
step=0.05,
label="Matching Threshold",
info="Lower = stricter matching (0.0-1.0)",
)
min_confidence_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.30,
step=0.05,
label="Minimum Confidence",
info="Minimum confidence to include segments (0.0-1.0)",
)
concatenate_checkbox = gr.Checkbox(
value=True,
label="Concatenate Segments",
info="Combine all matching segments into one file",
)
with gr.Accordion("Advanced Options", open=False):
silence_slider = gr.Slider(
minimum=0,
maximum=500,
value=150,
step=50,
label="Silence Duration (ms)",
info="Silence between concatenated segments",
)
crossfade_slider = gr.Slider(
minimum=0,
maximum=200,
value=75,
step=25,
label="Crossfade Duration (ms)",
info="Crossfade for smooth transitions",
)
sample_rate_dropdown = gr.Dropdown(
choices=[16000, 22050, 44100, 48000],
value=44100,
label="Output Sample Rate (Hz)",
)
bitrate_dropdown = gr.Dropdown(
choices=["128k", "192k", "256k", "320k"],
value="192k",
label="Output Bitrate",
)
extract_btn = gr.Button("Extract Speaker", variant="primary", size="lg")
with gr.Column(scale=1):
gr.Markdown("### Results")
status_box = gr.Textbox(
label="Status",
value="Ready to extract. Upload files and click 'Extract Speaker'.",
interactive=False,
lines=2,
)
output_audio = gr.Audio(label="Extracted Audio", type="filepath")
download_btn = gr.Button("Download Extracted Audio", visible=False, size="sm")
with gr.Accordion("Extraction Report", open=True):
report_display = gr.Markdown("")
with gr.Accordion("Technical Details", open=False):
report_json = gr.Code(label="Report JSON", language="json")
# Examples
gr.Markdown("## Examples")
gr.Examples(
examples=[
[0.40, 0.30, True, 150, 75, "Standard extraction with default settings"],
[0.25, 0.40, True, 150, 75, "Strict matching for high confidence"],
[0.60, 0.20, True, 200, 100, "Permissive matching with more segments"],
[0.40, 0.30, False, 0, 0, "Export segments separately"],
],
inputs=[
threshold_slider,
min_confidence_slider,
concatenate_checkbox,
silence_slider,
crossfade_slider,
],
label="Parameter Presets",
)
# Event handlers
validate_btn.click(
fn=validate_reference, inputs=[reference_audio], outputs=[reference_status]
)
extract_btn.click(
fn=extract_speaker_handler,
inputs=[
reference_audio,
target_audio,
threshold_slider,
min_confidence_slider,
concatenate_checkbox,
silence_slider,
crossfade_slider,
sample_rate_dropdown,
bitrate_dropdown,
],
outputs=[output_audio, report_display, status_box, download_btn],
)
download_btn.click(fn=lambda audio: audio, inputs=[output_audio], outputs=[output_audio])
return tab