MAI_MidiAI_Playback / midimusicgenapp.py
tlmdesign's picture
Upload 6 files
f542560 verified
# -*- coding: utf-8 -*-
"""MidiMusicGenApp.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1Dn99ii_FiQTx-z5B0dX0br0Gc0U9MUqD
"""
import gradio as gr
import torch
from transformers import GPT2LMHeadModel
from miditokenizer import MIDITokenizer
from genprocessor import GENProcessor, generated_tokens_to_midi
from midi2audio import FluidSynth
from pydub import AudioSegment
import tempfile
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Load model and tokenizer
torch.serialization.add_safe_globals([set])
torch.serialization.add_safe_globals([GPT2LMHeadModel])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load('model_complete_18epochs.pkl',map_location=device, weights_only=False)
tokenizer = MIDITokenizer()
processor = GENProcessor()
model.eval()
#functions to adjust timing & combine generated song parts
def adjust_midi_timing(midi_data, start_time=0):
"""Adjust MIDI timing with optional start time. Prevent large gaps based on ticks_per_beat."""
try:
# Keep tempo track separate
tempo_track = midi_data['tracks'][0]
ticks_per_beat = midi_data['metadata']['ticks_per_beat']
# Calculate thresholds based on ticks_per_beat
gap_threshold = ticks_per_beat * 2
small_increment = ticks_per_beat // 8 # Eighth note
# Get all other events and sort by time
all_events = []
for track in midi_data['tracks'][1:]:
all_events.extend(track)
all_events.sort(key=lambda x: x['time'])
# Find sequential times, ignoring large gaps
sequential_events = []
current_time = all_events[0]['time'] if all_events else 0
for event in all_events:
if event['time'] - current_time > gap_threshold:
event['time'] = current_time + small_increment
current_time = event['time']
sequential_events.append(event)
# Find first non-zero time
first_time = min((event['time'] for event in sequential_events if event['time'] != 0), default=0)
adjusted_data = {'metadata': midi_data['metadata'], 'tracks': [tempo_track]}
# Adjust all events
adjusted_track = []
for event in sequential_events:
adjusted_event = event.copy()
if event['time'] != 0:
adjusted_event['time'] = (event['time'] - first_time) + start_time
else:
adjusted_event['time'] = start_time
adjusted_track.append(adjusted_event)
adjusted_data['tracks'].append(adjusted_track)
return adjusted_data
except Exception as e:
print(f"Error adjusting MIDI timing: {str(e)}")
return midi_data
#Functions to generate music
def generate_music(prompt):
"""Generate music based on a given prompt."""
# Tokenize
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
add_special_tokens=True
)
# Generate
output_sequences = model.generate(
input_ids=inputs["input_ids"].to(model.device),
attention_mask=inputs["attention_mask"].to(model.device),
max_length=1024,
do_sample=True,
temperature=0.6, #adjust creativity
top_k=30,
top_p=0.90,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode the generated sequence
generated_text = tokenizer.decode(output_sequences[0])
return generated_text
def generate_wrapper(composer):
# Format the prompt with the selected composer
prompt = f"<|START_METADATA|> <|composer_{composer}|><metadata> ticks_per_beat="
generated_text = generate_music(prompt)
midi_data=adjust_midi_timing(processor.decode_midi_file(generated_text))
print(midi_data)
# Create temp file for MIDI
with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as tmp:
generated_tokens_to_midi(midi_data, tmp.name)
# Convert to WAV
fs = FluidSynth(sound_font='FluidR3Mono_GM.sf3')
wav_file = tmp.name.replace('.mid', '.wav')
fs.midi_to_audio(tmp.name, wav_file)
# Convert to MP3
mp3_file = wav_file.replace('.wav', '.mp3')
audio = AudioSegment.from_wav(wav_file)
audio.export(mp3_file, format="mp3")
return mp3_file
iface = gr.Interface(
fn=generate_wrapper,
inputs=[
gr.Dropdown(
choices=["Bach", "Chopin"],
label="Select Composer",
value="Bach" # default value
)
],
outputs=gr.Audio(type="filepath",label="Generated MIDI"),
title="MAI: MIDI AI Music Generation Model",
description="Select a composer whose musical style you'd like to emulate. Generate an original sequence inspired by that composer's unique sound. It should take a few minutes. Once it's ready, you can listen to the clip or download the audio file."
#description="Compose Music in the Style of Your Favorite Composer. Select a composer to generate a music sequence in the style of selected composer",
#flagging_mode="never"
)
iface.launch()