Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| CLI command for speaker extraction | |
| Extracts specific speaker from audio using reference clip. | |
| """ | |
| import json | |
| import sys | |
| from pathlib import Path | |
| import click | |
| from rich.console import Console | |
| from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn | |
| from rich.table import Table | |
| from src.services.speaker_extraction import SpeakerExtractionService | |
| console = Console() | |
| def _display_results(report: dict, output_path: Path): | |
| """Display extraction results in a formatted table""" | |
| console.print() | |
| console.print("[bold green]✓ Extraction Complete[/bold green]") | |
| console.print() | |
| # Summary table | |
| summary_table = Table(title="Extraction Summary", show_header=True) | |
| summary_table.add_column("Metric", style="cyan") | |
| summary_table.add_column("Value", style="white") | |
| summary_table.add_row("Segments found", str(report["segments_found"])) | |
| summary_table.add_row("Segments included", str(report["segments_included"])) | |
| summary_table.add_row("Total duration", f"{report['total_duration_seconds']:.1f}s") | |
| summary_table.add_row("Average confidence", f"{report['average_confidence']:.3f}") | |
| summary_table.add_row("Processing time", f"{report['processing_time_seconds']:.1f}s") | |
| if report.get("low_confidence_segments", 0) > 0: | |
| summary_table.add_row( | |
| "Low confidence segments", str(report["low_confidence_segments"]), style="yellow" | |
| ) | |
| console.print(summary_table) | |
| console.print() | |
| # Output files | |
| if report.get("output_file"): | |
| console.print(f"[bold]Output:[/bold] {report['output_file']}") | |
| # Write report JSON | |
| report_file = output_path.parent / "extraction_report.json" | |
| with open(report_file, "w") as f: | |
| json.dump(report, f, indent=2) | |
| console.print(f"[bold]Report:[/bold] {report_file}") | |
| console.print() | |
| def extract_speaker( | |
| reference_clip, | |
| target_file, | |
| output, | |
| threshold, | |
| min_confidence, | |
| concatenate, | |
| silence, | |
| crossfade, | |
| sample_rate, | |
| bitrate, | |
| ): | |
| """ | |
| Extract specific speaker from audio using reference clip. | |
| REFERENCE_CLIP: Path to audio file containing reference speaker's voice (3+ seconds) | |
| TARGET_FILE: Path to audio file to extract speaker from | |
| Examples: | |
| # Basic extraction with default settings | |
| voice-tools extract-speaker reference.m4a interview.m4a | |
| # Strict matching with custom output | |
| voice-tools extract-speaker ref.m4a target.m4a \\ | |
| --threshold 0.30 --output alice_voice.m4a | |
| # Export to separate segment files | |
| voice-tools extract-speaker ref.m4a podcast.m4a \\ | |
| --no-concatenate --output ./alice_segments/ | |
| """ | |
| console.print() | |
| console.print("[bold]Voice Tools - Speaker Extraction[/bold]") | |
| console.print() | |
| try: | |
| # Validate threshold range | |
| if not 0.0 <= threshold <= 1.0: | |
| console.print( | |
| "[bold red]Error:[/bold red] Threshold must be between 0.0 and 1.0", style="red" | |
| ) | |
| sys.exit(1) | |
| if not 0.0 <= min_confidence <= 1.0: | |
| console.print( | |
| "[bold red]Error:[/bold red] Min confidence must be between 0.0 and 1.0", | |
| style="red", | |
| ) | |
| sys.exit(1) | |
| # Initialize service | |
| console.print("Initializing speaker extraction models...") | |
| service = SpeakerExtractionService() | |
| console.print("[green]✓[/green] Models loaded") | |
| console.print() | |
| # Validate reference clip | |
| is_valid, message = service.validate_reference_clip(str(reference_clip)) | |
| if not is_valid: | |
| console.print(f"[bold red]Error:[/bold red] {message}", style="red") | |
| sys.exit(4) # Exit code 4 for reference clip issues | |
| if message and "warning" in message.lower(): | |
| console.print(f"[yellow]Warning:[/yellow] {message}") | |
| console.print() | |
| # Progress tracking | |
| current_task = None | |
| with Progress( | |
| SpinnerColumn(), | |
| TextColumn("[progress.description]{task.description}"), | |
| BarColumn(), | |
| TextColumn("{task.completed}/{task.total}"), | |
| "•", | |
| TimeElapsedColumn(), | |
| console=console, | |
| transient=False, | |
| ) as prog: | |
| def progress_callback(stage: str, current: float, total: float): | |
| nonlocal current_task | |
| # Interpret float-based (0.0-1.0) vs integer-based formats | |
| if total == 1.0: | |
| # Float format: current is 0.0-1.0, scale to 100 for display | |
| display_total = 100 | |
| display_current = int(current * 100) | |
| else: | |
| # Integer format: use as-is | |
| display_total = int(total) | |
| display_current = int(current) | |
| if current_task is None: | |
| current_task = prog.add_task(stage, total=display_total) | |
| else: | |
| prog.update( | |
| current_task, | |
| description=stage, | |
| completed=display_current, | |
| total=display_total, | |
| ) | |
| # Perform extraction | |
| report = service.extract_and_export( | |
| reference_clip=str(reference_clip), | |
| target_file=str(target_file), | |
| output_path=str(output), | |
| threshold=threshold, | |
| min_confidence=min_confidence, | |
| concatenate=concatenate, | |
| silence_duration_ms=silence, | |
| crossfade_duration_ms=crossfade, | |
| sample_rate=sample_rate, | |
| bitrate=bitrate, | |
| progress_callback=progress_callback, | |
| ) | |
| # Check if result is an error report | |
| if report.get("status") == "failed": | |
| error_type = report.get("error_type", "processing") | |
| # Color-code by error type | |
| color_map = { | |
| "audio_io": "red", | |
| "processing": "red", | |
| "validation": "yellow", | |
| "ssl": "magenta", | |
| "model_loading": "magenta", | |
| } | |
| color = color_map.get(error_type, "red") | |
| console.print( | |
| f"[bold {color}]Error ({error_type}):[/bold {color}] {report['error']}", style=color | |
| ) | |
| sys.exit(2) | |
| # Check if any segments were found | |
| if report["segments_included"] == 0: | |
| console.print() | |
| console.print( | |
| "[yellow]Warning:[/yellow] No segments matched reference speaker", style="yellow" | |
| ) | |
| console.print( | |
| f" Try lowering the threshold (current: {threshold:.2f}) for more permissive matching", | |
| style="dim", | |
| ) | |
| sys.exit(3) # Exit code 3 for no matches | |
| # Display results | |
| _display_results(report, output) | |
| # Show low confidence warning | |
| if report.get("low_confidence_segments", 0) > 0: | |
| console.print( | |
| f"[yellow]Note:[/yellow] {report['low_confidence_segments']} segment(s) " | |
| f"have confidence close to threshold. Consider raising threshold for stricter matching.", | |
| style="dim", | |
| ) | |
| console.print() | |
| except Exception as e: | |
| console.print(f"[bold red]Error:[/bold red] Unexpected error: {e}", style="red") | |
| console.print() | |
| console.print("[dim]Stack trace:[/dim]") | |
| import traceback | |
| console.print(traceback.format_exc(), style="dim") | |
| sys.exit(3) # Exit code 3 for processing error | |