voice-tools / src /cli /progress.py
jcudit's picture
jcudit HF Staff
feat: complete audio speaker separation feature with 3 workflows
cb39c05
"""
Progress reporting and statistics display for CLI.
Uses Rich library for beautiful terminal output with progress bars,
tables, and formatted statistics.
"""
from typing import Dict, List, Optional
from rich import box
from rich.console import Console
from rich.panel import Panel
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from rich.table import Table
from rich.text import Text
console = Console()
class ExtractionProgress:
"""
Manages progress reporting for voice extraction tasks.
Provides rich terminal output with progress bars, statistics,
and formatted results.
"""
def __init__(self):
"""Initialize the progress reporter."""
self.progress: Optional[Progress] = None
self.overall_task = None
self.current_task = None
def start(self, total_files: int):
"""
Start progress tracking.
Args:
total_files: Total number of files to process
"""
self.progress = Progress(
SpinnerColumn(),
TextColumn("[bold blue]{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeElapsedColumn(),
TimeRemainingColumn(),
console=console,
)
self.progress.start()
self.overall_task = self.progress.add_task("[cyan]Overall Progress", total=total_files)
def start_file(self, filename: str):
"""
Start processing a new file.
Args:
filename: Name of the file being processed
"""
if self.progress:
if self.current_task is not None:
self.progress.remove_task(self.current_task)
self.current_task = self.progress.add_task(f"[green]Processing: {filename}", total=100)
def update_file(self, progress: int, status: str = ""):
"""
Update progress for current file.
Args:
progress: Progress percentage (0-100)
status: Optional status message
"""
if self.progress and self.current_task is not None:
description = f"[green]Processing"
if status:
description += f": {status}"
self.progress.update(self.current_task, completed=progress, description=description)
def complete_file(self, success: bool = True):
"""
Mark current file as complete.
Args:
success: Whether file was processed successfully
"""
if self.progress:
if self.current_task is not None:
self.progress.remove_task(self.current_task)
self.current_task = None
if self.overall_task is not None:
self.progress.update(self.overall_task, advance=1)
def stop(self):
"""Stop progress tracking."""
if self.progress:
self.progress.stop()
self.progress = None
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.stop()
def display_header(title: str):
"""
Display a formatted header.
Args:
title: Header title
"""
console.print()
console.print(Panel.fit(f"[bold cyan]{title}[/bold cyan]", border_style="cyan"))
console.print()
def display_config(config: Dict):
"""
Display configuration in a formatted table.
Args:
config: Configuration dictionary
"""
table = Table(title="Configuration", box=box.ROUNDED, show_header=False)
table.add_column("Setting", style="cyan", no_wrap=True)
table.add_column("Value", style="white")
for key, value in config.items():
table.add_row(key, str(value))
console.print(table)
console.print()
def display_statistics(stats: Dict):
"""
Display extraction statistics in a formatted table.
Args:
stats: Statistics dictionary
"""
table = Table(title="Extraction Statistics", box=box.ROUNDED, show_header=False)
table.add_column("Metric", style="cyan", no_wrap=True)
table.add_column("Value", style="white")
# Format durations
if "total_input_duration" in stats:
duration = stats["total_input_duration"]
table.add_row("Total Input Duration", f"{duration:.2f}s ({duration / 60:.2f} min)")
if "total_extracted_duration" in stats:
duration = stats["total_extracted_duration"]
table.add_row("Total Extracted Duration", f"{duration:.2f}s ({duration / 60:.2f} min)")
if "extraction_percentage" in stats:
table.add_row("Extraction Percentage", f"{stats['extraction_percentage']:.1f}%")
if "files_processed" in stats:
table.add_row("Files Processed", str(stats["files_processed"]))
if "files_failed" in stats:
failures = stats["files_failed"]
style = "red" if failures > 0 else "green"
table.add_row("Files Failed", f"[{style}]{failures}[/{style}]")
if "segments_extracted" in stats:
table.add_row("Segments Extracted", str(stats["segments_extracted"]))
if "average_segment_duration" in stats:
duration = stats["average_segment_duration"]
table.add_row("Average Segment Duration", f"{duration:.2f}s")
if "average_confidence" in stats:
conf = stats["average_confidence"]
table.add_row("Average Confidence", f"{conf:.2f}")
if "duration" in stats:
duration = stats["duration"]
table.add_row("Processing Time", f"{duration:.2f}s ({duration / 60:.2f} min)")
console.print(table)
console.print()
def display_failures(failures: List[Dict]):
"""
Display failed files in a formatted table.
Args:
failures: List of failure dictionaries
"""
if not failures:
return
table = Table(title="Failed Files", box=box.ROUNDED)
table.add_column("File", style="yellow", no_wrap=True)
table.add_column("Error", style="red")
for failure in failures:
table.add_row(failure["file"], failure["error"])
console.print(table)
console.print()
def display_vad_stats(vad_stats: Dict):
"""
Display VAD statistics.
Args:
vad_stats: VAD statistics dictionary
"""
if not vad_stats:
return
table = Table(title="Voice Activity Detection", box=box.ROUNDED, show_header=False)
table.add_column("Metric", style="cyan", no_wrap=True)
table.add_column("Value", style="white")
if "total_duration" in vad_stats:
duration = vad_stats["total_duration"]
table.add_row("Total Duration", f"{duration:.2f}s ({duration / 60:.2f} min)")
if "voice_duration" in vad_stats:
duration = vad_stats["voice_duration"]
table.add_row("Voice Duration", f"{duration:.2f}s ({duration / 60:.2f} min)")
if "voice_percentage" in vad_stats:
percentage = vad_stats["voice_percentage"]
style = "green" if percentage > 20 else "yellow" if percentage > 10 else "red"
table.add_row("Voice Activity", f"[{style}]{percentage:.1f}%[/{style}]")
if "num_segments" in vad_stats:
table.add_row("Voice Segments", str(vad_stats["num_segments"]))
if "worth_processing" in vad_stats:
worth = vad_stats["worth_processing"]
status = "[green]Yes[/green]" if worth else "[red]No[/red]"
table.add_row("Worth Processing", status)
console.print(table)
console.print()
def display_success(message: str):
"""
Display a success message.
Args:
message: Success message
"""
console.print(f"[green]βœ“[/green] {message}")
def display_warning(message: str):
"""
Display a warning message.
Args:
message: Warning message
"""
console.print(f"[yellow]⚠[/yellow] {message}")
def display_error(message: str):
"""
Display an error message.
Args:
message: Error message
"""
console.print(f"[red]βœ—[/red] {message}")
def display_info(message: str):
"""
Display an info message.
Args:
message: Info message
"""
console.print(f"[blue]β„Ή[/blue] {message}")
def display_segment_details(segments: List[Dict]):
"""
Display extracted segment details in a table.
Args:
segments: List of segment dictionaries
"""
if not segments:
return
table = Table(title="Extracted Segments", box=box.ROUNDED)
table.add_column("#", style="cyan", no_wrap=True)
table.add_column("Start", style="white")
table.add_column("End", style="white")
table.add_column("Duration", style="white")
table.add_column("Type", style="green")
table.add_column("Confidence", style="yellow")
table.add_column("Quality", style="magenta")
for i, segment in enumerate(segments, 1):
segment_type = segment.get("segment_type", "unknown")
confidence = segment.get("confidence", 0)
snr = segment.get("snr")
# Format quality indicator
if snr is not None:
quality = "βœ“" if snr >= 15 else "βœ—"
else:
quality = "?"
table.add_row(
str(i),
f"{segment['start']:.2f}s",
f"{segment['end']:.2f}s",
f"{segment['duration']:.2f}s",
segment_type.value if hasattr(segment_type, "value") else str(segment_type),
f"{confidence:.2f}",
quality,
)
# Show first 20 segments, with note if more
if len(segments) > 20:
console.print(table)
console.print(f"[dim]Showing first 20 of {len(segments)} segments[/dim]")
else:
console.print(table)
console.print()
def display_quality_report(quality_metrics: Dict):
"""
Display quality metrics report.
Args:
quality_metrics: Quality metrics dictionary
"""
table = Table(title="Quality Metrics", box=box.ROUNDED, show_header=False)
table.add_column("Metric", style="cyan", no_wrap=True)
table.add_column("Value", style="white")
table.add_column("Status", style="green")
if "snr" in quality_metrics:
snr = quality_metrics["snr"]
status = "βœ“ Pass" if snr >= 15 else "βœ— Fail"
table.add_row("SNR (Signal-to-Noise Ratio)", f"{snr:.2f} dB", status)
if "stoi" in quality_metrics:
stoi = quality_metrics["stoi"]
status = "βœ“ Pass" if stoi >= 0.70 else "βœ— Fail"
table.add_row("STOI (Intelligibility)", f"{stoi:.2f}", status)
if "pesq" in quality_metrics:
pesq = quality_metrics["pesq"]
status = "βœ“ Pass" if pesq >= 2.0 else "βœ— Fail"
table.add_row("PESQ (Perceptual Quality)", f"{pesq:.2f}", status)
console.print(table)
console.print()