Upload 3 files
Browse files- genprocessor.py +253 -0
- midimusicgenapp.py +147 -0
- miditokenizer.py +66 -0
genprocessor.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""genprocessor.ipynb
|
| 3 |
+
|
| 4 |
+
Automatically generated by Colab.
|
| 5 |
+
|
| 6 |
+
Original file is located at
|
| 7 |
+
https://colab.research.google.com/drive/1kvkhcC2RFcAMNh-jOb6NLFtX_lftKi3I
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
#Process generated text to MIDI compatable file
|
| 11 |
+
import re
|
| 12 |
+
from typing import Dict
|
| 13 |
+
import mido
|
| 14 |
+
|
| 15 |
+
class GENProcessor:
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.START_TRACK = "<|START_TRACK|>"
|
| 18 |
+
self.END_TRACK = "<|END_TRACK|>"
|
| 19 |
+
self.START_METADATA = "<|START_METADATA|>"
|
| 20 |
+
self.END_METADATA = "<|END_METADATA|>"
|
| 21 |
+
|
| 22 |
+
self.field_order = {
|
| 23 |
+
"metadata": ["type", "ticks_per_beat"],
|
| 24 |
+
"tempo": ["type", "time", "tempo"],
|
| 25 |
+
"time_signature": ["type", "time", "numerator", "denominator"],
|
| 26 |
+
"track_name": ["type", "time", "name"],
|
| 27 |
+
"program_change": ["type", "time", "channel", "program"],
|
| 28 |
+
"control_change": ["type", "time", "channel", "control", "value"],
|
| 29 |
+
"note_on": ["type", "time", "channel", "note", "velocity"]
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
def sanitize_event(self, event):
|
| 33 |
+
"""make sure events have all required fields"""
|
| 34 |
+
if not event or 'type' not in event:
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
event_type = event['type']
|
| 38 |
+
required_fields = {
|
| 39 |
+
'note_on': {'time', 'channel', 'note', 'velocity'},
|
| 40 |
+
'note_off': {'time', 'channel', 'note', 'velocity'},
|
| 41 |
+
'control_change': {'time', 'channel', 'control', 'value'},
|
| 42 |
+
'program_change': {'time', 'program'},
|
| 43 |
+
'time_signature': {'time', 'numerator', 'denominator'}
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
if event_type in required_fields:
|
| 47 |
+
missing_fields = required_fields[event_type] - set(event.keys())
|
| 48 |
+
|
| 49 |
+
if missing_fields:
|
| 50 |
+
if event_type == 'time_signature':
|
| 51 |
+
event['numerator'] = 4 # set default
|
| 52 |
+
event['denominator'] = 4 # set default
|
| 53 |
+
event['time'] = event.get('time', 0)
|
| 54 |
+
else:
|
| 55 |
+
return None
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
# validate fields
|
| 59 |
+
if 'time' in event:
|
| 60 |
+
event['time'] = max(0, int(event['time']))
|
| 61 |
+
if 'channel' in event:
|
| 62 |
+
event['channel'] = max(0, int(event['channel']))
|
| 63 |
+
if 'note' in event:
|
| 64 |
+
event['note'] = max(0, int(event['note']))
|
| 65 |
+
if 'velocity' in event:
|
| 66 |
+
event['velocity'] = max(0, int(event['velocity']))
|
| 67 |
+
if 'control' in event:
|
| 68 |
+
event['control'] = max(0, int(event['control']))
|
| 69 |
+
if 'value' in event:
|
| 70 |
+
event['value'] = max(0, int(event['value']))
|
| 71 |
+
if 'program' in event:
|
| 72 |
+
event['program'] = max(0, int(event['program']))
|
| 73 |
+
if 'numerator' in event:
|
| 74 |
+
numerator = int(event['numerator'])
|
| 75 |
+
event['numerator'] = min(4, max(2, numerator))
|
| 76 |
+
if 'denominator' in event:
|
| 77 |
+
event['denominator'] = 4
|
| 78 |
+
|
| 79 |
+
except (ValueError, TypeError):
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
return event
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def parse_event_params(self, text: str) -> Dict:
|
| 86 |
+
"""Parse parameters from a line of text."""
|
| 87 |
+
return {p.split('=', 1)[0].strip(): p.split('=', 1)[1].strip()
|
| 88 |
+
for p in text.split() if '=' in p and len(p.split('=', 1)) == 2}
|
| 89 |
+
|
| 90 |
+
def decode_midi_file(self, text: str) -> Dict:
|
| 91 |
+
"""Decode text representation of a MIDI file into dictionary."""
|
| 92 |
+
# Create template with defaults in case data is missing
|
| 93 |
+
midi_data = {
|
| 94 |
+
"metadata": {
|
| 95 |
+
"ticks_per_beat": 480
|
| 96 |
+
},
|
| 97 |
+
"tracks": [
|
| 98 |
+
[ # First track always contains tempo and time signature, set defaults
|
| 99 |
+
{
|
| 100 |
+
"type": "tempo",
|
| 101 |
+
"time": 0,
|
| 102 |
+
"tempo": 500000
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"type": "time_signature",
|
| 106 |
+
"time": 0,
|
| 107 |
+
"numerator": 4,
|
| 108 |
+
"denominator": 4
|
| 109 |
+
}
|
| 110 |
+
]
|
| 111 |
+
]
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
# Parse the text to get all metadata values
|
| 115 |
+
metadata_values = {}
|
| 116 |
+
for line in text.split():
|
| 117 |
+
if "ticks_per_beat" in line or "ticks_beat" in line:
|
| 118 |
+
match = re.search(r"ticks[_]?(?:per_)?beat=(\d+)", line)
|
| 119 |
+
if match:
|
| 120 |
+
metadata_values["ticks_per_beat"] = max(75, int(match.group(1)))
|
| 121 |
+
elif "tempo" in line and "time=0" in line:
|
| 122 |
+
match = re.search(r"tempo=(\d+)", line)
|
| 123 |
+
if match:
|
| 124 |
+
metadata_values["tempo"] = int(match.group(1))
|
| 125 |
+
|
| 126 |
+
# Update template with any found metadata values
|
| 127 |
+
if "ticks_per_beat" in metadata_values:
|
| 128 |
+
midi_data["metadata"]["ticks_per_beat"] = metadata_values["ticks_per_beat"]
|
| 129 |
+
if "tempo" in metadata_values:
|
| 130 |
+
midi_data["tracks"][0][0]["tempo"] = metadata_values["tempo"]
|
| 131 |
+
|
| 132 |
+
# parse the actual events
|
| 133 |
+
current_track = []
|
| 134 |
+
building_event = None
|
| 135 |
+
collecting_params = {}
|
| 136 |
+
|
| 137 |
+
for line in text.split():
|
| 138 |
+
line = line.strip()
|
| 139 |
+
if not line:
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
# Skip metadata lines we already processed
|
| 143 |
+
if "ticks_per_beat" in line or "tempo=0" in line:
|
| 144 |
+
continue
|
| 145 |
+
|
| 146 |
+
# Track boundaries
|
| 147 |
+
if "<|START_TRACK|>" in line:
|
| 148 |
+
if len(midi_data["tracks"]) == 1: # If we're starting the second track
|
| 149 |
+
current_track = []
|
| 150 |
+
continue
|
| 151 |
+
|
| 152 |
+
if "<|END_TRACK|>" in line:
|
| 153 |
+
if current_track: # Only add non-empty tracks after the first one
|
| 154 |
+
midi_data["tracks"].append(current_track)
|
| 155 |
+
current_track = []
|
| 156 |
+
building_event = None
|
| 157 |
+
collecting_params = {}
|
| 158 |
+
continue
|
| 159 |
+
|
| 160 |
+
# Handle events
|
| 161 |
+
if line.startswith("<") and ">" in line:
|
| 162 |
+
if building_event and collecting_params:
|
| 163 |
+
full_event = {**building_event, **collecting_params}
|
| 164 |
+
sanitized = self.sanitize_event(full_event)
|
| 165 |
+
if sanitized:
|
| 166 |
+
current_track.append(sanitized)
|
| 167 |
+
|
| 168 |
+
event_type = re.match(r"<(\w+)>", line)
|
| 169 |
+
if event_type and event_type.group(1) not in ['START_METADATA', 'composer_', 'position_']:
|
| 170 |
+
building_event = {"type": event_type.group(1)}
|
| 171 |
+
params_text = line[line.find(">") + 1:].strip()
|
| 172 |
+
collecting_params = self.parse_event_params(params_text)
|
| 173 |
+
continue
|
| 174 |
+
|
| 175 |
+
# Collect additional parameters
|
| 176 |
+
if building_event and '=' in line:
|
| 177 |
+
collecting_params.update(self.parse_event_params(line))
|
| 178 |
+
|
| 179 |
+
# Add any remaining events in the last track
|
| 180 |
+
if current_track:
|
| 181 |
+
midi_data["tracks"].append(current_track)
|
| 182 |
+
|
| 183 |
+
return midi_data
|
| 184 |
+
|
| 185 |
+
def generated_tokens_to_midi(tokens, output_path):
|
| 186 |
+
"""Convert tokenized musical events back into an audio MIDI file."""
|
| 187 |
+
midi_file = mido.MidiFile(ticks_per_beat=tokens["metadata"]["ticks_per_beat"])
|
| 188 |
+
|
| 189 |
+
for track_tokens in tokens["tracks"]:
|
| 190 |
+
track = mido.MidiTrack()
|
| 191 |
+
midi_file.tracks.append(track)
|
| 192 |
+
|
| 193 |
+
last_time = 0
|
| 194 |
+
|
| 195 |
+
# sort events by time
|
| 196 |
+
sorted_tokens = sorted(track_tokens, key=lambda x: x["time"])
|
| 197 |
+
|
| 198 |
+
for token in sorted_tokens:
|
| 199 |
+
# Calculate time
|
| 200 |
+
delta_time = token["time"] - last_time
|
| 201 |
+
last_time = token["time"]
|
| 202 |
+
|
| 203 |
+
if token["type"] == "note_on":
|
| 204 |
+
msg = mido.Message('note_on',
|
| 205 |
+
channel=token["channel"],
|
| 206 |
+
note=token["note"],
|
| 207 |
+
velocity=token["velocity"],
|
| 208 |
+
time=int(delta_time))
|
| 209 |
+
track.append(msg)
|
| 210 |
+
|
| 211 |
+
elif token["type"] == "note_off":
|
| 212 |
+
msg = mido.Message('note_off',
|
| 213 |
+
channel=token["channel"],
|
| 214 |
+
note=token["note"],
|
| 215 |
+
velocity=token["velocity"],
|
| 216 |
+
time=int(delta_time))
|
| 217 |
+
track.append(msg)
|
| 218 |
+
|
| 219 |
+
elif token["type"] == "program_change":
|
| 220 |
+
msg = mido.Message('program_change',
|
| 221 |
+
channel=token["channel"],
|
| 222 |
+
program=token["program"],
|
| 223 |
+
time=int(delta_time))
|
| 224 |
+
track.append(msg)
|
| 225 |
+
|
| 226 |
+
elif token["type"] == "control_change":
|
| 227 |
+
msg = mido.Message('control_change',
|
| 228 |
+
channel=token["channel"],
|
| 229 |
+
control=token["control"],
|
| 230 |
+
value=token["value"],
|
| 231 |
+
time=int(delta_time))
|
| 232 |
+
track.append(msg)
|
| 233 |
+
|
| 234 |
+
elif token["type"] == "tempo":
|
| 235 |
+
msg = mido.MetaMessage('set_tempo',
|
| 236 |
+
tempo=token["tempo"],
|
| 237 |
+
time=int(delta_time))
|
| 238 |
+
track.append(msg)
|
| 239 |
+
|
| 240 |
+
elif token["type"] == "time_signature":
|
| 241 |
+
msg = mido.MetaMessage('time_signature',
|
| 242 |
+
numerator=token["numerator"],
|
| 243 |
+
denominator=token["denominator"],
|
| 244 |
+
time=int(delta_time))
|
| 245 |
+
track.append(msg)
|
| 246 |
+
|
| 247 |
+
elif token["type"] == "track_name":
|
| 248 |
+
msg = mido.MetaMessage('track_name',
|
| 249 |
+
name=token["name"],
|
| 250 |
+
time=int(delta_time))
|
| 251 |
+
track.append(msg)
|
| 252 |
+
|
| 253 |
+
midi_file.save(output_path)
|
midimusicgenapp.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""MidiMusicGenApp.ipynb
|
| 3 |
+
|
| 4 |
+
Automatically generated by Colab.
|
| 5 |
+
|
| 6 |
+
Original file is located at
|
| 7 |
+
https://colab.research.google.com/drive/1Dn99ii_FiQTx-z5B0dX0br0Gc0U9MUqD
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import GPT2LMHeadModel
|
| 13 |
+
from miditokenizer import MIDITokenizer
|
| 14 |
+
from genprocessor import GENProcessor, generated_tokens_to_midi
|
| 15 |
+
from midi2audio import FluidSynth
|
| 16 |
+
from pydub import AudioSegment
|
| 17 |
+
import tempfile
|
| 18 |
+
import os
|
| 19 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 20 |
+
|
| 21 |
+
# Load model and tokenizer
|
| 22 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
model = torch.load('model_complete_18epochs.pkl',map_location=device)
|
| 24 |
+
tokenizer = MIDITokenizer()
|
| 25 |
+
processor = GENProcessor()
|
| 26 |
+
model.eval()
|
| 27 |
+
|
| 28 |
+
#functions to adjust timing & combine generated song parts
|
| 29 |
+
def adjust_midi_timing(midi_data, start_time=0):
|
| 30 |
+
"""Adjust MIDI timing with optional start time. Prevent large gaps based on ticks_per_beat."""
|
| 31 |
+
try:
|
| 32 |
+
# Keep tempo track separate
|
| 33 |
+
tempo_track = midi_data['tracks'][0]
|
| 34 |
+
ticks_per_beat = midi_data['metadata']['ticks_per_beat']
|
| 35 |
+
|
| 36 |
+
# Calculate thresholds based on ticks_per_beat
|
| 37 |
+
gap_threshold = ticks_per_beat * 2
|
| 38 |
+
small_increment = ticks_per_beat // 8 # Eighth note
|
| 39 |
+
|
| 40 |
+
# Get all other events and sort by time
|
| 41 |
+
all_events = []
|
| 42 |
+
for track in midi_data['tracks'][1:]:
|
| 43 |
+
all_events.extend(track)
|
| 44 |
+
all_events.sort(key=lambda x: x['time'])
|
| 45 |
+
|
| 46 |
+
# Find sequential times, ignoring large gaps
|
| 47 |
+
sequential_events = []
|
| 48 |
+
current_time = all_events[0]['time'] if all_events else 0
|
| 49 |
+
|
| 50 |
+
for event in all_events:
|
| 51 |
+
if event['time'] - current_time > gap_threshold:
|
| 52 |
+
event['time'] = current_time + small_increment
|
| 53 |
+
current_time = event['time']
|
| 54 |
+
sequential_events.append(event)
|
| 55 |
+
|
| 56 |
+
# Find first non-zero time
|
| 57 |
+
first_time = min((event['time'] for event in sequential_events if event['time'] != 0), default=0)
|
| 58 |
+
|
| 59 |
+
adjusted_data = {'metadata': midi_data['metadata'], 'tracks': [tempo_track]}
|
| 60 |
+
|
| 61 |
+
# Adjust all events
|
| 62 |
+
adjusted_track = []
|
| 63 |
+
for event in sequential_events:
|
| 64 |
+
adjusted_event = event.copy()
|
| 65 |
+
if event['time'] != 0:
|
| 66 |
+
adjusted_event['time'] = (event['time'] - first_time) + start_time
|
| 67 |
+
else:
|
| 68 |
+
adjusted_event['time'] = start_time
|
| 69 |
+
adjusted_track.append(adjusted_event)
|
| 70 |
+
|
| 71 |
+
adjusted_data['tracks'].append(adjusted_track)
|
| 72 |
+
return adjusted_data
|
| 73 |
+
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print(f"Error adjusting MIDI timing: {str(e)}")
|
| 76 |
+
return midi_data
|
| 77 |
+
|
| 78 |
+
#Functions to generate music
|
| 79 |
+
|
| 80 |
+
def generate_music(prompt):
|
| 81 |
+
"""Generate music based on a given prompt."""
|
| 82 |
+
# Tokenize
|
| 83 |
+
if tokenizer.pad_token is None:
|
| 84 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 85 |
+
|
| 86 |
+
inputs = tokenizer(
|
| 87 |
+
prompt,
|
| 88 |
+
return_tensors="pt",
|
| 89 |
+
padding=True,
|
| 90 |
+
truncation=True,
|
| 91 |
+
add_special_tokens=True
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Generate
|
| 95 |
+
output_sequences = model.generate(
|
| 96 |
+
input_ids=inputs["input_ids"].to(model.device),
|
| 97 |
+
attention_mask=inputs["attention_mask"].to(model.device),
|
| 98 |
+
max_length=1024,
|
| 99 |
+
do_sample=True,
|
| 100 |
+
temperature=0.6, #adjust creativity
|
| 101 |
+
top_k=30,
|
| 102 |
+
top_p=0.90,
|
| 103 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 104 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Decode the generated sequence
|
| 108 |
+
generated_text = tokenizer.decode(output_sequences[0])
|
| 109 |
+
|
| 110 |
+
return generated_text
|
| 111 |
+
|
| 112 |
+
def generate_wrapper(composer):
|
| 113 |
+
# Format the prompt with the selected composer
|
| 114 |
+
prompt = f"<|START_METADATA|> <|composer_{composer}|><metadata> ticks_per_beat="
|
| 115 |
+
print(prompt)
|
| 116 |
+
generated_text = generate_music(prompt)
|
| 117 |
+
midi_data=adjust_midi_timing(processor.decode_midi_file(generated_text))
|
| 118 |
+
|
| 119 |
+
# Create temp file for MIDI
|
| 120 |
+
with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as tmp:
|
| 121 |
+
generated_tokens_to_midi(midi_data, tmp.name)
|
| 122 |
+
# Convert to WAV
|
| 123 |
+
fs = FluidSynth(sound_font='FluidR3Mono_GM.sf3')
|
| 124 |
+
wav_file = tmp.name.replace('.mid', '.wav')
|
| 125 |
+
fs.midi_to_audio(tmp.name, wav_file)
|
| 126 |
+
|
| 127 |
+
mp3_file = wav_file.replace('.wav', '.mp3')
|
| 128 |
+
audio = AudioSegment.from_wav(wav_file)
|
| 129 |
+
audio.export(mp3_file, format="mp3")
|
| 130 |
+
return mp3_file
|
| 131 |
+
|
| 132 |
+
iface = gr.Interface(
|
| 133 |
+
fn=generate_wrapper,
|
| 134 |
+
inputs=[
|
| 135 |
+
gr.Dropdown(
|
| 136 |
+
choices=["Bach", "Chopin"],
|
| 137 |
+
label="Select Composer",
|
| 138 |
+
value="Bach" # default value
|
| 139 |
+
)
|
| 140 |
+
],
|
| 141 |
+
outputs=gr.Audio(type="filepath", label="Generated MIDI"),
|
| 142 |
+
title="MAI: MIDI AI Music Generation Model",
|
| 143 |
+
description="Generate MIDI sequences",
|
| 144 |
+
flagging_mode="never"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
iface.launch()
|
miditokenizer.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""miditokenizer.ipynb
|
| 3 |
+
|
| 4 |
+
Automatically generated by Colab.
|
| 5 |
+
|
| 6 |
+
Original file is located at
|
| 7 |
+
https://colab.research.google.com/drive/16YJUBYcqKYPVIhwzKNi4ELnTftr2TcUY
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
#We use a base GPT2 tokenizer with additional functions to handle composer tokens
|
| 11 |
+
#Datasets are created by processing our files in chunks, due to model sequence limits
|
| 12 |
+
#Position information is added to each chunk as additional pattern/data for training
|
| 13 |
+
|
| 14 |
+
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
|
| 15 |
+
from torch.utils.data import Dataset
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
class MIDITokenizer:
|
| 20 |
+
"""tokenization specific to MIDI data with special tokens"""
|
| 21 |
+
def __init__(self, pretrained_model='gpt2'):
|
| 22 |
+
self.base_tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_model)
|
| 23 |
+
special_tokens = {
|
| 24 |
+
'additional_special_tokens': [
|
| 25 |
+
'<|START_METADATA|>',
|
| 26 |
+
'<|END_METADATA|>',
|
| 27 |
+
'<|START_TRACK|>',
|
| 28 |
+
'<|END_TRACK|>',
|
| 29 |
+
'<metadata>',
|
| 30 |
+
'<tempo>',
|
| 31 |
+
'<time_signature>',
|
| 32 |
+
'<program_change>',
|
| 33 |
+
'<note_on>',
|
| 34 |
+
'<note_off>',
|
| 35 |
+
'<control_change>'
|
| 36 |
+
],
|
| 37 |
+
'pad_token': '[PAD]'
|
| 38 |
+
}
|
| 39 |
+
self.base_tokenizer.add_special_tokens(special_tokens)
|
| 40 |
+
self.pad_token_id = self.base_tokenizer.pad_token_id
|
| 41 |
+
self.eos_token_id = self.base_tokenizer.eos_token_id
|
| 42 |
+
self.bos_token_id = self.base_tokenizer.bos_token_id
|
| 43 |
+
self.pad_token = self.base_tokenizer.pad_token
|
| 44 |
+
self.eos_token = self.base_tokenizer.eos_token
|
| 45 |
+
self.bos_token = self.base_tokenizer.bos_token
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def add_composer_tokens(self, composers):
|
| 49 |
+
#composer tokens
|
| 50 |
+
composer_tokens = [f'<|composer_{c}|>' for c in composers]
|
| 51 |
+
self.base_tokenizer.add_special_tokens({
|
| 52 |
+
'additional_special_tokens': composer_tokens
|
| 53 |
+
})
|
| 54 |
+
|
| 55 |
+
def __call__(self, text, **kwargs):
|
| 56 |
+
return self.base_tokenizer(text, **kwargs)
|
| 57 |
+
|
| 58 |
+
def decode(self, token_ids, **kwargs):
|
| 59 |
+
"""Decode while preserving special tokens"""
|
| 60 |
+
return self.base_tokenizer.decode(token_ids, skip_special_tokens=False, **kwargs)
|
| 61 |
+
|
| 62 |
+
def pad(self, *args, **kwargs):
|
| 63 |
+
return self.base_tokenizer.pad(*args, **kwargs)
|
| 64 |
+
|
| 65 |
+
def get_vocab(self):
|
| 66 |
+
return self.base_tokenizer.get_vocab()
|