Spaces:
Sleeping
Sleeping
| """MIDI tokenizer for bach-gpt. | |
| Vocabulary | |
| ---------- | |
| Structural: PAD, EOS, BAR_START, BAR_END, PHRASE_START, PHRASE_END, | |
| REST, CHORD_START, CHORD_END (9) | |
| Pitch: MIDI 21..108 (88) | |
| Duration: 32 log-quantized bins over [0.03125, 4.0] seconds (32) | |
| Time-shift: 32 log-quantized bins over [0.03125, 4.0] seconds (32) | |
| Velocity: 16 uniform bins over [0, 127] (V0..V15) (16) | |
| Voice/chan: 16 GM families + 1 drums (VC0..VC16) (17) | |
| Tempo: 16 log-spaced bins over [40, 240] BPM (T0..T15) (16) | |
| Position: 16 sub-beat positions per bar (POS0..POS15) (16) | |
| Meter: 8 common time signatures + OTHER (METER_*) (9) | |
| Voice role: ROLE_BASS, ROLE_INNER, ROLE_TOP (within chord brackets) (3) | |
| Total vocab size: 238 tokens. | |
| API | |
| --- | |
| encode(pm: pretty_midi.PrettyMIDI) -> List[int] | |
| decode(ids: List[int]) -> pretty_midi.PrettyMIDI | |
| round_trip_test(pm) -> (passed: bool, details: dict) | |
| Note on fidelity: encode/decode preserve the pitch multiset and instrument | |
| family per note but timing is fuzzy due to log quantization. BAR_* and | |
| PHRASE_* markers are emitted from PrettyMIDI downbeats when a time signature | |
| is present. A tempo token (T*) is emitted at PHRASE_START and on tempo | |
| changes. A VC* token is emitted whenever the active track's instrument | |
| family changes. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import math | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import numpy as np | |
| import pretty_midi | |
| # --- Vocabulary construction -------------------------------------------------- | |
| STRUCTURAL = [ | |
| "PAD", | |
| "EOS", | |
| "BAR_START", | |
| "BAR_END", | |
| "PHRASE_START", | |
| "PHRASE_END", | |
| "REST", | |
| "CHORD_START", | |
| "CHORD_END", | |
| ] | |
| ROLES = ["ROLE_BASS", "ROLE_INNER", "ROLE_TOP"] | |
| # 24 keys: indices 0..11 = C..B major; 12..23 = C..B minor (PrettyMIDI's | |
| # convention via key_number). | |
| KEYS = [f"KEY_{i}" for i in range(24)] | |
| # Bar-header axes: emitted right after each BAR_START as a coarse summary | |
| # of the bar's harmonic, density, and register content. | |
| ROOT_NAMES = [f"ROOT_{i}" for i in range(12)] | |
| N_DENS_BINS = 4 # 0-3, 4-7, 8-15, 16+ | |
| N_REG_BINS = 4 # <48, 48-59, 60-71, 72+ | |
| DENS_NAMES = [f"DENS_{i}" for i in range(N_DENS_BINS)] | |
| REG_NAMES = [f"REG_{i}" for i in range(N_REG_BINS)] | |
| # Bar-repetition markers: emitted right after bar header when this bar's | |
| # pitch multiset matches a bar K positions earlier (K in {1, 2, 4, 8}). | |
| REF_DISTANCES = [1, 2, 4, 8] | |
| REF_NAMES = [f"REF_BAR_{k}" for k in REF_DISTANCES] | |
| # Caption-segment markers for cross-modal alignment with MidiCaps-style | |
| # multi-sentence captions. Emit CAP_SEG_<i> at PHRASE_START to bind the | |
| # next phrase to the i-th caption segment. | |
| N_CAP_SEGS = 8 | |
| CAP_SEG_NAMES = [f"CAP_SEG_{i}" for i in range(N_CAP_SEGS)] | |
| # Pedal tokens: GM CC#64 (sustain), CC#66 (sostenuto), CC#67 (soft). | |
| PEDAL_CC_NUMBERS = {64: "SUS", 66: "SOS", 67: "SFT"} | |
| PEDAL_NAMES = [ | |
| f"PEDAL_{p}_{state}" | |
| for p in PEDAL_CC_NUMBERS.values() | |
| for state in ("UP", "DOWN") | |
| ] | |
| # Continuous-controller tokens for the most common expressive CCs, each | |
| # quantized to 8 bins over [0, 128). | |
| N_CC_BINS = 8 | |
| CC_TYPES = {1: "MOD", 7: "VOL", 10: "PAN", 11: "EXPR"} | |
| CC_NAMES = [ | |
| f"CC_{name}_{i}" | |
| for name in CC_TYPES.values() | |
| for i in range(N_CC_BINS) | |
| ] | |
| # Pitch-bend tokens. PrettyMIDI gives 14-bit values in [-8192, 8191]; | |
| # quantize to 16 uniform bins and emit as PB_<i>. | |
| N_PB_BINS = 16 | |
| PB_NAMES = [f"PB_{i}" for i in range(N_PB_BINS)] | |
| # Reverse maps for the decoder: short name -> CC number. | |
| PEDAL_NAME_TO_CC = {v: k for k, v in PEDAL_CC_NUMBERS.items()} | |
| CC_NAME_TO_NUMBER = {v: k for k, v in CC_TYPES.items()} | |
| METERS = [ | |
| "METER_2_4", | |
| "METER_3_4", | |
| "METER_4_4", | |
| "METER_5_4", | |
| "METER_6_8", | |
| "METER_7_8", | |
| "METER_9_8", | |
| "METER_12_8", | |
| "METER_OTHER", | |
| ] | |
| # Bar length in quarter-notes for each meter (used by decoder). | |
| METER_QUARTERS: Dict[str, float] = { | |
| "METER_2_4": 2.0, | |
| "METER_3_4": 3.0, | |
| "METER_4_4": 4.0, | |
| "METER_5_4": 5.0, | |
| "METER_6_8": 3.0, | |
| "METER_7_8": 3.5, | |
| "METER_9_8": 4.5, | |
| "METER_12_8": 6.0, | |
| "METER_OTHER": 4.0, | |
| } | |
| PITCH_MIN, PITCH_MAX = 21, 108 # 88 pitches (A0..C8) | |
| N_PITCH = PITCH_MAX - PITCH_MIN + 1 # 88 | |
| N_DUR_BINS = 32 | |
| N_SHIFT_BINS = 32 | |
| N_VEL_BINS = 16 | |
| # 16 GM instrument families + 1 reserved drum voice (VC16). | |
| N_VOICE_BINS = 17 | |
| DRUM_VOICE = 16 | |
| N_TEMPO_BINS = 16 | |
| # Sub-beat resolution per bar (sixteenth-note grid in 4/4). | |
| N_POS_BINS = 16 | |
| # Log-quantization range: 2**-5 s (~31 ms) to 4 s. | |
| TIME_MIN, TIME_MAX = 2 ** -5, 4.0 | |
| LOG_TIME_EDGES = np.linspace(math.log(TIME_MIN), math.log(TIME_MAX), N_DUR_BINS + 1) | |
| # Tempo log-quantization range: 40..240 BPM. | |
| TEMPO_MIN, TEMPO_MAX = 40.0, 240.0 | |
| LOG_TEMPO_EDGES = np.linspace( | |
| math.log(TEMPO_MIN), math.log(TEMPO_MAX), N_TEMPO_BINS + 1 | |
| ) | |
| def _build_vocab() -> Tuple[List[str], Dict[str, int]]: | |
| tokens: List[str] = list(STRUCTURAL) | |
| tokens += [f"P{p}" for p in range(PITCH_MIN, PITCH_MAX + 1)] | |
| tokens += [f"D{i}" for i in range(N_DUR_BINS)] | |
| tokens += [f"TS{i}" for i in range(N_SHIFT_BINS)] | |
| tokens += [f"V{i}" for i in range(N_VEL_BINS)] | |
| tokens += [f"VC{i}" for i in range(N_VOICE_BINS)] | |
| tokens += [f"T{i}" for i in range(N_TEMPO_BINS)] | |
| tokens += [f"POS{i}" for i in range(N_POS_BINS)] | |
| tokens += list(METERS) | |
| tokens += list(ROLES) | |
| tokens += list(KEYS) | |
| tokens += list(ROOT_NAMES) | |
| tokens += list(DENS_NAMES) | |
| tokens += list(REG_NAMES) | |
| tokens += list(REF_NAMES) | |
| tokens += list(CAP_SEG_NAMES) | |
| tokens += list(PEDAL_NAMES) | |
| tokens += list(CC_NAMES) | |
| tokens += list(PB_NAMES) | |
| t2i = {t: i for i, t in enumerate(tokens)} | |
| return tokens, t2i | |
| TOKENS, TOKEN2ID = _build_vocab() | |
| ID2TOKEN = {i: t for t, i in TOKEN2ID.items()} | |
| VOCAB_SIZE = len(TOKENS) | |
| # Default location for fitted velocity quantiles. See _maybe_load_vel_edges | |
| # at the end of this module for the auto-load. | |
| DEFAULT_VEL_QUANTILES_PATH = ( | |
| Path(__file__).resolve().parent.parent / "data" / "tokenizer" / "velocity_quantiles.json" | |
| ) | |
| PAD = TOKEN2ID["PAD"] | |
| EOS = TOKEN2ID["EOS"] | |
| BAR_START = TOKEN2ID["BAR_START"] | |
| BAR_END = TOKEN2ID["BAR_END"] | |
| PHRASE_START = TOKEN2ID["PHRASE_START"] | |
| PHRASE_END = TOKEN2ID["PHRASE_END"] | |
| REST = TOKEN2ID["REST"] | |
| CHORD_START = TOKEN2ID["CHORD_START"] | |
| CHORD_END = TOKEN2ID["CHORD_END"] | |
| ROLE_BASS = TOKEN2ID["ROLE_BASS"] | |
| ROLE_INNER = TOKEN2ID["ROLE_INNER"] | |
| ROLE_TOP = TOKEN2ID["ROLE_TOP"] | |
| # --- Quantization helpers ----------------------------------------------------- | |
| def _log_bin(x: float, edges=LOG_TIME_EDGES) -> int: | |
| """Map a positive time (s) to a log-bin index in [0, N-1].""" | |
| x = max(x, TIME_MIN) | |
| x = min(x, TIME_MAX) | |
| logx = math.log(x) | |
| # digitize returns 1..len(edges)-1; clip to valid bin range. | |
| idx = int(np.digitize(logx, edges)) - 1 | |
| return max(0, min(N_DUR_BINS - 1, idx)) | |
| def _bin_center(i: int, edges=LOG_TIME_EDGES) -> float: | |
| lo, hi = edges[i], edges[i + 1] | |
| return math.exp(0.5 * (lo + hi)) | |
| # Optional corpus-fit quantile edges for velocity bins. When set, _vel_bin | |
| # uses these instead of uniform binning. Loaded from a JSON file by | |
| # load_velocity_quantiles(); fit_velocity_quantiles() trains them. | |
| _VEL_EDGES: Optional[np.ndarray] = None | |
| def _vel_bin(v: int) -> int: | |
| v = max(0, min(127, int(v))) | |
| if _VEL_EDGES is not None: | |
| idx = int(np.searchsorted(_VEL_EDGES, v, side="right")) - 1 | |
| return max(0, min(N_VEL_BINS - 1, idx)) | |
| return min(N_VEL_BINS - 1, v * N_VEL_BINS // 128) | |
| def _vel_center(i: int) -> int: | |
| if _VEL_EDGES is not None: | |
| lo = float(_VEL_EDGES[i]) | |
| hi = float(_VEL_EDGES[i + 1]) | |
| return int(round(0.5 * (lo + hi))) | |
| return int((i + 0.5) * 128 / N_VEL_BINS) | |
| def fit_velocity_quantiles(velocities: List[int], n_bins: int = N_VEL_BINS) -> List[float]: | |
| """Compute quantile bin edges for a corpus of velocity values.""" | |
| if not velocities: | |
| return [i * 128 / n_bins for i in range(n_bins + 1)] | |
| vs = np.asarray(velocities, dtype=np.float64) | |
| qs = np.linspace(0.0, 1.0, n_bins + 1) | |
| edges = np.quantile(vs, qs).tolist() | |
| edges[0] = 0.0 | |
| edges[-1] = 128.0 | |
| # Force monotonic increasing in case of heavy ties. | |
| for i in range(1, len(edges)): | |
| if edges[i] <= edges[i - 1]: | |
| edges[i] = edges[i - 1] + 1e-3 | |
| return edges | |
| def load_velocity_quantiles(path: Path) -> bool: | |
| """Install bin edges from a JSON file. Returns True if loaded.""" | |
| global _VEL_EDGES | |
| p = Path(path) | |
| if not p.exists(): | |
| return False | |
| edges = json.loads(p.read_text()) | |
| if not isinstance(edges, list) or len(edges) != N_VEL_BINS + 1: | |
| return False | |
| _VEL_EDGES = np.asarray(edges, dtype=np.float64) | |
| return True | |
| def save_velocity_quantiles(edges: List[float], path: Path) -> None: | |
| p = Path(path) | |
| p.parent.mkdir(parents=True, exist_ok=True) | |
| p.write_text(json.dumps(list(edges))) | |
| def _tempo_bin(bpm: float) -> int: | |
| bpm = max(TEMPO_MIN, min(TEMPO_MAX, float(bpm))) | |
| idx = int(np.digitize(math.log(bpm), LOG_TEMPO_EDGES)) - 1 | |
| return max(0, min(N_TEMPO_BINS - 1, idx)) | |
| def _tempo_center(i: int) -> float: | |
| lo, hi = LOG_TEMPO_EDGES[i], LOG_TEMPO_EDGES[i + 1] | |
| return math.exp(0.5 * (lo + hi)) | |
| def _program_family(program: int) -> int: | |
| """Map a General MIDI program (0..127) to its 16-family index.""" | |
| return max(0, min(N_VOICE_BINS - 1, int(program) // 8)) | |
| # --- Encode ------------------------------------------------------------------- | |
| class _Event: | |
| onset: float | |
| voice: int # GM family index 0..15 or DRUM_VOICE | |
| kind: str = "note" # 'note' | 'pedal' | 'cc' | 'pb' | |
| # Note fields | |
| pitch: int = 0 | |
| duration: float = 0.0 | |
| velocity: int = 0 | |
| # Pedal fields | |
| pedal_type: str = "SUS" # 'SUS' | 'SOS' | 'SFT' | |
| pedal_state: str = "UP" # 'UP' | 'DOWN' | |
| # CC fields (continuous controllers) | |
| cc_type: str = "MOD" # name from CC_TYPES values | |
| cc_bin: int = 0 | |
| # Pitch-bend fields | |
| pb_bin: int = 0 | |
| def _cc_bin(value: int) -> int: | |
| v = max(0, min(127, int(value))) | |
| return min(N_CC_BINS - 1, v * N_CC_BINS // 128) | |
| def _cc_center(i: int) -> int: | |
| return int((i + 0.5) * 128 / N_CC_BINS) | |
| def _pb_bin(value: int) -> int: | |
| v = max(-8192, min(8191, int(value))) | |
| return min(N_PB_BINS - 1, (v + 8192) * N_PB_BINS // 16384) | |
| def _pb_center(i: int) -> int: | |
| return int((i + 0.5) * 16384 / N_PB_BINS) - 8192 | |
| def _extract_events(pm: pretty_midi.PrettyMIDI) -> List[_Event]: | |
| events: List[_Event] = [] | |
| for inst in pm.instruments: | |
| voice = DRUM_VOICE if inst.is_drum else _program_family(inst.program) | |
| for n in inst.notes: | |
| if PITCH_MIN <= n.pitch <= PITCH_MAX: | |
| events.append( | |
| _Event( | |
| onset=n.start, voice=voice, kind="note", | |
| pitch=n.pitch, | |
| duration=max(n.end - n.start, TIME_MIN), | |
| velocity=n.velocity, | |
| ) | |
| ) | |
| for cc in getattr(inst, "control_changes", []) or []: | |
| num = int(cc.number) | |
| if num in PEDAL_CC_NUMBERS: | |
| events.append( | |
| _Event( | |
| onset=float(cc.time), voice=voice, kind="pedal", | |
| pedal_type=PEDAL_CC_NUMBERS[num], | |
| pedal_state="DOWN" if int(cc.value) >= 64 else "UP", | |
| ) | |
| ) | |
| elif num in CC_TYPES: | |
| events.append( | |
| _Event( | |
| onset=float(cc.time), voice=voice, kind="cc", | |
| cc_type=CC_TYPES[num], | |
| cc_bin=_cc_bin(cc.value), | |
| ) | |
| ) | |
| for pb in getattr(inst, "pitch_bends", []) or []: | |
| events.append( | |
| _Event( | |
| onset=float(pb.time), voice=voice, kind="pb", | |
| pb_bin=_pb_bin(pb.pitch), | |
| ) | |
| ) | |
| # Note events sort by pitch within onset; non-note events keep their input order. | |
| events.sort(key=lambda e: (e.onset, e.kind != "note", e.voice, e.pitch)) | |
| return events | |
| def _tempo_changes(pm: pretty_midi.PrettyMIDI) -> List[Tuple[float, float]]: | |
| """Return sorted (time_s, bpm) pairs from a PrettyMIDI's tempo map.""" | |
| try: | |
| times, tempos = pm.get_tempo_changes() | |
| except Exception: | |
| return [(0.0, 120.0)] | |
| pairs = list(zip(times.tolist(), tempos.tolist())) if len(times) else [] | |
| if not pairs or pairs[0][0] > 1e-6: | |
| pairs.insert(0, (0.0, pairs[0][1] if pairs else 120.0)) | |
| return pairs | |
| def _meter_token(num: int, den: int) -> str: | |
| name = f"METER_{num}_{den}" | |
| return name if name in METER_QUARTERS else "METER_OTHER" | |
| def _key_changes(pm: pretty_midi.PrettyMIDI) -> List[Tuple[float, str]]: | |
| """Return sorted (time_s, key_name) pairs from key_signature_changes.""" | |
| out: List[Tuple[float, str]] = [] | |
| for ks in getattr(pm, "key_signature_changes", []) or []: | |
| kn = int(getattr(ks, "key_number", 0)) % 24 | |
| out.append((float(ks.time), f"KEY_{kn}")) | |
| if not out or out[0][0] > 1e-6: | |
| out.insert(0, (0.0, out[0][1] if out else "KEY_0")) | |
| return out | |
| def _meter_changes(pm: pretty_midi.PrettyMIDI) -> List[Tuple[float, str]]: | |
| """Return sorted (time_s, meter_name) pairs from time-signature changes.""" | |
| out: List[Tuple[float, str]] = [] | |
| for ts in getattr(pm, "time_signature_changes", []) or []: | |
| out.append((float(ts.time), _meter_token(ts.numerator, ts.denominator))) | |
| if not out or out[0][0] > 1e-6: | |
| out.insert(0, (0.0, out[0][1] if out else "METER_4_4")) | |
| return out | |
| def _bin_density(n: int) -> int: | |
| if n < 4: | |
| return 0 | |
| if n < 8: | |
| return 1 | |
| if n < 16: | |
| return 2 | |
| return 3 | |
| def _bin_register(mean_pitch: float) -> int: | |
| if mean_pitch < 48: | |
| return 0 | |
| if mean_pitch < 60: | |
| return 1 | |
| if mean_pitch < 72: | |
| return 2 | |
| return 3 | |
| def _bar_header_tokens(bar_events: List["_Event"]) -> List[int]: | |
| """Return [ROOT_<n>, DENS_<n>, REG_<n>] tokens summarizing a bar.""" | |
| if not bar_events: | |
| return [ | |
| TOKEN2ID[ROOT_NAMES[0]], | |
| TOKEN2ID[DENS_NAMES[0]], | |
| TOKEN2ID[REG_NAMES[0]], | |
| ] | |
| lowest = min(e.pitch for e in bar_events) | |
| root_pc = lowest % 12 | |
| n = len(bar_events) | |
| mean_pitch = sum(e.pitch for e in bar_events) / n | |
| return [ | |
| TOKEN2ID[ROOT_NAMES[root_pc]], | |
| TOKEN2ID[DENS_NAMES[_bin_density(n)]], | |
| TOKEN2ID[REG_NAMES[_bin_register(mean_pitch)]], | |
| ] | |
| def _group_by_onset(events: List["_Event"], eps: float = TIME_MIN) -> List[List["_Event"]]: | |
| """Group consecutive *note* events with coincident onsets into chord | |
| groups. Non-note events (pedal/cc/pb) are emitted as size-1 groups so | |
| they keep their place in the timeline but never bracket as chords. | |
| """ | |
| groups: List[List[_Event]] = [] | |
| cur: List[_Event] = [] | |
| cur_onset: Optional[float] = None | |
| def _flush() -> None: | |
| nonlocal cur, cur_onset | |
| if cur: | |
| groups.append(cur) | |
| cur = [] | |
| cur_onset = None | |
| for ev in events: | |
| if ev.kind != "note": | |
| _flush() | |
| groups.append([ev]) | |
| continue | |
| if cur_onset is None or abs(ev.onset - cur_onset) <= eps: | |
| cur.append(ev) | |
| if cur_onset is None: | |
| cur_onset = ev.onset | |
| else: | |
| _flush() | |
| cur = [ev] | |
| cur_onset = ev.onset | |
| _flush() | |
| return groups | |
| def _downbeats(pm: pretty_midi.PrettyMIDI) -> np.ndarray: | |
| try: | |
| db = pm.get_downbeats() | |
| return np.asarray(db) if db is not None else np.array([]) | |
| except Exception: | |
| return np.array([]) | |
| def encode(pm: pretty_midi.PrettyMIDI) -> List[int]: | |
| """Encode a PrettyMIDI object to a list of vocabulary ids. | |
| Stream layout: PHRASE_START T<n> METER_X_Y [VC<v>] [BAR_START] ... | |
| For each onset group (chord = co-located notes): | |
| [tempo/meter/bar tokens if any cross this onset] | |
| [POS<p> if in a bar AND position changed; otherwise TS<n> as fallback] | |
| if size>1: CHORD_START <per-note: VC?, ROLE, V, P, D> CHORD_END | |
| else: <per-note: VC?, V, P, D> | |
| Notes within a chord are sorted by pitch ascending; lowest gets ROLE_BASS, | |
| highest ROLE_TOP, middle pitches ROLE_INNER. | |
| """ | |
| events = _extract_events(pm) | |
| ids: List[int] = [PHRASE_START] | |
| if not events: | |
| ids.append(PHRASE_END) | |
| ids.append(EOS) | |
| return ids | |
| tempo_map = _tempo_changes(pm) | |
| tempo_iter = iter(tempo_map) | |
| cur_tempo = next(tempo_iter, (0.0, 120.0)) | |
| next_tempo = next(tempo_iter, None) | |
| ids.append(TOKEN2ID[f"T{_tempo_bin(cur_tempo[1])}"]) | |
| meter_map = _meter_changes(pm) | |
| meter_iter = iter(meter_map) | |
| cur_meter = next(meter_iter, (0.0, "METER_4_4")) | |
| next_meter = next(meter_iter, None) | |
| ids.append(TOKEN2ID[cur_meter[1]]) | |
| key_map = _key_changes(pm) | |
| key_iter = iter(key_map) | |
| cur_key = next(key_iter, (0.0, "KEY_0")) | |
| next_key = next(key_iter, None) | |
| ids.append(TOKEN2ID[cur_key[1]]) | |
| downbeats = list(_downbeats(pm)) | |
| # Precompute *note* events per bar for the header summary. Non-note | |
| # events (pedals/CC/PB) are excluded so they don't skew ROOT/DENS/REG. | |
| bar_events_by_idx: Dict[int, List[_Event]] = {} | |
| if downbeats: | |
| db_arr = np.asarray(downbeats) | |
| for ev in events: | |
| if ev.kind != "note": | |
| continue | |
| i = max( | |
| 0, | |
| int(np.searchsorted(db_arr, ev.onset, side="right")) - 1, | |
| ) | |
| bar_events_by_idx.setdefault(i, []).append(ev) | |
| db_idx = 0 | |
| in_bar = False | |
| bar_start_time: Optional[float] = None | |
| bar_duration: Optional[float] = None | |
| last_pos_in_bar: Optional[int] = None | |
| # History of pitch multisets per emitted bar (for REF_BAR_K matching). | |
| bar_pitch_history: List[Tuple[int, ...]] = [] | |
| def _bar_pitches(idx: int) -> Tuple[int, ...]: | |
| return tuple(sorted(e.pitch for e in bar_events_by_idx.get(idx, []))) | |
| def _emit_bar_start(bar_index: int) -> None: | |
| ids.append(BAR_START) | |
| ids.extend(_bar_header_tokens(bar_events_by_idx.get(bar_index, []))) | |
| fp = _bar_pitches(bar_index) | |
| for k in REF_DISTANCES: | |
| if k <= len(bar_pitch_history) and fp and fp == bar_pitch_history[-k]: | |
| ids.append(TOKEN2ID[f"REF_BAR_{k}"]) | |
| break | |
| bar_pitch_history.append(fp) | |
| groups = _group_by_onset(events) | |
| current_voice = groups[0][0].voice | |
| ids.append(TOKEN2ID[f"VC{current_voice}"]) | |
| prev_onset = groups[0][0].onset | |
| while db_idx < len(downbeats) and downbeats[db_idx] <= prev_onset + 1e-6: | |
| if in_bar: | |
| ids.append(BAR_END) | |
| _emit_bar_start(db_idx) | |
| in_bar = True | |
| bar_start_time = float(downbeats[db_idx]) | |
| if db_idx + 1 < len(downbeats): | |
| bar_duration = float(downbeats[db_idx + 1] - downbeats[db_idx]) | |
| last_pos_in_bar = None | |
| db_idx += 1 | |
| for g_idx, group in enumerate(groups): | |
| onset = group[0].onset | |
| while next_tempo is not None and next_tempo[0] <= onset + 1e-6: | |
| ids.append(TOKEN2ID[f"T{_tempo_bin(next_tempo[1])}"]) | |
| next_tempo = next(tempo_iter, None) | |
| while next_meter is not None and next_meter[0] <= onset + 1e-6: | |
| ids.append(TOKEN2ID[next_meter[1]]) | |
| next_meter = next(meter_iter, None) | |
| while next_key is not None and next_key[0] <= onset + 1e-6: | |
| ids.append(TOKEN2ID[next_key[1]]) | |
| next_key = next(key_iter, None) | |
| while db_idx < len(downbeats) and downbeats[db_idx] <= onset + 1e-6: | |
| if in_bar: | |
| ids.append(BAR_END) | |
| _emit_bar_start(db_idx) | |
| in_bar = True | |
| bar_start_time = float(downbeats[db_idx]) | |
| if db_idx + 1 < len(downbeats): | |
| bar_duration = float(downbeats[db_idx + 1] - downbeats[db_idx]) | |
| last_pos_in_bar = None | |
| db_idx += 1 | |
| if in_bar and bar_duration and bar_duration > 1e-6: | |
| pos_bin = int(round((onset - bar_start_time) / bar_duration * N_POS_BINS)) | |
| pos_bin = max(0, min(N_POS_BINS - 1, pos_bin)) | |
| if pos_bin != last_pos_in_bar: | |
| ids.append(TOKEN2ID[f"POS{pos_bin}"]) | |
| last_pos_in_bar = pos_bin | |
| else: | |
| shift = 0.0 if g_idx == 0 else onset - prev_onset | |
| if shift > TIME_MIN: | |
| ids.append(TOKEN2ID[f"TS{_log_bin(shift)}"]) | |
| if group[0].kind != "note": | |
| ev = group[0] | |
| if ev.voice != current_voice: | |
| ids.append(TOKEN2ID[f"VC{ev.voice}"]) | |
| current_voice = ev.voice | |
| if ev.kind == "pedal": | |
| ids.append(TOKEN2ID[f"PEDAL_{ev.pedal_type}_{ev.pedal_state}"]) | |
| elif ev.kind == "cc": | |
| ids.append(TOKEN2ID[f"CC_{ev.cc_type}_{ev.cc_bin}"]) | |
| elif ev.kind == "pb": | |
| ids.append(TOKEN2ID[f"PB_{ev.pb_bin}"]) | |
| prev_onset = onset | |
| continue | |
| notes = sorted(group, key=lambda e: e.pitch) | |
| is_chord = len(notes) > 1 | |
| if is_chord: | |
| ids.append(CHORD_START) | |
| for n_idx, ev in enumerate(notes): | |
| if ev.voice != current_voice: | |
| ids.append(TOKEN2ID[f"VC{ev.voice}"]) | |
| current_voice = ev.voice | |
| if is_chord: | |
| if n_idx == 0: | |
| ids.append(ROLE_BASS) | |
| elif n_idx == len(notes) - 1: | |
| ids.append(ROLE_TOP) | |
| else: | |
| ids.append(ROLE_INNER) | |
| ids.append(TOKEN2ID[f"V{_vel_bin(ev.velocity)}"]) | |
| ids.append(TOKEN2ID[f"P{ev.pitch}"]) | |
| ids.append(TOKEN2ID[f"D{_log_bin(ev.duration)}"]) | |
| if is_chord: | |
| ids.append(CHORD_END) | |
| prev_onset = onset | |
| if in_bar: | |
| ids.append(BAR_END) | |
| ids.append(PHRASE_END) | |
| ids.append(EOS) | |
| return ids | |
| # --- Decode ------------------------------------------------------------------- | |
| # GM family -> representative program number (one per family of 8). | |
| FAMILY_PROGRAMS = { | |
| 0: 0, # Piano | |
| 1: 8, # Chromatic Percussion | |
| 2: 16, # Organ | |
| 3: 24, # Guitar | |
| 4: 32, # Bass | |
| 5: 40, # Strings | |
| 6: 48, # Ensemble | |
| 7: 56, # Brass | |
| 8: 64, # Reed | |
| 9: 72, # Pipe | |
| 10: 80, # Synth Lead | |
| 11: 88, # Synth Pad | |
| 12: 96, # Synth Effects | |
| 13: 104, # Ethnic | |
| 14: 112, # Percussive | |
| 15: 120, # Sound Effects | |
| } | |
| def _kind(t: str): | |
| """Classify a token name into a (kind, value) pair for the decoder.""" | |
| if t in STRUCTURAL: | |
| return ("struct", t) | |
| if t in ROLES: | |
| return ("role", t) | |
| if t in METERS: | |
| return ("meter", t) | |
| if t.startswith("KEY_") and t[4:].isdigit(): | |
| return ("key", int(t[4:])) | |
| if t.startswith("ROOT_") and t[5:].isdigit(): | |
| return ("root", int(t[5:])) | |
| if t.startswith("DENS_") and t[5:].isdigit(): | |
| return ("dens", int(t[5:])) | |
| if t.startswith("REG_") and t[4:].isdigit(): | |
| return ("reg", int(t[4:])) | |
| if t.startswith("REF_BAR_") and t[8:].isdigit(): | |
| return ("ref", int(t[8:])) | |
| if t.startswith("CAP_SEG_") and t[8:].isdigit(): | |
| return ("capseg", int(t[8:])) | |
| if t.startswith("PEDAL_"): | |
| return ("pedal", t) | |
| if t.startswith("CC_"): | |
| return ("cc", t) | |
| if t.startswith("PB_") and t[3:].isdigit(): | |
| return ("pb", int(t[3:])) | |
| if t.startswith("TS") and t[2:].isdigit(): | |
| return ("ts", int(t[2:])) | |
| if t.startswith("VC") and t[2:].isdigit(): | |
| return ("voice", int(t[2:])) | |
| if t.startswith("POS") and t[3:].isdigit(): | |
| return ("pos", int(t[3:])) | |
| if t.startswith("D") and t[1:].isdigit(): | |
| return ("dur", int(t[1:])) | |
| if t.startswith("V") and t[1:].isdigit(): | |
| return ("vel", int(t[1:])) | |
| if t.startswith("T") and t[1:].isdigit(): | |
| return ("tempo", int(t[1:])) | |
| if t.startswith("P") and t[1:].isdigit(): | |
| return ("pitch", int(t[1:])) | |
| return ("struct", t) | |
| def decode(ids: List[int], default_tempo: float = 120.0) -> pretty_midi.PrettyMIDI: | |
| """Decode a token id list back to a PrettyMIDI. Timing is reconstructed | |
| from POS within bars (using current tempo + meter) or TS deltas as a | |
| fallback. Pitches and instrument families are preserved exactly. | |
| """ | |
| initial_tempo = default_tempo | |
| for tid in ids: | |
| t = ID2TOKEN.get(tid, "") | |
| if t.startswith("T") and not t.startswith("TS") and t[1:].isdigit(): | |
| initial_tempo = _tempo_center(int(t[1:])) | |
| break | |
| pm = pretty_midi.PrettyMIDI(initial_tempo=initial_tempo) | |
| instruments: Dict[int, pretty_midi.Instrument] = {} | |
| current_voice = 0 | |
| def get_inst(v: int) -> pretty_midi.Instrument: | |
| if v not in instruments: | |
| if v == DRUM_VOICE: | |
| instruments[v] = pretty_midi.Instrument( | |
| program=0, | |
| is_drum=True, | |
| name="drums", | |
| ) | |
| else: | |
| prog = FAMILY_PROGRAMS.get(v, 0) | |
| instruments[v] = pretty_midi.Instrument( | |
| program=prog, | |
| name=f"family_{v}", | |
| ) | |
| return instruments[v] | |
| current_tempo = initial_tempo | |
| current_meter_quarters = METER_QUARTERS["METER_4_4"] | |
| bar_duration = current_meter_quarters * 60.0 / current_tempo | |
| bar_start_time = 0.0 | |
| n_bars_seen = 0 | |
| current_time = 0.0 | |
| pending_velocity = 64 | |
| i = 0 | |
| while i < len(ids): | |
| kind, val = _kind(ID2TOKEN.get(ids[i], "PAD")) | |
| if kind == "ts": | |
| current_time += _bin_center(val) | |
| elif kind == "pos": | |
| current_time = bar_start_time + (int(val) / N_POS_BINS) * bar_duration | |
| elif kind == "voice": | |
| current_voice = int(val) | |
| elif kind == "vel": | |
| pending_velocity = _vel_center(int(val)) | |
| elif kind == "tempo": | |
| current_tempo = _tempo_center(int(val)) | |
| bar_duration = current_meter_quarters * 60.0 / current_tempo | |
| elif kind == "meter": | |
| current_meter_quarters = METER_QUARTERS.get(val, 4.0) | |
| bar_duration = current_meter_quarters * 60.0 / current_tempo | |
| elif kind == "key": | |
| try: | |
| pm.key_signature_changes.append( | |
| pretty_midi.KeySignature(int(val), float(current_time)) | |
| ) | |
| except Exception: | |
| pass | |
| elif kind == "pedal": | |
| # PEDAL_<SUS|SOS|SFT>_<UP|DOWN> | |
| parts = str(val).split("_") | |
| if len(parts) == 3: | |
| ptype, pstate = parts[1], parts[2] | |
| cc_num = PEDAL_NAME_TO_CC.get(ptype) | |
| if cc_num is not None: | |
| inst = get_inst(current_voice) | |
| inst.control_changes.append( | |
| pretty_midi.ControlChange( | |
| number=cc_num, | |
| value=127 if pstate == "DOWN" else 0, | |
| time=float(current_time), | |
| ) | |
| ) | |
| elif kind == "cc": | |
| # CC_<NAME>_<BIN> | |
| parts = str(val).split("_") | |
| if len(parts) == 3 and parts[2].isdigit(): | |
| cname, bidx = parts[1], int(parts[2]) | |
| cc_num = CC_NAME_TO_NUMBER.get(cname) | |
| if cc_num is not None: | |
| inst = get_inst(current_voice) | |
| inst.control_changes.append( | |
| pretty_midi.ControlChange( | |
| number=cc_num, | |
| value=_cc_center(bidx), | |
| time=float(current_time), | |
| ) | |
| ) | |
| elif kind == "pb": | |
| inst = get_inst(current_voice) | |
| inst.pitch_bends.append( | |
| pretty_midi.PitchBend( | |
| pitch=_pb_center(int(val)), | |
| time=float(current_time), | |
| ) | |
| ) | |
| elif kind == "struct" and val == "BAR_START": | |
| if n_bars_seen > 0: | |
| bar_start_time += bar_duration | |
| current_time = bar_start_time | |
| n_bars_seen += 1 | |
| elif kind == "struct" and val == "REST": | |
| current_time += 0.25 | |
| elif kind == "pitch": | |
| duration = 0.25 | |
| j = i + 1 | |
| while j < len(ids): | |
| nt = ID2TOKEN.get(ids[j], "PAD") | |
| nkind, nval = _kind(nt) | |
| if nkind == "dur": | |
| duration = _bin_center(int(nval)) | |
| break | |
| if nkind in ( | |
| "pitch", "ts", "pos", "voice", "vel", | |
| "tempo", "meter", "role", "key", | |
| "root", "dens", "reg", "ref", "capseg", | |
| "pedal", "cc", "pb", | |
| ): | |
| break | |
| if nkind == "struct" and nval not in ("CHORD_START", "CHORD_END"): | |
| break | |
| j += 1 | |
| note = pretty_midi.Note( | |
| velocity=int(pending_velocity), | |
| pitch=int(val), | |
| start=current_time, | |
| end=current_time + max(duration, 0.01), | |
| ) | |
| get_inst(current_voice).notes.append(note) | |
| # role / chord brackets / bar_end / phrase / pad / eos: no timing effect | |
| i += 1 | |
| for voice in sorted(instruments): | |
| inst = instruments[voice] | |
| if inst.notes: | |
| pm.instruments.append(inst) | |
| return pm | |
| def inject_caption_segments(ids: List[int], n_segs: int = N_CAP_SEGS) -> List[int]: | |
| """Insert CAP_SEG_<i % n_segs> right after each PHRASE_START. | |
| Use this when you have a multi-sentence caption split into ``n_segs`` | |
| parts and want to bind the i-th phrase of the encoded MIDI to the | |
| i-th caption segment. Emission is opt-in because the segment count | |
| only makes sense in the presence of an external caption. | |
| """ | |
| if n_segs <= 0 or n_segs > N_CAP_SEGS: | |
| raise ValueError(f"n_segs must be in [1, {N_CAP_SEGS}]") | |
| out: List[int] = [] | |
| seen_phrases = 0 | |
| for tid in ids: | |
| out.append(tid) | |
| if tid == PHRASE_START: | |
| out.append(TOKEN2ID[CAP_SEG_NAMES[seen_phrases % n_segs]]) | |
| seen_phrases += 1 | |
| return out | |
| # --- Round-trip test ---------------------------------------------------------- | |
| def _voice_label(inst: "pretty_midi.Instrument") -> int: | |
| return DRUM_VOICE if inst.is_drum else _program_family(inst.program) | |
| def round_trip_test(pm: pretty_midi.PrettyMIDI) -> Tuple[bool, Dict]: | |
| """Verify the (pitch, voice) multiset is preserved through encode+decode. | |
| Timing is not checked because log quantization is lossy. | |
| """ | |
| original = sorted( | |
| (n.pitch, _voice_label(inst)) | |
| for inst in pm.instruments | |
| for n in inst.notes | |
| if PITCH_MIN <= n.pitch <= PITCH_MAX | |
| ) | |
| ids = encode(pm) | |
| pm2 = decode(ids) | |
| reconstructed = sorted( | |
| (n.pitch, _voice_label(inst)) | |
| for inst in pm2.instruments | |
| for n in inst.notes | |
| ) | |
| passed = original == reconstructed | |
| return passed, { | |
| "n_orig": len(original), | |
| "n_recon": len(reconstructed), | |
| "n_tokens": len(ids), | |
| "vocab_size": VOCAB_SIZE, | |
| } | |
| # --- Auto-load fitted velocity quantiles ------------------------------------- | |
| if DEFAULT_VEL_QUANTILES_PATH.exists(): | |
| try: | |
| load_velocity_quantiles(DEFAULT_VEL_QUANTILES_PATH) | |
| except Exception: | |
| pass | |
| # --- CLI ---------------------------------------------------------------------- | |
| def _cli_fit_velocity_quantiles(sample_dir: Path, out_path: Path) -> None: | |
| paths = sorted(sample_dir.rglob("*.mid")) + sorted(sample_dir.rglob("*.midi")) | |
| if not paths: | |
| raise SystemExit(f"No MIDI files under {sample_dir}") | |
| velocities: List[int] = [] | |
| n_failed = 0 | |
| for p in paths: | |
| try: | |
| pm = pretty_midi.PrettyMIDI(str(p)) | |
| except Exception: | |
| n_failed += 1 | |
| continue | |
| for inst in pm.instruments: | |
| for n in inst.notes: | |
| velocities.append(int(n.velocity)) | |
| edges = fit_velocity_quantiles(velocities, n_bins=N_VEL_BINS) | |
| save_velocity_quantiles(edges, out_path) | |
| print( | |
| f"[velocity] files={len(paths)} failed={n_failed} " | |
| f"velocities={len(velocities)} -> {out_path}" | |
| ) | |
| print(f"[velocity] edges = {[round(e, 2) for e in edges]}") | |
| if __name__ == "__main__": | |
| import argparse as _argparse | |
| import sys as _sys | |
| if len(_sys.argv) > 1 and _sys.argv[1] == "fit-velocity": | |
| p = _argparse.ArgumentParser() | |
| p.add_argument( | |
| "--sample-dir", | |
| type=str, | |
| default=str(Path(__file__).resolve().parent.parent / "data" / "gigamidi" / "sample"), | |
| ) | |
| p.add_argument("--out", type=str, default=str(DEFAULT_VEL_QUANTILES_PATH)) | |
| args = p.parse_args(_sys.argv[2:]) | |
| _cli_fit_velocity_quantiles(Path(args.sample_dir), Path(args.out)) | |
| _sys.exit(0) | |
| print(f"Vocab size: {VOCAB_SIZE}") | |
| print(f" structural: {len(STRUCTURAL)}") | |
| print(f" pitch: {N_PITCH}") | |
| print(f" duration: {N_DUR_BINS}") | |
| print(f" time-shift: {N_SHIFT_BINS}") | |
| print(f" velocity: {N_VEL_BINS}") | |
| print(f" voice/family: {N_VOICE_BINS}") | |
| print(f" tempo: {N_TEMPO_BINS}") | |
| # Smoke test: build a tiny C-major scale and round-trip it. | |
| pm = pretty_midi.PrettyMIDI() | |
| inst = pretty_midi.Instrument(program=0) | |
| t = 0.0 | |
| for p in [60, 62, 64, 65, 67, 69, 71, 72]: | |
| inst.notes.append(pretty_midi.Note(velocity=80, pitch=p, start=t, end=t + 0.5)) | |
| t += 0.5 | |
| pm.instruments.append(inst) | |
| ok, info = round_trip_test(pm) | |
| print(f"\nSmoke test round-trip: {'PASS' if ok else 'FAIL'} {info}") | |
| # Multi-track / multi-velocity test. | |
| pm2 = pretty_midi.PrettyMIDI(initial_tempo=92.0) | |
| piano = pretty_midi.Instrument(program=0) # family 0 (piano) | |
| bass = pretty_midi.Instrument(program=33) # family 4 (bass) | |
| strings = pretty_midi.Instrument(program=48) # family 6 (ensemble) | |
| t = 0.0 | |
| for p, v in [(60, 30), (64, 80), (67, 110), (72, 60)]: | |
| piano.notes.append(pretty_midi.Note(velocity=v, pitch=p, start=t, end=t + 0.5)) | |
| bass.notes.append(pretty_midi.Note(velocity=70, pitch=p - 24, start=t, end=t + 0.5)) | |
| strings.notes.append(pretty_midi.Note(velocity=50, pitch=p + 12, start=t, end=t + 0.5)) | |
| t += 0.5 | |
| pm2.instruments += [piano, bass, strings] | |
| ok2, info2 = round_trip_test(pm2) | |
| print(f"Multi-track round-trip: {'PASS' if ok2 else 'FAIL'} {info2}") | |
| ids = encode(pm2) | |
| has_tempo = any( | |
| ID2TOKEN[i].startswith("T") and not ID2TOKEN[i].startswith("TS") | |
| and ID2TOKEN[i][1:].isdigit() | |
| for i in ids | |
| ) | |
| has_voice = any(ID2TOKEN[i].startswith("VC") for i in ids) | |
| has_meter = any(ID2TOKEN[i] in METERS for i in ids) | |
| has_pos = any(ID2TOKEN[i].startswith("POS") for i in ids) | |
| has_chord = CHORD_START in ids | |
| has_role = any(t in ids for t in (ROLE_BASS, ROLE_INNER, ROLE_TOP)) | |
| print( | |
| f"Stream features: tempo={has_tempo} voice={has_voice} " | |
| f"meter={has_meter} pos={has_pos} chord={has_chord} role={has_role}" | |
| ) | |