hedrekao
HF deploy: clean snapshot without local artifacts
a361db3
"""
Benchmark script: Compare all approaches on test audio files.
Runs all three approaches on each input file, collects metrics,
and produces a CSV comparison report.
Usage:
uv run python scripts/benchmark.py --data-dir data --output-dir benchmark_results
"""
import argparse
import sys
import csv
import json
import logging
import time
from pathlib import Path
from datetime import datetime
from approaches import list_approaches, get_approach
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%H:%M:%S",
)
log = logging.getLogger("benchmark")
# Ensure project root is importable when script is run directly
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
class BenchmarkRunner:
"""Run benchmark across all approaches and audio files."""
def __init__(self, data_dir: str, output_dir: str):
self.data_dir = Path(data_dir)
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
# Results storage
self.results = []
def find_test_files(self) -> list:
"""Find all WAV files in data directory."""
wav_files = sorted(self.data_dir.glob("*.wav"))
if not wav_files:
log.warning(f"No WAV files found in {self.data_dir}")
return wav_files
def run_benchmark(self, whisper_model: str = "base"):
"""Run all approaches on all test files."""
test_files = self.find_test_files()
if not test_files:
log.error(f"No test files found in {self.data_dir}")
return
log.info("="*70)
log.info(f"BENCHMARK: {len(list_approaches())} approaches × {len(test_files)} files")
log.info("="*70)
for input_file in test_files:
log.info(f"\n{'='*70}")
log.info(f"File: {input_file.name}")
log.info(f"{'='*70}")
for approach_name in list_approaches():
log.info(f"\n Testing approach: {approach_name}")
log.info("-"*70)
try:
result = self._run_approach(
approach_name,
input_file,
whisper_model,
)
self.results.append(result)
except Exception as e:
log.error(f" FAILED: {e}")
result = {
"timestamp": datetime.now().isoformat(),
"input_file": input_file.name,
"approach": approach_name,
"status": "FAILED",
"error": str(e),
}
self.results.append(result)
log.info(f"\n{'='*70}")
log.info("BENCHMARK COMPLETE")
log.info(f"{'='*70}")
# Save results
self._save_results()
self._print_summary()
def _run_approach(self, approach_name: str, input_file: Path, whisper_model: str):
"""Run single approach on single file."""
# Create output directory for this run
output_subdir = self.output_dir / approach_name / input_file.stem
output_subdir.mkdir(parents=True, exist_ok=True)
# Initialize approach
approach_class = get_approach(approach_name)
approach = approach_class()
# Run
start_time = time.time()
pipeline_output = approach.run(
input_file=str(input_file),
output_dir=str(output_subdir),
whisper_model=whisper_model,
)
execution_time = time.time() - start_time
result = {
"timestamp": datetime.now().isoformat(),
"input_file": input_file.name,
"input_size_mb": input_file.stat().st_size / (1024*1024),
"approach": approach_name,
"status": "SUCCESS",
"duration_seconds": pipeline_output.duration_seconds,
"execution_time_seconds": execution_time,
"samples_per_second": (pipeline_output.duration_seconds / execution_time)
if execution_time > 0 else 0,
"n_speakers": pipeline_output.n_speakers,
"talker_of_interest": pipeline_output.talker_of_interest,
"separation_method": pipeline_output.separation_method,
"doa_method": pipeline_output.doa_method,
"gender_method": pipeline_output.gender_method,
"asr_model": pipeline_output.asr_model,
"output_dir": str(output_subdir),
}
# Log metrics
log.info(" Status: SUCCESS")
log.info(f" Execution time: {execution_time:.2f}s")
log.info(f" Speakers: {pipeline_output.n_speakers}")
log.info(f" ToI: Speaker {pipeline_output.talker_of_interest}")
log.info(f" Output: {output_subdir}")
return result
def _save_results(self):
"""Save results to CSV and JSON."""
# Save CSV
csv_path = self.output_dir / "benchmark_results.csv"
if self.results:
fieldnames = self.results[0].keys()
with open(csv_path, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(self.results)
log.info(f"\nSaved: {csv_path}")
# Save JSON
json_path = self.output_dir / "benchmark_results.json"
with open(json_path, 'w') as f:
json.dump(self.results, f, indent=2)
log.info(f"Saved: {json_path}")
def _print_summary(self):
"""Print summary statistics."""
if not self.results:
return
log.info("\n" + "="*70)
log.info("SUMMARY")
log.info("="*70)
# Group by approach
by_approach = {}
for result in self.results:
approach = result.get("approach")
if approach not in by_approach:
by_approach[approach] = []
by_approach[approach].append(result)
# Print stats per approach
for approach, runs in sorted(by_approach.items()):
successful = [r for r in runs if r.get("status") == "SUCCESS"]
failed = [r for r in runs if r.get("status") == "FAILED"]
log.info(f"\nApproach: {approach}")
log.info(f" Successful: {len(successful)}/{len(runs)}")
if successful:
avg_exec_time = sum(r["execution_time_seconds"] for r in successful) / len(successful)
avg_speedup = sum(r.get("samples_per_second", 0) for r in successful) / len(successful)
log.info(f" Avg execution time: {avg_exec_time:.2f}s")
log.info(f" Avg speedup (samples/s): {avg_speedup:.1f}x")
if failed:
log.info(f" Failed runs: {len(failed)}")
def main():
parser = argparse.ArgumentParser(description="Benchmark all approaches")
parser.add_argument("--data-dir", default="data", help="Directory with test WAV files")
parser.add_argument("--output-dir", default="benchmark_results", help="Output directory")
parser.add_argument("-w", "--whisper-model", default="base", help="Whisper model")
parser.add_argument("-v", "--verbose", action="store_true")
args = parser.parse_args()
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
data_path = Path(args.data_dir)
if not data_path.exists():
log.error(f"Data directory not found: {data_path}")
return 1
runner = BenchmarkRunner(args.data_dir, args.output_dir)
runner.run_benchmark(args.whisper_model)
return 0
if __name__ == "__main__":
sys.exit(main())