Spaces:
Running
Running
File size: 3,456 Bytes
5be9d07 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | ### IMPORTS
from midi_model import MIDIModel
from midi_tokenizer import MIDITokenizerV2
from midi_synthesizer import MidiSynthesizer
from MIDI import midi2score, score2opus, opus2midi
import numpy as np
import torch
import gradio as gr
import tempfile
### MODEL & TOKENIZER
model_name = "skytnt/midi-model-tv2o-medium"
tokenizer = MIDITokenizerV2()
model = MIDIModel.from_pretrained(model_name)
synth = MidiSynthesizer("./gm.sf2")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
### HELPER FUNCTIONS
def encode_midi(bin_score):
"""Encodes the binary contents of a MIDI file into a numpy array of tokens."""
midi = midi2score(bin_score)
tokens = tokenizer.tokenize(midi)
tokens_np = np.array(tokens, dtype=np.long)
return tokens_np
def extend_midi(bin_score, max_len, output_midi_path, generate_audio=True):
"""Extends the contents of an existing MIDI file."""
inputs = encode_midi(bin_score)
outputs = model.generate(inputs, max_len=max_len)
print(f"Generated MIDI. Output shape: {outputs.shape}")
result = tokenizer.detokenize(outputs[0])
opus = score2opus(result)
midi = opus2midi(opus)
with open(output_midi_path, "wb") as f:
f.write(midi)
if generate_audio:
audio_data = synth.synthesis(opus)
return audio_data
### EVENT HANDLERS
def process_midi(midi_file, max_len, progress=gr.Progress(track_tqdm=True)):
"""
Process the uploaded MIDI file and generate extended output.
Args:
midi_file: Path to uploaded MIDI file
Returns:
tuple: (audio_path, midi_path) for preview and download
"""
if midi_file is None:
return None, None
# Create a temporary file for the output MIDI
# The file will be automatically cleaned up by the system
temp_output = tempfile.NamedTemporaryFile(
delete=False,
prefix="output",
suffix=".mid",
mode='wb'
)
output_midi_path = temp_output.name
temp_output.close()
output_audio = extend_midi(midi_file, max_len, output_midi_path)
return (synth.sample_rate, output_audio), output_midi_path
### USER INTERFACE
# Create Gradio interface using Blocks API
with gr.Blocks(title="MIDI Extender") as demo:
gr.Markdown("# MIDI Extender")
with gr.Row():
# Left side - Input section
with gr.Column():
gr.Markdown("### Input")
midi_input = gr.File(
label="Upload MIDI File",
file_types=[".mid", ".midi"],
type="binary"
)
max_len = gr.Slider(
label="Maximum Length (tokens)",
minimum=16,
maximum=512,
value=128,
step=0
)
process_btn = gr.Button("Extend MIDI", variant="primary")
# Right side - Output section
with gr.Column():
gr.Markdown("### Output")
audio_output = gr.Audio(
label="Audio Preview",
type="numpy"
)
midi_output = gr.File(
label="Download Generated MIDI"
)
# Connect the button to the processing function
process_btn.click(
fn=process_midi,
inputs=[midi_input, max_len],
outputs=[audio_output, midi_output]
)
### RUN
if __name__ == "__main__":
demo.launch() |