Spaces:
Sleeping
Sleeping
File size: 10,972 Bytes
73779e1 71388c0 73779e1 71388c0 73779e1 71388c0 73779e1 71388c0 73779e1 71388c0 73779e1 71388c0 73779e1 71388c0 73779e1 71388c0 73779e1 71388c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 |
# -*- coding: utf-8 -*-
"""loop_ downloadOnly_MidiMusicGenApp.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1gFFJW56DAbeLYqKWqi6jZN_einOQBf4j
"""
import gradio as gr
import torch
import gc
from transformers import GPT2LMHeadModel
from miditokenizer import MIDITokenizer
from genprocessor import GENProcessor, generated_tokens_to_midi
from midi2audio import FluidSynth
from pydub import AudioSegment
import tempfile
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Load model and tokenizer
torch.serialization.add_safe_globals([set])
torch.serialization.add_safe_globals([GPT2LMHeadModel])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load('model_complete_18epochs.pkl',map_location=device, weights_only=False)
tokenizer = MIDITokenizer()
processor = GENProcessor()
model.eval()
class State:
def __init__(self):
self.generated_text = None # Store the text representation of the music
state = State()
#functions to adjust timing & combine generated song parts
def adjust_midi_timing(midi_data, start_time=0):
"""Adjust MIDI timing with optional start time. Prevent large gaps based on ticks_per_beat."""
try:
# Keep tempo track separate
tempo_track = midi_data['tracks'][0]
ticks_per_beat = midi_data['metadata']['ticks_per_beat']
# Calculate thresholds based on ticks_per_beat
gap_threshold = ticks_per_beat * 2
small_increment = ticks_per_beat // 8 # Eighth note
# Get all other events and sort by time
all_events = []
for track in midi_data['tracks'][1:]:
all_events.extend(track)
all_events.sort(key=lambda x: x['time'])
# Find sequential times, ignoring large gaps
sequential_events = []
current_time = all_events[0]['time'] if all_events else 0
for event in all_events:
if event['time'] - current_time > gap_threshold:
event['time'] = current_time + small_increment
current_time = event['time']
sequential_events.append(event)
# Find first non-zero time
first_time = min((event['time'] for event in sequential_events if event['time'] != 0), default=0)
adjusted_data = {'metadata': midi_data['metadata'], 'tracks': [tempo_track]}
# Adjust all events
adjusted_track = []
for event in sequential_events:
adjusted_event = event.copy()
if event['time'] != 0:
adjusted_event['time'] = (event['time'] - first_time) + start_time
else:
adjusted_event['time'] = start_time
adjusted_track.append(adjusted_event)
adjusted_data['tracks'].append(adjusted_track)
return adjusted_data
except Exception as e:
print(f"Error adjusting MIDI timing: {str(e)}")
return midi_data
def combine_tracks(first, continued):
"""Combine two generated sequences into a single song"""
processor = GENProcessor()
# Check if first input is already decoded or needs decoding
if isinstance(first, str):
gendecoded = processor.decode_midi_file(first)
else:
gendecoded = first
first_midi = adjust_midi_timing(gendecoded, start_time=0)
# Get last time from the single note track
last_time = max(event['time'] for event in first_midi['tracks'][1])
# adjust timing of second midi
second_midi = adjust_midi_timing(processor.decode_midi_file(continued), start_time=last_time)
# Combine into a single song
full_song = {
'metadata': first_midi['metadata'],
'tracks': [
first_midi['tracks'][0], # Keep tempo track
first_midi['tracks'][1] + second_midi['tracks'][1] # Combine note tracks
]
}
return full_song
def extract_context(text_or_midi, num_events=3):
"""get metadata, composer, and last few events from generated text to create new prompt to continue sequence"""
result_metadata = None
if isinstance(text_or_midi, dict):
# Extract metadata from MIDI structure
ticks = text_or_midi['metadata']['ticks_per_beat']
numerator = text_or_midi['tracks'][0][1]['numerator']
composer = text_or_midi.get('composer', 'Bach') #default to Bach if missing
# Get last events from track
all_events = text_or_midi['tracks'][1]
sorted_events = sorted(all_events, key=lambda x: x['time'])
last_events = sorted_events[-num_events:]
# Format metadata string
result_metadata = (f"<|START_METADATA|> <|composer_{composer}|><metadata> "
f"ticks_per_beat={ticks} <|START_TRACK|> "
f"tempo=500000 <time_signature> time=0 numerator={numerator} denominator=4")
# Format context string from last events
context = " ".join(f"<{event['type']}> time={event['time']} channel={event['channel']} " +
(f"note={event['note']} velocity={event['velocity']}" if 'note' in event else
f"control={event['control']} value={event['value']}")
for event in last_events)
last_time = last_events[-1]['time'] if last_events else 0
else: # Input is string
if "<|START_METADATA|>" in text_or_midi:
composer_match = re.search(r"<\|composer_([^|]+)\|>", text_or_midi)
ticks_match = re.search(r"ticks[_]?(?:per_)?beat=(\d+)", text_or_midi)
time_sig_match = re.search(r"numerator=(\d+)", text_or_midi)
if composer_match and ticks_match:
composer = composer_match.group(1)
numerator = 4
if time_sig_match:
try:
num = int(time_sig_match.group(1))
numerator = min(4, max(2, num))
except ValueError:
pass
result_metadata = (f"<|START_METADATA|> <|composer_{composer}|><metadata> "
f"ticks_per_beat={max(75, int(ticks_match.group(1)))} <|START_TRACK|> "
f"tempo=500000 <time_signature> time=0 numerator={numerator} denominator=4")
# Get last N complete events and timestamp
events = []
matches = re.finditer(r"<(note_on|control_change|note_off)>.*?(value=\d+|velocity=\d+)", text_or_midi)
events = list(matches)[-num_events:]
context = " ".join(event.group(0) for event in events)
last_time = None
time_match = re.search(r"time=(\d+)", context)
if time_match:
last_time = int(time_match.group(1))
return result_metadata, context, last_time
def continue_sequence(generated_text, num_loops=1):
"""continue the sequence, extract info from previous sequence to create prompt, append to prevous sequence"""
full_song = generated_text
for i in range(num_loops):
print(f"Generating loop {i+1}/{num_loops}")
metadata, context, last_time = extract_context(full_song)
print(metadata+context)
continued = generate_music(metadata+context)
full_song = combine_tracks(full_song, continued)
return full_song
#Functions to generate music
def generate_music(prompt):
"""Generate music based on a given prompt."""
# Tokenize
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
add_special_tokens=True
)
# Generate
output_sequences = model.generate(
input_ids=inputs["input_ids"].to(model.device),
attention_mask=inputs["attention_mask"].to(model.device),
max_length=1024,
do_sample=True,
temperature=0.6, #adjust creativity
top_k=30,
top_p=0.90,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode the generated sequence
generated_text = tokenizer.decode(output_sequences[0])
return generated_text
def generate_wrapper(composer):
try:
# Clear memory before generation
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Format the prompt with the selected composer
prompt = f"<|START_METADATA|> <|composer_{composer}|><metadata> ticks_per_beat="
generated_text = generate_music(prompt)
midi_data = adjust_midi_timing(processor.decode_midi_file(generated_text))
state.generated_text = midi_data
# Create temp file for MIDI
with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as tmp:
generated_tokens_to_midi(midi_data, tmp.name)
return tmp.name, gr.update(visible=True)
finally:
# Clear memory after generation
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def continue_wrapper():
"""Wrapper for continue_sequence"""
if state.generated_text is not None:
# Continue the sequence using your existing function
extended_text = continue_sequence(state.generated_text, num_loops=8)
state.generated_text = extended_text # Update the stored text
# Create temp file for MIDI
with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as tmp:
generated_tokens_to_midi(extended_text, tmp.name)
return tmp.name
return None
with gr.Blocks() as iface:
gr.Markdown("""
# MAI: MIDI AI Music Generation Model
Select a composer whose musical style you'd like to emulate. Generate an original sequence inspired by that composer's unique sound.
It should take a few minutes. Once it's ready, you can download the audio file.
If you like the opening, you can continue the sequence and make your song longer, repeat, or try again. It may take a few minutes to continue the sequence.
""")
with gr.Column():
composer_input = gr.Dropdown(
choices=["Bach", "Chopin"],
label="Select Composer",
value="Bach"
)
generate_btn = gr.Button("Generate Music")
output_file = gr.File(label="Generated MIDI File")
continue_btn = gr.Button("Continue Sequence (add 10 seconds)", visible=False)
#generate_btn.click(
# fn=generate_wrapper,
# inputs=composer_input,
# outputs=[output_file, continue_btn]
#)
generate_btn.click(
lambda: (None, gr.update(visible=False)),
None,
[output_file, continue_btn],
queue=False
).success(
generate_wrapper,
inputs=composer_input,
outputs=[output_file, continue_btn]
)
continue_btn.click(
fn=continue_wrapper,
outputs=output_file
)
iface.launch()
|