Spaces:
Sleeping
Sleeping
v3: sample_extractor.py
Browse files- sample_extractor.py +131 -0
sample_extractor.py
CHANGED
|
@@ -475,6 +475,137 @@ def build_sample_map(clusters: list) -> dict:
|
|
| 475 |
}
|
| 476 |
|
| 477 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
# βββ Main pipeline βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 479 |
|
| 480 |
def run_pipeline(
|
|
|
|
| 475 |
}
|
| 476 |
|
| 477 |
|
| 478 |
+
|
| 479 |
+
# βββ BPM Detection βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 480 |
+
|
| 481 |
+
def detect_bpm(y: np.ndarray, sr: int) -> float:
|
| 482 |
+
"""Detect BPM from audio using onset-strength autocorrelation.
|
| 483 |
+
Handles the common halving/doubling ambiguity."""
|
| 484 |
+
onset_env = librosa.onset.onset_strength(y=y, sr=sr, aggregate=np.median)
|
| 485 |
+
|
| 486 |
+
# Primary estimate
|
| 487 |
+
tempo_arr = librosa.feature.tempo(onset_envelope=onset_env, sr=sr)
|
| 488 |
+
bpm = float(tempo_arr.item())
|
| 489 |
+
|
| 490 |
+
# Cross-check with beat_track inter-beat interval
|
| 491 |
+
_, beats = librosa.beat.beat_track(onset_envelope=onset_env, sr=sr, units='time')
|
| 492 |
+
if len(beats) > 2:
|
| 493 |
+
ibi_bpm = 60.0 / float(np.median(np.diff(beats)))
|
| 494 |
+
# If the two estimates diverge by ~2x, prefer the one in [70, 200]
|
| 495 |
+
for candidate in [bpm, ibi_bpm]:
|
| 496 |
+
if 70 <= candidate <= 200:
|
| 497 |
+
bpm = candidate
|
| 498 |
+
break
|
| 499 |
+
else:
|
| 500 |
+
# Force into reasonable range
|
| 501 |
+
if bpm < 70: bpm *= 2
|
| 502 |
+
elif bpm > 200: bpm /= 2
|
| 503 |
+
|
| 504 |
+
return round(bpm, 1)
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
# βββ MIDI Rendering with extracted samples βββββββββββββββββββββββββββββββββββ
|
| 508 |
+
|
| 509 |
+
def render_midi_with_samples(clusters: list, sr: int = 44100) -> np.ndarray:
|
| 510 |
+
"""Render the MIDI reconstruction back to audio using extracted samples.
|
| 511 |
+
Each cluster's best_hit is placed at every onset, scaled by velocity."""
|
| 512 |
+
# Determine total length
|
| 513 |
+
max_end = 0.0
|
| 514 |
+
for c in clusters:
|
| 515 |
+
for h in c.hits:
|
| 516 |
+
max_end = max(max_end, h.onset_time + h.duration)
|
| 517 |
+
total_samples = int((max_end + 1.0) * sr) # +1s tail
|
| 518 |
+
buf = np.zeros(total_samples, dtype=np.float64)
|
| 519 |
+
|
| 520 |
+
# Build noteβsample lookup
|
| 521 |
+
for c in clusters:
|
| 522 |
+
sample = c.best_hit.audio.astype(np.float64)
|
| 523 |
+
# Compute reference energy for velocity scaling
|
| 524 |
+
ref_energy = c.best_hit.rms_energy if c.best_hit.rms_energy > 0 else 0.1
|
| 525 |
+
|
| 526 |
+
for h in c.hits:
|
| 527 |
+
# Velocity: scale by hit energy relative to best hit
|
| 528 |
+
vel_scale = min(2.0, h.rms_energy / (ref_energy + 1e-8))
|
| 529 |
+
vel_scale = vel_scale ** 0.5 # perceptual square-root scaling
|
| 530 |
+
|
| 531 |
+
start_idx = int(h.onset_time * sr)
|
| 532 |
+
end_idx = start_idx + len(sample)
|
| 533 |
+
if end_idx > len(buf):
|
| 534 |
+
buf = np.concatenate([buf, np.zeros(end_idx - len(buf))])
|
| 535 |
+
buf[start_idx:end_idx] += sample * vel_scale
|
| 536 |
+
|
| 537 |
+
# Normalize
|
| 538 |
+
pk = np.abs(buf).max()
|
| 539 |
+
if pk > 1e-8:
|
| 540 |
+
buf = buf / pk * 0.9
|
| 541 |
+
return buf.astype(np.float32)
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
# βββ ZIP Archive Export ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 545 |
+
|
| 546 |
+
def build_archive(clusters: list, bpm: float, sr: int,
|
| 547 |
+
midi_path: str = None, rendered_audio: np.ndarray = None) -> str:
|
| 548 |
+
"""Build a ZIP archive containing all samples, index, MIDI, and rendered audio.
|
| 549 |
+
Returns path to the ZIP file."""
|
| 550 |
+
import zipfile, tempfile, io
|
| 551 |
+
|
| 552 |
+
zip_path = tempfile.mktemp(suffix='.zip')
|
| 553 |
+
|
| 554 |
+
index = {
|
| 555 |
+
'bpm': round(bpm, 1),
|
| 556 |
+
'sample_rate': sr,
|
| 557 |
+
'total_clusters': len(clusters),
|
| 558 |
+
'total_hits': sum(c.count for c in clusters),
|
| 559 |
+
'samples': {},
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
with zipfile.ZipFile(zip_path, 'w', compression=zipfile.ZIP_STORED) as zf:
|
| 563 |
+
for c in clusters:
|
| 564 |
+
best = c.best_hit
|
| 565 |
+
fname = f"samples/{c.label}.wav"
|
| 566 |
+
|
| 567 |
+
# Write WAV to memory
|
| 568 |
+
wav_buf = io.BytesIO()
|
| 569 |
+
sf.write(wav_buf, best.audio, sr, format='WAV', subtype='PCM_24')
|
| 570 |
+
zf.writestr(fname, wav_buf.getvalue())
|
| 571 |
+
|
| 572 |
+
# Collect all onset times for this cluster
|
| 573 |
+
onset_times = sorted([h.onset_time for h in c.hits])
|
| 574 |
+
|
| 575 |
+
index['samples'][c.label] = {
|
| 576 |
+
'file': fname,
|
| 577 |
+
'classification': c.label.rsplit('_', 1)[0],
|
| 578 |
+
'midi_note': c.midi_note,
|
| 579 |
+
'occurrences': c.count,
|
| 580 |
+
'onset_times_sec': [round(t, 4) for t in onset_times],
|
| 581 |
+
'duration_sec': round(best.duration, 4),
|
| 582 |
+
'rms_energy': round(best.rms_energy, 6),
|
| 583 |
+
'spectral_centroid_hz': round(best.spectral_centroid, 1),
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
# Also include synthesized version if available
|
| 587 |
+
if c.synthesized is not None:
|
| 588 |
+
syn_fname = f"samples/{c.label}__synthesized.wav"
|
| 589 |
+
syn_buf = io.BytesIO()
|
| 590 |
+
sf.write(syn_buf, c.synthesized, sr, format='WAV', subtype='PCM_24')
|
| 591 |
+
zf.writestr(syn_fname, syn_buf.getvalue())
|
| 592 |
+
index['samples'][c.label]['synthesized_file'] = syn_fname
|
| 593 |
+
|
| 594 |
+
# Add index
|
| 595 |
+
zf.writestr('index.json', json.dumps(index, indent=2))
|
| 596 |
+
|
| 597 |
+
# Add MIDI
|
| 598 |
+
if midi_path and os.path.exists(midi_path):
|
| 599 |
+
zf.write(midi_path, 'reconstruction.mid')
|
| 600 |
+
|
| 601 |
+
# Add rendered audio
|
| 602 |
+
if rendered_audio is not None:
|
| 603 |
+
render_buf = io.BytesIO()
|
| 604 |
+
sf.write(render_buf, rendered_audio, sr, format='WAV', subtype='PCM_16')
|
| 605 |
+
zf.writestr('rendered_reconstruction.wav', render_buf.getvalue())
|
| 606 |
+
|
| 607 |
+
return zip_path
|
| 608 |
+
|
| 609 |
# βββ Main pipeline βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 610 |
|
| 611 |
def run_pipeline(
|