Ewan Claude Opus 4.6 commited on
Commit
f474670
Β·
1 Parent(s): 1f67352

Merge repeated same-pitch notes into sustains, increase drum detection sensitivity

Browse files

Piano: New merge_repeated_notes() step (7c) checks onset energy at re-attack
points of consecutive same-pitch notes. If re-attack energy is below 1.2x
the median onset strength, the notes are merged into one sustained note
instead of stuttering re-strikes. Preserves real repeated notes with genuine
attack energy.

Drums: Significantly increased detection sensitivity:
- Hi-hat band: delta 0.05->0.03, RMS threshold 0.003->0.001, wait 3->2
- Mid band: delta 0.06->0.04, RMS 0.005->0.003, wait 3->2
- Low band: delta 0.08->0.06
- Added full-band safety net pass (catches hits all sub-bands miss)
- Improved fallback classifier for full-band-only detections
- Merge window 25ms->30ms for better cross-band alignment

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

transcriber/drums.py CHANGED
@@ -1,9 +1,11 @@
1
  """Drum transcription via multi-band onset detection + spectral classification.
2
 
3
  Uses sub-band filtering to detect onsets independently in low/mid/high
4
- frequency ranges, then merges and classifies based on which bands triggered.
5
- This handles simultaneous hits (kick+hihat, kick+snare+hihat) naturally
6
- since each drum occupies different frequency bands.
 
 
7
 
8
  Input: isolated Demucs drums stem (already separated from other instruments).
9
  Output: JSON with lane-based drum events.
@@ -21,7 +23,7 @@ from scipy.signal import butter, filtfilt
21
  LANES = ["crash", "ride", "hihat", "tom_high", "snare", "tom_low", "kick"]
22
 
23
  # Merge tolerance: onsets within this window across bands are the same hit
24
- MERGE_WINDOW = 0.025 # 25ms
25
 
26
 
27
  def _bandpass(y, low, high, sr, order=4):
@@ -48,7 +50,7 @@ def _highpass(y, low, sr, order=4):
48
  def _detect_band_onsets(y_band, sr, hop_length, delta=0.06, wait=3, rms_threshold=0.005):
49
  """Detect onsets in a filtered frequency band.
50
 
51
- Returns list of (time, sample) tuples for onsets above the RMS threshold.
52
  """
53
  onset_env = librosa.onset.onset_strength(
54
  y=y_band, sr=sr, hop_length=hop_length,
@@ -102,7 +104,6 @@ def _classify_from_bands(low_hit, mid_hit, high_hit, y, sr, onset_sample):
102
  n_fft = min(4096, len(segment))
103
  if n_fft < 256:
104
  n_fft = 256
105
- # Zero-pad if segment is shorter than n_fft
106
  if len(segment) < n_fft:
107
  segment = np.pad(segment, (0, n_fft - len(segment)))
108
  fft = np.abs(np.fft.rfft(segment, n=n_fft))
@@ -173,13 +174,11 @@ def _classify_from_bands(low_hit, mid_hit, high_hit, y, sr, onset_sample):
173
 
174
  elif has_low and has_mid and has_high:
175
  # All three bands β†’ snare (full broadband) or complex hit
176
- # Snare triggers all bands: low body, mid fundamental, high snare wires
177
  low_rms = low_hit[2] if low_hit else 0
178
  high_rms = high_hit[2] if high_hit else 0
179
 
180
  if mid_r > 0.15 and flatness > 0.03:
181
  results.append(("snare", velocity))
182
- # If low is much stronger than expected for snare, also a kick
183
  if low_rms > 0.08 and sub_low_r > 0.25:
184
  results.append(("kick", min(1.0, low_rms * 10)))
185
  else:
@@ -192,13 +191,22 @@ def _classify_from_bands(low_hit, mid_hit, high_hit, y, sr, onset_sample):
192
  if low_mid_r > 0.4:
193
  results.append(("tom_high", velocity))
194
  else:
195
- results.append(("snare", velocity * 0.6)) # ghost note
196
 
197
  else:
198
- # Fallback: use spectral features
199
- if low_r > 0.5:
 
 
 
 
 
 
 
 
 
200
  results.append(("kick", velocity))
201
- elif high_r > 0.3:
202
  results.append(("hihat", velocity))
203
  else:
204
  results.append(("snare", velocity))
@@ -210,7 +218,8 @@ def transcribe_drums(audio_path, output_path):
210
  """Transcribe a drums stem to a lane-based drum tab JSON.
211
 
212
  Uses multi-band onset detection: filters the signal into low/mid/high
213
- bands, detects onsets independently in each, then merges and classifies
 
214
  based on which bands triggered at each time point.
215
 
216
  Args:
@@ -235,17 +244,32 @@ def transcribe_drums(audio_path, output_path):
235
  y_high = _highpass(y, 3000, sr) # hi-hat, crash, ride
236
 
237
  # ── Step 2: Per-band onset detection ─────────────────────────────────
 
 
 
 
 
238
  print(" Drums: detecting per-band onsets...")
239
- low_onsets = _detect_band_onsets(y_low, sr, hop_length, delta=0.08, rms_threshold=0.008)
240
- mid_onsets = _detect_band_onsets(y_mid, sr, hop_length, delta=0.06, rms_threshold=0.005)
241
- high_onsets = _detect_band_onsets(y_high, sr, hop_length, delta=0.05, rms_threshold=0.003)
242
- print(f" Low: {len(low_onsets)}, Mid: {len(mid_onsets)}, High: {len(high_onsets)}")
 
 
 
 
 
 
 
 
 
 
 
243
 
244
  # ── Step 3: Merge onsets across bands ────────────────────────────────
245
- # Collect all onset times, then group within MERGE_WINDOW
246
  print(" Drums: merging cross-band onsets...")
247
  all_times = set()
248
- for onsets in [low_onsets, mid_onsets, high_onsets]:
249
  for t, s, r in onsets:
250
  all_times.add(t)
251
 
@@ -262,6 +286,8 @@ def transcribe_drums(audio_path, output_path):
262
  group = [t]
263
  merged_times.append(np.mean(group))
264
 
 
 
265
  # For each merged onset, find which bands triggered
266
  def find_band_hit(band_onsets, target_time):
267
  """Find the band onset closest to target_time within MERGE_WINDOW."""
@@ -292,8 +318,13 @@ def transcribe_drums(audio_path, output_path):
292
  if onset_sample is None or s < onset_sample:
293
  onset_sample = s
294
 
 
295
  if onset_sample is None:
296
- continue
 
 
 
 
297
 
298
  hits = _classify_from_bands(low_hit, mid_hit, high_hit, y, sr, onset_sample)
299
  for lane, vel in hits:
 
1
  """Drum transcription via multi-band onset detection + spectral classification.
2
 
3
  Uses sub-band filtering to detect onsets independently in low/mid/high
4
+ frequency ranges, plus a full-band pass as a safety net. Merges across
5
+ bands and classifies based on which bands triggered.
6
+
7
+ Handles simultaneous hits (kick+hihat, kick+snare+hihat) naturally since
8
+ each drum occupies different frequency bands.
9
 
10
  Input: isolated Demucs drums stem (already separated from other instruments).
11
  Output: JSON with lane-based drum events.
 
23
  LANES = ["crash", "ride", "hihat", "tom_high", "snare", "tom_low", "kick"]
24
 
25
  # Merge tolerance: onsets within this window across bands are the same hit
26
+ MERGE_WINDOW = 0.030 # 30ms
27
 
28
 
29
  def _bandpass(y, low, high, sr, order=4):
 
50
  def _detect_band_onsets(y_band, sr, hop_length, delta=0.06, wait=3, rms_threshold=0.005):
51
  """Detect onsets in a filtered frequency band.
52
 
53
+ Returns list of (time, sample, rms) tuples for onsets above the threshold.
54
  """
55
  onset_env = librosa.onset.onset_strength(
56
  y=y_band, sr=sr, hop_length=hop_length,
 
104
  n_fft = min(4096, len(segment))
105
  if n_fft < 256:
106
  n_fft = 256
 
107
  if len(segment) < n_fft:
108
  segment = np.pad(segment, (0, n_fft - len(segment)))
109
  fft = np.abs(np.fft.rfft(segment, n=n_fft))
 
174
 
175
  elif has_low and has_mid and has_high:
176
  # All three bands β†’ snare (full broadband) or complex hit
 
177
  low_rms = low_hit[2] if low_hit else 0
178
  high_rms = high_hit[2] if high_hit else 0
179
 
180
  if mid_r > 0.15 and flatness > 0.03:
181
  results.append(("snare", velocity))
 
182
  if low_rms > 0.08 and sub_low_r > 0.25:
183
  results.append(("kick", min(1.0, low_rms * 10)))
184
  else:
 
191
  if low_mid_r > 0.4:
192
  results.append(("tom_high", velocity))
193
  else:
194
+ results.append(("snare", velocity * 0.6))
195
 
196
  else:
197
+ # Fallback (only full-band detected, no sub-band): use spectral features
198
+ if low_r > 0.5 and centroid < 400:
199
+ results.append(("kick", velocity))
200
+ elif high_r > 0.35 and centroid > 4000:
201
+ if flatness > 0.15 and velocity > 0.4:
202
+ results.append(("crash", velocity))
203
+ else:
204
+ results.append(("hihat", velocity))
205
+ elif mid_r > 0.3 and flatness > 0.05:
206
+ results.append(("snare", velocity))
207
+ elif centroid < 500:
208
  results.append(("kick", velocity))
209
+ elif centroid > 3000:
210
  results.append(("hihat", velocity))
211
  else:
212
  results.append(("snare", velocity))
 
218
  """Transcribe a drums stem to a lane-based drum tab JSON.
219
 
220
  Uses multi-band onset detection: filters the signal into low/mid/high
221
+ bands, detects onsets independently in each, plus a full-band safety
222
+ net to catch any hits missed by sub-band detection. Merges and classifies
223
  based on which bands triggered at each time point.
224
 
225
  Args:
 
244
  y_high = _highpass(y, 3000, sr) # hi-hat, crash, ride
245
 
246
  # ── Step 2: Per-band onset detection ─────────────────────────────────
247
+ # Sensitivity tuned per band:
248
+ # - Low: moderate delta, kick/toms are loud and clear
249
+ # - Mid: moderate delta, snare is usually prominent
250
+ # - High: LOW delta + low RMS threshold β€” hi-hats are quiet but frequent
251
+ # - Full: catches anything the sub-bands miss
252
  print(" Drums: detecting per-band onsets...")
253
+ low_onsets = _detect_band_onsets(
254
+ y_low, sr, hop_length, delta=0.06, wait=3, rms_threshold=0.005
255
+ )
256
+ mid_onsets = _detect_band_onsets(
257
+ y_mid, sr, hop_length, delta=0.04, wait=2, rms_threshold=0.003
258
+ )
259
+ high_onsets = _detect_band_onsets(
260
+ y_high, sr, hop_length, delta=0.03, wait=2, rms_threshold=0.001
261
+ )
262
+ # Full-band safety net β€” catches hits that sub-band filters miss
263
+ full_onsets = _detect_band_onsets(
264
+ y, sr, hop_length, delta=0.04, wait=2, rms_threshold=0.005
265
+ )
266
+ print(f" Low: {len(low_onsets)}, Mid: {len(mid_onsets)}, "
267
+ f"High: {len(high_onsets)}, Full: {len(full_onsets)}")
268
 
269
  # ── Step 3: Merge onsets across bands ────────────────────────────────
 
270
  print(" Drums: merging cross-band onsets...")
271
  all_times = set()
272
+ for onsets in [low_onsets, mid_onsets, high_onsets, full_onsets]:
273
  for t, s, r in onsets:
274
  all_times.add(t)
275
 
 
286
  group = [t]
287
  merged_times.append(np.mean(group))
288
 
289
+ print(f" {len(merged_times)} merged onsets (from {len(all_times)} raw)")
290
+
291
  # For each merged onset, find which bands triggered
292
  def find_band_hit(band_onsets, target_time):
293
  """Find the band onset closest to target_time within MERGE_WINDOW."""
 
318
  if onset_sample is None or s < onset_sample:
319
  onset_sample = s
320
 
321
+ # If no sub-band hit, use the full-band onset
322
  if onset_sample is None:
323
+ full_hit = find_band_hit(full_onsets, onset_time)
324
+ if full_hit is not None:
325
+ onset_sample = full_hit[1]
326
+ else:
327
+ continue
328
 
329
  hits = _classify_from_bands(low_hit, mid_hit, high_hit, y, sr, onset_sample)
330
  for lane, vel in hits:
transcriber/optimize.py CHANGED
@@ -803,6 +803,85 @@ def remove_hand_outliers(midi_data, hand_split=60, gap_threshold=7):
803
  return midi_out, removed
804
 
805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
806
  def consolidate_rhythm(midi_data, y, sr, hop_length=512, max_snap=0.04):
807
  """Consolidate note onsets onto a dominant rhythmic pattern.
808
 
@@ -1640,6 +1719,11 @@ def optimize(original_audio_path, midi_path, output_path=None):
1640
  midi_data, rhythm_snapped, n_dominant = consolidate_rhythm(midi_data, y, sr, hop_length)
1641
  print(f" Snapped {rhythm_snapped} notes to {n_dominant} dominant subdivisions")
1642
 
 
 
 
 
 
1643
  # Step 8: Fix overlaps and enforce min duration (LAST β€” after all position changes)
1644
  print("\nStep 8: Fixing overlaps and enforcing min duration...")
1645
  midi_data, notes_trimmed, durations_enforced = fix_note_overlap(midi_data)
 
803
  return midi_out, removed
804
 
805
 
806
+ def merge_repeated_notes(midi_data, y, sr, hop_length=512, min_gap=0.15):
807
+ """Merge consecutive same-pitch notes that lack a real re-attack.
808
+
809
+ Basic-pitch often fragments a single sustained note into multiple short
810
+ re-strikes. This step checks whether a repeated note has genuine onset
811
+ energy at the re-attack point. If not, the notes are merged into one
812
+ sustained note.
813
+
814
+ Args:
815
+ min_gap: If the gap between notes is larger than this (seconds),
816
+ always keep separate β€” the silence itself is musical. Default 150ms.
817
+ """
818
+ midi_out = copy.deepcopy(midi_data)
819
+ merged_count = 0
820
+
821
+ # Compute onset strength envelope for verification
822
+ onset_env = librosa.onset.onset_strength(y=y, sr=sr, hop_length=hop_length)
823
+
824
+ for instrument in midi_out.instruments:
825
+ # Sort by pitch then start time to find consecutive same-pitch notes
826
+ notes = sorted(instrument.notes, key=lambda n: (n.pitch, n.start))
827
+ to_remove = set()
828
+
829
+ i = 0
830
+ while i < len(notes) - 1:
831
+ if i in to_remove:
832
+ i += 1
833
+ continue
834
+
835
+ note = notes[i]
836
+ j = i + 1
837
+
838
+ # Walk forward through consecutive same-pitch notes
839
+ while j < len(notes) and notes[j].pitch == note.pitch:
840
+ if j in to_remove:
841
+ j += 1
842
+ continue
843
+
844
+ next_note = notes[j]
845
+ gap = next_note.start - note.end
846
+
847
+ # If there's a real gap (silence), keep them separate
848
+ if gap > min_gap:
849
+ break
850
+
851
+ # If the next note starts before or just after this one ends,
852
+ # check for onset energy at the re-attack point
853
+ reattack_time = next_note.start
854
+ reattack_frame = int(reattack_time * sr / hop_length)
855
+
856
+ has_onset = False
857
+ if 0 <= reattack_frame < len(onset_env):
858
+ # Check onset strength in a small window around the re-attack
859
+ lo = max(0, reattack_frame - 1)
860
+ hi = min(len(onset_env), reattack_frame + 2)
861
+ local_strength = float(np.max(onset_env[lo:hi]))
862
+
863
+ # Compare to the median onset strength β€” if re-attack is
864
+ # weaker than median, it's not a real new attack
865
+ median_strength = float(np.median(onset_env[onset_env > 0])) if np.any(onset_env > 0) else 0
866
+ has_onset = local_strength > median_strength * 1.2
867
+
868
+ if not has_onset:
869
+ # Merge: extend current note to cover the next one
870
+ note.end = max(note.end, next_note.end)
871
+ to_remove.add(j)
872
+ merged_count += 1
873
+ j += 1
874
+ else:
875
+ # Real re-attack β€” stop merging
876
+ break
877
+
878
+ i = j if j > i + 1 else i + 1
879
+
880
+ instrument.notes = [n for k, n in enumerate(notes) if k not in to_remove]
881
+
882
+ return midi_out, merged_count
883
+
884
+
885
  def consolidate_rhythm(midi_data, y, sr, hop_length=512, max_snap=0.04):
886
  """Consolidate note onsets onto a dominant rhythmic pattern.
887
 
 
1719
  midi_data, rhythm_snapped, n_dominant = consolidate_rhythm(midi_data, y, sr, hop_length)
1720
  print(f" Snapped {rhythm_snapped} notes to {n_dominant} dominant subdivisions")
1721
 
1722
+ # Step 7c: Merge repeated consecutive same-pitch notes without real re-attack
1723
+ print("\nStep 7c: Merging repeated notes without re-attack energy...")
1724
+ midi_data, notes_merged = merge_repeated_notes(midi_data, y, sr, hop_length)
1725
+ print(f" Merged {notes_merged} repeated notes into sustains")
1726
+
1727
  # Step 8: Fix overlaps and enforce min duration (LAST β€” after all position changes)
1728
  print("\nStep 8: Fixing overlaps and enforcing min duration...")
1729
  midi_data, notes_trimmed, durations_enforced = fix_note_overlap(midi_data)
transcriber/optimize_other.py CHANGED
@@ -34,6 +34,7 @@ from optimize import (
34
  remove_harmonic_ghosts,
35
  remove_hand_outliers,
36
  consolidate_rhythm,
 
37
  )
38
 
39
 
@@ -159,6 +160,11 @@ def optimize_other(original_audio_path, midi_path, output_path=None, mix_audio_p
159
  midi_data, rhythm_snapped, n_dominant = consolidate_rhythm(midi_data, y, sr, hop_length)
160
  print(f" Snapped {rhythm_snapped} notes to {n_dominant} dominant subdivisions")
161
 
 
 
 
 
 
162
  # Step 8: Fix overlaps and enforce min duration
163
  print("\nStep 8: Fixing overlaps...")
164
  midi_data, notes_trimmed, durations_enforced = fix_note_overlap(midi_data)
 
34
  remove_harmonic_ghosts,
35
  remove_hand_outliers,
36
  consolidate_rhythm,
37
+ merge_repeated_notes,
38
  )
39
 
40
 
 
160
  midi_data, rhythm_snapped, n_dominant = consolidate_rhythm(midi_data, y, sr, hop_length)
161
  print(f" Snapped {rhythm_snapped} notes to {n_dominant} dominant subdivisions")
162
 
163
+ # Step 7c: Merge repeated same-pitch notes without real re-attack
164
+ print("\nStep 7c: Merging repeated notes without re-attack energy...")
165
+ midi_data, notes_merged = merge_repeated_notes(midi_data, y, sr, hop_length)
166
+ print(f" Merged {notes_merged} repeated notes into sustains")
167
+
168
  # Step 8: Fix overlaps and enforce min duration
169
  print("\nStep 8: Fixing overlaps...")
170
  midi_data, notes_trimmed, durations_enforced = fix_note_overlap(midi_data)