File size: 3,456 Bytes
5be9d07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
### IMPORTS

from midi_model import MIDIModel
from midi_tokenizer import MIDITokenizerV2
from midi_synthesizer import MidiSynthesizer
from MIDI import midi2score, score2opus, opus2midi
import numpy as np
import torch
import gradio as gr
import tempfile


### MODEL & TOKENIZER

model_name = "skytnt/midi-model-tv2o-medium"
tokenizer = MIDITokenizerV2()
model = MIDIModel.from_pretrained(model_name)
synth = MidiSynthesizer("./gm.sf2")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)


### HELPER FUNCTIONS

def encode_midi(bin_score):
    """Encodes the binary contents of a MIDI file into a numpy array of tokens."""
    midi = midi2score(bin_score)
    tokens = tokenizer.tokenize(midi)
    tokens_np = np.array(tokens, dtype=np.long)
    return tokens_np

def extend_midi(bin_score, max_len, output_midi_path, generate_audio=True):
    """Extends the contents of an existing MIDI file."""
    inputs = encode_midi(bin_score)
    outputs = model.generate(inputs, max_len=max_len)
    print(f"Generated MIDI. Output shape: {outputs.shape}")
    result = tokenizer.detokenize(outputs[0])

    opus = score2opus(result)
    midi = opus2midi(opus)

    with open(output_midi_path, "wb") as f:
        f.write(midi)

    if generate_audio:
        audio_data = synth.synthesis(opus)
        return audio_data


### EVENT HANDLERS

def process_midi(midi_file, max_len, progress=gr.Progress(track_tqdm=True)):
    """
    Process the uploaded MIDI file and generate extended output.
    
    Args:
        midi_file: Path to uploaded MIDI file
    
    Returns:
        tuple: (audio_path, midi_path) for preview and download
    """
    if midi_file is None:
        return None, None
    
    # Create a temporary file for the output MIDI
    # The file will be automatically cleaned up by the system
    temp_output = tempfile.NamedTemporaryFile(
        delete=False, 
        prefix="output",
        suffix=".mid",
        mode='wb'
    )
    output_midi_path = temp_output.name
    temp_output.close()
    
    output_audio = extend_midi(midi_file, max_len, output_midi_path)
    
    return (synth.sample_rate, output_audio), output_midi_path


### USER INTERFACE

# Create Gradio interface using Blocks API
with gr.Blocks(title="MIDI Extender") as demo:
    gr.Markdown("# MIDI Extender")
    
    with gr.Row():
        # Left side - Input section
        with gr.Column():
            gr.Markdown("### Input")
            midi_input = gr.File(
                label="Upload MIDI File",
                file_types=[".mid", ".midi"],
                type="binary"
            )
            max_len = gr.Slider(
                label="Maximum Length (tokens)",
                minimum=16,
                maximum=512,
                value=128,
                step=0
            )
            process_btn = gr.Button("Extend MIDI", variant="primary")
        
        # Right side - Output section
        with gr.Column():
            gr.Markdown("### Output")
            audio_output = gr.Audio(
                label="Audio Preview",
                type="numpy"
            )
            midi_output = gr.File(
                label="Download Generated MIDI"
            )
    
    # Connect the button to the processing function
    process_btn.click(
        fn=process_midi,
        inputs=[midi_input, max_len],
        outputs=[audio_output, midi_output]
    )


### RUN

if __name__ == "__main__":
    demo.launch()