""" CLI command for speaker separation. Analyzes M4A audio files to detect and separate speakers into individual output streams. """ import logging from pathlib import Path from typing import Optional import click from rich.console import Console from rich.progress import ( BarColumn, MofNCompleteColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn, ) from rich.table import Table console = Console() logger = logging.getLogger(__name__) @click.command() @click.argument("input_file", type=click.Path(exists=True, path_type=Path)) @click.option( "--output-dir", "-o", type=click.Path(path_type=Path), default="./separated_speakers/", help="Directory for output files (default: ./separated_speakers/)", ) @click.option( "--min-speakers", type=int, default=2, help="Minimum number of speakers to detect (default: 2)", ) @click.option( "--max-speakers", type=int, default=5, help="Maximum number of speakers to detect (default: 5)", ) @click.option( "--num-speakers", type=int, default=None, help="Exact number of speakers (overrides min/max)", ) @click.option( "--output-format", type=click.Choice(["m4a", "wav", "mp3"], case_sensitive=False), default="m4a", help="Output format (default: m4a)", ) @click.option( "--sample-rate", type=int, default=44100, help="Output sample rate in Hz (max 48000 for m4a, default: 44100)", ) @click.option( "--bitrate", type=str, default="192k", help="Output bitrate for compressed formats (default: 192k)", ) @click.option("--progress/--no-progress", default=True, help="Show/hide progress indicators") @click.option("--verbose", "-v", is_flag=True, help="Enable verbose logging") @click.option("--quiet", "-q", is_flag=True, help="Minimal output (errors only)") def separate( input_file: Path, output_dir: Path, min_speakers: int, max_speakers: int, num_speakers: Optional[int], output_format: str, sample_rate: int, bitrate: str, progress: bool, verbose: bool, quiet: bool, ): """ Separate speakers from multi-speaker audio files. Analyzes an M4A audio file to detect and separate all distinct speakers into individual output audio streams. INPUT_FILE: Path to the M4A/AAC audio file to analyze Examples: \b # Basic usage voice-tools separate interview.m4a \b # Specify output directory voice-tools separate podcast.m4a --output-dir ./speakers/ \b # Known speaker count voice-tools separate meeting.m4a --num-speakers 4 \b # WAV output with high quality voice-tools separate audio.m4a --output-format wav --sample-rate 48000 """ # Configure logging if verbose: logging.getLogger().setLevel(logging.DEBUG) elif quiet: logging.getLogger().setLevel(logging.ERROR) try: # Validate arguments _validate_arguments( input_file, min_speakers, max_speakers, num_speakers, output_format, sample_rate ) # Display header if not quiet: console.print("\n[bold cyan]Voice Tools - Speaker Separation[/bold cyan]\n") # Create output directory output_dir.mkdir(parents=True, exist_ok=True) # Determine speaker count parameters if num_speakers is not None: min_speakers = num_speakers max_speakers = num_speakers # Initialize service (lazy import to avoid pyannote loading on CLI startup) from ..services.speaker_separation import SpeakerSeparationService if not quiet and progress: console.print("[cyan]Initializing speaker separation models...[/cyan]") service = SpeakerSeparationService() if not quiet and progress: console.print("[green]✓[/green] Models loaded\n") # Process with progress display if not quiet and progress: with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), MofNCompleteColumn(), TextColumn("•"), TimeElapsedColumn(), TextColumn("•"), TimeRemainingColumn(), console=console, ) as prog: # Create progress callback current_task = None def progress_callback(stage: str, current: float, total: float): nonlocal current_task # Interpret float-based (0.0-1.0) vs integer-based formats if total == 1.0: # Float format: current is 0.0-1.0, scale to 100 for display display_total = 100 display_current = int(current * 100) else: # Integer format: use as-is display_total = int(total) display_current = int(current) if current_task is None: current_task = prog.add_task(stage, total=display_total) else: prog.update( current_task, description=stage, completed=display_current, total=display_total, ) # Run separation report = service.separate_and_export( input_file=str(input_file), output_dir=str(output_dir), min_speakers=min_speakers, max_speakers=max_speakers, output_format=output_format, sample_rate=sample_rate, bitrate=bitrate, progress_callback=progress_callback, ) else: # Run without progress display report = service.separate_and_export( input_file=str(input_file), output_dir=str(output_dir), min_speakers=min_speakers, max_speakers=max_speakers, output_format=output_format, sample_rate=sample_rate, bitrate=bitrate, ) # Check if result is an error report if report.get("status") == "failed": error_type = report.get("error_type", "processing") # Color-code by error type color_map = { "audio_io": "red", "processing": "red", "validation": "yellow", "ssl": "magenta", "model_loading": "magenta", } color = color_map.get(error_type, "red") console.print(f"[{color}]Error ({error_type}):[/{color}] {report['error']}") raise SystemExit(2) # Display results if not quiet: _display_results(report, output_dir) # Success exit raise SystemExit(0) except Exception as e: console.print(f"[red]Error:[/red] Unexpected error: {e}") logger.exception("Separation failed") raise SystemExit(3) def _validate_arguments( input_file: Path, min_speakers: int, max_speakers: int, num_speakers: Optional[int], output_format: str, sample_rate: int, ): """Validate command-line arguments.""" # Check file exists (already checked by Click, but double-check) if not input_file.exists(): raise FileNotFoundError(f"Input file '{input_file}' not found") # Check file format supported_formats = [".m4a", ".aac", ".wav", ".mp3"] if input_file.suffix.lower() not in supported_formats: raise ValueError( f"File format not supported. Expected M4A/AAC, got {input_file.suffix.upper()}" ) # Validate speaker count if num_speakers is not None and num_speakers < 1: raise ValueError(f"Number of speakers must be at least 1, got {num_speakers}") if min_speakers < 1: raise ValueError(f"Minimum speakers must be at least 1, got {min_speakers}") if max_speakers < min_speakers: raise ValueError( f"min-speakers ({min_speakers}) cannot exceed max-speakers ({max_speakers})" ) # Validate sample rate if sample_rate < 8000 or sample_rate > 48000: raise ValueError(f"Sample rate must be between 8000-48000 Hz, got {sample_rate}") if output_format.lower() == "m4a" and sample_rate > 48000: raise ValueError(f"Sample rate {sample_rate} exceeds M4A limit of 48000 Hz") def _display_results(report: dict, output_dir: Path): """Display separation results in a formatted table.""" console.print("\n[bold green]✓ Separation Complete[/bold green]\n") # Summary table summary_table = Table(title="Summary", show_header=False, box=None) summary_table.add_column("Metric", style="cyan") summary_table.add_column("Value", style="white") summary_table.add_row("Speakers detected", str(report["speakers_detected"])) summary_table.add_row("Processing time", f"{report['processing_time_seconds']:.1f}s") summary_table.add_row("Input duration", f"{report['input_duration_seconds']:.1f}s") if report.get("overlapping_segments"): summary_table.add_row("Overlapping segments", str(report["overlapping_segments"])) console.print(summary_table) console.print() # Speaker details table if report["output_files"]: speaker_table = Table(title="Speaker Details", show_header=True) speaker_table.add_column("Speaker", style="cyan") speaker_table.add_column("Duration", style="white") speaker_table.add_column("File", style="white") for output in report["output_files"]: speaker_table.add_row( output["speaker_id"], f"{output['duration']:.1f}s", output["file"], ) console.print(speaker_table) console.print() # Quality metrics if "quality_metrics" in report: metrics = report["quality_metrics"] quality_table = Table(title="Quality Metrics", show_header=False, box=None) quality_table.add_column("Metric", style="cyan") quality_table.add_column("Value", style="white") quality_table.add_row("Average confidence", f"{metrics['average_confidence']:.2f}") if metrics.get("low_confidence_segments"): quality_table.add_row( "Low confidence segments", str(metrics["low_confidence_segments"]) ) console.print(quality_table) console.print() # Output location console.print(f"[green]Output saved to:[/green] {output_dir}") console.print(f"[green]Report saved to:[/green] {output_dir / 'separation_report.json'}\n")