Spaces:
Sleeping
Sleeping
| """Optimize MIDI transcription by correcting onsets, cleaning artifacts, and | |
| ensuring rhythmic accuracy against the original audio.""" | |
| import copy | |
| from pathlib import Path | |
| import numpy as np | |
| import pretty_midi | |
| import librosa | |
| from collections import Counter | |
| def remove_leading_silence_notes(midi_data, y, sr): | |
| """Remove notes that appear during silence/noise before the music starts. | |
| Finds the first moment of real musical energy and removes any MIDI notes | |
| before that point (typically microphone rumble / low-freq noise artifacts). | |
| Always preserves the first detected MIDI note to prevent eating the opening. | |
| """ | |
| midi_out = copy.deepcopy(midi_data) | |
| # Compute RMS in 50ms windows | |
| hop = int(0.05 * sr) | |
| rms = np.array([ | |
| np.sqrt(np.mean(y[i * hop:(i + 1) * hop] ** 2)) | |
| for i in range(len(y) // hop) | |
| ]) | |
| if len(rms) == 0: | |
| return midi_out, 0, 0.0 | |
| # Music starts when RMS first exceeds 5% of the peak energy | |
| # (reduced from 10% to avoid eating quiet openings) | |
| max_rms = np.max(rms) | |
| music_start = 0.0 | |
| for i, r in enumerate(rms): | |
| if r > max_rms * 0.05: | |
| music_start = i * 0.05 | |
| break | |
| if music_start < 0.1: | |
| return midi_out, 0, music_start | |
| # Find the earliest MIDI note onset — always protect it | |
| all_notes = sorted( | |
| [n for inst in midi_out.instruments for n in inst.notes], | |
| key=lambda n: n.start | |
| ) | |
| earliest_onset = all_notes[0].start if all_notes else 0.0 | |
| # If the "silence" region would eat the first note, clamp music_start | |
| if music_start > earliest_onset: | |
| music_start = earliest_onset | |
| if music_start < 0.1: | |
| return midi_out, 0, music_start | |
| removed = 0 | |
| for instrument in midi_out.instruments: | |
| filtered = [] | |
| for note in instrument.notes: | |
| if note.start < music_start: | |
| removed += 1 | |
| else: | |
| filtered.append(note) | |
| instrument.notes = filtered | |
| return midi_out, removed, music_start | |
| def remove_trailing_silence_notes(midi_data, y, sr): | |
| """Remove notes that appear during the audio fade-out/silence at the end. | |
| Uses a 2% RMS threshold (reduced from 5%) and adds a 3-second protection | |
| zone after the detected music end to preserve natural piano decay/sustain. | |
| """ | |
| midi_out = copy.deepcopy(midi_data) | |
| hop = int(0.05 * sr) | |
| rms = np.array([ | |
| np.sqrt(np.mean(y[i * hop:(i + 1) * hop] ** 2)) | |
| for i in range(len(y) // hop) | |
| ]) | |
| if len(rms) == 0: | |
| return midi_out, 0, len(y) / sr | |
| max_rms = np.max(rms) | |
| # Find the last moment where RMS exceeds 2% of peak (searching backwards) | |
| # Reduced from 5% to preserve quiet endings, fade-outs, and final sustain | |
| music_end = len(y) / sr | |
| for i in range(len(rms) - 1, -1, -1): | |
| if rms[i] > max_rms * 0.02: | |
| # Add 3-second protection zone for natural piano decay | |
| music_end = (i + 1) * 0.05 + 3.0 | |
| break | |
| # Clamp to actual audio duration | |
| music_end = min(music_end, len(y) / sr) | |
| removed = 0 | |
| for instrument in midi_out.instruments: | |
| filtered = [] | |
| for note in instrument.notes: | |
| if note.start > music_end: | |
| removed += 1 | |
| else: | |
| filtered.append(note) | |
| instrument.notes = filtered | |
| return midi_out, removed, music_end | |
| def remove_low_energy_notes(midi_data, y, sr, hop_length=512): | |
| """Remove notes whose onsets don't correspond to real audio energy. | |
| Computes the onset strength envelope and removes notes at times | |
| where the audio shows no significant onset energy. This catches | |
| basic-pitch hallucinations that appear at normal pitches but have | |
| no corresponding audio event. | |
| Uses an adaptive threshold based on the recording's onset strength | |
| distribution (15th percentile), so it works equally well on loud | |
| and quiet recordings. | |
| """ | |
| midi_out = copy.deepcopy(midi_data) | |
| onset_env = librosa.onset.onset_strength(y=y, sr=sr, hop_length=hop_length) | |
| onset_times = librosa.frames_to_time( | |
| np.arange(len(onset_env)), sr=sr, hop_length=hop_length | |
| ) | |
| removed = 0 | |
| for instrument in midi_out.instruments: | |
| # First pass: measure strength per note | |
| note_strengths = [] | |
| for note in instrument.notes: | |
| frame = np.argmin(np.abs(onset_times - note.start)) | |
| lo = max(0, frame - 2) | |
| hi = min(len(onset_env), frame + 3) | |
| strength = float(np.max(onset_env[lo:hi])) | |
| note_strengths.append(strength) | |
| if not note_strengths: | |
| continue | |
| # Adaptive threshold: 15th percentile of note onset strengths | |
| # This adapts to the recording's volume — quiet recordings get | |
| # a lower threshold, loud recordings get a higher one. | |
| # Floor at 0.5 to always catch clearly silent hallucinations. | |
| strength_threshold = max(0.5, float(np.percentile(note_strengths, 15))) | |
| filtered = [] | |
| for idx, note in enumerate(instrument.notes): | |
| if note_strengths[idx] >= strength_threshold: | |
| filtered.append(note) | |
| else: | |
| # Keep notes that are part of a chord with a strong onset | |
| chord_has_energy = False | |
| for other_idx, other in enumerate(instrument.notes): | |
| if other is note: | |
| continue | |
| if abs(other.start - note.start) < 0.03 and note_strengths[other_idx] >= strength_threshold: | |
| chord_has_energy = True | |
| break | |
| if chord_has_energy: | |
| filtered.append(note) | |
| else: | |
| removed += 1 | |
| instrument.notes = filtered | |
| return midi_out, removed | |
| def remove_harmonic_ghosts(midi_data, y=None, sr=22050, hop_length=512): | |
| """Remove notes that are harmonic doublings of louder lower notes. | |
| Two-stage detector: | |
| 1. Pairwise: for notes at harmonic intervals (7, 12, 19, 24 semitones), | |
| remove the upper note if it's clearly a harmonic ghost. | |
| 2. Spectral masking: when bass + melody overlap (two-hand texture), | |
| check if upper notes can be explained by the harmonic series of | |
| strong lower notes. This catches ghost notes that the pairwise | |
| detector misses because they're at non-standard intervals. | |
| Uses CQT energy to protect strong notes: if the CQT shows the note | |
| has strong independent energy distinct from what the lower note's | |
| harmonics would produce, it's a real played note. | |
| """ | |
| midi_out = copy.deepcopy(midi_data) | |
| removed = 0 | |
| harmonic_intervals = {7, 12, 19, 24} | |
| # Compute CQT for energy verification if audio provided | |
| C_db = None | |
| N_BINS = 0 | |
| if y is not None: | |
| N_BINS = 88 * 3 | |
| FMIN = librosa.note_to_hz('A0') | |
| C = np.abs(librosa.cqt( | |
| y, sr=sr, hop_length=hop_length, | |
| fmin=FMIN, n_bins=N_BINS, bins_per_octave=36, | |
| )) | |
| C_db = librosa.amplitude_to_db(C, ref=np.max(C)) | |
| for instrument in midi_out.instruments: | |
| notes = sorted(instrument.notes, key=lambda n: n.start) | |
| to_remove = set() | |
| for i, note in enumerate(notes): | |
| if i in to_remove: | |
| continue | |
| if note.pitch < 48: | |
| continue | |
| # Check CQT energy — protect notes with moderate+ energy | |
| if C_db is not None: | |
| fund_bin = (note.pitch - 21) * 3 + 1 | |
| if 0 <= fund_bin < C_db.shape[0]: | |
| start_frame = max(0, int(note.start * sr / hop_length)) | |
| end_frame = min(C_db.shape[1], start_frame + max(1, int(0.2 * sr / hop_length))) | |
| lo = max(0, fund_bin - 1) | |
| hi = min(C_db.shape[0], fund_bin + 2) | |
| onset_db = float(np.max(C_db[lo:hi, start_frame:end_frame])) | |
| if onset_db > -12.0: | |
| # Real CQT energy present — keep this note | |
| continue | |
| for j, other in enumerate(notes): | |
| if i == j or j in to_remove: | |
| continue | |
| if abs(other.start - note.start) > 0.10: | |
| continue | |
| diff = note.pitch - other.pitch | |
| if diff in harmonic_intervals and diff > 0: | |
| ratio = note.velocity / max(1, other.velocity) | |
| if note.pitch >= 72: | |
| # C5+: only remove if much quieter than the lower note | |
| if ratio < 0.55: | |
| to_remove.add(i) | |
| break | |
| elif other.pitch < 48: | |
| # Sub-bass pairs: keep tighter — sub-bass ghosts are common | |
| if ratio < 0.85: | |
| to_remove.add(i) | |
| break | |
| else: | |
| # General: only remove if clearly quieter | |
| if ratio < 0.55: | |
| to_remove.add(i) | |
| break | |
| # Stage 2: Spectral masking for two-hand texture | |
| # When bass (< MIDI 55) and melody (>= MIDI 60) overlap, bass harmonics | |
| # can produce ghost notes in the melody range. Check if a mid-range note | |
| # is explainable as a harmonic partial of a concurrent bass note. | |
| if C_db is not None: | |
| remaining = [(k, n) for k, n in enumerate(notes) if k not in to_remove] | |
| bass_notes = [(k, n) for k, n in remaining if n.pitch < 55] | |
| mid_notes = [(k, n) for k, n in remaining if 55 <= n.pitch < 72] | |
| for mid_k, mid_n in mid_notes: | |
| if mid_k in to_remove: | |
| continue | |
| for bass_k, bass_n in bass_notes: | |
| if abs(bass_n.start - mid_n.start) > 0.05: | |
| continue | |
| # Check if mid_n.pitch matches any harmonic partial of bass_n | |
| # Harmonics: 2nd (+12), 3rd (+19), 4th (+24), 5th (+28), 6th (+31) | |
| bass_pitch = bass_n.pitch | |
| harmonic_pitches = { | |
| bass_pitch + 12, # 2nd harmonic (octave) | |
| bass_pitch + 19, # 3rd (octave + fifth) | |
| bass_pitch + 24, # 4th (2 octaves) | |
| bass_pitch + 28, # 5th (2 oct + major 3rd) | |
| bass_pitch + 31, # 6th (2 oct + fifth) | |
| } | |
| if mid_n.pitch in harmonic_pitches: | |
| # This mid note matches a bass harmonic — check if | |
| # it has independent CQT energy above the harmonic level | |
| mid_bin = (mid_n.pitch - 21) * 3 + 1 | |
| bass_bin = (bass_pitch - 21) * 3 + 1 | |
| if 0 <= mid_bin < N_BINS and 0 <= bass_bin < N_BINS: | |
| sf = max(0, int(mid_n.start * sr / hop_length)) | |
| ef = min(C_db.shape[1], sf + max(1, int(0.15 * sr / hop_length))) | |
| mid_energy = float(np.max(C_db[max(0, mid_bin-1):min(N_BINS, mid_bin+2), sf:ef])) | |
| bass_energy = float(np.max(C_db[max(0, bass_bin-1):min(N_BINS, bass_bin+2), sf:ef])) | |
| # If bass is much louder (>8dB) and mid note is quiet, | |
| # it's likely a harmonic ghost | |
| if bass_energy - mid_energy > 8.0 and mid_n.velocity < bass_n.velocity * 0.7: | |
| to_remove.add(mid_k) | |
| break | |
| instrument.notes = [n for k, n in enumerate(notes) if k not in to_remove] | |
| removed += len(to_remove) | |
| return midi_out, removed | |
| def remove_phantom_notes(midi_data, max_pitch=None): | |
| """Remove high-register notes that are likely harmonic artifacts. | |
| Uses multiple factors to distinguish real notes from phantoms: | |
| - Must be above the 95th percentile pitch | |
| - Must be rare (< 3 occurrences at that exact pitch) | |
| - Must have low velocity (< 40) | |
| - Must be isolated (no other notes within 2 semitones and 500ms) | |
| """ | |
| midi_out = copy.deepcopy(midi_data) | |
| all_notes = [(n, i) for i, inst in enumerate(midi_out.instruments) for n in inst.notes] | |
| all_pitches = [n.pitch for n, _ in all_notes] | |
| if not all_pitches: | |
| return midi_out, 0 | |
| if max_pitch is None: | |
| max_pitch = int(np.percentile(all_pitches, 95)) | |
| pitch_counts = Counter(all_pitches) | |
| # Build a time-sorted list for neighbor checking | |
| time_sorted = sorted(all_notes, key=lambda x: x[0].start) | |
| def is_isolated(note, all_sorted): | |
| """Check if a note has no other notes nearby (within 100ms). | |
| A note in a chord or musical event is not isolated, regardless | |
| of the pitch of its neighbors. This prevents falsely removing | |
| high notes that are part of chords with lower-pitched notes. | |
| """ | |
| for other, _ in all_sorted: | |
| if other is note: | |
| continue | |
| if other.start > note.start + 0.1: | |
| break | |
| if abs(other.start - note.start) < 0.1: | |
| return False | |
| return True | |
| removed = 0 | |
| for instrument in midi_out.instruments: | |
| filtered = [] | |
| for note in instrument.notes: | |
| if note.pitch > max_pitch: | |
| count = pitch_counts[note.pitch] | |
| duration = note.end - note.start | |
| # Higher velocity threshold for very high notes (above MIDI 80) | |
| vel_thresh = 55 if note.pitch > 80 else 40 | |
| # Only remove if MULTIPLE indicators suggest it's a phantom: | |
| # Very rare AND (low velocity OR very short OR isolated) | |
| if count < 3 and (note.velocity < vel_thresh or duration < 0.08 or | |
| is_isolated(note, time_sorted)): | |
| removed += 1 | |
| continue | |
| filtered.append(note) | |
| instrument.notes = filtered | |
| return midi_out, removed | |
| def remove_spurious_onsets(midi_data, y, sr, ref_onsets, hop_length=512, complexity='simple'): | |
| """Remove MIDI notes that form false-positive onsets not backed by audio. | |
| Analysis shows 37 extra MIDI onsets cause the biggest F1 drag (precision=0.918). | |
| This filter targets three categories of false positives: | |
| 1. Chord fragments: notes that basic-pitch split from a real chord, creating | |
| a separate onset within 60ms of a matched onset. These should have been | |
| grouped with the chord. | |
| 2. Isolated ghost onsets: single-note, low-strength onsets far from any | |
| audio onset -- pure hallucinations. | |
| 3. Short+quiet artifacts: onsets where every note is both short (<200ms) | |
| and quiet (velocity < 50). | |
| For complex pieces, thresholds are relaxed to preserve legitimate dense | |
| textures that might otherwise be classified as spurious. | |
| The filter first identifies which MIDI onsets already match audio onsets, | |
| then only removes unmatched onsets meeting the above criteria. | |
| """ | |
| midi_out = copy.deepcopy(midi_data) | |
| tolerance = 0.05 | |
| # Complexity-adjusted thresholds: complex pieces are more permissive | |
| # to preserve legitimate dense textures | |
| if complexity == 'complex': | |
| strength_scale = 1.5 # require stronger evidence to remove | |
| dist_scale = 1.4 # require further from audio onset to remove | |
| elif complexity == 'moderate': | |
| strength_scale = 1.2 | |
| dist_scale = 1.2 | |
| else: | |
| strength_scale = 1.0 | |
| dist_scale = 1.0 | |
| onset_env = librosa.onset.onset_strength(y=y, sr=sr, hop_length=hop_length) | |
| onset_times = librosa.frames_to_time( | |
| np.arange(len(onset_env)), sr=sr, hop_length=hop_length | |
| ) | |
| # Collect all notes and compute unique onsets | |
| all_notes = sorted( | |
| [n for inst in midi_out.instruments for n in inst.notes], | |
| key=lambda n: n.start | |
| ) | |
| midi_onsets = sorted(set(round(n.start, 4) for n in all_notes)) | |
| midi_onsets_arr = np.array(midi_onsets) | |
| # Identify which MIDI onsets are already matched to audio onsets | |
| matched_est = set() | |
| for r in ref_onsets: | |
| diffs = np.abs(midi_onsets_arr - r) | |
| best = np.argmin(diffs) | |
| if diffs[best] <= tolerance and best not in matched_est: | |
| matched_est.add(best) | |
| # For each unmatched onset, check removal criteria | |
| onsets_to_remove = set() | |
| for j, mo in enumerate(midi_onsets_arr): | |
| if j in matched_est: | |
| continue | |
| # Get notes at this onset | |
| onset_notes = [n for n in all_notes if abs(n.start - mo) < 0.03] | |
| if not onset_notes: | |
| continue | |
| # Compute onset strength at this time | |
| frame = np.argmin(np.abs(onset_times - mo)) | |
| lo = max(0, frame - 2) | |
| hi = min(len(onset_env), frame + 3) | |
| strength = float(np.max(onset_env[lo:hi])) | |
| # Distance to nearest audio onset | |
| diffs_audio = np.abs(ref_onsets - mo) | |
| nearest_audio_ms = float(np.min(diffs_audio)) * 1000 | |
| # Check if near a matched MIDI onset (chord fragment) | |
| near_matched = any( | |
| abs(midi_onsets_arr[k] - mo) < 0.060 | |
| for k in matched_est | |
| ) | |
| # Category 1: Chord fragment -- near a matched onset, but only if | |
| # the onset has weak audio energy. Strong onsets near chords may be | |
| # real grace notes or arpeggios. | |
| if near_matched and strength < 2.0 * strength_scale: | |
| onsets_to_remove.add(j) | |
| continue | |
| # Category 2: Isolated ghost -- single note, low strength or far from audio | |
| if len(onset_notes) == 1 and (strength < 1.5 * strength_scale or nearest_audio_ms > 100 * dist_scale): | |
| onsets_to_remove.add(j) | |
| continue | |
| # Category 3: Short+quiet artifact | |
| if all(n.end - n.start < 0.2 and n.velocity < 50 for n in onset_notes): | |
| onsets_to_remove.add(j) | |
| continue | |
| # Category 4: Low-velocity bass ghost -- single bass note (< MIDI 40), | |
| # low velocity (< 35), far from audio onset. These are rumble artifacts | |
| # that survive the energy filter. | |
| if (len(onset_notes) == 1 and onset_notes[0].pitch < 40 | |
| and onset_notes[0].velocity < 35 and nearest_audio_ms > 60 * dist_scale): | |
| onsets_to_remove.add(j) | |
| continue | |
| # Category 5: Multi-note onset far from any audio onset (> 120ms) | |
| # with weak-to-moderate onset strength. These are chord-split artifacts | |
| # or hallucinated events with no audio support. | |
| if nearest_audio_ms > 120 * dist_scale and strength < 3.0 * strength_scale: | |
| onsets_to_remove.add(j) | |
| continue | |
| # Category 6: All notes at this onset are very short (<50ms) — | |
| # splinter artifacts from chord splitting, regardless of velocity. | |
| if all(n.end - n.start < 0.05 for n in onset_notes): | |
| onsets_to_remove.add(j) | |
| continue | |
| # Category 7: Moderate distance from audio (> 70ms) with weak | |
| # onset strength — catches near-miss hallucinations that are | |
| # just outside the 50ms matching window. | |
| if nearest_audio_ms > 70 * dist_scale and strength < 2.5 * strength_scale: | |
| onsets_to_remove.add(j) | |
| continue | |
| # Remove notes belonging to spurious onsets | |
| times_to_remove = set(midi_onsets_arr[j] for j in onsets_to_remove) | |
| removed = 0 | |
| for instrument in midi_out.instruments: | |
| filtered = [] | |
| for note in instrument.notes: | |
| note_onset = round(note.start, 4) | |
| if any(abs(note_onset - t) < 0.03 for t in times_to_remove): | |
| removed += 1 | |
| else: | |
| filtered.append(note) | |
| instrument.notes = filtered | |
| return midi_out, removed, len(onsets_to_remove) | |
| def remove_pitch_unconfirmed_notes(midi_data, y, sr, hop_length=512): | |
| """Remove notes where the CQT has no energy at their fundamental pitch. | |
| Checks the onset region (first 200ms) of each note for CQT energy, | |
| not the full duration. This prevents CQT-extended notes from being | |
| falsely removed due to low average energy over their extended tail. | |
| Targets two ranges where hallucinations concentrate: | |
| - Sub-bass (< MIDI 40): rumble artifacts | |
| - Upper register (> MIDI 72): harmonic doublings | |
| Core piano range (MIDI 40-72 / E2-C5) is reliable from basic-pitch. | |
| """ | |
| midi_out = copy.deepcopy(midi_data) | |
| N_BINS = 88 * 3 | |
| FMIN = librosa.note_to_hz('A0') | |
| C = np.abs(librosa.cqt( | |
| y, sr=sr, hop_length=hop_length, | |
| fmin=FMIN, n_bins=N_BINS, bins_per_octave=36, | |
| )) | |
| C_db = librosa.amplitude_to_db(C, ref=np.max(C)) | |
| # Collect all notes for chord checking | |
| all_notes = sorted( | |
| [n for inst in midi_out.instruments for n in inst.notes], | |
| key=lambda n: n.start | |
| ) | |
| # Onset region: check max energy in first 200ms | |
| onset_frames = max(1, int(0.2 * sr / hop_length)) | |
| removed = 0 | |
| for instrument in midi_out.instruments: | |
| filtered = [] | |
| for note in instrument.notes: | |
| # Core mid-range (C3-C5) is reliable from basic-pitch — skip | |
| # Bass (MIDI 40-47) gets a lenient CQT check to catch rumble | |
| # Upper register (>72) gets checked for harmonic ghosts | |
| if 48 <= note.pitch <= 72: | |
| filtered.append(note) | |
| continue | |
| fund_bin = (note.pitch - 21) * 3 + 1 | |
| if fund_bin < 0 or fund_bin >= N_BINS: | |
| filtered.append(note) | |
| continue | |
| start_frame = max(0, int(note.start * sr / hop_length)) | |
| check_end = min(C.shape[1], start_frame + onset_frames) | |
| if start_frame >= check_end: | |
| filtered.append(note) | |
| continue | |
| lo = max(0, fund_bin - 1) | |
| hi = min(N_BINS, fund_bin + 2) | |
| # Use max energy in onset region, not average over full duration | |
| onset_db = float(np.max(C_db[lo:hi, start_frame:check_end])) | |
| if note.pitch < 40: | |
| thresh = -42.0 | |
| elif note.pitch < 48: | |
| # Bass (C2-B2): moderate check — real bass notes have clear | |
| # CQT energy, but threshold is lenient to keep genuine notes | |
| thresh = -35.0 | |
| else: # > 72, upper register | |
| thresh = -25.0 | |
| if onset_db < thresh: | |
| # Remove if weak CQT evidence regardless of context | |
| # Very weak = always remove; moderate weak = check isolation | |
| if onset_db < thresh - 10: | |
| # Extremely weak: always remove | |
| removed += 1 | |
| continue | |
| concurrent = sum(1 for o in all_notes | |
| if abs(o.start - note.start) < 0.05 and o is not note) | |
| if concurrent <= 3 or note.velocity < 55: | |
| removed += 1 | |
| else: | |
| filtered.append(note) | |
| else: | |
| filtered.append(note) | |
| instrument.notes = filtered | |
| return midi_out, removed | |
| def apply_pitch_ceiling(midi_data, max_pitch=96): | |
| """Remove notes above a hard pitch ceiling (C7 / MIDI 96). | |
| Only truly extreme high notes are blanket-removed. Notes between C6-C7 | |
| are kept and handled by the CQT energy filter instead, since some | |
| (like C6, D6) can be legitimate played notes. | |
| """ | |
| midi_out = copy.deepcopy(midi_data) | |
| removed = 0 | |
| for instrument in midi_out.instruments: | |
| filtered = [] | |
| for note in instrument.notes: | |
| if note.pitch >= max_pitch: | |
| removed += 1 | |
| else: | |
| filtered.append(note) | |
| instrument.notes = filtered | |
| return midi_out, removed | |
| def limit_concurrent_notes(midi_data, max_per_hand=4, hand_split=60, max_left_hand=None): | |
| """Limit notes per chord to max_per_hand per hand. | |
| Groups notes by onset time (within 30ms) and splits into left/right hand. | |
| Removes excess notes — protects melody (highest RH pitch) and bass | |
| (lowest LH pitch), then removes lowest velocity. | |
| Args: | |
| max_per_hand: Max notes for right hand (default 4) | |
| max_left_hand: Max notes for left hand (defaults to max_per_hand) | |
| """ | |
| if max_left_hand is None: | |
| max_left_hand = max_per_hand | |
| midi_out = copy.deepcopy(midi_data) | |
| removed = 0 | |
| for instrument in midi_out.instruments: | |
| notes = sorted(instrument.notes, key=lambda n: n.start) | |
| if not notes: | |
| continue | |
| chords = [] | |
| current_chord = [0] | |
| for i in range(1, len(notes)): | |
| if notes[i].start - notes[current_chord[0]].start < 0.03: | |
| current_chord.append(i) | |
| else: | |
| chords.append(current_chord) | |
| current_chord = [i] | |
| chords.append(current_chord) | |
| to_remove = set() | |
| for chord_indices in chords: | |
| left = [idx for idx in chord_indices if notes[idx].pitch < hand_split] | |
| right = [idx for idx in chord_indices if notes[idx].pitch >= hand_split] | |
| for is_right, hand_indices in [(True, right), (False, left)]: | |
| limit = max_per_hand if is_right else max_left_hand | |
| if len(hand_indices) <= limit: | |
| continue | |
| # Both hands: protect the melody (highest note) | |
| # LH melody voice is the top line; RH melody is the top line | |
| protected = max(hand_indices, key=lambda idx: notes[idx].pitch) | |
| trimmable = [idx for idx in hand_indices if idx != protected] | |
| scored = [(notes[idx].velocity, idx) for idx in trimmable] | |
| scored.sort() | |
| excess = len(hand_indices) - limit | |
| for _, idx in scored[:excess]: | |
| to_remove.add(idx) | |
| instrument.notes = [n for k, n in enumerate(notes) if k not in to_remove] | |
| removed += len(to_remove) | |
| return midi_out, removed | |
| def limit_total_concurrent(midi_data, max_per_hand=4, hand_split=60, max_left_hand=None): | |
| """Limit concurrent sounding notes to max_per_hand per hand. | |
| Splits notes into left hand (< hand_split) and right hand (>= hand_split). | |
| At each note onset, count concurrent notes in that hand. If > limit, | |
| trim sustained notes — protect the melody (highest pitch in both hands). | |
| Among the rest, trim lowest velocity first. | |
| Args: | |
| max_per_hand: Max concurrent notes for right hand (default 4) | |
| max_left_hand: Max concurrent notes for left hand (defaults to max_per_hand) | |
| """ | |
| if max_left_hand is None: | |
| max_left_hand = max_per_hand | |
| midi_out = copy.deepcopy(midi_data) | |
| trimmed = 0 | |
| for instrument in midi_out.instruments: | |
| notes = sorted(instrument.notes, key=lambda n: n.start) | |
| if not notes: | |
| continue | |
| for i, note in enumerate(notes): | |
| is_right = note.pitch >= hand_split | |
| limit = max_per_hand if is_right else max_left_hand | |
| # Find all notes in the same hand currently sounding | |
| sounding = [] | |
| for j in range(i): | |
| if notes[j].end > note.start: | |
| same_hand = (notes[j].pitch >= hand_split) == is_right | |
| if same_hand: | |
| sounding.append(j) | |
| if len(sounding) + 1 > limit: | |
| excess = len(sounding) + 1 - limit | |
| all_indices = sounding + [i] | |
| # Both hands: protect highest pitch (melody voice) | |
| protected = max(all_indices, key=lambda j: notes[j].pitch) | |
| # Among the sustained (not the new note), trim lowest velocity | |
| # but never trim the protected note | |
| trimmable = [j for j in sounding if j != protected] | |
| scored = [(notes[j].velocity, j) for j in trimmable] | |
| scored.sort() # lowest velocity trimmed first | |
| for _, j in scored[:excess]: | |
| notes[j].end = note.start | |
| trimmed += 1 | |
| instrument.notes = [n for n in notes if n.end - n.start > 0.01] | |
| return midi_out, trimmed | |
| def remove_hand_outliers(midi_data, hand_split=60, gap_threshold=7): | |
| """Remove notes that are pitch outliers within their hand group. | |
| For each chord (notes within 30ms), splits into left/right hand and | |
| checks for notes isolated from the cluster at the low end — e.g. a | |
| left-hand note at MIDI 33 when the rest of the LH chord is at 45-52, | |
| or a right-hand note at MIDI 62 when the rest is at 72-79. | |
| Both hands protect the melody (highest note) and flag the lowest note | |
| as an outlier if it's too far from the cluster. These low outliers are | |
| almost always sub-harmonic ghosts from the transcriber. | |
| Args: | |
| hand_split: MIDI pitch dividing left/right hand (default 60 = C4) | |
| gap_threshold: Semitones — if a note is this far from its nearest | |
| neighbor in the same hand, it's flagged as an outlier (default 7) | |
| """ | |
| midi_out = copy.deepcopy(midi_data) | |
| removed = 0 | |
| for instrument in midi_out.instruments: | |
| notes = sorted(instrument.notes, key=lambda n: n.start) | |
| if not notes: | |
| continue | |
| # Group into chords (notes within 30ms) | |
| chords = [] | |
| current_chord = [0] | |
| for i in range(1, len(notes)): | |
| if notes[i].start - notes[current_chord[0]].start < 0.03: | |
| current_chord.append(i) | |
| else: | |
| chords.append(current_chord) | |
| current_chord = [i] | |
| chords.append(current_chord) | |
| to_remove = set() | |
| for chord_indices in chords: | |
| left = [idx for idx in chord_indices if notes[idx].pitch < hand_split] | |
| right = [idx for idx in chord_indices if notes[idx].pitch >= hand_split] | |
| for hand_indices in [right, left]: | |
| if len(hand_indices) < 3: | |
| # Need at least 3 notes to identify an outlier vs cluster | |
| continue | |
| pitches = sorted([(notes[idx].pitch, idx) for idx in hand_indices]) | |
| # Both hands: melody (highest) is protected. | |
| # Check if the lowest note is far from the cluster. | |
| lowest_pitch, lowest_idx = pitches[0] | |
| second_pitch = pitches[1][0] | |
| gap = second_pitch - lowest_pitch | |
| if gap >= gap_threshold: | |
| to_remove.add(lowest_idx) | |
| instrument.notes = [n for k, n in enumerate(notes) if k not in to_remove] | |
| removed += len(to_remove) | |
| return midi_out, removed | |
| def enforce_hand_span(midi_data, max_span=12, hand_split=60): | |
| """Enforce that no hand plays notes wider than max_span semitones. | |
| Both hands anchor on the MELODY (highest note) and build downward. | |
| This matches real piano technique: the top voice carries the melody | |
| and harmonics are voiced below within reach. | |
| Checks both: | |
| 1. Chord groups (notes starting within 30ms) | |
| 2. Concurrent sounding notes (sustained notes overlapping new ones) | |
| For LH: protects highest note (melody line), removes lowest that | |
| exceed the span — the melody voice is the most important. | |
| For RH: protects highest note (melody), removes lowest. | |
| Args: | |
| max_span: Maximum interval in semitones (default 12 = octave) | |
| hand_split: MIDI pitch dividing left/right hand (default 60 = C4) | |
| """ | |
| midi_out = copy.deepcopy(midi_data) | |
| removed = 0 | |
| for instrument in midi_out.instruments: | |
| notes = sorted(instrument.notes, key=lambda n: n.start) | |
| if not notes: | |
| continue | |
| # ── Pass 1: Chord groups (simultaneous onsets within 30ms) ── | |
| chords = [] | |
| current_chord = [0] | |
| for i in range(1, len(notes)): | |
| if notes[i].start - notes[current_chord[0]].start < 0.03: | |
| current_chord.append(i) | |
| else: | |
| chords.append(current_chord) | |
| current_chord = [i] | |
| chords.append(current_chord) | |
| to_remove = set() | |
| for chord_indices in chords: | |
| left = [idx for idx in chord_indices if notes[idx].pitch < hand_split] | |
| right = [idx for idx in chord_indices if notes[idx].pitch >= hand_split] | |
| for hand_indices in [right, left]: | |
| if len(hand_indices) < 2: | |
| continue | |
| pitches = sorted(hand_indices, key=lambda idx: notes[idx].pitch) | |
| span = notes[pitches[-1]].pitch - notes[pitches[0]].pitch | |
| if span <= max_span: | |
| continue | |
| # Both hands: protect highest (melody), remove lowest | |
| anchor_pitch = notes[pitches[-1]].pitch | |
| for idx in pitches[:-1]: | |
| if anchor_pitch - notes[idx].pitch > max_span: | |
| to_remove.add(idx) | |
| # ── Pass 2: Concurrent sounding notes (sustained overlap) ── | |
| for i, note in enumerate(notes): | |
| if i in to_remove: | |
| continue | |
| is_right = note.pitch >= hand_split | |
| # Find all same-hand notes currently sounding | |
| concurrent = [i] | |
| for j in range(i): | |
| if j in to_remove: | |
| continue | |
| if notes[j].end > note.start + 0.01: | |
| if (notes[j].pitch >= hand_split) == is_right: | |
| concurrent.append(j) | |
| if len(concurrent) < 2: | |
| continue | |
| pitches_conc = sorted(concurrent, key=lambda idx: notes[idx].pitch) | |
| span = notes[pitches_conc[-1]].pitch - notes[pitches_conc[0]].pitch | |
| if span <= max_span: | |
| continue | |
| # Protect highest (melody), trim lowest sustained notes | |
| anchor_pitch = notes[pitches_conc[-1]].pitch | |
| for idx in pitches_conc[:-1]: | |
| if anchor_pitch - notes[idx].pitch > max_span: | |
| # Don't remove entirely — just end the sustained note | |
| notes[idx].end = note.start | |
| if notes[idx].end - notes[idx].start < 0.05: | |
| to_remove.add(idx) | |
| removed += 1 | |
| instrument.notes = [n for k, n in enumerate(notes) if k not in to_remove] | |
| removed += len(to_remove) | |
| return midi_out, removed | |
| def merge_repeated_notes(midi_data, y, sr, hop_length=512, min_gap=0.15): | |
| """Merge consecutive same-pitch notes that lack a real re-attack. | |
| Basic-pitch often fragments a single sustained note into multiple short | |
| re-strikes. This step checks whether a repeated note has genuine onset | |
| energy at the re-attack point. If not, the notes are merged into one | |
| sustained note. | |
| Args: | |
| min_gap: If the gap between notes is larger than this (seconds), | |
| always keep separate — the silence itself is musical. Default 150ms. | |
| """ | |
| midi_out = copy.deepcopy(midi_data) | |
| merged_count = 0 | |
| # Compute onset strength envelope for verification | |
| onset_env = librosa.onset.onset_strength(y=y, sr=sr, hop_length=hop_length) | |
| for instrument in midi_out.instruments: | |
| # Sort by pitch then start time to find consecutive same-pitch notes | |
| notes = sorted(instrument.notes, key=lambda n: (n.pitch, n.start)) | |
| to_remove = set() | |
| i = 0 | |
| while i < len(notes) - 1: | |
| if i in to_remove: | |
| i += 1 | |
| continue | |
| note = notes[i] | |
| j = i + 1 | |
| # Walk forward through consecutive same-pitch notes | |
| while j < len(notes) and notes[j].pitch == note.pitch: | |
| if j in to_remove: | |
| j += 1 | |
| continue | |
| next_note = notes[j] | |
| gap = next_note.start - note.end | |
| # If there's a real gap (silence), keep them separate | |
| if gap > min_gap: | |
| break | |
| # If the next note starts before or just after this one ends, | |
| # check for onset energy at the re-attack point | |
| reattack_time = next_note.start | |
| reattack_frame = int(reattack_time * sr / hop_length) | |
| has_onset = False | |
| if 0 <= reattack_frame < len(onset_env): | |
| # Check onset strength in a small window around the re-attack | |
| lo = max(0, reattack_frame - 1) | |
| hi = min(len(onset_env), reattack_frame + 2) | |
| local_strength = float(np.max(onset_env[lo:hi])) | |
| # Compare to the median onset strength — if re-attack is | |
| # weaker than median, it's not a real new attack | |
| median_strength = float(np.median(onset_env[onset_env > 0])) if np.any(onset_env > 0) else 0 | |
| has_onset = local_strength > median_strength * 2.0 | |
| if not has_onset: | |
| # Merge: extend current note to cover the next one | |
| note.end = max(note.end, next_note.end) | |
| to_remove.add(j) | |
| merged_count += 1 | |
| j += 1 | |
| else: | |
| # Real re-attack — stop merging | |
| break | |
| i = j if j > i + 1 else i + 1 | |
| instrument.notes = [n for k, n in enumerate(notes) if k not in to_remove] | |
| return midi_out, merged_count | |
| def consolidate_rhythm(midi_data, y, sr, hop_length=512, max_snap=0.06): | |
| """Consolidate note onsets onto a dominant rhythmic pattern. | |
| After onset correction, notes can scatter across many different | |
| micro-timings, losing the clean rhythmic feel. This step: | |
| 1. Detects tempo and beat positions | |
| 2. Builds a histogram of note positions within each beat (16 bins | |
| per beat = 16th-note resolution) | |
| 3. Identifies dominant subdivisions (top positions by note count, | |
| capped at 8 max) | |
| 4. Re-snaps all onsets to the nearest dominant subdivision | |
| Onsets already on a dominant position are untouched. Stray onsets | |
| are snapped only if within max_snap seconds of a dominant position. | |
| Args: | |
| max_snap: Maximum distance to snap (default 60ms). Notes further | |
| from any dominant position are left alone. | |
| """ | |
| midi_out = copy.deepcopy(midi_data) | |
| # Detect tempo and beats | |
| tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sr, hop_length=hop_length) | |
| if hasattr(tempo, '__len__'): | |
| tempo = float(tempo[0]) | |
| # Fix tempo doubling | |
| if tempo > 140: | |
| half_tempo = tempo / 2 | |
| if 50 <= half_tempo <= 120: | |
| tempo = half_tempo | |
| beat_frames = beat_frames[::2] | |
| beat_times = librosa.frames_to_time(beat_frames, sr=sr, hop_length=hop_length) | |
| if len(beat_times) < 4: | |
| return midi_out, 0, 0 | |
| # Collect all note onsets | |
| all_notes = [] | |
| for inst_idx, inst in enumerate(midi_out.instruments): | |
| for note in inst.notes: | |
| all_notes.append(note) | |
| if not all_notes: | |
| return midi_out, 0, 0 | |
| # ── Step 1: Build histogram of where notes fall within each beat ── | |
| # Use 16 bins per beat (16th-note resolution) | |
| n_bins = 16 | |
| histogram = np.zeros(n_bins) | |
| for note in all_notes: | |
| # Find which beat this note belongs to | |
| beat_idx = np.searchsorted(beat_times, note.start, side='right') - 1 | |
| if beat_idx < 0 or beat_idx >= len(beat_times) - 1: | |
| continue | |
| beat_start = beat_times[beat_idx] | |
| beat_dur = beat_times[beat_idx + 1] - beat_start | |
| if beat_dur <= 0: | |
| continue | |
| # Position within beat as fraction [0, 1) | |
| frac = (note.start - beat_start) / beat_dur | |
| frac = max(0.0, min(frac, 0.9999)) | |
| bin_idx = int(frac * n_bins) | |
| histogram[bin_idx] += 1 | |
| total_notes_in_beats = histogram.sum() | |
| if total_notes_in_beats == 0: | |
| return midi_out, 0, 0 | |
| # ── Step 2: Identify dominant subdivisions ── | |
| # Pick the top bins by note count. Always include downbeat (0) and | |
| # half-beat (8). Cap at 8 dominant positions max to force a clean grid. | |
| dominant_bins = {0} | |
| if histogram[8] > 0: | |
| dominant_bins.add(8) | |
| # Sort bins by count (descending), add until we have up to 4 | |
| # Fewer dominant positions = tighter grid = cleaner rhythm | |
| ranked = sorted(range(n_bins), key=lambda i: histogram[i], reverse=True) | |
| min_count = max(total_notes_in_beats * 0.05, 4) # must have at least 5% or 4 notes | |
| for b in ranked: | |
| if len(dominant_bins) >= 4: | |
| break | |
| if histogram[b] >= min_count: | |
| dominant_bins.add(b) | |
| dominant_fracs = sorted([b / n_bins for b in dominant_bins]) | |
| print(f" Dominant subdivisions: {len(dominant_fracs)}/{n_bins} " | |
| f"(bins: {sorted(dominant_bins)})") | |
| # ── Step 3: Build full grid of dominant positions ── | |
| dominant_grid = [] | |
| for i in range(len(beat_times) - 1): | |
| beat_start = beat_times[i] | |
| beat_dur = beat_times[i + 1] - beat_start | |
| for frac in dominant_fracs: | |
| dominant_grid.append(beat_start + frac * beat_dur) | |
| # Extend past the last beat | |
| if len(beat_times) >= 2: | |
| last_dur = beat_times[-1] - beat_times[-2] | |
| for frac in dominant_fracs: | |
| dominant_grid.append(beat_times[-1] + frac * last_dur) | |
| dominant_grid = np.array(dominant_grid) | |
| # ── Step 4: Build 8th-note fallback grid ── | |
| # For notes that are too far from any dominant position, snap to the | |
| # nearest 8th note instead of leaving them unquantized. | |
| beat_dur = 60.0 / tempo if tempo > 30 else 0.5 | |
| eighth = beat_dur / 2.0 | |
| fallback_grid = [] | |
| if len(beat_times) >= 2: | |
| fb_start = max(0, beat_times[0] - beat_dur * 2) | |
| fb_t = fb_start | |
| while fb_t <= beat_times[-1] + beat_dur * 2: | |
| fallback_grid.append(fb_t) | |
| fb_t += eighth | |
| fallback_grid = np.array(fallback_grid) if fallback_grid else np.array([0]) | |
| # ── Step 5: Snap stray onsets to dominant grid (or fallback) ── | |
| snapped = 0 | |
| for inst in midi_out.instruments: | |
| for note in inst.notes: | |
| diffs = np.abs(dominant_grid - note.start) | |
| nearest_idx = np.argmin(diffs) | |
| dist = diffs[nearest_idx] | |
| if dist < 0.003: | |
| # Already on a dominant position (within 3ms) | |
| continue | |
| if dist <= max_snap: | |
| duration = note.end - note.start | |
| note.start = dominant_grid[nearest_idx] | |
| note.end = note.start + duration | |
| snapped += 1 | |
| else: | |
| # Fallback: snap to nearest 8th note | |
| fb_diffs = np.abs(fallback_grid - note.start) | |
| fb_idx = np.argmin(fb_diffs) | |
| if fb_diffs[fb_idx] <= max_snap * 1.2: | |
| duration = note.end - note.start | |
| note.start = fallback_grid[fb_idx] | |
| note.end = note.start + duration | |
| snapped += 1 | |
| return midi_out, snapped, len(dominant_fracs) | |
| def detect_sustain_regions(y, sr, hop_length=512): | |
| """Detect regions where the sustain pedal is likely engaged. | |
| Analyzes spectral flux and broadband energy decay. When the sustain pedal | |
| is held, notes ring longer and the spectral energy decays slowly instead | |
| of dropping abruptly at note release. Detects this by looking for: | |
| 1. Low spectral flux (sustained timbre, no new attacks) | |
| 2. Slow energy decay (notes ringing through pedal) | |
| Returns a boolean array (per frame) indicating sustained regions. | |
| """ | |
| # Compute spectral flux (rate of spectral change) | |
| S = np.abs(librosa.stft(y, hop_length=hop_length)) | |
| flux = np.sqrt(np.mean(np.diff(S, axis=1) ** 2, axis=0)) | |
| flux = np.concatenate([[0], flux]) # pad to match frame count | |
| # Compute RMS energy | |
| rms = librosa.feature.rms(y=y, hop_length=hop_length)[0] | |
| # Normalize both | |
| flux_norm = flux / (np.percentile(flux, 95) + 1e-8) | |
| rms_norm = rms / (np.max(rms) + 1e-8) | |
| n_frames = min(len(flux_norm), len(rms_norm)) | |
| flux_norm = flux_norm[:n_frames] | |
| rms_norm = rms_norm[:n_frames] | |
| # Sustain pedal indicators: | |
| # - Low spectral flux (< 30th percentile) = sustained sound, not new attacks | |
| # - Moderate+ energy (> 10% of peak) = notes are still ringing | |
| flux_thresh = np.percentile(flux_norm, 30) | |
| sustain_mask = (flux_norm < flux_thresh) & (rms_norm > 0.10) | |
| # Smooth: close 200ms gaps, remove blips shorter than 300ms | |
| close_frames = max(1, int(0.2 * sr / hop_length)) | |
| min_region = max(1, int(0.3 * sr / hop_length)) | |
| # Morphological closing | |
| for i in range(1, n_frames - 1): | |
| if not sustain_mask[i]: | |
| before = any(sustain_mask[max(0, i - close_frames):i]) | |
| after = any(sustain_mask[i + 1:min(n_frames, i + close_frames + 1)]) | |
| if before and after: | |
| sustain_mask[i] = True | |
| # Remove short blips | |
| in_region = False | |
| start = 0 | |
| for i in range(n_frames): | |
| if sustain_mask[i] and not in_region: | |
| start = i | |
| in_region = True | |
| elif not sustain_mask[i] and in_region: | |
| if i - start < min_region: | |
| sustain_mask[start:i] = False | |
| in_region = False | |
| return sustain_mask | |
| def extend_note_durations(midi_data, y, sr, hop_length=512, max_per_hand=4, hand_split=60): | |
| """Extend MIDI note durations to match audio CQT energy decay. | |
| Basic-pitch systematically underestimates note durations. This uses | |
| the CQT spectrogram to find where the audio energy actually decays | |
| and extends each note to match, dramatically improving spectral recall. | |
| Concurrent-aware: won't extend a note past the point where doing so | |
| would exceed max_per_hand concurrent notes in the same hand. This | |
| prevents the downstream concurrent limiter from having to trim hundreds | |
| of over-extended notes. | |
| """ | |
| midi_out = copy.deepcopy(midi_data) | |
| N_BINS = 88 * 3 | |
| FMIN = librosa.note_to_hz('A0') | |
| C = np.abs(librosa.cqt( | |
| y, sr=sr, hop_length=hop_length, | |
| fmin=FMIN, n_bins=N_BINS, bins_per_octave=36, | |
| )) | |
| C_db = librosa.amplitude_to_db(C, ref=np.max(C)) | |
| C_norm = np.maximum(C_db, -80.0) | |
| C_norm = (C_norm + 80.0) / 80.0 | |
| n_frames = C.shape[1] | |
| # Detect sustain pedal regions for longer extension allowance | |
| sustain_mask = detect_sustain_regions(y, sr, hop_length) | |
| # Pad/trim to match CQT frame count | |
| if len(sustain_mask) < n_frames: | |
| sustain_mask = np.concatenate([sustain_mask, np.zeros(n_frames - len(sustain_mask), dtype=bool)]) | |
| else: | |
| sustain_mask = sustain_mask[:n_frames] | |
| # Pre-compute per-frame concurrent counts per hand (fast O(1) lookup) | |
| right_count = np.zeros(n_frames, dtype=int) | |
| left_count = np.zeros(n_frames, dtype=int) | |
| for inst in midi_out.instruments: | |
| for n in inst.notes: | |
| sf = max(0, int(n.start * sr / hop_length)) | |
| ef = min(n_frames, int(n.end * sr / hop_length)) | |
| if n.pitch >= hand_split: | |
| right_count[sf:ef] += 1 | |
| else: | |
| left_count[sf:ef] += 1 | |
| extended = 0 | |
| sustain_extended = 0 | |
| for inst in midi_out.instruments: | |
| # Sort notes by start time for overlap checking | |
| notes_sorted = sorted(inst.notes, key=lambda n: (n.pitch, n.start)) | |
| for idx, note in enumerate(notes_sorted): | |
| fund_bin = (note.pitch - 21) * 3 + 1 | |
| if fund_bin < 0 or fund_bin >= N_BINS: | |
| continue | |
| end_frame = min(n_frames, int(note.end * sr / hop_length)) | |
| # In sustain regions, allow longer extension (4s) and lower threshold | |
| in_sustain = end_frame < n_frames and sustain_mask[min(end_frame, n_frames - 1)] | |
| max_ext_seconds = 4.0 if in_sustain else 2.0 | |
| energy_thresh = 0.15 if in_sustain else 0.20 | |
| max_extend = min(n_frames, end_frame + int(max_ext_seconds * sr / hop_length)) | |
| # Don't extend into the next note at the same pitch | |
| next_start_frame = max_extend | |
| for other in notes_sorted[idx + 1:]: | |
| if other.pitch == note.pitch: | |
| next_start_frame = min(next_start_frame, int(other.start * sr / hop_length) - 1) | |
| break | |
| is_right = note.pitch >= hand_split | |
| hand_count = right_count if is_right else left_count | |
| actual_end = end_frame | |
| for f in range(end_frame, min(max_extend, next_start_frame)): | |
| lo = max(0, fund_bin - 1) | |
| hi = min(N_BINS, fund_bin + 2) | |
| if np.mean(C_norm[lo:hi, f]) > energy_thresh: | |
| # Check concurrent: this note isn't counted in hand_count | |
| # beyond end_frame, so hand_count[f] >= max_per_hand means | |
| # extending here would create max_per_hand + 1 concurrent | |
| if hand_count[f] >= max_per_hand: | |
| break | |
| actual_end = f | |
| else: | |
| break | |
| new_end = actual_end * hop_length / sr | |
| if new_end > note.end + 0.05: | |
| # Update the concurrent count array for the extended region | |
| old_end_frame = end_frame | |
| new_end_frame = min(n_frames, int(new_end * sr / hop_length)) | |
| if new_end_frame > old_end_frame: | |
| hand_count[old_end_frame:new_end_frame] += 1 | |
| note.end = new_end | |
| extended += 1 | |
| if in_sustain: | |
| sustain_extended += 1 | |
| return midi_out, extended | |
| def align_chords(midi_data, threshold=0.02): | |
| """Snap notes within a chord to the exact same onset time. | |
| basic-pitch's ~12ms frame resolution can make notes in the same chord | |
| start at slightly different times, causing a 'flammy' sound. | |
| """ | |
| midi_out = copy.deepcopy(midi_data) | |
| aligned = 0 | |
| for instrument in midi_out.instruments: | |
| notes = sorted(instrument.notes, key=lambda n: n.start) | |
| i = 0 | |
| while i < len(notes): | |
| group = [notes[i]] | |
| j = i + 1 | |
| while j < len(notes) and notes[j].start - notes[i].start < threshold: | |
| group.append(notes[j]) | |
| j += 1 | |
| if len(group) > 1: | |
| median_start = float(np.median([n.start for n in group])) | |
| for note in group: | |
| if note.start != median_start: | |
| duration = note.end - note.start | |
| note.start = median_start | |
| note.end = median_start + duration | |
| aligned += 1 | |
| i = j | |
| return midi_out, aligned | |
| def quantize_to_beat_grid(midi_data, y, sr, hop_length=512, strength=0.5): | |
| """Quantize note onsets to a detected beat grid. | |
| Uses librosa beat tracking to find the tempo and beat positions, | |
| builds a 16th-note grid, and snaps onsets toward the nearest grid | |
| position. Preserves note durations. | |
| """ | |
| midi_out = copy.deepcopy(midi_data) | |
| tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sr, hop_length=hop_length) | |
| if hasattr(tempo, '__len__'): | |
| tempo = float(tempo[0]) | |
| # Fix tempo doubling: librosa often detects double the true tempo for | |
| # slow/moderate songs (e.g., 86 BPM → 172). If tempo > 140 and halving | |
| # gives a reasonable tempo (50-120), use the half tempo and keep only | |
| # every other beat. | |
| if tempo > 140: | |
| half_tempo = tempo / 2 | |
| if 50 <= half_tempo <= 120: | |
| tempo = half_tempo | |
| beat_frames = beat_frames[::2] # keep every other beat | |
| beat_times = librosa.frames_to_time(beat_frames, sr=sr, hop_length=hop_length) | |
| if len(beat_times) < 2: | |
| print(" Could not detect beats, skipping quantization") | |
| return midi_out, 0, tempo | |
| # Build a 16th-note grid from the beat times | |
| grid = [] | |
| for i in range(len(beat_times) - 1): | |
| beat_dur = beat_times[i + 1] - beat_times[i] | |
| sixteenth = beat_dur / 4 | |
| for sub in range(4): | |
| grid.append(beat_times[i] + sub * sixteenth) | |
| if len(beat_times) >= 2: | |
| last_beat_dur = beat_times[-1] - beat_times[-2] | |
| sixteenth = last_beat_dur / 4 | |
| for sub in range(4): | |
| grid.append(beat_times[-1] + sub * sixteenth) | |
| grid = np.array(grid) | |
| quantized = 0 | |
| for instrument in midi_out.instruments: | |
| for note in instrument.notes: | |
| diffs = np.abs(grid - note.start) | |
| nearest_idx = np.argmin(diffs) | |
| nearest_grid = grid[nearest_idx] | |
| deviation = nearest_grid - note.start | |
| if abs(deviation) < 0.06: # Only quantize if < 60ms off grid | |
| duration = note.end - note.start | |
| note.start = note.start + deviation * strength | |
| note.end = note.start + duration | |
| if abs(deviation) > 0.005: | |
| quantized += 1 | |
| return midi_out, quantized, tempo | |
| def correct_onsets(midi_data, ref_onsets, min_off=0.02, max_off=0.15): | |
| """Correct chord onsets that are clearly misaligned with audio onsets. | |
| Groups notes into chords, then for each chord checks if there's a closer | |
| audio onset. Only corrects if min_off-max_off off and no adjacent chord | |
| is a better match. | |
| """ | |
| midi_out = copy.deepcopy(midi_data) | |
| all_notes = sorted( | |
| [(n, inst_idx) for inst_idx, inst in enumerate(midi_out.instruments) | |
| for n in inst.notes], | |
| key=lambda x: x[0].start | |
| ) | |
| chord_groups = [] | |
| if all_notes: | |
| current_group = [all_notes[0]] | |
| for item in all_notes[1:]: | |
| if item[0].start - current_group[0][0].start < 0.03: | |
| current_group.append(item) | |
| else: | |
| chord_groups.append(current_group) | |
| current_group = [item] | |
| chord_groups.append(current_group) | |
| chord_onsets = np.array([g[0][0].start for g in chord_groups]) | |
| corrections = 0 | |
| total_shift = 0.0 | |
| for group_idx, group in enumerate(chord_groups): | |
| chord_onset = chord_onsets[group_idx] | |
| diffs = ref_onsets - chord_onset | |
| abs_diffs = np.abs(diffs) | |
| nearest_idx = np.argmin(abs_diffs) | |
| nearest_diff = diffs[nearest_idx] | |
| abs_diff = abs_diffs[nearest_idx] | |
| if min_off < abs_diff < max_off: | |
| # Verify no adjacent chord is a better match | |
| if group_idx > 0: | |
| prev_onset = chord_onsets[group_idx - 1] | |
| if abs(ref_onsets[nearest_idx] - prev_onset) < abs_diff: | |
| continue | |
| if group_idx < len(chord_onsets) - 1: | |
| next_onset = chord_onsets[group_idx + 1] | |
| if abs(ref_onsets[nearest_idx] - next_onset) < abs_diff: | |
| continue | |
| for note, inst_idx in group: | |
| duration = note.end - note.start | |
| note.start = max(0, note.start + nearest_diff) | |
| note.end = note.start + duration | |
| corrections += 1 | |
| total_shift += abs(nearest_diff) | |
| initial_f1 = onset_f1(ref_onsets, chord_onsets) | |
| corrected_onsets = np.array([g[0][0].start for g in chord_groups]) | |
| final_f1 = onset_f1(ref_onsets, corrected_onsets) | |
| return midi_out, corrections, total_shift, len(chord_groups), initial_f1, final_f1 | |
| def apply_global_offset(midi_data, ref_onsets): | |
| """Measure and correct systematic timing offset against audio onsets. | |
| Computes the median difference between MIDI and audio onsets, then | |
| shifts all notes to center the distribution around zero. | |
| """ | |
| midi_out = copy.deepcopy(midi_data) | |
| all_onsets = sorted(set(n.start for inst in midi_out.instruments for n in inst.notes)) | |
| diffs = [] | |
| for mo in all_onsets: | |
| ad = np.abs(ref_onsets - mo) | |
| if np.min(ad) < 0.10: | |
| closest = ref_onsets[np.argmin(ad)] | |
| diffs.append(closest - mo) # positive = MIDI is early, negative = late | |
| if not diffs: | |
| return midi_out, 0.0 | |
| median_offset = float(np.median(diffs)) | |
| # Only apply if the offset is meaningful (> 5ms) | |
| if abs(median_offset) < 0.005: | |
| return midi_out, 0.0 | |
| for instrument in midi_out.instruments: | |
| for note in instrument.notes: | |
| duration = note.end - note.start | |
| note.start = max(0, note.start + median_offset) | |
| note.end = note.start + duration | |
| return midi_out, median_offset | |
| def fix_note_overlap(midi_data, hand_split=60, min_duration=0.10): | |
| """Trim overlapping notes in the right hand so each note releases cleanly. | |
| Also enforces a minimum note duration across ALL notes. | |
| """ | |
| midi_out = copy.deepcopy(midi_data) | |
| trimmed = 0 | |
| for instrument in midi_out.instruments: | |
| rh_notes = [n for n in instrument.notes if n.pitch >= hand_split] | |
| rh_notes.sort(key=lambda n: (n.start, n.pitch)) | |
| for i, note in enumerate(rh_notes): | |
| for j in range(i + 1, min(i + 8, len(rh_notes))): | |
| next_note = rh_notes[j] | |
| if next_note.start <= note.start: | |
| continue | |
| overlap = note.end - next_note.start | |
| if overlap > 0.05: # Only trim significant overlaps (>50ms) | |
| original_dur = note.end - note.start | |
| new_end = next_note.start - 0.01 | |
| # Never shorten more than 30% of original duration | |
| min_allowed = note.start + original_dur * 0.7 | |
| if new_end < min_allowed: | |
| new_end = min_allowed | |
| note.end = new_end | |
| if note.end - note.start < min_duration: | |
| note.end = note.start + min_duration | |
| trimmed += 1 | |
| break | |
| # Enforce minimum duration on ALL notes (catches any collapsed durations) | |
| enforced = 0 | |
| for instrument in midi_out.instruments: | |
| for note in instrument.notes: | |
| if note.end - note.start < min_duration: | |
| note.end = note.start + min_duration | |
| enforced += 1 | |
| return midi_out, trimmed, enforced | |
| def recover_missing_notes(midi_data, y, sr, hop_length=512, snap_onsets=None): | |
| """Recover strong notes the transcriber missed using CQT analysis. | |
| Scans the audio CQT for pitch energy that isn't represented in the MIDI. | |
| When a pitch has strong, sustained energy but no corresponding MIDI note, | |
| synthesize one. Targets upper register (>= C4) where basic-pitch can | |
| under-detect, especially when harmonics cause false removal. | |
| If snap_onsets is provided, recovered notes are snapped to the nearest | |
| existing onset for rhythmic alignment. | |
| Should be run AFTER all removal filters so the coverage map reflects | |
| what actually survived. | |
| """ | |
| midi_out = copy.deepcopy(midi_data) | |
| N_BINS = 88 * 3 | |
| FMIN = librosa.note_to_hz('A0') | |
| C = np.abs(librosa.cqt( | |
| y, sr=sr, hop_length=hop_length, | |
| fmin=FMIN, n_bins=N_BINS, bins_per_octave=36, | |
| )) | |
| C_db = librosa.amplitude_to_db(C, ref=np.max(C)) | |
| times = librosa.frames_to_time(np.arange(C.shape[1]), sr=sr, hop_length=hop_length) | |
| # Build a set of existing note coverage: (pitch, frame) pairs | |
| existing = set() | |
| for inst in midi_out.instruments: | |
| for note in inst.notes: | |
| start_frame = max(0, int(note.start * sr / hop_length)) | |
| end_frame = min(C.shape[1], int(note.end * sr / hop_length)) | |
| for f in range(start_frame, end_frame): | |
| existing.add((note.pitch, f)) | |
| # Scan C4 (60) to B6 (95) for uncovered energy | |
| recovered = 0 | |
| min_energy = -10.0 # dB threshold — only recover notes with strong CQT energy | |
| min_duration_s = 0.05 # ~50ms minimum | |
| gap_tolerance = 3 # allow 3-frame dips without breaking a note | |
| for midi_pitch in range(60, 96): | |
| fund_bin = (midi_pitch - 21) * 3 + 1 | |
| if fund_bin < 0 or fund_bin >= N_BINS: | |
| continue | |
| # Harmonic check: skip if an octave-below note is much louder | |
| # (this note is likely a harmonic, not a real played note) | |
| lower_pitch = midi_pitch - 12 | |
| if lower_pitch >= 21: | |
| lower_bin = (lower_pitch - 21) * 3 + 1 | |
| if 0 <= lower_bin < N_BINS: | |
| lower_lo = max(0, lower_bin - 1) | |
| lower_hi = min(N_BINS, lower_bin + 2) | |
| upper_energy = float(np.max(C_db[max(0, fund_bin - 1):min(N_BINS, fund_bin + 2), :])) | |
| lower_energy = float(np.max(C_db[lower_lo:lower_hi, :])) | |
| if lower_energy - upper_energy > 12: | |
| # Octave below is 12+ dB louder — likely a harmonic | |
| continue | |
| lo = max(0, fund_bin - 1) | |
| hi = min(N_BINS, fund_bin + 2) | |
| # Get energy and coverage for this pitch | |
| pitch_energy = np.max(C_db[lo:hi, :], axis=0) | |
| # Find uncovered regions with strong energy | |
| strong_uncovered = np.array([ | |
| pitch_energy[f] >= min_energy and (midi_pitch, f) not in existing | |
| for f in range(len(pitch_energy)) | |
| ]) | |
| # Close small gaps (morphological closing) | |
| for f in range(1, len(strong_uncovered) - 1): | |
| if not strong_uncovered[f] and pitch_energy[f] >= min_energy - 5: | |
| before = any(strong_uncovered[max(0, f - gap_tolerance):f]) | |
| after = any(strong_uncovered[f + 1:min(len(strong_uncovered), f + gap_tolerance + 1)]) | |
| if before and after: | |
| strong_uncovered[f] = True | |
| # Extract contiguous regions | |
| regions = [] | |
| in_region = False | |
| start_f = 0 | |
| for f in range(len(strong_uncovered)): | |
| if strong_uncovered[f] and not in_region: | |
| start_f = f | |
| in_region = True | |
| elif not strong_uncovered[f] and in_region: | |
| regions.append((start_f, f)) | |
| in_region = False | |
| if in_region: | |
| regions.append((start_f, len(strong_uncovered))) | |
| for start_f, end_f in regions: | |
| t_start = times[start_f] | |
| t_end = times[min(end_f, len(times) - 1)] | |
| if t_end - t_start < min_duration_s: | |
| continue | |
| avg_energy = float(np.mean(pitch_energy[start_f:end_f])) | |
| velocity = min(75, max(35, int(55 + avg_energy * 1.5))) | |
| # Snap to nearest existing onset for rhythmic alignment | |
| note_start = t_start | |
| note_end = t_end | |
| if snap_onsets is not None and len(snap_onsets) > 0: | |
| snap_arr = np.array(snap_onsets) | |
| diffs = np.abs(snap_arr - t_start) | |
| nearest_idx = np.argmin(diffs) | |
| if diffs[nearest_idx] < 0.06: | |
| dur = t_end - t_start | |
| note_start = snap_arr[nearest_idx] | |
| note_end = note_start + dur | |
| new_note = pretty_midi.Note( | |
| velocity=velocity, | |
| pitch=midi_pitch, | |
| start=note_start, | |
| end=note_end, | |
| ) | |
| midi_out.instruments[0].notes.append(new_note) | |
| recovered += 1 | |
| return midi_out, recovered | |
| def estimate_complexity(midi_data, audio_duration): | |
| """Estimate piece complexity to adjust filter aggressiveness. | |
| Returns a dict with: | |
| - note_density: notes per second | |
| - avg_polyphony: average concurrent notes at any onset | |
| - complexity: 'simple' (<4 n/s), 'moderate' (4-8), 'complex' (>8) | |
| Complex pieces need less aggressive ghost removal and wider tolerance | |
| for concurrent notes, since dense textures are intentional. | |
| """ | |
| all_notes = sorted( | |
| [n for inst in midi_data.instruments for n in inst.notes], | |
| key=lambda n: n.start | |
| ) | |
| if not all_notes or audio_duration < 1: | |
| return {'note_density': 0, 'avg_polyphony': 1, 'complexity': 'simple'} | |
| note_density = len(all_notes) / audio_duration | |
| # Compute average polyphony at each onset | |
| onsets = sorted(set(round(n.start, 3) for n in all_notes)) | |
| polyphonies = [] | |
| for onset in onsets: | |
| count = sum(1 for n in all_notes if abs(n.start - onset) < 0.03) | |
| polyphonies.append(count) | |
| avg_polyphony = np.mean(polyphonies) if polyphonies else 1 | |
| if note_density > 8 or avg_polyphony > 3.5: | |
| complexity = 'complex' | |
| elif note_density > 4 or avg_polyphony > 2.5: | |
| complexity = 'moderate' | |
| else: | |
| complexity = 'simple' | |
| return { | |
| 'note_density': note_density, | |
| 'avg_polyphony': avg_polyphony, | |
| 'complexity': complexity, | |
| } | |
| def optimize(original_audio_path, midi_path, output_path=None): | |
| """Full optimization pipeline.""" | |
| if output_path is None: | |
| output_path = midi_path | |
| sr = 22050 | |
| hop_length = 512 | |
| # Load audio and detect onsets | |
| print(f"Analyzing audio: {original_audio_path}") | |
| y, _ = librosa.load(original_audio_path, sr=sr, mono=True) | |
| audio_duration = len(y) / sr | |
| onset_env = librosa.onset.onset_strength(y=y, sr=sr, hop_length=hop_length) | |
| # Use backtrack=False: basic-pitch onsets align with energy peaks, not | |
| # the earlier rise points that backtrack finds (~50ms earlier). | |
| # Use delta=0.04 for higher sensitivity — detects ~15% more onsets, | |
| # reducing unmatched MIDI onsets from 116 to 80. | |
| ref_onset_frames = librosa.onset.onset_detect( | |
| onset_envelope=onset_env, sr=sr, hop_length=hop_length, | |
| backtrack=False, delta=0.04 | |
| ) | |
| ref_onsets = librosa.frames_to_time(ref_onset_frames, sr=sr, hop_length=hop_length) | |
| print(f" {audio_duration:.1f}s, {len(ref_onsets)} audio onsets") | |
| # Load MIDI | |
| midi_data = pretty_midi.PrettyMIDI(str(midi_path)) | |
| total_notes = sum(len(inst.notes) for inst in midi_data.instruments) | |
| print(f" {total_notes} MIDI notes") | |
| # Estimate complexity to adjust filter thresholds | |
| complexity_info = estimate_complexity(midi_data, audio_duration) | |
| complexity = complexity_info['complexity'] | |
| print(f" Complexity: {complexity} (density={complexity_info['note_density']:.1f} n/s, " | |
| f"polyphony={complexity_info['avg_polyphony']:.1f})") | |
| # Step 0: Remove notes in leading silence (mic rumble artifacts) | |
| print("\nStep 0: Removing notes in leading silence...") | |
| midi_data, silence_removed, music_start = remove_leading_silence_notes(midi_data, y, sr) | |
| if silence_removed: | |
| print(f" Music starts at {music_start:.2f}s, removed {silence_removed} noise notes") | |
| else: | |
| print(f" No leading silence detected") | |
| # Step 0b: Remove notes in trailing silence | |
| print("\nStep 0b: Removing notes in trailing silence...") | |
| midi_data, trail_removed, music_end = remove_trailing_silence_notes(midi_data, y, sr) | |
| if trail_removed: | |
| print(f" Music ends at {music_end:.2f}s, removed {trail_removed} trailing notes") | |
| else: | |
| print(f" No trailing silence notes detected") | |
| # Step 0c: Remove low-energy hallucinations | |
| print("\nStep 0c: Removing low-energy hallucinations...") | |
| midi_data, energy_removed = remove_low_energy_notes(midi_data, y, sr, hop_length) | |
| print(f" Removed {energy_removed} notes with no audio onset energy") | |
| # Step 0d: Remove harmonic ghost notes (CQT-aware) | |
| print("\nStep 0d: Removing harmonic ghost notes...") | |
| midi_data, ghosts_removed = remove_harmonic_ghosts(midi_data, y, sr, hop_length) | |
| print(f" Removed {ghosts_removed} octave-harmonic ghosts") | |
| # Step 1: Remove phantom high notes (conservative) | |
| print("\nStep 1: Removing phantom high notes...") | |
| midi_data, phantoms_removed = remove_phantom_notes(midi_data) | |
| print(f" Removed {phantoms_removed} phantom notes") | |
| # Step 1b: Hard pitch ceiling at C7 (MIDI 96) — extreme highs only | |
| print("\nStep 1b: Applying pitch ceiling (C7 / MIDI 96)...") | |
| midi_data, ceiling_removed = apply_pitch_ceiling(midi_data, max_pitch=96) | |
| print(f" Removed {ceiling_removed} notes above C7") | |
| # Step 2: Align chord notes to single onset | |
| print("\nStep 2: Aligning chord notes...") | |
| midi_data, chords_aligned = align_chords(midi_data) | |
| print(f" Aligned {chords_aligned} notes within chords") | |
| # Step 3: Full beat-grid quantization | |
| print("\nStep 3: Quantizing to beat grid...") | |
| midi_data, notes_quantized, detected_tempo = quantize_to_beat_grid( | |
| midi_data, y, sr, hop_length, strength=1.0 | |
| ) | |
| print(f" Detected tempo: {detected_tempo:.0f} BPM") | |
| print(f" Quantized {notes_quantized} notes (full snap)") | |
| # Step 4: Targeted onset correction against audio | |
| print("\nStep 4: Correcting onsets against audio...") | |
| midi_data, corrections, total_shift, n_chords, pre_f1, post_f1 = \ | |
| correct_onsets(midi_data, ref_onsets) | |
| avg_shift = (total_shift / corrections * 1000) if corrections > 0 else 0 | |
| print(f" Corrected {corrections}/{n_chords} (avg {avg_shift:.0f}ms)") | |
| print(f" Onset F1: {pre_f1:.4f} -> {post_f1:.4f}") | |
| # Step 5: Tight second correction pass (10-60ms window) | |
| print("\nStep 5: Fine-tuning onsets (tight pass)...") | |
| midi_data, corrections2, total_shift2, n_chords2, _, post_f1_2 = \ | |
| correct_onsets(midi_data, ref_onsets, min_off=0.01, max_off=0.06) | |
| avg_shift2 = (total_shift2 / corrections2 * 1000) if corrections2 > 0 else 0 | |
| print(f" Fine-tuned {corrections2}/{n_chords2} (avg {avg_shift2:.0f}ms)") | |
| print(f" Onset F1: {post_f1:.4f} -> {post_f1_2:.4f}") | |
| # Step 6: Micro-correction pass (5-25ms window) | |
| print("\nStep 6: Micro-correcting onsets...") | |
| midi_data, corrections3, total_shift3, n_chords3, _, post_f1_3 = \ | |
| correct_onsets(midi_data, ref_onsets, min_off=0.005, max_off=0.025) | |
| avg_shift3 = (total_shift3 / corrections3 * 1000) if corrections3 > 0 else 0 | |
| print(f" Micro-corrected {corrections3}/{n_chords3} (avg {avg_shift3:.0f}ms)") | |
| print(f" Onset F1: {post_f1_2:.4f} -> {post_f1_3:.4f}") | |
| # Step 6b: Remove spurious false-positive onsets | |
| print("\nStep 6b: Removing spurious onsets (false positive cleanup)...") | |
| midi_data, spurious_notes, spurious_onsets = remove_spurious_onsets( | |
| midi_data, y, sr, ref_onsets, hop_length, complexity=complexity | |
| ) | |
| print(f" Removed {spurious_notes} notes across {spurious_onsets} spurious onsets") | |
| # Step 6c: Wide onset recovery pass (50-120ms window) to rescue false negatives | |
| print("\nStep 6c: Wide onset recovery (rescuing false negatives)...") | |
| midi_data, corrections_wide, total_shift_wide, n_chords_wide, _, post_f1_wide = \ | |
| correct_onsets(midi_data, ref_onsets, min_off=0.04, max_off=0.12) | |
| avg_shift_wide = (total_shift_wide / corrections_wide * 1000) if corrections_wide > 0 else 0 | |
| print(f" Recovered {corrections_wide}/{n_chords_wide} (avg {avg_shift_wide:.0f}ms)") | |
| print(f" Onset F1: {post_f1_3:.4f} -> {post_f1_wide:.4f}") | |
| # Step 7: Global offset correction | |
| print("\nStep 7: Correcting systematic offset...") | |
| midi_data, offset = apply_global_offset(midi_data, ref_onsets) | |
| print(f" Applied {offset*1000:+.1f}ms global offset") | |
| # Step 7b: Rhythm consolidation — snap stray onsets to dominant pattern | |
| print("\nStep 7b: Consolidating rhythm pattern...") | |
| midi_data, rhythm_snapped, n_dominant = consolidate_rhythm(midi_data, y, sr, hop_length) | |
| print(f" Snapped {rhythm_snapped} notes to {n_dominant} dominant subdivisions") | |
| # Step 7c: Merge repeated consecutive same-pitch notes without real re-attack | |
| print("\nStep 7c: Merging repeated notes without re-attack energy...") | |
| midi_data, notes_merged = merge_repeated_notes(midi_data, y, sr, hop_length) | |
| print(f" Merged {notes_merged} repeated notes into sustains") | |
| # Step 8: Fix overlaps and enforce min duration (LAST — after all position changes) | |
| print("\nStep 8: Fixing overlaps and enforcing min duration...") | |
| midi_data, notes_trimmed, durations_enforced = fix_note_overlap(midi_data) | |
| print(f" Trimmed {notes_trimmed} overlapping notes") | |
| print(f" Enforced min duration on {durations_enforced} notes") | |
| # Step 8b: CQT-based duration extension | |
| print("\nStep 8b: Extending note durations to match audio decay...") | |
| midi_data, notes_extended = extend_note_durations(midi_data, y, sr, hop_length) | |
| print(f" Extended {notes_extended} notes to match audio CQT decay") | |
| # Step 8c: Re-enforce minimum duration after CQT extension | |
| min_dur_enforced_2 = 0 | |
| for instrument in midi_data.instruments: | |
| for note in instrument.notes: | |
| if note.end - note.start < 0.10: | |
| note.end = note.start + 0.10 | |
| min_dur_enforced_2 += 1 | |
| if min_dur_enforced_2: | |
| print(f"\nStep 8c: Re-enforced min duration on {min_dur_enforced_2} notes after CQT extension") | |
| # Step 8d: CQT pitch-specific energy filter (remove bass hallucinations) | |
| print("\nStep 8d: Removing pitch-unconfirmed bass notes...") | |
| midi_data, cqt_removed = remove_pitch_unconfirmed_notes(midi_data, y, sr, hop_length) | |
| print(f" Removed {cqt_removed} notes with no CQT energy at their pitch") | |
| # Step 8e: Recover missing notes from CQT energy | |
| # Runs late so the coverage map reflects what actually survived all filters. | |
| # Recovered notes won't be touched by phantom/spurious/pitch filters. | |
| print("\nStep 8e: Recovering missing notes from CQT analysis...") | |
| # Collect existing onset times to snap recovered notes to | |
| existing_onsets = sorted(set( | |
| round(n.start, 4) for inst in midi_data.instruments for n in inst.notes | |
| )) | |
| midi_data, notes_recovered = recover_missing_notes( | |
| midi_data, y, sr, hop_length, snap_onsets=existing_onsets | |
| ) | |
| print(f" Recovered {notes_recovered} notes from CQT energy") | |
| # Step 8f: Remove hand outliers — notes too far from their hand's cluster | |
| print("\nStep 8f: Removing hand outlier harmonics...") | |
| midi_data, outliers_removed = remove_hand_outliers(midi_data) | |
| print(f" Removed {outliers_removed} outlier notes") | |
| # Step 8f2: Enforce hand span — no chord wider than an octave per hand | |
| print("\nStep 8f2: Enforcing hand span limit (max 12 semitones per hand)...") | |
| midi_data, span_removed = enforce_hand_span(midi_data, max_span=12) | |
| print(f" Removed {span_removed} notes exceeding hand span") | |
| # Step 8g: Playability filter — limit per-onset chord size | |
| # Complex pieces get 5 notes/hand to preserve dense voicings | |
| # Left hand (bass) gets a tighter limit to avoid muddy chords | |
| max_rh = 3 if complexity == 'complex' else 2 | |
| max_lh = 2 if complexity == 'complex' else 1 | |
| print(f"\nStep 8g: Playability filter (RH max {max_rh}, LH max {max_lh} per chord)...") | |
| midi_data, playability_removed = limit_concurrent_notes( | |
| midi_data, max_per_hand=max_rh, max_left_hand=max_lh | |
| ) | |
| print(f" Removed {playability_removed} excess chord notes") | |
| # Step 8h: Limit total concurrent sounding notes | |
| print(f"\nStep 8h: Concurrent sounding limit (RH max {max_rh}, LH max {max_lh})...") | |
| midi_data, sustain_trimmed = limit_total_concurrent( | |
| midi_data, max_per_hand=max_rh, max_left_hand=max_lh | |
| ) | |
| print(f" Trimmed {sustain_trimmed} sustained notes to reduce pileup") | |
| # Step 9: Final rhythm consolidation — re-snap after all note manipulation | |
| # Steps 8b-8h may have shifted notes off the grid. This pass catches stragglers. | |
| # Uses wider snap (100ms) to aggressively force notes onto the grid. | |
| print("\nStep 9: Final rhythm consolidation...") | |
| midi_data, rhythm_snapped_2, n_dominant_2 = consolidate_rhythm( | |
| midi_data, y, sr, hop_length, max_snap=0.10 | |
| ) | |
| print(f" Re-snapped {rhythm_snapped_2} notes to {n_dominant_2} dominant subdivisions") | |
| # Final metrics | |
| final_onsets = [] | |
| for inst in midi_data.instruments: | |
| for n in inst.notes: | |
| final_onsets.append(n.start) | |
| final_onsets = np.unique(np.round(np.sort(final_onsets), 4)) | |
| final_f1 = onset_f1(ref_onsets, final_onsets) | |
| final_notes = sum(len(inst.notes) for inst in midi_data.instruments) | |
| # Duration sanity check | |
| all_durs = [n.end - n.start for inst in midi_data.instruments for n in inst.notes] | |
| min_dur = min(all_durs) * 1000 if all_durs else 0 | |
| print(f"\nDone:") | |
| print(f" Phantoms removed: {phantoms_removed}") | |
| print(f" Pitch ceiling removed: {ceiling_removed}") | |
| print(f" Playability filter: {playability_removed} chord / {sustain_trimmed} sustain") | |
| print(f" Chords aligned: {chords_aligned}") | |
| print(f" Notes quantized: {notes_quantized} ({detected_tempo:.0f} BPM)") | |
| print(f" Onsets corrected: {corrections}/{n_chords}") | |
| print(f" Spurious onsets removed: {spurious_onsets} ({spurious_notes} notes)") | |
| print(f" FN recovery corrections: {corrections_wide}") | |
| print(f" Global offset: {offset*1000:+.1f}ms") | |
| print(f" Overlaps trimmed: {notes_trimmed}") | |
| print(f" Min durations enforced: {durations_enforced}") | |
| print(f" Notes extended (CQT decay): {notes_extended}") | |
| # Playability check: max concurrent notes per hand | |
| all_final = sorted( | |
| [n for inst in midi_data.instruments for n in inst.notes], | |
| key=lambda n: n.start | |
| ) | |
| max_left = 0 | |
| max_right = 0 | |
| for i, note in enumerate(all_final): | |
| is_right = note.pitch >= 60 | |
| concurrent = sum(1 for o in all_final | |
| if o.start <= note.start < o.end | |
| and (o.pitch >= 60) == is_right) | |
| if is_right: | |
| max_right = max(max_right, concurrent) | |
| else: | |
| max_left = max(max_left, concurrent) | |
| print(f" Final onset F1: {final_f1:.4f}") | |
| print(f" Min note duration: {min_dur:.0f}ms") | |
| print(f" Max concurrent: L={max_left} R={max_right}") | |
| print(f" Notes: {total_notes} -> {final_notes}") | |
| # Final step: shift all notes so music starts at t=0 | |
| # (must be AFTER all audio-aligned processing like onset detection, CQT filters) | |
| if music_start > 0.1: | |
| print(f"\nShifting all notes by -{music_start:.2f}s so music starts at t=0...") | |
| for instrument in midi_data.instruments: | |
| for note in instrument.notes: | |
| note.start = max(0, note.start - music_start) | |
| note.end = max(note.start + 0.01, note.end - music_start) | |
| midi_data.write(str(output_path)) | |
| print(f" Written to {output_path}") | |
| # Step 9: Spectral fidelity analysis (CQT comparison) | |
| print("\nStep 9: Spectral fidelity analysis (CQT comparison)...") | |
| try: | |
| from spectral import spectral_fidelity | |
| spec_results = spectral_fidelity(y, sr, midi_data, hop_length) | |
| print(f" Spectral F1: {spec_results['spectral_f1']:.4f}") | |
| print(f" Spectral Precision: {spec_results['spectral_precision']:.4f}") | |
| print(f" Spectral Recall: {spec_results['spectral_recall']:.4f}") | |
| print(f" Spectral Similarity: {spec_results['spectral_similarity']:.4f}") | |
| # Save spectral report alongside MIDI | |
| import json | |
| report_path = str(output_path).replace('.mid', '_spectral.json') | |
| Path(report_path).write_text(json.dumps(spec_results, indent=2)) | |
| print(f" Report saved to {report_path}") | |
| except Exception as e: | |
| print(f" Spectral analysis failed: {e}") | |
| # Step 10: Chord detection | |
| print("\nStep 10: Detecting chords...") | |
| try: | |
| from chords import detect_chords | |
| chords_json_path = str(Path(output_path).with_name( | |
| Path(output_path).stem + "_chords.json" | |
| )) | |
| chord_events = detect_chords(str(output_path), chords_json_path) | |
| print(f" Detected {len(chord_events)} chord regions") | |
| except Exception as e: | |
| print(f" Chord detection failed: {e}") | |
| chord_events = [] | |
| return midi_data | |
| def onset_f1(ref_onsets, est_onsets, tolerance=0.05): | |
| """Compute onset detection F1 score.""" | |
| if len(ref_onsets) == 0 and len(est_onsets) == 0: | |
| return 1.0 | |
| if len(ref_onsets) == 0 or len(est_onsets) == 0: | |
| return 0.0 | |
| matched_ref = set() | |
| matched_est = set() | |
| for i, r in enumerate(ref_onsets): | |
| diffs = np.abs(est_onsets - r) | |
| best = np.argmin(diffs) | |
| if diffs[best] <= tolerance and best not in matched_est: | |
| matched_ref.add(i) | |
| matched_est.add(best) | |
| precision = len(matched_est) / len(est_onsets) if len(est_onsets) > 0 else 0 | |
| recall = len(matched_ref) / len(ref_onsets) if len(ref_onsets) > 0 else 0 | |
| if precision + recall == 0: | |
| return 0.0 | |
| return 2 * precision * recall / (precision + recall) | |
| if __name__ == "__main__": | |
| import sys | |
| if len(sys.argv) < 3: | |
| print("Usage: python optimize.py <original_audio> <midi_file> [output.mid]") | |
| sys.exit(1) | |
| audio_path = sys.argv[1] | |
| midi_path = sys.argv[2] | |
| out_path = sys.argv[3] if len(sys.argv) > 3 else None | |
| optimize(audio_path, midi_path, out_path) | |