Spaces:
Sleeping
Sleeping
| """Byte-pair encoding on top of the base MIDI tokenizer. | |
| Trains greedy pair merges over sequences of base token ids and | |
| exposes apply/unapply for use during dataset construction and decode. | |
| Merge id space | |
| -------------- | |
| Base ids occupy [0, base_vocab_size). Merge i (0-indexed) is assigned id | |
| base_vocab_size + i, so the BPE-aware vocab size is base_vocab_size + n_merges. | |
| Boundary protection | |
| ------------------- | |
| Pairs where either side equals PAD or EOS are never merged, so the model | |
| still sees explicit sequence terminators. | |
| API | |
| --- | |
| train_bpe(streams, n_merges, base_vocab_size, no_merge_ids) -> merges | |
| apply_bpe(ids, merges) -> List[int] | |
| unapply_bpe(ids, merges) -> List[int] | |
| save(merges, path), load(path) | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import random | |
| from collections import Counter | |
| from pathlib import Path | |
| from typing import Dict, Iterable, List, Optional, Sequence, Tuple | |
| # A merge table is a list of ((left, right), merged_id) entries in | |
| # learning order. The order matters: earlier merges may participate in | |
| # later merges. Stored as a JSON list of [left, right, merged_id]. | |
| Merge = Tuple[int, int, int] | |
| def default_no_merge_ids() -> set: | |
| """Structural ids that must remain unmerged so the model sees clean | |
| sequence/phrase/bar boundaries.""" | |
| from tokenizer import ( | |
| BAR_END, | |
| BAR_START, | |
| EOS, | |
| PAD, | |
| PHRASE_END, | |
| PHRASE_START, | |
| ) | |
| return {PAD, EOS, PHRASE_START, PHRASE_END, BAR_START, BAR_END} | |
| def _count_pairs( | |
| streams: Sequence[Sequence[int]], | |
| no_merge_ids: set, | |
| ) -> Counter: | |
| counter: Counter = Counter() | |
| for s in streams: | |
| for a, b in zip(s, s[1:]): | |
| if a in no_merge_ids or b in no_merge_ids: | |
| continue | |
| counter[(a, b)] += 1 | |
| return counter | |
| def _replace_pair( | |
| seq: Sequence[int], | |
| pair: Tuple[int, int], | |
| new_id: int, | |
| dropout: float = 0.0, | |
| rng: Optional[random.Random] = None, | |
| ) -> List[int]: | |
| """Replace adjacent occurrences of ``pair`` with ``new_id``. | |
| With ``dropout > 0`` each occurrence is independently skipped with that | |
| probability, leaving the original two tokens in place. This is the | |
| BPE-dropout regularization from Provilkov et al. 2020. | |
| """ | |
| a, b = pair | |
| out: List[int] = [] | |
| i = 0 | |
| n = len(seq) | |
| while i < n: | |
| if i + 1 < n and seq[i] == a and seq[i + 1] == b: | |
| if dropout > 0.0 and (rng or random).random() < dropout: | |
| out.append(seq[i]) | |
| i += 1 | |
| else: | |
| out.append(new_id) | |
| i += 2 | |
| else: | |
| out.append(seq[i]) | |
| i += 1 | |
| return out | |
| def train_bpe( | |
| streams: Sequence[Sequence[int]], | |
| n_merges: int, | |
| base_vocab_size: int, | |
| no_merge_ids: Iterable[int] = (), | |
| min_pair_count: int = 2, | |
| ) -> List[Merge]: | |
| """Greedy BPE on base-id sequences. Returns the merge list.""" | |
| no_merge = set(no_merge_ids) | |
| working: List[List[int]] = [list(s) for s in streams] | |
| merges: List[Merge] = [] | |
| next_id = base_vocab_size | |
| for _ in range(n_merges): | |
| counter = _count_pairs(working, no_merge) | |
| if not counter: | |
| break | |
| (best_pair, best_count) = counter.most_common(1)[0] | |
| if best_count < min_pair_count: | |
| break | |
| merges.append((best_pair[0], best_pair[1], next_id)) | |
| working = [_replace_pair(s, best_pair, next_id) for s in working] | |
| next_id += 1 | |
| return merges | |
| def apply_bpe( | |
| ids: Sequence[int], | |
| merges: Sequence[Merge], | |
| dropout: float = 0.0, | |
| rng: Optional[random.Random] = None, | |
| ) -> List[int]: | |
| """Apply merges in learned order. O(M * N) per stream; fine offline. | |
| ``dropout`` enables BPE-dropout: each merge candidate is randomly | |
| skipped with this probability, exposing the model to multiple | |
| segmentations of the same underlying base sequence. Use 0.0 at | |
| inference time and roughly 0.1 during training. | |
| """ | |
| out = list(ids) | |
| for left, right, merged in merges: | |
| out = _replace_pair(out, (left, right), merged, dropout=dropout, rng=rng) | |
| return out | |
| def unapply_bpe(ids: Sequence[int], merges: Sequence[Merge]) -> List[int]: | |
| """Expand merged ids back to base. Walks merges in reverse order.""" | |
| expand: Dict[int, Tuple[int, int]] = {m[2]: (m[0], m[1]) for m in merges} | |
| if not expand: | |
| return list(ids) | |
| out = list(ids) | |
| changed = True | |
| while changed: | |
| changed = False | |
| new_out: List[int] = [] | |
| for tid in out: | |
| if tid in expand: | |
| a, b = expand[tid] | |
| new_out.append(a) | |
| new_out.append(b) | |
| changed = True | |
| else: | |
| new_out.append(tid) | |
| out = new_out | |
| return out | |
| def save(merges: Sequence[Merge], path: Path) -> None: | |
| path = Path(path) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| path.write_text(json.dumps([list(m) for m in merges])) | |
| def load(path: Path) -> List[Merge]: | |
| path = Path(path) | |
| if not path.exists(): | |
| return [] | |
| data = json.loads(path.read_text()) | |
| return [(int(a), int(b), int(c)) for a, b, c in data] | |
| def effective_vocab_size(base_vocab_size: int, merges: Sequence[Merge]) -> int: | |
| return base_vocab_size + len(merges) | |
| # --- CLI ---------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import argparse | |
| import sys | |
| from pathlib import Path as _P | |
| _SRC = _P(__file__).resolve().parent | |
| if str(_SRC) not in sys.path: | |
| sys.path.insert(0, str(_SRC)) | |
| import pretty_midi # noqa: E402 | |
| from tokenizer import ( # noqa: E402 | |
| BAR_END, | |
| BAR_START, | |
| EOS, | |
| PAD, | |
| PHRASE_END, | |
| PHRASE_START, | |
| VOCAB_SIZE, | |
| encode, | |
| ) | |
| parser = argparse.ArgumentParser(description="Train BPE on encoded MIDI.") | |
| parser.add_argument( | |
| "--sample-dir", | |
| type=str, | |
| default=str(_SRC.parent / "data" / "gigamidi" / "sample"), | |
| ) | |
| parser.add_argument("--n-merges", type=int, default=2000) | |
| parser.add_argument( | |
| "--out", | |
| type=str, | |
| default=str(_SRC.parent / "data" / "bpe" / "merges.json"), | |
| ) | |
| args = parser.parse_args() | |
| sample_dir = _P(args.sample_dir) | |
| midi_paths = ( | |
| sorted(sample_dir.rglob("*.mid")) | |
| + sorted(sample_dir.rglob("*.midi")) | |
| ) | |
| if not midi_paths: | |
| raise SystemExit(f"No MIDI files found under {sample_dir}") | |
| streams: List[List[int]] = [] | |
| n_failed = 0 | |
| for p in midi_paths: | |
| try: | |
| pm = pretty_midi.PrettyMIDI(str(p)) | |
| streams.append(encode(pm)) | |
| except Exception: | |
| n_failed += 1 | |
| n_base_tokens = sum(len(s) for s in streams) | |
| print( | |
| f"[bpe] streams={len(streams)} failed={n_failed} " | |
| f"base_tokens={n_base_tokens}" | |
| ) | |
| merges = train_bpe( | |
| streams=streams, | |
| n_merges=args.n_merges, | |
| base_vocab_size=VOCAB_SIZE, | |
| no_merge_ids={PAD, EOS, PHRASE_START, PHRASE_END, BAR_START, BAR_END}, | |
| ) | |
| after = sum(len(apply_bpe(s, merges)) for s in streams) | |
| print( | |
| f"[bpe] learned {len(merges)} merges; " | |
| f"compression: {n_base_tokens} -> {after} " | |
| f"({(1 - after / max(n_base_tokens, 1)) * 100:.1f}% fewer tokens)" | |
| ) | |
| out_path = _P(args.out) | |
| save(merges, out_path) | |
| print(f"[bpe] saved -> {out_path}") | |