Coda / src /compound.py
Prajanya Gupta
initial deploy
6b7b403
"""Octuple/CPWord-style compound tokenizer for bach-gpt.
Each event becomes a single "step" with N parallel feature ids instead of
N consecutive tokens. The compound encoder is a thin post-processor over
the 1D ``tokenizer.encode`` stream, so all structural logic (chord
brackets, voice-role, bar headers, REF, key/meter/tempo, etc.) is shared.
Step layout
-----------
A compound step is a tuple of ``N_AXES`` feature ids, one per axis:
(step_type, pitch, duration, velocity, position, voice, aux)
Step types
----------
0 PAD all other axes = sentinel
1 BOS / PHRASE_START
2 EOS / PHRASE_END
3 BAR_START pos = 0; aux carries packed (root, density, register)
4 BAR_END
5 NOTE pitch/dur/vel filled; voice + position; aux = role
6 PEDAL_DOWN voice + position; aux = pedal_type (0/1/2 = SUS/SOS/SFT)
7 PEDAL_UP voice + position; aux = pedal_type
8 CC_CHANGE voice + position; aux = CC_type * N_CC_BINS + bin
9 PITCH_BEND voice + position; aux = pb_bin
10 CHORD_START structural; pos optional
11 CHORD_END structural
Sentinels
---------
A per-axis "no value" id is the last entry of each axis. Models embed
this just like any other id; the sentinel just means "this axis doesn't
apply at this step type."
"""
from __future__ import annotations
from typing import List, Sequence, Tuple
import pretty_midi
from tokenizer import (
BAR_END,
BAR_START,
CHORD_END,
CHORD_START,
DRUM_VOICE,
EOS,
ID2TOKEN,
METERS,
METER_QUARTERS,
N_CC_BINS,
N_PB_BINS,
N_POS_BINS,
N_VEL_BINS,
N_VOICE_BINS,
PHRASE_END,
PHRASE_START,
PITCH_MAX,
PITCH_MIN,
REF_NAMES,
ROLES,
ROOT_NAMES,
DENS_NAMES,
REG_NAMES,
PB_NAMES,
CC_NAMES,
PEDAL_NAMES,
CC_TYPES,
PEDAL_CC_NUMBERS,
TOKEN2ID,
VOCAB_SIZE,
_bin_center,
_cc_center,
_pb_center,
_tempo_center,
_vel_center,
decode as decode_1d,
encode as encode_1d,
)
# --- Step type vocabulary ----------------------------------------------------
STEP_PAD = 0
STEP_BOS = 1
STEP_EOS = 2
STEP_BAR_START = 3
STEP_BAR_END = 4
STEP_NOTE = 5
STEP_PEDAL_DOWN = 6
STEP_PEDAL_UP = 7
STEP_CC = 8
STEP_PB = 9
STEP_CHORD_START = 10
STEP_CHORD_END = 11
N_STEP_TYPES = 12
# Per-axis cardinalities (last entry of each is the sentinel).
N_PITCH_AXIS = (PITCH_MAX - PITCH_MIN + 1) + 1 # 88 + sentinel
N_DUR_AXIS = 32 + 1
N_VEL_AXIS = N_VEL_BINS + 1
N_POS_AXIS = N_POS_BINS + 1
N_VOICE_AXIS = N_VOICE_BINS + 1
# Aux axis is type-multiplexed. We size it for the largest payload.
N_PEDAL_TYPES = 3
N_CC_AUX = len(CC_TYPES) * N_CC_BINS
N_AUX = max(len(ROLES) + 1, N_PEDAL_TYPES, N_CC_AUX, N_PB_BINS)
N_AUX_AXIS = N_AUX + 1
AXIS_SIZES = [
N_STEP_TYPES,
N_PITCH_AXIS,
N_DUR_AXIS,
N_VEL_AXIS,
N_POS_AXIS,
N_VOICE_AXIS,
N_AUX_AXIS,
]
AXIS_NAMES = ["step", "pitch", "dur", "vel", "pos", "voice", "aux"]
N_AXES = len(AXIS_SIZES)
# Per-axis sentinel ids = the last index in that axis.
SENTINELS = [n - 1 for n in AXIS_SIZES]
def empty_step() -> List[int]:
"""Return a step with all axes set to their sentinel."""
return list(SENTINELS)
# --- Helpers -----------------------------------------------------------------
PEDAL_TYPE_TO_AUX = {"SUS": 0, "SOS": 1, "SFT": 2}
ROLE_NAME_TO_AUX = {"ROLE_BASS": 0, "ROLE_INNER": 1, "ROLE_TOP": 2}
CC_NAME_TO_AUX_BASE = {
name: i * N_CC_BINS for i, name in enumerate(CC_TYPES.values())
}
def _classify(name: str) -> Tuple[str, str]:
"""Return (kind, payload) for a 1D token name."""
if name in ("PAD", "EOS", "PHRASE_START", "PHRASE_END",
"BAR_START", "BAR_END", "REST",
"CHORD_START", "CHORD_END"):
return ("struct", name)
if name in ROLES:
return ("role", name)
if name in METERS:
return ("meter", name)
if name in REF_NAMES:
return ("ref", name)
if name in PEDAL_NAMES:
return ("pedal", name)
if name in CC_NAMES:
return ("cc", name)
if name in PB_NAMES:
return ("pb", name)
if name in ROOT_NAMES:
return ("root", name)
if name in DENS_NAMES:
return ("dens", name)
if name in REG_NAMES:
return ("reg", name)
if name.startswith("VC") and name[2:].isdigit():
return ("voice", name[2:])
if name.startswith("POS") and name[3:].isdigit():
return ("pos", name[3:])
if name.startswith("TS") and name[2:].isdigit():
return ("ts", name[2:])
if name.startswith("D") and name[1:].isdigit():
return ("dur", name[1:])
if name.startswith("V") and name[1:].isdigit():
return ("vel", name[1:])
if name.startswith("T") and name[1:].isdigit():
return ("tempo", name[1:])
if name.startswith("KEY_") and name[4:].isdigit():
return ("key", name[4:])
if name.startswith("CAP_SEG_") and name[8:].isdigit():
return ("capseg", name[8:])
if name.startswith("P") and name[1:].isdigit():
return ("pitch", name[1:])
return ("struct", name)
# --- Encode ------------------------------------------------------------------
def encode_compound(pm: pretty_midi.PrettyMIDI) -> List[List[int]]:
"""Convert a PrettyMIDI to a list of compound steps via the 1D encoder.
The 1D stream is walked once; each musical event collapses into a
single compound step. Bare metadata tokens (KEY/METER/TEMPO/ROLE/
bar headers/REF/CAP_SEG) update running state but don't emit a step.
"""
ids_1d = encode_1d(pm)
steps: List[List[int]] = []
cur_voice = SENTINELS[5]
cur_pos = SENTINELS[4]
pending_role = SENTINELS[6]
pending_vel = SENTINELS[3]
i = 0
while i < len(ids_1d):
name = ID2TOKEN.get(ids_1d[i], "PAD")
kind, payload = _classify(name)
if kind == "struct":
tag = payload
if tag in ("PHRASE_START",):
steps.append([
STEP_BOS, SENTINELS[1], SENTINELS[2],
SENTINELS[3], SENTINELS[4], SENTINELS[5], SENTINELS[6],
])
elif tag in ("PHRASE_END",):
steps.append([
STEP_EOS, SENTINELS[1], SENTINELS[2],
SENTINELS[3], SENTINELS[4], SENTINELS[5], SENTINELS[6],
])
elif tag == "BAR_START":
steps.append([
STEP_BAR_START, SENTINELS[1], SENTINELS[2],
SENTINELS[3], 0, cur_voice, SENTINELS[6],
])
cur_pos = 0
elif tag == "BAR_END":
steps.append([
STEP_BAR_END, SENTINELS[1], SENTINELS[2],
SENTINELS[3], SENTINELS[4], cur_voice, SENTINELS[6],
])
elif tag == "CHORD_START":
steps.append([
STEP_CHORD_START, SENTINELS[1], SENTINELS[2],
SENTINELS[3], cur_pos, cur_voice, SENTINELS[6],
])
elif tag == "CHORD_END":
steps.append([
STEP_CHORD_END, SENTINELS[1], SENTINELS[2],
SENTINELS[3], cur_pos, cur_voice, SENTINELS[6],
])
# PAD, EOS, REST: ignored; metadata-only are skipped quietly
elif kind == "voice":
cur_voice = int(payload)
elif kind == "pos":
cur_pos = int(payload)
elif kind == "vel":
pending_vel = int(payload)
elif kind == "role":
pending_role = ROLE_NAME_TO_AUX.get(payload, SENTINELS[6])
elif kind == "pitch":
# Look ahead for the duration token bound to this pitch.
midi = int(payload)
dur_bin = SENTINELS[2]
j = i + 1
while j < len(ids_1d):
nname = ID2TOKEN.get(ids_1d[j], "PAD")
nkind, npay = _classify(nname)
if nkind == "dur":
dur_bin = int(npay)
break
if nkind in ("pitch", "vel", "voice", "pos", "ts",
"tempo", "meter", "key", "role",
"root", "dens", "reg", "ref",
"pedal", "cc", "pb", "capseg"):
break
j += 1
steps.append([
STEP_NOTE,
midi - PITCH_MIN,
dur_bin,
pending_vel,
cur_pos,
cur_voice,
pending_role,
])
pending_role = SENTINELS[6]
elif kind == "pedal":
# PEDAL_<TYPE>_<STATE>
parts = payload.split("_")
ptype = PEDAL_TYPE_TO_AUX.get(parts[1], 0)
stype = STEP_PEDAL_DOWN if parts[2] == "DOWN" else STEP_PEDAL_UP
steps.append([
stype, SENTINELS[1], SENTINELS[2], SENTINELS[3],
cur_pos, cur_voice, ptype,
])
elif kind == "cc":
# CC_<NAME>_<BIN>
parts = payload.split("_")
cname = parts[1]
cbin = int(parts[2])
aux = CC_NAME_TO_AUX_BASE.get(cname, 0) + cbin
steps.append([
STEP_CC, SENTINELS[1], SENTINELS[2], SENTINELS[3],
cur_pos, cur_voice, aux,
])
elif kind == "pb":
steps.append([
STEP_PB, SENTINELS[1], SENTINELS[2], SENTINELS[3],
cur_pos, cur_voice, int(payload[3:]) if payload[3:].isdigit() else 0,
])
# Other metadata kinds (tempo/meter/key/ref/root/dens/reg/capseg/ts)
# update running state implicitly via the 1D decoder; the compound
# stream omits them since the model can derive them from absolute
# position + KEY/METER/TEMPO that the 1D decoder also handles.
i += 1
return steps
# --- Decode ------------------------------------------------------------------
def decode_compound(steps: Sequence[Sequence[int]]) -> pretty_midi.PrettyMIDI:
"""Reconstruct a PrettyMIDI from a compound step list. Approximates
timing using the assumption that each BAR_START corresponds to one
bar at 4/4 + 120 BPM (i.e., 2.0 s per bar). Pitches and instrument
routing are exact; precise timing requires the 1D stream.
"""
pm = pretty_midi.PrettyMIDI(initial_tempo=120.0)
instruments: dict = {}
bar_duration = 2.0 # 4/4 at 120 BPM
bar_idx = 0
bar_start_time = 0.0
cur_time = 0.0
def get_inst(v: int) -> pretty_midi.Instrument:
if v not in instruments:
from tokenizer import FAMILY_PROGRAMS
if v == DRUM_VOICE:
instruments[v] = pretty_midi.Instrument(program=0, is_drum=True, name="drums")
else:
instruments[v] = pretty_midi.Instrument(
program=FAMILY_PROGRAMS.get(v, 0),
name=f"family_{v}",
)
return instruments[v]
for step in steps:
stype = int(step[0])
if stype == STEP_BAR_START:
if bar_idx > 0:
bar_start_time += bar_duration
cur_time = bar_start_time
bar_idx += 1
elif stype == STEP_NOTE:
midi_off = int(step[1])
dur_bin = int(step[2])
vel_bin = int(step[3])
pos_bin = int(step[4])
voice = int(step[5])
if voice == SENTINELS[5]:
continue
cur_time = bar_start_time + (pos_bin / N_POS_BINS) * bar_duration if pos_bin != SENTINELS[4] else cur_time
duration = _bin_center(dur_bin) if dur_bin != SENTINELS[2] else 0.25
velocity = _vel_center(vel_bin) if vel_bin != SENTINELS[3] else 64
pitch = midi_off + PITCH_MIN
note = pretty_midi.Note(
velocity=int(velocity),
pitch=int(pitch),
start=cur_time,
end=cur_time + max(duration, 0.01),
)
get_inst(voice).notes.append(note)
elif stype in (STEP_PEDAL_DOWN, STEP_PEDAL_UP):
voice = int(step[5])
ptype_aux = int(step[6])
if voice != SENTINELS[5] and ptype_aux < N_PEDAL_TYPES:
ptype_name = list(PEDAL_TYPE_TO_AUX.keys())[ptype_aux]
cc_num = next(k for k, v in PEDAL_CC_NUMBERS.items() if v == ptype_name)
value = 127 if stype == STEP_PEDAL_DOWN else 0
get_inst(voice).control_changes.append(
pretty_midi.ControlChange(number=cc_num, value=value, time=cur_time)
)
elif stype == STEP_CC:
voice = int(step[5])
aux = int(step[6])
if voice != SENTINELS[5] and aux < N_CC_AUX:
cc_idx = aux // N_CC_BINS
cc_bin = aux % N_CC_BINS
cc_name = list(CC_TYPES.values())[cc_idx]
cc_num = next(k for k, v in CC_TYPES.items() if v == cc_name)
get_inst(voice).control_changes.append(
pretty_midi.ControlChange(
number=cc_num, value=_cc_center(cc_bin), time=cur_time
)
)
elif stype == STEP_PB:
voice = int(step[5])
if voice != SENTINELS[5]:
get_inst(voice).pitch_bends.append(
pretty_midi.PitchBend(pitch=_pb_center(int(step[6])), time=cur_time)
)
# CHORD_*, BAR_END, BOS, EOS, PAD: structural, no PrettyMIDI side effect
for v in sorted(instruments):
if instruments[v].notes or instruments[v].control_changes or instruments[v].pitch_bends:
pm.instruments.append(instruments[v])
return pm
# --- Smoke test --------------------------------------------------------------
if __name__ == "__main__":
print(f"Compound axes: {AXIS_NAMES}")
print(f"Axis sizes: {AXIS_SIZES}")
print(f"Sentinels: {SENTINELS}")
pm = pretty_midi.PrettyMIDI(initial_tempo=120.0)
inst = pretty_midi.Instrument(program=0)
t = 0.0
for p in [60, 64, 67, 72]:
inst.notes.append(pretty_midi.Note(velocity=80, pitch=p, start=t, end=t + 0.5))
t += 0.5
pm.instruments.append(inst)
steps = encode_compound(pm)
print(f"\n{len(steps)} compound steps from {len(inst.notes)} notes:")
for s in steps[:8]:
print(f" {s}")
pm2 = decode_compound(steps)
n_recon = sum(len(i.notes) for i in pm2.instruments)
print(f"\nDecoded {n_recon} notes (orig {len(inst.notes)})")