stem-separator / scripts /benchmark_separator.py
sourav-das's picture
Upload folder using huggingface_hub
7dfae77 verified
import argparse
import csv
import json
import math
import statistics
import sys
import tempfile
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import numpy as np
import soundfile as sf
import torch
import torchaudio
REPO_ROOT = Path(__file__).resolve().parent.parent
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
DEFAULT_STEMS = ["Bass", "Drums", "Other", "Vocals", "Guitar", "Piano"]
DEFAULT_FORMAT = "wav"
VALID_FORMATS = {"wav", "mp3", "aac"}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run BS-RoFormer stem separation on a local file and save benchmark data."
)
parser.add_argument("input_file", help="Path to the source audio file")
parser.add_argument(
"--runs",
type=int,
default=1,
help="Number of benchmark runs to execute (default: 1)",
)
parser.add_argument(
"--stems",
default="all",
help="Comma-separated stems to export. Use 'all' for every stem.",
)
parser.add_argument(
"--output-format",
default=DEFAULT_FORMAT,
choices=sorted(VALID_FORMATS),
help="Stem export format (default: wav)",
)
parser.add_argument(
"--output-root",
help="Directory to place benchmark outputs. Defaults to a temp directory.",
)
parser.add_argument(
"--reload-model-per-run",
action="store_true",
help="Force a fresh model load before each run to measure cold-start cost every time.",
)
parser.add_argument(
"--skip-save-stems",
action="store_true",
help="Run the benchmark but skip writing separated stem files to disk.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
separator_module = get_separator_module()
input_path = Path(args.input_file).expanduser().resolve()
if not input_path.exists():
print(f"Input file not found: {input_path}", file=sys.stderr)
return 1
if args.runs < 1:
print("--runs must be at least 1", file=sys.stderr)
return 1
stems = parse_stems(args.stems)
output_root = prepare_output_root(args.output_root)
source_info = read_source_file_info(input_path)
summary: dict[str, Any] = {
"script": str(Path(__file__).resolve()),
"created_at_utc": datetime.now(timezone.utc).isoformat(),
"input_file": str(input_path),
"output_root": str(output_root),
"device": str(separator_module.DEVICE),
"runs_requested": args.runs,
"reload_model_per_run": args.reload_model_per_run,
"save_stems": not args.skip_save_stems,
"stems_requested": stems,
"output_format": args.output_format,
"source_info": source_info,
"runs": [],
}
print(f"Input file: {input_path}")
print(f"Clip duration: {source_info['duration_seconds']:.3f}s")
print(f"Output root: {output_root}")
print(f"Device: {separator_module.DEVICE}")
print(f"Runs: {args.runs}")
print(f"Stems: {', '.join(stems)}")
print(f"Output format: {args.output_format}")
for run_index in range(1, args.runs + 1):
run_dir = output_root / f"run_{run_index:02d}"
run_dir.mkdir(parents=True, exist_ok=True)
print(f"\n=== Run {run_index}/{args.runs} ===")
run_result = benchmark_run(
input_path=input_path,
run_dir=run_dir,
stems=stems,
output_format=args.output_format,
reload_model=args.reload_model_per_run,
save_stems=not args.skip_save_stems,
)
summary["runs"].append(run_result)
print(
"Total: "
f"{run_result['timings']['total_seconds']:.3f}s | "
f"Load: {run_result['timings']['model_load_seconds']:.3f}s | "
f"Inference: {run_result['timings']['inference_seconds']:.3f}s | "
f"RTF(total): {run_result['performance']['realtime_factor_total']:.3f}"
)
print(f"Run output: {run_dir}")
summary["aggregate"] = aggregate_runs(summary["runs"])
benchmark_json_path = output_root / "benchmark_summary.json"
benchmark_csv_path = output_root / "benchmark_runs.csv"
benchmark_json_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
write_runs_csv(benchmark_csv_path, summary["runs"])
print("\nBenchmark complete.")
print(f"Summary JSON: {benchmark_json_path}")
print(f"Runs CSV: {benchmark_csv_path}")
print(f"Latest stems directory: {summary['runs'][-1]['stems_output_dir']}")
return 0
def benchmark_run(
input_path: Path,
run_dir: Path,
stems: list[str],
output_format: str,
reload_model: bool,
save_stems: bool,
) -> dict[str, Any]:
separator_module = get_separator_module()
service = create_service(reload_model)
progress_events: list[dict[str, float | str]] = []
def progress_callback(state: str, pct: float):
progress_events.append(
{
"state": state,
"progress": pct,
"timestamp_monotonic": time.perf_counter(),
}
)
run_started_at = time.perf_counter()
model_load_started = time.perf_counter()
service.load_model()
sync_device()
model_load_seconds = time.perf_counter() - model_load_started
read_started = time.perf_counter()
audio, input_sample_rate = sf.read(str(input_path))
read_seconds = time.perf_counter() - read_started
original_channels = 1 if audio.ndim == 1 else audio.shape[1]
mono_to_stereo_seconds = 0.0
if audio.ndim == 1:
mono_to_stereo_started = time.perf_counter()
audio = np.stack([audio, audio], axis=1)
mono_to_stereo_seconds = time.perf_counter() - mono_to_stereo_started
tensor_started = time.perf_counter()
audio_tensor = torch.tensor(audio.T, dtype=torch.float32)
tensor_prepare_seconds = time.perf_counter() - tensor_started
resample_seconds = 0.0
if input_sample_rate != service.sample_rate:
resample_started = time.perf_counter()
resampler = torchaudio.transforms.Resample(input_sample_rate, service.sample_rate)
audio_tensor = resampler(audio_tensor)
sync_device()
resample_seconds = time.perf_counter() - resample_started
resampled_samples = int(audio_tensor.shape[1])
clip_duration_seconds = resampled_samples / float(service.sample_rate)
chunk_plan = compute_chunk_plan(
total_samples=resampled_samples,
chunk_size=service.chunk_size,
num_overlap=service.num_overlap,
)
peak_vram_mb = None
if separator_module.DEVICE.type == "cuda":
torch.cuda.reset_peak_memory_stats(separator_module.DEVICE)
inference_started = time.perf_counter()
separated = service._process_audio(audio_tensor, progress_callback)
sync_device()
inference_seconds = time.perf_counter() - inference_started
if separator_module.DEVICE.type == "cuda":
peak_vram_mb = bytes_to_mb(torch.cuda.max_memory_allocated(separator_module.DEVICE))
export_started = time.perf_counter()
stems_output_dir = run_dir / "stems"
stems_output_dir.mkdir(parents=True, exist_ok=True)
exported_files: dict[str, str] = {}
if save_stems:
for i, stem_key in enumerate(separator_module.STEM_ORDER):
canonical = separator_module.STEM_NAME_MAP[stem_key]
if canonical not in stems:
continue
stem_audio = separated[i].numpy().T
stem_audio = np.clip(stem_audio, -1.0, 1.0)
output_filename = f"{canonical}.{output_format}"
output_path = stems_output_dir / output_filename
service._write_output(str(output_path), stem_audio, output_format)
exported_files[canonical] = str(output_path)
export_seconds = time.perf_counter() - export_started
total_seconds = time.perf_counter() - run_started_at
preprocessing_seconds = (
read_seconds
+ mono_to_stereo_seconds
+ tensor_prepare_seconds
+ resample_seconds
)
performance = {
"realtime_factor_total": safe_divide(total_seconds, clip_duration_seconds),
"realtime_factor_inference": safe_divide(inference_seconds, clip_duration_seconds),
"throughput_total_x": safe_divide(clip_duration_seconds, total_seconds),
"throughput_inference_x": safe_divide(clip_duration_seconds, inference_seconds),
}
run_result: dict[str, Any] = {
"run_dir": str(run_dir),
"stems_output_dir": str(stems_output_dir),
"exported_files": exported_files,
"clip": {
"input_sample_rate": input_sample_rate,
"output_sample_rate": service.sample_rate,
"input_channels": original_channels,
"processed_channels": int(audio_tensor.shape[0]),
"duration_seconds": clip_duration_seconds,
"processed_samples": resampled_samples,
},
"model": {
"device": str(separator_module.DEVICE),
"chunk_size": service.chunk_size,
"num_overlap": service.num_overlap,
"peak_vram_mb": peak_vram_mb,
},
"chunk_plan": chunk_plan,
"timings": {
"model_load_seconds": model_load_seconds,
"audio_read_seconds": read_seconds,
"mono_to_stereo_seconds": mono_to_stereo_seconds,
"tensor_prepare_seconds": tensor_prepare_seconds,
"resample_seconds": resample_seconds,
"preprocessing_seconds": preprocessing_seconds,
"inference_seconds": inference_seconds,
"export_seconds": export_seconds,
"total_seconds": total_seconds,
},
"performance": performance,
"progress_events": normalize_progress_events(progress_events, run_started_at),
}
per_run_json_path = run_dir / "run_metrics.json"
per_run_json_path.write_text(json.dumps(run_result, indent=2), encoding="utf-8")
run_result["run_metrics_path"] = str(per_run_json_path)
return run_result
def create_service(reload_model: bool):
separator_module = get_separator_module()
if reload_model:
separator_module.StemSeparatorService._instance = None
separator_module.StemSeparatorService._model_loaded = False
return separator_module.StemSeparatorService()
def parse_stems(raw_value: str) -> list[str]:
if raw_value.strip().lower() == "all":
return list(DEFAULT_STEMS)
selected = [part.strip().title() for part in raw_value.split(",") if part.strip()]
invalid = [stem for stem in selected if stem not in DEFAULT_STEMS]
if invalid:
valid_display = ", ".join(DEFAULT_STEMS)
raise SystemExit(
f"Invalid stems: {', '.join(invalid)}. Valid stems are: {valid_display}"
)
return selected
def prepare_output_root(output_root: str | None) -> Path:
if output_root:
path = Path(output_root).expanduser().resolve()
path.mkdir(parents=True, exist_ok=True)
return path
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
path = Path(tempfile.mkdtemp(prefix=f"stem-benchmark-{stamp}-"))
return path
def read_source_file_info(input_path: Path) -> dict[str, Any]:
info = sf.info(str(input_path))
return {
"path": str(input_path),
"format": info.format,
"subtype": info.subtype,
"sample_rate": info.samplerate,
"channels": info.channels,
"frames": info.frames,
"duration_seconds": safe_divide(info.frames, info.samplerate),
"file_size_bytes": input_path.stat().st_size,
"file_size_mb": bytes_to_mb(input_path.stat().st_size),
}
def compute_chunk_plan(total_samples: int, chunk_size: int, num_overlap: int) -> dict[str, Any]:
step = chunk_size // num_overlap
pad_needed = max(0, chunk_size - total_samples)
if total_samples > chunk_size:
remainder = (total_samples - chunk_size) % step
if remainder != 0:
pad_needed = step - remainder
padded_len = total_samples + pad_needed
total_chunks = len(range(0, padded_len - chunk_size + 1, step))
return {
"step_samples": step,
"pad_needed_samples": pad_needed,
"padded_length_samples": padded_len,
"estimated_chunk_count": total_chunks,
}
def normalize_progress_events(
progress_events: list[dict[str, float | str]],
run_started_at: float,
) -> list[dict[str, float | str]]:
normalized: list[dict[str, float | str]] = []
for event in progress_events:
timestamp = float(event["timestamp_monotonic"])
normalized.append(
{
"state": str(event["state"]),
"progress": float(event["progress"]),
"seconds_from_run_start": timestamp - run_started_at,
}
)
return normalized
def aggregate_runs(runs: list[dict[str, Any]]) -> dict[str, Any]:
timing_keys = list(runs[0]["timings"].keys())
performance_keys = list(runs[0]["performance"].keys())
return {
"timings": {
key: summarize_metric([float(run["timings"][key]) for run in runs])
for key in timing_keys
},
"performance": {
key: summarize_metric([float(run["performance"][key]) for run in runs])
for key in performance_keys
},
}
def summarize_metric(values: list[float]) -> dict[str, float]:
summary = {
"min": min(values),
"max": max(values),
"mean": statistics.mean(values),
"median": statistics.median(values),
}
if len(values) > 1:
summary["stdev"] = statistics.stdev(values)
else:
summary["stdev"] = 0.0
return summary
def write_runs_csv(csv_path: Path, runs: list[dict[str, Any]]) -> None:
fieldnames = [
"run_dir",
"stems_output_dir",
"duration_seconds",
"input_sample_rate",
"output_sample_rate",
"estimated_chunk_count",
"model_load_seconds",
"audio_read_seconds",
"mono_to_stereo_seconds",
"tensor_prepare_seconds",
"resample_seconds",
"preprocessing_seconds",
"inference_seconds",
"export_seconds",
"total_seconds",
"realtime_factor_total",
"realtime_factor_inference",
"throughput_total_x",
"throughput_inference_x",
"peak_vram_mb",
]
with csv_path.open("w", newline="", encoding="utf-8") as handle:
writer = csv.DictWriter(handle, fieldnames=fieldnames)
writer.writeheader()
for run in runs:
writer.writerow(
{
"run_dir": run["run_dir"],
"stems_output_dir": run["stems_output_dir"],
"duration_seconds": run["clip"]["duration_seconds"],
"input_sample_rate": run["clip"]["input_sample_rate"],
"output_sample_rate": run["clip"]["output_sample_rate"],
"estimated_chunk_count": run["chunk_plan"]["estimated_chunk_count"],
"model_load_seconds": run["timings"]["model_load_seconds"],
"audio_read_seconds": run["timings"]["audio_read_seconds"],
"mono_to_stereo_seconds": run["timings"]["mono_to_stereo_seconds"],
"tensor_prepare_seconds": run["timings"]["tensor_prepare_seconds"],
"resample_seconds": run["timings"]["resample_seconds"],
"preprocessing_seconds": run["timings"]["preprocessing_seconds"],
"inference_seconds": run["timings"]["inference_seconds"],
"export_seconds": run["timings"]["export_seconds"],
"total_seconds": run["timings"]["total_seconds"],
"realtime_factor_total": run["performance"]["realtime_factor_total"],
"realtime_factor_inference": run["performance"]["realtime_factor_inference"],
"throughput_total_x": run["performance"]["throughput_total_x"],
"throughput_inference_x": run["performance"]["throughput_inference_x"],
"peak_vram_mb": run["model"]["peak_vram_mb"],
}
)
def sync_device() -> None:
separator_module = get_separator_module()
if separator_module.DEVICE.type == "cuda":
torch.cuda.synchronize(separator_module.DEVICE)
def safe_divide(numerator: float, denominator: float) -> float:
if denominator == 0:
return math.inf
return numerator / denominator
def bytes_to_mb(value: int) -> float:
return value / (1024 * 1024)
def get_separator_module():
from backend import separator
return separator
if __name__ == "__main__":
raise SystemExit(main())