""" Drum sample quality metrics — completeness, cleanness, and overall scoring. Replaces the naive "60% centroid + 40% energy" selection with production-grade quality assessment grounded in: - SI-SDR / BSS_eval (Le Roux et al., 1811.02508) - MAPSS leakage vs self-distortion framework (Ivry et al., 2509.09212) - ADT onset precision (Callender et al., 2004.00188) """ import numpy as np import librosa import scipy.stats import warnings # ───────────────────────────────────────────────────────────────────────────── # Completeness metrics: is the full transient + decay captured? # ───────────────────────────────────────────────────────────────────────────── def tail_peak_ratio(y: np.ndarray, sr: int) -> float: """C1: ratio of tail energy to peak energy. Low = good (fully decayed). High = truncated. """ rms = librosa.feature.rms(y=y, frame_length=512, hop_length=128)[0] if len(rms) < 10: return 1.0 # too short to evaluate peak_idx = np.argmax(rms) post = rms[peak_idx:] if len(post) < 5: return 0.5 tail_energy = np.mean(post[-max(3, len(post)//5):]) return float(tail_energy / (rms[peak_idx] + 1e-8)) def decay_linearity(y: np.ndarray, sr: int) -> tuple[float, bool]: """C2: R² of log-linear decay fit. High R² = clean exponential decay. Returns (r_squared, is_decaying). """ rms = librosa.feature.rms(y=y, frame_length=512, hop_length=128)[0] if len(rms) < 10: return 0.0, False peak_idx = np.argmax(rms) post = rms[peak_idx:] if len(post) < 5: return 0.0, False x = np.arange(len(post)) log_rms = np.log(post + 1e-8) slope, _, r, _, _ = scipy.stats.linregress(x, log_rms) return float(r ** 2), bool(slope < 0) def temporal_centroid_ms(y: np.ndarray, sr: int) -> float: """C3: temporal centroid in milliseconds. Where the 'center of mass' of the energy is. Too early = truncated; too late = bleed-dominated.""" rms = librosa.feature.rms(y=y, frame_length=512, hop_length=128)[0] times = np.arange(len(rms)) * 128 / sr total = np.sum(rms ** 2) + 1e-8 tc = np.sum(times * rms ** 2) / total return float(tc * 1000) # Expected temporal centroid ranges per drum type (milliseconds) TC_RANGES = { 'kick': (15, 100), 'snare': (8, 60), 'hihat': (3, 30), 'hihat_closed': (3, 20), 'hihat_open': (5, 50), 'tom': (10, 80), 'cymbal': (10, 100), 'perc_high': (3, 40), 'perc_low': (10, 80), } def compute_completeness(y: np.ndarray, sr: int, drum_type: str = 'kick') -> float: """Composite completeness score [0, 1]. Higher = more complete.""" # C1: tail/peak ratio tr = tail_peak_ratio(y, sr) c1 = max(0.0, 1.0 - tr * 5) # 0.0 at tr=0.2, 1.0 at tr=0.0 # C2: decay linearity r2, decaying = decay_linearity(y, sr) c2 = r2 if decaying else r2 * 0.3 # penalize non-decaying # C3: temporal centroid in expected range tc = temporal_centroid_ms(y, sr) lo, hi = TC_RANGES.get(drum_type, (5, 150)) if lo <= tc <= hi: c3 = 1.0 elif tc < lo: c3 = max(0.2, tc / lo) # too early = potentially truncated pre-onset else: c3 = max(0.2, hi / tc) # too late = bleed extending the sound return float(c1 * 0.50 + c2 * 0.30 + c3 * 0.20) # ───────────────────────────────────────────────────────────────────────────── # Cleanness metrics: absence of bleed and artifacts # ───────────────────────────────────────────────────────────────────────────── # Spectral band definitions per drum type: (signal_band, bleed_band, threshold_dB) SPECTRAL_BANDS = { 'kick': ((30, 300), (3000, 20000), 20), 'snare': ((100, 8000), (8000, 20000), 10), 'hihat': ((3000, 20000), (20, 200), 20), 'hihat_closed': ((3000, 20000), (20, 200), 20), 'hihat_open': ((2000, 20000), (20, 200), 18), 'tom': ((50, 2000), (4000, 20000), 15), 'cymbal': ((2000, 20000), (20, 300), 18), 'perc_high': ((2000, 20000), (20, 500), 15), 'perc_low': ((30, 2000), (4000, 20000), 15), } def pre_onset_energy_db(y: np.ndarray, sr: int) -> float: """N1: energy ratio of pre-onset region vs signal region (dB). Very negative = clean start. Near 0 = pre-noise/bleed.""" onsets = librosa.onset.onset_detect(y=y, sr=sr, units='samples', backtrack=True) if len(onsets) == 0: return -20.0 # assume decent if no onset found os = int(onsets[0]) pre_len = int(sr * 0.02) # 20ms before onset sig_len = int(sr * 0.1) # 100ms of signal pre = y[max(0, os - pre_len):os] sig = y[os:os + sig_len] if len(pre) < 10 or len(sig) < 10: return -20.0 pre_e = np.mean(pre ** 2) + 1e-12 sig_e = np.mean(sig ** 2) + 1e-12 return float(10 * np.log10(pre_e / sig_e)) def spectral_signal_to_bleed(y: np.ndarray, sr: int, drum_type: str = 'kick') -> float: """N2: signal-band energy vs bleed-band energy (dB). Higher = cleaner.""" D = np.abs(librosa.stft(y, n_fft=2048)) freqs = librosa.fft_frequencies(sr=sr, n_fft=2048) sb, bb, _ = SPECTRAL_BANDS.get(drum_type, ((50, 8000), (8000, 20000), 15)) sig_mask = (freqs >= sb[0]) & (freqs < sb[1]) ble_mask = (freqs >= bb[0]) & (freqs < bb[1]) sig_e = D[sig_mask].mean() + 1e-8 ble_e = D[ble_mask].mean() + 1e-8 return float(20 * np.log10(sig_e / ble_e)) def tail_zcr(y: np.ndarray, sr: int) -> float: """N3: zero-crossing rate in the tail. High ZCR = likely cymbal/hihat bleed.""" # Start tail at 100ms post-attack tail_start = int(sr * 0.1) if tail_start >= len(y) - 100: return 0.0 tail = y[tail_start:] return float(np.mean(librosa.feature.zero_crossing_rate(y=tail))) def robust_snr_db(y: np.ndarray) -> float: """N4: percentile-based SNR estimate. Robust to single-sample spikes.""" y_sq = y ** 2 peak = np.percentile(y_sq, 99) + 1e-12 noise = np.percentile(y_sq, 10) + 1e-12 return float(10 * np.log10(peak / noise)) def compute_cleanness(y: np.ndarray, sr: int, drum_type: str = 'kick') -> float: """Composite cleanness score [0, 1]. Higher = cleaner.""" # N1: pre-onset energy pre_db = pre_onset_energy_db(y, sr) n1 = np.clip((-pre_db - 5) / 30, 0, 1) # -35dB → 1.0, -5dB → 0.0 # N2: spectral SBL _, _, thresh = SPECTRAL_BANDS.get(drum_type, ((50, 8000), (8000, 20000), 15)) sbl = spectral_signal_to_bleed(y, sr, drum_type) n2 = np.clip((sbl - thresh) / 30, 0, 1) # N3: tail ZCR (relevant mainly for kick/tom where cymbal bleed is obvious) if drum_type in ('kick', 'tom', 'perc_low'): zcr = tail_zcr(y, sr) n3 = np.clip(1.0 - zcr * 10, 0, 1) else: n3 = 0.7 # neutral for non-kick types # N4: robust SNR snr = robust_snr_db(y) n4 = np.clip((snr - 10) / 40, 0, 1) return float(n1 * 0.30 + n2 * 0.35 + n3 * 0.15 + n4 * 0.20) # ───────────────────────────────────────────────────────────────────────────── # Onset quality # ───────────────────────────────────────────────────────────────────────────── def onset_sharpness(y: np.ndarray, sr: int) -> float: """Onset transient sharpness: peak onset strength / mean. High = punchy attack. Low = mushy/missed transient.""" onset_env = librosa.onset.onset_strength(y=y, sr=sr) if len(onset_env) < 2: return 1.0 return float(np.max(onset_env) / (np.mean(onset_env) + 1e-8)) def compute_onset_quality(y: np.ndarray, sr: int) -> float: """Onset quality score [0, 1].""" sharpness = onset_sharpness(y, sr) # sharpness > 5 = excellent, 1 = terrible return float(np.clip((sharpness - 1.0) / 5.0, 0, 1)) # ───────────────────────────────────────────────────────────────────────────── # Combined score # ───────────────────────────────────────────────────────────────────────────── def drum_sample_score(y: np.ndarray, sr: int, drum_type: str = 'kick', centroid_dist: float = 0.0, cluster_radius: float = 1.0) -> dict: """ Production-quality score for a drum sample. Returns dict with individual components and total score [0, 100]. Weights: cleanness 40%, completeness 30%, onset 20%, representativeness 10%. """ C = compute_completeness(y, sr, drum_type) N = compute_cleanness(y, sr, drum_type) O = compute_onset_quality(y, sr) R = 1.0 / (1.0 + centroid_dist / (cluster_radius + 1e-8)) total = (C * 0.30 + N * 0.40 + O * 0.20 + R * 0.10) * 100 return { 'total': float(total), 'completeness': float(C), 'cleanness': float(N), 'onset_quality': float(O), 'representativeness': float(R), 'components': { 'tail_peak_ratio': tail_peak_ratio(y, sr), 'temporal_centroid_ms': temporal_centroid_ms(y, sr), 'pre_onset_db': pre_onset_energy_db(y, sr), 'spectral_sbl_db': spectral_signal_to_bleed(y, sr, drum_type), 'robust_snr_db': robust_snr_db(y), 'onset_sharpness': onset_sharpness(y, sr), } } # ───────────────────────────────────────────────────────────────────────────── # Reference-based metrics (for evaluation against ground truth) # ───────────────────────────────────────────────────────────────────────────── def compute_si_sdr(ref: np.ndarray, est: np.ndarray) -> float: """Scale-Invariant SDR (Le Roux et al. 2019). Primary quality metric.""" ref = ref - ref.mean() est = est - est.mean() eps = 1e-8 alpha = np.dot(ref, est) / (np.dot(ref, ref) + eps) e_target = alpha * ref e_residual = est - e_target return float(10 * np.log10( (np.dot(e_target, e_target) + eps) / (np.dot(e_residual, e_residual) + eps) )) def compute_spectral_convergence(ref: np.ndarray, est: np.ndarray) -> float: """Spectral convergence [0, 1]. Lower = better frequency match.""" n = min(len(ref), len(est)) S_ref = np.abs(librosa.stft(ref[:n])) + 1e-8 S_est = np.abs(librosa.stft(est[:n])) + 1e-8 return float(np.linalg.norm(S_ref - S_est, 'fro') / (np.linalg.norm(S_ref, 'fro') + 1e-8)) def compute_log_spectral_distance(ref: np.ndarray, est: np.ndarray) -> float: """Log spectral distance (dB). Lower = better.""" n = min(len(ref), len(est)) S_ref = np.abs(librosa.stft(ref[:n])) + 1e-8 S_est = np.abs(librosa.stft(est[:n])) + 1e-8 return float(np.mean(np.sqrt( np.mean((20 * np.log10(S_ref) - 20 * np.log10(S_est)) ** 2, axis=0) ))) def compute_mfcc_distance(ref: np.ndarray, est: np.ndarray, sr: int) -> float: """MFCC cosine distance. Lower = more similar timbre.""" n = min(len(ref), len(est)) mfcc_ref = librosa.feature.mfcc(y=ref[:n], sr=sr, n_mfcc=13).mean(axis=1) mfcc_est = librosa.feature.mfcc(y=est[:n], sr=sr, n_mfcc=13).mean(axis=1) return float(np.linalg.norm(mfcc_ref - mfcc_est)) def compute_envelope_correlation(ref: np.ndarray, est: np.ndarray, hop: int = 512) -> float: """Amplitude envelope correlation. Higher = better attack/decay shape match.""" n = min(len(ref), len(est)) ref, est = ref[:n], est[:n] frames = range(0, n - hop, hop) if len(frames) < 2: return 0.0 er = np.array([np.max(np.abs(ref[i:i + hop])) for i in frames]) ee = np.array([np.max(np.abs(est[i:i + hop])) for i in frames]) if er.std() < 1e-8 or ee.std() < 1e-8: return 0.0 return float(np.corrcoef(er, ee)[0, 1]) def compute_all_reference_metrics(ref: np.ndarray, est: np.ndarray, sr: int) -> dict: """Compute all reference-based metrics between ground truth and extracted sample.""" n = min(len(ref), len(est)) ref_t = ref[:n] est_t = est[:n] return { 'SI-SDR (dB)': compute_si_sdr(ref_t, est_t), 'Spectral Convergence': compute_spectral_convergence(ref_t, est_t), 'Log Spectral Distance (dB)': compute_log_spectral_distance(ref_t, est_t), 'MFCC Distance': compute_mfcc_distance(ref_t, est_t, sr), 'Envelope Correlation': compute_envelope_correlation(ref_t, est_t), }