Spaces:
Sleeping
Sleeping
| """ | |
| Drum sample quality metrics β completeness, cleanness, and overall scoring. | |
| Replaces the naive "60% centroid + 40% energy" selection with production-grade | |
| quality assessment grounded in: | |
| - SI-SDR / BSS_eval (Le Roux et al., 1811.02508) | |
| - MAPSS leakage vs self-distortion framework (Ivry et al., 2509.09212) | |
| - ADT onset precision (Callender et al., 2004.00188) | |
| """ | |
| import numpy as np | |
| import librosa | |
| import scipy.stats | |
| import warnings | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Completeness metrics: is the full transient + decay captured? | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def tail_peak_ratio(y: np.ndarray, sr: int) -> float: | |
| """C1: ratio of tail energy to peak energy. | |
| Low = good (fully decayed). High = truncated. | |
| """ | |
| rms = librosa.feature.rms(y=y, frame_length=512, hop_length=128)[0] | |
| if len(rms) < 10: | |
| return 1.0 # too short to evaluate | |
| peak_idx = np.argmax(rms) | |
| post = rms[peak_idx:] | |
| if len(post) < 5: | |
| return 0.5 | |
| tail_energy = np.mean(post[-max(3, len(post)//5):]) | |
| return float(tail_energy / (rms[peak_idx] + 1e-8)) | |
| def decay_linearity(y: np.ndarray, sr: int) -> tuple[float, bool]: | |
| """C2: RΒ² of log-linear decay fit. High RΒ² = clean exponential decay. | |
| Returns (r_squared, is_decaying). | |
| """ | |
| rms = librosa.feature.rms(y=y, frame_length=512, hop_length=128)[0] | |
| if len(rms) < 10: | |
| return 0.0, False | |
| peak_idx = np.argmax(rms) | |
| post = rms[peak_idx:] | |
| if len(post) < 5: | |
| return 0.0, False | |
| x = np.arange(len(post)) | |
| log_rms = np.log(post + 1e-8) | |
| slope, _, r, _, _ = scipy.stats.linregress(x, log_rms) | |
| return float(r ** 2), bool(slope < 0) | |
| def temporal_centroid_ms(y: np.ndarray, sr: int) -> float: | |
| """C3: temporal centroid in milliseconds. Where the 'center of mass' of the | |
| energy is. Too early = truncated; too late = bleed-dominated.""" | |
| rms = librosa.feature.rms(y=y, frame_length=512, hop_length=128)[0] | |
| times = np.arange(len(rms)) * 128 / sr | |
| total = np.sum(rms ** 2) + 1e-8 | |
| tc = np.sum(times * rms ** 2) / total | |
| return float(tc * 1000) | |
| # Expected temporal centroid ranges per drum type (milliseconds) | |
| TC_RANGES = { | |
| 'kick': (15, 100), | |
| 'snare': (8, 60), | |
| 'hihat': (3, 30), | |
| 'hihat_closed': (3, 20), | |
| 'hihat_open': (5, 50), | |
| 'tom': (10, 80), | |
| 'cymbal': (10, 100), | |
| 'perc_high': (3, 40), | |
| 'perc_low': (10, 80), | |
| } | |
| def compute_completeness(y: np.ndarray, sr: int, drum_type: str = 'kick') -> float: | |
| """Composite completeness score [0, 1]. Higher = more complete.""" | |
| # C1: tail/peak ratio | |
| tr = tail_peak_ratio(y, sr) | |
| c1 = max(0.0, 1.0 - tr * 5) # 0.0 at tr=0.2, 1.0 at tr=0.0 | |
| # C2: decay linearity | |
| r2, decaying = decay_linearity(y, sr) | |
| c2 = r2 if decaying else r2 * 0.3 # penalize non-decaying | |
| # C3: temporal centroid in expected range | |
| tc = temporal_centroid_ms(y, sr) | |
| lo, hi = TC_RANGES.get(drum_type, (5, 150)) | |
| if lo <= tc <= hi: | |
| c3 = 1.0 | |
| elif tc < lo: | |
| c3 = max(0.2, tc / lo) # too early = potentially truncated pre-onset | |
| else: | |
| c3 = max(0.2, hi / tc) # too late = bleed extending the sound | |
| return float(c1 * 0.50 + c2 * 0.30 + c3 * 0.20) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Cleanness metrics: absence of bleed and artifacts | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Spectral band definitions per drum type: (signal_band, bleed_band, threshold_dB) | |
| SPECTRAL_BANDS = { | |
| 'kick': ((30, 300), (3000, 20000), 20), | |
| 'snare': ((100, 8000), (8000, 20000), 10), | |
| 'hihat': ((3000, 20000), (20, 200), 20), | |
| 'hihat_closed': ((3000, 20000), (20, 200), 20), | |
| 'hihat_open': ((2000, 20000), (20, 200), 18), | |
| 'tom': ((50, 2000), (4000, 20000), 15), | |
| 'cymbal': ((2000, 20000), (20, 300), 18), | |
| 'perc_high': ((2000, 20000), (20, 500), 15), | |
| 'perc_low': ((30, 2000), (4000, 20000), 15), | |
| } | |
| def pre_onset_energy_db(y: np.ndarray, sr: int) -> float: | |
| """N1: energy ratio of pre-onset region vs signal region (dB). | |
| Very negative = clean start. Near 0 = pre-noise/bleed.""" | |
| onsets = librosa.onset.onset_detect(y=y, sr=sr, units='samples', backtrack=True) | |
| if len(onsets) == 0: | |
| return -20.0 # assume decent if no onset found | |
| os = int(onsets[0]) | |
| pre_len = int(sr * 0.02) # 20ms before onset | |
| sig_len = int(sr * 0.1) # 100ms of signal | |
| pre = y[max(0, os - pre_len):os] | |
| sig = y[os:os + sig_len] | |
| if len(pre) < 10 or len(sig) < 10: | |
| return -20.0 | |
| pre_e = np.mean(pre ** 2) + 1e-12 | |
| sig_e = np.mean(sig ** 2) + 1e-12 | |
| return float(10 * np.log10(pre_e / sig_e)) | |
| def spectral_signal_to_bleed(y: np.ndarray, sr: int, | |
| drum_type: str = 'kick') -> float: | |
| """N2: signal-band energy vs bleed-band energy (dB). Higher = cleaner.""" | |
| D = np.abs(librosa.stft(y, n_fft=2048)) | |
| freqs = librosa.fft_frequencies(sr=sr, n_fft=2048) | |
| sb, bb, _ = SPECTRAL_BANDS.get(drum_type, ((50, 8000), (8000, 20000), 15)) | |
| sig_mask = (freqs >= sb[0]) & (freqs < sb[1]) | |
| ble_mask = (freqs >= bb[0]) & (freqs < bb[1]) | |
| sig_e = D[sig_mask].mean() + 1e-8 | |
| ble_e = D[ble_mask].mean() + 1e-8 | |
| return float(20 * np.log10(sig_e / ble_e)) | |
| def tail_zcr(y: np.ndarray, sr: int) -> float: | |
| """N3: zero-crossing rate in the tail. High ZCR = likely cymbal/hihat bleed.""" | |
| # Start tail at 100ms post-attack | |
| tail_start = int(sr * 0.1) | |
| if tail_start >= len(y) - 100: | |
| return 0.0 | |
| tail = y[tail_start:] | |
| return float(np.mean(librosa.feature.zero_crossing_rate(y=tail))) | |
| def robust_snr_db(y: np.ndarray) -> float: | |
| """N4: percentile-based SNR estimate. Robust to single-sample spikes.""" | |
| y_sq = y ** 2 | |
| peak = np.percentile(y_sq, 99) + 1e-12 | |
| noise = np.percentile(y_sq, 10) + 1e-12 | |
| return float(10 * np.log10(peak / noise)) | |
| def compute_cleanness(y: np.ndarray, sr: int, drum_type: str = 'kick') -> float: | |
| """Composite cleanness score [0, 1]. Higher = cleaner.""" | |
| # N1: pre-onset energy | |
| pre_db = pre_onset_energy_db(y, sr) | |
| n1 = np.clip((-pre_db - 5) / 30, 0, 1) # -35dB β 1.0, -5dB β 0.0 | |
| # N2: spectral SBL | |
| _, _, thresh = SPECTRAL_BANDS.get(drum_type, ((50, 8000), (8000, 20000), 15)) | |
| sbl = spectral_signal_to_bleed(y, sr, drum_type) | |
| n2 = np.clip((sbl - thresh) / 30, 0, 1) | |
| # N3: tail ZCR (relevant mainly for kick/tom where cymbal bleed is obvious) | |
| if drum_type in ('kick', 'tom', 'perc_low'): | |
| zcr = tail_zcr(y, sr) | |
| n3 = np.clip(1.0 - zcr * 10, 0, 1) | |
| else: | |
| n3 = 0.7 # neutral for non-kick types | |
| # N4: robust SNR | |
| snr = robust_snr_db(y) | |
| n4 = np.clip((snr - 10) / 40, 0, 1) | |
| return float(n1 * 0.30 + n2 * 0.35 + n3 * 0.15 + n4 * 0.20) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Onset quality | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def onset_sharpness(y: np.ndarray, sr: int) -> float: | |
| """Onset transient sharpness: peak onset strength / mean. | |
| High = punchy attack. Low = mushy/missed transient.""" | |
| onset_env = librosa.onset.onset_strength(y=y, sr=sr) | |
| if len(onset_env) < 2: | |
| return 1.0 | |
| return float(np.max(onset_env) / (np.mean(onset_env) + 1e-8)) | |
| def compute_onset_quality(y: np.ndarray, sr: int) -> float: | |
| """Onset quality score [0, 1].""" | |
| sharpness = onset_sharpness(y, sr) | |
| # sharpness > 5 = excellent, 1 = terrible | |
| return float(np.clip((sharpness - 1.0) / 5.0, 0, 1)) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Combined score | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def drum_sample_score(y: np.ndarray, sr: int, drum_type: str = 'kick', | |
| centroid_dist: float = 0.0, | |
| cluster_radius: float = 1.0) -> dict: | |
| """ | |
| Production-quality score for a drum sample. | |
| Returns dict with individual components and total score [0, 100]. | |
| Weights: cleanness 40%, completeness 30%, onset 20%, representativeness 10%. | |
| """ | |
| C = compute_completeness(y, sr, drum_type) | |
| N = compute_cleanness(y, sr, drum_type) | |
| O = compute_onset_quality(y, sr) | |
| R = 1.0 / (1.0 + centroid_dist / (cluster_radius + 1e-8)) | |
| total = (C * 0.30 + N * 0.40 + O * 0.20 + R * 0.10) * 100 | |
| return { | |
| 'total': float(total), | |
| 'completeness': float(C), | |
| 'cleanness': float(N), | |
| 'onset_quality': float(O), | |
| 'representativeness': float(R), | |
| 'components': { | |
| 'tail_peak_ratio': tail_peak_ratio(y, sr), | |
| 'temporal_centroid_ms': temporal_centroid_ms(y, sr), | |
| 'pre_onset_db': pre_onset_energy_db(y, sr), | |
| 'spectral_sbl_db': spectral_signal_to_bleed(y, sr, drum_type), | |
| 'robust_snr_db': robust_snr_db(y), | |
| 'onset_sharpness': onset_sharpness(y, sr), | |
| } | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Reference-based metrics (for evaluation against ground truth) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def compute_si_sdr(ref: np.ndarray, est: np.ndarray) -> float: | |
| """Scale-Invariant SDR (Le Roux et al. 2019). Primary quality metric.""" | |
| ref = ref - ref.mean() | |
| est = est - est.mean() | |
| eps = 1e-8 | |
| alpha = np.dot(ref, est) / (np.dot(ref, ref) + eps) | |
| e_target = alpha * ref | |
| e_residual = est - e_target | |
| return float(10 * np.log10( | |
| (np.dot(e_target, e_target) + eps) / (np.dot(e_residual, e_residual) + eps) | |
| )) | |
| def compute_spectral_convergence(ref: np.ndarray, est: np.ndarray) -> float: | |
| """Spectral convergence [0, 1]. Lower = better frequency match.""" | |
| n = min(len(ref), len(est)) | |
| S_ref = np.abs(librosa.stft(ref[:n])) + 1e-8 | |
| S_est = np.abs(librosa.stft(est[:n])) + 1e-8 | |
| return float(np.linalg.norm(S_ref - S_est, 'fro') / | |
| (np.linalg.norm(S_ref, 'fro') + 1e-8)) | |
| def compute_log_spectral_distance(ref: np.ndarray, est: np.ndarray) -> float: | |
| """Log spectral distance (dB). Lower = better.""" | |
| n = min(len(ref), len(est)) | |
| S_ref = np.abs(librosa.stft(ref[:n])) + 1e-8 | |
| S_est = np.abs(librosa.stft(est[:n])) + 1e-8 | |
| return float(np.mean(np.sqrt( | |
| np.mean((20 * np.log10(S_ref) - 20 * np.log10(S_est)) ** 2, axis=0) | |
| ))) | |
| def compute_mfcc_distance(ref: np.ndarray, est: np.ndarray, sr: int) -> float: | |
| """MFCC cosine distance. Lower = more similar timbre.""" | |
| n = min(len(ref), len(est)) | |
| mfcc_ref = librosa.feature.mfcc(y=ref[:n], sr=sr, n_mfcc=13).mean(axis=1) | |
| mfcc_est = librosa.feature.mfcc(y=est[:n], sr=sr, n_mfcc=13).mean(axis=1) | |
| return float(np.linalg.norm(mfcc_ref - mfcc_est)) | |
| def compute_envelope_correlation(ref: np.ndarray, est: np.ndarray, | |
| hop: int = 512) -> float: | |
| """Amplitude envelope correlation. Higher = better attack/decay shape match.""" | |
| n = min(len(ref), len(est)) | |
| ref, est = ref[:n], est[:n] | |
| frames = range(0, n - hop, hop) | |
| if len(frames) < 2: | |
| return 0.0 | |
| er = np.array([np.max(np.abs(ref[i:i + hop])) for i in frames]) | |
| ee = np.array([np.max(np.abs(est[i:i + hop])) for i in frames]) | |
| if er.std() < 1e-8 or ee.std() < 1e-8: | |
| return 0.0 | |
| return float(np.corrcoef(er, ee)[0, 1]) | |
| def compute_all_reference_metrics(ref: np.ndarray, est: np.ndarray, | |
| sr: int) -> dict: | |
| """Compute all reference-based metrics between ground truth and extracted sample.""" | |
| n = min(len(ref), len(est)) | |
| ref_t = ref[:n] | |
| est_t = est[:n] | |
| return { | |
| 'SI-SDR (dB)': compute_si_sdr(ref_t, est_t), | |
| 'Spectral Convergence': compute_spectral_convergence(ref_t, est_t), | |
| 'Log Spectral Distance (dB)': compute_log_spectral_distance(ref_t, est_t), | |
| 'MFCC Distance': compute_mfcc_distance(ref_t, est_t, sr), | |
| 'Envelope Correlation': compute_envelope_correlation(ref_t, est_t), | |
| } | |