| """ |
| Synthetic drum song generator with known ground-truth samples. |
| |
| Generates realistic drum patterns by: |
| 1. Synthesizing individual drum samples (kick, snare, hihat, etc.) with controlled parameters |
| 2. Placing them in musical patterns with velocity variation, timing humanization, and overlap |
| 3. Optionally mixing with bass/harmony for realistic Demucs testing |
| 4. Returning both the mix AND the isolated ground-truth samples + onset map |
| |
| This gives us a perfect evaluation setup: we know exactly which samples are where, |
| so we can compare extracted samples against the originals. |
| """ |
|
|
| import numpy as np |
| from scipy.signal import butter, filtfilt, lfilter |
| from dataclasses import dataclass, field |
| from typing import Optional |
| import soundfile as sf |
| import json |
|
|
|
|
| @dataclass |
| class GroundTruthSample: |
| """A ground-truth drum sample used to build the synthetic song.""" |
| name: str |
| audio: np.ndarray |
| sr: int |
| frequency_range: tuple |
|
|
| @property |
| def duration(self) -> float: |
| return len(self.audio) / self.sr |
|
|
|
|
| @dataclass |
| class PlacedHit: |
| """A single hit placed in the timeline.""" |
| sample_name: str |
| onset_time: float |
| velocity: float |
| audio: np.ndarray |
| sr: int |
|
|
|
|
| @dataclass |
| class SyntheticSong: |
| """A complete synthetic drum song with ground truth.""" |
| mix: np.ndarray |
| drums_only: np.ndarray |
| sr: int |
| bpm: float |
| duration: float |
| samples: dict |
| hits: list |
| per_sample_stems: dict |
| pattern_description: str |
|
|
|
|
| |
| |
| |
|
|
| def _butter_filter(y, sr, fmin=None, fmax=None, order=4): |
| """Apply butterworth bandpass/lowpass/highpass filter.""" |
| nyq = sr / 2 |
| if fmin and fmax: |
| b, a = butter(order, [fmin / nyq, fmax / nyq], btype='band') |
| elif fmin: |
| b, a = butter(order, fmin / nyq, btype='high') |
| elif fmax: |
| b, a = butter(order, fmax / nyq, btype='low') |
| else: |
| return y |
| return filtfilt(b, a, y) |
|
|
|
|
| def synthesize_kick(sr: int = 44100, pitch: float = 60.0, |
| decay: float = 12.0, punch: float = 150.0, |
| duration: float = 0.25, noise_amount: float = 0.05) -> np.ndarray: |
| """Synthesize a kick drum: sine sweep + sub thump + click.""" |
| t = np.arange(int(sr * duration)) / sr |
| |
| freq = (punch - pitch) * np.exp(-30 * t) + pitch |
| phase = 2 * np.pi * np.cumsum(freq / sr) |
| body = np.sin(phase) * np.exp(-decay * t) |
| |
| sub = 0.4 * np.sin(2 * np.pi * pitch * t) * np.exp(-15 * t) |
| |
| click = noise_amount * np.random.randn(len(t)) * np.exp(-200 * t) |
| click = _butter_filter(click, sr, fmax=4000) |
| |
| kick = body + sub + click |
| kick = kick / (np.abs(kick).max() + 1e-8) * 0.95 |
| return kick.astype(np.float32) |
|
|
|
|
| def synthesize_snare(sr: int = 44100, body_freq: float = 200.0, |
| noise_decay: float = 12.0, body_decay: float = 20.0, |
| duration: float = 0.25, wire_amount: float = 0.6) -> np.ndarray: |
| """Synthesize a snare drum: body tone + noise wires.""" |
| t = np.arange(int(sr * duration)) / sr |
| |
| body = np.sin(2 * np.pi * body_freq * t) * np.exp(-body_decay * t) * 0.5 |
| |
| noise = np.random.randn(len(t)) * np.exp(-noise_decay * t) * wire_amount |
| noise = _butter_filter(noise, sr, fmin=1000, fmax=10000) |
| |
| ring = 0.15 * np.sin(2 * np.pi * body_freq * 2.7 * t) * np.exp(-25 * t) |
| |
| snare = body + noise + ring |
| snare = snare / (np.abs(snare).max() + 1e-8) * 0.95 |
| return snare.astype(np.float32) |
|
|
|
|
| def synthesize_hihat(sr: int = 44100, is_open: bool = False, |
| brightness: float = 8000.0, |
| duration: float = None) -> np.ndarray: |
| """Synthesize a hi-hat: filtered noise with metallic overtones.""" |
| if duration is None: |
| duration = 0.4 if is_open else 0.08 |
| t = np.arange(int(sr * duration)) / sr |
| decay = 6.0 if is_open else 40.0 |
|
|
| noise = np.random.randn(len(t)) * np.exp(-decay * t) |
| noise = _butter_filter(noise, sr, fmin=brightness) |
| |
| metal = 0.2 * np.sin(2 * np.pi * 6500 * t) * np.exp(-(decay + 5) * t) |
| metal += 0.1 * np.sin(2 * np.pi * 9200 * t) * np.exp(-(decay + 8) * t) |
|
|
| hh = noise + metal |
| hh = hh / (np.abs(hh).max() + 1e-8) * 0.7 |
| return hh.astype(np.float32) |
|
|
|
|
| def synthesize_tom(sr: int = 44100, pitch: float = 120.0, |
| decay: float = 10.0, duration: float = 0.3) -> np.ndarray: |
| """Synthesize a tom: pitched body + slight noise.""" |
| t = np.arange(int(sr * duration)) / sr |
| freq = pitch * 1.5 * np.exp(-8 * t) + pitch |
| phase = 2 * np.pi * np.cumsum(freq / sr) |
| body = np.sin(phase) * np.exp(-decay * t) |
| noise = 0.1 * np.random.randn(len(t)) * np.exp(-20 * t) |
| noise = _butter_filter(noise, sr, fmin=200, fmax=3000) |
| tom = body + noise |
| tom = tom / (np.abs(tom).max() + 1e-8) * 0.9 |
| return tom.astype(np.float32) |
|
|
|
|
| def synthesize_cymbal(sr: int = 44100, duration: float = 1.5) -> np.ndarray: |
| """Synthesize a crash/ride cymbal: dense metallic noise.""" |
| t = np.arange(int(sr * duration)) / sr |
| noise = np.random.randn(len(t)) * np.exp(-3 * t) |
| noise = _butter_filter(noise, sr, fmin=3000) |
| |
| partials = sum( |
| (0.15 / (i + 1)) * np.sin(2 * np.pi * f * t) * np.exp(-(2 + i) * t) |
| for i, f in enumerate([4200, 5800, 7300, 9100, 11500]) |
| ) |
| cym = noise + partials |
| cym = cym / (np.abs(cym).max() + 1e-8) * 0.6 |
| return cym.astype(np.float32) |
|
|
|
|
| def synthesize_bass_note(sr: int = 44100, freq: float = 65.0, |
| duration: float = 0.5) -> np.ndarray: |
| """Synthesize a bass note for adding to the mix (tests Demucs separation).""" |
| t = np.arange(int(sr * duration)) / sr |
| |
| wave = (np.sin(2 * np.pi * freq * t) + |
| 0.5 * np.sin(2 * np.pi * freq * 2 * t) + |
| 0.25 * np.sin(2 * np.pi * freq * 3 * t)) |
| envelope = np.minimum(t * 50, 1.0) * np.exp(-3 * t) |
| bass = wave * envelope |
| bass = _butter_filter(bass, sr, fmax=500) |
| bass = bass / (np.abs(bass).max() + 1e-8) * 0.5 |
| return bass.astype(np.float32) |
|
|
|
|
| |
| |
| |
|
|
| def create_sample_set(sr: int = 44100, seed: int = 42, |
| variation: str = "medium") -> dict: |
| """Create a set of ground-truth drum samples with parametric variation. |
| |
| Args: |
| variation: "none" (identical hits), "low", "medium", "high" |
| """ |
| rng = np.random.RandomState(seed) |
| |
| |
| var_scale = {"none": 0.0, "low": 0.05, "medium": 0.15, "high": 0.3}[variation] |
| |
| def vary(val, amount=None): |
| a = amount if amount is not None else var_scale |
| return val * (1.0 + rng.uniform(-a, a)) |
|
|
| samples = { |
| 'kick': GroundTruthSample( |
| name='kick', |
| audio=synthesize_kick(sr, pitch=vary(60), decay=vary(12), punch=vary(150)), |
| sr=sr, |
| frequency_range=(30, 300), |
| ), |
| 'snare': GroundTruthSample( |
| name='snare', |
| audio=synthesize_snare(sr, body_freq=vary(200), noise_decay=vary(12)), |
| sr=sr, |
| frequency_range=(100, 8000), |
| ), |
| 'hihat_closed': GroundTruthSample( |
| name='hihat_closed', |
| audio=synthesize_hihat(sr, is_open=False, brightness=vary(8000)), |
| sr=sr, |
| frequency_range=(3000, 20000), |
| ), |
| 'hihat_open': GroundTruthSample( |
| name='hihat_open', |
| audio=synthesize_hihat(sr, is_open=True, brightness=vary(7000)), |
| sr=sr, |
| frequency_range=(2000, 20000), |
| ), |
| 'tom': GroundTruthSample( |
| name='tom', |
| audio=synthesize_tom(sr, pitch=vary(120), decay=vary(10)), |
| sr=sr, |
| frequency_range=(50, 2000), |
| ), |
| 'cymbal': GroundTruthSample( |
| name='cymbal', |
| audio=synthesize_cymbal(sr), |
| sr=sr, |
| frequency_range=(2000, 20000), |
| ), |
| } |
| return samples |
|
|
|
|
| |
| |
| |
|
|
| def generate_basic_rock(bars: int = 4) -> dict: |
| """Basic rock pattern. Returns {sample_name: [(beat_position, velocity), ...]}""" |
| pattern = { |
| 'kick': [], |
| 'snare': [], |
| 'hihat_closed': [], |
| 'hihat_open': [], |
| } |
| for bar in range(bars): |
| offset = bar * 4 |
| |
| pattern['kick'].extend([(offset + 0, 0.9), (offset + 2, 0.85)]) |
| |
| pattern['snare'].extend([(offset + 1, 0.85), (offset + 3, 0.9)]) |
| |
| for eighth in range(8): |
| vel = 0.6 if eighth % 2 == 0 else 0.4 |
| pattern['hihat_closed'].append((offset + eighth * 0.5, vel)) |
| |
| pattern['hihat_open'].append((offset + 3.5, 0.55)) |
| return pattern |
|
|
|
|
| def generate_funk_pattern(bars: int = 4) -> dict: |
| """Funky syncopated pattern with ghost notes.""" |
| pattern = { |
| 'kick': [], |
| 'snare': [], |
| 'hihat_closed': [], |
| 'hihat_open': [], |
| 'tom': [], |
| } |
| for bar in range(bars): |
| o = bar * 4 |
| |
| pattern['kick'].extend([ |
| (o + 0, 0.95), (o + 0.75, 0.6), (o + 2, 0.9), (o + 2.5, 0.7) |
| ]) |
| |
| pattern['snare'].extend([ |
| (o + 1, 0.9), (o + 1.75, 0.3), (o + 3, 0.85), (o + 3.25, 0.25) |
| ]) |
| |
| for sixteenth in range(16): |
| vel = 0.5 + 0.2 * (sixteenth % 4 == 0) |
| pattern['hihat_closed'].append((o + sixteenth * 0.25, vel)) |
| |
| if bar == bars - 1: |
| pattern['tom'].extend([ |
| (o + 3, 0.8), (o + 3.25, 0.75), (o + 3.5, 0.85), (o + 3.75, 0.9) |
| ]) |
| return pattern |
|
|
|
|
| def generate_halftime_pattern(bars: int = 4) -> dict: |
| """Half-time/trap-influenced pattern.""" |
| pattern = { |
| 'kick': [], |
| 'snare': [], |
| 'hihat_closed': [], |
| 'cymbal': [], |
| } |
| for bar in range(bars): |
| o = bar * 4 |
| |
| pattern['kick'].append((o + 0, 0.95)) |
| |
| if bar % 2 == 1: |
| pattern['kick'].append((o + 0.5, 0.7)) |
| |
| pattern['snare'].append((o + 2, 0.9)) |
| |
| for sixteenth in range(16): |
| vel = 0.3 + 0.15 * (sixteenth % 2 == 0) |
| pattern['hihat_closed'].append((o + sixteenth * 0.25, vel)) |
| |
| if bar == 0: |
| pattern['cymbal'].append((o + 0, 0.7)) |
| return pattern |
|
|
|
|
| PATTERNS = { |
| 'rock': generate_basic_rock, |
| 'funk': generate_funk_pattern, |
| 'halftime': generate_halftime_pattern, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def assemble_song( |
| samples: dict, |
| pattern: dict, |
| sr: int = 44100, |
| bpm: float = 120.0, |
| humanize_timing_ms: float = 5.0, |
| humanize_velocity: float = 0.05, |
| add_bass: bool = True, |
| bass_notes: list = None, |
| room_noise_db: float = -60.0, |
| seed: int = 42, |
| ) -> SyntheticSong: |
| """Assemble a complete synthetic song from samples and pattern.""" |
| rng = np.random.RandomState(seed) |
| beat_dur = 60.0 / bpm |
|
|
| |
| all_beats = [] |
| for name, events in pattern.items(): |
| if events: |
| all_beats.extend([e[0] for e in events]) |
| max_beat = max(all_beats) if all_beats else 4 |
| total_dur = (max_beat + 2) * beat_dur |
| total_samples = int(total_dur * sr) |
|
|
| |
| drums_mix = np.zeros(total_samples, dtype=np.float64) |
| per_sample = {name: np.zeros(total_samples, dtype=np.float64) for name in samples} |
| hits = [] |
|
|
| |
| for sample_name, events in pattern.items(): |
| if sample_name not in samples: |
| continue |
| sample = samples[sample_name] |
| for beat_pos, velocity in events: |
| |
| timing_offset = rng.normal(0, humanize_timing_ms / 1000.0) |
| onset_time = beat_pos * beat_dur + timing_offset |
| onset_time = max(0, onset_time) |
|
|
| |
| vel = velocity * (1.0 + rng.uniform(-humanize_velocity, humanize_velocity)) |
| vel = np.clip(vel, 0.05, 1.0) |
|
|
| |
| start = int(onset_time * sr) |
| audio = sample.audio * vel |
| end = min(start + len(audio), total_samples) |
| actual_len = end - start |
|
|
| if actual_len <= 0: |
| continue |
|
|
| drums_mix[start:end] += audio[:actual_len] |
| per_sample[sample_name][start:end] += audio[:actual_len] |
|
|
| hits.append(PlacedHit( |
| sample_name=sample_name, |
| onset_time=onset_time, |
| velocity=vel, |
| audio=audio[:actual_len], |
| sr=sr, |
| )) |
|
|
| |
| bass_track = np.zeros(total_samples, dtype=np.float64) |
| if add_bass: |
| if bass_notes is None: |
| |
| bass_notes_list = [(0, 65), (2, 65), (4, 82), (6, 82)] |
| |
| n_bars = int(max_beat / 4) + 1 |
| bass_notes = [] |
| for bar in range(n_bars): |
| for beat, freq in bass_notes_list: |
| if beat + bar * 4 <= max_beat: |
| bass_notes.append((beat + bar * 4, freq)) |
| |
| for beat_pos, freq in bass_notes: |
| onset = beat_pos * beat_dur |
| start = int(onset * sr) |
| bass = synthesize_bass_note(sr, freq=freq, duration=beat_dur * 2) |
| end = min(start + len(bass), total_samples) |
| bass_track[start:end] += bass[:end - start] |
|
|
| |
| noise = rng.randn(total_samples) * (10 ** (room_noise_db / 20)) |
| |
| |
| full_mix = drums_mix + bass_track + noise |
| |
| |
| peak = np.abs(full_mix).max() |
| if peak > 0: |
| scale = 0.9 / peak |
| full_mix *= scale |
| drums_mix *= scale |
| for name in per_sample: |
| per_sample[name] *= scale |
|
|
| return SyntheticSong( |
| mix=full_mix.astype(np.float32), |
| drums_only=drums_mix.astype(np.float32), |
| sr=sr, |
| bpm=bpm, |
| duration=total_dur, |
| samples=samples, |
| hits=hits, |
| per_sample_stems=per_sample, |
| pattern_description=str({k: len(v) for k, v in pattern.items()}), |
| ) |
|
|
|
|
| def generate_test_song( |
| pattern_name: str = 'rock', |
| bars: int = 4, |
| bpm: float = 120.0, |
| sr: int = 44100, |
| variation: str = 'medium', |
| add_bass: bool = True, |
| seed: int = 42, |
| ) -> SyntheticSong: |
| """High-level function: generate a complete test song with ground truth.""" |
| samples = create_sample_set(sr=sr, seed=seed, variation=variation) |
| pattern_fn = PATTERNS.get(pattern_name, generate_basic_rock) |
| pattern = pattern_fn(bars=bars) |
| |
| return assemble_song( |
| samples=samples, |
| pattern=pattern, |
| sr=sr, |
| bpm=bpm, |
| add_bass=add_bass, |
| seed=seed, |
| ) |
|
|
|
|
| def save_ground_truth(song: SyntheticSong, output_dir: str): |
| """Save all ground truth data for evaluation.""" |
| import os |
| os.makedirs(output_dir, exist_ok=True) |
| os.makedirs(os.path.join(output_dir, 'gt_samples'), exist_ok=True) |
| os.makedirs(os.path.join(output_dir, 'gt_stems'), exist_ok=True) |
| |
| |
| sf.write(os.path.join(output_dir, 'mix.wav'), song.mix, song.sr, subtype='PCM_24') |
| sf.write(os.path.join(output_dir, 'drums_only.wav'), song.drums_only, song.sr, subtype='PCM_24') |
| |
| |
| for name, sample in song.samples.items(): |
| sf.write(os.path.join(output_dir, 'gt_samples', f'{name}.wav'), |
| sample.audio, sample.sr, subtype='PCM_24') |
| |
| |
| for name, stem in song.per_sample_stems.items(): |
| sf.write(os.path.join(output_dir, 'gt_stems', f'{name}_stem.wav'), |
| stem, song.sr, subtype='PCM_24') |
| |
| |
| hit_map = [ |
| {'sample': h.sample_name, 'onset': h.onset_time, 'velocity': h.velocity} |
| for h in song.hits |
| ] |
| with open(os.path.join(output_dir, 'hit_map.json'), 'w') as f: |
| json.dump({ |
| 'bpm': song.bpm, |
| 'duration': song.duration, |
| 'sr': song.sr, |
| 'pattern': song.pattern_description, |
| 'hits': hit_map, |
| }, f, indent=2) |
|
|