| | import gradio as gr
|
| | import torch
|
| | import os
|
| | import random
|
| | import glob
|
| | import numpy as np
|
| | import pretty_midi
|
| | import scipy.io.wavfile
|
| |
|
| |
|
| |
|
| | try:
|
| | from model.music_transformer import MusicTransformer
|
| | from processor import encode_midi, decode_midi
|
| | from dataset.e_piano import process_midi
|
| | from utilities.constants import *
|
| | from utilities.device import get_device, use_cuda
|
| | except ImportError as e:
|
| | print(f"Error: Could not import necessary files.")
|
| | print(f"Make sure app.py is in the same folder as 'model', 'processor.py', etc.")
|
| | print(f"Details: {e}")
|
| | exit()
|
| |
|
| |
|
| |
|
| | MODEL_CONFIG = {
|
| | "n_layers": 6,
|
| | "num_heads": 8,
|
| | "d_model": 512,
|
| | "dim_feedforward": 1024,
|
| | "max_sequence": 2048,
|
| | "rpr": True
|
| | }
|
| |
|
| |
|
| |
|
| | model = None
|
| | device = get_device()
|
| | print(f"Using device: {device}")
|
| |
|
| |
|
| | def load_model(model_path):
|
| | """
|
| | Loads the trained MusicTransformer model into memory.
|
| | """
|
| | global model
|
| | if model_path is None or not os.path.exists(model_path):
|
| | return "Error: Model file not found. Please check the path."
|
| |
|
| | try:
|
| | print("Loading model...")
|
| | model = MusicTransformer(
|
| | n_layers=MODEL_CONFIG["n_layers"],
|
| | num_heads=MODEL_CONFIG["num_heads"],
|
| | d_model=MODEL_CONFIG["d_model"],
|
| | dim_feedforward=MODEL_CONFIG["dim_feedforward"],
|
| | max_sequence=MODEL_CONFIG["max_sequence"],
|
| | rpr=MODEL_CONFIG["rpr"]
|
| | ).to(device)
|
| |
|
| |
|
| | model.load_state_dict(
|
| | torch.load(model_path, map_location=device, weights_only=True)
|
| | )
|
| | model.eval()
|
| | print("Model loaded successfully.")
|
| | return f"Model '{model_path}' loaded successfully."
|
| | except Exception as e:
|
| | return f"Error loading model: {e}"
|
| |
|
| |
|
| |
|
| | def midi_to_wav(midi_file_path, wav_file_path):
|
| | """
|
| | Synthesizes a MIDI file to a WAV file using pretty_midi's
|
| | built-in (simple) sine wave synthesizer.
|
| | """
|
| | try:
|
| | pm = pretty_midi.PrettyMIDI(midi_file_path)
|
| |
|
| | audio_data = pm.synthesize(fs=44100)
|
| |
|
| | scipy.io.wavfile.write(wav_file_path, 44100, audio_data.astype(np.int16))
|
| | return wav_file_path
|
| | except Exception as e:
|
| | print(f"Error during MIDI to WAV conversion: {e}")
|
| | return None
|
| |
|
| |
|
| |
|
| |
|
| | def generate_music(primer_type, uploaded_midi, upload_start_location, maestro_path, maestro_start_location,
|
| | primer_length, generation_length_new, progress=gr.Progress(track_tqdm=True)):
|
| | """
|
| | The main function called by the Gradio button.
|
| | """
|
| | global model
|
| | if model is None:
|
| |
|
| | yield "Error: Model is not loaded. Please load a model first.", None, None
|
| |
|
| | try:
|
| |
|
| | primer = None
|
| | num_primer = 0
|
| |
|
| | total_target_length = primer_length + generation_length_new
|
| | if total_target_length > MODEL_CONFIG["max_sequence"]:
|
| | total_target_length = MODEL_CONFIG["max_sequence"]
|
| | yield f"Warning: Clamping to {total_target_length} tokens.", None, None
|
| |
|
| | if primer_type == "Generate from Silence":
|
| | yield "Generating from silence...", None, None
|
| | primer = torch.tensor([372], dtype=TORCH_LABEL_TYPE, device=device)
|
| | num_primer = 1
|
| |
|
| | elif primer_type == "Random Maestro MIDI":
|
| | yield "Finding random Maestro file...", None, None
|
| | if maestro_path is None or not os.path.isdir(maestro_path):
|
| | yield f"Error: Maestro path '{maestro_path}' is not valid.", None, None
|
| | return
|
| |
|
| | midi_files = glob.glob(os.path.join(maestro_path, "**", "*.mid"), recursive=True) + \
|
| | glob.glob(os.path.join(maestro_path, "**", "*.midi"), recursive=True)
|
| |
|
| | if not midi_files:
|
| | yield f"Error: No .mid/.midi files found in '{maestro_path}'.", None, None
|
| | return
|
| |
|
| | random_file = random.choice(midi_files)
|
| | yield f"Tokenizing random file: {os.path.basename(random_file)}...", None, None
|
| | raw_mid = encode_midi(random_file)
|
| |
|
| | is_random_start = (maestro_start_location == "Random Location")
|
| | primer_tokens, _ = process_midi(raw_mid, primer_length, random_seq=is_random_start)
|
| |
|
| | primer = torch.tensor(primer_tokens, dtype=TORCH_LABEL_TYPE, device=device)
|
| | num_primer = primer.shape[0]
|
| |
|
| | elif primer_type == "Upload My Own MIDI":
|
| | if uploaded_midi is None:
|
| | yield "Error: Please upload a MIDI file.", None, None
|
| | return
|
| |
|
| | yield f"Tokenizing uploaded MIDI: {os.path.basename(uploaded_midi.name)}...", None, None
|
| | raw_mid = encode_midi(uploaded_midi.name)
|
| | if not raw_mid:
|
| | yield "Error: Could not read MIDI messages.", None, None
|
| | return
|
| |
|
| | is_random_start = (upload_start_location == "Random Location")
|
| | primer_tokens, _ = process_midi(raw_mid, primer_length, random_seq=is_random_start)
|
| | primer = torch.tensor(primer_tokens, dtype=TORCH_LABEL_TYPE, device=device)
|
| | num_primer = primer.shape[0]
|
| |
|
| | if num_primer == 0:
|
| | yield "Error: Primer processing resulted in 0 tokens.", None, None
|
| | return
|
| |
|
| |
|
| | yield f"Primed with {num_primer} tokens. Generating {generation_length_new} new tokens...", None, None
|
| |
|
| | primer_batch = primer.unsqueeze(0)
|
| |
|
| | model.eval()
|
| | with torch.set_grad_enabled(False):
|
| | rand_seq = model.generate(primer_batch, total_target_length, beam=0)
|
| |
|
| |
|
| | generated_only_tokens = rand_seq[0][num_primer:]
|
| |
|
| | if len(generated_only_tokens) == 0:
|
| | yield "Warning: Generation produced 0 new tokens.", None, None
|
| | return
|
| |
|
| |
|
| | midi_output_filename = "generation_output.mid"
|
| | wav_output_filename = "generation_output.wav"
|
| |
|
| |
|
| | decode_midi(generated_only_tokens.cpu().numpy(), midi_output_filename)
|
| |
|
| |
|
| | yield "Synthesizing audio...", midi_output_filename, None
|
| | wav_path = midi_to_wav(midi_output_filename, wav_output_filename)
|
| |
|
| | if wav_path:
|
| | yield "Generation Complete!", midi_output_filename, wav_path
|
| | else:
|
| | yield "Generation complete (WAV synthesis failed).", midi_output_filename, None
|
| |
|
| | except Exception as e:
|
| | yield f"An error occurred: {e}", None, None
|
| |
|
| |
|
| |
|
| | with gr.Blocks(theme=gr.themes.Soft()) as app:
|
| | gr.Markdown("# 🎹 Music Transformer Generation UI")
|
| | gr.Markdown("Load your trained model and generate music from silence, a random seed, or your own MIDI file.")
|
| |
|
| | with gr.Row():
|
| | with gr.Column(scale=1):
|
| | gr.Markdown("### 1. Load Model")
|
| | model_path_input = gr.Textbox(
|
| | label="Path to your .pickle model file",
|
| | value="best_acc_weights.pickle"
|
| | )
|
| | load_button = gr.Button("Load Model", variant="primary")
|
| | load_status = gr.Textbox(label="Model Status", interactive=False)
|
| |
|
| | with gr.Column(scale=2):
|
| | gr.Markdown("### 2. Configure Generation")
|
| |
|
| | primer_type_input = gr.Radio(
|
| | label="Choose Primer Type",
|
| | choices=["Generate from Silence", "Random Maestro MIDI", "Upload My Own MIDI"],
|
| | value="Generate from Silence"
|
| | )
|
| |
|
| | with gr.Column(visible=False) as maestro_options:
|
| | maestro_path_input = gr.Textbox(
|
| | label="Path to RAW Maestro MIDI Folder (searches all subfolders)",
|
| | value="./maestro-v2.0.0"
|
| | )
|
| | maestro_start_location_input = gr.Radio(
|
| | label="Primer Start Location",
|
| | choices=["Start of File", "Random Location"],
|
| | value="Random Location",
|
| | info="Selects a random chunk from the file, giving more variety."
|
| | )
|
| |
|
| | with gr.Column(visible=False) as upload_options:
|
| | uploaded_midi_input = gr.File(
|
| | label="Upload Your MIDI Primer",
|
| | file_types=[".mid", ".midi"]
|
| | )
|
| | upload_start_location_input = gr.Radio(
|
| | label="Primer Start Location",
|
| | choices=["Start of File", "Random Location"],
|
| | value="Start of File"
|
| | )
|
| |
|
| | primer_length_slider = gr.Slider(
|
| | label="Primer Length (Tokens)",
|
| | minimum=64,
|
| | maximum=2000,
|
| | value=512,
|
| | step=32,
|
| | info="How many tokens to use from the primer file. Ignored for 'Silence'."
|
| | )
|
| |
|
| | generation_length_slider = gr.Slider(
|
| | label="New Tokens to Generate",
|
| | minimum=128,
|
| | maximum=2048,
|
| | value=1024,
|
| | step=32,
|
| | info="How many new tokens to create after the primer."
|
| | )
|
| |
|
| | generate_button = gr.Button("Generate Music", variant="primary")
|
| |
|
| | with gr.Row():
|
| | gr.Markdown("### 3. Get Your Music")
|
| | status_output = gr.Textbox(label="Status", interactive=False)
|
| | with gr.Row():
|
| | output_midi_file = gr.File(label="Download Generated MIDI")
|
| |
|
| | output_wav_file = gr.Audio(label="Listen to Generated WAV", type="filepath")
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def update_ui(primer_type):
|
| | return {
|
| | maestro_options: gr.Column(visible=(primer_type == "Random Maestro MIDI")),
|
| | upload_options: gr.Column(visible=(primer_type == "Upload My Own MIDI")),
|
| | primer_length_slider: gr.Slider(visible=(primer_type != "Generate from Silence"))
|
| | }
|
| |
|
| |
|
| | primer_type_input.change(
|
| | fn=update_ui,
|
| | inputs=primer_type_input,
|
| | outputs=[maestro_options, upload_options, primer_length_slider]
|
| | )
|
| |
|
| | load_button.click(
|
| | fn=load_model,
|
| | inputs=model_path_input,
|
| | outputs=load_status
|
| | )
|
| |
|
| |
|
| | generate_button.click(
|
| | fn=generate_music,
|
| | inputs=[
|
| | primer_type_input,
|
| | uploaded_midi_input,
|
| | upload_start_location_input,
|
| | maestro_path_input,
|
| | maestro_start_location_input,
|
| | primer_length_slider,
|
| | generation_length_slider
|
| | ],
|
| | outputs=[status_output, output_midi_file, output_wav_file]
|
| | )
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| |
|
| | if (not torch.cuda.is_available()):
|
| | print("----- WARNING: CUDA devices not detected. This will cause the model to run very slow! -----")
|
| | use_cuda(False)
|
| |
|
| | print("Launching Gradio UI...")
|
| | app.launch() |