""" MIDI Tokenizer - REMI-style encoding for MIDI files Converts MIDI files to/from token sequences that can be processed by the transformer model. Supports conditioning tokens for tempo, instrument, mood, style, key, role, and section labels. """ import json from dataclasses import dataclass, field from pathlib import Path from typing import Optional import numpy as np import pretty_midi # Vocabulary ranges PITCH_RANGE = (21, 109) # Piano range A0-C8 VELOCITY_BINS = 32 DURATION_BINS = 64 POSITION_RESOLUTION = 32 # 32nd notes for rolls and finer groove MAX_BARS = 256 TEMPO_BINS = 32 TEMPO_RANGE = (40, 240) STYLE_TOKENS = [ "trap", "reggaeton", "house", "techno", "edm", "hiphop", "lofi", "ambient", "pop", "rnb", "drill", "cinematic" ] KEY_TOKENS = [ "C_MAJOR", "C_MINOR", "C#_MAJOR", "C#_MINOR", "D_MAJOR", "D_MINOR", "D#_MAJOR", "D#_MINOR", "E_MAJOR", "E_MINOR", "F_MAJOR", "F_MINOR", "F#_MAJOR", "F#_MINOR", "G_MAJOR", "G_MINOR", "G#_MAJOR", "G#_MINOR", "A_MAJOR", "A_MINOR", "A#_MAJOR", "A#_MINOR", "B_MAJOR", "B_MINOR", ] ROLE_TOKENS = ["drums", "bass", "lead", "pad", "fx", "chords", "melody"] SECTION_TOKENS = ["intro", "verse", "chorus", "bridge", "drop", "breakdown", "outro"] @dataclass class TokenizerConfig: """Configuration for MIDI tokenizer.""" pitch_range: tuple = PITCH_RANGE velocity_bins: int = VELOCITY_BINS duration_bins: int = DURATION_BINS position_resolution: int = POSITION_RESOLUTION max_bars: int = MAX_BARS tempo_bins: int = TEMPO_BINS tempo_range: tuple = TEMPO_RANGE max_seq_len: int = 2048 # Conditioning use_tempo_condition: bool = True use_instrument_condition: bool = True use_mood_condition: bool = True use_style_condition: bool = True use_key_condition: bool = True use_role_condition: bool = True use_section_condition: bool = True @dataclass class MIDITokenizer: """ REMI-style MIDI tokenizer with conditioning support. Token structure: - Special tokens: PAD, BOS, EOS, SEP, MASK - Conditioning tokens: TEMPO_*, INST_*, MOOD_* - Bar tokens: BAR - Position tokens: POS_0 to POS_{resolution-1} - Pitch tokens: PITCH_21 to PITCH_108 - Velocity tokens: VEL_0 to VEL_{bins-1} - Duration tokens: DUR_0 to DUR_{bins-1} """ config: TokenizerConfig = field(default_factory=TokenizerConfig) def __post_init__(self): self._build_vocab() def _build_vocab(self): """Build the vocabulary mapping.""" self.token_to_id = {} self.id_to_token = {} idx = 0 # Special tokens self.special_tokens = ["", "", "", "", "", ""] for token in self.special_tokens: self.token_to_id[token] = idx self.id_to_token[idx] = token idx += 1 self.pad_id = self.token_to_id[""] self.bos_id = self.token_to_id[""] self.eos_id = self.token_to_id[""] self.sep_id = self.token_to_id[""] self.mask_id = self.token_to_id[""] self.unk_id = self.token_to_id[""] self.tempo_tokens = [] self.instrument_tokens = [] self.mood_tokens = [] self.style_tokens = [] self.key_tokens = [] self.role_tokens = [] self.section_tokens = [] # Tempo conditioning tokens (quantized BPM) if self.config.use_tempo_condition: tempo_min, tempo_max = self.config.tempo_range for i in range(self.config.tempo_bins): bpm = tempo_min + (tempo_max - tempo_min) * i / (self.config.tempo_bins - 1) token = f"" self.tempo_tokens.append(token) self.token_to_id[token] = idx self.id_to_token[idx] = token idx += 1 # Instrument conditioning tokens (GM program categories) if self.config.use_instrument_condition: self.instrument_tokens = [ "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "" ] for token in self.instrument_tokens: self.token_to_id[token] = idx self.id_to_token[idx] = token idx += 1 # Mood conditioning tokens if self.config.use_mood_condition: self.mood_tokens = [ "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "" ] for token in self.mood_tokens: self.token_to_id[token] = idx self.id_to_token[idx] = token idx += 1 if self.config.use_style_condition: self.style_tokens = [f"" for style in STYLE_TOKENS] for token in self.style_tokens: self.token_to_id[token] = idx self.id_to_token[idx] = token idx += 1 if self.config.use_key_condition: self.key_tokens = [f"" for key in KEY_TOKENS] for token in self.key_tokens: self.token_to_id[token] = idx self.id_to_token[idx] = token idx += 1 if self.config.use_role_condition: self.role_tokens = [f"" for role in ROLE_TOKENS] for token in self.role_tokens: self.token_to_id[token] = idx self.id_to_token[idx] = token idx += 1 if self.config.use_section_condition: self.section_tokens = [f"" for section in SECTION_TOKENS] for token in self.section_tokens: self.token_to_id[token] = idx self.id_to_token[idx] = token idx += 1 # Bar token self.bar_token = "" self.token_to_id[self.bar_token] = idx self.id_to_token[idx] = self.bar_token idx += 1 # Position tokens (within bar) self.position_tokens = [] for i in range(self.config.position_resolution * 4): # 4 beats per bar token = f"" self.position_tokens.append(token) self.token_to_id[token] = idx self.id_to_token[idx] = token idx += 1 # Pitch tokens self.pitch_tokens = [] for pitch in range(self.config.pitch_range[0], self.config.pitch_range[1]): token = f"" self.pitch_tokens.append(token) self.token_to_id[token] = idx self.id_to_token[idx] = token idx += 1 # Velocity tokens self.velocity_tokens = [] for i in range(self.config.velocity_bins): token = f"" self.velocity_tokens.append(token) self.token_to_id[token] = idx self.id_to_token[idx] = token idx += 1 # Duration tokens self.duration_tokens = [] for i in range(self.config.duration_bins): token = f"" self.duration_tokens.append(token) self.token_to_id[token] = idx self.id_to_token[idx] = token idx += 1 self.vocab_size = idx def _tempo_to_token(self, bpm: float) -> str: """Convert BPM to tempo token.""" tempo_min, tempo_max = self.config.tempo_range bpm = np.clip(bpm, tempo_min, tempo_max) bin_idx = int((bpm - tempo_min) / (tempo_max - tempo_min) * (self.config.tempo_bins - 1)) return self.tempo_tokens[bin_idx] def _token_to_tempo(self, token: str) -> float: """Convert tempo token back to BPM.""" if token in self.tempo_tokens: idx = self.tempo_tokens.index(token) tempo_min, tempo_max = self.config.tempo_range return tempo_min + (tempo_max - tempo_min) * idx / (self.config.tempo_bins - 1) return 120.0 # Default def _program_to_instrument(self, program: int, is_drum: bool = False) -> str: """Convert MIDI program number to instrument category token.""" if is_drum: return "" categories = [ "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "" ] return categories[program // 8] def _velocity_to_bin(self, velocity: int) -> int: """Quantize velocity to bin.""" return min(velocity * self.config.velocity_bins // 128, self.config.velocity_bins - 1) def _bin_to_velocity(self, bin_idx: int) -> int: """Convert bin back to velocity.""" return int((bin_idx + 0.5) * 128 / self.config.velocity_bins) def _duration_to_bin(self, duration: float, ticks_per_beat: int = 480) -> int: """Quantize duration to bin (in ticks).""" # Duration bins: 0 = 1/64 note, 63 = whole note tied tick_per_bin = ticks_per_beat // 8 # 1/32 note base bin_idx = int(duration / tick_per_bin) return min(bin_idx, self.config.duration_bins - 1) def _bin_to_duration(self, bin_idx: int, ticks_per_beat: int = 480) -> float: """Convert bin back to duration in ticks.""" tick_per_bin = ticks_per_beat // 8 return (bin_idx + 0.5) * tick_per_bin def encode( self, midi_path: str | Path, tempo: Optional[float] = None, instrument: Optional[str] = None, mood: Optional[str] = None, style: Optional[str] = None, key: Optional[str] = None, role: Optional[str] = None, section: Optional[str] = None, ) -> list[int]: """ Encode a MIDI file to token sequence. Args: midi_path: Path to MIDI file tempo: Optional tempo override (detected from file if None) instrument: Optional instrument hint mood: Optional mood label style: Optional style label key: Optional key label, e.g. "D_MINOR" role: Optional track role section: Optional section label Returns: List of token IDs """ try: pm = pretty_midi.PrettyMIDI(str(midi_path)) except Exception as e: raise ValueError(f"Failed to load MIDI: {e}") tokens = [""] # Add conditioning tokens if self.config.use_tempo_condition: if tempo is None: # Estimate tempo from file tempo_times, tempos = pm.get_tempo_changes() tempo = tempos[0] if len(tempos) > 0 else 120.0 tokens.append(self._tempo_to_token(tempo)) if self.config.use_instrument_condition and instrument: inst_token = f"" if inst_token in self.token_to_id: tokens.append(inst_token) elif self.config.use_instrument_condition and len(pm.instruments) > 0: # Use first non-drum instrument for inst in pm.instruments: if not inst.is_drum: tokens.append(self._program_to_instrument(inst.program)) break if self.config.use_mood_condition and mood: mood_token = f"" if mood_token in self.token_to_id: tokens.append(mood_token) if self.config.use_style_condition and style: style_token = f"" if style_token in self.token_to_id: tokens.append(style_token) if self.config.use_key_condition and key: normalized_key = key.upper().replace(" ", "_") key_token = f"" if key_token in self.token_to_id: tokens.append(key_token) if self.config.use_role_condition and role: role_token = f"" if role_token in self.token_to_id: tokens.append(role_token) if self.config.use_section_condition and section: section_token = f"" if section_token in self.token_to_id: tokens.append(section_token) tokens.append("") # Collect all notes with timing all_notes = [] ticks_per_beat = 480 # Standard for inst in pm.instruments: for note in inst.notes: # Convert time to ticks start_tick = int(note.start * ticks_per_beat * (tempo or 120) / 60) duration_tick = int(note.duration * ticks_per_beat * (tempo or 120) / 60) all_notes.append({ 'start': start_tick, 'pitch': note.pitch, 'velocity': note.velocity, 'duration': duration_tick, 'is_drum': inst.is_drum }) # Sort by start time all_notes.sort(key=lambda x: (x['start'], x['pitch'])) # Convert to tokens ticks_per_bar = ticks_per_beat * 4 # 4/4 time ticks_per_pos = ticks_per_bar // (self.config.position_resolution * 4) current_bar = -1 for note in all_notes: # Skip notes outside pitch range (except drums) if not note['is_drum']: if note['pitch'] < self.config.pitch_range[0] or note['pitch'] >= self.config.pitch_range[1]: continue bar = note['start'] // ticks_per_bar pos = (note['start'] % ticks_per_bar) // ticks_per_pos # Add bar tokens if needed while current_bar < bar: current_bar += 1 tokens.append("") if current_bar >= self.config.max_bars: break if current_bar >= self.config.max_bars: break # Position pos = min(pos, len(self.position_tokens) - 1) tokens.append(f"") # Pitch pitch = np.clip(note['pitch'], self.config.pitch_range[0], self.config.pitch_range[1] - 1) tokens.append(f"") # Velocity vel_bin = self._velocity_to_bin(note['velocity']) tokens.append(f"") # Duration dur_bin = self._duration_to_bin(note['duration'], ticks_per_beat) tokens.append(f"") # Limit sequence length if len(tokens) >= self.config.max_seq_len - 1: break tokens.append("") # Convert to IDs token_ids = [self.token_to_id.get(t, self.unk_id) for t in tokens] return token_ids def decode(self, token_ids: list[int], output_path: Optional[str | Path] = None) -> pretty_midi.PrettyMIDI: """ Decode token sequence back to MIDI. Args: token_ids: List of token IDs output_path: Optional path to save MIDI file Returns: PrettyMIDI object """ tokens = [self.id_to_token.get(tid, "") for tid in token_ids] # Parse conditioning tempo = 120.0 for token in tokens: if token.startswith("": current_bar += 1 i += 1 continue if token.startswith(" list[list[int]]: """Encode multiple MIDI files.""" return [self.encode(p, **kwargs) for p in midi_paths] def pad_sequence(self, token_ids: list[int], max_len: Optional[int] = None) -> list[int]: """Pad or truncate sequence to max length.""" max_len = max_len or self.config.max_seq_len if len(token_ids) > max_len: return token_ids[:max_len] return token_ids + [self.pad_id] * (max_len - len(token_ids)) def save(self, path: str | Path): """Save tokenizer to file.""" path = Path(path) path.mkdir(parents=True, exist_ok=True) with open(path / "tokenizer_config.json", "w") as f: json.dump(self.config.__dict__, f, indent=2) with open(path / "vocab.json", "w") as f: json.dump(self.token_to_id, f, indent=2) @classmethod def load(cls, path: str | Path) -> "MIDITokenizer": """Load tokenizer from file.""" path = Path(path) with open(path / "tokenizer_config.json") as f: config_dict = json.load(f) config = TokenizerConfig(**config_dict) return cls(config=config)