|
|
import symusic |
|
|
import pretty_midi |
|
|
import numpy as np |
|
|
from dataclasses import dataclass, asdict |
|
|
from typing import List, Tuple, Dict, TypeVar, Generic, Type |
|
|
import json |
|
|
import random |
|
|
import logging |
|
|
from util import crop_sm |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
MIDI_DRUM_PITCHES = range(22, 82) |
|
|
|
|
|
|
|
|
def crop_sm(sm, n_beats): |
|
|
""" |
|
|
Crop a symbolic music object to a specific number of beats. |
|
|
|
|
|
Parameters: |
|
|
----------- |
|
|
sm : object |
|
|
Symbolic music object with tpq attribute and clip method |
|
|
n_beats : int |
|
|
Number of beats to keep |
|
|
|
|
|
Returns: |
|
|
-------- |
|
|
object |
|
|
Cropped symbolic music object |
|
|
""" |
|
|
|
|
|
sm_copy = sm.copy() |
|
|
tpq = sm_copy.tpq |
|
|
|
|
|
|
|
|
if sm_copy.end() > n_beats * tpq: |
|
|
|
|
|
sm_copy = sm_copy.clip(0, n_beats * tpq, clip_end=True) |
|
|
|
|
|
return sm_copy |
|
|
|
|
|
|
|
|
class Quantizer: |
|
|
def __init__( |
|
|
self, value_range: Tuple[float, float], n_bins: int, round_values: bool = False |
|
|
): |
|
|
self.range = value_range |
|
|
self.n_bins = n_bins |
|
|
self.bins = np.linspace(value_range[0], value_range[1], n_bins) |
|
|
if round_values: |
|
|
self.bins = np.round(self.bins).astype(int) |
|
|
|
|
|
|
|
|
def quantize(self, value: float): |
|
|
"""Returns the closest bin value for a given input.""" |
|
|
return self.bins[np.argmin(np.abs(self.bins - value))] |
|
|
|
|
|
@dataclass |
|
|
class TokenizerConfig: |
|
|
pass |
|
|
|
|
|
|
|
|
T = TypeVar("T", bound=TokenizerConfig) |
|
|
|
|
|
|
|
|
class BaseTokenizer(Generic[T]): |
|
|
"""Abstract base class for MIDI tokenizers.""" |
|
|
|
|
|
config_cls: Type[T] |
|
|
|
|
|
def __init__(self, config: T) -> None: |
|
|
self.config = config |
|
|
self.vocab: List[str] = [] |
|
|
self.token_to_idx: Dict[str, int] = {} |
|
|
self.pad_token_id = -1 |
|
|
|
|
|
def to_json(self, path: str) -> None: |
|
|
"""Save the tokenizer configuration to a JSON file.""" |
|
|
with open(path, "w") as f: |
|
|
json.dump(self.config.__dict__, f, indent=2) |
|
|
|
|
|
@classmethod |
|
|
def from_json(cls, path: str): |
|
|
"""Load the tokenizer configuration from a JSON file.""" |
|
|
with open(path, "r") as f: |
|
|
config_dict = json.load(f) |
|
|
config = cls.config_cls(**config_dict) |
|
|
return cls(config) |
|
|
|
|
|
def midi_to_tokens(self, midi: symusic.Score) -> List[str]: |
|
|
raise NotImplementedError |
|
|
|
|
|
def tokens_to_midi(self, tokens: List[str]) -> symusic.Score: |
|
|
raise NotImplementedError |
|
|
|
|
|
def ids_to_midi(self, ids: List[int]) -> symusic.Score: |
|
|
return self.tokens_to_midi(self.ids_to_tokens(ids)) |
|
|
|
|
|
def midi_to_ids(self, midi: symusic.Score) -> List[int]: |
|
|
return self.tokens_to_ids(self.midi_to_tokens(midi)) |
|
|
|
|
|
def tokens_to_ids(self, tokens: List[str]) -> List[int]: |
|
|
"""Convert tokens to their corresponding indices.""" |
|
|
return [self.token_to_idx[token] for token in tokens] |
|
|
|
|
|
def ids_to_tokens(self, ids: List[int]) -> List[str]: |
|
|
"""Convert indices back to tokens.""" |
|
|
return [self.vocab[idx] for idx in ids] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TanjaTokenizerConfig(TokenizerConfig): |
|
|
ticks_per_beat: int |
|
|
coarse_ticks_per_beat: int |
|
|
tempo_range: Tuple[int, int] |
|
|
n_tempo_bins: int |
|
|
n_velocity_bins: int |
|
|
n_bars : int |
|
|
n_events : int |
|
|
|
|
|
def dict(self): |
|
|
return {k: str(v) for k, v in asdict(self).items()} |
|
|
|
|
|
class TanjaTokenizer(BaseTokenizer): |
|
|
|
|
|
''' |
|
|
CMLM Tokenizer. |
|
|
This tokenizer outputs a list of tokens in the following format: |
|
|
# First tempo is provided. |
|
|
Tempo |
|
|
# Then for each note we have 7 attributes. |
|
|
Program Pitch OnsetCoarse OnsetFine Offset Duration Velocity |
|
|
# There is also a mask token |
|
|
''' |
|
|
|
|
|
def __init__(self, config: TanjaTokenizerConfig): |
|
|
|
|
|
self.config = config |
|
|
self.vocab = [] |
|
|
|
|
|
self.n_beats = config.n_bars * 4 |
|
|
|
|
|
self.vocab.append("BOS_None") |
|
|
self.vocab.append("EOS_None") |
|
|
self.vocab.append("SEP_None") |
|
|
self.vocab.append("PAD_None") |
|
|
self.vocab.append("MASK_None") |
|
|
|
|
|
|
|
|
|
|
|
self.tempo_quantizer = Quantizer( |
|
|
config.tempo_range, config.n_tempo_bins, round_values=True |
|
|
) |
|
|
|
|
|
self.vocab.extend(f"Tempo_{tempo}" for tempo in self.tempo_quantizer.bins) |
|
|
|
|
|
|
|
|
for i in range(128): |
|
|
self.vocab.append(f"Program_{i}") |
|
|
|
|
|
|
|
|
self.vocab.append(f"Program_Drums") |
|
|
|
|
|
|
|
|
self.vocab.append(f"Program_inactive") |
|
|
|
|
|
|
|
|
self.vocab.extend(f"Pitch_{pitch}" for pitch in range(128)) |
|
|
|
|
|
|
|
|
self.vocab.extend(f"Pitch_Drum{pitch}" for pitch in range(128)) |
|
|
|
|
|
|
|
|
self.vocab.append(f"Pitch_inactive") |
|
|
|
|
|
|
|
|
self.vocab.extend(f"Onset_{i}" for i in range(0, self.n_beats * self.config.ticks_per_beat, config.coarse_ticks_per_beat)) |
|
|
|
|
|
self.vocab.append(f"Onset_inactive") |
|
|
|
|
|
|
|
|
self.vocab.extend(f"Microtiming_{i}" for i in range(self.config.coarse_ticks_per_beat)) |
|
|
|
|
|
self.vocab.append(f"Microtiming_inactive") |
|
|
|
|
|
|
|
|
self.vocab.extend(f"Offset_{i}" for i in range(0, (self.n_beats + 1) * self.config.ticks_per_beat, config.coarse_ticks_per_beat)) |
|
|
|
|
|
self.vocab.append(f"Offset_inactive") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
thirtysecond_ticks = (self.config.ticks_per_beat * 4) // 32 |
|
|
fourbar_ticks = (self.config.ticks_per_beat * self.n_beats) |
|
|
|
|
|
ticks = thirtysecond_ticks |
|
|
while ticks <= fourbar_ticks: |
|
|
|
|
|
self.vocab.append(f"Duration_{ticks}") |
|
|
|
|
|
ticks *= 2 |
|
|
|
|
|
self.durations = [int(t.split("_")[-1]) for t in self.vocab if t.startswith("Duration_")] |
|
|
|
|
|
|
|
|
self.vocab.append(f"Duration_inactive") |
|
|
|
|
|
|
|
|
self.velocity_quantizer = Quantizer( |
|
|
(1, 127), config.n_velocity_bins, round_values=True |
|
|
) |
|
|
|
|
|
self.vocab.extend(f"Velocity_{v}" for v in self.velocity_quantizer.bins) |
|
|
|
|
|
self.vocab.append(f"Velocity_inactive") |
|
|
|
|
|
self.event_attribute_order = [ |
|
|
"Program", |
|
|
"Pitch", |
|
|
"Onset", |
|
|
"Microtiming", |
|
|
"Offset", |
|
|
"Duration", |
|
|
"Velocity", |
|
|
] |
|
|
|
|
|
self.token_to_idx = {token: idx for idx, token in enumerate(self.vocab)} |
|
|
|
|
|
def remove_special_tokens(self, tokens: List[str]) -> List[str]: |
|
|
"""Remove special tokens from the token list.""" |
|
|
special_tokens = ["BOS_None", "EOS_None", "SEP_None", "PAD_None"] |
|
|
return [token for token in tokens if token not in special_tokens] |
|
|
|
|
|
def get_inactive_note_tokens(self): |
|
|
|
|
|
program_token = f"Program_inactive" |
|
|
pitch_token = f"Pitch_inactive" |
|
|
onset_coarse_token = f"Onset_inactive" |
|
|
onset_fine_token = f"Microtiming_inactive" |
|
|
offset_token = f"Offset_inactive" |
|
|
duration_token = f"Duration_inactive" |
|
|
velocity_token = f"Velocity_inactive" |
|
|
|
|
|
return [program_token, pitch_token, onset_coarse_token, onset_fine_token, offset_token, duration_token, velocity_token] |
|
|
|
|
|
def get_closest_duration(self, duration: float) -> int: |
|
|
"""Get the closest duration in self.durations, round down.""" |
|
|
return min(self.durations, key=lambda x: abs(x - duration)) |
|
|
|
|
|
def get_note_tokens(self, note, program, is_drums): |
|
|
|
|
|
program_token = f"Program_{program}" if not is_drums else f"Program_Drums" |
|
|
pitch_token = f"Pitch_{note.pitch}" if not is_drums else f"Pitch_Drum{note.pitch}" |
|
|
onset_coarse_token = f"Onset_{int(self.config.coarse_ticks_per_beat * (note.start // self.config.coarse_ticks_per_beat))}" |
|
|
onset_fine_token = f"Microtiming_{int(note.start % self.config.coarse_ticks_per_beat)}" |
|
|
offset_token = f"Offset_{min(int(self.config.coarse_ticks_per_beat * (note.end // self.config.coarse_ticks_per_beat)), self.n_beats * self.config.ticks_per_beat)}" |
|
|
|
|
|
duration = self.get_closest_duration(note.end - note.start) |
|
|
|
|
|
duration_token = f"Duration_{duration}" |
|
|
velocity_token = f"Velocity_{self.velocity_quantizer.quantize(note.velocity)}" |
|
|
|
|
|
return [program_token, pitch_token, onset_coarse_token, onset_fine_token, offset_token, duration_token, velocity_token] |
|
|
|
|
|
def tokens_to_ids(self, tokens): |
|
|
return super().tokens_to_ids(tokens) |
|
|
|
|
|
def ids_to_tokens(self, ids): |
|
|
return super().ids_to_tokens(ids) |
|
|
|
|
|
def midi_to_token_ids(self, midi, shuffle_events=True): |
|
|
"""Convert a MIDI score to token IDs.""" |
|
|
tokens = self.midi_to_tokens(midi, shuffle_events) |
|
|
return self.tokens_to_ids(tokens) |
|
|
|
|
|
def token_ids_to_midi(self, token_ids): |
|
|
"""Convert token IDs back to a MIDI score.""" |
|
|
tokens = self.ids_to_tokens(token_ids) |
|
|
return self.tokens_to_midi(tokens) |
|
|
|
|
|
def midi_to_tokens(self, midi, shuffle_events=True): |
|
|
assert midi.note_num() > 0, "MIDI file must contain at least one note" |
|
|
assert midi.note_num() <= self.config.n_events, "MIDI file must contain less than n_events notes" |
|
|
|
|
|
midi = midi.copy().resample(self.config.ticks_per_beat) |
|
|
|
|
|
midi = crop_sm(midi, self.n_beats) |
|
|
|
|
|
time_signature = midi.time_signatures[-1] |
|
|
if time_signature.numerator != 4 or time_signature.denominator != 4: |
|
|
raise ValueError( |
|
|
"Only 4/4 time signature is supported for Tanja tokenizer." |
|
|
) |
|
|
|
|
|
tempo = midi.tempos[-1].qpm if len(midi.tempos) > 0 else 120 |
|
|
|
|
|
tempo_token = f"Tempo_{self.tempo_quantizer.quantize(tempo)}" |
|
|
note_tokens = [] |
|
|
|
|
|
for track in midi.tracks: |
|
|
is_drums = track.is_drum |
|
|
program_nr = track.program |
|
|
for note in track.notes: |
|
|
|
|
|
note_tokens.append(self.get_note_tokens(note, program_nr, is_drums)) |
|
|
|
|
|
note_tokens = sorted(note_tokens,key=lambda x: x) |
|
|
|
|
|
|
|
|
n_inactive_notes = self.config.n_events - len(note_tokens) |
|
|
|
|
|
for i in range(n_inactive_notes): |
|
|
note_tokens.append(self.get_inactive_note_tokens()) |
|
|
if shuffle_events: |
|
|
note_tokens = random.sample(note_tokens, len(note_tokens)) |
|
|
|
|
|
def flatten(lst): |
|
|
return [item for sublist in lst for item in sublist] |
|
|
|
|
|
tokens = [tempo_token, *flatten(note_tokens)] |
|
|
assert tokens[0].startswith("Tempo_"), "First token must be a tempo token" |
|
|
return tokens |
|
|
|
|
|
def get_prob_mask(self,idx): |
|
|
|
|
|
if idx == 0: |
|
|
return [1 if token.startswith("Tempo_") else 0 for token in self.vocab] |
|
|
else: |
|
|
attr_index = (idx-1) % len(self.event_attribute_order) |
|
|
attr_str = self.event_attribute_order[attr_index] |
|
|
return [1 if token.startswith(attr_str) else 0 for token in self.vocab] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tokens_to_midi(self, tokens): |
|
|
|
|
|
tokens = tokens.copy() |
|
|
tokens = self.remove_special_tokens(tokens) |
|
|
|
|
|
midi = symusic.Score() |
|
|
|
|
|
midi = midi.resample(self.config.ticks_per_beat) |
|
|
|
|
|
|
|
|
tempo_token = tokens.pop(0) |
|
|
tempo = int(tempo_token.split("_")[-1]) |
|
|
midi.tempos = [symusic.Tempo(qpm=tempo, time=0)] |
|
|
|
|
|
|
|
|
midi.time_signatures.append(symusic.TimeSignature(numerator=4, denominator=4, time=0)) |
|
|
|
|
|
program_notes = {} |
|
|
|
|
|
while len(tokens) > 0: |
|
|
|
|
|
note_tokens = tokens[:len(self.event_attribute_order)] |
|
|
tokens = tokens[len(self.event_attribute_order):] |
|
|
|
|
|
print(f"Note tokens: {note_tokens}") |
|
|
|
|
|
|
|
|
program_token = note_tokens[0] |
|
|
|
|
|
assert program_token.startswith("Program_"), "First token must be a program token" |
|
|
program_str = program_token.split("_")[-1] |
|
|
if program_str == "inactive": |
|
|
continue |
|
|
program = int(program_str) if program_str != "Drums" else -1 |
|
|
is_drum = program_str == "Drums" |
|
|
pitch_token = note_tokens[1] |
|
|
|
|
|
assert pitch_token.startswith("Pitch_"), "Second token must be a pitch token" |
|
|
pitch_str = pitch_token.split("_")[-1] |
|
|
if pitch_str == "inactive": |
|
|
continue |
|
|
pitch = int(pitch_str) if "Drum" not in pitch_str else int(pitch_str.split("Drum")[-1]) |
|
|
|
|
|
onset_coarse_token = note_tokens[2] |
|
|
|
|
|
assert onset_coarse_token.startswith("Onset_"), "Third token must be an onset token" |
|
|
onset_coarse_str = onset_coarse_token.split("_")[-1] |
|
|
if onset_coarse_str == "inactive": |
|
|
continue |
|
|
onset_coarse = int(onset_coarse_str) |
|
|
|
|
|
onset_fine_token = note_tokens[3] |
|
|
|
|
|
assert onset_fine_token.startswith("Microtiming_"), "Fourth token must be an onset token" |
|
|
onset_fine_str = onset_fine_token.split("_")[-1] |
|
|
if onset_fine_str == "inactive": |
|
|
continue |
|
|
onset_fine = int(onset_fine_str) |
|
|
|
|
|
offset_token = note_tokens[4] |
|
|
|
|
|
assert offset_token.startswith("Offset_"), "Fifth token must be an offset token" |
|
|
offset_str = offset_token.split("_")[-1] |
|
|
if offset_str == "inactive": |
|
|
continue |
|
|
offset = int(offset_str) |
|
|
|
|
|
duration_token = note_tokens[5] |
|
|
|
|
|
assert duration_token.startswith("Duration_"), "Sixth token must be a duration token" |
|
|
duration_str = duration_token.split("_")[-1] |
|
|
if duration_str == "inactive": |
|
|
continue |
|
|
duration = int(duration_str) |
|
|
|
|
|
velocity_token = note_tokens[6] |
|
|
|
|
|
assert velocity_token.startswith("Velocity_"), "Seventh token must be a velocity token" |
|
|
velocity_str = velocity_token.split("_")[-1] |
|
|
if velocity_str == "inactive": |
|
|
continue |
|
|
velocity = int(velocity_str) |
|
|
|
|
|
if program not in program_notes: |
|
|
program_notes[program] = [] |
|
|
|
|
|
onset_tick = onset_coarse + onset_fine |
|
|
offset_tick = offset + onset_fine |
|
|
duration = offset_tick - onset_tick |
|
|
|
|
|
program_notes[program].append( |
|
|
symusic.Note( |
|
|
time=onset_coarse + onset_fine, |
|
|
pitch=pitch, |
|
|
velocity=velocity, |
|
|
duration = duration, |
|
|
) |
|
|
) |
|
|
|
|
|
program_notes = sorted(program_notes.items(), key=lambda x: x[0]) |
|
|
|
|
|
for program, notes in program_notes: |
|
|
notes.sort(key=lambda note: (note.start, note.end, note.pitch, note.velocity)) |
|
|
|
|
|
for program, notes in program_notes: |
|
|
|
|
|
track = symusic.Track(is_drum=program == -1, program=program if program != -1 else 0) |
|
|
|
|
|
for note in notes: |
|
|
track.notes.append(note) |
|
|
|
|
|
midi.tracks.append(track) |
|
|
return midi |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class IrmaTokenizerConfig(TokenizerConfig): |
|
|
ticks_per_beat: int |
|
|
positions_per_beat : int |
|
|
tempo_range: Tuple[int, int] |
|
|
n_tempo_bins: int |
|
|
n_velocity_bins: int |
|
|
n_bars : int |
|
|
duration_ranges: List[Tuple[int, int]] |
|
|
|
|
|
def dict(self): |
|
|
return {k: str(v) for k, v in asdict(self).items()} |
|
|
|
|
|
|
|
|
class IrmaTokenizer(BaseTokenizer): |
|
|
''' |
|
|
Irma Tokenizer. |
|
|
Starts with a header that contains the time signature and tempo. |
|
|
Then, it contains the programs that will be involved (in arbitrary order). |
|
|
Then, the body starts. |
|
|
The body has one part per program, separated by the separator token. |
|
|
A body part is structured as follows: |
|
|
Track_None Program_0 BAR_None Position_12 Shift_2 Pitch_60 Velocity_100 Duration_... |
|
|
# shift is in relation to last position |
|
|
Track_None ... |
|
|
We can have multiple tracks per program. |
|
|
Offset is only present if needed. |
|
|
Only supports 4/4 time signature. |
|
|
''' |
|
|
|
|
|
config_cls = IrmaTokenizerConfig |
|
|
|
|
|
def __init__(self, config: IrmaTokenizerConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
self.ticks_per_position = self.config.ticks_per_beat / self.config.positions_per_beat |
|
|
|
|
|
self.vocab = [] |
|
|
|
|
|
self.vocab.append("BOS_None") |
|
|
self.vocab.append("EOS_None") |
|
|
self.vocab.append("SEP_None") |
|
|
self.vocab.append("PAD_None") |
|
|
self.vocab.append("Bar_None") |
|
|
self.vocab.append("Track_None") |
|
|
|
|
|
|
|
|
self.tempo_quantizer = Quantizer( |
|
|
self.config.tempo_range, self.config.n_tempo_bins, round_values=True |
|
|
) |
|
|
self.vocab.extend(f"Tempo_{tempo}" for tempo in self.tempo_quantizer.bins) |
|
|
|
|
|
|
|
|
for i in range(128): |
|
|
self.vocab.append(f"Program_{i}") |
|
|
|
|
|
self.vocab.append(f"Program_Drums") |
|
|
|
|
|
|
|
|
positions_per_bar = 4 * config.positions_per_beat |
|
|
for i in range(positions_per_bar): |
|
|
self.vocab.append(f"Position_{i}") |
|
|
|
|
|
|
|
|
n_offsets = config.ticks_per_beat / config.positions_per_beat |
|
|
for i in range(1, int(n_offsets)): |
|
|
self.vocab.append(f"Shift_{i}") |
|
|
|
|
|
|
|
|
self.vocab.extend(f"Pitch_{pitch}" for pitch in range(128)) |
|
|
|
|
|
|
|
|
self.vocab.extend(f"Pitch_Drum{pitch}" for pitch in range(128)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for dur_range in self.config.duration_ranges: |
|
|
assert self.config.ticks_per_beat % dur_range[1] == 0, f"Duration division {dur_range[1]} must be a divisor of ticks_per_beat {self.config.ticks_per_beat}" |
|
|
|
|
|
range_start = 0 |
|
|
|
|
|
self.durations = [] |
|
|
for dur_range in self.config.duration_ranges: |
|
|
range_end = dur_range[0] |
|
|
|
|
|
range_start_ticks = range_start * self.config.ticks_per_beat |
|
|
range_end_ticks = range_end * self.config.ticks_per_beat |
|
|
dur_skip_ticks = self.config.ticks_per_beat / dur_range[1] |
|
|
for i in range(range_start_ticks, range_end_ticks, int(dur_skip_ticks)): |
|
|
self.vocab.append(f"Duration_{i}d{self.config.ticks_per_beat*4}") |
|
|
self.durations.append(i) |
|
|
range_start = range_end |
|
|
|
|
|
|
|
|
self.velocity_quantizer = Quantizer( |
|
|
(1, 127), self.config.n_velocity_bins, round_values=True |
|
|
) |
|
|
self.vocab.extend(f"Velocity_{v}" for v in self.velocity_quantizer.bins) |
|
|
|
|
|
|
|
|
self.token_to_idx = {token: idx for idx, token in enumerate(self.vocab)} |
|
|
|
|
|
def midi_to_token_ids(self, midi: symusic.Score, shuffle_tracks=True) -> List[int]: |
|
|
"""Convert a MIDI score to token IDs.""" |
|
|
tokens = self.midi_to_tokens(midi, shuffle_tracks) |
|
|
return self.tokens_to_ids(tokens) |
|
|
|
|
|
def remove_special_tokens(self, tokens: List[str]) -> List[str]: |
|
|
"""Remove special tokens from the token list.""" |
|
|
special_tokens = ["BOS_None", "EOS_None", "SEP_None", "PAD_None"] |
|
|
return [token for token in tokens if token not in special_tokens] |
|
|
|
|
|
def token_ids_to_midi(self, token_ids: List[int]) -> symusic.Score: |
|
|
"""Convert token IDs back to a MIDI score.""" |
|
|
tokens = self.ids_to_tokens(token_ids) |
|
|
return self.tokens_to_midi(tokens) |
|
|
|
|
|
def get_closest_duration(self, duration: float) -> int: |
|
|
"""Get the closest duration in self.durations.""" |
|
|
return min(self.durations, key=lambda x: abs(x - duration)) |
|
|
|
|
|
def midi_to_tokens(self, midi: symusic.Score, shuffle_tracks=True) -> List[str]: |
|
|
"""Convert a MIDI score to tokens.""" |
|
|
midi = midi.copy().resample(self.config.ticks_per_beat) |
|
|
|
|
|
tempo = midi.tempos[-1].qpm if len(midi.tempos) > 0 else 120 |
|
|
time_signature = midi.time_signatures[-1] |
|
|
if time_signature.numerator != 4 or time_signature.denominator != 4: |
|
|
raise ValueError( |
|
|
"Only 4/4 time signature is supported for Irma tokenizer." |
|
|
) |
|
|
|
|
|
tempo_token = f"Tempo_{self.tempo_quantizer.quantize(tempo)}" |
|
|
|
|
|
|
|
|
tracks = [track for track in midi.tracks if len(track.notes) > 0] |
|
|
|
|
|
if shuffle_tracks: |
|
|
|
|
|
tracks = random.sample(tracks, len(tracks)) |
|
|
|
|
|
program_tokens = [] |
|
|
track_tokens = [] |
|
|
for track in tracks: |
|
|
|
|
|
if track.is_drum: |
|
|
|
|
|
program_tokens.append(f"Program_Drums") |
|
|
else: |
|
|
program_tokens.append(f"Program_{track.program}") |
|
|
|
|
|
new_track_tokens = ["Track_None"] |
|
|
|
|
|
bar_count = -1 |
|
|
curr_position = -1 |
|
|
curr_shift = 0 |
|
|
notes = track.notes.copy() |
|
|
notes.sort(key=lambda note: (note.start, note.pitch, note.velocity)) |
|
|
for note in notes: |
|
|
|
|
|
bar_idx = note.start // (self.config.ticks_per_beat * 4) |
|
|
while bar_count < bar_idx: |
|
|
new_track_tokens.append("Bar_None") |
|
|
bar_count += 1 |
|
|
curr_position = -1 |
|
|
curr_shift = 0 |
|
|
|
|
|
|
|
|
onset = note.start |
|
|
|
|
|
|
|
|
position = int(onset % (self.config.ticks_per_beat * 4) // self.ticks_per_position) |
|
|
if position != curr_position: |
|
|
new_track_tokens.append(f"Position_{position}") |
|
|
curr_position = position |
|
|
curr_shift = 0 |
|
|
|
|
|
shift = int(onset % self.ticks_per_position) |
|
|
if shift != curr_shift: |
|
|
new_track_tokens.append(f"Shift_{shift}") |
|
|
curr_shift = shift |
|
|
|
|
|
|
|
|
if track.is_drum: |
|
|
new_track_tokens.append(f"Pitch_Drum{note.pitch}") |
|
|
else: |
|
|
new_track_tokens.append(f"Pitch_{note.pitch}") |
|
|
|
|
|
|
|
|
new_track_tokens.append(f"Velocity_{self.velocity_quantizer.quantize(note.velocity)}") |
|
|
|
|
|
|
|
|
|
|
|
duration = note.end - note.start |
|
|
|
|
|
|
|
|
closest_duration = self.get_closest_duration(duration) |
|
|
|
|
|
|
|
|
new_track_tokens.append(f"Duration_{closest_duration}d{self.config.ticks_per_beat*4}") |
|
|
|
|
|
track_tokens.append(new_track_tokens) |
|
|
|
|
|
tokens = [tempo_token, *program_tokens] |
|
|
|
|
|
for track in track_tokens: |
|
|
tokens.extend(track) |
|
|
|
|
|
return tokens |
|
|
|
|
|
|
|
|
def tokens_to_midi(self, tokens): |
|
|
|
|
|
tokens = tokens.copy() |
|
|
|
|
|
tokens = self.remove_special_tokens(tokens) |
|
|
|
|
|
|
|
|
assert tokens[0].startswith("Tempo_"), "First token must be a tempo token" |
|
|
|
|
|
tempo_token = tokens.pop(0) |
|
|
|
|
|
|
|
|
program_tokens = [] |
|
|
while tokens and not tokens[0].startswith("Track_None"): |
|
|
pr_token = tokens.pop(0) |
|
|
assert pr_token.startswith("Program_"), "Program token must start with Program_" |
|
|
program_tokens.append(pr_token) |
|
|
|
|
|
|
|
|
|
|
|
midi = symusic.Score() |
|
|
|
|
|
midi = midi.resample(self.config.ticks_per_beat) |
|
|
|
|
|
|
|
|
tempo = int(tempo_token.split("_")[-1]) |
|
|
midi.tempos = [symusic.Tempo(qpm=tempo, time=0)] |
|
|
|
|
|
|
|
|
midi.time_signatures.append(symusic.TimeSignature(numerator=4, denominator=4, time=0)) |
|
|
|
|
|
|
|
|
def split_list_by_value(lst, value): |
|
|
result = [] |
|
|
current_sublist = [] |
|
|
|
|
|
for item in lst: |
|
|
if item == value: |
|
|
if current_sublist: |
|
|
result.append(current_sublist) |
|
|
current_sublist = [] |
|
|
|
|
|
else: |
|
|
current_sublist.append(item) |
|
|
|
|
|
if current_sublist: |
|
|
result.append(current_sublist) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
tokens_split_by_track = split_list_by_value(tokens, "Track_None") |
|
|
|
|
|
|
|
|
assert len(tokens_split_by_track) == len(program_tokens), "Number of tracks must be equal to number of programs" |
|
|
|
|
|
|
|
|
for track_tokens, track_program in zip(tokens_split_by_track, program_tokens): |
|
|
|
|
|
track = symusic.Track(is_drum=track_program == "Program_Drums", program=int(track_program.split("_")[-1]) if track_program != "Program_Drums" else 0) |
|
|
|
|
|
bar_count = -1 |
|
|
curr_position = 0 |
|
|
curr_shift = 0 |
|
|
for token in track_tokens: |
|
|
if token.startswith("Bar_None"): |
|
|
bar_count += 1 |
|
|
curr_position = 0 |
|
|
elif token.startswith("Position_"): |
|
|
curr_position = int(token.split("_")[-1]) |
|
|
curr_shift = 0 |
|
|
elif token.startswith("Shift_"): |
|
|
curr_shift = int(token.split("_")[-1]) |
|
|
elif token.startswith("Pitch_"): |
|
|
pitch_str = token.split("_")[-1] |
|
|
if pitch_str.startswith("Drum"): |
|
|
pitch = int(pitch_str.split("Drum")[-1]) |
|
|
else: |
|
|
pitch = int(pitch_str) |
|
|
elif token.startswith("Velocity_"): |
|
|
velocity = int(token.split("_")[-1]) |
|
|
elif token.startswith("Duration_"): |
|
|
duration = int(token.split("_")[-1].split("d")[0]) |
|
|
|
|
|
note = symusic.Note( |
|
|
time=int(bar_count * self.config.ticks_per_beat * 4 + curr_position * self.ticks_per_position + curr_shift), |
|
|
pitch=pitch, |
|
|
velocity=velocity, |
|
|
duration=duration) |
|
|
track.notes.append(note) |
|
|
|
|
|
midi.tracks.append(track) |
|
|
|
|
|
return midi |