drum-sample-extractor / quality_metrics.py
rikhoffbauer2's picture
Add quality_metrics.py
1b8f186 verified
"""
Drum sample quality metrics β€” completeness, cleanness, and overall scoring.
Replaces the naive "60% centroid + 40% energy" selection with production-grade
quality assessment grounded in:
- SI-SDR / BSS_eval (Le Roux et al., 1811.02508)
- MAPSS leakage vs self-distortion framework (Ivry et al., 2509.09212)
- ADT onset precision (Callender et al., 2004.00188)
"""
import numpy as np
import librosa
import scipy.stats
import warnings
# ─────────────────────────────────────────────────────────────────────────────
# Completeness metrics: is the full transient + decay captured?
# ─────────────────────────────────────────────────────────────────────────────
def tail_peak_ratio(y: np.ndarray, sr: int) -> float:
"""C1: ratio of tail energy to peak energy.
Low = good (fully decayed). High = truncated.
"""
rms = librosa.feature.rms(y=y, frame_length=512, hop_length=128)[0]
if len(rms) < 10:
return 1.0 # too short to evaluate
peak_idx = np.argmax(rms)
post = rms[peak_idx:]
if len(post) < 5:
return 0.5
tail_energy = np.mean(post[-max(3, len(post)//5):])
return float(tail_energy / (rms[peak_idx] + 1e-8))
def decay_linearity(y: np.ndarray, sr: int) -> tuple[float, bool]:
"""C2: RΒ² of log-linear decay fit. High RΒ² = clean exponential decay.
Returns (r_squared, is_decaying).
"""
rms = librosa.feature.rms(y=y, frame_length=512, hop_length=128)[0]
if len(rms) < 10:
return 0.0, False
peak_idx = np.argmax(rms)
post = rms[peak_idx:]
if len(post) < 5:
return 0.0, False
x = np.arange(len(post))
log_rms = np.log(post + 1e-8)
slope, _, r, _, _ = scipy.stats.linregress(x, log_rms)
return float(r ** 2), bool(slope < 0)
def temporal_centroid_ms(y: np.ndarray, sr: int) -> float:
"""C3: temporal centroid in milliseconds. Where the 'center of mass' of the
energy is. Too early = truncated; too late = bleed-dominated."""
rms = librosa.feature.rms(y=y, frame_length=512, hop_length=128)[0]
times = np.arange(len(rms)) * 128 / sr
total = np.sum(rms ** 2) + 1e-8
tc = np.sum(times * rms ** 2) / total
return float(tc * 1000)
# Expected temporal centroid ranges per drum type (milliseconds)
TC_RANGES = {
'kick': (15, 100),
'snare': (8, 60),
'hihat': (3, 30),
'hihat_closed': (3, 20),
'hihat_open': (5, 50),
'tom': (10, 80),
'cymbal': (10, 100),
'perc_high': (3, 40),
'perc_low': (10, 80),
}
def compute_completeness(y: np.ndarray, sr: int, drum_type: str = 'kick') -> float:
"""Composite completeness score [0, 1]. Higher = more complete."""
# C1: tail/peak ratio
tr = tail_peak_ratio(y, sr)
c1 = max(0.0, 1.0 - tr * 5) # 0.0 at tr=0.2, 1.0 at tr=0.0
# C2: decay linearity
r2, decaying = decay_linearity(y, sr)
c2 = r2 if decaying else r2 * 0.3 # penalize non-decaying
# C3: temporal centroid in expected range
tc = temporal_centroid_ms(y, sr)
lo, hi = TC_RANGES.get(drum_type, (5, 150))
if lo <= tc <= hi:
c3 = 1.0
elif tc < lo:
c3 = max(0.2, tc / lo) # too early = potentially truncated pre-onset
else:
c3 = max(0.2, hi / tc) # too late = bleed extending the sound
return float(c1 * 0.50 + c2 * 0.30 + c3 * 0.20)
# ─────────────────────────────────────────────────────────────────────────────
# Cleanness metrics: absence of bleed and artifacts
# ─────────────────────────────────────────────────────────────────────────────
# Spectral band definitions per drum type: (signal_band, bleed_band, threshold_dB)
SPECTRAL_BANDS = {
'kick': ((30, 300), (3000, 20000), 20),
'snare': ((100, 8000), (8000, 20000), 10),
'hihat': ((3000, 20000), (20, 200), 20),
'hihat_closed': ((3000, 20000), (20, 200), 20),
'hihat_open': ((2000, 20000), (20, 200), 18),
'tom': ((50, 2000), (4000, 20000), 15),
'cymbal': ((2000, 20000), (20, 300), 18),
'perc_high': ((2000, 20000), (20, 500), 15),
'perc_low': ((30, 2000), (4000, 20000), 15),
}
def pre_onset_energy_db(y: np.ndarray, sr: int) -> float:
"""N1: energy ratio of pre-onset region vs signal region (dB).
Very negative = clean start. Near 0 = pre-noise/bleed."""
onsets = librosa.onset.onset_detect(y=y, sr=sr, units='samples', backtrack=True)
if len(onsets) == 0:
return -20.0 # assume decent if no onset found
os = int(onsets[0])
pre_len = int(sr * 0.02) # 20ms before onset
sig_len = int(sr * 0.1) # 100ms of signal
pre = y[max(0, os - pre_len):os]
sig = y[os:os + sig_len]
if len(pre) < 10 or len(sig) < 10:
return -20.0
pre_e = np.mean(pre ** 2) + 1e-12
sig_e = np.mean(sig ** 2) + 1e-12
return float(10 * np.log10(pre_e / sig_e))
def spectral_signal_to_bleed(y: np.ndarray, sr: int,
drum_type: str = 'kick') -> float:
"""N2: signal-band energy vs bleed-band energy (dB). Higher = cleaner."""
D = np.abs(librosa.stft(y, n_fft=2048))
freqs = librosa.fft_frequencies(sr=sr, n_fft=2048)
sb, bb, _ = SPECTRAL_BANDS.get(drum_type, ((50, 8000), (8000, 20000), 15))
sig_mask = (freqs >= sb[0]) & (freqs < sb[1])
ble_mask = (freqs >= bb[0]) & (freqs < bb[1])
sig_e = D[sig_mask].mean() + 1e-8
ble_e = D[ble_mask].mean() + 1e-8
return float(20 * np.log10(sig_e / ble_e))
def tail_zcr(y: np.ndarray, sr: int) -> float:
"""N3: zero-crossing rate in the tail. High ZCR = likely cymbal/hihat bleed."""
# Start tail at 100ms post-attack
tail_start = int(sr * 0.1)
if tail_start >= len(y) - 100:
return 0.0
tail = y[tail_start:]
return float(np.mean(librosa.feature.zero_crossing_rate(y=tail)))
def robust_snr_db(y: np.ndarray) -> float:
"""N4: percentile-based SNR estimate. Robust to single-sample spikes."""
y_sq = y ** 2
peak = np.percentile(y_sq, 99) + 1e-12
noise = np.percentile(y_sq, 10) + 1e-12
return float(10 * np.log10(peak / noise))
def compute_cleanness(y: np.ndarray, sr: int, drum_type: str = 'kick') -> float:
"""Composite cleanness score [0, 1]. Higher = cleaner."""
# N1: pre-onset energy
pre_db = pre_onset_energy_db(y, sr)
n1 = np.clip((-pre_db - 5) / 30, 0, 1) # -35dB β†’ 1.0, -5dB β†’ 0.0
# N2: spectral SBL
_, _, thresh = SPECTRAL_BANDS.get(drum_type, ((50, 8000), (8000, 20000), 15))
sbl = spectral_signal_to_bleed(y, sr, drum_type)
n2 = np.clip((sbl - thresh) / 30, 0, 1)
# N3: tail ZCR (relevant mainly for kick/tom where cymbal bleed is obvious)
if drum_type in ('kick', 'tom', 'perc_low'):
zcr = tail_zcr(y, sr)
n3 = np.clip(1.0 - zcr * 10, 0, 1)
else:
n3 = 0.7 # neutral for non-kick types
# N4: robust SNR
snr = robust_snr_db(y)
n4 = np.clip((snr - 10) / 40, 0, 1)
return float(n1 * 0.30 + n2 * 0.35 + n3 * 0.15 + n4 * 0.20)
# ─────────────────────────────────────────────────────────────────────────────
# Onset quality
# ─────────────────────────────────────────────────────────────────────────────
def onset_sharpness(y: np.ndarray, sr: int) -> float:
"""Onset transient sharpness: peak onset strength / mean.
High = punchy attack. Low = mushy/missed transient."""
onset_env = librosa.onset.onset_strength(y=y, sr=sr)
if len(onset_env) < 2:
return 1.0
return float(np.max(onset_env) / (np.mean(onset_env) + 1e-8))
def compute_onset_quality(y: np.ndarray, sr: int) -> float:
"""Onset quality score [0, 1]."""
sharpness = onset_sharpness(y, sr)
# sharpness > 5 = excellent, 1 = terrible
return float(np.clip((sharpness - 1.0) / 5.0, 0, 1))
# ─────────────────────────────────────────────────────────────────────────────
# Combined score
# ─────────────────────────────────────────────────────────────────────────────
def drum_sample_score(y: np.ndarray, sr: int, drum_type: str = 'kick',
centroid_dist: float = 0.0,
cluster_radius: float = 1.0) -> dict:
"""
Production-quality score for a drum sample.
Returns dict with individual components and total score [0, 100].
Weights: cleanness 40%, completeness 30%, onset 20%, representativeness 10%.
"""
C = compute_completeness(y, sr, drum_type)
N = compute_cleanness(y, sr, drum_type)
O = compute_onset_quality(y, sr)
R = 1.0 / (1.0 + centroid_dist / (cluster_radius + 1e-8))
total = (C * 0.30 + N * 0.40 + O * 0.20 + R * 0.10) * 100
return {
'total': float(total),
'completeness': float(C),
'cleanness': float(N),
'onset_quality': float(O),
'representativeness': float(R),
'components': {
'tail_peak_ratio': tail_peak_ratio(y, sr),
'temporal_centroid_ms': temporal_centroid_ms(y, sr),
'pre_onset_db': pre_onset_energy_db(y, sr),
'spectral_sbl_db': spectral_signal_to_bleed(y, sr, drum_type),
'robust_snr_db': robust_snr_db(y),
'onset_sharpness': onset_sharpness(y, sr),
}
}
# ─────────────────────────────────────────────────────────────────────────────
# Reference-based metrics (for evaluation against ground truth)
# ─────────────────────────────────────────────────────────────────────────────
def compute_si_sdr(ref: np.ndarray, est: np.ndarray) -> float:
"""Scale-Invariant SDR (Le Roux et al. 2019). Primary quality metric."""
ref = ref - ref.mean()
est = est - est.mean()
eps = 1e-8
alpha = np.dot(ref, est) / (np.dot(ref, ref) + eps)
e_target = alpha * ref
e_residual = est - e_target
return float(10 * np.log10(
(np.dot(e_target, e_target) + eps) / (np.dot(e_residual, e_residual) + eps)
))
def compute_spectral_convergence(ref: np.ndarray, est: np.ndarray) -> float:
"""Spectral convergence [0, 1]. Lower = better frequency match."""
n = min(len(ref), len(est))
S_ref = np.abs(librosa.stft(ref[:n])) + 1e-8
S_est = np.abs(librosa.stft(est[:n])) + 1e-8
return float(np.linalg.norm(S_ref - S_est, 'fro') /
(np.linalg.norm(S_ref, 'fro') + 1e-8))
def compute_log_spectral_distance(ref: np.ndarray, est: np.ndarray) -> float:
"""Log spectral distance (dB). Lower = better."""
n = min(len(ref), len(est))
S_ref = np.abs(librosa.stft(ref[:n])) + 1e-8
S_est = np.abs(librosa.stft(est[:n])) + 1e-8
return float(np.mean(np.sqrt(
np.mean((20 * np.log10(S_ref) - 20 * np.log10(S_est)) ** 2, axis=0)
)))
def compute_mfcc_distance(ref: np.ndarray, est: np.ndarray, sr: int) -> float:
"""MFCC cosine distance. Lower = more similar timbre."""
n = min(len(ref), len(est))
mfcc_ref = librosa.feature.mfcc(y=ref[:n], sr=sr, n_mfcc=13).mean(axis=1)
mfcc_est = librosa.feature.mfcc(y=est[:n], sr=sr, n_mfcc=13).mean(axis=1)
return float(np.linalg.norm(mfcc_ref - mfcc_est))
def compute_envelope_correlation(ref: np.ndarray, est: np.ndarray,
hop: int = 512) -> float:
"""Amplitude envelope correlation. Higher = better attack/decay shape match."""
n = min(len(ref), len(est))
ref, est = ref[:n], est[:n]
frames = range(0, n - hop, hop)
if len(frames) < 2:
return 0.0
er = np.array([np.max(np.abs(ref[i:i + hop])) for i in frames])
ee = np.array([np.max(np.abs(est[i:i + hop])) for i in frames])
if er.std() < 1e-8 or ee.std() < 1e-8:
return 0.0
return float(np.corrcoef(er, ee)[0, 1])
def compute_all_reference_metrics(ref: np.ndarray, est: np.ndarray,
sr: int) -> dict:
"""Compute all reference-based metrics between ground truth and extracted sample."""
n = min(len(ref), len(est))
ref_t = ref[:n]
est_t = est[:n]
return {
'SI-SDR (dB)': compute_si_sdr(ref_t, est_t),
'Spectral Convergence': compute_spectral_convergence(ref_t, est_t),
'Log Spectral Distance (dB)': compute_log_spectral_distance(ref_t, est_t),
'MFCC Distance': compute_mfcc_distance(ref_t, est_t, sr),
'Envelope Correlation': compute_envelope_correlation(ref_t, est_t),
}