MAI_MidiAI_Music / genprocessor.py
tlmdesign's picture
Upload 3 files
d9e88e8 verified
# -*- coding: utf-8 -*-
"""genprocessor.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1kvkhcC2RFcAMNh-jOb6NLFtX_lftKi3I
"""
#Process generated text to MIDI compatable file
import re
from typing import Dict
import mido
class GENProcessor:
def __init__(self):
self.START_TRACK = "<|START_TRACK|>"
self.END_TRACK = "<|END_TRACK|>"
self.START_METADATA = "<|START_METADATA|>"
self.END_METADATA = "<|END_METADATA|>"
self.field_order = {
"metadata": ["type", "ticks_per_beat"],
"tempo": ["type", "time", "tempo"],
"time_signature": ["type", "time", "numerator", "denominator"],
"track_name": ["type", "time", "name"],
"program_change": ["type", "time", "channel", "program"],
"control_change": ["type", "time", "channel", "control", "value"],
"note_on": ["type", "time", "channel", "note", "velocity"]
}
def sanitize_event(self, event):
"""make sure events have all required fields"""
if not event or 'type' not in event:
return None
event_type = event['type']
required_fields = {
'note_on': {'time', 'channel', 'note', 'velocity'},
'note_off': {'time', 'channel', 'note', 'velocity'},
'control_change': {'time', 'channel', 'control', 'value'},
'program_change': {'time', 'program'},
'time_signature': {'time', 'numerator', 'denominator'}
}
if event_type in required_fields:
missing_fields = required_fields[event_type] - set(event.keys())
if missing_fields:
if event_type == 'time_signature':
event['numerator'] = 4 # set default
event['denominator'] = 4 # set default
event['time'] = event.get('time', 0)
else:
return None
try:
# validate fields
if 'time' in event:
event['time'] = max(0, int(event['time']))
if 'channel' in event:
event['channel'] = max(0, int(event['channel']))
if 'note' in event:
event['note'] = max(0, int(event['note']))
if 'velocity' in event:
event['velocity'] = max(0, int(event['velocity']))
if 'control' in event:
event['control'] = max(0, int(event['control']))
if 'value' in event:
event['value'] = max(0, int(event['value']))
if 'program' in event:
event['program'] = max(0, int(event['program']))
if 'numerator' in event:
numerator = int(event['numerator'])
event['numerator'] = min(4, max(2, numerator))
if 'denominator' in event:
event['denominator'] = 4
except (ValueError, TypeError):
return None
return event
def parse_event_params(self, text: str) -> Dict:
"""Parse parameters from a line of text."""
return {p.split('=', 1)[0].strip(): p.split('=', 1)[1].strip()
for p in text.split() if '=' in p and len(p.split('=', 1)) == 2}
def decode_midi_file(self, text: str) -> Dict:
"""Decode text representation of a MIDI file into dictionary."""
# Create template with defaults in case data is missing
midi_data = {
"metadata": {
"ticks_per_beat": 480
},
"tracks": [
[ # First track always contains tempo and time signature, set defaults
{
"type": "tempo",
"time": 0,
"tempo": 500000
},
{
"type": "time_signature",
"time": 0,
"numerator": 4,
"denominator": 4
}
]
]
}
# Parse the text to get all metadata values
metadata_values = {}
for line in text.split():
if "ticks_per_beat" in line or "ticks_beat" in line:
match = re.search(r"ticks[_]?(?:per_)?beat=(\d+)", line)
if match:
metadata_values["ticks_per_beat"] = max(75, int(match.group(1)))
elif "tempo" in line and "time=0" in line:
match = re.search(r"tempo=(\d+)", line)
if match:
metadata_values["tempo"] = int(match.group(1))
# Update template with any found metadata values
if "ticks_per_beat" in metadata_values:
midi_data["metadata"]["ticks_per_beat"] = metadata_values["ticks_per_beat"]
if "tempo" in metadata_values:
midi_data["tracks"][0][0]["tempo"] = metadata_values["tempo"]
# parse the actual events
current_track = []
building_event = None
collecting_params = {}
for line in text.split():
line = line.strip()
if not line:
continue
# Skip metadata lines we already processed
if "ticks_per_beat" in line or "tempo=0" in line:
continue
# Track boundaries
if "<|START_TRACK|>" in line:
if len(midi_data["tracks"]) == 1: # If we're starting the second track
current_track = []
continue
if "<|END_TRACK|>" in line:
if current_track: # Only add non-empty tracks after the first one
midi_data["tracks"].append(current_track)
current_track = []
building_event = None
collecting_params = {}
continue
# Handle events
if line.startswith("<") and ">" in line:
if building_event and collecting_params:
full_event = {**building_event, **collecting_params}
sanitized = self.sanitize_event(full_event)
if sanitized:
current_track.append(sanitized)
event_type = re.match(r"<(\w+)>", line)
if event_type and event_type.group(1) not in ['START_METADATA', 'composer_', 'position_']:
building_event = {"type": event_type.group(1)}
params_text = line[line.find(">") + 1:].strip()
collecting_params = self.parse_event_params(params_text)
continue
# Collect additional parameters
if building_event and '=' in line:
collecting_params.update(self.parse_event_params(line))
# Add any remaining events in the last track
if current_track:
midi_data["tracks"].append(current_track)
return midi_data
def generated_tokens_to_midi(tokens, output_path):
"""Convert tokenized musical events back into an audio MIDI file."""
midi_file = mido.MidiFile(ticks_per_beat=tokens["metadata"]["ticks_per_beat"])
for track_tokens in tokens["tracks"]:
track = mido.MidiTrack()
midi_file.tracks.append(track)
last_time = 0
# sort events by time
sorted_tokens = sorted(track_tokens, key=lambda x: x["time"])
for token in sorted_tokens:
# Calculate time
delta_time = token["time"] - last_time
last_time = token["time"]
if token["type"] == "note_on":
msg = mido.Message('note_on',
channel=token["channel"],
note=token["note"],
velocity=token["velocity"],
time=int(delta_time))
track.append(msg)
elif token["type"] == "note_off":
msg = mido.Message('note_off',
channel=token["channel"],
note=token["note"],
velocity=token["velocity"],
time=int(delta_time))
track.append(msg)
elif token["type"] == "program_change":
msg = mido.Message('program_change',
channel=token["channel"],
program=token["program"],
time=int(delta_time))
track.append(msg)
elif token["type"] == "control_change":
msg = mido.Message('control_change',
channel=token["channel"],
control=token["control"],
value=token["value"],
time=int(delta_time))
track.append(msg)
elif token["type"] == "tempo":
msg = mido.MetaMessage('set_tempo',
tempo=token["tempo"],
time=int(delta_time))
track.append(msg)
elif token["type"] == "time_signature":
msg = mido.MetaMessage('time_signature',
numerator=token["numerator"],
denominator=token["denominator"],
time=int(delta_time))
track.append(msg)
elif token["type"] == "track_name":
msg = mido.MetaMessage('track_name',
name=token["name"],
time=int(delta_time))
track.append(msg)
midi_file.save(output_path)