"""MIDI tokenizer for bach-gpt. Vocabulary ---------- Structural: PAD, EOS, BAR_START, BAR_END, PHRASE_START, PHRASE_END, REST, CHORD_START, CHORD_END (9) Pitch: MIDI 21..108 (88) Duration: 32 log-quantized bins over [0.03125, 4.0] seconds (32) Time-shift: 32 log-quantized bins over [0.03125, 4.0] seconds (32) Velocity: 16 uniform bins over [0, 127] (V0..V15) (16) Voice/chan: 16 GM families + 1 drums (VC0..VC16) (17) Tempo: 16 log-spaced bins over [40, 240] BPM (T0..T15) (16) Position: 16 sub-beat positions per bar (POS0..POS15) (16) Meter: 8 common time signatures + OTHER (METER_*) (9) Voice role: ROLE_BASS, ROLE_INNER, ROLE_TOP (within chord brackets) (3) Total vocab size: 238 tokens. API --- encode(pm: pretty_midi.PrettyMIDI) -> List[int] decode(ids: List[int]) -> pretty_midi.PrettyMIDI round_trip_test(pm) -> (passed: bool, details: dict) Note on fidelity: encode/decode preserve the pitch multiset and instrument family per note but timing is fuzzy due to log quantization. BAR_* and PHRASE_* markers are emitted from PrettyMIDI downbeats when a time signature is present. A tempo token (T*) is emitted at PHRASE_START and on tempo changes. A VC* token is emitted whenever the active track's instrument family changes. """ from __future__ import annotations import json import math from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np import pretty_midi # --- Vocabulary construction -------------------------------------------------- STRUCTURAL = [ "PAD", "EOS", "BAR_START", "BAR_END", "PHRASE_START", "PHRASE_END", "REST", "CHORD_START", "CHORD_END", ] ROLES = ["ROLE_BASS", "ROLE_INNER", "ROLE_TOP"] # 24 keys: indices 0..11 = C..B major; 12..23 = C..B minor (PrettyMIDI's # convention via key_number). KEYS = [f"KEY_{i}" for i in range(24)] # Bar-header axes: emitted right after each BAR_START as a coarse summary # of the bar's harmonic, density, and register content. ROOT_NAMES = [f"ROOT_{i}" for i in range(12)] N_DENS_BINS = 4 # 0-3, 4-7, 8-15, 16+ N_REG_BINS = 4 # <48, 48-59, 60-71, 72+ DENS_NAMES = [f"DENS_{i}" for i in range(N_DENS_BINS)] REG_NAMES = [f"REG_{i}" for i in range(N_REG_BINS)] # Bar-repetition markers: emitted right after bar header when this bar's # pitch multiset matches a bar K positions earlier (K in {1, 2, 4, 8}). REF_DISTANCES = [1, 2, 4, 8] REF_NAMES = [f"REF_BAR_{k}" for k in REF_DISTANCES] # Caption-segment markers for cross-modal alignment with MidiCaps-style # multi-sentence captions. Emit CAP_SEG_ at PHRASE_START to bind the # next phrase to the i-th caption segment. N_CAP_SEGS = 8 CAP_SEG_NAMES = [f"CAP_SEG_{i}" for i in range(N_CAP_SEGS)] # Pedal tokens: GM CC#64 (sustain), CC#66 (sostenuto), CC#67 (soft). PEDAL_CC_NUMBERS = {64: "SUS", 66: "SOS", 67: "SFT"} PEDAL_NAMES = [ f"PEDAL_{p}_{state}" for p in PEDAL_CC_NUMBERS.values() for state in ("UP", "DOWN") ] # Continuous-controller tokens for the most common expressive CCs, each # quantized to 8 bins over [0, 128). N_CC_BINS = 8 CC_TYPES = {1: "MOD", 7: "VOL", 10: "PAN", 11: "EXPR"} CC_NAMES = [ f"CC_{name}_{i}" for name in CC_TYPES.values() for i in range(N_CC_BINS) ] # Pitch-bend tokens. PrettyMIDI gives 14-bit values in [-8192, 8191]; # quantize to 16 uniform bins and emit as PB_. N_PB_BINS = 16 PB_NAMES = [f"PB_{i}" for i in range(N_PB_BINS)] # Reverse maps for the decoder: short name -> CC number. PEDAL_NAME_TO_CC = {v: k for k, v in PEDAL_CC_NUMBERS.items()} CC_NAME_TO_NUMBER = {v: k for k, v in CC_TYPES.items()} METERS = [ "METER_2_4", "METER_3_4", "METER_4_4", "METER_5_4", "METER_6_8", "METER_7_8", "METER_9_8", "METER_12_8", "METER_OTHER", ] # Bar length in quarter-notes for each meter (used by decoder). METER_QUARTERS: Dict[str, float] = { "METER_2_4": 2.0, "METER_3_4": 3.0, "METER_4_4": 4.0, "METER_5_4": 5.0, "METER_6_8": 3.0, "METER_7_8": 3.5, "METER_9_8": 4.5, "METER_12_8": 6.0, "METER_OTHER": 4.0, } PITCH_MIN, PITCH_MAX = 21, 108 # 88 pitches (A0..C8) N_PITCH = PITCH_MAX - PITCH_MIN + 1 # 88 N_DUR_BINS = 32 N_SHIFT_BINS = 32 N_VEL_BINS = 16 # 16 GM instrument families + 1 reserved drum voice (VC16). N_VOICE_BINS = 17 DRUM_VOICE = 16 N_TEMPO_BINS = 16 # Sub-beat resolution per bar (sixteenth-note grid in 4/4). N_POS_BINS = 16 # Log-quantization range: 2**-5 s (~31 ms) to 4 s. TIME_MIN, TIME_MAX = 2 ** -5, 4.0 LOG_TIME_EDGES = np.linspace(math.log(TIME_MIN), math.log(TIME_MAX), N_DUR_BINS + 1) # Tempo log-quantization range: 40..240 BPM. TEMPO_MIN, TEMPO_MAX = 40.0, 240.0 LOG_TEMPO_EDGES = np.linspace( math.log(TEMPO_MIN), math.log(TEMPO_MAX), N_TEMPO_BINS + 1 ) def _build_vocab() -> Tuple[List[str], Dict[str, int]]: tokens: List[str] = list(STRUCTURAL) tokens += [f"P{p}" for p in range(PITCH_MIN, PITCH_MAX + 1)] tokens += [f"D{i}" for i in range(N_DUR_BINS)] tokens += [f"TS{i}" for i in range(N_SHIFT_BINS)] tokens += [f"V{i}" for i in range(N_VEL_BINS)] tokens += [f"VC{i}" for i in range(N_VOICE_BINS)] tokens += [f"T{i}" for i in range(N_TEMPO_BINS)] tokens += [f"POS{i}" for i in range(N_POS_BINS)] tokens += list(METERS) tokens += list(ROLES) tokens += list(KEYS) tokens += list(ROOT_NAMES) tokens += list(DENS_NAMES) tokens += list(REG_NAMES) tokens += list(REF_NAMES) tokens += list(CAP_SEG_NAMES) tokens += list(PEDAL_NAMES) tokens += list(CC_NAMES) tokens += list(PB_NAMES) t2i = {t: i for i, t in enumerate(tokens)} return tokens, t2i TOKENS, TOKEN2ID = _build_vocab() ID2TOKEN = {i: t for t, i in TOKEN2ID.items()} VOCAB_SIZE = len(TOKENS) # Default location for fitted velocity quantiles. See _maybe_load_vel_edges # at the end of this module for the auto-load. DEFAULT_VEL_QUANTILES_PATH = ( Path(__file__).resolve().parent.parent / "data" / "tokenizer" / "velocity_quantiles.json" ) PAD = TOKEN2ID["PAD"] EOS = TOKEN2ID["EOS"] BAR_START = TOKEN2ID["BAR_START"] BAR_END = TOKEN2ID["BAR_END"] PHRASE_START = TOKEN2ID["PHRASE_START"] PHRASE_END = TOKEN2ID["PHRASE_END"] REST = TOKEN2ID["REST"] CHORD_START = TOKEN2ID["CHORD_START"] CHORD_END = TOKEN2ID["CHORD_END"] ROLE_BASS = TOKEN2ID["ROLE_BASS"] ROLE_INNER = TOKEN2ID["ROLE_INNER"] ROLE_TOP = TOKEN2ID["ROLE_TOP"] # --- Quantization helpers ----------------------------------------------------- def _log_bin(x: float, edges=LOG_TIME_EDGES) -> int: """Map a positive time (s) to a log-bin index in [0, N-1].""" x = max(x, TIME_MIN) x = min(x, TIME_MAX) logx = math.log(x) # digitize returns 1..len(edges)-1; clip to valid bin range. idx = int(np.digitize(logx, edges)) - 1 return max(0, min(N_DUR_BINS - 1, idx)) def _bin_center(i: int, edges=LOG_TIME_EDGES) -> float: lo, hi = edges[i], edges[i + 1] return math.exp(0.5 * (lo + hi)) # Optional corpus-fit quantile edges for velocity bins. When set, _vel_bin # uses these instead of uniform binning. Loaded from a JSON file by # load_velocity_quantiles(); fit_velocity_quantiles() trains them. _VEL_EDGES: Optional[np.ndarray] = None def _vel_bin(v: int) -> int: v = max(0, min(127, int(v))) if _VEL_EDGES is not None: idx = int(np.searchsorted(_VEL_EDGES, v, side="right")) - 1 return max(0, min(N_VEL_BINS - 1, idx)) return min(N_VEL_BINS - 1, v * N_VEL_BINS // 128) def _vel_center(i: int) -> int: if _VEL_EDGES is not None: lo = float(_VEL_EDGES[i]) hi = float(_VEL_EDGES[i + 1]) return int(round(0.5 * (lo + hi))) return int((i + 0.5) * 128 / N_VEL_BINS) def fit_velocity_quantiles(velocities: List[int], n_bins: int = N_VEL_BINS) -> List[float]: """Compute quantile bin edges for a corpus of velocity values.""" if not velocities: return [i * 128 / n_bins for i in range(n_bins + 1)] vs = np.asarray(velocities, dtype=np.float64) qs = np.linspace(0.0, 1.0, n_bins + 1) edges = np.quantile(vs, qs).tolist() edges[0] = 0.0 edges[-1] = 128.0 # Force monotonic increasing in case of heavy ties. for i in range(1, len(edges)): if edges[i] <= edges[i - 1]: edges[i] = edges[i - 1] + 1e-3 return edges def load_velocity_quantiles(path: Path) -> bool: """Install bin edges from a JSON file. Returns True if loaded.""" global _VEL_EDGES p = Path(path) if not p.exists(): return False edges = json.loads(p.read_text()) if not isinstance(edges, list) or len(edges) != N_VEL_BINS + 1: return False _VEL_EDGES = np.asarray(edges, dtype=np.float64) return True def save_velocity_quantiles(edges: List[float], path: Path) -> None: p = Path(path) p.parent.mkdir(parents=True, exist_ok=True) p.write_text(json.dumps(list(edges))) def _tempo_bin(bpm: float) -> int: bpm = max(TEMPO_MIN, min(TEMPO_MAX, float(bpm))) idx = int(np.digitize(math.log(bpm), LOG_TEMPO_EDGES)) - 1 return max(0, min(N_TEMPO_BINS - 1, idx)) def _tempo_center(i: int) -> float: lo, hi = LOG_TEMPO_EDGES[i], LOG_TEMPO_EDGES[i + 1] return math.exp(0.5 * (lo + hi)) def _program_family(program: int) -> int: """Map a General MIDI program (0..127) to its 16-family index.""" return max(0, min(N_VOICE_BINS - 1, int(program) // 8)) # --- Encode ------------------------------------------------------------------- @dataclass class _Event: onset: float voice: int # GM family index 0..15 or DRUM_VOICE kind: str = "note" # 'note' | 'pedal' | 'cc' | 'pb' # Note fields pitch: int = 0 duration: float = 0.0 velocity: int = 0 # Pedal fields pedal_type: str = "SUS" # 'SUS' | 'SOS' | 'SFT' pedal_state: str = "UP" # 'UP' | 'DOWN' # CC fields (continuous controllers) cc_type: str = "MOD" # name from CC_TYPES values cc_bin: int = 0 # Pitch-bend fields pb_bin: int = 0 def _cc_bin(value: int) -> int: v = max(0, min(127, int(value))) return min(N_CC_BINS - 1, v * N_CC_BINS // 128) def _cc_center(i: int) -> int: return int((i + 0.5) * 128 / N_CC_BINS) def _pb_bin(value: int) -> int: v = max(-8192, min(8191, int(value))) return min(N_PB_BINS - 1, (v + 8192) * N_PB_BINS // 16384) def _pb_center(i: int) -> int: return int((i + 0.5) * 16384 / N_PB_BINS) - 8192 def _extract_events(pm: pretty_midi.PrettyMIDI) -> List[_Event]: events: List[_Event] = [] for inst in pm.instruments: voice = DRUM_VOICE if inst.is_drum else _program_family(inst.program) for n in inst.notes: if PITCH_MIN <= n.pitch <= PITCH_MAX: events.append( _Event( onset=n.start, voice=voice, kind="note", pitch=n.pitch, duration=max(n.end - n.start, TIME_MIN), velocity=n.velocity, ) ) for cc in getattr(inst, "control_changes", []) or []: num = int(cc.number) if num in PEDAL_CC_NUMBERS: events.append( _Event( onset=float(cc.time), voice=voice, kind="pedal", pedal_type=PEDAL_CC_NUMBERS[num], pedal_state="DOWN" if int(cc.value) >= 64 else "UP", ) ) elif num in CC_TYPES: events.append( _Event( onset=float(cc.time), voice=voice, kind="cc", cc_type=CC_TYPES[num], cc_bin=_cc_bin(cc.value), ) ) for pb in getattr(inst, "pitch_bends", []) or []: events.append( _Event( onset=float(pb.time), voice=voice, kind="pb", pb_bin=_pb_bin(pb.pitch), ) ) # Note events sort by pitch within onset; non-note events keep their input order. events.sort(key=lambda e: (e.onset, e.kind != "note", e.voice, e.pitch)) return events def _tempo_changes(pm: pretty_midi.PrettyMIDI) -> List[Tuple[float, float]]: """Return sorted (time_s, bpm) pairs from a PrettyMIDI's tempo map.""" try: times, tempos = pm.get_tempo_changes() except Exception: return [(0.0, 120.0)] pairs = list(zip(times.tolist(), tempos.tolist())) if len(times) else [] if not pairs or pairs[0][0] > 1e-6: pairs.insert(0, (0.0, pairs[0][1] if pairs else 120.0)) return pairs def _meter_token(num: int, den: int) -> str: name = f"METER_{num}_{den}" return name if name in METER_QUARTERS else "METER_OTHER" def _key_changes(pm: pretty_midi.PrettyMIDI) -> List[Tuple[float, str]]: """Return sorted (time_s, key_name) pairs from key_signature_changes.""" out: List[Tuple[float, str]] = [] for ks in getattr(pm, "key_signature_changes", []) or []: kn = int(getattr(ks, "key_number", 0)) % 24 out.append((float(ks.time), f"KEY_{kn}")) if not out or out[0][0] > 1e-6: out.insert(0, (0.0, out[0][1] if out else "KEY_0")) return out def _meter_changes(pm: pretty_midi.PrettyMIDI) -> List[Tuple[float, str]]: """Return sorted (time_s, meter_name) pairs from time-signature changes.""" out: List[Tuple[float, str]] = [] for ts in getattr(pm, "time_signature_changes", []) or []: out.append((float(ts.time), _meter_token(ts.numerator, ts.denominator))) if not out or out[0][0] > 1e-6: out.insert(0, (0.0, out[0][1] if out else "METER_4_4")) return out def _bin_density(n: int) -> int: if n < 4: return 0 if n < 8: return 1 if n < 16: return 2 return 3 def _bin_register(mean_pitch: float) -> int: if mean_pitch < 48: return 0 if mean_pitch < 60: return 1 if mean_pitch < 72: return 2 return 3 def _bar_header_tokens(bar_events: List["_Event"]) -> List[int]: """Return [ROOT_, DENS_, REG_] tokens summarizing a bar.""" if not bar_events: return [ TOKEN2ID[ROOT_NAMES[0]], TOKEN2ID[DENS_NAMES[0]], TOKEN2ID[REG_NAMES[0]], ] lowest = min(e.pitch for e in bar_events) root_pc = lowest % 12 n = len(bar_events) mean_pitch = sum(e.pitch for e in bar_events) / n return [ TOKEN2ID[ROOT_NAMES[root_pc]], TOKEN2ID[DENS_NAMES[_bin_density(n)]], TOKEN2ID[REG_NAMES[_bin_register(mean_pitch)]], ] def _group_by_onset(events: List["_Event"], eps: float = TIME_MIN) -> List[List["_Event"]]: """Group consecutive *note* events with coincident onsets into chord groups. Non-note events (pedal/cc/pb) are emitted as size-1 groups so they keep their place in the timeline but never bracket as chords. """ groups: List[List[_Event]] = [] cur: List[_Event] = [] cur_onset: Optional[float] = None def _flush() -> None: nonlocal cur, cur_onset if cur: groups.append(cur) cur = [] cur_onset = None for ev in events: if ev.kind != "note": _flush() groups.append([ev]) continue if cur_onset is None or abs(ev.onset - cur_onset) <= eps: cur.append(ev) if cur_onset is None: cur_onset = ev.onset else: _flush() cur = [ev] cur_onset = ev.onset _flush() return groups def _downbeats(pm: pretty_midi.PrettyMIDI) -> np.ndarray: try: db = pm.get_downbeats() return np.asarray(db) if db is not None else np.array([]) except Exception: return np.array([]) def encode(pm: pretty_midi.PrettyMIDI) -> List[int]: """Encode a PrettyMIDI object to a list of vocabulary ids. Stream layout: PHRASE_START T METER_X_Y [VC] [BAR_START] ... For each onset group (chord = co-located notes): [tempo/meter/bar tokens if any cross this onset] [POS

if in a bar AND position changed; otherwise TS as fallback] if size>1: CHORD_START CHORD_END else: Notes within a chord are sorted by pitch ascending; lowest gets ROLE_BASS, highest ROLE_TOP, middle pitches ROLE_INNER. """ events = _extract_events(pm) ids: List[int] = [PHRASE_START] if not events: ids.append(PHRASE_END) ids.append(EOS) return ids tempo_map = _tempo_changes(pm) tempo_iter = iter(tempo_map) cur_tempo = next(tempo_iter, (0.0, 120.0)) next_tempo = next(tempo_iter, None) ids.append(TOKEN2ID[f"T{_tempo_bin(cur_tempo[1])}"]) meter_map = _meter_changes(pm) meter_iter = iter(meter_map) cur_meter = next(meter_iter, (0.0, "METER_4_4")) next_meter = next(meter_iter, None) ids.append(TOKEN2ID[cur_meter[1]]) key_map = _key_changes(pm) key_iter = iter(key_map) cur_key = next(key_iter, (0.0, "KEY_0")) next_key = next(key_iter, None) ids.append(TOKEN2ID[cur_key[1]]) downbeats = list(_downbeats(pm)) # Precompute *note* events per bar for the header summary. Non-note # events (pedals/CC/PB) are excluded so they don't skew ROOT/DENS/REG. bar_events_by_idx: Dict[int, List[_Event]] = {} if downbeats: db_arr = np.asarray(downbeats) for ev in events: if ev.kind != "note": continue i = max( 0, int(np.searchsorted(db_arr, ev.onset, side="right")) - 1, ) bar_events_by_idx.setdefault(i, []).append(ev) db_idx = 0 in_bar = False bar_start_time: Optional[float] = None bar_duration: Optional[float] = None last_pos_in_bar: Optional[int] = None # History of pitch multisets per emitted bar (for REF_BAR_K matching). bar_pitch_history: List[Tuple[int, ...]] = [] def _bar_pitches(idx: int) -> Tuple[int, ...]: return tuple(sorted(e.pitch for e in bar_events_by_idx.get(idx, []))) def _emit_bar_start(bar_index: int) -> None: ids.append(BAR_START) ids.extend(_bar_header_tokens(bar_events_by_idx.get(bar_index, []))) fp = _bar_pitches(bar_index) for k in REF_DISTANCES: if k <= len(bar_pitch_history) and fp and fp == bar_pitch_history[-k]: ids.append(TOKEN2ID[f"REF_BAR_{k}"]) break bar_pitch_history.append(fp) groups = _group_by_onset(events) current_voice = groups[0][0].voice ids.append(TOKEN2ID[f"VC{current_voice}"]) prev_onset = groups[0][0].onset while db_idx < len(downbeats) and downbeats[db_idx] <= prev_onset + 1e-6: if in_bar: ids.append(BAR_END) _emit_bar_start(db_idx) in_bar = True bar_start_time = float(downbeats[db_idx]) if db_idx + 1 < len(downbeats): bar_duration = float(downbeats[db_idx + 1] - downbeats[db_idx]) last_pos_in_bar = None db_idx += 1 for g_idx, group in enumerate(groups): onset = group[0].onset while next_tempo is not None and next_tempo[0] <= onset + 1e-6: ids.append(TOKEN2ID[f"T{_tempo_bin(next_tempo[1])}"]) next_tempo = next(tempo_iter, None) while next_meter is not None and next_meter[0] <= onset + 1e-6: ids.append(TOKEN2ID[next_meter[1]]) next_meter = next(meter_iter, None) while next_key is not None and next_key[0] <= onset + 1e-6: ids.append(TOKEN2ID[next_key[1]]) next_key = next(key_iter, None) while db_idx < len(downbeats) and downbeats[db_idx] <= onset + 1e-6: if in_bar: ids.append(BAR_END) _emit_bar_start(db_idx) in_bar = True bar_start_time = float(downbeats[db_idx]) if db_idx + 1 < len(downbeats): bar_duration = float(downbeats[db_idx + 1] - downbeats[db_idx]) last_pos_in_bar = None db_idx += 1 if in_bar and bar_duration and bar_duration > 1e-6: pos_bin = int(round((onset - bar_start_time) / bar_duration * N_POS_BINS)) pos_bin = max(0, min(N_POS_BINS - 1, pos_bin)) if pos_bin != last_pos_in_bar: ids.append(TOKEN2ID[f"POS{pos_bin}"]) last_pos_in_bar = pos_bin else: shift = 0.0 if g_idx == 0 else onset - prev_onset if shift > TIME_MIN: ids.append(TOKEN2ID[f"TS{_log_bin(shift)}"]) if group[0].kind != "note": ev = group[0] if ev.voice != current_voice: ids.append(TOKEN2ID[f"VC{ev.voice}"]) current_voice = ev.voice if ev.kind == "pedal": ids.append(TOKEN2ID[f"PEDAL_{ev.pedal_type}_{ev.pedal_state}"]) elif ev.kind == "cc": ids.append(TOKEN2ID[f"CC_{ev.cc_type}_{ev.cc_bin}"]) elif ev.kind == "pb": ids.append(TOKEN2ID[f"PB_{ev.pb_bin}"]) prev_onset = onset continue notes = sorted(group, key=lambda e: e.pitch) is_chord = len(notes) > 1 if is_chord: ids.append(CHORD_START) for n_idx, ev in enumerate(notes): if ev.voice != current_voice: ids.append(TOKEN2ID[f"VC{ev.voice}"]) current_voice = ev.voice if is_chord: if n_idx == 0: ids.append(ROLE_BASS) elif n_idx == len(notes) - 1: ids.append(ROLE_TOP) else: ids.append(ROLE_INNER) ids.append(TOKEN2ID[f"V{_vel_bin(ev.velocity)}"]) ids.append(TOKEN2ID[f"P{ev.pitch}"]) ids.append(TOKEN2ID[f"D{_log_bin(ev.duration)}"]) if is_chord: ids.append(CHORD_END) prev_onset = onset if in_bar: ids.append(BAR_END) ids.append(PHRASE_END) ids.append(EOS) return ids # --- Decode ------------------------------------------------------------------- # GM family -> representative program number (one per family of 8). FAMILY_PROGRAMS = { 0: 0, # Piano 1: 8, # Chromatic Percussion 2: 16, # Organ 3: 24, # Guitar 4: 32, # Bass 5: 40, # Strings 6: 48, # Ensemble 7: 56, # Brass 8: 64, # Reed 9: 72, # Pipe 10: 80, # Synth Lead 11: 88, # Synth Pad 12: 96, # Synth Effects 13: 104, # Ethnic 14: 112, # Percussive 15: 120, # Sound Effects } def _kind(t: str): """Classify a token name into a (kind, value) pair for the decoder.""" if t in STRUCTURAL: return ("struct", t) if t in ROLES: return ("role", t) if t in METERS: return ("meter", t) if t.startswith("KEY_") and t[4:].isdigit(): return ("key", int(t[4:])) if t.startswith("ROOT_") and t[5:].isdigit(): return ("root", int(t[5:])) if t.startswith("DENS_") and t[5:].isdigit(): return ("dens", int(t[5:])) if t.startswith("REG_") and t[4:].isdigit(): return ("reg", int(t[4:])) if t.startswith("REF_BAR_") and t[8:].isdigit(): return ("ref", int(t[8:])) if t.startswith("CAP_SEG_") and t[8:].isdigit(): return ("capseg", int(t[8:])) if t.startswith("PEDAL_"): return ("pedal", t) if t.startswith("CC_"): return ("cc", t) if t.startswith("PB_") and t[3:].isdigit(): return ("pb", int(t[3:])) if t.startswith("TS") and t[2:].isdigit(): return ("ts", int(t[2:])) if t.startswith("VC") and t[2:].isdigit(): return ("voice", int(t[2:])) if t.startswith("POS") and t[3:].isdigit(): return ("pos", int(t[3:])) if t.startswith("D") and t[1:].isdigit(): return ("dur", int(t[1:])) if t.startswith("V") and t[1:].isdigit(): return ("vel", int(t[1:])) if t.startswith("T") and t[1:].isdigit(): return ("tempo", int(t[1:])) if t.startswith("P") and t[1:].isdigit(): return ("pitch", int(t[1:])) return ("struct", t) def decode(ids: List[int], default_tempo: float = 120.0) -> pretty_midi.PrettyMIDI: """Decode a token id list back to a PrettyMIDI. Timing is reconstructed from POS within bars (using current tempo + meter) or TS deltas as a fallback. Pitches and instrument families are preserved exactly. """ initial_tempo = default_tempo for tid in ids: t = ID2TOKEN.get(tid, "") if t.startswith("T") and not t.startswith("TS") and t[1:].isdigit(): initial_tempo = _tempo_center(int(t[1:])) break pm = pretty_midi.PrettyMIDI(initial_tempo=initial_tempo) instruments: Dict[int, pretty_midi.Instrument] = {} current_voice = 0 def get_inst(v: int) -> pretty_midi.Instrument: if v not in instruments: if v == DRUM_VOICE: instruments[v] = pretty_midi.Instrument( program=0, is_drum=True, name="drums", ) else: prog = FAMILY_PROGRAMS.get(v, 0) instruments[v] = pretty_midi.Instrument( program=prog, name=f"family_{v}", ) return instruments[v] current_tempo = initial_tempo current_meter_quarters = METER_QUARTERS["METER_4_4"] bar_duration = current_meter_quarters * 60.0 / current_tempo bar_start_time = 0.0 n_bars_seen = 0 current_time = 0.0 pending_velocity = 64 i = 0 while i < len(ids): kind, val = _kind(ID2TOKEN.get(ids[i], "PAD")) if kind == "ts": current_time += _bin_center(val) elif kind == "pos": current_time = bar_start_time + (int(val) / N_POS_BINS) * bar_duration elif kind == "voice": current_voice = int(val) elif kind == "vel": pending_velocity = _vel_center(int(val)) elif kind == "tempo": current_tempo = _tempo_center(int(val)) bar_duration = current_meter_quarters * 60.0 / current_tempo elif kind == "meter": current_meter_quarters = METER_QUARTERS.get(val, 4.0) bar_duration = current_meter_quarters * 60.0 / current_tempo elif kind == "key": try: pm.key_signature_changes.append( pretty_midi.KeySignature(int(val), float(current_time)) ) except Exception: pass elif kind == "pedal": # PEDAL__ parts = str(val).split("_") if len(parts) == 3: ptype, pstate = parts[1], parts[2] cc_num = PEDAL_NAME_TO_CC.get(ptype) if cc_num is not None: inst = get_inst(current_voice) inst.control_changes.append( pretty_midi.ControlChange( number=cc_num, value=127 if pstate == "DOWN" else 0, time=float(current_time), ) ) elif kind == "cc": # CC__ parts = str(val).split("_") if len(parts) == 3 and parts[2].isdigit(): cname, bidx = parts[1], int(parts[2]) cc_num = CC_NAME_TO_NUMBER.get(cname) if cc_num is not None: inst = get_inst(current_voice) inst.control_changes.append( pretty_midi.ControlChange( number=cc_num, value=_cc_center(bidx), time=float(current_time), ) ) elif kind == "pb": inst = get_inst(current_voice) inst.pitch_bends.append( pretty_midi.PitchBend( pitch=_pb_center(int(val)), time=float(current_time), ) ) elif kind == "struct" and val == "BAR_START": if n_bars_seen > 0: bar_start_time += bar_duration current_time = bar_start_time n_bars_seen += 1 elif kind == "struct" and val == "REST": current_time += 0.25 elif kind == "pitch": duration = 0.25 j = i + 1 while j < len(ids): nt = ID2TOKEN.get(ids[j], "PAD") nkind, nval = _kind(nt) if nkind == "dur": duration = _bin_center(int(nval)) break if nkind in ( "pitch", "ts", "pos", "voice", "vel", "tempo", "meter", "role", "key", "root", "dens", "reg", "ref", "capseg", "pedal", "cc", "pb", ): break if nkind == "struct" and nval not in ("CHORD_START", "CHORD_END"): break j += 1 note = pretty_midi.Note( velocity=int(pending_velocity), pitch=int(val), start=current_time, end=current_time + max(duration, 0.01), ) get_inst(current_voice).notes.append(note) # role / chord brackets / bar_end / phrase / pad / eos: no timing effect i += 1 for voice in sorted(instruments): inst = instruments[voice] if inst.notes: pm.instruments.append(inst) return pm def inject_caption_segments(ids: List[int], n_segs: int = N_CAP_SEGS) -> List[int]: """Insert CAP_SEG_ right after each PHRASE_START. Use this when you have a multi-sentence caption split into ``n_segs`` parts and want to bind the i-th phrase of the encoded MIDI to the i-th caption segment. Emission is opt-in because the segment count only makes sense in the presence of an external caption. """ if n_segs <= 0 or n_segs > N_CAP_SEGS: raise ValueError(f"n_segs must be in [1, {N_CAP_SEGS}]") out: List[int] = [] seen_phrases = 0 for tid in ids: out.append(tid) if tid == PHRASE_START: out.append(TOKEN2ID[CAP_SEG_NAMES[seen_phrases % n_segs]]) seen_phrases += 1 return out # --- Round-trip test ---------------------------------------------------------- def _voice_label(inst: "pretty_midi.Instrument") -> int: return DRUM_VOICE if inst.is_drum else _program_family(inst.program) def round_trip_test(pm: pretty_midi.PrettyMIDI) -> Tuple[bool, Dict]: """Verify the (pitch, voice) multiset is preserved through encode+decode. Timing is not checked because log quantization is lossy. """ original = sorted( (n.pitch, _voice_label(inst)) for inst in pm.instruments for n in inst.notes if PITCH_MIN <= n.pitch <= PITCH_MAX ) ids = encode(pm) pm2 = decode(ids) reconstructed = sorted( (n.pitch, _voice_label(inst)) for inst in pm2.instruments for n in inst.notes ) passed = original == reconstructed return passed, { "n_orig": len(original), "n_recon": len(reconstructed), "n_tokens": len(ids), "vocab_size": VOCAB_SIZE, } # --- Auto-load fitted velocity quantiles ------------------------------------- if DEFAULT_VEL_QUANTILES_PATH.exists(): try: load_velocity_quantiles(DEFAULT_VEL_QUANTILES_PATH) except Exception: pass # --- CLI ---------------------------------------------------------------------- def _cli_fit_velocity_quantiles(sample_dir: Path, out_path: Path) -> None: paths = sorted(sample_dir.rglob("*.mid")) + sorted(sample_dir.rglob("*.midi")) if not paths: raise SystemExit(f"No MIDI files under {sample_dir}") velocities: List[int] = [] n_failed = 0 for p in paths: try: pm = pretty_midi.PrettyMIDI(str(p)) except Exception: n_failed += 1 continue for inst in pm.instruments: for n in inst.notes: velocities.append(int(n.velocity)) edges = fit_velocity_quantiles(velocities, n_bins=N_VEL_BINS) save_velocity_quantiles(edges, out_path) print( f"[velocity] files={len(paths)} failed={n_failed} " f"velocities={len(velocities)} -> {out_path}" ) print(f"[velocity] edges = {[round(e, 2) for e in edges]}") if __name__ == "__main__": import argparse as _argparse import sys as _sys if len(_sys.argv) > 1 and _sys.argv[1] == "fit-velocity": p = _argparse.ArgumentParser() p.add_argument( "--sample-dir", type=str, default=str(Path(__file__).resolve().parent.parent / "data" / "gigamidi" / "sample"), ) p.add_argument("--out", type=str, default=str(DEFAULT_VEL_QUANTILES_PATH)) args = p.parse_args(_sys.argv[2:]) _cli_fit_velocity_quantiles(Path(args.sample_dir), Path(args.out)) _sys.exit(0) print(f"Vocab size: {VOCAB_SIZE}") print(f" structural: {len(STRUCTURAL)}") print(f" pitch: {N_PITCH}") print(f" duration: {N_DUR_BINS}") print(f" time-shift: {N_SHIFT_BINS}") print(f" velocity: {N_VEL_BINS}") print(f" voice/family: {N_VOICE_BINS}") print(f" tempo: {N_TEMPO_BINS}") # Smoke test: build a tiny C-major scale and round-trip it. pm = pretty_midi.PrettyMIDI() inst = pretty_midi.Instrument(program=0) t = 0.0 for p in [60, 62, 64, 65, 67, 69, 71, 72]: inst.notes.append(pretty_midi.Note(velocity=80, pitch=p, start=t, end=t + 0.5)) t += 0.5 pm.instruments.append(inst) ok, info = round_trip_test(pm) print(f"\nSmoke test round-trip: {'PASS' if ok else 'FAIL'} {info}") # Multi-track / multi-velocity test. pm2 = pretty_midi.PrettyMIDI(initial_tempo=92.0) piano = pretty_midi.Instrument(program=0) # family 0 (piano) bass = pretty_midi.Instrument(program=33) # family 4 (bass) strings = pretty_midi.Instrument(program=48) # family 6 (ensemble) t = 0.0 for p, v in [(60, 30), (64, 80), (67, 110), (72, 60)]: piano.notes.append(pretty_midi.Note(velocity=v, pitch=p, start=t, end=t + 0.5)) bass.notes.append(pretty_midi.Note(velocity=70, pitch=p - 24, start=t, end=t + 0.5)) strings.notes.append(pretty_midi.Note(velocity=50, pitch=p + 12, start=t, end=t + 0.5)) t += 0.5 pm2.instruments += [piano, bass, strings] ok2, info2 = round_trip_test(pm2) print(f"Multi-track round-trip: {'PASS' if ok2 else 'FAIL'} {info2}") ids = encode(pm2) has_tempo = any( ID2TOKEN[i].startswith("T") and not ID2TOKEN[i].startswith("TS") and ID2TOKEN[i][1:].isdigit() for i in ids ) has_voice = any(ID2TOKEN[i].startswith("VC") for i in ids) has_meter = any(ID2TOKEN[i] in METERS for i in ids) has_pos = any(ID2TOKEN[i].startswith("POS") for i in ids) has_chord = CHORD_START in ids has_role = any(t in ids for t in (ROLE_BASS, ROLE_INNER, ROLE_TOP)) print( f"Stream features: tempo={has_tempo} voice={has_voice} " f"meter={has_meter} pos={has_pos} chord={has_chord} role={has_role}" )