|
|
|
|
|
"""genprocessor.ipynb |
|
|
|
|
|
Automatically generated by Colab. |
|
|
|
|
|
Original file is located at |
|
|
https://colab.research.google.com/drive/1kvkhcC2RFcAMNh-jOb6NLFtX_lftKi3I |
|
|
""" |
|
|
|
|
|
|
|
|
import re |
|
|
from typing import Dict |
|
|
import mido |
|
|
|
|
|
class GENProcessor: |
|
|
def __init__(self): |
|
|
self.START_TRACK = "<|START_TRACK|>" |
|
|
self.END_TRACK = "<|END_TRACK|>" |
|
|
self.START_METADATA = "<|START_METADATA|>" |
|
|
self.END_METADATA = "<|END_METADATA|>" |
|
|
|
|
|
self.field_order = { |
|
|
"metadata": ["type", "ticks_per_beat"], |
|
|
"tempo": ["type", "time", "tempo"], |
|
|
"time_signature": ["type", "time", "numerator", "denominator"], |
|
|
"track_name": ["type", "time", "name"], |
|
|
"program_change": ["type", "time", "channel", "program"], |
|
|
"control_change": ["type", "time", "channel", "control", "value"], |
|
|
"note_on": ["type", "time", "channel", "note", "velocity"] |
|
|
} |
|
|
|
|
|
def sanitize_event(self, event): |
|
|
"""make sure events have all required fields""" |
|
|
if not event or 'type' not in event: |
|
|
return None |
|
|
|
|
|
event_type = event['type'] |
|
|
required_fields = { |
|
|
'note_on': {'time', 'channel', 'note', 'velocity'}, |
|
|
'note_off': {'time', 'channel', 'note', 'velocity'}, |
|
|
'control_change': {'time', 'channel', 'control', 'value'}, |
|
|
'program_change': {'time', 'program'}, |
|
|
'time_signature': {'time', 'numerator', 'denominator'} |
|
|
} |
|
|
|
|
|
if event_type in required_fields: |
|
|
missing_fields = required_fields[event_type] - set(event.keys()) |
|
|
|
|
|
if missing_fields: |
|
|
if event_type == 'time_signature': |
|
|
event['numerator'] = 4 |
|
|
event['denominator'] = 4 |
|
|
event['time'] = event.get('time', 0) |
|
|
else: |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
if 'time' in event: |
|
|
event['time'] = max(0, int(event['time'])) |
|
|
if 'channel' in event: |
|
|
event['channel'] = max(0, int(event['channel'])) |
|
|
if 'note' in event: |
|
|
event['note'] = max(0, int(event['note'])) |
|
|
if 'velocity' in event: |
|
|
event['velocity'] = max(0, int(event['velocity'])) |
|
|
if 'control' in event: |
|
|
event['control'] = max(0, int(event['control'])) |
|
|
if 'value' in event: |
|
|
event['value'] = max(0, int(event['value'])) |
|
|
if 'program' in event: |
|
|
event['program'] = max(0, int(event['program'])) |
|
|
if 'numerator' in event: |
|
|
numerator = int(event['numerator']) |
|
|
event['numerator'] = min(4, max(2, numerator)) |
|
|
if 'denominator' in event: |
|
|
event['denominator'] = 4 |
|
|
|
|
|
except (ValueError, TypeError): |
|
|
return None |
|
|
|
|
|
return event |
|
|
|
|
|
|
|
|
def parse_event_params(self, text: str) -> Dict: |
|
|
"""Parse parameters from a line of text.""" |
|
|
return {p.split('=', 1)[0].strip(): p.split('=', 1)[1].strip() |
|
|
for p in text.split() if '=' in p and len(p.split('=', 1)) == 2} |
|
|
|
|
|
def decode_midi_file(self, text: str) -> Dict: |
|
|
"""Decode text representation of a MIDI file into dictionary.""" |
|
|
|
|
|
midi_data = { |
|
|
"metadata": { |
|
|
"ticks_per_beat": 480 |
|
|
}, |
|
|
"tracks": [ |
|
|
[ |
|
|
{ |
|
|
"type": "tempo", |
|
|
"time": 0, |
|
|
"tempo": 500000 |
|
|
}, |
|
|
{ |
|
|
"type": "time_signature", |
|
|
"time": 0, |
|
|
"numerator": 4, |
|
|
"denominator": 4 |
|
|
} |
|
|
] |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
metadata_values = {} |
|
|
for line in text.split(): |
|
|
if "ticks_per_beat" in line or "ticks_beat" in line: |
|
|
match = re.search(r"ticks[_]?(?:per_)?beat=(\d+)", line) |
|
|
if match: |
|
|
metadata_values["ticks_per_beat"] = max(75, int(match.group(1))) |
|
|
elif "tempo" in line and "time=0" in line: |
|
|
match = re.search(r"tempo=(\d+)", line) |
|
|
if match: |
|
|
metadata_values["tempo"] = int(match.group(1)) |
|
|
|
|
|
|
|
|
if "ticks_per_beat" in metadata_values: |
|
|
midi_data["metadata"]["ticks_per_beat"] = metadata_values["ticks_per_beat"] |
|
|
if "tempo" in metadata_values: |
|
|
midi_data["tracks"][0][0]["tempo"] = metadata_values["tempo"] |
|
|
|
|
|
|
|
|
current_track = [] |
|
|
building_event = None |
|
|
collecting_params = {} |
|
|
|
|
|
for line in text.split(): |
|
|
line = line.strip() |
|
|
if not line: |
|
|
continue |
|
|
|
|
|
|
|
|
if "ticks_per_beat" in line or "tempo=0" in line: |
|
|
continue |
|
|
|
|
|
|
|
|
if "<|START_TRACK|>" in line: |
|
|
if len(midi_data["tracks"]) == 1: |
|
|
current_track = [] |
|
|
continue |
|
|
|
|
|
if "<|END_TRACK|>" in line: |
|
|
if current_track: |
|
|
midi_data["tracks"].append(current_track) |
|
|
current_track = [] |
|
|
building_event = None |
|
|
collecting_params = {} |
|
|
continue |
|
|
|
|
|
|
|
|
if line.startswith("<") and ">" in line: |
|
|
if building_event and collecting_params: |
|
|
full_event = {**building_event, **collecting_params} |
|
|
sanitized = self.sanitize_event(full_event) |
|
|
if sanitized: |
|
|
current_track.append(sanitized) |
|
|
|
|
|
event_type = re.match(r"<(\w+)>", line) |
|
|
if event_type and event_type.group(1) not in ['START_METADATA', 'composer_', 'position_']: |
|
|
building_event = {"type": event_type.group(1)} |
|
|
params_text = line[line.find(">") + 1:].strip() |
|
|
collecting_params = self.parse_event_params(params_text) |
|
|
continue |
|
|
|
|
|
|
|
|
if building_event and '=' in line: |
|
|
collecting_params.update(self.parse_event_params(line)) |
|
|
|
|
|
|
|
|
if current_track: |
|
|
midi_data["tracks"].append(current_track) |
|
|
|
|
|
return midi_data |
|
|
|
|
|
def generated_tokens_to_midi(tokens, output_path): |
|
|
"""Convert tokenized musical events back into an audio MIDI file.""" |
|
|
midi_file = mido.MidiFile(ticks_per_beat=tokens["metadata"]["ticks_per_beat"]) |
|
|
|
|
|
for track_tokens in tokens["tracks"]: |
|
|
track = mido.MidiTrack() |
|
|
midi_file.tracks.append(track) |
|
|
|
|
|
last_time = 0 |
|
|
|
|
|
|
|
|
sorted_tokens = sorted(track_tokens, key=lambda x: x["time"]) |
|
|
|
|
|
for token in sorted_tokens: |
|
|
|
|
|
delta_time = token["time"] - last_time |
|
|
last_time = token["time"] |
|
|
|
|
|
if token["type"] == "note_on": |
|
|
msg = mido.Message('note_on', |
|
|
channel=token["channel"], |
|
|
note=token["note"], |
|
|
velocity=token["velocity"], |
|
|
time=int(delta_time)) |
|
|
track.append(msg) |
|
|
|
|
|
elif token["type"] == "note_off": |
|
|
msg = mido.Message('note_off', |
|
|
channel=token["channel"], |
|
|
note=token["note"], |
|
|
velocity=token["velocity"], |
|
|
time=int(delta_time)) |
|
|
track.append(msg) |
|
|
|
|
|
elif token["type"] == "program_change": |
|
|
msg = mido.Message('program_change', |
|
|
channel=token["channel"], |
|
|
program=token["program"], |
|
|
time=int(delta_time)) |
|
|
track.append(msg) |
|
|
|
|
|
elif token["type"] == "control_change": |
|
|
msg = mido.Message('control_change', |
|
|
channel=token["channel"], |
|
|
control=token["control"], |
|
|
value=token["value"], |
|
|
time=int(delta_time)) |
|
|
track.append(msg) |
|
|
|
|
|
elif token["type"] == "tempo": |
|
|
msg = mido.MetaMessage('set_tempo', |
|
|
tempo=token["tempo"], |
|
|
time=int(delta_time)) |
|
|
track.append(msg) |
|
|
|
|
|
elif token["type"] == "time_signature": |
|
|
msg = mido.MetaMessage('time_signature', |
|
|
numerator=token["numerator"], |
|
|
denominator=token["denominator"], |
|
|
time=int(delta_time)) |
|
|
track.append(msg) |
|
|
|
|
|
elif token["type"] == "track_name": |
|
|
msg = mido.MetaMessage('track_name', |
|
|
name=token["name"], |
|
|
time=int(delta_time)) |
|
|
track.append(msg) |
|
|
|
|
|
midi_file.save(output_path) |