projectlosangeles's picture
Update app.py
7a98bb4 verified
#=================================================================================
# https://huggingface.co/spaces/projectlosangeles/Orpheus-Masked-Pitches-Inpainter
#=================================================================================
print('=' * 70)
print('Orpheus Masked Pitches Inpainter Gradio App')
print('=' * 70)
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ['USE_FLASH_ATTENTION'] = '1'
import time as reqtime
from pytz import timezone
import torch
torch.set_float32_matmul_precision('high')
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_cudnn_sdp(True)
import spaces
import gradio as gr
from x_transformer_2_3_1 import *
import datetime
import random
import tqdm
from midi_to_colab_audio import midi_to_colab_audio
import TMIDIX
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
# =================================================================================================
OUTPUT_MIDIS_DIR = 'output_midis'
# =================================================================================================
print('=' * 70)
print('Loading models...')
print('=' * 70)
print('Loading Orpheus masked encoder model...')
print('=' * 70)
SEQ_LEN = 2048
PAD_IDX = 18820
DEVICE = 'cuda'
model = TransformerWrapper(
num_tokens = PAD_IDX+1,
max_seq_len = SEQ_LEN,
attn_layers = Encoder(dim = 2048,
depth = 12,
heads = 16,
rotary_pos_emb = True,
attn_flash = True
)
)
model.to(DEVICE)
print('=' * 70)
print('Loading model checkpoint...')
checkpoint = hf_hub_download(
repo_id='asigalov61/Orpheus-Music-Transformer',
filename='Orpheus_Music_Transformer_Masked_Encoder_Trained_Model_23000_steps_0.6548_loss_0.8132_acc.pth'
)
model.load_state_dict(torch.load(checkpoint, map_location=DEVICE, weights_only=True))
model.eval()
# model = torch.compile(model)
print('=' * 70)
print('Done!')
print('=' * 70)
# =================================================================================================
dtype = torch.bfloat16
ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
print('Done!')
print('=' * 70)
# =================================================================================================
print('Loading SoundFont...')
SOUNDFONT_PATH = hf_hub_download(repo_id='projectlosangeles/soundfonts4u',
repo_type='dataset',
filename='SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
)
print('Done!')
print('=' * 70)
# =================================================================================================
def load_midi(input_midi):
"""Process the input MIDI file and create a token sequence."""
raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name, do_not_check_MIDI_signature=True)
escore_notes = TMIDIX.advanced_score_processor(raw_score,
return_enhanced_score_notes=True,
apply_sustain=True
)
if escore_notes and escore_notes[0]:
escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0],
sort_drums_last=True
)
escore_notes = TMIDIX.remove_duplicate_pitches_from_escore_notes(escore_notes)
escore_notes = TMIDIX.fix_escore_notes_durations(escore_notes,
min_notes_gap=0
)
dscore = TMIDIX.delta_score_notes(escore_notes)
dcscore = TMIDIX.chordify_score([d[1:] for d in dscore])
melody_chords = [18816]
#=======================================================
# MAIN PROCESSING CYCLE
#=======================================================
for i, c in enumerate(dcscore):
delta_time = c[0][0]
melody_chords.append(delta_time)
for e in c:
#=======================================================
# Durations
dur = max(1, min(255, e[1]))
# Patches
pat = max(0, min(128, e[5]))
# Pitches
ptc = max(1, min(127, e[3]))
# Velocities
# Calculating octo-velocity
vel = max(8, min(127, e[4]))
velocity = round(vel / 15)-1
#=======================================================
# FINAL NOTE SEQ
#=======================================================
# Writing final note
pat_ptc = (128 * pat) + ptc
dur_vel = (8 * dur) + velocity
melody_chords.extend([pat_ptc+256, dur_vel+16768])
return melody_chords
else:
return [18816]
# =================================================================================================
def save_midi(tokens):
"""Convert token sequence back to a MIDI score and write it using TMIDIX.
"""
time = 0
dur = 1
vel = 90
pitch = 60
channel = 0
patch = 0
patches = [-1] * 16
channels = [0] * 16
channels[9] = 1
song_f = []
for ss in tokens:
if 0 <= ss < 256:
time += ss * 16
if 256 <= ss < 16768:
patch = (ss-256) // 128
if patch < 128:
if patch not in patches:
if 0 in channels:
cha = channels.index(0)
channels[cha] = 1
else:
cha = 15
patches[cha] = patch
channel = patches.index(patch)
else:
channel = patches.index(patch)
if patch == 128:
channel = 9
pitch = (ss-256) % 128
if 16768 <= ss < 18816:
dur = ((ss-16768) // 8) * 16
vel = (((ss-16768) % 8)+1) * 15
song_f.append(['note', time, dur, channel, pitch, vel, patch])
if song_f is not None and song_f:
song_f = TMIDIX.remove_duplicate_pitches_from_escore_notes(song_f)
song_f = TMIDIX.fix_escore_notes_durations(song_f,
min_notes_gap=0
)
output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f)
now = datetime.datetime.now(PDT)
ms4 = now.strftime("%f")[:4] # first four digits of microseconds
fname = (
"Orpheus-Masked-Pitches-Inpainter-Composition-"
+ now.strftime(f"%Y-%m-%d-%H-%M-%S-{ms4}")
)
os.makedirs(OUTPUT_MIDIS_DIR, exist_ok=True)
output_fname = os.path.join(OUTPUT_MIDIS_DIR, fname)
TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
output_score,
output_signature='Orpheus Masked Pitches Inpainter',
output_file_name=output_fname,
track_name='Project Los Angeles',
list_of_MIDI_patches=patches,
verbose=False
)
return output_fname, output_score
else:
return None, None
# =================================================================================================
@spaces.GPU
def inpaint_pitches(inp_seq,
input_patch,
input_inpaint_ratio,
input_num_prime_notes
):
print('*' * 70)
print('Inpainting pitches...')
inp_seq = inp_seq[:SEQ_LEN]
m_pos = [i for i in range(SEQ_LEN) if (128*input_patch)+256 < inp_seq[i] < (128*(input_patch+1))+256]
m_pos = m_pos[min(len(m_pos), input_num_prime_notes):]
if input_inpaint_ratio < 1:
m_pos = sorted(random.sample(m_pos, k=int(round(len(m_pos) * input_inpaint_ratio))))
results = predict_masked_tokens(model, inp_seq, mask_positions=m_pos, topk=1)
output_seq = results['predicted_ids']
print('Done!')
print('=' * 70)
return output_seq
# =================================================================================================
def Inpaint_Pitches(input_midi,
input_patch,
input_inpaint_ratio,
input_num_prime_notes
):
if input_midi is not None:
print('=' * 70)
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
start_time = reqtime.time()
print('=' * 70)
fn = os.path.basename(input_midi.name)
fn1 = fn.split('.')[0]
print('Input file name:', fn)
print('Input patch:', input_patch)
print('Input inpaint ratio:', input_inpaint_ratio)
print('Input number of prime notes:', input_num_prime_notes)
print('=' * 70)
print('Loading MIDI...')
inp_seq = load_midi(input_midi)
print('Composition has', len(inp_seq), 'tokens')
print('Sample composition tokens:', inp_seq[:5])
print('=' * 70)
#===============================================================================
output_seq = inpaint_pitches(inp_seq,
input_patch,
input_inpaint_ratio,
input_num_prime_notes
)
#===============================================================================
print('Saving MIDI...')
print('=' * 70)
output_fname, output_score = save_midi(output_seq)
#===============================================================================
print('Rendering results...')
print('=' * 70)
audio = midi_to_colab_audio(output_fname+'.mid',
soundfont_path=SOUNDFONT_PATH,
sample_rate=16000,
output_for_gradio=True
)
#========================================================
output_audio = (16000, audio)
output_plot = TMIDIX.plot_ms_SONG(output_score,
plot_title=os.path.basename(output_fname)+'.mid',
return_plt=True
)
print('Done!')
print('=' * 70)
#========================================================
print('-' * 70)
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('-' * 70)
print('Req execution time:', (reqtime.time() - start_time), 'sec')
return output_audio, output_plot, output_fname+'.mid'
return None, None, None
# =================================================================================================
PDT = timezone('US/Pacific')
print('=' * 70)
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('=' * 70)
soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
app = gr.Blocks()
with app:
gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Orpheus Masked Pitches Inpainter</h1>")
gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Instantly inpaint pitches in any MIDI with Orpheus masked encoder</h1>")
with gr.Row(elem_classes="duplicate-row"):
gr.DuplicateButton(
value="🤗 Duplicate 🤗",
variant="huggingface",
size="md",
link="https://huggingface.co/spaces/projectlosangeles/Orpheus-Masked-Pitches-Inpainter?duplicate=true",
link_target="_blank"
)
gr.Button(
value="❤️ Models ❤️",
variant="huggingface",
size="md",
link="https://huggingface.co/asigalov61/Orpheus-Music-Transformer",
link_target="_blank"
)
gr.Button(
value="🚀 Spaces 🚀",
variant="huggingface",
size="md",
link="https://huggingface.co/collections/asigalov61/orpheus-music-transformer",
link_target="_blank"
)
gr.Button(
value="🦖 Dataset 🦖",
variant="huggingface",
size="md",
link="https://huggingface.co/datasets/projectlosangeles/Godzilla-MIDI-Dataset",
link_target="_blank"
)
gr.Markdown("## Upload your MIDI or select an example MIDI at the bottom of the page")
input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
input_patch = gr.Slider(0, 128, value=40, step=1, label="Patch number to inpaint")
input_num_prime_notes = gr.Slider(0, 64, value=16, step=1, label="Number of prime notes")
input_inpaint_ratio = gr.Slider(0.01, 1.0, value=1, step=0.01, label="Pitches inpaint ratio")
run_btn = gr.Button("Inpaint Pitches", variant="primary")
gr.Markdown("## Generation results")
output_audio = gr.Audio(label="Output MIDI audio", format="mp3", elem_id="midi_audio")
output_plot = gr.Plot(label="Output MIDI score plot")
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
run_event = run_btn.click(Inpaint_Pitches,
[input_midi,
input_patch,
input_inpaint_ratio,
input_num_prime_notes
],
[output_audio,
output_plot,
output_midi
])
gr.Examples(
[["Gang Stop.mid", 40, 1, 16],
["Soli.mid", 40, 1, 16]
],
[input_midi,
input_patch,
input_inpaint_ratio,
input_num_prime_notes
],
[output_audio,
output_plot,
output_midi
],
Inpaint_Pitches
)
app.launch(mcp_server=True)