rikhoffbauer2 commited on
Commit
0a6fa16
Β·
verified Β·
1 Parent(s): 33c4b3e

v3: sample_extractor.py

Browse files
Files changed (1) hide show
  1. sample_extractor.py +131 -0
sample_extractor.py CHANGED
@@ -475,6 +475,137 @@ def build_sample_map(clusters: list) -> dict:
475
  }
476
 
477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
  # ─── Main pipeline ───────────────────────────────────────────────────────────
479
 
480
  def run_pipeline(
 
475
  }
476
 
477
 
478
+
479
+ # ─── BPM Detection ───────────────────────────────────────────────────────────
480
+
481
+ def detect_bpm(y: np.ndarray, sr: int) -> float:
482
+ """Detect BPM from audio using onset-strength autocorrelation.
483
+ Handles the common halving/doubling ambiguity."""
484
+ onset_env = librosa.onset.onset_strength(y=y, sr=sr, aggregate=np.median)
485
+
486
+ # Primary estimate
487
+ tempo_arr = librosa.feature.tempo(onset_envelope=onset_env, sr=sr)
488
+ bpm = float(tempo_arr.item())
489
+
490
+ # Cross-check with beat_track inter-beat interval
491
+ _, beats = librosa.beat.beat_track(onset_envelope=onset_env, sr=sr, units='time')
492
+ if len(beats) > 2:
493
+ ibi_bpm = 60.0 / float(np.median(np.diff(beats)))
494
+ # If the two estimates diverge by ~2x, prefer the one in [70, 200]
495
+ for candidate in [bpm, ibi_bpm]:
496
+ if 70 <= candidate <= 200:
497
+ bpm = candidate
498
+ break
499
+ else:
500
+ # Force into reasonable range
501
+ if bpm < 70: bpm *= 2
502
+ elif bpm > 200: bpm /= 2
503
+
504
+ return round(bpm, 1)
505
+
506
+
507
+ # ─── MIDI Rendering with extracted samples ───────────────────────────────────
508
+
509
+ def render_midi_with_samples(clusters: list, sr: int = 44100) -> np.ndarray:
510
+ """Render the MIDI reconstruction back to audio using extracted samples.
511
+ Each cluster's best_hit is placed at every onset, scaled by velocity."""
512
+ # Determine total length
513
+ max_end = 0.0
514
+ for c in clusters:
515
+ for h in c.hits:
516
+ max_end = max(max_end, h.onset_time + h.duration)
517
+ total_samples = int((max_end + 1.0) * sr) # +1s tail
518
+ buf = np.zeros(total_samples, dtype=np.float64)
519
+
520
+ # Build note→sample lookup
521
+ for c in clusters:
522
+ sample = c.best_hit.audio.astype(np.float64)
523
+ # Compute reference energy for velocity scaling
524
+ ref_energy = c.best_hit.rms_energy if c.best_hit.rms_energy > 0 else 0.1
525
+
526
+ for h in c.hits:
527
+ # Velocity: scale by hit energy relative to best hit
528
+ vel_scale = min(2.0, h.rms_energy / (ref_energy + 1e-8))
529
+ vel_scale = vel_scale ** 0.5 # perceptual square-root scaling
530
+
531
+ start_idx = int(h.onset_time * sr)
532
+ end_idx = start_idx + len(sample)
533
+ if end_idx > len(buf):
534
+ buf = np.concatenate([buf, np.zeros(end_idx - len(buf))])
535
+ buf[start_idx:end_idx] += sample * vel_scale
536
+
537
+ # Normalize
538
+ pk = np.abs(buf).max()
539
+ if pk > 1e-8:
540
+ buf = buf / pk * 0.9
541
+ return buf.astype(np.float32)
542
+
543
+
544
+ # ─── ZIP Archive Export ──────────────────────────────────────────────────────
545
+
546
+ def build_archive(clusters: list, bpm: float, sr: int,
547
+ midi_path: str = None, rendered_audio: np.ndarray = None) -> str:
548
+ """Build a ZIP archive containing all samples, index, MIDI, and rendered audio.
549
+ Returns path to the ZIP file."""
550
+ import zipfile, tempfile, io
551
+
552
+ zip_path = tempfile.mktemp(suffix='.zip')
553
+
554
+ index = {
555
+ 'bpm': round(bpm, 1),
556
+ 'sample_rate': sr,
557
+ 'total_clusters': len(clusters),
558
+ 'total_hits': sum(c.count for c in clusters),
559
+ 'samples': {},
560
+ }
561
+
562
+ with zipfile.ZipFile(zip_path, 'w', compression=zipfile.ZIP_STORED) as zf:
563
+ for c in clusters:
564
+ best = c.best_hit
565
+ fname = f"samples/{c.label}.wav"
566
+
567
+ # Write WAV to memory
568
+ wav_buf = io.BytesIO()
569
+ sf.write(wav_buf, best.audio, sr, format='WAV', subtype='PCM_24')
570
+ zf.writestr(fname, wav_buf.getvalue())
571
+
572
+ # Collect all onset times for this cluster
573
+ onset_times = sorted([h.onset_time for h in c.hits])
574
+
575
+ index['samples'][c.label] = {
576
+ 'file': fname,
577
+ 'classification': c.label.rsplit('_', 1)[0],
578
+ 'midi_note': c.midi_note,
579
+ 'occurrences': c.count,
580
+ 'onset_times_sec': [round(t, 4) for t in onset_times],
581
+ 'duration_sec': round(best.duration, 4),
582
+ 'rms_energy': round(best.rms_energy, 6),
583
+ 'spectral_centroid_hz': round(best.spectral_centroid, 1),
584
+ }
585
+
586
+ # Also include synthesized version if available
587
+ if c.synthesized is not None:
588
+ syn_fname = f"samples/{c.label}__synthesized.wav"
589
+ syn_buf = io.BytesIO()
590
+ sf.write(syn_buf, c.synthesized, sr, format='WAV', subtype='PCM_24')
591
+ zf.writestr(syn_fname, syn_buf.getvalue())
592
+ index['samples'][c.label]['synthesized_file'] = syn_fname
593
+
594
+ # Add index
595
+ zf.writestr('index.json', json.dumps(index, indent=2))
596
+
597
+ # Add MIDI
598
+ if midi_path and os.path.exists(midi_path):
599
+ zf.write(midi_path, 'reconstruction.mid')
600
+
601
+ # Add rendered audio
602
+ if rendered_audio is not None:
603
+ render_buf = io.BytesIO()
604
+ sf.write(render_buf, rendered_audio, sr, format='WAV', subtype='PCM_16')
605
+ zf.writestr('rendered_reconstruction.wav', render_buf.getvalue())
606
+
607
+ return zip_path
608
+
609
  # ─── Main pipeline ───────────────────────────────────────────────────────────
610
 
611
  def run_pipeline(