Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| """genprocessor.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1kvkhcC2RFcAMNh-jOb6NLFtX_lftKi3I | |
| """ | |
| #Process generated text to MIDI compatable file | |
| import re | |
| from typing import Dict | |
| import mido | |
| class GENProcessor: | |
| def __init__(self): | |
| self.START_TRACK = "<|START_TRACK|>" | |
| self.END_TRACK = "<|END_TRACK|>" | |
| self.START_METADATA = "<|START_METADATA|>" | |
| self.END_METADATA = "<|END_METADATA|>" | |
| self.field_order = { | |
| "metadata": ["type", "ticks_per_beat"], | |
| "tempo": ["type", "time", "tempo"], | |
| "time_signature": ["type", "time", "numerator", "denominator"], | |
| "track_name": ["type", "time", "name"], | |
| "program_change": ["type", "time", "channel", "program"], | |
| "control_change": ["type", "time", "channel", "control", "value"], | |
| "note_on": ["type", "time", "channel", "note", "velocity"] | |
| } | |
| def sanitize_event(self, event): | |
| """make sure events have all required fields""" | |
| if not event or 'type' not in event: | |
| return None | |
| event_type = event['type'] | |
| required_fields = { | |
| 'note_on': {'time', 'channel', 'note', 'velocity'}, | |
| 'note_off': {'time', 'channel', 'note', 'velocity'}, | |
| 'control_change': {'time', 'channel', 'control', 'value'}, | |
| 'program_change': {'time', 'program'}, | |
| 'time_signature': {'time', 'numerator', 'denominator'} | |
| } | |
| if event_type in required_fields: | |
| missing_fields = required_fields[event_type] - set(event.keys()) | |
| if missing_fields: | |
| if event_type == 'time_signature': | |
| event['numerator'] = 4 # set default | |
| event['denominator'] = 4 # set default | |
| event['time'] = event.get('time', 0) | |
| else: | |
| return None | |
| try: | |
| # validate fields | |
| if 'time' in event: | |
| event['time'] = max(0, int(event['time'])) | |
| if 'channel' in event: | |
| event['channel'] = max(0, int(event['channel'])) | |
| if 'note' in event: | |
| event['note'] = max(0, int(event['note'])) | |
| if 'velocity' in event: | |
| event['velocity'] = max(0, int(event['velocity'])) | |
| if 'control' in event: | |
| event['control'] = max(0, int(event['control'])) | |
| if 'value' in event: | |
| event['value'] = max(0, int(event['value'])) | |
| if 'program' in event: | |
| event['program'] = max(0, int(event['program'])) | |
| if 'numerator' in event: | |
| numerator = int(event['numerator']) | |
| event['numerator'] = min(4, max(2, numerator)) | |
| if 'denominator' in event: | |
| event['denominator'] = 4 | |
| except (ValueError, TypeError): | |
| return None | |
| return event | |
| def parse_event_params(self, text: str) -> Dict: | |
| """Parse parameters from a line of text.""" | |
| return {p.split('=', 1)[0].strip(): p.split('=', 1)[1].strip() | |
| for p in text.split() if '=' in p and len(p.split('=', 1)) == 2} | |
| def decode_midi_file(self, text: str) -> Dict: | |
| """Decode text representation of a MIDI file into dictionary.""" | |
| # Create template with defaults in case data is missing | |
| midi_data = { | |
| "metadata": { | |
| "ticks_per_beat": 480 | |
| }, | |
| "tracks": [ | |
| [ # First track always contains tempo and time signature, set defaults | |
| { | |
| "type": "tempo", | |
| "time": 0, | |
| "tempo": 500000 | |
| }, | |
| { | |
| "type": "time_signature", | |
| "time": 0, | |
| "numerator": 4, | |
| "denominator": 4 | |
| } | |
| ] | |
| ] | |
| } | |
| # Parse the text to get all metadata values | |
| metadata_values = {} | |
| for line in text.split(): | |
| if "ticks_per_beat" in line or "ticks_beat" in line: | |
| match = re.search(r"ticks[_]?(?:per_)?beat=(\d+)", line) | |
| if match: | |
| metadata_values["ticks_per_beat"] = max(75, int(match.group(1))) | |
| elif "tempo" in line and "time=0" in line: | |
| match = re.search(r"tempo=(\d+)", line) | |
| if match: | |
| metadata_values["tempo"] = int(match.group(1)) | |
| # Update template with any found metadata values | |
| if "ticks_per_beat" in metadata_values: | |
| midi_data["metadata"]["ticks_per_beat"] = metadata_values["ticks_per_beat"] | |
| if "tempo" in metadata_values: | |
| midi_data["tracks"][0][0]["tempo"] = metadata_values["tempo"] | |
| # parse the actual events | |
| current_track = [] | |
| building_event = None | |
| collecting_params = {} | |
| for line in text.split(): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| # Skip metadata lines we already processed | |
| if "ticks_per_beat" in line or "tempo=0" in line: | |
| continue | |
| # Track boundaries | |
| if "<|START_TRACK|>" in line: | |
| if len(midi_data["tracks"]) == 1: # If we're starting the second track | |
| current_track = [] | |
| continue | |
| if "<|END_TRACK|>" in line: | |
| if current_track: # Only add non-empty tracks after the first one | |
| midi_data["tracks"].append(current_track) | |
| current_track = [] | |
| building_event = None | |
| collecting_params = {} | |
| continue | |
| # Handle events | |
| if line.startswith("<") and ">" in line: | |
| if building_event and collecting_params: | |
| full_event = {**building_event, **collecting_params} | |
| sanitized = self.sanitize_event(full_event) | |
| if sanitized: | |
| current_track.append(sanitized) | |
| event_type = re.match(r"<(\w+)>", line) | |
| if event_type and event_type.group(1) not in ['START_METADATA', 'composer_', 'position_']: | |
| building_event = {"type": event_type.group(1)} | |
| params_text = line[line.find(">") + 1:].strip() | |
| collecting_params = self.parse_event_params(params_text) | |
| continue | |
| # Collect additional parameters | |
| if building_event and '=' in line: | |
| collecting_params.update(self.parse_event_params(line)) | |
| # Add any remaining events in the last track | |
| if current_track: | |
| midi_data["tracks"].append(current_track) | |
| return midi_data | |
| def generated_tokens_to_midi(tokens, output_path): | |
| """Convert tokenized musical events back into an audio MIDI file.""" | |
| midi_file = mido.MidiFile(ticks_per_beat=tokens["metadata"]["ticks_per_beat"]) | |
| for track_tokens in tokens["tracks"]: | |
| track = mido.MidiTrack() | |
| midi_file.tracks.append(track) | |
| last_time = 0 | |
| # sort events by time | |
| sorted_tokens = sorted(track_tokens, key=lambda x: x["time"]) | |
| for token in sorted_tokens: | |
| # Calculate time | |
| delta_time = token["time"] - last_time | |
| last_time = token["time"] | |
| if token["type"] == "note_on": | |
| msg = mido.Message('note_on', | |
| channel=token["channel"], | |
| note=token["note"], | |
| velocity=token["velocity"], | |
| time=int(delta_time)) | |
| track.append(msg) | |
| elif token["type"] == "note_off": | |
| msg = mido.Message('note_off', | |
| channel=token["channel"], | |
| note=token["note"], | |
| velocity=token["velocity"], | |
| time=int(delta_time)) | |
| track.append(msg) | |
| elif token["type"] == "program_change": | |
| msg = mido.Message('program_change', | |
| channel=token["channel"], | |
| program=token["program"], | |
| time=int(delta_time)) | |
| track.append(msg) | |
| elif token["type"] == "control_change": | |
| msg = mido.Message('control_change', | |
| channel=token["channel"], | |
| control=token["control"], | |
| value=token["value"], | |
| time=int(delta_time)) | |
| track.append(msg) | |
| elif token["type"] == "tempo": | |
| msg = mido.MetaMessage('set_tempo', | |
| tempo=token["tempo"], | |
| time=int(delta_time)) | |
| track.append(msg) | |
| elif token["type"] == "time_signature": | |
| msg = mido.MetaMessage('time_signature', | |
| numerator=token["numerator"], | |
| denominator=token["denominator"], | |
| time=int(delta_time)) | |
| track.append(msg) | |
| elif token["type"] == "track_name": | |
| msg = mido.MetaMessage('track_name', | |
| name=token["name"], | |
| time=int(delta_time)) | |
| track.append(msg) | |
| midi_file.save(output_path) |