NanoMaestro / app.py
utkucoban's picture
NanoMaestro Full model weights released
47dfee0 verified
import gradio as gr
import torch
import os
import random
import glob
import numpy as np
import pretty_midi
import scipy.io.wavfile
# --- Dependencies from your project ---
# (Make sure these files are in the same directory)
try:
from model.music_transformer import MusicTransformer
from processor import encode_midi, decode_midi
from dataset.e_piano import process_midi
from utilities.constants import *
from utilities.device import get_device, use_cuda
except ImportError as e:
print(f"Error: Could not import necessary files.")
print(f"Make sure app.py is in the same folder as 'model', 'processor.py', etc.")
print(f"Details: {e}")
exit()
# --- Your Model's Hyperparameters ---
# (Pulled from your training logs)
MODEL_CONFIG = {
"n_layers": 6,
"num_heads": 8,
"d_model": 512,
"dim_feedforward": 1024,
"max_sequence": 2048,
"rpr": True
}
# ------------------------------------
# Global variable to hold the loaded model
model = None
device = get_device()
print(f"Using device: {device}")
def load_model(model_path):
"""
Loads the trained MusicTransformer model into memory.
"""
global model
if model_path is None or not os.path.exists(model_path):
return "Error: Model file not found. Please check the path."
try:
print("Loading model...")
model = MusicTransformer(
n_layers=MODEL_CONFIG["n_layers"],
num_heads=MODEL_CONFIG["num_heads"],
d_model=MODEL_CONFIG["d_model"],
dim_feedforward=MODEL_CONFIG["dim_feedforward"],
max_sequence=MODEL_CONFIG["max_sequence"],
rpr=MODEL_CONFIG["rpr"]
).to(device)
# Load the weights, mapping to the correct device
model.load_state_dict(
torch.load(model_path, map_location=device, weights_only=True)
)
model.eval()
print("Model loaded successfully.")
return f"Model '{model_path}' loaded successfully."
except Exception as e:
return f"Error loading model: {e}"
# --- NEW FUNCTION ---
def midi_to_wav(midi_file_path, wav_file_path):
"""
Synthesizes a MIDI file to a WAV file using pretty_midi's
built-in (simple) sine wave synthesizer.
"""
try:
pm = pretty_midi.PrettyMIDI(midi_file_path)
# Synthesize the audio at a 44.1kHz sample rate
audio_data = pm.synthesize(fs=44100)
# Write as a 16-bit WAV file
scipy.io.wavfile.write(wav_file_path, 44100, audio_data.astype(np.int16))
return wav_file_path
except Exception as e:
print(f"Error during MIDI to WAV conversion: {e}")
return None
# --- END NEW FUNCTION ---
def generate_music(primer_type, uploaded_midi, upload_start_location, maestro_path, maestro_start_location,
primer_length, generation_length_new, progress=gr.Progress(track_tqdm=True)):
"""
The main function called by the Gradio button.
"""
global model
if model is None:
# --- MODIFICATION: Return 3 values on error ---
yield "Error: Model is not loaded. Please load a model first.", None, None
try:
# --- 1. Prepare the Primer ---
primer = None
num_primer = 0
total_target_length = primer_length + generation_length_new
if total_target_length > MODEL_CONFIG["max_sequence"]:
total_target_length = MODEL_CONFIG["max_sequence"]
yield f"Warning: Clamping to {total_target_length} tokens.", None, None
if primer_type == "Generate from Silence":
yield "Generating from silence...", None, None
primer = torch.tensor([372], dtype=TORCH_LABEL_TYPE, device=device)
num_primer = 1
elif primer_type == "Random Maestro MIDI":
yield "Finding random Maestro file...", None, None
if maestro_path is None or not os.path.isdir(maestro_path):
yield f"Error: Maestro path '{maestro_path}' is not valid.", None, None
return
midi_files = glob.glob(os.path.join(maestro_path, "**", "*.mid"), recursive=True) + \
glob.glob(os.path.join(maestro_path, "**", "*.midi"), recursive=True)
if not midi_files:
yield f"Error: No .mid/.midi files found in '{maestro_path}'.", None, None
return
random_file = random.choice(midi_files)
yield f"Tokenizing random file: {os.path.basename(random_file)}...", None, None
raw_mid = encode_midi(random_file)
is_random_start = (maestro_start_location == "Random Location")
primer_tokens, _ = process_midi(raw_mid, primer_length, random_seq=is_random_start)
primer = torch.tensor(primer_tokens, dtype=TORCH_LABEL_TYPE, device=device)
num_primer = primer.shape[0]
elif primer_type == "Upload My Own MIDI":
if uploaded_midi is None:
yield "Error: Please upload a MIDI file.", None, None
return
yield f"Tokenizing uploaded MIDI: {os.path.basename(uploaded_midi.name)}...", None, None
raw_mid = encode_midi(uploaded_midi.name)
if not raw_mid:
yield "Error: Could not read MIDI messages.", None, None
return
is_random_start = (upload_start_location == "Random Location")
primer_tokens, _ = process_midi(raw_mid, primer_length, random_seq=is_random_start)
primer = torch.tensor(primer_tokens, dtype=TORCH_LABEL_TYPE, device=device)
num_primer = primer.shape[0]
if num_primer == 0:
yield "Error: Primer processing resulted in 0 tokens.", None, None
return
# --- 2. Run Generation ---
yield f"Primed with {num_primer} tokens. Generating {generation_length_new} new tokens...", None, None
primer_batch = primer.unsqueeze(0)
model.eval()
with torch.set_grad_enabled(False):
rand_seq = model.generate(primer_batch, total_target_length, beam=0)
# --- 3. Process and Save Output ---
generated_only_tokens = rand_seq[0][num_primer:]
if len(generated_only_tokens) == 0:
yield "Warning: Generation produced 0 new tokens.", None, None
return
# --- MODIFICATION: Define output paths ---
midi_output_filename = "generation_output.mid"
wav_output_filename = "generation_output.wav"
# Save the MIDI file
decode_midi(generated_only_tokens.cpu().numpy(), midi_output_filename)
# --- MODIFICATION: Synthesize MIDI to WAV ---
yield "Synthesizing audio...", midi_output_filename, None
wav_path = midi_to_wav(midi_output_filename, wav_output_filename)
if wav_path:
yield "Generation Complete!", midi_output_filename, wav_path
else:
yield "Generation complete (WAV synthesis failed).", midi_output_filename, None
except Exception as e:
yield f"An error occurred: {e}", None, None
# --- Build the Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("# 🎹 Music Transformer Generation UI")
gr.Markdown("Load your trained model and generate music from silence, a random seed, or your own MIDI file.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Load Model")
model_path_input = gr.Textbox(
label="Path to your .pickle model file",
value="best_acc_weights.pickle"
)
load_button = gr.Button("Load Model", variant="primary")
load_status = gr.Textbox(label="Model Status", interactive=False)
with gr.Column(scale=2):
gr.Markdown("### 2. Configure Generation")
primer_type_input = gr.Radio(
label="Choose Primer Type",
choices=["Generate from Silence", "Random Maestro MIDI", "Upload My Own MIDI"],
value="Generate from Silence"
)
with gr.Column(visible=False) as maestro_options:
maestro_path_input = gr.Textbox(
label="Path to RAW Maestro MIDI Folder (searches all subfolders)",
value="./maestro-v2.0.0"
)
maestro_start_location_input = gr.Radio(
label="Primer Start Location",
choices=["Start of File", "Random Location"],
value="Random Location",
info="Selects a random chunk from the file, giving more variety."
)
with gr.Column(visible=False) as upload_options:
uploaded_midi_input = gr.File(
label="Upload Your MIDI Primer",
file_types=[".mid", ".midi"]
)
upload_start_location_input = gr.Radio(
label="Primer Start Location",
choices=["Start of File", "Random Location"],
value="Start of File"
)
primer_length_slider = gr.Slider(
label="Primer Length (Tokens)",
minimum=64,
maximum=2000,
value=512,
step=32,
info="How many tokens to use from the primer file. Ignored for 'Silence'."
)
generation_length_slider = gr.Slider(
label="New Tokens to Generate",
minimum=128,
maximum=2048,
value=1024,
step=32,
info="How many new tokens to create after the primer."
)
generate_button = gr.Button("Generate Music", variant="primary")
with gr.Row():
gr.Markdown("### 3. Get Your Music")
status_output = gr.Textbox(label="Status", interactive=False)
with gr.Row():
output_midi_file = gr.File(label="Download Generated MIDI")
# --- MODIFICATION: Added Audio player ---
output_wav_file = gr.Audio(label="Listen to Generated WAV", type="filepath")
# --- END MODIFICATION ---
# --- UI Event Listeners ---
def update_ui(primer_type):
return {
maestro_options: gr.Column(visible=(primer_type == "Random Maestro MIDI")),
upload_options: gr.Column(visible=(primer_type == "Upload My Own MIDI")),
primer_length_slider: gr.Slider(visible=(primer_type != "Generate from Silence"))
}
primer_type_input.change(
fn=update_ui,
inputs=primer_type_input,
outputs=[maestro_options, upload_options, primer_length_slider]
)
load_button.click(
fn=load_model,
inputs=model_path_input,
outputs=load_status
)
# --- MODIFICATION: Updated outputs list ---
generate_button.click(
fn=generate_music,
inputs=[
primer_type_input,
uploaded_midi_input,
upload_start_location_input,
maestro_path_input,
maestro_start_location_input,
primer_length_slider,
generation_length_slider
],
outputs=[status_output, output_midi_file, output_wav_file] # <-- Added WAV output
)
# --- END MODIFICATION ---
if __name__ == "__main__":
# Check if CUDA is available and set device
if (not torch.cuda.is_available()):
print("----- WARNING: CUDA devices not detected. This will cause the model to run very slow! -----")
use_cuda(False)
print("Launching Gradio UI...")
app.launch()