Coda / src /bpe.py
Prajanya Gupta
initial deploy
6b7b403
"""Byte-pair encoding on top of the base MIDI tokenizer.
Trains greedy pair merges over sequences of base token ids and
exposes apply/unapply for use during dataset construction and decode.
Merge id space
--------------
Base ids occupy [0, base_vocab_size). Merge i (0-indexed) is assigned id
base_vocab_size + i, so the BPE-aware vocab size is base_vocab_size + n_merges.
Boundary protection
-------------------
Pairs where either side equals PAD or EOS are never merged, so the model
still sees explicit sequence terminators.
API
---
train_bpe(streams, n_merges, base_vocab_size, no_merge_ids) -> merges
apply_bpe(ids, merges) -> List[int]
unapply_bpe(ids, merges) -> List[int]
save(merges, path), load(path)
"""
from __future__ import annotations
import json
import random
from collections import Counter
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
# A merge table is a list of ((left, right), merged_id) entries in
# learning order. The order matters: earlier merges may participate in
# later merges. Stored as a JSON list of [left, right, merged_id].
Merge = Tuple[int, int, int]
def default_no_merge_ids() -> set:
"""Structural ids that must remain unmerged so the model sees clean
sequence/phrase/bar boundaries."""
from tokenizer import (
BAR_END,
BAR_START,
EOS,
PAD,
PHRASE_END,
PHRASE_START,
)
return {PAD, EOS, PHRASE_START, PHRASE_END, BAR_START, BAR_END}
def _count_pairs(
streams: Sequence[Sequence[int]],
no_merge_ids: set,
) -> Counter:
counter: Counter = Counter()
for s in streams:
for a, b in zip(s, s[1:]):
if a in no_merge_ids or b in no_merge_ids:
continue
counter[(a, b)] += 1
return counter
def _replace_pair(
seq: Sequence[int],
pair: Tuple[int, int],
new_id: int,
dropout: float = 0.0,
rng: Optional[random.Random] = None,
) -> List[int]:
"""Replace adjacent occurrences of ``pair`` with ``new_id``.
With ``dropout > 0`` each occurrence is independently skipped with that
probability, leaving the original two tokens in place. This is the
BPE-dropout regularization from Provilkov et al. 2020.
"""
a, b = pair
out: List[int] = []
i = 0
n = len(seq)
while i < n:
if i + 1 < n and seq[i] == a and seq[i + 1] == b:
if dropout > 0.0 and (rng or random).random() < dropout:
out.append(seq[i])
i += 1
else:
out.append(new_id)
i += 2
else:
out.append(seq[i])
i += 1
return out
def train_bpe(
streams: Sequence[Sequence[int]],
n_merges: int,
base_vocab_size: int,
no_merge_ids: Iterable[int] = (),
min_pair_count: int = 2,
) -> List[Merge]:
"""Greedy BPE on base-id sequences. Returns the merge list."""
no_merge = set(no_merge_ids)
working: List[List[int]] = [list(s) for s in streams]
merges: List[Merge] = []
next_id = base_vocab_size
for _ in range(n_merges):
counter = _count_pairs(working, no_merge)
if not counter:
break
(best_pair, best_count) = counter.most_common(1)[0]
if best_count < min_pair_count:
break
merges.append((best_pair[0], best_pair[1], next_id))
working = [_replace_pair(s, best_pair, next_id) for s in working]
next_id += 1
return merges
def apply_bpe(
ids: Sequence[int],
merges: Sequence[Merge],
dropout: float = 0.0,
rng: Optional[random.Random] = None,
) -> List[int]:
"""Apply merges in learned order. O(M * N) per stream; fine offline.
``dropout`` enables BPE-dropout: each merge candidate is randomly
skipped with this probability, exposing the model to multiple
segmentations of the same underlying base sequence. Use 0.0 at
inference time and roughly 0.1 during training.
"""
out = list(ids)
for left, right, merged in merges:
out = _replace_pair(out, (left, right), merged, dropout=dropout, rng=rng)
return out
def unapply_bpe(ids: Sequence[int], merges: Sequence[Merge]) -> List[int]:
"""Expand merged ids back to base. Walks merges in reverse order."""
expand: Dict[int, Tuple[int, int]] = {m[2]: (m[0], m[1]) for m in merges}
if not expand:
return list(ids)
out = list(ids)
changed = True
while changed:
changed = False
new_out: List[int] = []
for tid in out:
if tid in expand:
a, b = expand[tid]
new_out.append(a)
new_out.append(b)
changed = True
else:
new_out.append(tid)
out = new_out
return out
def save(merges: Sequence[Merge], path: Path) -> None:
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps([list(m) for m in merges]))
def load(path: Path) -> List[Merge]:
path = Path(path)
if not path.exists():
return []
data = json.loads(path.read_text())
return [(int(a), int(b), int(c)) for a, b, c in data]
def effective_vocab_size(base_vocab_size: int, merges: Sequence[Merge]) -> int:
return base_vocab_size + len(merges)
# --- CLI ----------------------------------------------------------------------
if __name__ == "__main__":
import argparse
import sys
from pathlib import Path as _P
_SRC = _P(__file__).resolve().parent
if str(_SRC) not in sys.path:
sys.path.insert(0, str(_SRC))
import pretty_midi # noqa: E402
from tokenizer import ( # noqa: E402
BAR_END,
BAR_START,
EOS,
PAD,
PHRASE_END,
PHRASE_START,
VOCAB_SIZE,
encode,
)
parser = argparse.ArgumentParser(description="Train BPE on encoded MIDI.")
parser.add_argument(
"--sample-dir",
type=str,
default=str(_SRC.parent / "data" / "gigamidi" / "sample"),
)
parser.add_argument("--n-merges", type=int, default=2000)
parser.add_argument(
"--out",
type=str,
default=str(_SRC.parent / "data" / "bpe" / "merges.json"),
)
args = parser.parse_args()
sample_dir = _P(args.sample_dir)
midi_paths = (
sorted(sample_dir.rglob("*.mid"))
+ sorted(sample_dir.rglob("*.midi"))
)
if not midi_paths:
raise SystemExit(f"No MIDI files found under {sample_dir}")
streams: List[List[int]] = []
n_failed = 0
for p in midi_paths:
try:
pm = pretty_midi.PrettyMIDI(str(p))
streams.append(encode(pm))
except Exception:
n_failed += 1
n_base_tokens = sum(len(s) for s in streams)
print(
f"[bpe] streams={len(streams)} failed={n_failed} "
f"base_tokens={n_base_tokens}"
)
merges = train_bpe(
streams=streams,
n_merges=args.n_merges,
base_vocab_size=VOCAB_SIZE,
no_merge_ids={PAD, EOS, PHRASE_START, PHRASE_END, BAR_START, BAR_END},
)
after = sum(len(apply_bpe(s, merges)) for s in streams)
print(
f"[bpe] learned {len(merges)} merges; "
f"compression: {n_base_tokens} -> {after} "
f"({(1 - after / max(n_base_tokens, 1)) * 100:.1f}% fewer tokens)"
)
out_path = _P(args.out)
save(merges, out_path)
print(f"[bpe] saved -> {out_path}")