tlmdesign's picture
Upload app.py
71388c0 verified
# -*- coding: utf-8 -*-
"""loop_ downloadOnly_MidiMusicGenApp.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1gFFJW56DAbeLYqKWqi6jZN_einOQBf4j
"""
import gradio as gr
import torch
import gc
from transformers import GPT2LMHeadModel
from miditokenizer import MIDITokenizer
from genprocessor import GENProcessor, generated_tokens_to_midi
from midi2audio import FluidSynth
from pydub import AudioSegment
import tempfile
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Load model and tokenizer
torch.serialization.add_safe_globals([set])
torch.serialization.add_safe_globals([GPT2LMHeadModel])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load('model_complete_18epochs.pkl',map_location=device, weights_only=False)
tokenizer = MIDITokenizer()
processor = GENProcessor()
model.eval()
class State:
def __init__(self):
self.generated_text = None # Store the text representation of the music
state = State()
#functions to adjust timing & combine generated song parts
def adjust_midi_timing(midi_data, start_time=0):
"""Adjust MIDI timing with optional start time. Prevent large gaps based on ticks_per_beat."""
try:
# Keep tempo track separate
tempo_track = midi_data['tracks'][0]
ticks_per_beat = midi_data['metadata']['ticks_per_beat']
# Calculate thresholds based on ticks_per_beat
gap_threshold = ticks_per_beat * 2
small_increment = ticks_per_beat // 8 # Eighth note
# Get all other events and sort by time
all_events = []
for track in midi_data['tracks'][1:]:
all_events.extend(track)
all_events.sort(key=lambda x: x['time'])
# Find sequential times, ignoring large gaps
sequential_events = []
current_time = all_events[0]['time'] if all_events else 0
for event in all_events:
if event['time'] - current_time > gap_threshold:
event['time'] = current_time + small_increment
current_time = event['time']
sequential_events.append(event)
# Find first non-zero time
first_time = min((event['time'] for event in sequential_events if event['time'] != 0), default=0)
adjusted_data = {'metadata': midi_data['metadata'], 'tracks': [tempo_track]}
# Adjust all events
adjusted_track = []
for event in sequential_events:
adjusted_event = event.copy()
if event['time'] != 0:
adjusted_event['time'] = (event['time'] - first_time) + start_time
else:
adjusted_event['time'] = start_time
adjusted_track.append(adjusted_event)
adjusted_data['tracks'].append(adjusted_track)
return adjusted_data
except Exception as e:
print(f"Error adjusting MIDI timing: {str(e)}")
return midi_data
def combine_tracks(first, continued):
"""Combine two generated sequences into a single song"""
processor = GENProcessor()
# Check if first input is already decoded or needs decoding
if isinstance(first, str):
gendecoded = processor.decode_midi_file(first)
else:
gendecoded = first
first_midi = adjust_midi_timing(gendecoded, start_time=0)
# Get last time from the single note track
last_time = max(event['time'] for event in first_midi['tracks'][1])
# adjust timing of second midi
second_midi = adjust_midi_timing(processor.decode_midi_file(continued), start_time=last_time)
# Combine into a single song
full_song = {
'metadata': first_midi['metadata'],
'tracks': [
first_midi['tracks'][0], # Keep tempo track
first_midi['tracks'][1] + second_midi['tracks'][1] # Combine note tracks
]
}
return full_song
def extract_context(text_or_midi, num_events=3):
"""get metadata, composer, and last few events from generated text to create new prompt to continue sequence"""
result_metadata = None
if isinstance(text_or_midi, dict):
# Extract metadata from MIDI structure
ticks = text_or_midi['metadata']['ticks_per_beat']
numerator = text_or_midi['tracks'][0][1]['numerator']
composer = text_or_midi.get('composer', 'Bach') #default to Bach if missing
# Get last events from track
all_events = text_or_midi['tracks'][1]
sorted_events = sorted(all_events, key=lambda x: x['time'])
last_events = sorted_events[-num_events:]
# Format metadata string
result_metadata = (f"<|START_METADATA|> <|composer_{composer}|><metadata> "
f"ticks_per_beat={ticks} <|START_TRACK|> "
f"tempo=500000 <time_signature> time=0 numerator={numerator} denominator=4")
# Format context string from last events
context = " ".join(f"<{event['type']}> time={event['time']} channel={event['channel']} " +
(f"note={event['note']} velocity={event['velocity']}" if 'note' in event else
f"control={event['control']} value={event['value']}")
for event in last_events)
last_time = last_events[-1]['time'] if last_events else 0
else: # Input is string
if "<|START_METADATA|>" in text_or_midi:
composer_match = re.search(r"<\|composer_([^|]+)\|>", text_or_midi)
ticks_match = re.search(r"ticks[_]?(?:per_)?beat=(\d+)", text_or_midi)
time_sig_match = re.search(r"numerator=(\d+)", text_or_midi)
if composer_match and ticks_match:
composer = composer_match.group(1)
numerator = 4
if time_sig_match:
try:
num = int(time_sig_match.group(1))
numerator = min(4, max(2, num))
except ValueError:
pass
result_metadata = (f"<|START_METADATA|> <|composer_{composer}|><metadata> "
f"ticks_per_beat={max(75, int(ticks_match.group(1)))} <|START_TRACK|> "
f"tempo=500000 <time_signature> time=0 numerator={numerator} denominator=4")
# Get last N complete events and timestamp
events = []
matches = re.finditer(r"<(note_on|control_change|note_off)>.*?(value=\d+|velocity=\d+)", text_or_midi)
events = list(matches)[-num_events:]
context = " ".join(event.group(0) for event in events)
last_time = None
time_match = re.search(r"time=(\d+)", context)
if time_match:
last_time = int(time_match.group(1))
return result_metadata, context, last_time
def continue_sequence(generated_text, num_loops=1):
"""continue the sequence, extract info from previous sequence to create prompt, append to prevous sequence"""
full_song = generated_text
for i in range(num_loops):
print(f"Generating loop {i+1}/{num_loops}")
metadata, context, last_time = extract_context(full_song)
print(metadata+context)
continued = generate_music(metadata+context)
full_song = combine_tracks(full_song, continued)
return full_song
#Functions to generate music
def generate_music(prompt):
"""Generate music based on a given prompt."""
# Tokenize
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
add_special_tokens=True
)
# Generate
output_sequences = model.generate(
input_ids=inputs["input_ids"].to(model.device),
attention_mask=inputs["attention_mask"].to(model.device),
max_length=1024,
do_sample=True,
temperature=0.6, #adjust creativity
top_k=30,
top_p=0.90,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode the generated sequence
generated_text = tokenizer.decode(output_sequences[0])
return generated_text
def generate_wrapper(composer):
try:
# Clear memory before generation
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Format the prompt with the selected composer
prompt = f"<|START_METADATA|> <|composer_{composer}|><metadata> ticks_per_beat="
generated_text = generate_music(prompt)
midi_data = adjust_midi_timing(processor.decode_midi_file(generated_text))
state.generated_text = midi_data
# Create temp file for MIDI
with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as tmp:
generated_tokens_to_midi(midi_data, tmp.name)
return tmp.name, gr.update(visible=True)
finally:
# Clear memory after generation
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def continue_wrapper():
"""Wrapper for continue_sequence"""
if state.generated_text is not None:
# Continue the sequence using your existing function
extended_text = continue_sequence(state.generated_text, num_loops=8)
state.generated_text = extended_text # Update the stored text
# Create temp file for MIDI
with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as tmp:
generated_tokens_to_midi(extended_text, tmp.name)
return tmp.name
return None
with gr.Blocks() as iface:
gr.Markdown("""
# MAI: MIDI AI Music Generation Model
Select a composer whose musical style you'd like to emulate. Generate an original sequence inspired by that composer's unique sound.
It should take a few minutes. Once it's ready, you can download the audio file.
If you like the opening, you can continue the sequence and make your song longer, repeat, or try again. It may take a few minutes to continue the sequence.
""")
with gr.Column():
composer_input = gr.Dropdown(
choices=["Bach", "Chopin"],
label="Select Composer",
value="Bach"
)
generate_btn = gr.Button("Generate Music")
output_file = gr.File(label="Generated MIDI File")
continue_btn = gr.Button("Continue Sequence (add 10 seconds)", visible=False)
#generate_btn.click(
# fn=generate_wrapper,
# inputs=composer_input,
# outputs=[output_file, continue_btn]
#)
generate_btn.click(
lambda: (None, gr.update(visible=False)),
None,
[output_file, continue_btn],
queue=False
).success(
generate_wrapper,
inputs=composer_input,
outputs=[output_file, continue_btn]
)
continue_btn.click(
fn=continue_wrapper,
outputs=output_file
)
iface.launch()