""" Gradio tab for speaker extraction workflow Provides web interface for extracting specific speaker from audio using reference clip. """ import json from pathlib import Path from typing import Optional, Tuple import gradio as gr from src.services.speaker_extraction import SpeakerExtractionService def create_speaker_extraction_tab() -> gr.Tab: """ Create Gradio tab for speaker extraction. Returns: Gradio Tab component with speaker extraction interface """ # Initialize service (lazy loading) service = None def get_service(): """Lazy-load the speaker extraction service""" nonlocal service if service is None: service = SpeakerExtractionService() return service def validate_reference(reference_file) -> Tuple[str, str]: """ Validate reference clip. Returns: Tuple of (status_message, status_style) """ if reference_file is None: return "Please upload a reference clip", "⚠️" try: svc = get_service() is_valid, message = svc.validate_reference_clip(reference_file) if is_valid: if "warning" in message.lower(): return f"✓ Valid (with warning: {message})", "⚠️" return "✓ Reference clip is valid", "✅" else: return f"❌ Invalid: {message}", "❌" except Exception as e: return f"❌ Error: {str(e)}", "❌" def extract_speaker_handler( reference_file, target_file, threshold, min_confidence, concatenate, silence_duration, crossfade_duration, sample_rate, bitrate, progress=gr.Progress(), ): """ Handle speaker extraction request. Returns: Tuple of (output_audio, report_json, status_message, download_button) """ if reference_file is None: return None, None, "❌ Please upload a reference clip", gr.update(visible=False) if target_file is None: return None, None, "❌ Please upload a target audio file", gr.update(visible=False) try: # Progress callback def progress_callback(stage: str, current: float, total: float): # Interpret float-based (0.0-1.0) vs integer-based formats if total == 1.0: progress_pct = current # Already normalized 0.0-1.0 else: progress_pct = current / total if total > 0 else 0 progress(progress_pct, desc=stage) # Create temporary output path output_dir = Path("./temp_extraction_output") output_dir.mkdir(parents=True, exist_ok=True) if concatenate: output_path = output_dir / "extracted_speaker.m4a" else: output_path = output_dir / "segments" # Perform extraction progress(0.1, desc="Initializing...") svc = get_service() # Note: progress_callback cannot be passed due to ZeroGPU pickling constraints report = svc.extract_and_export( reference_clip=reference_file, target_file=target_file, output_path=str(output_path), threshold=threshold, min_confidence=min_confidence, concatenate=concatenate, silence_duration_ms=silence_duration, crossfade_duration_ms=crossfade_duration, sample_rate=sample_rate, bitrate=bitrate, progress_callback=None, # Cannot pass callback to avoid pickling errors ) # Check if result is an error report if report.get("status") == "failed": error_message = f"❌ **Error ({report['error_type']}):** {report['error']}" error_report_text = f""" # Error Report **Status**: {report["status"]} **Error Type**: {report["error_type"]} **Error**: {report["error"]} {json.dumps(report, indent=2)} """ return None, error_report_text, error_message, gr.update(visible=False) progress(1.0, desc="Complete!") # Format report for display report_text = f""" # Extraction Report ## Summary - **Reference**: {Path(reference_file).name} - **Target**: {Path(target_file).name} - **Threshold**: {threshold:.2f} - **Segments Found**: {report["segments_found"]} - **Segments Included**: {report["segments_included"]} - **Total Duration**: {report["total_duration_seconds"]:.1f}s - **Average Confidence**: {report["average_confidence"]:.3f} - **Low Confidence Segments**: {report.get("low_confidence_segments", 0)} - **Processing Time**: {report["processing_time_seconds"]:.1f}s ## Output - **File**: {report.get("output_file", "N/A")} """ # Return audio file for playback output_audio = str(output_path) if concatenate and output_path.exists() else None # Prepare report JSON report_json = json.dumps(report, indent=2) # Status message if report["segments_included"] == 0: status = f"⚠️ No matching segments found. Try lowering the threshold." elif report.get("low_confidence_segments", 0) > 0: status = f"✅ Extracted {report['segments_included']} segment(s) (⚠️ {report['low_confidence_segments']} low confidence)" else: status = f"✅ Successfully extracted {report['segments_included']} segment(s)" # Make download button visible download_btn = gr.update(visible=True) return output_audio, report_text, status, download_btn except Exception as e: # Catch any unexpected errors not handled by the service logger.exception("Unexpected error in speaker extraction") error_report = { "status": "failed", "error": f"Unexpected error: {str(e)}", "error_type": "processing", } error_report_text = f""" # Error Report **Status**: {error_report["status"]} **Error Type**: {error_report["error_type"]} **Error**: {error_report["error"]} {json.dumps(error_report, indent=2)} """ return ( None, error_report_text, f"❌ **Error:** {error_report['error']}", gr.update(visible=False), ) # Create Gradio interface with gr.Tab("Speaker Extraction") as tab: gr.Markdown( "Extract a specific speaker from audio using a reference clip. " "Upload a short clip (3+ seconds) of the target speaker's voice, " "then upload the audio file to extract from." ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Step 1: Upload Reference Clip") reference_audio = gr.Audio( label="Reference Clip (3+ seconds of target speaker)", type="filepath", sources=["upload"], ) reference_status = gr.Textbox( label="Reference Validation", value="", interactive=False ) validate_btn = gr.Button("Validate Reference", size="sm") gr.Markdown("### Step 2: Upload Target Audio") target_audio = gr.Audio( label="Target Audio File", type="filepath", sources=["upload"] ) gr.Markdown("### Step 3: Configure Parameters") threshold_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.40, step=0.05, label="Matching Threshold", info="Lower = stricter matching (0.0-1.0)", ) min_confidence_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.30, step=0.05, label="Minimum Confidence", info="Minimum confidence to include segments (0.0-1.0)", ) concatenate_checkbox = gr.Checkbox( value=True, label="Concatenate Segments", info="Combine all matching segments into one file", ) with gr.Accordion("Advanced Options", open=False): silence_slider = gr.Slider( minimum=0, maximum=500, value=150, step=50, label="Silence Duration (ms)", info="Silence between concatenated segments", ) crossfade_slider = gr.Slider( minimum=0, maximum=200, value=75, step=25, label="Crossfade Duration (ms)", info="Crossfade for smooth transitions", ) sample_rate_dropdown = gr.Dropdown( choices=[16000, 22050, 44100, 48000], value=44100, label="Output Sample Rate (Hz)", ) bitrate_dropdown = gr.Dropdown( choices=["128k", "192k", "256k", "320k"], value="192k", label="Output Bitrate", ) extract_btn = gr.Button("Extract Speaker", variant="primary", size="lg") with gr.Column(scale=1): gr.Markdown("### Results") status_box = gr.Textbox( label="Status", value="Ready to extract. Upload files and click 'Extract Speaker'.", interactive=False, lines=2, ) output_audio = gr.Audio(label="Extracted Audio", type="filepath") download_btn = gr.Button("Download Extracted Audio", visible=False, size="sm") with gr.Accordion("Extraction Report", open=True): report_display = gr.Markdown("") with gr.Accordion("Technical Details", open=False): report_json = gr.Code(label="Report JSON", language="json") # Examples gr.Markdown("## Examples") gr.Examples( examples=[ [0.40, 0.30, True, 150, 75, "Standard extraction with default settings"], [0.25, 0.40, True, 150, 75, "Strict matching for high confidence"], [0.60, 0.20, True, 200, 100, "Permissive matching with more segments"], [0.40, 0.30, False, 0, 0, "Export segments separately"], ], inputs=[ threshold_slider, min_confidence_slider, concatenate_checkbox, silence_slider, crossfade_slider, ], label="Parameter Presets", ) # Event handlers validate_btn.click( fn=validate_reference, inputs=[reference_audio], outputs=[reference_status] ) extract_btn.click( fn=extract_speaker_handler, inputs=[ reference_audio, target_audio, threshold_slider, min_confidence_slider, concatenate_checkbox, silence_slider, crossfade_slider, sample_rate_dropdown, bitrate_dropdown, ], outputs=[output_audio, report_display, status_box, download_btn], ) download_btn.click(fn=lambda audio: audio, inputs=[output_audio], outputs=[output_audio]) return tab