rikhoffbauer2 commited on
Commit
ab7ecaf
Β·
verified Β·
1 Parent(s): 1b8f186

Add synth_generator.py

Browse files
Files changed (1) hide show
  1. synth_generator.py +507 -0
synth_generator.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Synthetic drum song generator with known ground-truth samples.
3
+
4
+ Generates realistic drum patterns by:
5
+ 1. Synthesizing individual drum samples (kick, snare, hihat, etc.) with controlled parameters
6
+ 2. Placing them in musical patterns with velocity variation, timing humanization, and overlap
7
+ 3. Optionally mixing with bass/harmony for realistic Demucs testing
8
+ 4. Returning both the mix AND the isolated ground-truth samples + onset map
9
+
10
+ This gives us a perfect evaluation setup: we know exactly which samples are where,
11
+ so we can compare extracted samples against the originals.
12
+ """
13
+
14
+ import numpy as np
15
+ from scipy.signal import butter, filtfilt, lfilter
16
+ from dataclasses import dataclass, field
17
+ from typing import Optional
18
+ import soundfile as sf
19
+ import json
20
+
21
+
22
+ @dataclass
23
+ class GroundTruthSample:
24
+ """A ground-truth drum sample used to build the synthetic song."""
25
+ name: str # e.g. "kick", "snare"
26
+ audio: np.ndarray # the clean one-shot sample
27
+ sr: int
28
+ frequency_range: tuple # (low_hz, high_hz) primary energy band
29
+
30
+ @property
31
+ def duration(self) -> float:
32
+ return len(self.audio) / self.sr
33
+
34
+
35
+ @dataclass
36
+ class PlacedHit:
37
+ """A single hit placed in the timeline."""
38
+ sample_name: str
39
+ onset_time: float # seconds
40
+ velocity: float # 0-1 amplitude multiplier
41
+ audio: np.ndarray # the actual audio placed (with velocity applied)
42
+ sr: int
43
+
44
+
45
+ @dataclass
46
+ class SyntheticSong:
47
+ """A complete synthetic drum song with ground truth."""
48
+ mix: np.ndarray # full mix audio
49
+ drums_only: np.ndarray # drums-only mix
50
+ sr: int
51
+ bpm: float
52
+ duration: float
53
+ samples: dict # {name: GroundTruthSample}
54
+ hits: list # [PlacedHit, ...]
55
+ per_sample_stems: dict # {name: np.ndarray} isolated stems
56
+ pattern_description: str
57
+
58
+
59
+ # ─────────────────────────────────────────────────────────────────────────────
60
+ # Sample synthesis (parametric drum sounds)
61
+ # ─────────────────────────────────────────────────────────────────────────────
62
+
63
+ def _butter_filter(y, sr, fmin=None, fmax=None, order=4):
64
+ """Apply butterworth bandpass/lowpass/highpass filter."""
65
+ nyq = sr / 2
66
+ if fmin and fmax:
67
+ b, a = butter(order, [fmin / nyq, fmax / nyq], btype='band')
68
+ elif fmin:
69
+ b, a = butter(order, fmin / nyq, btype='high')
70
+ elif fmax:
71
+ b, a = butter(order, fmax / nyq, btype='low')
72
+ else:
73
+ return y
74
+ return filtfilt(b, a, y)
75
+
76
+
77
+ def synthesize_kick(sr: int = 44100, pitch: float = 60.0,
78
+ decay: float = 12.0, punch: float = 150.0,
79
+ duration: float = 0.25, noise_amount: float = 0.05) -> np.ndarray:
80
+ """Synthesize a kick drum: sine sweep + sub thump + click."""
81
+ t = np.arange(int(sr * duration)) / sr
82
+ # Frequency sweep: punch Hz β†’ pitch Hz
83
+ freq = (punch - pitch) * np.exp(-30 * t) + pitch
84
+ phase = 2 * np.pi * np.cumsum(freq / sr)
85
+ body = np.sin(phase) * np.exp(-decay * t)
86
+ # Sub thump
87
+ sub = 0.4 * np.sin(2 * np.pi * pitch * t) * np.exp(-15 * t)
88
+ # Click transient
89
+ click = noise_amount * np.random.randn(len(t)) * np.exp(-200 * t)
90
+ click = _butter_filter(click, sr, fmax=4000)
91
+
92
+ kick = body + sub + click
93
+ kick = kick / (np.abs(kick).max() + 1e-8) * 0.95
94
+ return kick.astype(np.float32)
95
+
96
+
97
+ def synthesize_snare(sr: int = 44100, body_freq: float = 200.0,
98
+ noise_decay: float = 12.0, body_decay: float = 20.0,
99
+ duration: float = 0.25, wire_amount: float = 0.6) -> np.ndarray:
100
+ """Synthesize a snare drum: body tone + noise wires."""
101
+ t = np.arange(int(sr * duration)) / sr
102
+ # Body
103
+ body = np.sin(2 * np.pi * body_freq * t) * np.exp(-body_decay * t) * 0.5
104
+ # Snare wires (filtered noise)
105
+ noise = np.random.randn(len(t)) * np.exp(-noise_decay * t) * wire_amount
106
+ noise = _butter_filter(noise, sr, fmin=1000, fmax=10000)
107
+ # Overtone ring
108
+ ring = 0.15 * np.sin(2 * np.pi * body_freq * 2.7 * t) * np.exp(-25 * t)
109
+
110
+ snare = body + noise + ring
111
+ snare = snare / (np.abs(snare).max() + 1e-8) * 0.95
112
+ return snare.astype(np.float32)
113
+
114
+
115
+ def synthesize_hihat(sr: int = 44100, is_open: bool = False,
116
+ brightness: float = 8000.0,
117
+ duration: float = None) -> np.ndarray:
118
+ """Synthesize a hi-hat: filtered noise with metallic overtones."""
119
+ if duration is None:
120
+ duration = 0.4 if is_open else 0.08
121
+ t = np.arange(int(sr * duration)) / sr
122
+ decay = 6.0 if is_open else 40.0
123
+
124
+ noise = np.random.randn(len(t)) * np.exp(-decay * t)
125
+ noise = _butter_filter(noise, sr, fmin=brightness)
126
+ # Metallic overtones
127
+ metal = 0.2 * np.sin(2 * np.pi * 6500 * t) * np.exp(-(decay + 5) * t)
128
+ metal += 0.1 * np.sin(2 * np.pi * 9200 * t) * np.exp(-(decay + 8) * t)
129
+
130
+ hh = noise + metal
131
+ hh = hh / (np.abs(hh).max() + 1e-8) * 0.7
132
+ return hh.astype(np.float32)
133
+
134
+
135
+ def synthesize_tom(sr: int = 44100, pitch: float = 120.0,
136
+ decay: float = 10.0, duration: float = 0.3) -> np.ndarray:
137
+ """Synthesize a tom: pitched body + slight noise."""
138
+ t = np.arange(int(sr * duration)) / sr
139
+ freq = pitch * 1.5 * np.exp(-8 * t) + pitch
140
+ phase = 2 * np.pi * np.cumsum(freq / sr)
141
+ body = np.sin(phase) * np.exp(-decay * t)
142
+ noise = 0.1 * np.random.randn(len(t)) * np.exp(-20 * t)
143
+ noise = _butter_filter(noise, sr, fmin=200, fmax=3000)
144
+ tom = body + noise
145
+ tom = tom / (np.abs(tom).max() + 1e-8) * 0.9
146
+ return tom.astype(np.float32)
147
+
148
+
149
+ def synthesize_cymbal(sr: int = 44100, duration: float = 1.5) -> np.ndarray:
150
+ """Synthesize a crash/ride cymbal: dense metallic noise."""
151
+ t = np.arange(int(sr * duration)) / sr
152
+ noise = np.random.randn(len(t)) * np.exp(-3 * t)
153
+ noise = _butter_filter(noise, sr, fmin=3000)
154
+ # Multiple metallic partials
155
+ partials = sum(
156
+ (0.15 / (i + 1)) * np.sin(2 * np.pi * f * t) * np.exp(-(2 + i) * t)
157
+ for i, f in enumerate([4200, 5800, 7300, 9100, 11500])
158
+ )
159
+ cym = noise + partials
160
+ cym = cym / (np.abs(cym).max() + 1e-8) * 0.6
161
+ return cym.astype(np.float32)
162
+
163
+
164
+ def synthesize_bass_note(sr: int = 44100, freq: float = 65.0,
165
+ duration: float = 0.5) -> np.ndarray:
166
+ """Synthesize a bass note for adding to the mix (tests Demucs separation)."""
167
+ t = np.arange(int(sr * duration)) / sr
168
+ # Sawtooth-ish bass with harmonics
169
+ wave = (np.sin(2 * np.pi * freq * t) +
170
+ 0.5 * np.sin(2 * np.pi * freq * 2 * t) +
171
+ 0.25 * np.sin(2 * np.pi * freq * 3 * t))
172
+ envelope = np.minimum(t * 50, 1.0) * np.exp(-3 * t) # quick attack, slow decay
173
+ bass = wave * envelope
174
+ bass = _butter_filter(bass, sr, fmax=500)
175
+ bass = bass / (np.abs(bass).max() + 1e-8) * 0.5
176
+ return bass.astype(np.float32)
177
+
178
+
179
+ # ─────────────────────────────────────────────────────────────────────────────
180
+ # Sample set creation with controlled variation
181
+ # ─────────────────────────────────────────────────────────────────────────────
182
+
183
+ def create_sample_set(sr: int = 44100, seed: int = 42,
184
+ variation: str = "medium") -> dict:
185
+ """Create a set of ground-truth drum samples with parametric variation.
186
+
187
+ Args:
188
+ variation: "none" (identical hits), "low", "medium", "high"
189
+ """
190
+ rng = np.random.RandomState(seed)
191
+
192
+ # Base parameters with per-variation noise
193
+ var_scale = {"none": 0.0, "low": 0.05, "medium": 0.15, "high": 0.3}[variation]
194
+
195
+ def vary(val, amount=None):
196
+ a = amount if amount is not None else var_scale
197
+ return val * (1.0 + rng.uniform(-a, a))
198
+
199
+ samples = {
200
+ 'kick': GroundTruthSample(
201
+ name='kick',
202
+ audio=synthesize_kick(sr, pitch=vary(60), decay=vary(12), punch=vary(150)),
203
+ sr=sr,
204
+ frequency_range=(30, 300),
205
+ ),
206
+ 'snare': GroundTruthSample(
207
+ name='snare',
208
+ audio=synthesize_snare(sr, body_freq=vary(200), noise_decay=vary(12)),
209
+ sr=sr,
210
+ frequency_range=(100, 8000),
211
+ ),
212
+ 'hihat_closed': GroundTruthSample(
213
+ name='hihat_closed',
214
+ audio=synthesize_hihat(sr, is_open=False, brightness=vary(8000)),
215
+ sr=sr,
216
+ frequency_range=(3000, 20000),
217
+ ),
218
+ 'hihat_open': GroundTruthSample(
219
+ name='hihat_open',
220
+ audio=synthesize_hihat(sr, is_open=True, brightness=vary(7000)),
221
+ sr=sr,
222
+ frequency_range=(2000, 20000),
223
+ ),
224
+ 'tom': GroundTruthSample(
225
+ name='tom',
226
+ audio=synthesize_tom(sr, pitch=vary(120), decay=vary(10)),
227
+ sr=sr,
228
+ frequency_range=(50, 2000),
229
+ ),
230
+ 'cymbal': GroundTruthSample(
231
+ name='cymbal',
232
+ audio=synthesize_cymbal(sr),
233
+ sr=sr,
234
+ frequency_range=(2000, 20000),
235
+ ),
236
+ }
237
+ return samples
238
+
239
+
240
+ # ─────────────────────────────────────────────────────────────────────────────
241
+ # Pattern generation
242
+ # ─────────────────────────────────────────────────────────────────────────────
243
+
244
+ def generate_basic_rock(bars: int = 4) -> dict:
245
+ """Basic rock pattern. Returns {sample_name: [(beat_position, velocity), ...]}"""
246
+ pattern = {
247
+ 'kick': [],
248
+ 'snare': [],
249
+ 'hihat_closed': [],
250
+ 'hihat_open': [],
251
+ }
252
+ for bar in range(bars):
253
+ offset = bar * 4 # 4 beats per bar
254
+ # Kick on 1 and 3
255
+ pattern['kick'].extend([(offset + 0, 0.9), (offset + 2, 0.85)])
256
+ # Snare on 2 and 4
257
+ pattern['snare'].extend([(offset + 1, 0.85), (offset + 3, 0.9)])
258
+ # HH on every 8th note
259
+ for eighth in range(8):
260
+ vel = 0.6 if eighth % 2 == 0 else 0.4 # accented downbeats
261
+ pattern['hihat_closed'].append((offset + eighth * 0.5, vel))
262
+ # Open hat on "& of 4"
263
+ pattern['hihat_open'].append((offset + 3.5, 0.55))
264
+ return pattern
265
+
266
+
267
+ def generate_funk_pattern(bars: int = 4) -> dict:
268
+ """Funky syncopated pattern with ghost notes."""
269
+ pattern = {
270
+ 'kick': [],
271
+ 'snare': [],
272
+ 'hihat_closed': [],
273
+ 'hihat_open': [],
274
+ 'tom': [],
275
+ }
276
+ for bar in range(bars):
277
+ o = bar * 4
278
+ # Syncopated kick
279
+ pattern['kick'].extend([
280
+ (o + 0, 0.95), (o + 0.75, 0.6), (o + 2, 0.9), (o + 2.5, 0.7)
281
+ ])
282
+ # Snare with ghost notes
283
+ pattern['snare'].extend([
284
+ (o + 1, 0.9), (o + 1.75, 0.3), (o + 3, 0.85), (o + 3.25, 0.25)
285
+ ])
286
+ # 16th note hats
287
+ for sixteenth in range(16):
288
+ vel = 0.5 + 0.2 * (sixteenth % 4 == 0)
289
+ pattern['hihat_closed'].append((o + sixteenth * 0.25, vel))
290
+ # Tom fill on last bar
291
+ if bar == bars - 1:
292
+ pattern['tom'].extend([
293
+ (o + 3, 0.8), (o + 3.25, 0.75), (o + 3.5, 0.85), (o + 3.75, 0.9)
294
+ ])
295
+ return pattern
296
+
297
+
298
+ def generate_halftime_pattern(bars: int = 4) -> dict:
299
+ """Half-time/trap-influenced pattern."""
300
+ pattern = {
301
+ 'kick': [],
302
+ 'snare': [],
303
+ 'hihat_closed': [],
304
+ 'cymbal': [],
305
+ }
306
+ for bar in range(bars):
307
+ o = bar * 4
308
+ # Kick on 1
309
+ pattern['kick'].append((o + 0, 0.95))
310
+ # Occasional double kick
311
+ if bar % 2 == 1:
312
+ pattern['kick'].append((o + 0.5, 0.7))
313
+ # Snare on 3 only (half time)
314
+ pattern['snare'].append((o + 2, 0.9))
315
+ # Fast hats
316
+ for sixteenth in range(16):
317
+ vel = 0.3 + 0.15 * (sixteenth % 2 == 0)
318
+ pattern['hihat_closed'].append((o + sixteenth * 0.25, vel))
319
+ # Crash on bar 1
320
+ if bar == 0:
321
+ pattern['cymbal'].append((o + 0, 0.7))
322
+ return pattern
323
+
324
+
325
+ PATTERNS = {
326
+ 'rock': generate_basic_rock,
327
+ 'funk': generate_funk_pattern,
328
+ 'halftime': generate_halftime_pattern,
329
+ }
330
+
331
+
332
+ # ─────────────────────────────────────────────────────────────────────────────
333
+ # Song assembly
334
+ # ─────────────────────────────────────────────────────────────────────────────
335
+
336
+ def assemble_song(
337
+ samples: dict,
338
+ pattern: dict,
339
+ sr: int = 44100,
340
+ bpm: float = 120.0,
341
+ humanize_timing_ms: float = 5.0,
342
+ humanize_velocity: float = 0.05,
343
+ add_bass: bool = True,
344
+ bass_notes: list = None,
345
+ room_noise_db: float = -60.0,
346
+ seed: int = 42,
347
+ ) -> SyntheticSong:
348
+ """Assemble a complete synthetic song from samples and pattern."""
349
+ rng = np.random.RandomState(seed)
350
+ beat_dur = 60.0 / bpm
351
+
352
+ # Calculate total duration
353
+ all_beats = []
354
+ for name, events in pattern.items():
355
+ if events:
356
+ all_beats.extend([e[0] for e in events])
357
+ max_beat = max(all_beats) if all_beats else 4
358
+ total_dur = (max_beat + 2) * beat_dur # add 2 beats of tail
359
+ total_samples = int(total_dur * sr)
360
+
361
+ # Initialize stems
362
+ drums_mix = np.zeros(total_samples, dtype=np.float64)
363
+ per_sample = {name: np.zeros(total_samples, dtype=np.float64) for name in samples}
364
+ hits = []
365
+
366
+ # Place each hit
367
+ for sample_name, events in pattern.items():
368
+ if sample_name not in samples:
369
+ continue
370
+ sample = samples[sample_name]
371
+ for beat_pos, velocity in events:
372
+ # Humanize timing
373
+ timing_offset = rng.normal(0, humanize_timing_ms / 1000.0)
374
+ onset_time = beat_pos * beat_dur + timing_offset
375
+ onset_time = max(0, onset_time)
376
+
377
+ # Humanize velocity
378
+ vel = velocity * (1.0 + rng.uniform(-humanize_velocity, humanize_velocity))
379
+ vel = np.clip(vel, 0.05, 1.0)
380
+
381
+ # Place in timeline
382
+ start = int(onset_time * sr)
383
+ audio = sample.audio * vel
384
+ end = min(start + len(audio), total_samples)
385
+ actual_len = end - start
386
+
387
+ if actual_len <= 0:
388
+ continue
389
+
390
+ drums_mix[start:end] += audio[:actual_len]
391
+ per_sample[sample_name][start:end] += audio[:actual_len]
392
+
393
+ hits.append(PlacedHit(
394
+ sample_name=sample_name,
395
+ onset_time=onset_time,
396
+ velocity=vel,
397
+ audio=audio[:actual_len],
398
+ sr=sr,
399
+ ))
400
+
401
+ # Optional bass line (tests Demucs separation)
402
+ bass_track = np.zeros(total_samples, dtype=np.float64)
403
+ if add_bass:
404
+ if bass_notes is None:
405
+ # Simple root note bass on beat 1 and 3
406
+ bass_notes_list = [(0, 65), (2, 65), (4, 82), (6, 82)]
407
+ # Repeat for all bars
408
+ n_bars = int(max_beat / 4) + 1
409
+ bass_notes = []
410
+ for bar in range(n_bars):
411
+ for beat, freq in bass_notes_list:
412
+ if beat + bar * 4 <= max_beat:
413
+ bass_notes.append((beat + bar * 4, freq))
414
+
415
+ for beat_pos, freq in bass_notes:
416
+ onset = beat_pos * beat_dur
417
+ start = int(onset * sr)
418
+ bass = synthesize_bass_note(sr, freq=freq, duration=beat_dur * 2)
419
+ end = min(start + len(bass), total_samples)
420
+ bass_track[start:end] += bass[:end - start]
421
+
422
+ # Add room noise
423
+ noise = rng.randn(total_samples) * (10 ** (room_noise_db / 20))
424
+
425
+ # Final mix
426
+ full_mix = drums_mix + bass_track + noise
427
+
428
+ # Normalize
429
+ peak = np.abs(full_mix).max()
430
+ if peak > 0:
431
+ scale = 0.9 / peak
432
+ full_mix *= scale
433
+ drums_mix *= scale
434
+ for name in per_sample:
435
+ per_sample[name] *= scale
436
+
437
+ return SyntheticSong(
438
+ mix=full_mix.astype(np.float32),
439
+ drums_only=drums_mix.astype(np.float32),
440
+ sr=sr,
441
+ bpm=bpm,
442
+ duration=total_dur,
443
+ samples=samples,
444
+ hits=hits,
445
+ per_sample_stems=per_sample,
446
+ pattern_description=str({k: len(v) for k, v in pattern.items()}),
447
+ )
448
+
449
+
450
+ def generate_test_song(
451
+ pattern_name: str = 'rock',
452
+ bars: int = 4,
453
+ bpm: float = 120.0,
454
+ sr: int = 44100,
455
+ variation: str = 'medium',
456
+ add_bass: bool = True,
457
+ seed: int = 42,
458
+ ) -> SyntheticSong:
459
+ """High-level function: generate a complete test song with ground truth."""
460
+ samples = create_sample_set(sr=sr, seed=seed, variation=variation)
461
+ pattern_fn = PATTERNS.get(pattern_name, generate_basic_rock)
462
+ pattern = pattern_fn(bars=bars)
463
+
464
+ return assemble_song(
465
+ samples=samples,
466
+ pattern=pattern,
467
+ sr=sr,
468
+ bpm=bpm,
469
+ add_bass=add_bass,
470
+ seed=seed,
471
+ )
472
+
473
+
474
+ def save_ground_truth(song: SyntheticSong, output_dir: str):
475
+ """Save all ground truth data for evaluation."""
476
+ import os
477
+ os.makedirs(output_dir, exist_ok=True)
478
+ os.makedirs(os.path.join(output_dir, 'gt_samples'), exist_ok=True)
479
+ os.makedirs(os.path.join(output_dir, 'gt_stems'), exist_ok=True)
480
+
481
+ # Save mix and drums
482
+ sf.write(os.path.join(output_dir, 'mix.wav'), song.mix, song.sr, subtype='PCM_24')
483
+ sf.write(os.path.join(output_dir, 'drums_only.wav'), song.drums_only, song.sr, subtype='PCM_24')
484
+
485
+ # Save individual samples
486
+ for name, sample in song.samples.items():
487
+ sf.write(os.path.join(output_dir, 'gt_samples', f'{name}.wav'),
488
+ sample.audio, sample.sr, subtype='PCM_24')
489
+
490
+ # Save per-sample stems
491
+ for name, stem in song.per_sample_stems.items():
492
+ sf.write(os.path.join(output_dir, 'gt_stems', f'{name}_stem.wav'),
493
+ stem, song.sr, subtype='PCM_24')
494
+
495
+ # Save hit map
496
+ hit_map = [
497
+ {'sample': h.sample_name, 'onset': h.onset_time, 'velocity': h.velocity}
498
+ for h in song.hits
499
+ ]
500
+ with open(os.path.join(output_dir, 'hit_map.json'), 'w') as f:
501
+ json.dump({
502
+ 'bpm': song.bpm,
503
+ 'duration': song.duration,
504
+ 'sr': song.sr,
505
+ 'pattern': song.pattern_description,
506
+ 'hits': hit_map,
507
+ }, f, indent=2)