Spaces:
Running
Running
| import argparse | |
| import csv | |
| import json | |
| import math | |
| import statistics | |
| import sys | |
| import tempfile | |
| import time | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Any | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| import torchaudio | |
| REPO_ROOT = Path(__file__).resolve().parent.parent | |
| if str(REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(REPO_ROOT)) | |
| DEFAULT_STEMS = ["Bass", "Drums", "Other", "Vocals", "Guitar", "Piano"] | |
| DEFAULT_FORMAT = "wav" | |
| VALID_FORMATS = {"wav", "mp3", "aac"} | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Run BS-RoFormer stem separation on a local file and save benchmark data." | |
| ) | |
| parser.add_argument("input_file", help="Path to the source audio file") | |
| parser.add_argument( | |
| "--runs", | |
| type=int, | |
| default=1, | |
| help="Number of benchmark runs to execute (default: 1)", | |
| ) | |
| parser.add_argument( | |
| "--stems", | |
| default="all", | |
| help="Comma-separated stems to export. Use 'all' for every stem.", | |
| ) | |
| parser.add_argument( | |
| "--output-format", | |
| default=DEFAULT_FORMAT, | |
| choices=sorted(VALID_FORMATS), | |
| help="Stem export format (default: wav)", | |
| ) | |
| parser.add_argument( | |
| "--output-root", | |
| help="Directory to place benchmark outputs. Defaults to a temp directory.", | |
| ) | |
| parser.add_argument( | |
| "--reload-model-per-run", | |
| action="store_true", | |
| help="Force a fresh model load before each run to measure cold-start cost every time.", | |
| ) | |
| parser.add_argument( | |
| "--skip-save-stems", | |
| action="store_true", | |
| help="Run the benchmark but skip writing separated stem files to disk.", | |
| ) | |
| return parser.parse_args() | |
| def main() -> int: | |
| args = parse_args() | |
| separator_module = get_separator_module() | |
| input_path = Path(args.input_file).expanduser().resolve() | |
| if not input_path.exists(): | |
| print(f"Input file not found: {input_path}", file=sys.stderr) | |
| return 1 | |
| if args.runs < 1: | |
| print("--runs must be at least 1", file=sys.stderr) | |
| return 1 | |
| stems = parse_stems(args.stems) | |
| output_root = prepare_output_root(args.output_root) | |
| source_info = read_source_file_info(input_path) | |
| summary: dict[str, Any] = { | |
| "script": str(Path(__file__).resolve()), | |
| "created_at_utc": datetime.now(timezone.utc).isoformat(), | |
| "input_file": str(input_path), | |
| "output_root": str(output_root), | |
| "device": str(separator_module.DEVICE), | |
| "runs_requested": args.runs, | |
| "reload_model_per_run": args.reload_model_per_run, | |
| "save_stems": not args.skip_save_stems, | |
| "stems_requested": stems, | |
| "output_format": args.output_format, | |
| "source_info": source_info, | |
| "runs": [], | |
| } | |
| print(f"Input file: {input_path}") | |
| print(f"Clip duration: {source_info['duration_seconds']:.3f}s") | |
| print(f"Output root: {output_root}") | |
| print(f"Device: {separator_module.DEVICE}") | |
| print(f"Runs: {args.runs}") | |
| print(f"Stems: {', '.join(stems)}") | |
| print(f"Output format: {args.output_format}") | |
| for run_index in range(1, args.runs + 1): | |
| run_dir = output_root / f"run_{run_index:02d}" | |
| run_dir.mkdir(parents=True, exist_ok=True) | |
| print(f"\n=== Run {run_index}/{args.runs} ===") | |
| run_result = benchmark_run( | |
| input_path=input_path, | |
| run_dir=run_dir, | |
| stems=stems, | |
| output_format=args.output_format, | |
| reload_model=args.reload_model_per_run, | |
| save_stems=not args.skip_save_stems, | |
| ) | |
| summary["runs"].append(run_result) | |
| print( | |
| "Total: " | |
| f"{run_result['timings']['total_seconds']:.3f}s | " | |
| f"Load: {run_result['timings']['model_load_seconds']:.3f}s | " | |
| f"Inference: {run_result['timings']['inference_seconds']:.3f}s | " | |
| f"RTF(total): {run_result['performance']['realtime_factor_total']:.3f}" | |
| ) | |
| print(f"Run output: {run_dir}") | |
| summary["aggregate"] = aggregate_runs(summary["runs"]) | |
| benchmark_json_path = output_root / "benchmark_summary.json" | |
| benchmark_csv_path = output_root / "benchmark_runs.csv" | |
| benchmark_json_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") | |
| write_runs_csv(benchmark_csv_path, summary["runs"]) | |
| print("\nBenchmark complete.") | |
| print(f"Summary JSON: {benchmark_json_path}") | |
| print(f"Runs CSV: {benchmark_csv_path}") | |
| print(f"Latest stems directory: {summary['runs'][-1]['stems_output_dir']}") | |
| return 0 | |
| def benchmark_run( | |
| input_path: Path, | |
| run_dir: Path, | |
| stems: list[str], | |
| output_format: str, | |
| reload_model: bool, | |
| save_stems: bool, | |
| ) -> dict[str, Any]: | |
| separator_module = get_separator_module() | |
| service = create_service(reload_model) | |
| progress_events: list[dict[str, float | str]] = [] | |
| def progress_callback(state: str, pct: float): | |
| progress_events.append( | |
| { | |
| "state": state, | |
| "progress": pct, | |
| "timestamp_monotonic": time.perf_counter(), | |
| } | |
| ) | |
| run_started_at = time.perf_counter() | |
| model_load_started = time.perf_counter() | |
| service.load_model() | |
| sync_device() | |
| model_load_seconds = time.perf_counter() - model_load_started | |
| read_started = time.perf_counter() | |
| audio, input_sample_rate = sf.read(str(input_path)) | |
| read_seconds = time.perf_counter() - read_started | |
| original_channels = 1 if audio.ndim == 1 else audio.shape[1] | |
| mono_to_stereo_seconds = 0.0 | |
| if audio.ndim == 1: | |
| mono_to_stereo_started = time.perf_counter() | |
| audio = np.stack([audio, audio], axis=1) | |
| mono_to_stereo_seconds = time.perf_counter() - mono_to_stereo_started | |
| tensor_started = time.perf_counter() | |
| audio_tensor = torch.tensor(audio.T, dtype=torch.float32) | |
| tensor_prepare_seconds = time.perf_counter() - tensor_started | |
| resample_seconds = 0.0 | |
| if input_sample_rate != service.sample_rate: | |
| resample_started = time.perf_counter() | |
| resampler = torchaudio.transforms.Resample(input_sample_rate, service.sample_rate) | |
| audio_tensor = resampler(audio_tensor) | |
| sync_device() | |
| resample_seconds = time.perf_counter() - resample_started | |
| resampled_samples = int(audio_tensor.shape[1]) | |
| clip_duration_seconds = resampled_samples / float(service.sample_rate) | |
| chunk_plan = compute_chunk_plan( | |
| total_samples=resampled_samples, | |
| chunk_size=service.chunk_size, | |
| num_overlap=service.num_overlap, | |
| ) | |
| peak_vram_mb = None | |
| if separator_module.DEVICE.type == "cuda": | |
| torch.cuda.reset_peak_memory_stats(separator_module.DEVICE) | |
| inference_started = time.perf_counter() | |
| separated = service._process_audio(audio_tensor, progress_callback) | |
| sync_device() | |
| inference_seconds = time.perf_counter() - inference_started | |
| if separator_module.DEVICE.type == "cuda": | |
| peak_vram_mb = bytes_to_mb(torch.cuda.max_memory_allocated(separator_module.DEVICE)) | |
| export_started = time.perf_counter() | |
| stems_output_dir = run_dir / "stems" | |
| stems_output_dir.mkdir(parents=True, exist_ok=True) | |
| exported_files: dict[str, str] = {} | |
| if save_stems: | |
| for i, stem_key in enumerate(separator_module.STEM_ORDER): | |
| canonical = separator_module.STEM_NAME_MAP[stem_key] | |
| if canonical not in stems: | |
| continue | |
| stem_audio = separated[i].numpy().T | |
| stem_audio = np.clip(stem_audio, -1.0, 1.0) | |
| output_filename = f"{canonical}.{output_format}" | |
| output_path = stems_output_dir / output_filename | |
| service._write_output(str(output_path), stem_audio, output_format) | |
| exported_files[canonical] = str(output_path) | |
| export_seconds = time.perf_counter() - export_started | |
| total_seconds = time.perf_counter() - run_started_at | |
| preprocessing_seconds = ( | |
| read_seconds | |
| + mono_to_stereo_seconds | |
| + tensor_prepare_seconds | |
| + resample_seconds | |
| ) | |
| performance = { | |
| "realtime_factor_total": safe_divide(total_seconds, clip_duration_seconds), | |
| "realtime_factor_inference": safe_divide(inference_seconds, clip_duration_seconds), | |
| "throughput_total_x": safe_divide(clip_duration_seconds, total_seconds), | |
| "throughput_inference_x": safe_divide(clip_duration_seconds, inference_seconds), | |
| } | |
| run_result: dict[str, Any] = { | |
| "run_dir": str(run_dir), | |
| "stems_output_dir": str(stems_output_dir), | |
| "exported_files": exported_files, | |
| "clip": { | |
| "input_sample_rate": input_sample_rate, | |
| "output_sample_rate": service.sample_rate, | |
| "input_channels": original_channels, | |
| "processed_channels": int(audio_tensor.shape[0]), | |
| "duration_seconds": clip_duration_seconds, | |
| "processed_samples": resampled_samples, | |
| }, | |
| "model": { | |
| "device": str(separator_module.DEVICE), | |
| "chunk_size": service.chunk_size, | |
| "num_overlap": service.num_overlap, | |
| "peak_vram_mb": peak_vram_mb, | |
| }, | |
| "chunk_plan": chunk_plan, | |
| "timings": { | |
| "model_load_seconds": model_load_seconds, | |
| "audio_read_seconds": read_seconds, | |
| "mono_to_stereo_seconds": mono_to_stereo_seconds, | |
| "tensor_prepare_seconds": tensor_prepare_seconds, | |
| "resample_seconds": resample_seconds, | |
| "preprocessing_seconds": preprocessing_seconds, | |
| "inference_seconds": inference_seconds, | |
| "export_seconds": export_seconds, | |
| "total_seconds": total_seconds, | |
| }, | |
| "performance": performance, | |
| "progress_events": normalize_progress_events(progress_events, run_started_at), | |
| } | |
| per_run_json_path = run_dir / "run_metrics.json" | |
| per_run_json_path.write_text(json.dumps(run_result, indent=2), encoding="utf-8") | |
| run_result["run_metrics_path"] = str(per_run_json_path) | |
| return run_result | |
| def create_service(reload_model: bool): | |
| separator_module = get_separator_module() | |
| if reload_model: | |
| separator_module.StemSeparatorService._instance = None | |
| separator_module.StemSeparatorService._model_loaded = False | |
| return separator_module.StemSeparatorService() | |
| def parse_stems(raw_value: str) -> list[str]: | |
| if raw_value.strip().lower() == "all": | |
| return list(DEFAULT_STEMS) | |
| selected = [part.strip().title() for part in raw_value.split(",") if part.strip()] | |
| invalid = [stem for stem in selected if stem not in DEFAULT_STEMS] | |
| if invalid: | |
| valid_display = ", ".join(DEFAULT_STEMS) | |
| raise SystemExit( | |
| f"Invalid stems: {', '.join(invalid)}. Valid stems are: {valid_display}" | |
| ) | |
| return selected | |
| def prepare_output_root(output_root: str | None) -> Path: | |
| if output_root: | |
| path = Path(output_root).expanduser().resolve() | |
| path.mkdir(parents=True, exist_ok=True) | |
| return path | |
| stamp = datetime.now().strftime("%Y%m%d-%H%M%S") | |
| path = Path(tempfile.mkdtemp(prefix=f"stem-benchmark-{stamp}-")) | |
| return path | |
| def read_source_file_info(input_path: Path) -> dict[str, Any]: | |
| info = sf.info(str(input_path)) | |
| return { | |
| "path": str(input_path), | |
| "format": info.format, | |
| "subtype": info.subtype, | |
| "sample_rate": info.samplerate, | |
| "channels": info.channels, | |
| "frames": info.frames, | |
| "duration_seconds": safe_divide(info.frames, info.samplerate), | |
| "file_size_bytes": input_path.stat().st_size, | |
| "file_size_mb": bytes_to_mb(input_path.stat().st_size), | |
| } | |
| def compute_chunk_plan(total_samples: int, chunk_size: int, num_overlap: int) -> dict[str, Any]: | |
| step = chunk_size // num_overlap | |
| pad_needed = max(0, chunk_size - total_samples) | |
| if total_samples > chunk_size: | |
| remainder = (total_samples - chunk_size) % step | |
| if remainder != 0: | |
| pad_needed = step - remainder | |
| padded_len = total_samples + pad_needed | |
| total_chunks = len(range(0, padded_len - chunk_size + 1, step)) | |
| return { | |
| "step_samples": step, | |
| "pad_needed_samples": pad_needed, | |
| "padded_length_samples": padded_len, | |
| "estimated_chunk_count": total_chunks, | |
| } | |
| def normalize_progress_events( | |
| progress_events: list[dict[str, float | str]], | |
| run_started_at: float, | |
| ) -> list[dict[str, float | str]]: | |
| normalized: list[dict[str, float | str]] = [] | |
| for event in progress_events: | |
| timestamp = float(event["timestamp_monotonic"]) | |
| normalized.append( | |
| { | |
| "state": str(event["state"]), | |
| "progress": float(event["progress"]), | |
| "seconds_from_run_start": timestamp - run_started_at, | |
| } | |
| ) | |
| return normalized | |
| def aggregate_runs(runs: list[dict[str, Any]]) -> dict[str, Any]: | |
| timing_keys = list(runs[0]["timings"].keys()) | |
| performance_keys = list(runs[0]["performance"].keys()) | |
| return { | |
| "timings": { | |
| key: summarize_metric([float(run["timings"][key]) for run in runs]) | |
| for key in timing_keys | |
| }, | |
| "performance": { | |
| key: summarize_metric([float(run["performance"][key]) for run in runs]) | |
| for key in performance_keys | |
| }, | |
| } | |
| def summarize_metric(values: list[float]) -> dict[str, float]: | |
| summary = { | |
| "min": min(values), | |
| "max": max(values), | |
| "mean": statistics.mean(values), | |
| "median": statistics.median(values), | |
| } | |
| if len(values) > 1: | |
| summary["stdev"] = statistics.stdev(values) | |
| else: | |
| summary["stdev"] = 0.0 | |
| return summary | |
| def write_runs_csv(csv_path: Path, runs: list[dict[str, Any]]) -> None: | |
| fieldnames = [ | |
| "run_dir", | |
| "stems_output_dir", | |
| "duration_seconds", | |
| "input_sample_rate", | |
| "output_sample_rate", | |
| "estimated_chunk_count", | |
| "model_load_seconds", | |
| "audio_read_seconds", | |
| "mono_to_stereo_seconds", | |
| "tensor_prepare_seconds", | |
| "resample_seconds", | |
| "preprocessing_seconds", | |
| "inference_seconds", | |
| "export_seconds", | |
| "total_seconds", | |
| "realtime_factor_total", | |
| "realtime_factor_inference", | |
| "throughput_total_x", | |
| "throughput_inference_x", | |
| "peak_vram_mb", | |
| ] | |
| with csv_path.open("w", newline="", encoding="utf-8") as handle: | |
| writer = csv.DictWriter(handle, fieldnames=fieldnames) | |
| writer.writeheader() | |
| for run in runs: | |
| writer.writerow( | |
| { | |
| "run_dir": run["run_dir"], | |
| "stems_output_dir": run["stems_output_dir"], | |
| "duration_seconds": run["clip"]["duration_seconds"], | |
| "input_sample_rate": run["clip"]["input_sample_rate"], | |
| "output_sample_rate": run["clip"]["output_sample_rate"], | |
| "estimated_chunk_count": run["chunk_plan"]["estimated_chunk_count"], | |
| "model_load_seconds": run["timings"]["model_load_seconds"], | |
| "audio_read_seconds": run["timings"]["audio_read_seconds"], | |
| "mono_to_stereo_seconds": run["timings"]["mono_to_stereo_seconds"], | |
| "tensor_prepare_seconds": run["timings"]["tensor_prepare_seconds"], | |
| "resample_seconds": run["timings"]["resample_seconds"], | |
| "preprocessing_seconds": run["timings"]["preprocessing_seconds"], | |
| "inference_seconds": run["timings"]["inference_seconds"], | |
| "export_seconds": run["timings"]["export_seconds"], | |
| "total_seconds": run["timings"]["total_seconds"], | |
| "realtime_factor_total": run["performance"]["realtime_factor_total"], | |
| "realtime_factor_inference": run["performance"]["realtime_factor_inference"], | |
| "throughput_total_x": run["performance"]["throughput_total_x"], | |
| "throughput_inference_x": run["performance"]["throughput_inference_x"], | |
| "peak_vram_mb": run["model"]["peak_vram_mb"], | |
| } | |
| ) | |
| def sync_device() -> None: | |
| separator_module = get_separator_module() | |
| if separator_module.DEVICE.type == "cuda": | |
| torch.cuda.synchronize(separator_module.DEVICE) | |
| def safe_divide(numerator: float, denominator: float) -> float: | |
| if denominator == 0: | |
| return math.inf | |
| return numerator / denominator | |
| def bytes_to_mb(value: int) -> float: | |
| return value / (1024 * 1024) | |
| def get_separator_module(): | |
| from backend import separator | |
| return separator | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |