""" Synthetic drum song generator with known ground-truth samples. Generates realistic drum patterns by: 1. Synthesizing individual drum samples (kick, snare, hihat, etc.) with controlled parameters 2. Placing them in musical patterns with velocity variation, timing humanization, and overlap 3. Optionally mixing with bass/harmony for realistic Demucs testing 4. Returning both the mix AND the isolated ground-truth samples + onset map This gives us a perfect evaluation setup: we know exactly which samples are where, so we can compare extracted samples against the originals. """ import numpy as np from scipy.signal import butter, filtfilt, lfilter from dataclasses import dataclass, field from typing import Optional import soundfile as sf import json @dataclass class GroundTruthSample: """A ground-truth drum sample used to build the synthetic song.""" name: str # e.g. "kick", "snare" audio: np.ndarray # the clean one-shot sample sr: int frequency_range: tuple # (low_hz, high_hz) primary energy band @property def duration(self) -> float: return len(self.audio) / self.sr @dataclass class PlacedHit: """A single hit placed in the timeline.""" sample_name: str onset_time: float # seconds velocity: float # 0-1 amplitude multiplier audio: np.ndarray # the actual audio placed (with velocity applied) sr: int @dataclass class SyntheticSong: """A complete synthetic drum song with ground truth.""" mix: np.ndarray # full mix audio drums_only: np.ndarray # drums-only mix sr: int bpm: float duration: float samples: dict # {name: GroundTruthSample} hits: list # [PlacedHit, ...] per_sample_stems: dict # {name: np.ndarray} isolated stems pattern_description: str # ───────────────────────────────────────────────────────────────────────────── # Sample synthesis (parametric drum sounds) # ───────────────────────────────────────────────────────────────────────────── def _butter_filter(y, sr, fmin=None, fmax=None, order=4): """Apply butterworth bandpass/lowpass/highpass filter.""" nyq = sr / 2 if fmin and fmax: b, a = butter(order, [fmin / nyq, fmax / nyq], btype='band') elif fmin: b, a = butter(order, fmin / nyq, btype='high') elif fmax: b, a = butter(order, fmax / nyq, btype='low') else: return y return filtfilt(b, a, y) def synthesize_kick(sr: int = 44100, pitch: float = 60.0, decay: float = 12.0, punch: float = 150.0, duration: float = 0.25, noise_amount: float = 0.05) -> np.ndarray: """Synthesize a kick drum: sine sweep + sub thump + click.""" t = np.arange(int(sr * duration)) / sr # Frequency sweep: punch Hz → pitch Hz freq = (punch - pitch) * np.exp(-30 * t) + pitch phase = 2 * np.pi * np.cumsum(freq / sr) body = np.sin(phase) * np.exp(-decay * t) # Sub thump sub = 0.4 * np.sin(2 * np.pi * pitch * t) * np.exp(-15 * t) # Click transient click = noise_amount * np.random.randn(len(t)) * np.exp(-200 * t) click = _butter_filter(click, sr, fmax=4000) kick = body + sub + click kick = kick / (np.abs(kick).max() + 1e-8) * 0.95 return kick.astype(np.float32) def synthesize_snare(sr: int = 44100, body_freq: float = 200.0, noise_decay: float = 12.0, body_decay: float = 20.0, duration: float = 0.25, wire_amount: float = 0.6) -> np.ndarray: """Synthesize a snare drum: body tone + noise wires.""" t = np.arange(int(sr * duration)) / sr # Body body = np.sin(2 * np.pi * body_freq * t) * np.exp(-body_decay * t) * 0.5 # Snare wires (filtered noise) noise = np.random.randn(len(t)) * np.exp(-noise_decay * t) * wire_amount noise = _butter_filter(noise, sr, fmin=1000, fmax=10000) # Overtone ring ring = 0.15 * np.sin(2 * np.pi * body_freq * 2.7 * t) * np.exp(-25 * t) snare = body + noise + ring snare = snare / (np.abs(snare).max() + 1e-8) * 0.95 return snare.astype(np.float32) def synthesize_hihat(sr: int = 44100, is_open: bool = False, brightness: float = 8000.0, duration: float = None) -> np.ndarray: """Synthesize a hi-hat: filtered noise with metallic overtones.""" if duration is None: duration = 0.4 if is_open else 0.08 t = np.arange(int(sr * duration)) / sr decay = 6.0 if is_open else 40.0 noise = np.random.randn(len(t)) * np.exp(-decay * t) noise = _butter_filter(noise, sr, fmin=brightness) # Metallic overtones metal = 0.2 * np.sin(2 * np.pi * 6500 * t) * np.exp(-(decay + 5) * t) metal += 0.1 * np.sin(2 * np.pi * 9200 * t) * np.exp(-(decay + 8) * t) hh = noise + metal hh = hh / (np.abs(hh).max() + 1e-8) * 0.7 return hh.astype(np.float32) def synthesize_tom(sr: int = 44100, pitch: float = 120.0, decay: float = 10.0, duration: float = 0.3) -> np.ndarray: """Synthesize a tom: pitched body + slight noise.""" t = np.arange(int(sr * duration)) / sr freq = pitch * 1.5 * np.exp(-8 * t) + pitch phase = 2 * np.pi * np.cumsum(freq / sr) body = np.sin(phase) * np.exp(-decay * t) noise = 0.1 * np.random.randn(len(t)) * np.exp(-20 * t) noise = _butter_filter(noise, sr, fmin=200, fmax=3000) tom = body + noise tom = tom / (np.abs(tom).max() + 1e-8) * 0.9 return tom.astype(np.float32) def synthesize_cymbal(sr: int = 44100, duration: float = 1.5) -> np.ndarray: """Synthesize a crash/ride cymbal: dense metallic noise.""" t = np.arange(int(sr * duration)) / sr noise = np.random.randn(len(t)) * np.exp(-3 * t) noise = _butter_filter(noise, sr, fmin=3000) # Multiple metallic partials partials = sum( (0.15 / (i + 1)) * np.sin(2 * np.pi * f * t) * np.exp(-(2 + i) * t) for i, f in enumerate([4200, 5800, 7300, 9100, 11500]) ) cym = noise + partials cym = cym / (np.abs(cym).max() + 1e-8) * 0.6 return cym.astype(np.float32) def synthesize_bass_note(sr: int = 44100, freq: float = 65.0, duration: float = 0.5) -> np.ndarray: """Synthesize a bass note for adding to the mix (tests Demucs separation).""" t = np.arange(int(sr * duration)) / sr # Sawtooth-ish bass with harmonics wave = (np.sin(2 * np.pi * freq * t) + 0.5 * np.sin(2 * np.pi * freq * 2 * t) + 0.25 * np.sin(2 * np.pi * freq * 3 * t)) envelope = np.minimum(t * 50, 1.0) * np.exp(-3 * t) # quick attack, slow decay bass = wave * envelope bass = _butter_filter(bass, sr, fmax=500) bass = bass / (np.abs(bass).max() + 1e-8) * 0.5 return bass.astype(np.float32) # ───────────────────────────────────────────────────────────────────────────── # Sample set creation with controlled variation # ───────────────────────────────────────────────────────────────────────────── def create_sample_set(sr: int = 44100, seed: int = 42, variation: str = "medium") -> dict: """Create a set of ground-truth drum samples with parametric variation. Args: variation: "none" (identical hits), "low", "medium", "high" """ rng = np.random.RandomState(seed) # Base parameters with per-variation noise var_scale = {"none": 0.0, "low": 0.05, "medium": 0.15, "high": 0.3}[variation] def vary(val, amount=None): a = amount if amount is not None else var_scale return val * (1.0 + rng.uniform(-a, a)) samples = { 'kick': GroundTruthSample( name='kick', audio=synthesize_kick(sr, pitch=vary(60), decay=vary(12), punch=vary(150)), sr=sr, frequency_range=(30, 300), ), 'snare': GroundTruthSample( name='snare', audio=synthesize_snare(sr, body_freq=vary(200), noise_decay=vary(12)), sr=sr, frequency_range=(100, 8000), ), 'hihat_closed': GroundTruthSample( name='hihat_closed', audio=synthesize_hihat(sr, is_open=False, brightness=vary(8000)), sr=sr, frequency_range=(3000, 20000), ), 'hihat_open': GroundTruthSample( name='hihat_open', audio=synthesize_hihat(sr, is_open=True, brightness=vary(7000)), sr=sr, frequency_range=(2000, 20000), ), 'tom': GroundTruthSample( name='tom', audio=synthesize_tom(sr, pitch=vary(120), decay=vary(10)), sr=sr, frequency_range=(50, 2000), ), 'cymbal': GroundTruthSample( name='cymbal', audio=synthesize_cymbal(sr), sr=sr, frequency_range=(2000, 20000), ), } return samples # ───────────────────────────────────────────────────────────────────────────── # Pattern generation # ───────────────────────────────────────────────────────────────────────────── def generate_basic_rock(bars: int = 4) -> dict: """Basic rock pattern. Returns {sample_name: [(beat_position, velocity), ...]}""" pattern = { 'kick': [], 'snare': [], 'hihat_closed': [], 'hihat_open': [], } for bar in range(bars): offset = bar * 4 # 4 beats per bar # Kick on 1 and 3 pattern['kick'].extend([(offset + 0, 0.9), (offset + 2, 0.85)]) # Snare on 2 and 4 pattern['snare'].extend([(offset + 1, 0.85), (offset + 3, 0.9)]) # HH on every 8th note for eighth in range(8): vel = 0.6 if eighth % 2 == 0 else 0.4 # accented downbeats pattern['hihat_closed'].append((offset + eighth * 0.5, vel)) # Open hat on "& of 4" pattern['hihat_open'].append((offset + 3.5, 0.55)) return pattern def generate_funk_pattern(bars: int = 4) -> dict: """Funky syncopated pattern with ghost notes.""" pattern = { 'kick': [], 'snare': [], 'hihat_closed': [], 'hihat_open': [], 'tom': [], } for bar in range(bars): o = bar * 4 # Syncopated kick pattern['kick'].extend([ (o + 0, 0.95), (o + 0.75, 0.6), (o + 2, 0.9), (o + 2.5, 0.7) ]) # Snare with ghost notes pattern['snare'].extend([ (o + 1, 0.9), (o + 1.75, 0.3), (o + 3, 0.85), (o + 3.25, 0.25) ]) # 16th note hats for sixteenth in range(16): vel = 0.5 + 0.2 * (sixteenth % 4 == 0) pattern['hihat_closed'].append((o + sixteenth * 0.25, vel)) # Tom fill on last bar if bar == bars - 1: pattern['tom'].extend([ (o + 3, 0.8), (o + 3.25, 0.75), (o + 3.5, 0.85), (o + 3.75, 0.9) ]) return pattern def generate_halftime_pattern(bars: int = 4) -> dict: """Half-time/trap-influenced pattern.""" pattern = { 'kick': [], 'snare': [], 'hihat_closed': [], 'cymbal': [], } for bar in range(bars): o = bar * 4 # Kick on 1 pattern['kick'].append((o + 0, 0.95)) # Occasional double kick if bar % 2 == 1: pattern['kick'].append((o + 0.5, 0.7)) # Snare on 3 only (half time) pattern['snare'].append((o + 2, 0.9)) # Fast hats for sixteenth in range(16): vel = 0.3 + 0.15 * (sixteenth % 2 == 0) pattern['hihat_closed'].append((o + sixteenth * 0.25, vel)) # Crash on bar 1 if bar == 0: pattern['cymbal'].append((o + 0, 0.7)) return pattern PATTERNS = { 'rock': generate_basic_rock, 'funk': generate_funk_pattern, 'halftime': generate_halftime_pattern, } # ───────────────────────────────────────────────────────────────────────────── # Song assembly # ───────────────────────────────────────────────────────────────────────────── def assemble_song( samples: dict, pattern: dict, sr: int = 44100, bpm: float = 120.0, humanize_timing_ms: float = 5.0, humanize_velocity: float = 0.05, add_bass: bool = True, bass_notes: list = None, room_noise_db: float = -60.0, seed: int = 42, ) -> SyntheticSong: """Assemble a complete synthetic song from samples and pattern.""" rng = np.random.RandomState(seed) beat_dur = 60.0 / bpm # Calculate total duration all_beats = [] for name, events in pattern.items(): if events: all_beats.extend([e[0] for e in events]) max_beat = max(all_beats) if all_beats else 4 total_dur = (max_beat + 2) * beat_dur # add 2 beats of tail total_samples = int(total_dur * sr) # Initialize stems drums_mix = np.zeros(total_samples, dtype=np.float64) per_sample = {name: np.zeros(total_samples, dtype=np.float64) for name in samples} hits = [] # Place each hit for sample_name, events in pattern.items(): if sample_name not in samples: continue sample = samples[sample_name] for beat_pos, velocity in events: # Humanize timing timing_offset = rng.normal(0, humanize_timing_ms / 1000.0) onset_time = beat_pos * beat_dur + timing_offset onset_time = max(0, onset_time) # Humanize velocity vel = velocity * (1.0 + rng.uniform(-humanize_velocity, humanize_velocity)) vel = np.clip(vel, 0.05, 1.0) # Place in timeline start = int(onset_time * sr) audio = sample.audio * vel end = min(start + len(audio), total_samples) actual_len = end - start if actual_len <= 0: continue drums_mix[start:end] += audio[:actual_len] per_sample[sample_name][start:end] += audio[:actual_len] hits.append(PlacedHit( sample_name=sample_name, onset_time=onset_time, velocity=vel, audio=audio[:actual_len], sr=sr, )) # Optional bass line (tests Demucs separation) bass_track = np.zeros(total_samples, dtype=np.float64) if add_bass: if bass_notes is None: # Simple root note bass on beat 1 and 3 bass_notes_list = [(0, 65), (2, 65), (4, 82), (6, 82)] # Repeat for all bars n_bars = int(max_beat / 4) + 1 bass_notes = [] for bar in range(n_bars): for beat, freq in bass_notes_list: if beat + bar * 4 <= max_beat: bass_notes.append((beat + bar * 4, freq)) for beat_pos, freq in bass_notes: onset = beat_pos * beat_dur start = int(onset * sr) bass = synthesize_bass_note(sr, freq=freq, duration=beat_dur * 2) end = min(start + len(bass), total_samples) bass_track[start:end] += bass[:end - start] # Add room noise noise = rng.randn(total_samples) * (10 ** (room_noise_db / 20)) # Final mix full_mix = drums_mix + bass_track + noise # Normalize peak = np.abs(full_mix).max() if peak > 0: scale = 0.9 / peak full_mix *= scale drums_mix *= scale for name in per_sample: per_sample[name] *= scale return SyntheticSong( mix=full_mix.astype(np.float32), drums_only=drums_mix.astype(np.float32), sr=sr, bpm=bpm, duration=total_dur, samples=samples, hits=hits, per_sample_stems=per_sample, pattern_description=str({k: len(v) for k, v in pattern.items()}), ) def generate_test_song( pattern_name: str = 'rock', bars: int = 4, bpm: float = 120.0, sr: int = 44100, variation: str = 'medium', add_bass: bool = True, seed: int = 42, ) -> SyntheticSong: """High-level function: generate a complete test song with ground truth.""" samples = create_sample_set(sr=sr, seed=seed, variation=variation) pattern_fn = PATTERNS.get(pattern_name, generate_basic_rock) pattern = pattern_fn(bars=bars) return assemble_song( samples=samples, pattern=pattern, sr=sr, bpm=bpm, add_bass=add_bass, seed=seed, ) def save_ground_truth(song: SyntheticSong, output_dir: str): """Save all ground truth data for evaluation.""" import os os.makedirs(output_dir, exist_ok=True) os.makedirs(os.path.join(output_dir, 'gt_samples'), exist_ok=True) os.makedirs(os.path.join(output_dir, 'gt_stems'), exist_ok=True) # Save mix and drums sf.write(os.path.join(output_dir, 'mix.wav'), song.mix, song.sr, subtype='PCM_24') sf.write(os.path.join(output_dir, 'drums_only.wav'), song.drums_only, song.sr, subtype='PCM_24') # Save individual samples for name, sample in song.samples.items(): sf.write(os.path.join(output_dir, 'gt_samples', f'{name}.wav'), sample.audio, sample.sr, subtype='PCM_24') # Save per-sample stems for name, stem in song.per_sample_stems.items(): sf.write(os.path.join(output_dir, 'gt_stems', f'{name}_stem.wav'), stem, song.sr, subtype='PCM_24') # Save hit map hit_map = [ {'sample': h.sample_name, 'onset': h.onset_time, 'velocity': h.velocity} for h in song.hits ] with open(os.path.join(output_dir, 'hit_map.json'), 'w') as f: json.dump({ 'bpm': song.bpm, 'duration': song.duration, 'sr': song.sr, 'pattern': song.pattern_description, 'hits': hit_map, }, f, indent=2)