asigalov61's picture
Update app.py
673f330 verified
#==================================================================================
# https://huggingface.co/spaces/asigalov61/MIDI-Templates-Inpainter
#==================================================================================
print('=' * 70)
print('MIDI Templates Inpainter Gradio App')
print('=' * 70)
print('Loading core MIDI Templates Inpainter modules...')
import os
import copy
import pickle
import time as reqtime
import datetime
from pytz import timezone
print('=' * 70)
print('Loading main MIDI Templates Inpainter modules...')
os.environ['USE_FLASH_ATTENTION'] = '1'
import torch
torch.set_float32_matmul_precision('medium')
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_cudnn_sdp(True)
from huggingface_hub import hf_hub_download
import TMIDIX
from midi_to_colab_audio import midi_to_colab_audio
from x_transformer_1_23_2 import *
import random
import tqdm
print('=' * 70)
print('Loading aux MIDI Templates Inpainter modules...')
import matplotlib.pyplot as plt
import gradio as gr
import spaces
print('=' * 70)
print('PyTorch version:', torch.__version__)
print('=' * 70)
print('Done!')
print('Enjoy! :)')
print('=' * 70)
#==================================================================================
GMT_MODEL_CHECKPOINT = 'Giant_Music_Transformer_Medium_Trained_Model_42174_steps_0.5211_loss_0.8542_acc.pth'
SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
MAX_NOTES_TO_INPAINT = 1280
#==================================================================================
print('=' * 70)
print('Instantiating Giant Music Transformer model...')
device_type = 'cuda'
dtype = 'bfloat16'
ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
SEQ_LEN = 8192
PAD_IDX = 19463
gmt_model = TransformerWrapper(
num_tokens = PAD_IDX+1,
max_seq_len = SEQ_LEN,
attn_layers = Decoder(dim = 2048,
depth = 8,
heads = 32,
rotary_pos_emb = True,
attn_flash = True
)
)
gmt_model = AutoregressiveWrapper(gmt_model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
print('=' * 70)
print('Loading model checkpoint...')
gmt_model_checkpoint = hf_hub_download(repo_id='asigalov61/Giant-Music-Transformer', filename=GMT_MODEL_CHECKPOINT)
gmt_model.load_state_dict(torch.load(gmt_model_checkpoint, map_location='cpu', weights_only=True))
gmt_model = torch.compile(gmt_model, mode='max-autotune')
print('=' * 70)
print('Done!')
print('=' * 70)
print('Model will use', dtype, 'precision...')
print('=' * 70)
#==================================================================================
print('Loading MIDI Templates dataset...')
MIDI_Templates_Dataset = hf_hub_download(repo_id='asigalov61/MIDI-Templates',
repo_type='dataset',
filename='MIDI_Templates_16384_Processed_MIDIs_CC_BY_NC_SA.pickle')
with open(MIDI_Templates_Dataset, 'rb') as f:
midi_templates = pickle.load(f)
print('Done!')
print('=' * 70)
#==================================================================================
def toks_to_score(score_tokens):
song_f = []
time = 0
dur = 8
vel = 90
pitch = 60
channel = 0
patch = 40
for m in score_tokens:
if 0 <= m < 256:
time += m
elif 256 < m < 512:
dur = m-256
elif 511 < m < 641:
patch = m-512
if patch == 40:
channel = 0
elif patch == 24:
channel = 1
elif patch == 35:
channel = 2
elif patch == 128:
channel = 9
elif 640 < m < 768:
pitch = m-640
elif 768 < m < 896:
vel = m-768
song_f.append(['note', time, dur, channel, pitch, vel, patch])
return song_f
#==================================================================================
def score_to_toks(score):
score.sort(key=lambda x: x[1])
tokens = []
tokens.extend([19461, 19331, 19332+score[0][6]])
pe = score[0]
for e in score:
#=======================================================
# Timings...
# Cliping all values...
delta_time = max(0, min(255, e[1]-pe[1]))
# Durations and channels
dur = max(0, min(255, e[2]))
cha = max(0, min(15, e[3]))
# Patches
if cha == 9: # Drums patch will be == 128
pat = 128
else:
pat = e[6]
# Pitches
ptc = max(1, min(127, e[4]))
# Velocities
# Calculating octo-velocity
vel = max(8, min(127, e[5]))
velocity = round(vel / 15)-1
dur_vel = (8 * dur) + velocity
pat_ptc = (129 * pat) + ptc
tokens.extend([delta_time, dur_vel+256, pat_ptc+2304])
pe = e
return tokens
#==================================================================================
def first_note_idx(score_tokens, patch):
for i, t in enumerate(score_tokens):
if 2304 <= t < 18945:
pat = (t-2304) // 129
if pat == patch:
break
return i
#==================================================================================
@spaces.GPU
def Inpaint_MIDI_Template(midi_template_idx,
inpainting_mode,
model_temperature,
model_sampling_top_p
):
#===============================================================================
def inpaint(melody_chords,
inpaint_MIDI_patch=[0],
number_of_prime_tokens=0,
number_of_memory_tokens=4096,
temperature=1.0,
model_sampling_top_p_value=0.96,
verbose=False
):
#=====================================================================
if verbose:
print('=' * 70)
print('Giant Music Transformer Inpainting Model Generator')
print('=' * 70)
#=====================================================================
out2 = []
for m in melody_chords[:number_of_prime_tokens]:
out2.append(m)
#=====================================================================
for i in tqdm.tqdm(range(number_of_prime_tokens, len(melody_chords))):
cpatch = (melody_chords[i]-2304) // 129
if 2304 <= melody_chords[i] < 18945 and cpatch in inpaint_MIDI_patch:
inp = torch.LongTensor(out2[-number_of_memory_tokens:]).cuda()
with ctx:
out1 = gmt_model.generate(inp,
1,
filter_logits_fn=top_p,
filter_kwargs={'thres': model_sampling_top_p_value},
temperature=temperature,
return_prime=False,
verbose=False)
max_acc_sample = out1.tolist()[0][0]
cpitch = (max_acc_sample-2304) % 129
pat = (max_acc_sample-2304) // 129
out2.extend([((cpatch * 129) + cpitch)+2304])
else:
out2.append(melody_chords[i])
#=====================================================================
song_f = []
time = 0
dur = 0
vel = 90
pitch = 0
channel = 0
patches = [-1] * 16
channels = [0] * 16
channels[9] = 1
for ss in out2:
if 0 <= ss < 256:
time += ss
if 256 <= ss < 2304:
dur = ((ss-256) // 8)
vel = (((ss-256) % 8)+1) * 15
if 2304 <= ss < 18945:
patch = (ss-2304) // 129
if patch < 128:
if patch not in patches:
if 0 in channels:
cha = channels.index(0)
channels[cha] = 1
else:
cha = 15
patches[cha] = patch
channel = patches.index(patch)
else:
channel = patches.index(patch)
if patch == 128:
channel = 9
pitch = (ss-2304) % 129
song_f.append(['note', time, dur, channel, pitch, vel, patch ])
#=====================================================================
if verbose:
print('Done!')
print('=' * 70)
return song_f
#===============================================================================
print('=' * 70)
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
start_time = reqtime.time()
print('=' * 70)
print('Requested settings:')
print('=' * 70)
print('MIDI template idx:', midi_template_idx)
print('Parts to inpaint:', inpainting_mode)
print('Model temperature:', model_temperature)
print('Model sampling top p:', model_sampling_top_p)
print('=' * 70)
#==================================================================
gmt_model.to(device_type)
gmt_model.eval()
#==================================================================
print('Selecting and loading MIDI template...')
if midi_template_idx > -1:
midi_template = midi_templates[midi_template_idx]
else:
midi_template = random.choice(midi_templates)
mt_idx = midi_templates.index(midi_template)
mt_md5 = midi_template[0]
inp_score = toks_to_score(midi_template[1])[:MAX_NOTES_TO_INPAINT]
print('=' * 70)
print('Selected MIDI template idx:', mt_idx)
print('Selected MIDI template md5:', mt_md5)
#==================================================================
print('=' * 70)
print('Prepping patches...')
mel_pat = 40
acc_pat = 24
bass_pat = 35
ipatches = []
if 'Melody' in inpainting_mode:
ipatches.append(mel_pat)
if 'Accompaniment' in inpainting_mode:
ipatches.append(acc_pat)
if 'Base' in inpainting_mode:
ipatches.append(bass_pat)
if not inpainting_mode:
ipatches.append(mel_pat)
print('=' * 70)
#==================================================================
print('Inpainting...')
inp_score_tokens = score_to_toks(inp_score)
out_score = inpaint(inp_score_tokens, ipatches, 300)
print('Done!')
print('=' * 70)
#==================================================================
print('Patching final score...')
output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(out_score)
print('Done!')
print('=' * 70)
#===============================================================================
print('Rendering results...')
print('=' * 70)
print('Sample events', output_score[:3])
print('=' * 70)
fn1 = "Inpainted-MIDI-Template-" + str(mt_idx) + "-" + str(mt_md5)
detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score,
output_signature = 'MIDI Templates Inpainter',
output_file_name = fn1,
track_name='Project Los Angeles',
list_of_MIDI_patches=patches,
timings_multiplier=16
)
new_fn = fn1+'.mid'
audio = midi_to_colab_audio(new_fn,
soundfont_path=SOUDFONT_PATH,
sample_rate=16000,
volume_scale=10,
output_for_gradio=True
)
print('Done!')
print('=' * 70)
#========================================================
output_midi = str(new_fn)
output_audio = (16000, audio)
output_plot = TMIDIX.plot_ms_SONG(output_score,
plot_title=output_midi,
timings_multiplier=16,
return_plt=True,
)
print('Output MIDI file name:', output_midi)
print('=' * 70)
#========================================================
print('-' * 70)
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('-' * 70)
print('Req execution time:', (reqtime.time() - start_time), 'sec')
return output_audio, output_plot, output_midi
#==================================================================================
PDT = timezone('US/Pacific')
print('=' * 70)
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('=' * 70)
#==================================================================================
with gr.Blocks() as demo:
#==================================================================================
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Templates Inpainter</h1>")
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Inpaint pitches in MIDI templates to create unique songs</h1>")
gr.HTML("""
<p>
<a href="https://huggingface.co/spaces/asigalov61/MIDI-Templates-Inpainter?duplicate=true">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-md.svg" alt="Duplicate in Hugging Face">
</a>
</p>
for faster execution and endless generation!
""")
#==================================================================================
gr.Markdown("## MIDI Templates options")
midi_template_idx = gr.Slider(-1, 16383, value=-1, step=1, label="Desired MIDI template number")
inpainting_mode = gr.CheckboxGroup(['Melody', 'Accompaniment', 'Base'], value=['Melody'], label="Which MIDI template parts to inpaint")
gr.Markdown("## Model options")
model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
model_sampling_top_p = gr.Slider(0.1, 1, value=0.96, step=0.01, label="Model sampling top p value")
generate_btn = gr.Button("Inpaint", variant="primary")
gr.Markdown("## Inpainting results")
output_audio = gr.Audio(label="MIDI audio", format="wav", elem_id="midi_audio")
output_plot = gr.Plot(label="MIDI score plot")
output_midi = gr.File(label="MIDI file", file_types=[".mid"])
generate_btn.click(Inpaint_MIDI_Template,
[midi_template_idx,
inpainting_mode,
model_temperature,
model_sampling_top_p
],
[output_audio,
output_plot,
output_midi,
]
)
#==================================================================================
demo.launch()
#==================================================================================