Spaces:
Sleeping
Sleeping
| """Chord detection from MIDI files using template-matching music theory. | |
| Analyzes a MIDI file to detect chords at each note onset, producing a | |
| time-stamped list of chord events with root, quality, and constituent notes. | |
| Designed for the Mr. Octopus piano tutorial pipeline. | |
| """ | |
| import json | |
| from pathlib import Path | |
| from collections import defaultdict | |
| import pretty_midi | |
| import numpy as np | |
| # --------------------------------------------------------------------------- | |
| # Music theory constants | |
| # --------------------------------------------------------------------------- | |
| NOTE_NAMES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] | |
| # Enharmonic display preferences: use flats for certain roots to match | |
| # standard music notation (e.g. Bb major, not A# major). | |
| ENHARMONIC_DISPLAY = { | |
| "C": "C", "C#": "Db", "D": "D", "D#": "Eb", "E": "E", "F": "F", | |
| "F#": "F#", "G": "G", "G#": "Ab", "A": "A", "A#": "Bb", "B": "B", | |
| } | |
| # Chord templates: quality name -> set of semitone intervals from root. | |
| # Each template is a frozenset of pitch-class intervals (0 = root). | |
| CHORD_TEMPLATES = { | |
| # Triads | |
| "major": frozenset([0, 4, 7]), | |
| "minor": frozenset([0, 3, 7]), | |
| "diminished": frozenset([0, 3, 6]), | |
| "augmented": frozenset([0, 4, 8]), | |
| # Suspended | |
| "sus2": frozenset([0, 2, 7]), | |
| "sus4": frozenset([0, 5, 7]), | |
| # Seventh chords | |
| "dominant 7": frozenset([0, 4, 7, 10]), | |
| "major 7": frozenset([0, 4, 7, 11]), | |
| "minor 7": frozenset([0, 3, 7, 10]), | |
| "diminished 7": frozenset([0, 3, 6, 9]), | |
| "half-dim 7": frozenset([0, 3, 6, 10]), | |
| "min/maj 7": frozenset([0, 3, 7, 11]), | |
| "augmented 7": frozenset([0, 4, 8, 10]), | |
| # Extended / added-tone | |
| "add9": frozenset([0, 2, 4, 7]), | |
| "minor add9": frozenset([0, 2, 3, 7]), | |
| "6": frozenset([0, 4, 7, 9]), | |
| "minor 6": frozenset([0, 3, 7, 9]), | |
| } | |
| # Short suffix for display (e.g. "Cm7", "Gdim", "Fsus4") | |
| QUALITY_SUFFIX = { | |
| "major": "", | |
| "minor": "m", | |
| "diminished": "dim", | |
| "augmented": "aug", | |
| "sus2": "sus2", | |
| "sus4": "sus4", | |
| "dominant 7": "7", | |
| "major 7": "maj7", | |
| "minor 7": "m7", | |
| "diminished 7": "dim7", | |
| "half-dim 7": "m7b5", | |
| "min/maj 7": "m(maj7)", | |
| "augmented 7": "aug7", | |
| "add9": "add9", | |
| "minor add9": "madd9", | |
| "6": "6", | |
| "minor 6": "m6", | |
| } | |
| # Priority ordering for tie-breaking when multiple templates match equally. | |
| # Lower index = preferred. Triads > sevenths > extended > suspended. | |
| QUALITY_PRIORITY = [ | |
| "major", "minor", "dominant 7", "minor 7", "major 7", | |
| "diminished", "augmented", "half-dim 7", "diminished 7", | |
| "6", "minor 6", "sus4", "sus2", "add9", "minor add9", | |
| "min/maj 7", "augmented 7", | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Frame extraction | |
| # --------------------------------------------------------------------------- | |
| def extract_note_frames(midi_data, onset_tolerance=0.05): | |
| """Group MIDI notes into simultaneous frames (chords / single notes). | |
| Notes whose onsets fall within `onset_tolerance` seconds of each other | |
| are grouped into the same frame. Returns a list of dicts: | |
| { | |
| "start": float, # earliest onset in the group | |
| "end": float, # latest note-off in the group | |
| "pitches": [int], # MIDI pitch numbers | |
| "velocities": [int], # corresponding velocities | |
| } | |
| sorted by start time. | |
| """ | |
| # Collect all notes across instruments (typically only one for piano) | |
| all_notes = [] | |
| for inst in midi_data.instruments: | |
| for note in inst.notes: | |
| all_notes.append(note) | |
| all_notes.sort(key=lambda n: n.start) | |
| if not all_notes: | |
| return [] | |
| frames = [] | |
| current_group = [all_notes[0]] | |
| for note in all_notes[1:]: | |
| if note.start - current_group[0].start <= onset_tolerance: | |
| current_group.append(note) | |
| else: | |
| frames.append(_group_to_frame(current_group)) | |
| current_group = [note] | |
| frames.append(_group_to_frame(current_group)) | |
| return frames | |
| def _group_to_frame(notes): | |
| """Convert a group of pretty_midi Note objects into a frame dict.""" | |
| return { | |
| "start": min(n.start for n in notes), | |
| "end": max(n.end for n in notes), | |
| "pitches": [n.pitch for n in notes], | |
| "velocities": [n.velocity for n in notes], | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Template matching | |
| # --------------------------------------------------------------------------- | |
| def pitch_class_set(pitches): | |
| """Convert a list of MIDI pitches to a set of pitch classes (0-11).""" | |
| return set(p % 12 for p in pitches) | |
| def match_chord(pitches, velocities=None): | |
| """Identify a chord from a set of MIDI pitches. | |
| Uses a template-matching approach that tests every possible root (0-11) | |
| against every chord template. Scoring: | |
| 1. Count how many template tones are present in the pitch-class set | |
| (weighted by velocity when available). | |
| 2. Penalize extra notes not in the template. | |
| 3. Prefer templates that explain more notes. | |
| 4. Handle inversions: the bass note does not need to be the root. | |
| Returns a dict: | |
| { | |
| "root": int, # pitch class 0-11 | |
| "root_name": str, # e.g. "C", "Db" | |
| "quality": str, # e.g. "minor 7" | |
| "chord_name": str, # e.g. "Cm7" | |
| "notes": [str], # constituent note names | |
| "midi_pitches": [int], # original MIDI pitches | |
| } | |
| or None if fewer than 2 distinct pitch classes. | |
| """ | |
| pcs = pitch_class_set(pitches) | |
| if len(pcs) < 2: | |
| return _single_note_result(pitches) if pitches else None | |
| # Build a velocity weight map (pitch class -> total velocity) | |
| pc_weight = defaultdict(float) | |
| if velocities and len(velocities) == len(pitches): | |
| for p, v in zip(pitches, velocities): | |
| pc_weight[p % 12] += v | |
| else: | |
| for p in pitches: | |
| pc_weight[p % 12] += 80 # default velocity | |
| # Normalize weights so the max is 1.0 | |
| max_w = max(pc_weight.values()) if pc_weight else 1.0 | |
| for pc in pc_weight: | |
| pc_weight[pc] /= max_w | |
| # Determine the bass note (lowest pitch) for inversion bonus | |
| bass_pc = min(pitches) % 12 | |
| best_score = -999 | |
| best_result = None | |
| for root in range(12): | |
| for quality, template in CHORD_TEMPLATES.items(): | |
| # Transpose template to this root | |
| transposed = frozenset((root + interval) % 12 for interval in template) | |
| # Score: weighted count of template tones present | |
| matched_weight = 0.0 | |
| matched_count = 0 | |
| for pc in transposed: | |
| if pc in pcs: | |
| matched_weight += pc_weight.get(pc, 0.5) | |
| matched_count += 1 | |
| # How many of the input pitch classes are NOT in the template? | |
| extra_notes = len(pcs - transposed) | |
| # How many template tones are missing? | |
| missing = len(transposed) - matched_count | |
| # Base score: reward matches, penalize misses and extras | |
| score = matched_weight * 2.0 - missing * 1.5 - extra_notes * 0.5 | |
| # Bonus if this template perfectly covers all input notes | |
| if extra_notes == 0 and missing == 0: | |
| score += 3.0 | |
| # Bonus if root is the bass note (root position) | |
| if root == bass_pc: | |
| score += 0.8 | |
| # Bonus for root having high velocity | |
| score += pc_weight.get(root, 0) * 0.3 | |
| # Smaller bonus for simpler chord types (triads over 7ths) | |
| priority_idx = QUALITY_PRIORITY.index(quality) if quality in QUALITY_PRIORITY else len(QUALITY_PRIORITY) | |
| score -= priority_idx * 0.05 | |
| # A template must match at least 2 pitch classes to be viable | |
| if matched_count < 2: | |
| continue | |
| if score > best_score: | |
| best_score = score | |
| root_name = ENHARMONIC_DISPLAY[NOTE_NAMES[root]] | |
| suffix = QUALITY_SUFFIX.get(quality, quality) | |
| chord_name = f"{root_name}{suffix}" | |
| best_result = { | |
| "root": root, | |
| "root_name": root_name, | |
| "quality": quality, | |
| "chord_name": chord_name, | |
| "notes": sorted([ENHARMONIC_DISPLAY[NOTE_NAMES[pc]] for pc in transposed]), | |
| "midi_pitches": sorted(pitches), | |
| } | |
| # If no template matched well enough, fall back to describing the bass + interval | |
| if best_result is None: | |
| return _fallback_chord(pitches) | |
| return best_result | |
| def _single_note_result(pitches): | |
| """Return a result for a single note (no chord).""" | |
| if not pitches: | |
| return None | |
| pc = pitches[0] % 12 | |
| name = ENHARMONIC_DISPLAY[NOTE_NAMES[pc]] | |
| return { | |
| "root": pc, | |
| "root_name": name, | |
| "quality": "note", | |
| "chord_name": name, | |
| "notes": [name], | |
| "midi_pitches": sorted(pitches), | |
| } | |
| def _fallback_chord(pitches): | |
| """Produce a best-effort label for unrecognized pitch combinations.""" | |
| pcs = pitch_class_set(pitches) | |
| bass_pc = min(pitches) % 12 | |
| bass_name = ENHARMONIC_DISPLAY[NOTE_NAMES[bass_pc]] | |
| # Try to describe as a root + collection of intervals | |
| intervals = sorted((pc - bass_pc) % 12 for pc in pcs if pc != bass_pc) | |
| interval_str = ",".join(str(i) for i in intervals) | |
| return { | |
| "root": bass_pc, | |
| "root_name": bass_name, | |
| "quality": "unknown", | |
| "chord_name": f"{bass_name}({interval_str})", | |
| "notes": sorted([ENHARMONIC_DISPLAY[NOTE_NAMES[pc]] for pc in pcs]), | |
| "midi_pitches": sorted(pitches), | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Smoothing | |
| # --------------------------------------------------------------------------- | |
| def smooth_chords(chord_events, min_duration=0.1): | |
| """Remove very short chord changes and merge consecutive identical chords. | |
| If the same chord name appears in consecutive events and the intermediate | |
| event lasts less than `min_duration`, it gets absorbed into the surrounding | |
| chord. Then consecutive events with the same chord name are merged. | |
| """ | |
| if not chord_events: | |
| return chord_events | |
| # Pass 1: Remove extremely short transient chords (< min_duration) | |
| # by replacing them with the previous chord name | |
| filtered = list(chord_events) | |
| for i in range(1, len(filtered) - 1): | |
| duration = filtered[i]["end_time"] - filtered[i]["start_time"] | |
| if duration < min_duration: | |
| # Absorb into previous chord | |
| filtered[i]["chord_name"] = filtered[i - 1]["chord_name"] | |
| filtered[i]["quality"] = filtered[i - 1]["quality"] | |
| filtered[i]["root_note"] = filtered[i - 1]["root_note"] | |
| filtered[i]["notes"] = filtered[i - 1]["notes"] | |
| # Pass 2: Merge consecutive events with the same chord name | |
| merged = [filtered[0]] | |
| for event in filtered[1:]: | |
| if event["chord_name"] == merged[-1]["chord_name"]: | |
| # Extend the previous event's end time | |
| merged[-1]["end_time"] = event["end_time"] | |
| # Merge midi_pitches (union) | |
| existing = set(merged[-1].get("midi_pitches", [])) | |
| existing.update(event.get("midi_pitches", [])) | |
| merged[-1]["midi_pitches"] = sorted(existing) | |
| else: | |
| merged.append(event) | |
| return merged | |
| # --------------------------------------------------------------------------- | |
| # Main detection pipeline | |
| # --------------------------------------------------------------------------- | |
| def detect_chords(midi_path, output_path=None, onset_tolerance=0.05, | |
| min_chord_duration=0.1): | |
| """Detect chords from a MIDI file and save results as JSON. | |
| Parameters | |
| ---------- | |
| midi_path : str or Path | |
| Path to the input MIDI file. | |
| output_path : str or Path, optional | |
| Path for the output JSON file. Defaults to the MIDI filename | |
| with "_chords.json" suffix. | |
| onset_tolerance : float | |
| Maximum time difference (seconds) to group notes into the same frame. | |
| min_chord_duration : float | |
| Minimum duration for a chord event; shorter events get smoothed away. | |
| Returns | |
| ------- | |
| list[dict] | |
| List of chord event dicts, each containing: | |
| - start_time (float): onset time in seconds | |
| - end_time (float): offset time in seconds | |
| - chord_name (str): display name, e.g. "Am7" | |
| - root_note (str): root pitch class name, e.g. "A" | |
| - quality (str): chord quality, e.g. "minor 7" | |
| - notes (list[str]): constituent note names | |
| - midi_pitches (list[int]): original MIDI pitch numbers | |
| """ | |
| midi_path = Path(midi_path) | |
| if output_path is None: | |
| output_path = midi_path.with_name( | |
| midi_path.stem + "_chords.json" | |
| ) | |
| else: | |
| output_path = Path(output_path) | |
| print(f"\nChord detection: {midi_path.name}") | |
| # Load MIDI | |
| midi_data = pretty_midi.PrettyMIDI(str(midi_path)) | |
| # Extract note frames | |
| frames = extract_note_frames(midi_data, onset_tolerance=onset_tolerance) | |
| print(f" Extracted {len(frames)} note frames") | |
| if not frames: | |
| result = [] | |
| _write_json(result, output_path) | |
| return result | |
| # Match chords for each frame | |
| raw_events = [] | |
| for frame in frames: | |
| chord = match_chord(frame["pitches"], frame["velocities"]) | |
| if chord is None: | |
| continue | |
| raw_events.append({ | |
| "start_time": round(frame["start"], 4), | |
| "end_time": round(frame["end"], 4), | |
| "chord_name": chord["chord_name"], | |
| "root_note": chord["root_name"], | |
| "quality": chord["quality"], | |
| "notes": chord["notes"], | |
| "midi_pitches": chord["midi_pitches"], | |
| }) | |
| print(f" Identified {len(raw_events)} raw chord events") | |
| # Smooth results | |
| smoothed = smooth_chords(raw_events, min_duration=min_chord_duration) | |
| print(f" After smoothing: {len(smoothed)} chord events") | |
| # Round all times for clean output | |
| for event in smoothed: | |
| event["start_time"] = round(event["start_time"], 4) | |
| event["end_time"] = round(event["end_time"], 4) | |
| # Summary: count unique chords | |
| unique_chords = set(e["chord_name"] for e in smoothed) | |
| print(f" Unique chords: {len(unique_chords)} ({', '.join(sorted(unique_chords))})") | |
| # Write JSON | |
| _write_json(smoothed, output_path) | |
| print(f" Saved to {output_path}") | |
| return smoothed | |
| def _write_json(data, path): | |
| """Write chord data to a JSON file.""" | |
| output = { | |
| "version": 1, | |
| "description": "Chord detection output from Mr. Octopus piano tutorial pipeline", | |
| "chord_count": len(data), | |
| "chords": data, | |
| } | |
| with open(path, "w") as f: | |
| json.dump(output, f, indent=2) | |
| # --------------------------------------------------------------------------- | |
| # CLI entry point | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import sys | |
| if len(sys.argv) < 2: | |
| print("Usage: python chords.py <midi_file> [output.json]") | |
| print() | |
| print("Analyzes a MIDI file and detects chords at each note onset.") | |
| print("Outputs a JSON file with timestamped chord events.") | |
| sys.exit(1) | |
| midi_file = sys.argv[1] | |
| out_file = sys.argv[2] if len(sys.argv) > 2 else None | |
| events = detect_chords(midi_file, out_file) | |
| print(f"\nDetected {len(events)} chord events") | |