FJFehr Claude Sonnet 4.6 commited on
Commit
29dbf34
·
1 Parent(s): 4d8957d

feat: add CPU/GPU generation benchmark script

Browse files

Adds benchmark.py to measure Godzilla model generation speed across all
combinations of input length (short/long) and generation length (32–128
tokens), with mean, std, min, max, tok/s and GPU speedup reporting.

Also caps requires-python to <3.14 to avoid pydantic-core build failure
on Python 3.14 (pyo3 does not yet support it), and documents the
benchmark in README.md.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (3) hide show
  1. README.md +55 -0
  2. benchmark.py +204 -0
  3. pyproject.toml +1 -1
README.md CHANGED
@@ -201,6 +201,61 @@ Code consolidation to improve maintainability:
201
 
202
  ---
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  ## 🛠️ Development Tips
205
 
206
  ### Debugging
 
201
 
202
  ---
203
 
204
+ ## ⚡ Benchmarking
205
+
206
+ `benchmark.py` measures Godzilla model generation speed across all combinations of input length and generation length, with CPU and GPU compared side by side.
207
+
208
+ ### What it tests
209
+
210
+ | Axis | Values |
211
+ |------|--------|
212
+ | Input length | Short (8 notes, ~4 s) · Long (90 notes, ~18 s) |
213
+ | Generation length | 32 · 64 · 96 · 128 tokens (matches the four UI presets) |
214
+ | Devices | CPU always · CUDA if available |
215
+
216
+ Each combination runs a warm-up pass (model load, timing discarded) followed by `--runs` timed passes. The summary tables report mean, std, min, max in both ms and seconds, plus tokens/sec and GPU speedup.
217
+
218
+ ### Usage
219
+
220
+ ```bash
221
+ # Full sweep — CPU + GPU (if available), 5 runs per combination
222
+ uv run python benchmark.py
223
+
224
+ # CPU only (useful for verifying the script or on CPU-only machines)
225
+ uv run python benchmark.py --cpu-only
226
+
227
+ # Increase runs for tighter statistics
228
+ uv run python benchmark.py --runs 10
229
+
230
+ # Multi-candidate generation (higher quality, slower)
231
+ uv run python benchmark.py --candidates 3
232
+ ```
233
+
234
+ Results are printed to stdout and saved to `benchmark_results.txt` (override with `--output`).
235
+
236
+ ### Example output
237
+
238
+ ```
239
+ ============================================================
240
+ Device: CUDA | candidates=1
241
+ ============================================================
242
+ [warm-up] loading model + first inference...
243
+ input=short (8 notes, ~4s) gen= 32 tokens [1:85ms] [2:82ms] ...
244
+ ...
245
+
246
+ ================================================================================
247
+ SUMMARY — CUDA | candidates=1
248
+ ================================================================================
249
+ Input Gen tok Mean ms Mean s Std ms Min ms Max ms tok/s
250
+ -----------------------------------------------------------------------------------------
251
+ short (8 notes, ~4s) 32 85 0.09 2.1 82 89 376.5
252
+ short (8 notes, ~4s) 128 290 0.29 4.3 284 297 441.4
253
+ long (90 notes, ~18s) 32 91 0.09 1.8 88 94 351.6
254
+ long (90 notes, ~18s) 128 305 0.31 3.9 299 312 419.7
255
+ ```
256
+
257
+ ---
258
+
259
  ## 🛠️ Development Tips
260
 
261
  ### Debugging
benchmark.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ CPU vs GPU generation benchmark for the Godzilla MIDI model.
4
+ Sweeps all combinations of input length x generation length.
5
+
6
+ Usage:
7
+ uv run python benchmark.py
8
+ uv run python benchmark.py --runs 5 --candidates 1 --cpu-only
9
+ """
10
+
11
+ import argparse
12
+ import datetime
13
+ import io
14
+ import math
15
+ import sys
16
+ import time
17
+ import torch
18
+ from midi_model import generate_godzilla_continuation
19
+
20
+ # Short input: 8 notes, 0.5s apart (~4 seconds, ~24 prompt tokens)
21
+ SHORT_EVENTS = [
22
+ {
23
+ "type": "note",
24
+ "note": 60 + (i % 12),
25
+ "velocity": 80,
26
+ "time": i * 0.5,
27
+ "channel": 0,
28
+ }
29
+ for i in range(8)
30
+ ]
31
+
32
+ # Long input: 90 notes, 0.2s apart (~18 seconds — fills the prompt window)
33
+ LONG_EVENTS = [
34
+ {
35
+ "type": "note",
36
+ "note": 60 + (i % 12),
37
+ "velocity": 80,
38
+ "time": i * 0.2,
39
+ "channel": 0,
40
+ }
41
+ for i in range(90)
42
+ ]
43
+
44
+ INPUT_FIXTURES = {
45
+ "short (8 notes, ~4s)": SHORT_EVENTS,
46
+ "long (90 notes, ~18s)": LONG_EVENTS,
47
+ }
48
+
49
+ # Matches the four UI presets in keyboard.js
50
+ GENERATION_LENGTHS = [32, 64, 96, 128]
51
+
52
+
53
+ def gpu_name() -> str:
54
+ if torch.cuda.is_available():
55
+ return torch.cuda.get_device_name(0)
56
+ return "N/A"
57
+
58
+
59
+ def stddev(values: list[float]) -> float:
60
+ n = len(values)
61
+ if n < 2:
62
+ return 0.0
63
+ mean = sum(values) / n
64
+ return math.sqrt(sum((x - mean) ** 2 for x in values) / (n - 1))
65
+
66
+
67
+ def run_generation(
68
+ events: list[dict], device: str, tokens: int, candidates: int
69
+ ) -> float:
70
+ """Run one generation call, return wall-clock time in ms."""
71
+ t0 = time.perf_counter()
72
+ generate_godzilla_continuation(
73
+ events,
74
+ generate_tokens=tokens,
75
+ device=device,
76
+ num_candidates=candidates,
77
+ seed=42,
78
+ )
79
+ return (time.perf_counter() - t0) * 1000.0
80
+
81
+
82
+ def benchmark_device(
83
+ device: str, runs: int, candidates: int
84
+ ) -> dict[tuple[str, int], list[float]]:
85
+ """Run all input x generation-length combinations for one device."""
86
+ print(f"\n{'=' * 72}")
87
+ print(f" Device: {device.upper()} | candidates={candidates}")
88
+ print(f"{'=' * 72}")
89
+
90
+ # Single warm-up to load the model (use smallest combo)
91
+ print(" [warm-up] loading model + first inference...")
92
+ run_generation(SHORT_EVENTS, device, GENERATION_LENGTHS[0], candidates)
93
+
94
+ results: dict[tuple[str, int], list[float]] = {}
95
+ for input_label, events in INPUT_FIXTURES.items():
96
+ for gen_tokens in GENERATION_LENGTHS:
97
+ key = (input_label, gen_tokens)
98
+ timings = []
99
+ print(
100
+ f" input={input_label} gen={gen_tokens:>3} tokens",
101
+ end=" ",
102
+ flush=True,
103
+ )
104
+ for i in range(runs):
105
+ ms = run_generation(events, device, gen_tokens, candidates)
106
+ timings.append(ms)
107
+ print(f"[{i + 1}:{ms:.0f}ms]", end=" ", flush=True)
108
+ print()
109
+ results[key] = timings
110
+
111
+ return results
112
+
113
+
114
+ def print_summary(
115
+ device: str, results: dict[tuple[str, int], list[float]], candidates: int
116
+ ) -> None:
117
+ print(f"\n{'=' * 80}")
118
+ print(f" SUMMARY — {device.upper()} | candidates={candidates}")
119
+ print(f"{'=' * 80}")
120
+ header = f" {'Input':<24} {'Gen tok':>7} {'Mean ms':>8} {'Mean s':>7} {'Std ms':>7} {'Min ms':>7} {'Max ms':>7} {'tok/s':>7}"
121
+ print(header)
122
+ print(" " + "-" * (len(header) - 2))
123
+ for (input_label, gen_tokens), timings in results.items():
124
+ mean = sum(timings) / len(timings)
125
+ std = stddev(timings)
126
+ tok_per_s = gen_tokens / (mean / 1000.0)
127
+ print(
128
+ f" {input_label:<24} {gen_tokens:>7} {mean:>8.0f} {mean / 1000:>7.2f}"
129
+ f" {std:>7.1f} {min(timings):>7.0f} {max(timings):>7.0f} {tok_per_s:>7.1f}"
130
+ )
131
+
132
+
133
+ def main():
134
+ parser = argparse.ArgumentParser()
135
+ parser.add_argument("--runs", type=int, default=5)
136
+ parser.add_argument("--candidates", type=int, default=1)
137
+ parser.add_argument("--output", type=str, default="benchmark_results.txt")
138
+ parser.add_argument("--cpu-only", action="store_true", help="Skip GPU benchmark")
139
+ args = parser.parse_args()
140
+
141
+ # Tee all output to stdout and a buffer for saving
142
+ buffer = io.StringIO()
143
+
144
+ class Tee:
145
+ def write(self, msg):
146
+ sys.__stdout__.write(msg)
147
+ buffer.write(msg)
148
+
149
+ def flush(self):
150
+ sys.__stdout__.flush()
151
+
152
+ sys.stdout = Tee()
153
+
154
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
155
+ print(f"Benchmark run: {timestamp}")
156
+ print(f"GPU: {gpu_name()}")
157
+ print(f"Runs per combination: {args.runs} | Candidates: {args.candidates}")
158
+ print(
159
+ f"Input sizes: short={len(SHORT_EVENTS)} notes, long={len(LONG_EVENTS)} notes"
160
+ )
161
+ print(f"Generation sizes: {GENERATION_LENGTHS} tokens")
162
+
163
+ all_results: dict[str, dict[tuple[str, int], list[float]]] = {}
164
+
165
+ all_results["cpu"] = benchmark_device("cpu", args.runs, args.candidates)
166
+
167
+ if args.cpu_only:
168
+ print("\n[--cpu-only flag set — skipping GPU benchmark]")
169
+ elif torch.cuda.is_available():
170
+ all_results["cuda"] = benchmark_device("cuda", args.runs, args.candidates)
171
+ else:
172
+ print("\n[CUDA not available — skipping GPU benchmark]")
173
+
174
+ for device, results in all_results.items():
175
+ print_summary(device, results, args.candidates)
176
+
177
+ # GPU speedup table (if both ran)
178
+ if "cpu" in all_results and "cuda" in all_results:
179
+ print(f"\n{'=' * 80}")
180
+ print(" GPU SPEEDUP")
181
+ print(f"{'=' * 80}")
182
+ header = f" {'Input':<24} {'Gen tok':>7} {'CPU ms':>8} {'CPU s':>6} {'GPU ms':>8} {'GPU s':>6} {'Speedup':>8}"
183
+ print(header)
184
+ print(" " + "-" * (len(header) - 2))
185
+ for key in all_results["cpu"]:
186
+ cpu_mean = sum(all_results["cpu"][key]) / len(all_results["cpu"][key])
187
+ gpu_mean = sum(all_results["cuda"][key]) / len(all_results["cuda"][key])
188
+ speedup = cpu_mean / gpu_mean
189
+ input_label, gen_tokens = key
190
+ print(
191
+ f" {input_label:<24} {gen_tokens:>7} {cpu_mean:>8.0f} {cpu_mean / 1000:>6.2f}"
192
+ f" {gpu_mean:>8.0f} {gpu_mean / 1000:>6.2f} {speedup:>7.2f}x"
193
+ )
194
+
195
+ print()
196
+ sys.stdout = sys.__stdout__
197
+
198
+ with open(args.output, "w") as f:
199
+ f.write(buffer.getvalue())
200
+ print(f"Results saved to {args.output}")
201
+
202
+
203
+ if __name__ == "__main__":
204
+ main()
pyproject.toml CHANGED
@@ -3,7 +3,7 @@ name = "virtual-keyboard"
3
  version = "0.1.0"
4
  description = "Add your description here"
5
  readme = "README.md"
6
- requires-python = ">=3.10"
7
  dependencies = [
8
  "einops>=0.6",
9
  "einx>=0.3.0",
 
3
  version = "0.1.0"
4
  description = "Add your description here"
5
  readme = "README.md"
6
+ requires-python = ">=3.10,<3.14"
7
  dependencies = [
8
  "einops>=0.6",
9
  "einx>=0.3.0",