Spaces:
Sleeping
Sleeping
File size: 13,784 Bytes
1b8f186 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 | """
Drum sample quality metrics β completeness, cleanness, and overall scoring.
Replaces the naive "60% centroid + 40% energy" selection with production-grade
quality assessment grounded in:
- SI-SDR / BSS_eval (Le Roux et al., 1811.02508)
- MAPSS leakage vs self-distortion framework (Ivry et al., 2509.09212)
- ADT onset precision (Callender et al., 2004.00188)
"""
import numpy as np
import librosa
import scipy.stats
import warnings
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Completeness metrics: is the full transient + decay captured?
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def tail_peak_ratio(y: np.ndarray, sr: int) -> float:
"""C1: ratio of tail energy to peak energy.
Low = good (fully decayed). High = truncated.
"""
rms = librosa.feature.rms(y=y, frame_length=512, hop_length=128)[0]
if len(rms) < 10:
return 1.0 # too short to evaluate
peak_idx = np.argmax(rms)
post = rms[peak_idx:]
if len(post) < 5:
return 0.5
tail_energy = np.mean(post[-max(3, len(post)//5):])
return float(tail_energy / (rms[peak_idx] + 1e-8))
def decay_linearity(y: np.ndarray, sr: int) -> tuple[float, bool]:
"""C2: RΒ² of log-linear decay fit. High RΒ² = clean exponential decay.
Returns (r_squared, is_decaying).
"""
rms = librosa.feature.rms(y=y, frame_length=512, hop_length=128)[0]
if len(rms) < 10:
return 0.0, False
peak_idx = np.argmax(rms)
post = rms[peak_idx:]
if len(post) < 5:
return 0.0, False
x = np.arange(len(post))
log_rms = np.log(post + 1e-8)
slope, _, r, _, _ = scipy.stats.linregress(x, log_rms)
return float(r ** 2), bool(slope < 0)
def temporal_centroid_ms(y: np.ndarray, sr: int) -> float:
"""C3: temporal centroid in milliseconds. Where the 'center of mass' of the
energy is. Too early = truncated; too late = bleed-dominated."""
rms = librosa.feature.rms(y=y, frame_length=512, hop_length=128)[0]
times = np.arange(len(rms)) * 128 / sr
total = np.sum(rms ** 2) + 1e-8
tc = np.sum(times * rms ** 2) / total
return float(tc * 1000)
# Expected temporal centroid ranges per drum type (milliseconds)
TC_RANGES = {
'kick': (15, 100),
'snare': (8, 60),
'hihat': (3, 30),
'hihat_closed': (3, 20),
'hihat_open': (5, 50),
'tom': (10, 80),
'cymbal': (10, 100),
'perc_high': (3, 40),
'perc_low': (10, 80),
}
def compute_completeness(y: np.ndarray, sr: int, drum_type: str = 'kick') -> float:
"""Composite completeness score [0, 1]. Higher = more complete."""
# C1: tail/peak ratio
tr = tail_peak_ratio(y, sr)
c1 = max(0.0, 1.0 - tr * 5) # 0.0 at tr=0.2, 1.0 at tr=0.0
# C2: decay linearity
r2, decaying = decay_linearity(y, sr)
c2 = r2 if decaying else r2 * 0.3 # penalize non-decaying
# C3: temporal centroid in expected range
tc = temporal_centroid_ms(y, sr)
lo, hi = TC_RANGES.get(drum_type, (5, 150))
if lo <= tc <= hi:
c3 = 1.0
elif tc < lo:
c3 = max(0.2, tc / lo) # too early = potentially truncated pre-onset
else:
c3 = max(0.2, hi / tc) # too late = bleed extending the sound
return float(c1 * 0.50 + c2 * 0.30 + c3 * 0.20)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Cleanness metrics: absence of bleed and artifacts
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Spectral band definitions per drum type: (signal_band, bleed_band, threshold_dB)
SPECTRAL_BANDS = {
'kick': ((30, 300), (3000, 20000), 20),
'snare': ((100, 8000), (8000, 20000), 10),
'hihat': ((3000, 20000), (20, 200), 20),
'hihat_closed': ((3000, 20000), (20, 200), 20),
'hihat_open': ((2000, 20000), (20, 200), 18),
'tom': ((50, 2000), (4000, 20000), 15),
'cymbal': ((2000, 20000), (20, 300), 18),
'perc_high': ((2000, 20000), (20, 500), 15),
'perc_low': ((30, 2000), (4000, 20000), 15),
}
def pre_onset_energy_db(y: np.ndarray, sr: int) -> float:
"""N1: energy ratio of pre-onset region vs signal region (dB).
Very negative = clean start. Near 0 = pre-noise/bleed."""
onsets = librosa.onset.onset_detect(y=y, sr=sr, units='samples', backtrack=True)
if len(onsets) == 0:
return -20.0 # assume decent if no onset found
os = int(onsets[0])
pre_len = int(sr * 0.02) # 20ms before onset
sig_len = int(sr * 0.1) # 100ms of signal
pre = y[max(0, os - pre_len):os]
sig = y[os:os + sig_len]
if len(pre) < 10 or len(sig) < 10:
return -20.0
pre_e = np.mean(pre ** 2) + 1e-12
sig_e = np.mean(sig ** 2) + 1e-12
return float(10 * np.log10(pre_e / sig_e))
def spectral_signal_to_bleed(y: np.ndarray, sr: int,
drum_type: str = 'kick') -> float:
"""N2: signal-band energy vs bleed-band energy (dB). Higher = cleaner."""
D = np.abs(librosa.stft(y, n_fft=2048))
freqs = librosa.fft_frequencies(sr=sr, n_fft=2048)
sb, bb, _ = SPECTRAL_BANDS.get(drum_type, ((50, 8000), (8000, 20000), 15))
sig_mask = (freqs >= sb[0]) & (freqs < sb[1])
ble_mask = (freqs >= bb[0]) & (freqs < bb[1])
sig_e = D[sig_mask].mean() + 1e-8
ble_e = D[ble_mask].mean() + 1e-8
return float(20 * np.log10(sig_e / ble_e))
def tail_zcr(y: np.ndarray, sr: int) -> float:
"""N3: zero-crossing rate in the tail. High ZCR = likely cymbal/hihat bleed."""
# Start tail at 100ms post-attack
tail_start = int(sr * 0.1)
if tail_start >= len(y) - 100:
return 0.0
tail = y[tail_start:]
return float(np.mean(librosa.feature.zero_crossing_rate(y=tail)))
def robust_snr_db(y: np.ndarray) -> float:
"""N4: percentile-based SNR estimate. Robust to single-sample spikes."""
y_sq = y ** 2
peak = np.percentile(y_sq, 99) + 1e-12
noise = np.percentile(y_sq, 10) + 1e-12
return float(10 * np.log10(peak / noise))
def compute_cleanness(y: np.ndarray, sr: int, drum_type: str = 'kick') -> float:
"""Composite cleanness score [0, 1]. Higher = cleaner."""
# N1: pre-onset energy
pre_db = pre_onset_energy_db(y, sr)
n1 = np.clip((-pre_db - 5) / 30, 0, 1) # -35dB β 1.0, -5dB β 0.0
# N2: spectral SBL
_, _, thresh = SPECTRAL_BANDS.get(drum_type, ((50, 8000), (8000, 20000), 15))
sbl = spectral_signal_to_bleed(y, sr, drum_type)
n2 = np.clip((sbl - thresh) / 30, 0, 1)
# N3: tail ZCR (relevant mainly for kick/tom where cymbal bleed is obvious)
if drum_type in ('kick', 'tom', 'perc_low'):
zcr = tail_zcr(y, sr)
n3 = np.clip(1.0 - zcr * 10, 0, 1)
else:
n3 = 0.7 # neutral for non-kick types
# N4: robust SNR
snr = robust_snr_db(y)
n4 = np.clip((snr - 10) / 40, 0, 1)
return float(n1 * 0.30 + n2 * 0.35 + n3 * 0.15 + n4 * 0.20)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Onset quality
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def onset_sharpness(y: np.ndarray, sr: int) -> float:
"""Onset transient sharpness: peak onset strength / mean.
High = punchy attack. Low = mushy/missed transient."""
onset_env = librosa.onset.onset_strength(y=y, sr=sr)
if len(onset_env) < 2:
return 1.0
return float(np.max(onset_env) / (np.mean(onset_env) + 1e-8))
def compute_onset_quality(y: np.ndarray, sr: int) -> float:
"""Onset quality score [0, 1]."""
sharpness = onset_sharpness(y, sr)
# sharpness > 5 = excellent, 1 = terrible
return float(np.clip((sharpness - 1.0) / 5.0, 0, 1))
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Combined score
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def drum_sample_score(y: np.ndarray, sr: int, drum_type: str = 'kick',
centroid_dist: float = 0.0,
cluster_radius: float = 1.0) -> dict:
"""
Production-quality score for a drum sample.
Returns dict with individual components and total score [0, 100].
Weights: cleanness 40%, completeness 30%, onset 20%, representativeness 10%.
"""
C = compute_completeness(y, sr, drum_type)
N = compute_cleanness(y, sr, drum_type)
O = compute_onset_quality(y, sr)
R = 1.0 / (1.0 + centroid_dist / (cluster_radius + 1e-8))
total = (C * 0.30 + N * 0.40 + O * 0.20 + R * 0.10) * 100
return {
'total': float(total),
'completeness': float(C),
'cleanness': float(N),
'onset_quality': float(O),
'representativeness': float(R),
'components': {
'tail_peak_ratio': tail_peak_ratio(y, sr),
'temporal_centroid_ms': temporal_centroid_ms(y, sr),
'pre_onset_db': pre_onset_energy_db(y, sr),
'spectral_sbl_db': spectral_signal_to_bleed(y, sr, drum_type),
'robust_snr_db': robust_snr_db(y),
'onset_sharpness': onset_sharpness(y, sr),
}
}
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Reference-based metrics (for evaluation against ground truth)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def compute_si_sdr(ref: np.ndarray, est: np.ndarray) -> float:
"""Scale-Invariant SDR (Le Roux et al. 2019). Primary quality metric."""
ref = ref - ref.mean()
est = est - est.mean()
eps = 1e-8
alpha = np.dot(ref, est) / (np.dot(ref, ref) + eps)
e_target = alpha * ref
e_residual = est - e_target
return float(10 * np.log10(
(np.dot(e_target, e_target) + eps) / (np.dot(e_residual, e_residual) + eps)
))
def compute_spectral_convergence(ref: np.ndarray, est: np.ndarray) -> float:
"""Spectral convergence [0, 1]. Lower = better frequency match."""
n = min(len(ref), len(est))
S_ref = np.abs(librosa.stft(ref[:n])) + 1e-8
S_est = np.abs(librosa.stft(est[:n])) + 1e-8
return float(np.linalg.norm(S_ref - S_est, 'fro') /
(np.linalg.norm(S_ref, 'fro') + 1e-8))
def compute_log_spectral_distance(ref: np.ndarray, est: np.ndarray) -> float:
"""Log spectral distance (dB). Lower = better."""
n = min(len(ref), len(est))
S_ref = np.abs(librosa.stft(ref[:n])) + 1e-8
S_est = np.abs(librosa.stft(est[:n])) + 1e-8
return float(np.mean(np.sqrt(
np.mean((20 * np.log10(S_ref) - 20 * np.log10(S_est)) ** 2, axis=0)
)))
def compute_mfcc_distance(ref: np.ndarray, est: np.ndarray, sr: int) -> float:
"""MFCC cosine distance. Lower = more similar timbre."""
n = min(len(ref), len(est))
mfcc_ref = librosa.feature.mfcc(y=ref[:n], sr=sr, n_mfcc=13).mean(axis=1)
mfcc_est = librosa.feature.mfcc(y=est[:n], sr=sr, n_mfcc=13).mean(axis=1)
return float(np.linalg.norm(mfcc_ref - mfcc_est))
def compute_envelope_correlation(ref: np.ndarray, est: np.ndarray,
hop: int = 512) -> float:
"""Amplitude envelope correlation. Higher = better attack/decay shape match."""
n = min(len(ref), len(est))
ref, est = ref[:n], est[:n]
frames = range(0, n - hop, hop)
if len(frames) < 2:
return 0.0
er = np.array([np.max(np.abs(ref[i:i + hop])) for i in frames])
ee = np.array([np.max(np.abs(est[i:i + hop])) for i in frames])
if er.std() < 1e-8 or ee.std() < 1e-8:
return 0.0
return float(np.corrcoef(er, ee)[0, 1])
def compute_all_reference_metrics(ref: np.ndarray, est: np.ndarray,
sr: int) -> dict:
"""Compute all reference-based metrics between ground truth and extracted sample."""
n = min(len(ref), len(est))
ref_t = ref[:n]
est_t = est[:n]
return {
'SI-SDR (dB)': compute_si_sdr(ref_t, est_t),
'Spectral Convergence': compute_spectral_convergence(ref_t, est_t),
'Log Spectral Distance (dB)': compute_log_spectral_distance(ref_t, est_t),
'MFCC Distance': compute_mfcc_distance(ref_t, est_t, sr),
'Envelope Correlation': compute_envelope_correlation(ref_t, est_t),
}
|