Spaces:
Running
Running
| ### IMPORTS | |
| from midi_model import MIDIModel | |
| from midi_tokenizer import MIDITokenizerV2 | |
| from midi_synthesizer import MidiSynthesizer | |
| from MIDI import midi2score, score2opus, opus2midi | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| import tempfile | |
| ### MODEL & TOKENIZER | |
| model_name = "skytnt/midi-model-tv2o-medium" | |
| tokenizer = MIDITokenizerV2() | |
| model = MIDIModel.from_pretrained(model_name) | |
| synth = MidiSynthesizer("./gm.sf2") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| ### HELPER FUNCTIONS | |
| def encode_midi(bin_score): | |
| """Encodes the binary contents of a MIDI file into a numpy array of tokens.""" | |
| midi = midi2score(bin_score) | |
| tokens = tokenizer.tokenize(midi) | |
| tokens_np = np.array(tokens, dtype=np.long) | |
| return tokens_np | |
| def extend_midi(bin_score, max_len, output_midi_path, generate_audio=True): | |
| """Extends the contents of an existing MIDI file.""" | |
| inputs = encode_midi(bin_score) | |
| outputs = model.generate(inputs, max_len=max_len) | |
| print(f"Generated MIDI. Output shape: {outputs.shape}") | |
| result = tokenizer.detokenize(outputs[0]) | |
| opus = score2opus(result) | |
| midi = opus2midi(opus) | |
| with open(output_midi_path, "wb") as f: | |
| f.write(midi) | |
| if generate_audio: | |
| audio_data = synth.synthesis(opus) | |
| return audio_data | |
| ### EVENT HANDLERS | |
| def process_midi(midi_file, max_len, progress=gr.Progress(track_tqdm=True)): | |
| """ | |
| Process the uploaded MIDI file and generate extended output. | |
| Args: | |
| midi_file: Path to uploaded MIDI file | |
| Returns: | |
| tuple: (audio_path, midi_path) for preview and download | |
| """ | |
| if midi_file is None: | |
| return None, None | |
| # Create a temporary file for the output MIDI | |
| # The file will be automatically cleaned up by the system | |
| temp_output = tempfile.NamedTemporaryFile( | |
| delete=False, | |
| prefix="output", | |
| suffix=".mid", | |
| mode='wb' | |
| ) | |
| output_midi_path = temp_output.name | |
| temp_output.close() | |
| output_audio = extend_midi(midi_file, max_len, output_midi_path) | |
| return (synth.sample_rate, output_audio), output_midi_path | |
| ### USER INTERFACE | |
| # Create Gradio interface using Blocks API | |
| with gr.Blocks(title="MIDI Extender") as demo: | |
| gr.Markdown("# MIDI Extender") | |
| with gr.Row(): | |
| # Left side - Input section | |
| with gr.Column(): | |
| gr.Markdown("### Input") | |
| midi_input = gr.File( | |
| label="Upload MIDI File", | |
| file_types=[".mid", ".midi"], | |
| type="binary" | |
| ) | |
| max_len = gr.Slider( | |
| label="Maximum Length (tokens)", | |
| minimum=16, | |
| maximum=512, | |
| value=128, | |
| step=0 | |
| ) | |
| process_btn = gr.Button("Extend MIDI", variant="primary") | |
| # Right side - Output section | |
| with gr.Column(): | |
| gr.Markdown("### Output") | |
| audio_output = gr.Audio( | |
| label="Audio Preview", | |
| type="numpy" | |
| ) | |
| midi_output = gr.File( | |
| label="Download Generated MIDI" | |
| ) | |
| # Connect the button to the processing function | |
| process_btn.click( | |
| fn=process_midi, | |
| inputs=[midi_input, max_len], | |
| outputs=[audio_output, midi_output] | |
| ) | |
| ### RUN | |
| if __name__ == "__main__": | |
| demo.launch() |