### IMPORTS from midi_model import MIDIModel from midi_tokenizer import MIDITokenizerV2 from midi_synthesizer import MidiSynthesizer from MIDI import midi2score, score2opus, opus2midi import numpy as np import torch import gradio as gr import tempfile ### MODEL & TOKENIZER model_name = "skytnt/midi-model-tv2o-medium" tokenizer = MIDITokenizerV2() model = MIDIModel.from_pretrained(model_name) synth = MidiSynthesizer("./gm.sf2") device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) ### HELPER FUNCTIONS def encode_midi(bin_score): """Encodes the binary contents of a MIDI file into a numpy array of tokens.""" midi = midi2score(bin_score) tokens = tokenizer.tokenize(midi) tokens_np = np.array(tokens, dtype=np.long) return tokens_np def extend_midi(bin_score, max_len, output_midi_path, generate_audio=True): """Extends the contents of an existing MIDI file.""" inputs = encode_midi(bin_score) outputs = model.generate(inputs, max_len=max_len) print(f"Generated MIDI. Output shape: {outputs.shape}") result = tokenizer.detokenize(outputs[0]) opus = score2opus(result) midi = opus2midi(opus) with open(output_midi_path, "wb") as f: f.write(midi) if generate_audio: audio_data = synth.synthesis(opus) return audio_data ### EVENT HANDLERS def process_midi(midi_file, max_len, progress=gr.Progress(track_tqdm=True)): """ Process the uploaded MIDI file and generate extended output. Args: midi_file: Path to uploaded MIDI file Returns: tuple: (audio_path, midi_path) for preview and download """ if midi_file is None: return None, None # Create a temporary file for the output MIDI # The file will be automatically cleaned up by the system temp_output = tempfile.NamedTemporaryFile( delete=False, prefix="output", suffix=".mid", mode='wb' ) output_midi_path = temp_output.name temp_output.close() output_audio = extend_midi(midi_file, max_len, output_midi_path) return (synth.sample_rate, output_audio), output_midi_path ### USER INTERFACE # Create Gradio interface using Blocks API with gr.Blocks(title="MIDI Extender") as demo: gr.Markdown("# MIDI Extender") with gr.Row(): # Left side - Input section with gr.Column(): gr.Markdown("### Input") midi_input = gr.File( label="Upload MIDI File", file_types=[".mid", ".midi"], type="binary" ) max_len = gr.Slider( label="Maximum Length (tokens)", minimum=16, maximum=512, value=128, step=0 ) process_btn = gr.Button("Extend MIDI", variant="primary") # Right side - Output section with gr.Column(): gr.Markdown("### Output") audio_output = gr.Audio( label="Audio Preview", type="numpy" ) midi_output = gr.File( label="Download Generated MIDI" ) # Connect the button to the processing function process_btn.click( fn=process_midi, inputs=[midi_input, max_len], outputs=[audio_output, midi_output] ) ### RUN if __name__ == "__main__": demo.launch()