tlmdesign commited on
Commit
d9e88e8
·
verified ·
1 Parent(s): 6589cd9

Upload 3 files

Browse files
Files changed (3) hide show
  1. genprocessor.py +253 -0
  2. midimusicgenapp.py +147 -0
  3. miditokenizer.py +66 -0
genprocessor.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """genprocessor.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1kvkhcC2RFcAMNh-jOb6NLFtX_lftKi3I
8
+ """
9
+
10
+ #Process generated text to MIDI compatable file
11
+ import re
12
+ from typing import Dict
13
+ import mido
14
+
15
+ class GENProcessor:
16
+ def __init__(self):
17
+ self.START_TRACK = "<|START_TRACK|>"
18
+ self.END_TRACK = "<|END_TRACK|>"
19
+ self.START_METADATA = "<|START_METADATA|>"
20
+ self.END_METADATA = "<|END_METADATA|>"
21
+
22
+ self.field_order = {
23
+ "metadata": ["type", "ticks_per_beat"],
24
+ "tempo": ["type", "time", "tempo"],
25
+ "time_signature": ["type", "time", "numerator", "denominator"],
26
+ "track_name": ["type", "time", "name"],
27
+ "program_change": ["type", "time", "channel", "program"],
28
+ "control_change": ["type", "time", "channel", "control", "value"],
29
+ "note_on": ["type", "time", "channel", "note", "velocity"]
30
+ }
31
+
32
+ def sanitize_event(self, event):
33
+ """make sure events have all required fields"""
34
+ if not event or 'type' not in event:
35
+ return None
36
+
37
+ event_type = event['type']
38
+ required_fields = {
39
+ 'note_on': {'time', 'channel', 'note', 'velocity'},
40
+ 'note_off': {'time', 'channel', 'note', 'velocity'},
41
+ 'control_change': {'time', 'channel', 'control', 'value'},
42
+ 'program_change': {'time', 'program'},
43
+ 'time_signature': {'time', 'numerator', 'denominator'}
44
+ }
45
+
46
+ if event_type in required_fields:
47
+ missing_fields = required_fields[event_type] - set(event.keys())
48
+
49
+ if missing_fields:
50
+ if event_type == 'time_signature':
51
+ event['numerator'] = 4 # set default
52
+ event['denominator'] = 4 # set default
53
+ event['time'] = event.get('time', 0)
54
+ else:
55
+ return None
56
+
57
+ try:
58
+ # validate fields
59
+ if 'time' in event:
60
+ event['time'] = max(0, int(event['time']))
61
+ if 'channel' in event:
62
+ event['channel'] = max(0, int(event['channel']))
63
+ if 'note' in event:
64
+ event['note'] = max(0, int(event['note']))
65
+ if 'velocity' in event:
66
+ event['velocity'] = max(0, int(event['velocity']))
67
+ if 'control' in event:
68
+ event['control'] = max(0, int(event['control']))
69
+ if 'value' in event:
70
+ event['value'] = max(0, int(event['value']))
71
+ if 'program' in event:
72
+ event['program'] = max(0, int(event['program']))
73
+ if 'numerator' in event:
74
+ numerator = int(event['numerator'])
75
+ event['numerator'] = min(4, max(2, numerator))
76
+ if 'denominator' in event:
77
+ event['denominator'] = 4
78
+
79
+ except (ValueError, TypeError):
80
+ return None
81
+
82
+ return event
83
+
84
+
85
+ def parse_event_params(self, text: str) -> Dict:
86
+ """Parse parameters from a line of text."""
87
+ return {p.split('=', 1)[0].strip(): p.split('=', 1)[1].strip()
88
+ for p in text.split() if '=' in p and len(p.split('=', 1)) == 2}
89
+
90
+ def decode_midi_file(self, text: str) -> Dict:
91
+ """Decode text representation of a MIDI file into dictionary."""
92
+ # Create template with defaults in case data is missing
93
+ midi_data = {
94
+ "metadata": {
95
+ "ticks_per_beat": 480
96
+ },
97
+ "tracks": [
98
+ [ # First track always contains tempo and time signature, set defaults
99
+ {
100
+ "type": "tempo",
101
+ "time": 0,
102
+ "tempo": 500000
103
+ },
104
+ {
105
+ "type": "time_signature",
106
+ "time": 0,
107
+ "numerator": 4,
108
+ "denominator": 4
109
+ }
110
+ ]
111
+ ]
112
+ }
113
+
114
+ # Parse the text to get all metadata values
115
+ metadata_values = {}
116
+ for line in text.split():
117
+ if "ticks_per_beat" in line or "ticks_beat" in line:
118
+ match = re.search(r"ticks[_]?(?:per_)?beat=(\d+)", line)
119
+ if match:
120
+ metadata_values["ticks_per_beat"] = max(75, int(match.group(1)))
121
+ elif "tempo" in line and "time=0" in line:
122
+ match = re.search(r"tempo=(\d+)", line)
123
+ if match:
124
+ metadata_values["tempo"] = int(match.group(1))
125
+
126
+ # Update template with any found metadata values
127
+ if "ticks_per_beat" in metadata_values:
128
+ midi_data["metadata"]["ticks_per_beat"] = metadata_values["ticks_per_beat"]
129
+ if "tempo" in metadata_values:
130
+ midi_data["tracks"][0][0]["tempo"] = metadata_values["tempo"]
131
+
132
+ # parse the actual events
133
+ current_track = []
134
+ building_event = None
135
+ collecting_params = {}
136
+
137
+ for line in text.split():
138
+ line = line.strip()
139
+ if not line:
140
+ continue
141
+
142
+ # Skip metadata lines we already processed
143
+ if "ticks_per_beat" in line or "tempo=0" in line:
144
+ continue
145
+
146
+ # Track boundaries
147
+ if "<|START_TRACK|>" in line:
148
+ if len(midi_data["tracks"]) == 1: # If we're starting the second track
149
+ current_track = []
150
+ continue
151
+
152
+ if "<|END_TRACK|>" in line:
153
+ if current_track: # Only add non-empty tracks after the first one
154
+ midi_data["tracks"].append(current_track)
155
+ current_track = []
156
+ building_event = None
157
+ collecting_params = {}
158
+ continue
159
+
160
+ # Handle events
161
+ if line.startswith("<") and ">" in line:
162
+ if building_event and collecting_params:
163
+ full_event = {**building_event, **collecting_params}
164
+ sanitized = self.sanitize_event(full_event)
165
+ if sanitized:
166
+ current_track.append(sanitized)
167
+
168
+ event_type = re.match(r"<(\w+)>", line)
169
+ if event_type and event_type.group(1) not in ['START_METADATA', 'composer_', 'position_']:
170
+ building_event = {"type": event_type.group(1)}
171
+ params_text = line[line.find(">") + 1:].strip()
172
+ collecting_params = self.parse_event_params(params_text)
173
+ continue
174
+
175
+ # Collect additional parameters
176
+ if building_event and '=' in line:
177
+ collecting_params.update(self.parse_event_params(line))
178
+
179
+ # Add any remaining events in the last track
180
+ if current_track:
181
+ midi_data["tracks"].append(current_track)
182
+
183
+ return midi_data
184
+
185
+ def generated_tokens_to_midi(tokens, output_path):
186
+ """Convert tokenized musical events back into an audio MIDI file."""
187
+ midi_file = mido.MidiFile(ticks_per_beat=tokens["metadata"]["ticks_per_beat"])
188
+
189
+ for track_tokens in tokens["tracks"]:
190
+ track = mido.MidiTrack()
191
+ midi_file.tracks.append(track)
192
+
193
+ last_time = 0
194
+
195
+ # sort events by time
196
+ sorted_tokens = sorted(track_tokens, key=lambda x: x["time"])
197
+
198
+ for token in sorted_tokens:
199
+ # Calculate time
200
+ delta_time = token["time"] - last_time
201
+ last_time = token["time"]
202
+
203
+ if token["type"] == "note_on":
204
+ msg = mido.Message('note_on',
205
+ channel=token["channel"],
206
+ note=token["note"],
207
+ velocity=token["velocity"],
208
+ time=int(delta_time))
209
+ track.append(msg)
210
+
211
+ elif token["type"] == "note_off":
212
+ msg = mido.Message('note_off',
213
+ channel=token["channel"],
214
+ note=token["note"],
215
+ velocity=token["velocity"],
216
+ time=int(delta_time))
217
+ track.append(msg)
218
+
219
+ elif token["type"] == "program_change":
220
+ msg = mido.Message('program_change',
221
+ channel=token["channel"],
222
+ program=token["program"],
223
+ time=int(delta_time))
224
+ track.append(msg)
225
+
226
+ elif token["type"] == "control_change":
227
+ msg = mido.Message('control_change',
228
+ channel=token["channel"],
229
+ control=token["control"],
230
+ value=token["value"],
231
+ time=int(delta_time))
232
+ track.append(msg)
233
+
234
+ elif token["type"] == "tempo":
235
+ msg = mido.MetaMessage('set_tempo',
236
+ tempo=token["tempo"],
237
+ time=int(delta_time))
238
+ track.append(msg)
239
+
240
+ elif token["type"] == "time_signature":
241
+ msg = mido.MetaMessage('time_signature',
242
+ numerator=token["numerator"],
243
+ denominator=token["denominator"],
244
+ time=int(delta_time))
245
+ track.append(msg)
246
+
247
+ elif token["type"] == "track_name":
248
+ msg = mido.MetaMessage('track_name',
249
+ name=token["name"],
250
+ time=int(delta_time))
251
+ track.append(msg)
252
+
253
+ midi_file.save(output_path)
midimusicgenapp.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """MidiMusicGenApp.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1Dn99ii_FiQTx-z5B0dX0br0Gc0U9MUqD
8
+ """
9
+
10
+ import gradio as gr
11
+ import torch
12
+ from transformers import GPT2LMHeadModel
13
+ from miditokenizer import MIDITokenizer
14
+ from genprocessor import GENProcessor, generated_tokens_to_midi
15
+ from midi2audio import FluidSynth
16
+ from pydub import AudioSegment
17
+ import tempfile
18
+ import os
19
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
20
+
21
+ # Load model and tokenizer
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ model = torch.load('model_complete_18epochs.pkl',map_location=device)
24
+ tokenizer = MIDITokenizer()
25
+ processor = GENProcessor()
26
+ model.eval()
27
+
28
+ #functions to adjust timing & combine generated song parts
29
+ def adjust_midi_timing(midi_data, start_time=0):
30
+ """Adjust MIDI timing with optional start time. Prevent large gaps based on ticks_per_beat."""
31
+ try:
32
+ # Keep tempo track separate
33
+ tempo_track = midi_data['tracks'][0]
34
+ ticks_per_beat = midi_data['metadata']['ticks_per_beat']
35
+
36
+ # Calculate thresholds based on ticks_per_beat
37
+ gap_threshold = ticks_per_beat * 2
38
+ small_increment = ticks_per_beat // 8 # Eighth note
39
+
40
+ # Get all other events and sort by time
41
+ all_events = []
42
+ for track in midi_data['tracks'][1:]:
43
+ all_events.extend(track)
44
+ all_events.sort(key=lambda x: x['time'])
45
+
46
+ # Find sequential times, ignoring large gaps
47
+ sequential_events = []
48
+ current_time = all_events[0]['time'] if all_events else 0
49
+
50
+ for event in all_events:
51
+ if event['time'] - current_time > gap_threshold:
52
+ event['time'] = current_time + small_increment
53
+ current_time = event['time']
54
+ sequential_events.append(event)
55
+
56
+ # Find first non-zero time
57
+ first_time = min((event['time'] for event in sequential_events if event['time'] != 0), default=0)
58
+
59
+ adjusted_data = {'metadata': midi_data['metadata'], 'tracks': [tempo_track]}
60
+
61
+ # Adjust all events
62
+ adjusted_track = []
63
+ for event in sequential_events:
64
+ adjusted_event = event.copy()
65
+ if event['time'] != 0:
66
+ adjusted_event['time'] = (event['time'] - first_time) + start_time
67
+ else:
68
+ adjusted_event['time'] = start_time
69
+ adjusted_track.append(adjusted_event)
70
+
71
+ adjusted_data['tracks'].append(adjusted_track)
72
+ return adjusted_data
73
+
74
+ except Exception as e:
75
+ print(f"Error adjusting MIDI timing: {str(e)}")
76
+ return midi_data
77
+
78
+ #Functions to generate music
79
+
80
+ def generate_music(prompt):
81
+ """Generate music based on a given prompt."""
82
+ # Tokenize
83
+ if tokenizer.pad_token is None:
84
+ tokenizer.pad_token = tokenizer.eos_token
85
+
86
+ inputs = tokenizer(
87
+ prompt,
88
+ return_tensors="pt",
89
+ padding=True,
90
+ truncation=True,
91
+ add_special_tokens=True
92
+ )
93
+
94
+ # Generate
95
+ output_sequences = model.generate(
96
+ input_ids=inputs["input_ids"].to(model.device),
97
+ attention_mask=inputs["attention_mask"].to(model.device),
98
+ max_length=1024,
99
+ do_sample=True,
100
+ temperature=0.6, #adjust creativity
101
+ top_k=30,
102
+ top_p=0.90,
103
+ pad_token_id=tokenizer.eos_token_id,
104
+ eos_token_id=tokenizer.eos_token_id,
105
+ )
106
+
107
+ # Decode the generated sequence
108
+ generated_text = tokenizer.decode(output_sequences[0])
109
+
110
+ return generated_text
111
+
112
+ def generate_wrapper(composer):
113
+ # Format the prompt with the selected composer
114
+ prompt = f"<|START_METADATA|> <|composer_{composer}|><metadata> ticks_per_beat="
115
+ print(prompt)
116
+ generated_text = generate_music(prompt)
117
+ midi_data=adjust_midi_timing(processor.decode_midi_file(generated_text))
118
+
119
+ # Create temp file for MIDI
120
+ with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as tmp:
121
+ generated_tokens_to_midi(midi_data, tmp.name)
122
+ # Convert to WAV
123
+ fs = FluidSynth(sound_font='FluidR3Mono_GM.sf3')
124
+ wav_file = tmp.name.replace('.mid', '.wav')
125
+ fs.midi_to_audio(tmp.name, wav_file)
126
+
127
+ mp3_file = wav_file.replace('.wav', '.mp3')
128
+ audio = AudioSegment.from_wav(wav_file)
129
+ audio.export(mp3_file, format="mp3")
130
+ return mp3_file
131
+
132
+ iface = gr.Interface(
133
+ fn=generate_wrapper,
134
+ inputs=[
135
+ gr.Dropdown(
136
+ choices=["Bach", "Chopin"],
137
+ label="Select Composer",
138
+ value="Bach" # default value
139
+ )
140
+ ],
141
+ outputs=gr.Audio(type="filepath", label="Generated MIDI"),
142
+ title="MAI: MIDI AI Music Generation Model",
143
+ description="Generate MIDI sequences",
144
+ flagging_mode="never"
145
+ )
146
+
147
+ iface.launch()
miditokenizer.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """miditokenizer.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/16YJUBYcqKYPVIhwzKNi4ELnTftr2TcUY
8
+ """
9
+
10
+ #We use a base GPT2 tokenizer with additional functions to handle composer tokens
11
+ #Datasets are created by processing our files in chunks, due to model sequence limits
12
+ #Position information is added to each chunk as additional pattern/data for training
13
+
14
+ from transformers import GPT2TokenizerFast, GPT2LMHeadModel
15
+ from torch.utils.data import Dataset
16
+ from pathlib import Path
17
+ import torch
18
+
19
+ class MIDITokenizer:
20
+ """tokenization specific to MIDI data with special tokens"""
21
+ def __init__(self, pretrained_model='gpt2'):
22
+ self.base_tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_model)
23
+ special_tokens = {
24
+ 'additional_special_tokens': [
25
+ '<|START_METADATA|>',
26
+ '<|END_METADATA|>',
27
+ '<|START_TRACK|>',
28
+ '<|END_TRACK|>',
29
+ '<metadata>',
30
+ '<tempo>',
31
+ '<time_signature>',
32
+ '<program_change>',
33
+ '<note_on>',
34
+ '<note_off>',
35
+ '<control_change>'
36
+ ],
37
+ 'pad_token': '[PAD]'
38
+ }
39
+ self.base_tokenizer.add_special_tokens(special_tokens)
40
+ self.pad_token_id = self.base_tokenizer.pad_token_id
41
+ self.eos_token_id = self.base_tokenizer.eos_token_id
42
+ self.bos_token_id = self.base_tokenizer.bos_token_id
43
+ self.pad_token = self.base_tokenizer.pad_token
44
+ self.eos_token = self.base_tokenizer.eos_token
45
+ self.bos_token = self.base_tokenizer.bos_token
46
+
47
+
48
+ def add_composer_tokens(self, composers):
49
+ #composer tokens
50
+ composer_tokens = [f'<|composer_{c}|>' for c in composers]
51
+ self.base_tokenizer.add_special_tokens({
52
+ 'additional_special_tokens': composer_tokens
53
+ })
54
+
55
+ def __call__(self, text, **kwargs):
56
+ return self.base_tokenizer(text, **kwargs)
57
+
58
+ def decode(self, token_ids, **kwargs):
59
+ """Decode while preserving special tokens"""
60
+ return self.base_tokenizer.decode(token_ids, skip_special_tokens=False, **kwargs)
61
+
62
+ def pad(self, *args, **kwargs):
63
+ return self.base_tokenizer.pad(*args, **kwargs)
64
+
65
+ def get_vocab(self):
66
+ return self.base_tokenizer.get_vocab()