"""Octuple/CPWord-style compound tokenizer for bach-gpt. Each event becomes a single "step" with N parallel feature ids instead of N consecutive tokens. The compound encoder is a thin post-processor over the 1D ``tokenizer.encode`` stream, so all structural logic (chord brackets, voice-role, bar headers, REF, key/meter/tempo, etc.) is shared. Step layout ----------- A compound step is a tuple of ``N_AXES`` feature ids, one per axis: (step_type, pitch, duration, velocity, position, voice, aux) Step types ---------- 0 PAD all other axes = sentinel 1 BOS / PHRASE_START 2 EOS / PHRASE_END 3 BAR_START pos = 0; aux carries packed (root, density, register) 4 BAR_END 5 NOTE pitch/dur/vel filled; voice + position; aux = role 6 PEDAL_DOWN voice + position; aux = pedal_type (0/1/2 = SUS/SOS/SFT) 7 PEDAL_UP voice + position; aux = pedal_type 8 CC_CHANGE voice + position; aux = CC_type * N_CC_BINS + bin 9 PITCH_BEND voice + position; aux = pb_bin 10 CHORD_START structural; pos optional 11 CHORD_END structural Sentinels --------- A per-axis "no value" id is the last entry of each axis. Models embed this just like any other id; the sentinel just means "this axis doesn't apply at this step type." """ from __future__ import annotations from typing import List, Sequence, Tuple import pretty_midi from tokenizer import ( BAR_END, BAR_START, CHORD_END, CHORD_START, DRUM_VOICE, EOS, ID2TOKEN, METERS, METER_QUARTERS, N_CC_BINS, N_PB_BINS, N_POS_BINS, N_VEL_BINS, N_VOICE_BINS, PHRASE_END, PHRASE_START, PITCH_MAX, PITCH_MIN, REF_NAMES, ROLES, ROOT_NAMES, DENS_NAMES, REG_NAMES, PB_NAMES, CC_NAMES, PEDAL_NAMES, CC_TYPES, PEDAL_CC_NUMBERS, TOKEN2ID, VOCAB_SIZE, _bin_center, _cc_center, _pb_center, _tempo_center, _vel_center, decode as decode_1d, encode as encode_1d, ) # --- Step type vocabulary ---------------------------------------------------- STEP_PAD = 0 STEP_BOS = 1 STEP_EOS = 2 STEP_BAR_START = 3 STEP_BAR_END = 4 STEP_NOTE = 5 STEP_PEDAL_DOWN = 6 STEP_PEDAL_UP = 7 STEP_CC = 8 STEP_PB = 9 STEP_CHORD_START = 10 STEP_CHORD_END = 11 N_STEP_TYPES = 12 # Per-axis cardinalities (last entry of each is the sentinel). N_PITCH_AXIS = (PITCH_MAX - PITCH_MIN + 1) + 1 # 88 + sentinel N_DUR_AXIS = 32 + 1 N_VEL_AXIS = N_VEL_BINS + 1 N_POS_AXIS = N_POS_BINS + 1 N_VOICE_AXIS = N_VOICE_BINS + 1 # Aux axis is type-multiplexed. We size it for the largest payload. N_PEDAL_TYPES = 3 N_CC_AUX = len(CC_TYPES) * N_CC_BINS N_AUX = max(len(ROLES) + 1, N_PEDAL_TYPES, N_CC_AUX, N_PB_BINS) N_AUX_AXIS = N_AUX + 1 AXIS_SIZES = [ N_STEP_TYPES, N_PITCH_AXIS, N_DUR_AXIS, N_VEL_AXIS, N_POS_AXIS, N_VOICE_AXIS, N_AUX_AXIS, ] AXIS_NAMES = ["step", "pitch", "dur", "vel", "pos", "voice", "aux"] N_AXES = len(AXIS_SIZES) # Per-axis sentinel ids = the last index in that axis. SENTINELS = [n - 1 for n in AXIS_SIZES] def empty_step() -> List[int]: """Return a step with all axes set to their sentinel.""" return list(SENTINELS) # --- Helpers ----------------------------------------------------------------- PEDAL_TYPE_TO_AUX = {"SUS": 0, "SOS": 1, "SFT": 2} ROLE_NAME_TO_AUX = {"ROLE_BASS": 0, "ROLE_INNER": 1, "ROLE_TOP": 2} CC_NAME_TO_AUX_BASE = { name: i * N_CC_BINS for i, name in enumerate(CC_TYPES.values()) } def _classify(name: str) -> Tuple[str, str]: """Return (kind, payload) for a 1D token name.""" if name in ("PAD", "EOS", "PHRASE_START", "PHRASE_END", "BAR_START", "BAR_END", "REST", "CHORD_START", "CHORD_END"): return ("struct", name) if name in ROLES: return ("role", name) if name in METERS: return ("meter", name) if name in REF_NAMES: return ("ref", name) if name in PEDAL_NAMES: return ("pedal", name) if name in CC_NAMES: return ("cc", name) if name in PB_NAMES: return ("pb", name) if name in ROOT_NAMES: return ("root", name) if name in DENS_NAMES: return ("dens", name) if name in REG_NAMES: return ("reg", name) if name.startswith("VC") and name[2:].isdigit(): return ("voice", name[2:]) if name.startswith("POS") and name[3:].isdigit(): return ("pos", name[3:]) if name.startswith("TS") and name[2:].isdigit(): return ("ts", name[2:]) if name.startswith("D") and name[1:].isdigit(): return ("dur", name[1:]) if name.startswith("V") and name[1:].isdigit(): return ("vel", name[1:]) if name.startswith("T") and name[1:].isdigit(): return ("tempo", name[1:]) if name.startswith("KEY_") and name[4:].isdigit(): return ("key", name[4:]) if name.startswith("CAP_SEG_") and name[8:].isdigit(): return ("capseg", name[8:]) if name.startswith("P") and name[1:].isdigit(): return ("pitch", name[1:]) return ("struct", name) # --- Encode ------------------------------------------------------------------ def encode_compound(pm: pretty_midi.PrettyMIDI) -> List[List[int]]: """Convert a PrettyMIDI to a list of compound steps via the 1D encoder. The 1D stream is walked once; each musical event collapses into a single compound step. Bare metadata tokens (KEY/METER/TEMPO/ROLE/ bar headers/REF/CAP_SEG) update running state but don't emit a step. """ ids_1d = encode_1d(pm) steps: List[List[int]] = [] cur_voice = SENTINELS[5] cur_pos = SENTINELS[4] pending_role = SENTINELS[6] pending_vel = SENTINELS[3] i = 0 while i < len(ids_1d): name = ID2TOKEN.get(ids_1d[i], "PAD") kind, payload = _classify(name) if kind == "struct": tag = payload if tag in ("PHRASE_START",): steps.append([ STEP_BOS, SENTINELS[1], SENTINELS[2], SENTINELS[3], SENTINELS[4], SENTINELS[5], SENTINELS[6], ]) elif tag in ("PHRASE_END",): steps.append([ STEP_EOS, SENTINELS[1], SENTINELS[2], SENTINELS[3], SENTINELS[4], SENTINELS[5], SENTINELS[6], ]) elif tag == "BAR_START": steps.append([ STEP_BAR_START, SENTINELS[1], SENTINELS[2], SENTINELS[3], 0, cur_voice, SENTINELS[6], ]) cur_pos = 0 elif tag == "BAR_END": steps.append([ STEP_BAR_END, SENTINELS[1], SENTINELS[2], SENTINELS[3], SENTINELS[4], cur_voice, SENTINELS[6], ]) elif tag == "CHORD_START": steps.append([ STEP_CHORD_START, SENTINELS[1], SENTINELS[2], SENTINELS[3], cur_pos, cur_voice, SENTINELS[6], ]) elif tag == "CHORD_END": steps.append([ STEP_CHORD_END, SENTINELS[1], SENTINELS[2], SENTINELS[3], cur_pos, cur_voice, SENTINELS[6], ]) # PAD, EOS, REST: ignored; metadata-only are skipped quietly elif kind == "voice": cur_voice = int(payload) elif kind == "pos": cur_pos = int(payload) elif kind == "vel": pending_vel = int(payload) elif kind == "role": pending_role = ROLE_NAME_TO_AUX.get(payload, SENTINELS[6]) elif kind == "pitch": # Look ahead for the duration token bound to this pitch. midi = int(payload) dur_bin = SENTINELS[2] j = i + 1 while j < len(ids_1d): nname = ID2TOKEN.get(ids_1d[j], "PAD") nkind, npay = _classify(nname) if nkind == "dur": dur_bin = int(npay) break if nkind in ("pitch", "vel", "voice", "pos", "ts", "tempo", "meter", "key", "role", "root", "dens", "reg", "ref", "pedal", "cc", "pb", "capseg"): break j += 1 steps.append([ STEP_NOTE, midi - PITCH_MIN, dur_bin, pending_vel, cur_pos, cur_voice, pending_role, ]) pending_role = SENTINELS[6] elif kind == "pedal": # PEDAL__ parts = payload.split("_") ptype = PEDAL_TYPE_TO_AUX.get(parts[1], 0) stype = STEP_PEDAL_DOWN if parts[2] == "DOWN" else STEP_PEDAL_UP steps.append([ stype, SENTINELS[1], SENTINELS[2], SENTINELS[3], cur_pos, cur_voice, ptype, ]) elif kind == "cc": # CC__ parts = payload.split("_") cname = parts[1] cbin = int(parts[2]) aux = CC_NAME_TO_AUX_BASE.get(cname, 0) + cbin steps.append([ STEP_CC, SENTINELS[1], SENTINELS[2], SENTINELS[3], cur_pos, cur_voice, aux, ]) elif kind == "pb": steps.append([ STEP_PB, SENTINELS[1], SENTINELS[2], SENTINELS[3], cur_pos, cur_voice, int(payload[3:]) if payload[3:].isdigit() else 0, ]) # Other metadata kinds (tempo/meter/key/ref/root/dens/reg/capseg/ts) # update running state implicitly via the 1D decoder; the compound # stream omits them since the model can derive them from absolute # position + KEY/METER/TEMPO that the 1D decoder also handles. i += 1 return steps # --- Decode ------------------------------------------------------------------ def decode_compound(steps: Sequence[Sequence[int]]) -> pretty_midi.PrettyMIDI: """Reconstruct a PrettyMIDI from a compound step list. Approximates timing using the assumption that each BAR_START corresponds to one bar at 4/4 + 120 BPM (i.e., 2.0 s per bar). Pitches and instrument routing are exact; precise timing requires the 1D stream. """ pm = pretty_midi.PrettyMIDI(initial_tempo=120.0) instruments: dict = {} bar_duration = 2.0 # 4/4 at 120 BPM bar_idx = 0 bar_start_time = 0.0 cur_time = 0.0 def get_inst(v: int) -> pretty_midi.Instrument: if v not in instruments: from tokenizer import FAMILY_PROGRAMS if v == DRUM_VOICE: instruments[v] = pretty_midi.Instrument(program=0, is_drum=True, name="drums") else: instruments[v] = pretty_midi.Instrument( program=FAMILY_PROGRAMS.get(v, 0), name=f"family_{v}", ) return instruments[v] for step in steps: stype = int(step[0]) if stype == STEP_BAR_START: if bar_idx > 0: bar_start_time += bar_duration cur_time = bar_start_time bar_idx += 1 elif stype == STEP_NOTE: midi_off = int(step[1]) dur_bin = int(step[2]) vel_bin = int(step[3]) pos_bin = int(step[4]) voice = int(step[5]) if voice == SENTINELS[5]: continue cur_time = bar_start_time + (pos_bin / N_POS_BINS) * bar_duration if pos_bin != SENTINELS[4] else cur_time duration = _bin_center(dur_bin) if dur_bin != SENTINELS[2] else 0.25 velocity = _vel_center(vel_bin) if vel_bin != SENTINELS[3] else 64 pitch = midi_off + PITCH_MIN note = pretty_midi.Note( velocity=int(velocity), pitch=int(pitch), start=cur_time, end=cur_time + max(duration, 0.01), ) get_inst(voice).notes.append(note) elif stype in (STEP_PEDAL_DOWN, STEP_PEDAL_UP): voice = int(step[5]) ptype_aux = int(step[6]) if voice != SENTINELS[5] and ptype_aux < N_PEDAL_TYPES: ptype_name = list(PEDAL_TYPE_TO_AUX.keys())[ptype_aux] cc_num = next(k for k, v in PEDAL_CC_NUMBERS.items() if v == ptype_name) value = 127 if stype == STEP_PEDAL_DOWN else 0 get_inst(voice).control_changes.append( pretty_midi.ControlChange(number=cc_num, value=value, time=cur_time) ) elif stype == STEP_CC: voice = int(step[5]) aux = int(step[6]) if voice != SENTINELS[5] and aux < N_CC_AUX: cc_idx = aux // N_CC_BINS cc_bin = aux % N_CC_BINS cc_name = list(CC_TYPES.values())[cc_idx] cc_num = next(k for k, v in CC_TYPES.items() if v == cc_name) get_inst(voice).control_changes.append( pretty_midi.ControlChange( number=cc_num, value=_cc_center(cc_bin), time=cur_time ) ) elif stype == STEP_PB: voice = int(step[5]) if voice != SENTINELS[5]: get_inst(voice).pitch_bends.append( pretty_midi.PitchBend(pitch=_pb_center(int(step[6])), time=cur_time) ) # CHORD_*, BAR_END, BOS, EOS, PAD: structural, no PrettyMIDI side effect for v in sorted(instruments): if instruments[v].notes or instruments[v].control_changes or instruments[v].pitch_bends: pm.instruments.append(instruments[v]) return pm # --- Smoke test -------------------------------------------------------------- if __name__ == "__main__": print(f"Compound axes: {AXIS_NAMES}") print(f"Axis sizes: {AXIS_SIZES}") print(f"Sentinels: {SENTINELS}") pm = pretty_midi.PrettyMIDI(initial_tempo=120.0) inst = pretty_midi.Instrument(program=0) t = 0.0 for p in [60, 64, 67, 72]: inst.notes.append(pretty_midi.Note(velocity=80, pitch=p, start=t, end=t + 0.5)) t += 0.5 pm.instruments.append(inst) steps = encode_compound(pm) print(f"\n{len(steps)} compound steps from {len(inst.notes)} notes:") for s in steps[:8]: print(f" {s}") pm2 = decode_compound(steps) n_recon = sum(len(i.notes) for i in pm2.instruments) print(f"\nDecoded {n_recon} notes (orig {len(inst.notes)})")