Spaces:
Running on Zero
Running on Zero
| """ | |
| Gradio tab for speaker extraction workflow | |
| Provides web interface for extracting specific speaker from audio using reference clip. | |
| """ | |
| import json | |
| from pathlib import Path | |
| from typing import Optional, Tuple | |
| import gradio as gr | |
| from src.services.speaker_extraction import SpeakerExtractionService | |
| def create_speaker_extraction_tab() -> gr.Tab: | |
| """ | |
| Create Gradio tab for speaker extraction. | |
| Returns: | |
| Gradio Tab component with speaker extraction interface | |
| """ | |
| # Initialize service (lazy loading) | |
| service = None | |
| def get_service(): | |
| """Lazy-load the speaker extraction service""" | |
| nonlocal service | |
| if service is None: | |
| service = SpeakerExtractionService() | |
| return service | |
| def validate_reference(reference_file) -> Tuple[str, str]: | |
| """ | |
| Validate reference clip. | |
| Returns: | |
| Tuple of (status_message, status_style) | |
| """ | |
| if reference_file is None: | |
| return "Please upload a reference clip", "⚠️" | |
| try: | |
| svc = get_service() | |
| is_valid, message = svc.validate_reference_clip(reference_file) | |
| if is_valid: | |
| if "warning" in message.lower(): | |
| return f"✓ Valid (with warning: {message})", "⚠️" | |
| return "✓ Reference clip is valid", "✅" | |
| else: | |
| return f"❌ Invalid: {message}", "❌" | |
| except Exception as e: | |
| return f"❌ Error: {str(e)}", "❌" | |
| def extract_speaker_handler( | |
| reference_file, | |
| target_file, | |
| threshold, | |
| min_confidence, | |
| concatenate, | |
| silence_duration, | |
| crossfade_duration, | |
| sample_rate, | |
| bitrate, | |
| progress=gr.Progress(), | |
| ): | |
| """ | |
| Handle speaker extraction request. | |
| Returns: | |
| Tuple of (output_audio, report_json, status_message, download_button) | |
| """ | |
| if reference_file is None: | |
| return None, None, "❌ Please upload a reference clip", gr.update(visible=False) | |
| if target_file is None: | |
| return None, None, "❌ Please upload a target audio file", gr.update(visible=False) | |
| try: | |
| # Progress callback | |
| def progress_callback(stage: str, current: float, total: float): | |
| # Interpret float-based (0.0-1.0) vs integer-based formats | |
| if total == 1.0: | |
| progress_pct = current # Already normalized 0.0-1.0 | |
| else: | |
| progress_pct = current / total if total > 0 else 0 | |
| progress(progress_pct, desc=stage) | |
| # Create temporary output path | |
| output_dir = Path("./temp_extraction_output") | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| if concatenate: | |
| output_path = output_dir / "extracted_speaker.m4a" | |
| else: | |
| output_path = output_dir / "segments" | |
| # Perform extraction | |
| progress(0.1, desc="Initializing...") | |
| svc = get_service() | |
| # Note: progress_callback cannot be passed due to ZeroGPU pickling constraints | |
| report = svc.extract_and_export( | |
| reference_clip=reference_file, | |
| target_file=target_file, | |
| output_path=str(output_path), | |
| threshold=threshold, | |
| min_confidence=min_confidence, | |
| concatenate=concatenate, | |
| silence_duration_ms=silence_duration, | |
| crossfade_duration_ms=crossfade_duration, | |
| sample_rate=sample_rate, | |
| bitrate=bitrate, | |
| progress_callback=None, # Cannot pass callback to avoid pickling errors | |
| ) | |
| # Check if result is an error report | |
| if report.get("status") == "failed": | |
| error_message = f"❌ **Error ({report['error_type']}):** {report['error']}" | |
| error_report_text = f""" | |
| # Error Report | |
| **Status**: {report["status"]} | |
| **Error Type**: {report["error_type"]} | |
| **Error**: {report["error"]} | |
| {json.dumps(report, indent=2)} | |
| """ | |
| return None, error_report_text, error_message, gr.update(visible=False) | |
| progress(1.0, desc="Complete!") | |
| # Format report for display | |
| report_text = f""" | |
| # Extraction Report | |
| ## Summary | |
| - **Reference**: {Path(reference_file).name} | |
| - **Target**: {Path(target_file).name} | |
| - **Threshold**: {threshold:.2f} | |
| - **Segments Found**: {report["segments_found"]} | |
| - **Segments Included**: {report["segments_included"]} | |
| - **Total Duration**: {report["total_duration_seconds"]:.1f}s | |
| - **Average Confidence**: {report["average_confidence"]:.3f} | |
| - **Low Confidence Segments**: {report.get("low_confidence_segments", 0)} | |
| - **Processing Time**: {report["processing_time_seconds"]:.1f}s | |
| ## Output | |
| - **File**: {report.get("output_file", "N/A")} | |
| """ | |
| # Return audio file for playback | |
| output_audio = str(output_path) if concatenate and output_path.exists() else None | |
| # Prepare report JSON | |
| report_json = json.dumps(report, indent=2) | |
| # Status message | |
| if report["segments_included"] == 0: | |
| status = f"⚠️ No matching segments found. Try lowering the threshold." | |
| elif report.get("low_confidence_segments", 0) > 0: | |
| status = f"✅ Extracted {report['segments_included']} segment(s) (⚠️ {report['low_confidence_segments']} low confidence)" | |
| else: | |
| status = f"✅ Successfully extracted {report['segments_included']} segment(s)" | |
| # Make download button visible | |
| download_btn = gr.update(visible=True) | |
| return output_audio, report_text, status, download_btn | |
| except Exception as e: | |
| # Catch any unexpected errors not handled by the service | |
| logger.exception("Unexpected error in speaker extraction") | |
| error_report = { | |
| "status": "failed", | |
| "error": f"Unexpected error: {str(e)}", | |
| "error_type": "processing", | |
| } | |
| error_report_text = f""" | |
| # Error Report | |
| **Status**: {error_report["status"]} | |
| **Error Type**: {error_report["error_type"]} | |
| **Error**: {error_report["error"]} | |
| {json.dumps(error_report, indent=2)} | |
| """ | |
| return ( | |
| None, | |
| error_report_text, | |
| f"❌ **Error:** {error_report['error']}", | |
| gr.update(visible=False), | |
| ) | |
| # Create Gradio interface | |
| with gr.Tab("Speaker Extraction") as tab: | |
| gr.Markdown( | |
| "Extract a specific speaker from audio using a reference clip. " | |
| "Upload a short clip (3+ seconds) of the target speaker's voice, " | |
| "then upload the audio file to extract from." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Step 1: Upload Reference Clip") | |
| reference_audio = gr.Audio( | |
| label="Reference Clip (3+ seconds of target speaker)", | |
| type="filepath", | |
| sources=["upload"], | |
| ) | |
| reference_status = gr.Textbox( | |
| label="Reference Validation", value="", interactive=False | |
| ) | |
| validate_btn = gr.Button("Validate Reference", size="sm") | |
| gr.Markdown("### Step 2: Upload Target Audio") | |
| target_audio = gr.Audio( | |
| label="Target Audio File", type="filepath", sources=["upload"] | |
| ) | |
| gr.Markdown("### Step 3: Configure Parameters") | |
| threshold_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.40, | |
| step=0.05, | |
| label="Matching Threshold", | |
| info="Lower = stricter matching (0.0-1.0)", | |
| ) | |
| min_confidence_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.30, | |
| step=0.05, | |
| label="Minimum Confidence", | |
| info="Minimum confidence to include segments (0.0-1.0)", | |
| ) | |
| concatenate_checkbox = gr.Checkbox( | |
| value=True, | |
| label="Concatenate Segments", | |
| info="Combine all matching segments into one file", | |
| ) | |
| with gr.Accordion("Advanced Options", open=False): | |
| silence_slider = gr.Slider( | |
| minimum=0, | |
| maximum=500, | |
| value=150, | |
| step=50, | |
| label="Silence Duration (ms)", | |
| info="Silence between concatenated segments", | |
| ) | |
| crossfade_slider = gr.Slider( | |
| minimum=0, | |
| maximum=200, | |
| value=75, | |
| step=25, | |
| label="Crossfade Duration (ms)", | |
| info="Crossfade for smooth transitions", | |
| ) | |
| sample_rate_dropdown = gr.Dropdown( | |
| choices=[16000, 22050, 44100, 48000], | |
| value=44100, | |
| label="Output Sample Rate (Hz)", | |
| ) | |
| bitrate_dropdown = gr.Dropdown( | |
| choices=["128k", "192k", "256k", "320k"], | |
| value="192k", | |
| label="Output Bitrate", | |
| ) | |
| extract_btn = gr.Button("Extract Speaker", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Results") | |
| status_box = gr.Textbox( | |
| label="Status", | |
| value="Ready to extract. Upload files and click 'Extract Speaker'.", | |
| interactive=False, | |
| lines=2, | |
| ) | |
| output_audio = gr.Audio(label="Extracted Audio", type="filepath") | |
| download_btn = gr.Button("Download Extracted Audio", visible=False, size="sm") | |
| with gr.Accordion("Extraction Report", open=True): | |
| report_display = gr.Markdown("") | |
| with gr.Accordion("Technical Details", open=False): | |
| report_json = gr.Code(label="Report JSON", language="json") | |
| # Examples | |
| gr.Markdown("## Examples") | |
| gr.Examples( | |
| examples=[ | |
| [0.40, 0.30, True, 150, 75, "Standard extraction with default settings"], | |
| [0.25, 0.40, True, 150, 75, "Strict matching for high confidence"], | |
| [0.60, 0.20, True, 200, 100, "Permissive matching with more segments"], | |
| [0.40, 0.30, False, 0, 0, "Export segments separately"], | |
| ], | |
| inputs=[ | |
| threshold_slider, | |
| min_confidence_slider, | |
| concatenate_checkbox, | |
| silence_slider, | |
| crossfade_slider, | |
| ], | |
| label="Parameter Presets", | |
| ) | |
| # Event handlers | |
| validate_btn.click( | |
| fn=validate_reference, inputs=[reference_audio], outputs=[reference_status] | |
| ) | |
| extract_btn.click( | |
| fn=extract_speaker_handler, | |
| inputs=[ | |
| reference_audio, | |
| target_audio, | |
| threshold_slider, | |
| min_confidence_slider, | |
| concatenate_checkbox, | |
| silence_slider, | |
| crossfade_slider, | |
| sample_rate_dropdown, | |
| bitrate_dropdown, | |
| ], | |
| outputs=[output_audio, report_display, status_box, download_btn], | |
| ) | |
| download_btn.click(fn=lambda audio: audio, inputs=[output_audio], outputs=[output_audio]) | |
| return tab | |