Spaces:
Sleeping
Sleeping
| """ | |
| Utilities for converting to and from Midi data and encoded/tokenized data. | |
| """ | |
| from collections import defaultdict | |
| import mido | |
| from anticipation.config import * | |
| from anticipation.vocab import * | |
| from anticipation.ops import unpad | |
| def midi_to_interarrival(midifile, debug=False, stats=False): | |
| midi = mido.MidiFile(midifile) | |
| tokens = [] | |
| dt = 0 | |
| instruments = defaultdict(int) # default to code 0 = piano | |
| tempo = 500000 # default tempo: 500000 microseconds per beat | |
| truncations = 0 | |
| for message in midi: | |
| dt += message.time | |
| # sanity check: negative time? | |
| if message.time < 0: | |
| raise ValueError | |
| if message.type == 'program_change': | |
| instruments[message.channel] = message.program | |
| elif message.type in ['note_on', 'note_off']: | |
| delta_ticks = min(round(TIME_RESOLUTION*dt), MAX_INTERARRIVAL-1) | |
| if delta_ticks != round(TIME_RESOLUTION*dt): | |
| truncations += 1 | |
| if delta_ticks > 0: # if time elapsed since last token | |
| tokens.append(MIDI_TIME_OFFSET + delta_ticks) # add a time step event | |
| # special case: channel 9 is drums! | |
| inst = 128 if message.channel == 9 else instruments[message.channel] | |
| offset = MIDI_START_OFFSET if message.type == 'note_on' and message.velocity > 0 else MIDI_END_OFFSET | |
| tokens.append(offset + (2**7)*inst + message.note) | |
| dt = 0 | |
| elif message.type == 'set_tempo': | |
| tempo = message.tempo | |
| elif message.type == 'time_signature': | |
| pass # we use real time | |
| elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']: | |
| pass # we don't attempt to model these | |
| elif message.type == 'control_change': | |
| pass # this includes pedal and per-track volume: ignore for now | |
| elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature', | |
| 'copyright', 'marker', 'instrument_name', 'cue_marker', | |
| 'device_name', 'sequence_number']: | |
| pass # possibly useful metadata but ignore for now | |
| elif message.type == 'channel_prefix': | |
| pass # relatively common, but can we ignore this? | |
| elif message.type in ['midi_port', 'smpte_offset', 'sysex']: | |
| pass # I have no idea what this is | |
| else: | |
| if debug: | |
| print('UNHANDLED MESSAGE', message.type, message) | |
| if stats: | |
| return tokens, truncations | |
| return tokens | |
| def interarrival_to_midi(tokens, debug=False): | |
| mid = mido.MidiFile() | |
| mid.ticks_per_beat = TIME_RESOLUTION // 2 # 2 beats/second at quarter=120 | |
| track_idx = {} # maps instrument to (track number, current time) | |
| time_in_ticks = 0 | |
| num_tracks = 0 | |
| for token in tokens: | |
| if token == MIDI_SEPARATOR: | |
| continue | |
| if token < MIDI_START_OFFSET: | |
| time_in_ticks += token - MIDI_TIME_OFFSET | |
| elif token < MIDI_END_OFFSET: | |
| token -= MIDI_START_OFFSET | |
| instrument = token // 2**7 | |
| pitch = token - (2**7)*instrument | |
| try: | |
| track, previous_time, idx = track_idx[instrument] | |
| except KeyError: | |
| idx = num_tracks | |
| previous_time = 0 | |
| track = mido.MidiTrack() | |
| mid.tracks.append(track) | |
| if instrument == 128: # drums always go on channel 9 | |
| idx = 9 | |
| message = mido.Message('program_change', channel=idx, program=0) | |
| else: | |
| message = mido.Message('program_change', channel=idx, program=instrument) | |
| track.append(message) | |
| num_tracks += 1 | |
| if num_tracks == 9: | |
| num_tracks += 1 # skip the drums track | |
| track.append(mido.Message('note_on', note=pitch, channel=idx, velocity=96, time=time_in_ticks-previous_time)) | |
| track_idx[instrument] = (track, time_in_ticks, idx) | |
| else: | |
| token -= MIDI_END_OFFSET | |
| instrument = token // 2**7 | |
| pitch = token - (2**7)*instrument | |
| try: | |
| track, previous_time, idx = track_idx[instrument] | |
| except KeyError: | |
| # shouldn't happen because we should have a corresponding onset | |
| if debug: | |
| print('IGNORING bad offset') | |
| continue | |
| track.append(mido.Message('note_off', note=pitch, channel=idx, time=time_in_ticks-previous_time)) | |
| track_idx[instrument] = (track, time_in_ticks, idx) | |
| return mid | |
| def midi_to_compound(midifile, debug=False): | |
| if type(midifile) == str: | |
| midi = mido.MidiFile(midifile) | |
| else: | |
| midi = midifile | |
| tokens = [] | |
| note_idx = 0 | |
| open_notes = defaultdict(list) | |
| time = 0 | |
| instruments = defaultdict(lambda: {'program': 0, 'channel': None}) # Track channel assignments | |
| next_channel = 0 | |
| tempo = 500000 # default tempo: 500000 microseconds per beat | |
| for message in midi: | |
| time += message.time | |
| # sanity check: negative time? | |
| if message.time < 0: | |
| raise ValueError | |
| if message.type == 'program_change': | |
| # Reserve channels 0-8, 10-15 (skip 9 for drums) | |
| if message.channel != 9 and message.channel not in instruments: | |
| instruments[message.channel]['program'] = message.program | |
| instruments[message.channel]['channel'] = next_channel | |
| next_channel += 1 | |
| if next_channel == 9: # Skip channel 9 (drums) | |
| next_channel = 10 | |
| elif message.type in ['note_on', 'note_off']: | |
| # special case: channel 9 is drums! | |
| instr = 128 if message.channel == 9 else instruments[message.channel]['program'] | |
| channel = 9 if message.channel == 9 else instruments[message.channel]['channel'] | |
| compound_instr = (instr << 4) | channel | |
| if message.type == 'note_on' and message.velocity > 0: # onset | |
| # time quantization | |
| time_in_ticks = round(TIME_RESOLUTION*time) | |
| # Our compound word is: (time, duration, note, instr, velocity) | |
| tokens.append(time_in_ticks) # 5ms resolution | |
| tokens.append(-1) # placeholder (we'll fill this in later) | |
| tokens.append(message.note) | |
| tokens.append(compound_instr) | |
| tokens.append(message.velocity) | |
| open_notes[(instr,message.note,message.channel)].append((note_idx, time)) | |
| note_idx += 1 | |
| else: # offset | |
| try: | |
| open_idx, onset_time = open_notes[(instr,message.note,message.channel)].pop(0) | |
| except IndexError: | |
| if debug: | |
| print('WARNING: ignoring bad offset') | |
| else: | |
| duration_ticks = round(TIME_RESOLUTION*(time-onset_time)) | |
| tokens[5*open_idx + 1] = duration_ticks | |
| #del open_notes[(instr,message.note,message.channel)] | |
| elif message.type == 'set_tempo': | |
| tempo = message.tempo | |
| elif message.type == 'time_signature': | |
| pass # we use real time | |
| elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']: | |
| pass # we don't attempt to model these | |
| elif message.type == 'control_change': | |
| pass # this includes pedal and per-track volume: ignore for now | |
| elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature', | |
| 'copyright', 'marker', 'instrument_name', 'cue_marker', | |
| 'device_name', 'sequence_number']: | |
| pass # possibly useful metadata but ignore for now | |
| elif message.type == 'channel_prefix': | |
| pass # relatively common, but can we ignore this? | |
| elif message.type in ['midi_port', 'smpte_offset', 'sysex']: | |
| pass # I have no idea what this is | |
| else: | |
| if debug: | |
| print('UNHANDLED MESSAGE', message.type, message) | |
| unclosed_count = 0 | |
| for _,v in open_notes.items(): | |
| unclosed_count += len(v) | |
| if debug and unclosed_count > 0: | |
| print(f'WARNING: {unclosed_count} unclosed notes') | |
| print(' ', midifile) | |
| return tokens | |
| def compound_to_midi(tokens, debug=False): | |
| mid = mido.MidiFile() | |
| mid.ticks_per_beat = TIME_RESOLUTION // 2 # 2 beats/second at quarter=120 | |
| tracks = {} | |
| for token in tokens: | |
| # Decode program and channel | |
| program = (token >> 4) & 0x7F | |
| channel = token & 0x0F | |
| if (program, channel) not in tracks: | |
| track = mido.MidiTrack() | |
| mid.tracks.append(track) | |
| tracks[(program, channel)] = track | |
| track.append(mido.Message('program_change', | |
| program=program, | |
| channel=channel)) | |
| it = iter(tokens) | |
| time_index = defaultdict(list) | |
| for _, (time_in_ticks,duration,note,instrument,velocity) in enumerate(zip(it,it,it,it,it)): | |
| time_index[(time_in_ticks,0)].append((note, instrument, velocity)) # 0 = onset | |
| time_index[(time_in_ticks+duration,1)].append((note, instrument, velocity)) # 1 = offset | |
| track_idx = {} # maps instrument to (track number, current time) | |
| num_tracks = 0 | |
| for time_in_ticks, event_type in sorted(time_index.keys()): | |
| for (note, instrument, velocity) in time_index[(time_in_ticks, event_type)]: | |
| if event_type == 0: # onset | |
| try: | |
| track, previous_time, idx = track_idx[instrument] | |
| except KeyError: | |
| idx = num_tracks | |
| previous_time = 0 | |
| track = mido.MidiTrack() | |
| mid.tracks.append(track) | |
| if instrument == 128: # drums always go on channel 9 | |
| idx = 9 | |
| message = mido.Message('program_change', channel=idx, program=0) | |
| else: | |
| message = mido.Message('program_change', channel=idx, program=instrument) | |
| track.append(message) | |
| num_tracks += 1 | |
| if num_tracks == 9: | |
| num_tracks += 1 # skip the drums track | |
| track.append(mido.Message( | |
| 'note_on', note=note, channel=idx, velocity=velocity, | |
| time=time_in_ticks-previous_time)) | |
| track_idx[instrument] = (track, time_in_ticks, idx) | |
| else: # offset | |
| try: | |
| track, previous_time, idx = track_idx[instrument] | |
| except KeyError: | |
| # shouldn't happen because we should have a corresponding onset | |
| if debug: | |
| print('IGNORING bad offset') | |
| continue | |
| track.append(mido.Message( | |
| 'note_off', note=note, channel=idx, | |
| time=time_in_ticks-previous_time)) | |
| track_idx[instrument] = (track, time_in_ticks, idx) | |
| return mid | |
| def compound_to_events(tokens, stats=False): | |
| assert len(tokens) % 5 == 0 | |
| tokens = tokens.copy() | |
| # remove velocities | |
| del tokens[4::5] | |
| # combine (note, instrument) | |
| assert all(-1 <= tok < 2**7 for tok in tokens[2::4]) | |
| assert all(-1 <= tok < 129 for tok in tokens[3::4]) | |
| tokens[2::4] = [SEPARATOR if note == -1 else MAX_PITCH*instr + note | |
| for note, instr in zip(tokens[2::4],tokens[3::4])] | |
| tokens[2::4] = [NOTE_OFFSET + tok for tok in tokens[2::4]] | |
| del tokens[3::4] | |
| # max duration cutoff and set unknown durations to 250ms | |
| truncations = sum([1 for tok in tokens[1::3] if tok >= MAX_DUR]) | |
| tokens[1::3] = [TIME_RESOLUTION//4 if tok == -1 else min(tok, MAX_DUR-1) | |
| for tok in tokens[1::3]] | |
| tokens[1::3] = [DUR_OFFSET + tok for tok in tokens[1::3]] | |
| assert min(tokens[0::3]) >= 0 | |
| tokens[0::3] = [TIME_OFFSET + tok for tok in tokens[0::3]] | |
| assert len(tokens) % 3 == 0 | |
| if stats: | |
| return tokens, truncations | |
| return tokens | |
| def events_to_compound(tokens, debug=False): | |
| tokens = unpad(tokens) | |
| # move all tokens to zero-offset for synthesis | |
| tokens = [tok - CONTROL_OFFSET if tok >= CONTROL_OFFSET and tok != SEPARATOR else tok | |
| for tok in tokens] | |
| # remove type offsets | |
| tokens[0::3] = [tok - TIME_OFFSET if tok != SEPARATOR else tok for tok in tokens[0::3]] | |
| tokens[1::3] = [tok - DUR_OFFSET if tok != SEPARATOR else tok for tok in tokens[1::3]] | |
| tokens[2::3] = [tok - NOTE_OFFSET if tok != SEPARATOR else tok for tok in tokens[2::3]] | |
| offset = 0 # add max time from previous track for synthesis | |
| track_max = 0 # keep track of max time in track | |
| for j, (time,dur,note) in enumerate(zip(tokens[0::3],tokens[1::3],tokens[2::3])): | |
| if note == SEPARATOR: | |
| offset += track_max | |
| track_max = 0 | |
| if debug: | |
| print('Sequence Boundary') | |
| else: | |
| track_max = max(track_max, time+dur) | |
| tokens[3*j] += offset | |
| # strip sequence separators | |
| assert len([tok for tok in tokens if tok == SEPARATOR]) % 3 == 0 | |
| tokens = [tok for tok in tokens if tok != SEPARATOR] | |
| assert len(tokens) % 3 == 0 | |
| out = 5*(len(tokens)//3)*[0] | |
| out[0::5] = tokens[0::3] | |
| out[1::5] = tokens[1::3] | |
| out[2::5] = [tok - (2**7)*(tok//2**7) for tok in tokens[2::3]] | |
| out[3::5] = [tok//2**7 for tok in tokens[2::3]] | |
| out[4::5] = (len(tokens)//3)*[72] # default velocity | |
| assert max(out[1::5]) < MAX_DUR | |
| assert max(out[2::5]) < MAX_PITCH | |
| assert max(out[3::5]) < MAX_INSTR | |
| assert all(tok >= 0 for tok in out) | |
| return out | |
| def events_to_midi(tokens, debug=False): | |
| return compound_to_midi(events_to_compound(tokens, debug=debug), debug=debug) | |
| def midi_to_events(midifile, debug=False): | |
| return compound_to_events(midi_to_compound(midifile, debug=debug)) | |