virtual_keyboard / benchmark.py
FJFehr's picture
feat: add CPU/GPU generation benchmark script
29dbf34
#!/usr/bin/env python3
"""
CPU vs GPU generation benchmark for the Godzilla MIDI model.
Sweeps all combinations of input length x generation length.
Usage:
uv run python benchmark.py
uv run python benchmark.py --runs 5 --candidates 1 --cpu-only
"""
import argparse
import datetime
import io
import math
import sys
import time
import torch
from midi_model import generate_godzilla_continuation
# Short input: 8 notes, 0.5s apart (~4 seconds, ~24 prompt tokens)
SHORT_EVENTS = [
{
"type": "note",
"note": 60 + (i % 12),
"velocity": 80,
"time": i * 0.5,
"channel": 0,
}
for i in range(8)
]
# Long input: 90 notes, 0.2s apart (~18 seconds — fills the prompt window)
LONG_EVENTS = [
{
"type": "note",
"note": 60 + (i % 12),
"velocity": 80,
"time": i * 0.2,
"channel": 0,
}
for i in range(90)
]
INPUT_FIXTURES = {
"short (8 notes, ~4s)": SHORT_EVENTS,
"long (90 notes, ~18s)": LONG_EVENTS,
}
# Matches the four UI presets in keyboard.js
GENERATION_LENGTHS = [32, 64, 96, 128]
def gpu_name() -> str:
if torch.cuda.is_available():
return torch.cuda.get_device_name(0)
return "N/A"
def stddev(values: list[float]) -> float:
n = len(values)
if n < 2:
return 0.0
mean = sum(values) / n
return math.sqrt(sum((x - mean) ** 2 for x in values) / (n - 1))
def run_generation(
events: list[dict], device: str, tokens: int, candidates: int
) -> float:
"""Run one generation call, return wall-clock time in ms."""
t0 = time.perf_counter()
generate_godzilla_continuation(
events,
generate_tokens=tokens,
device=device,
num_candidates=candidates,
seed=42,
)
return (time.perf_counter() - t0) * 1000.0
def benchmark_device(
device: str, runs: int, candidates: int
) -> dict[tuple[str, int], list[float]]:
"""Run all input x generation-length combinations for one device."""
print(f"\n{'=' * 72}")
print(f" Device: {device.upper()} | candidates={candidates}")
print(f"{'=' * 72}")
# Single warm-up to load the model (use smallest combo)
print(" [warm-up] loading model + first inference...")
run_generation(SHORT_EVENTS, device, GENERATION_LENGTHS[0], candidates)
results: dict[tuple[str, int], list[float]] = {}
for input_label, events in INPUT_FIXTURES.items():
for gen_tokens in GENERATION_LENGTHS:
key = (input_label, gen_tokens)
timings = []
print(
f" input={input_label} gen={gen_tokens:>3} tokens",
end=" ",
flush=True,
)
for i in range(runs):
ms = run_generation(events, device, gen_tokens, candidates)
timings.append(ms)
print(f"[{i + 1}:{ms:.0f}ms]", end=" ", flush=True)
print()
results[key] = timings
return results
def print_summary(
device: str, results: dict[tuple[str, int], list[float]], candidates: int
) -> None:
print(f"\n{'=' * 80}")
print(f" SUMMARY — {device.upper()} | candidates={candidates}")
print(f"{'=' * 80}")
header = f" {'Input':<24} {'Gen tok':>7} {'Mean ms':>8} {'Mean s':>7} {'Std ms':>7} {'Min ms':>7} {'Max ms':>7} {'tok/s':>7}"
print(header)
print(" " + "-" * (len(header) - 2))
for (input_label, gen_tokens), timings in results.items():
mean = sum(timings) / len(timings)
std = stddev(timings)
tok_per_s = gen_tokens / (mean / 1000.0)
print(
f" {input_label:<24} {gen_tokens:>7} {mean:>8.0f} {mean / 1000:>7.2f}"
f" {std:>7.1f} {min(timings):>7.0f} {max(timings):>7.0f} {tok_per_s:>7.1f}"
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--runs", type=int, default=5)
parser.add_argument("--candidates", type=int, default=1)
parser.add_argument("--output", type=str, default="benchmark_results.txt")
parser.add_argument("--cpu-only", action="store_true", help="Skip GPU benchmark")
args = parser.parse_args()
# Tee all output to stdout and a buffer for saving
buffer = io.StringIO()
class Tee:
def write(self, msg):
sys.__stdout__.write(msg)
buffer.write(msg)
def flush(self):
sys.__stdout__.flush()
sys.stdout = Tee()
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"Benchmark run: {timestamp}")
print(f"GPU: {gpu_name()}")
print(f"Runs per combination: {args.runs} | Candidates: {args.candidates}")
print(
f"Input sizes: short={len(SHORT_EVENTS)} notes, long={len(LONG_EVENTS)} notes"
)
print(f"Generation sizes: {GENERATION_LENGTHS} tokens")
all_results: dict[str, dict[tuple[str, int], list[float]]] = {}
all_results["cpu"] = benchmark_device("cpu", args.runs, args.candidates)
if args.cpu_only:
print("\n[--cpu-only flag set — skipping GPU benchmark]")
elif torch.cuda.is_available():
all_results["cuda"] = benchmark_device("cuda", args.runs, args.candidates)
else:
print("\n[CUDA not available — skipping GPU benchmark]")
for device, results in all_results.items():
print_summary(device, results, args.candidates)
# GPU speedup table (if both ran)
if "cpu" in all_results and "cuda" in all_results:
print(f"\n{'=' * 80}")
print(" GPU SPEEDUP")
print(f"{'=' * 80}")
header = f" {'Input':<24} {'Gen tok':>7} {'CPU ms':>8} {'CPU s':>6} {'GPU ms':>8} {'GPU s':>6} {'Speedup':>8}"
print(header)
print(" " + "-" * (len(header) - 2))
for key in all_results["cpu"]:
cpu_mean = sum(all_results["cpu"][key]) / len(all_results["cpu"][key])
gpu_mean = sum(all_results["cuda"][key]) / len(all_results["cuda"][key])
speedup = cpu_mean / gpu_mean
input_label, gen_tokens = key
print(
f" {input_label:<24} {gen_tokens:>7} {cpu_mean:>8.0f} {cpu_mean / 1000:>6.2f}"
f" {gpu_mean:>8.0f} {gpu_mean / 1000:>6.2f} {speedup:>7.2f}x"
)
print()
sys.stdout = sys.__stdout__
with open(args.output, "w") as f:
f.write(buffer.getvalue())
print(f"Results saved to {args.output}")
if __name__ == "__main__":
main()