tlmdesign commited on
Commit
f542560
·
verified ·
1 Parent(s): 43d673b

Upload 6 files

Browse files
Files changed (6) hide show
  1. README.md +10 -7
  2. genprocessor.py +253 -0
  3. midimusicgenapp.py +151 -0
  4. miditokenizer.py +66 -0
  5. packages.txt +2 -0
  6. requirements.txt +6 -0
README.md CHANGED
@@ -1,14 +1,17 @@
1
  ---
2
- title: MAI MidiAI Playback
3
- emoji: 📚
4
- colorFrom: pink
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.7.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: 'MAI: MIDI AI Music Generation Model'
12
  ---
 
 
 
 
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: MAI MIDI AI Music Generation Model
3
+ Author: Tara Manuel
4
+ colorFrom: purple
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 4.43.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+
12
  ---
13
+ This project implements a GPT-2 based model for generating MIDI music sequences
14
+ using deep learning. The model is trained on MIDI files from the MAESTRO
15
+ dataset, converted to a special JSON format. The model uses HuggingFace's
16
+ transformers library for training.
17
 
 
genprocessor.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """genprocessor.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1kvkhcC2RFcAMNh-jOb6NLFtX_lftKi3I
8
+ """
9
+
10
+ #Process generated text to MIDI compatable file
11
+ import re
12
+ from typing import Dict
13
+ import mido
14
+
15
+ class GENProcessor:
16
+ def __init__(self):
17
+ self.START_TRACK = "<|START_TRACK|>"
18
+ self.END_TRACK = "<|END_TRACK|>"
19
+ self.START_METADATA = "<|START_METADATA|>"
20
+ self.END_METADATA = "<|END_METADATA|>"
21
+
22
+ self.field_order = {
23
+ "metadata": ["type", "ticks_per_beat"],
24
+ "tempo": ["type", "time", "tempo"],
25
+ "time_signature": ["type", "time", "numerator", "denominator"],
26
+ "track_name": ["type", "time", "name"],
27
+ "program_change": ["type", "time", "channel", "program"],
28
+ "control_change": ["type", "time", "channel", "control", "value"],
29
+ "note_on": ["type", "time", "channel", "note", "velocity"]
30
+ }
31
+
32
+ def sanitize_event(self, event):
33
+ """make sure events have all required fields"""
34
+ if not event or 'type' not in event:
35
+ return None
36
+
37
+ event_type = event['type']
38
+ required_fields = {
39
+ 'note_on': {'time', 'channel', 'note', 'velocity'},
40
+ 'note_off': {'time', 'channel', 'note', 'velocity'},
41
+ 'control_change': {'time', 'channel', 'control', 'value'},
42
+ 'program_change': {'time', 'program'},
43
+ 'time_signature': {'time', 'numerator', 'denominator'}
44
+ }
45
+
46
+ if event_type in required_fields:
47
+ missing_fields = required_fields[event_type] - set(event.keys())
48
+
49
+ if missing_fields:
50
+ if event_type == 'time_signature':
51
+ event['numerator'] = 4 # set default
52
+ event['denominator'] = 4 # set default
53
+ event['time'] = event.get('time', 0)
54
+ else:
55
+ return None
56
+
57
+ try:
58
+ # validate fields
59
+ if 'time' in event:
60
+ event['time'] = max(0, int(event['time']))
61
+ if 'channel' in event:
62
+ event['channel'] = max(0, int(event['channel']))
63
+ if 'note' in event:
64
+ event['note'] = max(0, int(event['note']))
65
+ if 'velocity' in event:
66
+ event['velocity'] = max(0, int(event['velocity']))
67
+ if 'control' in event:
68
+ event['control'] = max(0, int(event['control']))
69
+ if 'value' in event:
70
+ event['value'] = max(0, int(event['value']))
71
+ if 'program' in event:
72
+ event['program'] = max(0, int(event['program']))
73
+ if 'numerator' in event:
74
+ numerator = int(event['numerator'])
75
+ event['numerator'] = min(4, max(2, numerator))
76
+ if 'denominator' in event:
77
+ event['denominator'] = 4
78
+
79
+ except (ValueError, TypeError):
80
+ return None
81
+
82
+ return event
83
+
84
+
85
+ def parse_event_params(self, text: str) -> Dict:
86
+ """Parse parameters from a line of text."""
87
+ return {p.split('=', 1)[0].strip(): p.split('=', 1)[1].strip()
88
+ for p in text.split() if '=' in p and len(p.split('=', 1)) == 2}
89
+
90
+ def decode_midi_file(self, text: str) -> Dict:
91
+ """Decode text representation of a MIDI file into dictionary."""
92
+ # Create template with defaults in case data is missing
93
+ midi_data = {
94
+ "metadata": {
95
+ "ticks_per_beat": 480
96
+ },
97
+ "tracks": [
98
+ [ # First track always contains tempo and time signature, set defaults
99
+ {
100
+ "type": "tempo",
101
+ "time": 0,
102
+ "tempo": 500000
103
+ },
104
+ {
105
+ "type": "time_signature",
106
+ "time": 0,
107
+ "numerator": 4,
108
+ "denominator": 4
109
+ }
110
+ ]
111
+ ]
112
+ }
113
+
114
+ # Parse the text to get all metadata values
115
+ metadata_values = {}
116
+ for line in text.split():
117
+ if "ticks_per_beat" in line or "ticks_beat" in line:
118
+ match = re.search(r"ticks[_]?(?:per_)?beat=(\d+)", line)
119
+ if match:
120
+ metadata_values["ticks_per_beat"] = max(75, int(match.group(1)))
121
+ elif "tempo" in line and "time=0" in line:
122
+ match = re.search(r"tempo=(\d+)", line)
123
+ if match:
124
+ metadata_values["tempo"] = int(match.group(1))
125
+
126
+ # Update template with any found metadata values
127
+ if "ticks_per_beat" in metadata_values:
128
+ midi_data["metadata"]["ticks_per_beat"] = metadata_values["ticks_per_beat"]
129
+ if "tempo" in metadata_values:
130
+ midi_data["tracks"][0][0]["tempo"] = metadata_values["tempo"]
131
+
132
+ # parse the actual events
133
+ current_track = []
134
+ building_event = None
135
+ collecting_params = {}
136
+
137
+ for line in text.split():
138
+ line = line.strip()
139
+ if not line:
140
+ continue
141
+
142
+ # Skip metadata lines we already processed
143
+ if "ticks_per_beat" in line or "tempo=0" in line:
144
+ continue
145
+
146
+ # Track boundaries
147
+ if "<|START_TRACK|>" in line:
148
+ if len(midi_data["tracks"]) == 1: # If we're starting the second track
149
+ current_track = []
150
+ continue
151
+
152
+ if "<|END_TRACK|>" in line:
153
+ if current_track: # Only add non-empty tracks after the first one
154
+ midi_data["tracks"].append(current_track)
155
+ current_track = []
156
+ building_event = None
157
+ collecting_params = {}
158
+ continue
159
+
160
+ # Handle events
161
+ if line.startswith("<") and ">" in line:
162
+ if building_event and collecting_params:
163
+ full_event = {**building_event, **collecting_params}
164
+ sanitized = self.sanitize_event(full_event)
165
+ if sanitized:
166
+ current_track.append(sanitized)
167
+
168
+ event_type = re.match(r"<(\w+)>", line)
169
+ if event_type and event_type.group(1) not in ['START_METADATA', 'composer_', 'position_']:
170
+ building_event = {"type": event_type.group(1)}
171
+ params_text = line[line.find(">") + 1:].strip()
172
+ collecting_params = self.parse_event_params(params_text)
173
+ continue
174
+
175
+ # Collect additional parameters
176
+ if building_event and '=' in line:
177
+ collecting_params.update(self.parse_event_params(line))
178
+
179
+ # Add any remaining events in the last track
180
+ if current_track:
181
+ midi_data["tracks"].append(current_track)
182
+
183
+ return midi_data
184
+
185
+ def generated_tokens_to_midi(tokens, output_path):
186
+ """Convert tokenized musical events back into an audio MIDI file."""
187
+ midi_file = mido.MidiFile(ticks_per_beat=tokens["metadata"]["ticks_per_beat"])
188
+
189
+ for track_tokens in tokens["tracks"]:
190
+ track = mido.MidiTrack()
191
+ midi_file.tracks.append(track)
192
+
193
+ last_time = 0
194
+
195
+ # sort events by time
196
+ sorted_tokens = sorted(track_tokens, key=lambda x: x["time"])
197
+
198
+ for token in sorted_tokens:
199
+ # Calculate time
200
+ delta_time = token["time"] - last_time
201
+ last_time = token["time"]
202
+
203
+ if token["type"] == "note_on":
204
+ msg = mido.Message('note_on',
205
+ channel=token["channel"],
206
+ note=token["note"],
207
+ velocity=token["velocity"],
208
+ time=int(delta_time))
209
+ track.append(msg)
210
+
211
+ elif token["type"] == "note_off":
212
+ msg = mido.Message('note_off',
213
+ channel=token["channel"],
214
+ note=token["note"],
215
+ velocity=token["velocity"],
216
+ time=int(delta_time))
217
+ track.append(msg)
218
+
219
+ elif token["type"] == "program_change":
220
+ msg = mido.Message('program_change',
221
+ channel=token["channel"],
222
+ program=token["program"],
223
+ time=int(delta_time))
224
+ track.append(msg)
225
+
226
+ elif token["type"] == "control_change":
227
+ msg = mido.Message('control_change',
228
+ channel=token["channel"],
229
+ control=token["control"],
230
+ value=token["value"],
231
+ time=int(delta_time))
232
+ track.append(msg)
233
+
234
+ elif token["type"] == "tempo":
235
+ msg = mido.MetaMessage('set_tempo',
236
+ tempo=token["tempo"],
237
+ time=int(delta_time))
238
+ track.append(msg)
239
+
240
+ elif token["type"] == "time_signature":
241
+ msg = mido.MetaMessage('time_signature',
242
+ numerator=token["numerator"],
243
+ denominator=token["denominator"],
244
+ time=int(delta_time))
245
+ track.append(msg)
246
+
247
+ elif token["type"] == "track_name":
248
+ msg = mido.MetaMessage('track_name',
249
+ name=token["name"],
250
+ time=int(delta_time))
251
+ track.append(msg)
252
+
253
+ midi_file.save(output_path)
midimusicgenapp.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """MidiMusicGenApp.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1Dn99ii_FiQTx-z5B0dX0br0Gc0U9MUqD
8
+ """
9
+
10
+ import gradio as gr
11
+ import torch
12
+ from transformers import GPT2LMHeadModel
13
+ from miditokenizer import MIDITokenizer
14
+ from genprocessor import GENProcessor, generated_tokens_to_midi
15
+ from midi2audio import FluidSynth
16
+ from pydub import AudioSegment
17
+ import tempfile
18
+ import os
19
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
20
+
21
+ # Load model and tokenizer
22
+
23
+ torch.serialization.add_safe_globals([set])
24
+ torch.serialization.add_safe_globals([GPT2LMHeadModel])
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ model = torch.load('model_complete_18epochs.pkl',map_location=device, weights_only=False)
27
+ tokenizer = MIDITokenizer()
28
+ processor = GENProcessor()
29
+ model.eval()
30
+
31
+ #functions to adjust timing & combine generated song parts
32
+ def adjust_midi_timing(midi_data, start_time=0):
33
+ """Adjust MIDI timing with optional start time. Prevent large gaps based on ticks_per_beat."""
34
+ try:
35
+ # Keep tempo track separate
36
+ tempo_track = midi_data['tracks'][0]
37
+ ticks_per_beat = midi_data['metadata']['ticks_per_beat']
38
+
39
+ # Calculate thresholds based on ticks_per_beat
40
+ gap_threshold = ticks_per_beat * 2
41
+ small_increment = ticks_per_beat // 8 # Eighth note
42
+
43
+ # Get all other events and sort by time
44
+ all_events = []
45
+ for track in midi_data['tracks'][1:]:
46
+ all_events.extend(track)
47
+ all_events.sort(key=lambda x: x['time'])
48
+
49
+ # Find sequential times, ignoring large gaps
50
+ sequential_events = []
51
+ current_time = all_events[0]['time'] if all_events else 0
52
+
53
+ for event in all_events:
54
+ if event['time'] - current_time > gap_threshold:
55
+ event['time'] = current_time + small_increment
56
+ current_time = event['time']
57
+ sequential_events.append(event)
58
+
59
+ # Find first non-zero time
60
+ first_time = min((event['time'] for event in sequential_events if event['time'] != 0), default=0)
61
+
62
+ adjusted_data = {'metadata': midi_data['metadata'], 'tracks': [tempo_track]}
63
+
64
+ # Adjust all events
65
+ adjusted_track = []
66
+ for event in sequential_events:
67
+ adjusted_event = event.copy()
68
+ if event['time'] != 0:
69
+ adjusted_event['time'] = (event['time'] - first_time) + start_time
70
+ else:
71
+ adjusted_event['time'] = start_time
72
+ adjusted_track.append(adjusted_event)
73
+
74
+ adjusted_data['tracks'].append(adjusted_track)
75
+ return adjusted_data
76
+
77
+ except Exception as e:
78
+ print(f"Error adjusting MIDI timing: {str(e)}")
79
+ return midi_data
80
+
81
+ #Functions to generate music
82
+
83
+ def generate_music(prompt):
84
+ """Generate music based on a given prompt."""
85
+ # Tokenize
86
+ if tokenizer.pad_token is None:
87
+ tokenizer.pad_token = tokenizer.eos_token
88
+
89
+ inputs = tokenizer(
90
+ prompt,
91
+ return_tensors="pt",
92
+ padding=True,
93
+ truncation=True,
94
+ add_special_tokens=True
95
+ )
96
+
97
+ # Generate
98
+ output_sequences = model.generate(
99
+ input_ids=inputs["input_ids"].to(model.device),
100
+ attention_mask=inputs["attention_mask"].to(model.device),
101
+ max_length=1024,
102
+ do_sample=True,
103
+ temperature=0.6, #adjust creativity
104
+ top_k=30,
105
+ top_p=0.90,
106
+ pad_token_id=tokenizer.eos_token_id,
107
+ eos_token_id=tokenizer.eos_token_id,
108
+ )
109
+
110
+ # Decode the generated sequence
111
+ generated_text = tokenizer.decode(output_sequences[0])
112
+
113
+ return generated_text
114
+
115
+ def generate_wrapper(composer):
116
+ # Format the prompt with the selected composer
117
+ prompt = f"<|START_METADATA|> <|composer_{composer}|><metadata> ticks_per_beat="
118
+ generated_text = generate_music(prompt)
119
+ midi_data=adjust_midi_timing(processor.decode_midi_file(generated_text))
120
+ print(midi_data)
121
+ # Create temp file for MIDI
122
+ with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as tmp:
123
+ generated_tokens_to_midi(midi_data, tmp.name)
124
+ # Convert to WAV
125
+ fs = FluidSynth(sound_font='FluidR3Mono_GM.sf3')
126
+ wav_file = tmp.name.replace('.mid', '.wav')
127
+ fs.midi_to_audio(tmp.name, wav_file)
128
+
129
+ # Convert to MP3
130
+ mp3_file = wav_file.replace('.wav', '.mp3')
131
+ audio = AudioSegment.from_wav(wav_file)
132
+ audio.export(mp3_file, format="mp3")
133
+ return mp3_file
134
+
135
+ iface = gr.Interface(
136
+ fn=generate_wrapper,
137
+ inputs=[
138
+ gr.Dropdown(
139
+ choices=["Bach", "Chopin"],
140
+ label="Select Composer",
141
+ value="Bach" # default value
142
+ )
143
+ ],
144
+ outputs=gr.Audio(type="filepath",label="Generated MIDI"),
145
+ title="MAI: MIDI AI Music Generation Model",
146
+ description="Select a composer whose musical style you'd like to emulate. Generate an original sequence inspired by that composer's unique sound. It should take a few minutes. Once it's ready, you can listen to the clip or download the audio file."
147
+ #description="Compose Music in the Style of Your Favorite Composer. Select a composer to generate a music sequence in the style of selected composer",
148
+ #flagging_mode="never"
149
+ )
150
+
151
+ iface.launch()
miditokenizer.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """miditokenizer.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/16YJUBYcqKYPVIhwzKNi4ELnTftr2TcUY
8
+ """
9
+
10
+ #We use a base GPT2 tokenizer with additional functions to handle composer tokens
11
+ #Datasets are created by processing our files in chunks, due to model sequence limits
12
+ #Position information is added to each chunk as additional pattern/data for training
13
+
14
+ from transformers import GPT2TokenizerFast, GPT2LMHeadModel
15
+ from torch.utils.data import Dataset
16
+ from pathlib import Path
17
+ import torch
18
+
19
+ class MIDITokenizer:
20
+ """tokenization specific to MIDI data with special tokens"""
21
+ def __init__(self, pretrained_model='gpt2'):
22
+ self.base_tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_model)
23
+ special_tokens = {
24
+ 'additional_special_tokens': [
25
+ '<|START_METADATA|>',
26
+ '<|END_METADATA|>',
27
+ '<|START_TRACK|>',
28
+ '<|END_TRACK|>',
29
+ '<metadata>',
30
+ '<tempo>',
31
+ '<time_signature>',
32
+ '<program_change>',
33
+ '<note_on>',
34
+ '<note_off>',
35
+ '<control_change>'
36
+ ],
37
+ 'pad_token': '[PAD]'
38
+ }
39
+ self.base_tokenizer.add_special_tokens(special_tokens)
40
+ self.pad_token_id = self.base_tokenizer.pad_token_id
41
+ self.eos_token_id = self.base_tokenizer.eos_token_id
42
+ self.bos_token_id = self.base_tokenizer.bos_token_id
43
+ self.pad_token = self.base_tokenizer.pad_token
44
+ self.eos_token = self.base_tokenizer.eos_token
45
+ self.bos_token = self.base_tokenizer.bos_token
46
+
47
+
48
+ def add_composer_tokens(self, composers):
49
+ #composer tokens
50
+ composer_tokens = [f'<|composer_{c}|>' for c in composers]
51
+ self.base_tokenizer.add_special_tokens({
52
+ 'additional_special_tokens': composer_tokens
53
+ })
54
+
55
+ def __call__(self, text, **kwargs):
56
+ return self.base_tokenizer(text, **kwargs)
57
+
58
+ def decode(self, token_ids, **kwargs):
59
+ """Decode while preserving special tokens"""
60
+ return self.base_tokenizer.decode(token_ids, skip_special_tokens=False, **kwargs)
61
+
62
+ def pad(self, *args, **kwargs):
63
+ return self.base_tokenizer.pad(*args, **kwargs)
64
+
65
+ def get_vocab(self):
66
+ return self.base_tokenizer.get_vocab()
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ fluidsynth
2
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ mido
5
+ midi2audio
6
+ pydub