voice-tools / src /cli /separate.py
jcudit's picture
jcudit HF Staff
feat: implement cross-mode robustness fixes (phases 1-8)
95e1515
"""
CLI command for speaker separation.
Analyzes M4A audio files to detect and separate speakers into individual output streams.
"""
import logging
from pathlib import Path
from typing import Optional
import click
from rich.console import Console
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from rich.table import Table
console = Console()
logger = logging.getLogger(__name__)
@click.command()
@click.argument("input_file", type=click.Path(exists=True, path_type=Path))
@click.option(
"--output-dir",
"-o",
type=click.Path(path_type=Path),
default="./separated_speakers/",
help="Directory for output files (default: ./separated_speakers/)",
)
@click.option(
"--min-speakers",
type=int,
default=2,
help="Minimum number of speakers to detect (default: 2)",
)
@click.option(
"--max-speakers",
type=int,
default=5,
help="Maximum number of speakers to detect (default: 5)",
)
@click.option(
"--num-speakers",
type=int,
default=None,
help="Exact number of speakers (overrides min/max)",
)
@click.option(
"--output-format",
type=click.Choice(["m4a", "wav", "mp3"], case_sensitive=False),
default="m4a",
help="Output format (default: m4a)",
)
@click.option(
"--sample-rate",
type=int,
default=44100,
help="Output sample rate in Hz (max 48000 for m4a, default: 44100)",
)
@click.option(
"--bitrate",
type=str,
default="192k",
help="Output bitrate for compressed formats (default: 192k)",
)
@click.option("--progress/--no-progress", default=True, help="Show/hide progress indicators")
@click.option("--verbose", "-v", is_flag=True, help="Enable verbose logging")
@click.option("--quiet", "-q", is_flag=True, help="Minimal output (errors only)")
def separate(
input_file: Path,
output_dir: Path,
min_speakers: int,
max_speakers: int,
num_speakers: Optional[int],
output_format: str,
sample_rate: int,
bitrate: str,
progress: bool,
verbose: bool,
quiet: bool,
):
"""
Separate speakers from multi-speaker audio files.
Analyzes an M4A audio file to detect and separate all distinct speakers
into individual output audio streams.
INPUT_FILE: Path to the M4A/AAC audio file to analyze
Examples:
\b
# Basic usage
voice-tools separate interview.m4a
\b
# Specify output directory
voice-tools separate podcast.m4a --output-dir ./speakers/
\b
# Known speaker count
voice-tools separate meeting.m4a --num-speakers 4
\b
# WAV output with high quality
voice-tools separate audio.m4a --output-format wav --sample-rate 48000
"""
# Configure logging
if verbose:
logging.getLogger().setLevel(logging.DEBUG)
elif quiet:
logging.getLogger().setLevel(logging.ERROR)
try:
# Validate arguments
_validate_arguments(
input_file, min_speakers, max_speakers, num_speakers, output_format, sample_rate
)
# Display header
if not quiet:
console.print("\n[bold cyan]Voice Tools - Speaker Separation[/bold cyan]\n")
# Create output directory
output_dir.mkdir(parents=True, exist_ok=True)
# Determine speaker count parameters
if num_speakers is not None:
min_speakers = num_speakers
max_speakers = num_speakers
# Initialize service (lazy import to avoid pyannote loading on CLI startup)
from ..services.speaker_separation import SpeakerSeparationService
if not quiet and progress:
console.print("[cyan]Initializing speaker separation models...[/cyan]")
service = SpeakerSeparationService()
if not quiet and progress:
console.print("[green]✓[/green] Models loaded\n")
# Process with progress display
if not quiet and progress:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
MofNCompleteColumn(),
TextColumn("•"),
TimeElapsedColumn(),
TextColumn("•"),
TimeRemainingColumn(),
console=console,
) as prog:
# Create progress callback
current_task = None
def progress_callback(stage: str, current: float, total: float):
nonlocal current_task
# Interpret float-based (0.0-1.0) vs integer-based formats
if total == 1.0:
# Float format: current is 0.0-1.0, scale to 100 for display
display_total = 100
display_current = int(current * 100)
else:
# Integer format: use as-is
display_total = int(total)
display_current = int(current)
if current_task is None:
current_task = prog.add_task(stage, total=display_total)
else:
prog.update(
current_task,
description=stage,
completed=display_current,
total=display_total,
)
# Run separation
report = service.separate_and_export(
input_file=str(input_file),
output_dir=str(output_dir),
min_speakers=min_speakers,
max_speakers=max_speakers,
output_format=output_format,
sample_rate=sample_rate,
bitrate=bitrate,
progress_callback=progress_callback,
)
else:
# Run without progress display
report = service.separate_and_export(
input_file=str(input_file),
output_dir=str(output_dir),
min_speakers=min_speakers,
max_speakers=max_speakers,
output_format=output_format,
sample_rate=sample_rate,
bitrate=bitrate,
)
# Check if result is an error report
if report.get("status") == "failed":
error_type = report.get("error_type", "processing")
# Color-code by error type
color_map = {
"audio_io": "red",
"processing": "red",
"validation": "yellow",
"ssl": "magenta",
"model_loading": "magenta",
}
color = color_map.get(error_type, "red")
console.print(f"[{color}]Error ({error_type}):[/{color}] {report['error']}")
raise SystemExit(2)
# Display results
if not quiet:
_display_results(report, output_dir)
# Success exit
raise SystemExit(0)
except Exception as e:
console.print(f"[red]Error:[/red] Unexpected error: {e}")
logger.exception("Separation failed")
raise SystemExit(3)
def _validate_arguments(
input_file: Path,
min_speakers: int,
max_speakers: int,
num_speakers: Optional[int],
output_format: str,
sample_rate: int,
):
"""Validate command-line arguments."""
# Check file exists (already checked by Click, but double-check)
if not input_file.exists():
raise FileNotFoundError(f"Input file '{input_file}' not found")
# Check file format
supported_formats = [".m4a", ".aac", ".wav", ".mp3"]
if input_file.suffix.lower() not in supported_formats:
raise ValueError(
f"File format not supported. Expected M4A/AAC, got {input_file.suffix.upper()}"
)
# Validate speaker count
if num_speakers is not None and num_speakers < 1:
raise ValueError(f"Number of speakers must be at least 1, got {num_speakers}")
if min_speakers < 1:
raise ValueError(f"Minimum speakers must be at least 1, got {min_speakers}")
if max_speakers < min_speakers:
raise ValueError(
f"min-speakers ({min_speakers}) cannot exceed max-speakers ({max_speakers})"
)
# Validate sample rate
if sample_rate < 8000 or sample_rate > 48000:
raise ValueError(f"Sample rate must be between 8000-48000 Hz, got {sample_rate}")
if output_format.lower() == "m4a" and sample_rate > 48000:
raise ValueError(f"Sample rate {sample_rate} exceeds M4A limit of 48000 Hz")
def _display_results(report: dict, output_dir: Path):
"""Display separation results in a formatted table."""
console.print("\n[bold green]✓ Separation Complete[/bold green]\n")
# Summary table
summary_table = Table(title="Summary", show_header=False, box=None)
summary_table.add_column("Metric", style="cyan")
summary_table.add_column("Value", style="white")
summary_table.add_row("Speakers detected", str(report["speakers_detected"]))
summary_table.add_row("Processing time", f"{report['processing_time_seconds']:.1f}s")
summary_table.add_row("Input duration", f"{report['input_duration_seconds']:.1f}s")
if report.get("overlapping_segments"):
summary_table.add_row("Overlapping segments", str(report["overlapping_segments"]))
console.print(summary_table)
console.print()
# Speaker details table
if report["output_files"]:
speaker_table = Table(title="Speaker Details", show_header=True)
speaker_table.add_column("Speaker", style="cyan")
speaker_table.add_column("Duration", style="white")
speaker_table.add_column("File", style="white")
for output in report["output_files"]:
speaker_table.add_row(
output["speaker_id"],
f"{output['duration']:.1f}s",
output["file"],
)
console.print(speaker_table)
console.print()
# Quality metrics
if "quality_metrics" in report:
metrics = report["quality_metrics"]
quality_table = Table(title="Quality Metrics", show_header=False, box=None)
quality_table.add_column("Metric", style="cyan")
quality_table.add_column("Value", style="white")
quality_table.add_row("Average confidence", f"{metrics['average_confidence']:.2f}")
if metrics.get("low_confidence_segments"):
quality_table.add_row(
"Low confidence segments", str(metrics["low_confidence_segments"])
)
console.print(quality_table)
console.print()
# Output location
console.print(f"[green]Output saved to:[/green] {output_dir}")
console.print(f"[green]Report saved to:[/green] {output_dir / 'separation_report.json'}\n")