import argparse import csv import json import math import statistics import sys import tempfile import time from datetime import datetime, timezone from pathlib import Path from typing import Any import numpy as np import soundfile as sf import torch import torchaudio REPO_ROOT = Path(__file__).resolve().parent.parent if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) DEFAULT_STEMS = ["Bass", "Drums", "Other", "Vocals", "Guitar", "Piano"] DEFAULT_FORMAT = "wav" VALID_FORMATS = {"wav", "mp3", "aac"} def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Run BS-RoFormer stem separation on a local file and save benchmark data." ) parser.add_argument("input_file", help="Path to the source audio file") parser.add_argument( "--runs", type=int, default=1, help="Number of benchmark runs to execute (default: 1)", ) parser.add_argument( "--stems", default="all", help="Comma-separated stems to export. Use 'all' for every stem.", ) parser.add_argument( "--output-format", default=DEFAULT_FORMAT, choices=sorted(VALID_FORMATS), help="Stem export format (default: wav)", ) parser.add_argument( "--output-root", help="Directory to place benchmark outputs. Defaults to a temp directory.", ) parser.add_argument( "--reload-model-per-run", action="store_true", help="Force a fresh model load before each run to measure cold-start cost every time.", ) parser.add_argument( "--skip-save-stems", action="store_true", help="Run the benchmark but skip writing separated stem files to disk.", ) return parser.parse_args() def main() -> int: args = parse_args() separator_module = get_separator_module() input_path = Path(args.input_file).expanduser().resolve() if not input_path.exists(): print(f"Input file not found: {input_path}", file=sys.stderr) return 1 if args.runs < 1: print("--runs must be at least 1", file=sys.stderr) return 1 stems = parse_stems(args.stems) output_root = prepare_output_root(args.output_root) source_info = read_source_file_info(input_path) summary: dict[str, Any] = { "script": str(Path(__file__).resolve()), "created_at_utc": datetime.now(timezone.utc).isoformat(), "input_file": str(input_path), "output_root": str(output_root), "device": str(separator_module.DEVICE), "runs_requested": args.runs, "reload_model_per_run": args.reload_model_per_run, "save_stems": not args.skip_save_stems, "stems_requested": stems, "output_format": args.output_format, "source_info": source_info, "runs": [], } print(f"Input file: {input_path}") print(f"Clip duration: {source_info['duration_seconds']:.3f}s") print(f"Output root: {output_root}") print(f"Device: {separator_module.DEVICE}") print(f"Runs: {args.runs}") print(f"Stems: {', '.join(stems)}") print(f"Output format: {args.output_format}") for run_index in range(1, args.runs + 1): run_dir = output_root / f"run_{run_index:02d}" run_dir.mkdir(parents=True, exist_ok=True) print(f"\n=== Run {run_index}/{args.runs} ===") run_result = benchmark_run( input_path=input_path, run_dir=run_dir, stems=stems, output_format=args.output_format, reload_model=args.reload_model_per_run, save_stems=not args.skip_save_stems, ) summary["runs"].append(run_result) print( "Total: " f"{run_result['timings']['total_seconds']:.3f}s | " f"Load: {run_result['timings']['model_load_seconds']:.3f}s | " f"Inference: {run_result['timings']['inference_seconds']:.3f}s | " f"RTF(total): {run_result['performance']['realtime_factor_total']:.3f}" ) print(f"Run output: {run_dir}") summary["aggregate"] = aggregate_runs(summary["runs"]) benchmark_json_path = output_root / "benchmark_summary.json" benchmark_csv_path = output_root / "benchmark_runs.csv" benchmark_json_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") write_runs_csv(benchmark_csv_path, summary["runs"]) print("\nBenchmark complete.") print(f"Summary JSON: {benchmark_json_path}") print(f"Runs CSV: {benchmark_csv_path}") print(f"Latest stems directory: {summary['runs'][-1]['stems_output_dir']}") return 0 def benchmark_run( input_path: Path, run_dir: Path, stems: list[str], output_format: str, reload_model: bool, save_stems: bool, ) -> dict[str, Any]: separator_module = get_separator_module() service = create_service(reload_model) progress_events: list[dict[str, float | str]] = [] def progress_callback(state: str, pct: float): progress_events.append( { "state": state, "progress": pct, "timestamp_monotonic": time.perf_counter(), } ) run_started_at = time.perf_counter() model_load_started = time.perf_counter() service.load_model() sync_device() model_load_seconds = time.perf_counter() - model_load_started read_started = time.perf_counter() audio, input_sample_rate = sf.read(str(input_path)) read_seconds = time.perf_counter() - read_started original_channels = 1 if audio.ndim == 1 else audio.shape[1] mono_to_stereo_seconds = 0.0 if audio.ndim == 1: mono_to_stereo_started = time.perf_counter() audio = np.stack([audio, audio], axis=1) mono_to_stereo_seconds = time.perf_counter() - mono_to_stereo_started tensor_started = time.perf_counter() audio_tensor = torch.tensor(audio.T, dtype=torch.float32) tensor_prepare_seconds = time.perf_counter() - tensor_started resample_seconds = 0.0 if input_sample_rate != service.sample_rate: resample_started = time.perf_counter() resampler = torchaudio.transforms.Resample(input_sample_rate, service.sample_rate) audio_tensor = resampler(audio_tensor) sync_device() resample_seconds = time.perf_counter() - resample_started resampled_samples = int(audio_tensor.shape[1]) clip_duration_seconds = resampled_samples / float(service.sample_rate) chunk_plan = compute_chunk_plan( total_samples=resampled_samples, chunk_size=service.chunk_size, num_overlap=service.num_overlap, ) peak_vram_mb = None if separator_module.DEVICE.type == "cuda": torch.cuda.reset_peak_memory_stats(separator_module.DEVICE) inference_started = time.perf_counter() separated = service._process_audio(audio_tensor, progress_callback) sync_device() inference_seconds = time.perf_counter() - inference_started if separator_module.DEVICE.type == "cuda": peak_vram_mb = bytes_to_mb(torch.cuda.max_memory_allocated(separator_module.DEVICE)) export_started = time.perf_counter() stems_output_dir = run_dir / "stems" stems_output_dir.mkdir(parents=True, exist_ok=True) exported_files: dict[str, str] = {} if save_stems: for i, stem_key in enumerate(separator_module.STEM_ORDER): canonical = separator_module.STEM_NAME_MAP[stem_key] if canonical not in stems: continue stem_audio = separated[i].numpy().T stem_audio = np.clip(stem_audio, -1.0, 1.0) output_filename = f"{canonical}.{output_format}" output_path = stems_output_dir / output_filename service._write_output(str(output_path), stem_audio, output_format) exported_files[canonical] = str(output_path) export_seconds = time.perf_counter() - export_started total_seconds = time.perf_counter() - run_started_at preprocessing_seconds = ( read_seconds + mono_to_stereo_seconds + tensor_prepare_seconds + resample_seconds ) performance = { "realtime_factor_total": safe_divide(total_seconds, clip_duration_seconds), "realtime_factor_inference": safe_divide(inference_seconds, clip_duration_seconds), "throughput_total_x": safe_divide(clip_duration_seconds, total_seconds), "throughput_inference_x": safe_divide(clip_duration_seconds, inference_seconds), } run_result: dict[str, Any] = { "run_dir": str(run_dir), "stems_output_dir": str(stems_output_dir), "exported_files": exported_files, "clip": { "input_sample_rate": input_sample_rate, "output_sample_rate": service.sample_rate, "input_channels": original_channels, "processed_channels": int(audio_tensor.shape[0]), "duration_seconds": clip_duration_seconds, "processed_samples": resampled_samples, }, "model": { "device": str(separator_module.DEVICE), "chunk_size": service.chunk_size, "num_overlap": service.num_overlap, "peak_vram_mb": peak_vram_mb, }, "chunk_plan": chunk_plan, "timings": { "model_load_seconds": model_load_seconds, "audio_read_seconds": read_seconds, "mono_to_stereo_seconds": mono_to_stereo_seconds, "tensor_prepare_seconds": tensor_prepare_seconds, "resample_seconds": resample_seconds, "preprocessing_seconds": preprocessing_seconds, "inference_seconds": inference_seconds, "export_seconds": export_seconds, "total_seconds": total_seconds, }, "performance": performance, "progress_events": normalize_progress_events(progress_events, run_started_at), } per_run_json_path = run_dir / "run_metrics.json" per_run_json_path.write_text(json.dumps(run_result, indent=2), encoding="utf-8") run_result["run_metrics_path"] = str(per_run_json_path) return run_result def create_service(reload_model: bool): separator_module = get_separator_module() if reload_model: separator_module.StemSeparatorService._instance = None separator_module.StemSeparatorService._model_loaded = False return separator_module.StemSeparatorService() def parse_stems(raw_value: str) -> list[str]: if raw_value.strip().lower() == "all": return list(DEFAULT_STEMS) selected = [part.strip().title() for part in raw_value.split(",") if part.strip()] invalid = [stem for stem in selected if stem not in DEFAULT_STEMS] if invalid: valid_display = ", ".join(DEFAULT_STEMS) raise SystemExit( f"Invalid stems: {', '.join(invalid)}. Valid stems are: {valid_display}" ) return selected def prepare_output_root(output_root: str | None) -> Path: if output_root: path = Path(output_root).expanduser().resolve() path.mkdir(parents=True, exist_ok=True) return path stamp = datetime.now().strftime("%Y%m%d-%H%M%S") path = Path(tempfile.mkdtemp(prefix=f"stem-benchmark-{stamp}-")) return path def read_source_file_info(input_path: Path) -> dict[str, Any]: info = sf.info(str(input_path)) return { "path": str(input_path), "format": info.format, "subtype": info.subtype, "sample_rate": info.samplerate, "channels": info.channels, "frames": info.frames, "duration_seconds": safe_divide(info.frames, info.samplerate), "file_size_bytes": input_path.stat().st_size, "file_size_mb": bytes_to_mb(input_path.stat().st_size), } def compute_chunk_plan(total_samples: int, chunk_size: int, num_overlap: int) -> dict[str, Any]: step = chunk_size // num_overlap pad_needed = max(0, chunk_size - total_samples) if total_samples > chunk_size: remainder = (total_samples - chunk_size) % step if remainder != 0: pad_needed = step - remainder padded_len = total_samples + pad_needed total_chunks = len(range(0, padded_len - chunk_size + 1, step)) return { "step_samples": step, "pad_needed_samples": pad_needed, "padded_length_samples": padded_len, "estimated_chunk_count": total_chunks, } def normalize_progress_events( progress_events: list[dict[str, float | str]], run_started_at: float, ) -> list[dict[str, float | str]]: normalized: list[dict[str, float | str]] = [] for event in progress_events: timestamp = float(event["timestamp_monotonic"]) normalized.append( { "state": str(event["state"]), "progress": float(event["progress"]), "seconds_from_run_start": timestamp - run_started_at, } ) return normalized def aggregate_runs(runs: list[dict[str, Any]]) -> dict[str, Any]: timing_keys = list(runs[0]["timings"].keys()) performance_keys = list(runs[0]["performance"].keys()) return { "timings": { key: summarize_metric([float(run["timings"][key]) for run in runs]) for key in timing_keys }, "performance": { key: summarize_metric([float(run["performance"][key]) for run in runs]) for key in performance_keys }, } def summarize_metric(values: list[float]) -> dict[str, float]: summary = { "min": min(values), "max": max(values), "mean": statistics.mean(values), "median": statistics.median(values), } if len(values) > 1: summary["stdev"] = statistics.stdev(values) else: summary["stdev"] = 0.0 return summary def write_runs_csv(csv_path: Path, runs: list[dict[str, Any]]) -> None: fieldnames = [ "run_dir", "stems_output_dir", "duration_seconds", "input_sample_rate", "output_sample_rate", "estimated_chunk_count", "model_load_seconds", "audio_read_seconds", "mono_to_stereo_seconds", "tensor_prepare_seconds", "resample_seconds", "preprocessing_seconds", "inference_seconds", "export_seconds", "total_seconds", "realtime_factor_total", "realtime_factor_inference", "throughput_total_x", "throughput_inference_x", "peak_vram_mb", ] with csv_path.open("w", newline="", encoding="utf-8") as handle: writer = csv.DictWriter(handle, fieldnames=fieldnames) writer.writeheader() for run in runs: writer.writerow( { "run_dir": run["run_dir"], "stems_output_dir": run["stems_output_dir"], "duration_seconds": run["clip"]["duration_seconds"], "input_sample_rate": run["clip"]["input_sample_rate"], "output_sample_rate": run["clip"]["output_sample_rate"], "estimated_chunk_count": run["chunk_plan"]["estimated_chunk_count"], "model_load_seconds": run["timings"]["model_load_seconds"], "audio_read_seconds": run["timings"]["audio_read_seconds"], "mono_to_stereo_seconds": run["timings"]["mono_to_stereo_seconds"], "tensor_prepare_seconds": run["timings"]["tensor_prepare_seconds"], "resample_seconds": run["timings"]["resample_seconds"], "preprocessing_seconds": run["timings"]["preprocessing_seconds"], "inference_seconds": run["timings"]["inference_seconds"], "export_seconds": run["timings"]["export_seconds"], "total_seconds": run["timings"]["total_seconds"], "realtime_factor_total": run["performance"]["realtime_factor_total"], "realtime_factor_inference": run["performance"]["realtime_factor_inference"], "throughput_total_x": run["performance"]["throughput_total_x"], "throughput_inference_x": run["performance"]["throughput_inference_x"], "peak_vram_mb": run["model"]["peak_vram_mb"], } ) def sync_device() -> None: separator_module = get_separator_module() if separator_module.DEVICE.type == "cuda": torch.cuda.synchronize(separator_module.DEVICE) def safe_divide(numerator: float, denominator: float) -> float: if denominator == 0: return math.inf return numerator / denominator def bytes_to_mb(value: int) -> float: return value / (1024 * 1024) def get_separator_module(): from backend import separator return separator if __name__ == "__main__": raise SystemExit(main())