midi-extender / app.py
MarcVida
add more deps and rename main script
0167848
### IMPORTS
from midi_model import MIDIModel
from midi_tokenizer import MIDITokenizerV2
from midi_synthesizer import MidiSynthesizer
from MIDI import midi2score, score2opus, opus2midi
import numpy as np
import torch
import gradio as gr
import tempfile
### MODEL & TOKENIZER
model_name = "skytnt/midi-model-tv2o-medium"
tokenizer = MIDITokenizerV2()
model = MIDIModel.from_pretrained(model_name)
synth = MidiSynthesizer("./gm.sf2")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
### HELPER FUNCTIONS
def encode_midi(bin_score):
"""Encodes the binary contents of a MIDI file into a numpy array of tokens."""
midi = midi2score(bin_score)
tokens = tokenizer.tokenize(midi)
tokens_np = np.array(tokens, dtype=np.long)
return tokens_np
def extend_midi(bin_score, max_len, output_midi_path, generate_audio=True):
"""Extends the contents of an existing MIDI file."""
inputs = encode_midi(bin_score)
outputs = model.generate(inputs, max_len=max_len)
print(f"Generated MIDI. Output shape: {outputs.shape}")
result = tokenizer.detokenize(outputs[0])
opus = score2opus(result)
midi = opus2midi(opus)
with open(output_midi_path, "wb") as f:
f.write(midi)
if generate_audio:
audio_data = synth.synthesis(opus)
return audio_data
### EVENT HANDLERS
def process_midi(midi_file, max_len, progress=gr.Progress(track_tqdm=True)):
"""
Process the uploaded MIDI file and generate extended output.
Args:
midi_file: Path to uploaded MIDI file
Returns:
tuple: (audio_path, midi_path) for preview and download
"""
if midi_file is None:
return None, None
# Create a temporary file for the output MIDI
# The file will be automatically cleaned up by the system
temp_output = tempfile.NamedTemporaryFile(
delete=False,
prefix="output",
suffix=".mid",
mode='wb'
)
output_midi_path = temp_output.name
temp_output.close()
output_audio = extend_midi(midi_file, max_len, output_midi_path)
return (synth.sample_rate, output_audio), output_midi_path
### USER INTERFACE
# Create Gradio interface using Blocks API
with gr.Blocks(title="MIDI Extender") as demo:
gr.Markdown("# MIDI Extender")
with gr.Row():
# Left side - Input section
with gr.Column():
gr.Markdown("### Input")
midi_input = gr.File(
label="Upload MIDI File",
file_types=[".mid", ".midi"],
type="binary"
)
max_len = gr.Slider(
label="Maximum Length (tokens)",
minimum=16,
maximum=512,
value=128,
step=0
)
process_btn = gr.Button("Extend MIDI", variant="primary")
# Right side - Output section
with gr.Column():
gr.Markdown("### Output")
audio_output = gr.Audio(
label="Audio Preview",
type="numpy"
)
midi_output = gr.File(
label="Download Generated MIDI"
)
# Connect the button to the processing function
process_btn.click(
fn=process_midi,
inputs=[midi_input, max_len],
outputs=[audio_output, midi_output]
)
### RUN
if __name__ == "__main__":
demo.launch()