|
|
|
|
|
|
|
|
|
|
|
|
|
|
print('=' * 70) |
|
|
print('MIDI Templates Inpainter Gradio App') |
|
|
|
|
|
print('=' * 70) |
|
|
print('Loading core MIDI Templates Inpainter modules...') |
|
|
|
|
|
import os |
|
|
import copy |
|
|
import pickle |
|
|
import time as reqtime |
|
|
import datetime |
|
|
from pytz import timezone |
|
|
|
|
|
print('=' * 70) |
|
|
print('Loading main MIDI Templates Inpainter modules...') |
|
|
|
|
|
os.environ['USE_FLASH_ATTENTION'] = '1' |
|
|
|
|
|
import torch |
|
|
|
|
|
torch.set_float32_matmul_precision('medium') |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
torch.backends.cuda.enable_mem_efficient_sdp(True) |
|
|
torch.backends.cuda.enable_math_sdp(True) |
|
|
torch.backends.cuda.enable_flash_sdp(True) |
|
|
torch.backends.cuda.enable_cudnn_sdp(True) |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
import TMIDIX |
|
|
|
|
|
from midi_to_colab_audio import midi_to_colab_audio |
|
|
|
|
|
from x_transformer_1_23_2 import * |
|
|
|
|
|
import random |
|
|
|
|
|
import tqdm |
|
|
|
|
|
print('=' * 70) |
|
|
print('Loading aux MIDI Templates Inpainter modules...') |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
import gradio as gr |
|
|
import spaces |
|
|
|
|
|
print('=' * 70) |
|
|
print('PyTorch version:', torch.__version__) |
|
|
print('=' * 70) |
|
|
print('Done!') |
|
|
print('Enjoy! :)') |
|
|
print('=' * 70) |
|
|
|
|
|
|
|
|
|
|
|
GMT_MODEL_CHECKPOINT = 'Giant_Music_Transformer_Medium_Trained_Model_42174_steps_0.5211_loss_0.8542_acc.pth' |
|
|
|
|
|
SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2' |
|
|
|
|
|
MAX_NOTES_TO_INPAINT = 1280 |
|
|
|
|
|
|
|
|
|
|
|
print('=' * 70) |
|
|
print('Instantiating Giant Music Transformer model...') |
|
|
|
|
|
device_type = 'cuda' |
|
|
dtype = 'bfloat16' |
|
|
|
|
|
ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] |
|
|
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) |
|
|
|
|
|
SEQ_LEN = 8192 |
|
|
PAD_IDX = 19463 |
|
|
|
|
|
gmt_model = TransformerWrapper( |
|
|
num_tokens = PAD_IDX+1, |
|
|
max_seq_len = SEQ_LEN, |
|
|
attn_layers = Decoder(dim = 2048, |
|
|
depth = 8, |
|
|
heads = 32, |
|
|
rotary_pos_emb = True, |
|
|
attn_flash = True |
|
|
) |
|
|
) |
|
|
|
|
|
gmt_model = AutoregressiveWrapper(gmt_model, ignore_index=PAD_IDX, pad_value=PAD_IDX) |
|
|
|
|
|
print('=' * 70) |
|
|
print('Loading model checkpoint...') |
|
|
|
|
|
gmt_model_checkpoint = hf_hub_download(repo_id='asigalov61/Giant-Music-Transformer', filename=GMT_MODEL_CHECKPOINT) |
|
|
|
|
|
gmt_model.load_state_dict(torch.load(gmt_model_checkpoint, map_location='cpu', weights_only=True)) |
|
|
|
|
|
gmt_model = torch.compile(gmt_model, mode='max-autotune') |
|
|
|
|
|
print('=' * 70) |
|
|
print('Done!') |
|
|
print('=' * 70) |
|
|
print('Model will use', dtype, 'precision...') |
|
|
print('=' * 70) |
|
|
|
|
|
|
|
|
|
|
|
print('Loading MIDI Templates dataset...') |
|
|
|
|
|
MIDI_Templates_Dataset = hf_hub_download(repo_id='asigalov61/MIDI-Templates', |
|
|
repo_type='dataset', |
|
|
filename='MIDI_Templates_16384_Processed_MIDIs_CC_BY_NC_SA.pickle') |
|
|
|
|
|
with open(MIDI_Templates_Dataset, 'rb') as f: |
|
|
midi_templates = pickle.load(f) |
|
|
|
|
|
print('Done!') |
|
|
print('=' * 70) |
|
|
|
|
|
|
|
|
|
|
|
def toks_to_score(score_tokens): |
|
|
|
|
|
song_f = [] |
|
|
|
|
|
time = 0 |
|
|
dur = 8 |
|
|
vel = 90 |
|
|
pitch = 60 |
|
|
channel = 0 |
|
|
patch = 40 |
|
|
|
|
|
for m in score_tokens: |
|
|
|
|
|
if 0 <= m < 256: |
|
|
time += m |
|
|
|
|
|
elif 256 < m < 512: |
|
|
dur = m-256 |
|
|
|
|
|
elif 511 < m < 641: |
|
|
patch = m-512 |
|
|
|
|
|
if patch == 40: |
|
|
channel = 0 |
|
|
|
|
|
elif patch == 24: |
|
|
channel = 1 |
|
|
|
|
|
elif patch == 35: |
|
|
channel = 2 |
|
|
|
|
|
elif patch == 128: |
|
|
channel = 9 |
|
|
|
|
|
elif 640 < m < 768: |
|
|
pitch = m-640 |
|
|
|
|
|
elif 768 < m < 896: |
|
|
vel = m-768 |
|
|
|
|
|
song_f.append(['note', time, dur, channel, pitch, vel, patch]) |
|
|
|
|
|
return song_f |
|
|
|
|
|
|
|
|
|
|
|
def score_to_toks(score): |
|
|
|
|
|
score.sort(key=lambda x: x[1]) |
|
|
|
|
|
tokens = [] |
|
|
|
|
|
tokens.extend([19461, 19331, 19332+score[0][6]]) |
|
|
|
|
|
pe = score[0] |
|
|
|
|
|
for e in score: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
delta_time = max(0, min(255, e[1]-pe[1])) |
|
|
|
|
|
|
|
|
|
|
|
dur = max(0, min(255, e[2])) |
|
|
cha = max(0, min(15, e[3])) |
|
|
|
|
|
|
|
|
if cha == 9: |
|
|
pat = 128 |
|
|
|
|
|
else: |
|
|
pat = e[6] |
|
|
|
|
|
|
|
|
|
|
|
ptc = max(1, min(127, e[4])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vel = max(8, min(127, e[5])) |
|
|
velocity = round(vel / 15)-1 |
|
|
|
|
|
dur_vel = (8 * dur) + velocity |
|
|
pat_ptc = (129 * pat) + ptc |
|
|
|
|
|
tokens.extend([delta_time, dur_vel+256, pat_ptc+2304]) |
|
|
|
|
|
pe = e |
|
|
|
|
|
return tokens |
|
|
|
|
|
|
|
|
|
|
|
def first_note_idx(score_tokens, patch): |
|
|
|
|
|
for i, t in enumerate(score_tokens): |
|
|
if 2304 <= t < 18945: |
|
|
pat = (t-2304) // 129 |
|
|
if pat == patch: |
|
|
break |
|
|
|
|
|
return i |
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def Inpaint_MIDI_Template(midi_template_idx, |
|
|
inpainting_mode, |
|
|
model_temperature, |
|
|
model_sampling_top_p |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
def inpaint(melody_chords, |
|
|
inpaint_MIDI_patch=[0], |
|
|
number_of_prime_tokens=0, |
|
|
number_of_memory_tokens=4096, |
|
|
temperature=1.0, |
|
|
model_sampling_top_p_value=0.96, |
|
|
verbose=False |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
if verbose: |
|
|
print('=' * 70) |
|
|
print('Giant Music Transformer Inpainting Model Generator') |
|
|
print('=' * 70) |
|
|
|
|
|
|
|
|
|
|
|
out2 = [] |
|
|
|
|
|
for m in melody_chords[:number_of_prime_tokens]: |
|
|
out2.append(m) |
|
|
|
|
|
|
|
|
|
|
|
for i in tqdm.tqdm(range(number_of_prime_tokens, len(melody_chords))): |
|
|
|
|
|
cpatch = (melody_chords[i]-2304) // 129 |
|
|
|
|
|
if 2304 <= melody_chords[i] < 18945 and cpatch in inpaint_MIDI_patch: |
|
|
|
|
|
inp = torch.LongTensor(out2[-number_of_memory_tokens:]).cuda() |
|
|
|
|
|
with ctx: |
|
|
out1 = gmt_model.generate(inp, |
|
|
1, |
|
|
filter_logits_fn=top_p, |
|
|
filter_kwargs={'thres': model_sampling_top_p_value}, |
|
|
temperature=temperature, |
|
|
return_prime=False, |
|
|
verbose=False) |
|
|
|
|
|
|
|
|
|
|
|
max_acc_sample = out1.tolist()[0][0] |
|
|
|
|
|
cpitch = (max_acc_sample-2304) % 129 |
|
|
pat = (max_acc_sample-2304) // 129 |
|
|
|
|
|
out2.extend([((cpatch * 129) + cpitch)+2304]) |
|
|
|
|
|
else: |
|
|
out2.append(melody_chords[i]) |
|
|
|
|
|
|
|
|
|
|
|
song_f = [] |
|
|
|
|
|
time = 0 |
|
|
dur = 0 |
|
|
vel = 90 |
|
|
pitch = 0 |
|
|
channel = 0 |
|
|
|
|
|
patches = [-1] * 16 |
|
|
|
|
|
channels = [0] * 16 |
|
|
channels[9] = 1 |
|
|
|
|
|
for ss in out2: |
|
|
|
|
|
if 0 <= ss < 256: |
|
|
|
|
|
time += ss |
|
|
|
|
|
if 256 <= ss < 2304: |
|
|
|
|
|
dur = ((ss-256) // 8) |
|
|
vel = (((ss-256) % 8)+1) * 15 |
|
|
|
|
|
if 2304 <= ss < 18945: |
|
|
|
|
|
patch = (ss-2304) // 129 |
|
|
|
|
|
if patch < 128: |
|
|
|
|
|
if patch not in patches: |
|
|
if 0 in channels: |
|
|
cha = channels.index(0) |
|
|
channels[cha] = 1 |
|
|
else: |
|
|
cha = 15 |
|
|
|
|
|
patches[cha] = patch |
|
|
channel = patches.index(patch) |
|
|
else: |
|
|
channel = patches.index(patch) |
|
|
|
|
|
if patch == 128: |
|
|
channel = 9 |
|
|
|
|
|
pitch = (ss-2304) % 129 |
|
|
|
|
|
song_f.append(['note', time, dur, channel, pitch, vel, patch ]) |
|
|
|
|
|
|
|
|
|
|
|
if verbose: |
|
|
print('Done!') |
|
|
print('=' * 70) |
|
|
|
|
|
return song_f |
|
|
|
|
|
|
|
|
|
|
|
print('=' * 70) |
|
|
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) |
|
|
start_time = reqtime.time() |
|
|
print('=' * 70) |
|
|
print('Requested settings:') |
|
|
print('=' * 70) |
|
|
print('MIDI template idx:', midi_template_idx) |
|
|
print('Parts to inpaint:', inpainting_mode) |
|
|
print('Model temperature:', model_temperature) |
|
|
print('Model sampling top p:', model_sampling_top_p) |
|
|
print('=' * 70) |
|
|
|
|
|
|
|
|
|
|
|
gmt_model.to(device_type) |
|
|
gmt_model.eval() |
|
|
|
|
|
|
|
|
|
|
|
print('Selecting and loading MIDI template...') |
|
|
|
|
|
if midi_template_idx > -1: |
|
|
midi_template = midi_templates[midi_template_idx] |
|
|
|
|
|
else: |
|
|
midi_template = random.choice(midi_templates) |
|
|
|
|
|
mt_idx = midi_templates.index(midi_template) |
|
|
mt_md5 = midi_template[0] |
|
|
|
|
|
inp_score = toks_to_score(midi_template[1])[:MAX_NOTES_TO_INPAINT] |
|
|
|
|
|
print('=' * 70) |
|
|
print('Selected MIDI template idx:', mt_idx) |
|
|
print('Selected MIDI template md5:', mt_md5) |
|
|
|
|
|
|
|
|
|
|
|
print('=' * 70) |
|
|
print('Prepping patches...') |
|
|
|
|
|
mel_pat = 40 |
|
|
acc_pat = 24 |
|
|
bass_pat = 35 |
|
|
|
|
|
ipatches = [] |
|
|
|
|
|
if 'Melody' in inpainting_mode: |
|
|
ipatches.append(mel_pat) |
|
|
|
|
|
if 'Accompaniment' in inpainting_mode: |
|
|
ipatches.append(acc_pat) |
|
|
|
|
|
if 'Base' in inpainting_mode: |
|
|
ipatches.append(bass_pat) |
|
|
|
|
|
if not inpainting_mode: |
|
|
ipatches.append(mel_pat) |
|
|
|
|
|
print('=' * 70) |
|
|
|
|
|
|
|
|
|
|
|
print('Inpainting...') |
|
|
|
|
|
inp_score_tokens = score_to_toks(inp_score) |
|
|
out_score = inpaint(inp_score_tokens, ipatches, 300) |
|
|
|
|
|
print('Done!') |
|
|
print('=' * 70) |
|
|
|
|
|
|
|
|
|
|
|
print('Patching final score...') |
|
|
|
|
|
output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(out_score) |
|
|
|
|
|
print('Done!') |
|
|
print('=' * 70) |
|
|
|
|
|
|
|
|
|
|
|
print('Rendering results...') |
|
|
|
|
|
print('=' * 70) |
|
|
print('Sample events', output_score[:3]) |
|
|
print('=' * 70) |
|
|
|
|
|
|
|
|
fn1 = "Inpainted-MIDI-Template-" + str(mt_idx) + "-" + str(mt_md5) |
|
|
|
|
|
detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score, |
|
|
output_signature = 'MIDI Templates Inpainter', |
|
|
output_file_name = fn1, |
|
|
track_name='Project Los Angeles', |
|
|
list_of_MIDI_patches=patches, |
|
|
timings_multiplier=16 |
|
|
) |
|
|
|
|
|
new_fn = fn1+'.mid' |
|
|
|
|
|
|
|
|
audio = midi_to_colab_audio(new_fn, |
|
|
soundfont_path=SOUDFONT_PATH, |
|
|
sample_rate=16000, |
|
|
volume_scale=10, |
|
|
output_for_gradio=True |
|
|
) |
|
|
|
|
|
print('Done!') |
|
|
print('=' * 70) |
|
|
|
|
|
|
|
|
|
|
|
output_midi = str(new_fn) |
|
|
output_audio = (16000, audio) |
|
|
|
|
|
output_plot = TMIDIX.plot_ms_SONG(output_score, |
|
|
plot_title=output_midi, |
|
|
timings_multiplier=16, |
|
|
return_plt=True, |
|
|
) |
|
|
|
|
|
print('Output MIDI file name:', output_midi) |
|
|
print('=' * 70) |
|
|
|
|
|
|
|
|
|
|
|
print('-' * 70) |
|
|
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) |
|
|
print('-' * 70) |
|
|
print('Req execution time:', (reqtime.time() - start_time), 'sec') |
|
|
|
|
|
return output_audio, output_plot, output_midi |
|
|
|
|
|
|
|
|
|
|
|
PDT = timezone('US/Pacific') |
|
|
|
|
|
print('=' * 70) |
|
|
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) |
|
|
print('=' * 70) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
|
|
|
|
|
|
|
|
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Templates Inpainter</h1>") |
|
|
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Inpaint pitches in MIDI templates to create unique songs</h1>") |
|
|
gr.HTML(""" |
|
|
<p> |
|
|
<a href="https://huggingface.co/spaces/asigalov61/MIDI-Templates-Inpainter?duplicate=true"> |
|
|
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-md.svg" alt="Duplicate in Hugging Face"> |
|
|
</a> |
|
|
</p> |
|
|
|
|
|
for faster execution and endless generation! |
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
gr.Markdown("## MIDI Templates options") |
|
|
|
|
|
midi_template_idx = gr.Slider(-1, 16383, value=-1, step=1, label="Desired MIDI template number") |
|
|
inpainting_mode = gr.CheckboxGroup(['Melody', 'Accompaniment', 'Base'], value=['Melody'], label="Which MIDI template parts to inpaint") |
|
|
|
|
|
gr.Markdown("## Model options") |
|
|
|
|
|
model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature") |
|
|
model_sampling_top_p = gr.Slider(0.1, 1, value=0.96, step=0.01, label="Model sampling top p value") |
|
|
|
|
|
generate_btn = gr.Button("Inpaint", variant="primary") |
|
|
|
|
|
gr.Markdown("## Inpainting results") |
|
|
|
|
|
output_audio = gr.Audio(label="MIDI audio", format="wav", elem_id="midi_audio") |
|
|
output_plot = gr.Plot(label="MIDI score plot") |
|
|
output_midi = gr.File(label="MIDI file", file_types=[".mid"]) |
|
|
|
|
|
generate_btn.click(Inpaint_MIDI_Template, |
|
|
[midi_template_idx, |
|
|
inpainting_mode, |
|
|
model_temperature, |
|
|
model_sampling_top_p |
|
|
], |
|
|
[output_audio, |
|
|
output_plot, |
|
|
output_midi, |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|