|
|
"""
|
|
|
Utilities for converting to and from Midi data and encoded/tokenized data.
|
|
|
"""
|
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
import mido
|
|
|
import numpy as np
|
|
|
|
|
|
from anticipation.config import *
|
|
|
from anticipation.vocab import *
|
|
|
from anticipation.ops import unpad
|
|
|
|
|
|
|
|
|
def midi_to_interarrival(midifile, debug=False, stats=False):
|
|
|
midi = mido.MidiFile(midifile)
|
|
|
|
|
|
tokens = []
|
|
|
dt = 0
|
|
|
|
|
|
instruments = defaultdict(int)
|
|
|
tempo = 500000
|
|
|
truncations = 0
|
|
|
for message in midi:
|
|
|
dt += message.time
|
|
|
|
|
|
|
|
|
if message.time < 0:
|
|
|
raise ValueError
|
|
|
|
|
|
if message.type == 'program_change':
|
|
|
instruments[message.channel] = message.program
|
|
|
elif message.type in ['note_on', 'note_off']:
|
|
|
delta_ticks = min(round(TIME_RESOLUTION*dt), MAX_INTERARRIVAL-1)
|
|
|
if delta_ticks != round(TIME_RESOLUTION*dt):
|
|
|
truncations += 1
|
|
|
|
|
|
if delta_ticks > 0:
|
|
|
tokens.append(MIDI_TIME_OFFSET + delta_ticks)
|
|
|
|
|
|
|
|
|
inst = 128 if message.channel == 9 else instruments[message.channel]
|
|
|
offset = MIDI_START_OFFSET if message.type == 'note_on' and message.velocity > 0 else MIDI_END_OFFSET
|
|
|
tokens.append(offset + (2**7)*inst + message.note)
|
|
|
dt = 0
|
|
|
elif message.type == 'set_tempo':
|
|
|
tempo = message.tempo
|
|
|
elif message.type == 'time_signature':
|
|
|
pass
|
|
|
elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']:
|
|
|
pass
|
|
|
elif message.type == 'control_change':
|
|
|
pass
|
|
|
elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature',
|
|
|
'copyright', 'marker', 'instrument_name', 'cue_marker',
|
|
|
'device_name', 'sequence_number']:
|
|
|
pass
|
|
|
elif message.type == 'channel_prefix':
|
|
|
pass
|
|
|
elif message.type in ['midi_port', 'smpte_offset', 'sysex']:
|
|
|
pass
|
|
|
else:
|
|
|
if debug:
|
|
|
print('UNHANDLED MESSAGE', message.type, message)
|
|
|
|
|
|
if stats:
|
|
|
return tokens, truncations
|
|
|
|
|
|
return tokens
|
|
|
|
|
|
|
|
|
def interarrival_to_midi(tokens, debug=False):
|
|
|
mid = mido.MidiFile()
|
|
|
mid.ticks_per_beat = TIME_RESOLUTION // 2
|
|
|
|
|
|
track_idx = {}
|
|
|
time_in_ticks = 0
|
|
|
num_tracks = 0
|
|
|
for token in tokens:
|
|
|
if token == MIDI_SEPARATOR:
|
|
|
continue
|
|
|
|
|
|
if token < MIDI_START_OFFSET:
|
|
|
time_in_ticks += token - MIDI_TIME_OFFSET
|
|
|
elif token < MIDI_END_OFFSET:
|
|
|
token -= MIDI_START_OFFSET
|
|
|
instrument = token // 2**7
|
|
|
pitch = token - (2**7)*instrument
|
|
|
|
|
|
try:
|
|
|
track, previous_time, idx = track_idx[instrument]
|
|
|
except KeyError:
|
|
|
idx = num_tracks
|
|
|
previous_time = 0
|
|
|
track = mido.MidiTrack()
|
|
|
mid.tracks.append(track)
|
|
|
if instrument == 128:
|
|
|
idx = 9
|
|
|
message = mido.Message('program_change', channel=idx, program=0)
|
|
|
else:
|
|
|
message = mido.Message('program_change', channel=idx, program=instrument)
|
|
|
track.append(message)
|
|
|
num_tracks += 1
|
|
|
if num_tracks == 9:
|
|
|
num_tracks += 1
|
|
|
|
|
|
track.append(mido.Message('note_on', note=pitch, channel=idx, velocity=96, time=time_in_ticks-previous_time))
|
|
|
track_idx[instrument] = (track, time_in_ticks, idx)
|
|
|
else:
|
|
|
token -= MIDI_END_OFFSET
|
|
|
instrument = token // 2**7
|
|
|
pitch = token - (2**7)*instrument
|
|
|
|
|
|
try:
|
|
|
track, previous_time, idx = track_idx[instrument]
|
|
|
except KeyError:
|
|
|
|
|
|
if debug:
|
|
|
print('IGNORING bad offset')
|
|
|
|
|
|
continue
|
|
|
|
|
|
track.append(mido.Message('note_off', note=pitch, channel=idx, time=time_in_ticks-previous_time))
|
|
|
track_idx[instrument] = (track, time_in_ticks, idx)
|
|
|
|
|
|
return mid
|
|
|
|
|
|
|
|
|
def midi_to_compound(midifile, debug=False, quantize=True):
|
|
|
"""
|
|
|
Converts midi file to a compound tokenization that stores each note as
|
|
|
a 5-tuple of (time, duration, note, instrument, velocity).
|
|
|
|
|
|
Note that mido measures the time of a midi message in seconds, which we multiply by
|
|
|
TIME_RESOLUTION = 10ms to get a time in 10ms ticks.
|
|
|
"""
|
|
|
|
|
|
if isinstance(midifile, (str, np.generic)):
|
|
|
midi = mido.MidiFile(str(midifile))
|
|
|
else:
|
|
|
midi = midifile
|
|
|
|
|
|
tokens = []
|
|
|
note_idx = 0
|
|
|
open_notes = defaultdict(list)
|
|
|
|
|
|
time = 0
|
|
|
instruments = defaultdict(int)
|
|
|
tempo = 500000
|
|
|
for message in midi:
|
|
|
time += message.time
|
|
|
|
|
|
|
|
|
if message.time < 0:
|
|
|
raise ValueError
|
|
|
|
|
|
if message.type == 'program_change':
|
|
|
instruments[message.channel] = message.program
|
|
|
elif message.type in ['note_on', 'note_off']:
|
|
|
|
|
|
instr = 128 if message.channel == 9 else instruments[message.channel]
|
|
|
|
|
|
if message.type == 'note_on' and message.velocity > 0:
|
|
|
|
|
|
if quantize:
|
|
|
time_in_ticks = round(TIME_RESOLUTION*time)
|
|
|
else:
|
|
|
time_in_ticks = TIME_RESOLUTION*time
|
|
|
|
|
|
|
|
|
tokens.append(time_in_ticks)
|
|
|
tokens.append(-1)
|
|
|
tokens.append(message.note)
|
|
|
tokens.append(instr)
|
|
|
tokens.append(message.velocity)
|
|
|
|
|
|
open_notes[(instr,message.note,message.channel)].append((note_idx, time))
|
|
|
note_idx += 1
|
|
|
else:
|
|
|
try:
|
|
|
open_idx, onset_time = open_notes[(instr,message.note,message.channel)].pop(0)
|
|
|
except IndexError:
|
|
|
if debug:
|
|
|
print('WARNING: ignoring bad offset')
|
|
|
else:
|
|
|
duration_ticks = round(TIME_RESOLUTION*(time-onset_time))
|
|
|
tokens[5*open_idx + 1] = duration_ticks
|
|
|
|
|
|
elif message.type == 'set_tempo':
|
|
|
tempo = message.tempo
|
|
|
elif message.type == 'time_signature':
|
|
|
pass
|
|
|
elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']:
|
|
|
pass
|
|
|
elif message.type == 'control_change':
|
|
|
pass
|
|
|
elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature',
|
|
|
'copyright', 'marker', 'instrument_name', 'cue_marker',
|
|
|
'device_name', 'sequence_number']:
|
|
|
pass
|
|
|
elif message.type == 'channel_prefix':
|
|
|
pass
|
|
|
elif message.type in ['midi_port', 'smpte_offset', 'sysex']:
|
|
|
pass
|
|
|
else:
|
|
|
if debug:
|
|
|
print('UNHANDLED MESSAGE', message.type, message)
|
|
|
|
|
|
unclosed_count = 0
|
|
|
for _,v in open_notes.items():
|
|
|
unclosed_count += len(v)
|
|
|
|
|
|
if debug and unclosed_count > 0:
|
|
|
print(f'WARNING: {unclosed_count} unclosed notes')
|
|
|
print(' ', midifile)
|
|
|
|
|
|
return tokens
|
|
|
|
|
|
|
|
|
def compound_to_midi(tokens, debug=False):
|
|
|
mid = mido.MidiFile()
|
|
|
mid.ticks_per_beat = TIME_RESOLUTION // 2
|
|
|
|
|
|
it = iter(tokens)
|
|
|
time_index = defaultdict(list)
|
|
|
for _, (time_in_ticks,duration,note,instrument,velocity) in enumerate(zip(it,it,it,it,it)):
|
|
|
time_index[(time_in_ticks,0)].append((note, instrument, velocity))
|
|
|
time_index[(time_in_ticks+duration,1)].append((note, instrument, velocity))
|
|
|
|
|
|
track_idx = {}
|
|
|
num_tracks = 0
|
|
|
for time_in_ticks, event_type in sorted(time_index.keys()):
|
|
|
for (note, instrument, velocity) in time_index[(time_in_ticks, event_type)]:
|
|
|
if event_type == 0:
|
|
|
try:
|
|
|
track, previous_time, idx = track_idx[instrument]
|
|
|
except KeyError:
|
|
|
idx = num_tracks
|
|
|
previous_time = 0
|
|
|
track = mido.MidiTrack()
|
|
|
mid.tracks.append(track)
|
|
|
if instrument == 128:
|
|
|
idx = 9
|
|
|
message = mido.Message('program_change', channel=idx, program=0)
|
|
|
else:
|
|
|
message = mido.Message('program_change', channel=idx, program=instrument)
|
|
|
track.append(message)
|
|
|
num_tracks += 1
|
|
|
if num_tracks == 9:
|
|
|
num_tracks += 1
|
|
|
|
|
|
track.append(mido.Message(
|
|
|
'note_on', note=note, channel=idx, velocity=velocity,
|
|
|
time=time_in_ticks-previous_time))
|
|
|
track_idx[instrument] = (track, time_in_ticks, idx)
|
|
|
else:
|
|
|
try:
|
|
|
track, previous_time, idx = track_idx[instrument]
|
|
|
except KeyError:
|
|
|
|
|
|
if debug:
|
|
|
print('IGNORING bad offset')
|
|
|
|
|
|
continue
|
|
|
|
|
|
track.append(mido.Message(
|
|
|
'note_off', note=note, channel=idx,
|
|
|
time=time_in_ticks-previous_time))
|
|
|
track_idx[instrument] = (track, time_in_ticks, idx)
|
|
|
|
|
|
return mid
|
|
|
|
|
|
|
|
|
def compound_to_events(tokens, stats=False):
|
|
|
"""
|
|
|
Converts a compound tokenization to a sequence of events according to Definition 2.2
|
|
|
in the anticipation paper, removing velocity and instrument and combining note as a pitch
|
|
|
and instrument.
|
|
|
"""
|
|
|
assert len(tokens) % 5 == 0
|
|
|
tokens = tokens.copy()
|
|
|
|
|
|
|
|
|
del tokens[4::5]
|
|
|
|
|
|
|
|
|
assert all(-1 <= tok < 2**7 for tok in tokens[2::4])
|
|
|
assert all(-1 <= tok < 129 for tok in tokens[3::4])
|
|
|
tokens[2::4] = [SEPARATOR if note == -1 else MAX_PITCH*instr + note
|
|
|
for note, instr in zip(tokens[2::4],tokens[3::4])]
|
|
|
tokens[2::4] = [NOTE_OFFSET + tok for tok in tokens[2::4]]
|
|
|
del tokens[3::4]
|
|
|
|
|
|
|
|
|
truncations = sum([1 for tok in tokens[1::3] if tok >= MAX_DUR])
|
|
|
tokens[1::3] = [TIME_RESOLUTION//4 if tok == -1 else min(tok, MAX_DUR-1)
|
|
|
for tok in tokens[1::3]]
|
|
|
tokens[1::3] = [DUR_OFFSET + tok for tok in tokens[1::3]]
|
|
|
|
|
|
assert min(tokens[0::3]) >= 0
|
|
|
tokens[0::3] = [TIME_OFFSET + tok for tok in tokens[0::3]]
|
|
|
|
|
|
assert len(tokens) % 3 == 0
|
|
|
|
|
|
if stats:
|
|
|
return tokens, truncations
|
|
|
|
|
|
return tokens
|
|
|
|
|
|
|
|
|
def events_to_compound(tokens, debug=False):
|
|
|
tokens = unpad(tokens)
|
|
|
|
|
|
|
|
|
tokens = [tok - CONTROL_OFFSET if tok >= CONTROL_OFFSET and tok != SEPARATOR else tok
|
|
|
for tok in tokens]
|
|
|
|
|
|
|
|
|
tokens[0::3] = [tok - TIME_OFFSET if tok != SEPARATOR else tok for tok in tokens[0::3]]
|
|
|
tokens[1::3] = [tok - DUR_OFFSET if tok != SEPARATOR else tok for tok in tokens[1::3]]
|
|
|
tokens[2::3] = [tok - NOTE_OFFSET if tok != SEPARATOR else tok for tok in tokens[2::3]]
|
|
|
|
|
|
offset = 0
|
|
|
track_max = 0
|
|
|
for j, (time,dur,note) in enumerate(zip(tokens[0::3],tokens[1::3],tokens[2::3])):
|
|
|
if note == SEPARATOR:
|
|
|
offset += track_max
|
|
|
track_max = 0
|
|
|
if debug:
|
|
|
print('Sequence Boundary')
|
|
|
else:
|
|
|
track_max = max(track_max, time+dur)
|
|
|
tokens[3*j] += offset
|
|
|
|
|
|
|
|
|
assert len([tok for tok in tokens if tok == SEPARATOR]) % 3 == 0
|
|
|
tokens = [tok for tok in tokens if tok != SEPARATOR]
|
|
|
|
|
|
assert len(tokens) % 3 == 0
|
|
|
out = 5*(len(tokens)//3)*[0]
|
|
|
out[0::5] = tokens[0::3]
|
|
|
out[1::5] = tokens[1::3]
|
|
|
out[2::5] = [tok - (2**7)*(tok//2**7) for tok in tokens[2::3]]
|
|
|
out[3::5] = [tok//2**7 for tok in tokens[2::3]]
|
|
|
out[4::5] = (len(tokens)//3)*[72]
|
|
|
|
|
|
assert max(out[1::5]) < MAX_DUR
|
|
|
assert max(out[2::5]) < MAX_PITCH
|
|
|
assert max(out[3::5]) < MAX_INSTR
|
|
|
assert all(tok >= 0 for tok in out)
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
def events_to_midi(tokens, debug=False):
|
|
|
return compound_to_midi(events_to_compound(tokens, debug=debug), debug=debug)
|
|
|
|
|
|
def midi_to_events(midifile, debug=False, quantize=True):
|
|
|
return compound_to_events(midi_to_compound(midifile, debug=debug, quantize=quantize))
|
|
|
|