mroctopus / transcriber /chords.py
Ewan
Initial commit - Mr Octopus piano tutorial app
f0a176a
"""Chord detection from MIDI files using template-matching music theory.
Analyzes a MIDI file to detect chords at each note onset, producing a
time-stamped list of chord events with root, quality, and constituent notes.
Designed for the Mr. Octopus piano tutorial pipeline.
"""
import json
from pathlib import Path
from collections import defaultdict
import pretty_midi
import numpy as np
# ---------------------------------------------------------------------------
# Music theory constants
# ---------------------------------------------------------------------------
NOTE_NAMES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
# Enharmonic display preferences: use flats for certain roots to match
# standard music notation (e.g. Bb major, not A# major).
ENHARMONIC_DISPLAY = {
"C": "C", "C#": "Db", "D": "D", "D#": "Eb", "E": "E", "F": "F",
"F#": "F#", "G": "G", "G#": "Ab", "A": "A", "A#": "Bb", "B": "B",
}
# Chord templates: quality name -> set of semitone intervals from root.
# Each template is a frozenset of pitch-class intervals (0 = root).
CHORD_TEMPLATES = {
# Triads
"major": frozenset([0, 4, 7]),
"minor": frozenset([0, 3, 7]),
"diminished": frozenset([0, 3, 6]),
"augmented": frozenset([0, 4, 8]),
# Suspended
"sus2": frozenset([0, 2, 7]),
"sus4": frozenset([0, 5, 7]),
# Seventh chords
"dominant 7": frozenset([0, 4, 7, 10]),
"major 7": frozenset([0, 4, 7, 11]),
"minor 7": frozenset([0, 3, 7, 10]),
"diminished 7": frozenset([0, 3, 6, 9]),
"half-dim 7": frozenset([0, 3, 6, 10]),
"min/maj 7": frozenset([0, 3, 7, 11]),
"augmented 7": frozenset([0, 4, 8, 10]),
# Extended / added-tone
"add9": frozenset([0, 2, 4, 7]),
"minor add9": frozenset([0, 2, 3, 7]),
"6": frozenset([0, 4, 7, 9]),
"minor 6": frozenset([0, 3, 7, 9]),
}
# Short suffix for display (e.g. "Cm7", "Gdim", "Fsus4")
QUALITY_SUFFIX = {
"major": "",
"minor": "m",
"diminished": "dim",
"augmented": "aug",
"sus2": "sus2",
"sus4": "sus4",
"dominant 7": "7",
"major 7": "maj7",
"minor 7": "m7",
"diminished 7": "dim7",
"half-dim 7": "m7b5",
"min/maj 7": "m(maj7)",
"augmented 7": "aug7",
"add9": "add9",
"minor add9": "madd9",
"6": "6",
"minor 6": "m6",
}
# Priority ordering for tie-breaking when multiple templates match equally.
# Lower index = preferred. Triads > sevenths > extended > suspended.
QUALITY_PRIORITY = [
"major", "minor", "dominant 7", "minor 7", "major 7",
"diminished", "augmented", "half-dim 7", "diminished 7",
"6", "minor 6", "sus4", "sus2", "add9", "minor add9",
"min/maj 7", "augmented 7",
]
# ---------------------------------------------------------------------------
# Frame extraction
# ---------------------------------------------------------------------------
def extract_note_frames(midi_data, onset_tolerance=0.05):
"""Group MIDI notes into simultaneous frames (chords / single notes).
Notes whose onsets fall within `onset_tolerance` seconds of each other
are grouped into the same frame. Returns a list of dicts:
{
"start": float, # earliest onset in the group
"end": float, # latest note-off in the group
"pitches": [int], # MIDI pitch numbers
"velocities": [int], # corresponding velocities
}
sorted by start time.
"""
# Collect all notes across instruments (typically only one for piano)
all_notes = []
for inst in midi_data.instruments:
for note in inst.notes:
all_notes.append(note)
all_notes.sort(key=lambda n: n.start)
if not all_notes:
return []
frames = []
current_group = [all_notes[0]]
for note in all_notes[1:]:
if note.start - current_group[0].start <= onset_tolerance:
current_group.append(note)
else:
frames.append(_group_to_frame(current_group))
current_group = [note]
frames.append(_group_to_frame(current_group))
return frames
def _group_to_frame(notes):
"""Convert a group of pretty_midi Note objects into a frame dict."""
return {
"start": min(n.start for n in notes),
"end": max(n.end for n in notes),
"pitches": [n.pitch for n in notes],
"velocities": [n.velocity for n in notes],
}
# ---------------------------------------------------------------------------
# Template matching
# ---------------------------------------------------------------------------
def pitch_class_set(pitches):
"""Convert a list of MIDI pitches to a set of pitch classes (0-11)."""
return set(p % 12 for p in pitches)
def match_chord(pitches, velocities=None):
"""Identify a chord from a set of MIDI pitches.
Uses a template-matching approach that tests every possible root (0-11)
against every chord template. Scoring:
1. Count how many template tones are present in the pitch-class set
(weighted by velocity when available).
2. Penalize extra notes not in the template.
3. Prefer templates that explain more notes.
4. Handle inversions: the bass note does not need to be the root.
Returns a dict:
{
"root": int, # pitch class 0-11
"root_name": str, # e.g. "C", "Db"
"quality": str, # e.g. "minor 7"
"chord_name": str, # e.g. "Cm7"
"notes": [str], # constituent note names
"midi_pitches": [int], # original MIDI pitches
}
or None if fewer than 2 distinct pitch classes.
"""
pcs = pitch_class_set(pitches)
if len(pcs) < 2:
return _single_note_result(pitches) if pitches else None
# Build a velocity weight map (pitch class -> total velocity)
pc_weight = defaultdict(float)
if velocities and len(velocities) == len(pitches):
for p, v in zip(pitches, velocities):
pc_weight[p % 12] += v
else:
for p in pitches:
pc_weight[p % 12] += 80 # default velocity
# Normalize weights so the max is 1.0
max_w = max(pc_weight.values()) if pc_weight else 1.0
for pc in pc_weight:
pc_weight[pc] /= max_w
# Determine the bass note (lowest pitch) for inversion bonus
bass_pc = min(pitches) % 12
best_score = -999
best_result = None
for root in range(12):
for quality, template in CHORD_TEMPLATES.items():
# Transpose template to this root
transposed = frozenset((root + interval) % 12 for interval in template)
# Score: weighted count of template tones present
matched_weight = 0.0
matched_count = 0
for pc in transposed:
if pc in pcs:
matched_weight += pc_weight.get(pc, 0.5)
matched_count += 1
# How many of the input pitch classes are NOT in the template?
extra_notes = len(pcs - transposed)
# How many template tones are missing?
missing = len(transposed) - matched_count
# Base score: reward matches, penalize misses and extras
score = matched_weight * 2.0 - missing * 1.5 - extra_notes * 0.5
# Bonus if this template perfectly covers all input notes
if extra_notes == 0 and missing == 0:
score += 3.0
# Bonus if root is the bass note (root position)
if root == bass_pc:
score += 0.8
# Bonus for root having high velocity
score += pc_weight.get(root, 0) * 0.3
# Smaller bonus for simpler chord types (triads over 7ths)
priority_idx = QUALITY_PRIORITY.index(quality) if quality in QUALITY_PRIORITY else len(QUALITY_PRIORITY)
score -= priority_idx * 0.05
# A template must match at least 2 pitch classes to be viable
if matched_count < 2:
continue
if score > best_score:
best_score = score
root_name = ENHARMONIC_DISPLAY[NOTE_NAMES[root]]
suffix = QUALITY_SUFFIX.get(quality, quality)
chord_name = f"{root_name}{suffix}"
best_result = {
"root": root,
"root_name": root_name,
"quality": quality,
"chord_name": chord_name,
"notes": sorted([ENHARMONIC_DISPLAY[NOTE_NAMES[pc]] for pc in transposed]),
"midi_pitches": sorted(pitches),
}
# If no template matched well enough, fall back to describing the bass + interval
if best_result is None:
return _fallback_chord(pitches)
return best_result
def _single_note_result(pitches):
"""Return a result for a single note (no chord)."""
if not pitches:
return None
pc = pitches[0] % 12
name = ENHARMONIC_DISPLAY[NOTE_NAMES[pc]]
return {
"root": pc,
"root_name": name,
"quality": "note",
"chord_name": name,
"notes": [name],
"midi_pitches": sorted(pitches),
}
def _fallback_chord(pitches):
"""Produce a best-effort label for unrecognized pitch combinations."""
pcs = pitch_class_set(pitches)
bass_pc = min(pitches) % 12
bass_name = ENHARMONIC_DISPLAY[NOTE_NAMES[bass_pc]]
# Try to describe as a root + collection of intervals
intervals = sorted((pc - bass_pc) % 12 for pc in pcs if pc != bass_pc)
interval_str = ",".join(str(i) for i in intervals)
return {
"root": bass_pc,
"root_name": bass_name,
"quality": "unknown",
"chord_name": f"{bass_name}({interval_str})",
"notes": sorted([ENHARMONIC_DISPLAY[NOTE_NAMES[pc]] for pc in pcs]),
"midi_pitches": sorted(pitches),
}
# ---------------------------------------------------------------------------
# Smoothing
# ---------------------------------------------------------------------------
def smooth_chords(chord_events, min_duration=0.1):
"""Remove very short chord changes and merge consecutive identical chords.
If the same chord name appears in consecutive events and the intermediate
event lasts less than `min_duration`, it gets absorbed into the surrounding
chord. Then consecutive events with the same chord name are merged.
"""
if not chord_events:
return chord_events
# Pass 1: Remove extremely short transient chords (< min_duration)
# by replacing them with the previous chord name
filtered = list(chord_events)
for i in range(1, len(filtered) - 1):
duration = filtered[i]["end_time"] - filtered[i]["start_time"]
if duration < min_duration:
# Absorb into previous chord
filtered[i]["chord_name"] = filtered[i - 1]["chord_name"]
filtered[i]["quality"] = filtered[i - 1]["quality"]
filtered[i]["root_note"] = filtered[i - 1]["root_note"]
filtered[i]["notes"] = filtered[i - 1]["notes"]
# Pass 2: Merge consecutive events with the same chord name
merged = [filtered[0]]
for event in filtered[1:]:
if event["chord_name"] == merged[-1]["chord_name"]:
# Extend the previous event's end time
merged[-1]["end_time"] = event["end_time"]
# Merge midi_pitches (union)
existing = set(merged[-1].get("midi_pitches", []))
existing.update(event.get("midi_pitches", []))
merged[-1]["midi_pitches"] = sorted(existing)
else:
merged.append(event)
return merged
# ---------------------------------------------------------------------------
# Main detection pipeline
# ---------------------------------------------------------------------------
def detect_chords(midi_path, output_path=None, onset_tolerance=0.05,
min_chord_duration=0.1):
"""Detect chords from a MIDI file and save results as JSON.
Parameters
----------
midi_path : str or Path
Path to the input MIDI file.
output_path : str or Path, optional
Path for the output JSON file. Defaults to the MIDI filename
with "_chords.json" suffix.
onset_tolerance : float
Maximum time difference (seconds) to group notes into the same frame.
min_chord_duration : float
Minimum duration for a chord event; shorter events get smoothed away.
Returns
-------
list[dict]
List of chord event dicts, each containing:
- start_time (float): onset time in seconds
- end_time (float): offset time in seconds
- chord_name (str): display name, e.g. "Am7"
- root_note (str): root pitch class name, e.g. "A"
- quality (str): chord quality, e.g. "minor 7"
- notes (list[str]): constituent note names
- midi_pitches (list[int]): original MIDI pitch numbers
"""
midi_path = Path(midi_path)
if output_path is None:
output_path = midi_path.with_name(
midi_path.stem + "_chords.json"
)
else:
output_path = Path(output_path)
print(f"\nChord detection: {midi_path.name}")
# Load MIDI
midi_data = pretty_midi.PrettyMIDI(str(midi_path))
# Extract note frames
frames = extract_note_frames(midi_data, onset_tolerance=onset_tolerance)
print(f" Extracted {len(frames)} note frames")
if not frames:
result = []
_write_json(result, output_path)
return result
# Match chords for each frame
raw_events = []
for frame in frames:
chord = match_chord(frame["pitches"], frame["velocities"])
if chord is None:
continue
raw_events.append({
"start_time": round(frame["start"], 4),
"end_time": round(frame["end"], 4),
"chord_name": chord["chord_name"],
"root_note": chord["root_name"],
"quality": chord["quality"],
"notes": chord["notes"],
"midi_pitches": chord["midi_pitches"],
})
print(f" Identified {len(raw_events)} raw chord events")
# Smooth results
smoothed = smooth_chords(raw_events, min_duration=min_chord_duration)
print(f" After smoothing: {len(smoothed)} chord events")
# Round all times for clean output
for event in smoothed:
event["start_time"] = round(event["start_time"], 4)
event["end_time"] = round(event["end_time"], 4)
# Summary: count unique chords
unique_chords = set(e["chord_name"] for e in smoothed)
print(f" Unique chords: {len(unique_chords)} ({', '.join(sorted(unique_chords))})")
# Write JSON
_write_json(smoothed, output_path)
print(f" Saved to {output_path}")
return smoothed
def _write_json(data, path):
"""Write chord data to a JSON file."""
output = {
"version": 1,
"description": "Chord detection output from Mr. Octopus piano tutorial pipeline",
"chord_count": len(data),
"chords": data,
}
with open(path, "w") as f:
json.dump(output, f, indent=2)
# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python chords.py <midi_file> [output.json]")
print()
print("Analyzes a MIDI file and detects chords at each note onset.")
print("Outputs a JSON file with timestamped chord events.")
sys.exit(1)
midi_file = sys.argv[1]
out_file = sys.argv[2] if len(sys.argv) > 2 else None
events = detect_chords(midi_file, out_file)
print(f"\nDetected {len(events)} chord events")