Spaces:
Running on Zero
Running on Zero
| #================================================================================= | |
| # https://huggingface.co/spaces/projectlosangeles/Orpheus-Masked-Pitches-Inpainter | |
| #================================================================================= | |
| print('=' * 70) | |
| print('Orpheus Masked Pitches Inpainter Gradio App') | |
| print('=' * 70) | |
| import os | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
| os.environ['USE_FLASH_ATTENTION'] = '1' | |
| import time as reqtime | |
| from pytz import timezone | |
| import torch | |
| torch.set_float32_matmul_precision('high') | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.backends.cuda.enable_mem_efficient_sdp(True) | |
| torch.backends.cuda.enable_math_sdp(True) | |
| torch.backends.cuda.enable_flash_sdp(True) | |
| torch.backends.cuda.enable_cudnn_sdp(True) | |
| import spaces | |
| import gradio as gr | |
| from x_transformer_2_3_1 import * | |
| import datetime | |
| import random | |
| import tqdm | |
| from midi_to_colab_audio import midi_to_colab_audio | |
| import TMIDIX | |
| import matplotlib.pyplot as plt | |
| from huggingface_hub import hf_hub_download | |
| # ================================================================================================= | |
| OUTPUT_MIDIS_DIR = 'output_midis' | |
| # ================================================================================================= | |
| print('=' * 70) | |
| print('Loading models...') | |
| print('=' * 70) | |
| print('Loading Orpheus masked encoder model...') | |
| print('=' * 70) | |
| SEQ_LEN = 2048 | |
| PAD_IDX = 18820 | |
| DEVICE = 'cuda' | |
| model = TransformerWrapper( | |
| num_tokens = PAD_IDX+1, | |
| max_seq_len = SEQ_LEN, | |
| attn_layers = Encoder(dim = 2048, | |
| depth = 12, | |
| heads = 16, | |
| rotary_pos_emb = True, | |
| attn_flash = True | |
| ) | |
| ) | |
| model.to(DEVICE) | |
| print('=' * 70) | |
| print('Loading model checkpoint...') | |
| checkpoint = hf_hub_download( | |
| repo_id='asigalov61/Orpheus-Music-Transformer', | |
| filename='Orpheus_Music_Transformer_Masked_Encoder_Trained_Model_23000_steps_0.6548_loss_0.8132_acc.pth' | |
| ) | |
| model.load_state_dict(torch.load(checkpoint, map_location=DEVICE, weights_only=True)) | |
| model.eval() | |
| # model = torch.compile(model) | |
| print('=' * 70) | |
| print('Done!') | |
| print('=' * 70) | |
| # ================================================================================================= | |
| dtype = torch.bfloat16 | |
| ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype) | |
| print('Done!') | |
| print('=' * 70) | |
| # ================================================================================================= | |
| print('Loading SoundFont...') | |
| SOUNDFONT_PATH = hf_hub_download(repo_id='projectlosangeles/soundfonts4u', | |
| repo_type='dataset', | |
| filename='SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2' | |
| ) | |
| print('Done!') | |
| print('=' * 70) | |
| # ================================================================================================= | |
| def load_midi(input_midi): | |
| """Process the input MIDI file and create a token sequence.""" | |
| raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name, do_not_check_MIDI_signature=True) | |
| escore_notes = TMIDIX.advanced_score_processor(raw_score, | |
| return_enhanced_score_notes=True, | |
| apply_sustain=True | |
| ) | |
| if escore_notes and escore_notes[0]: | |
| escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], | |
| sort_drums_last=True | |
| ) | |
| escore_notes = TMIDIX.remove_duplicate_pitches_from_escore_notes(escore_notes) | |
| escore_notes = TMIDIX.fix_escore_notes_durations(escore_notes, | |
| min_notes_gap=0 | |
| ) | |
| dscore = TMIDIX.delta_score_notes(escore_notes) | |
| dcscore = TMIDIX.chordify_score([d[1:] for d in dscore]) | |
| melody_chords = [18816] | |
| #======================================================= | |
| # MAIN PROCESSING CYCLE | |
| #======================================================= | |
| for i, c in enumerate(dcscore): | |
| delta_time = c[0][0] | |
| melody_chords.append(delta_time) | |
| for e in c: | |
| #======================================================= | |
| # Durations | |
| dur = max(1, min(255, e[1])) | |
| # Patches | |
| pat = max(0, min(128, e[5])) | |
| # Pitches | |
| ptc = max(1, min(127, e[3])) | |
| # Velocities | |
| # Calculating octo-velocity | |
| vel = max(8, min(127, e[4])) | |
| velocity = round(vel / 15)-1 | |
| #======================================================= | |
| # FINAL NOTE SEQ | |
| #======================================================= | |
| # Writing final note | |
| pat_ptc = (128 * pat) + ptc | |
| dur_vel = (8 * dur) + velocity | |
| melody_chords.extend([pat_ptc+256, dur_vel+16768]) | |
| return melody_chords | |
| else: | |
| return [18816] | |
| # ================================================================================================= | |
| def save_midi(tokens): | |
| """Convert token sequence back to a MIDI score and write it using TMIDIX. | |
| """ | |
| time = 0 | |
| dur = 1 | |
| vel = 90 | |
| pitch = 60 | |
| channel = 0 | |
| patch = 0 | |
| patches = [-1] * 16 | |
| channels = [0] * 16 | |
| channels[9] = 1 | |
| song_f = [] | |
| for ss in tokens: | |
| if 0 <= ss < 256: | |
| time += ss * 16 | |
| if 256 <= ss < 16768: | |
| patch = (ss-256) // 128 | |
| if patch < 128: | |
| if patch not in patches: | |
| if 0 in channels: | |
| cha = channels.index(0) | |
| channels[cha] = 1 | |
| else: | |
| cha = 15 | |
| patches[cha] = patch | |
| channel = patches.index(patch) | |
| else: | |
| channel = patches.index(patch) | |
| if patch == 128: | |
| channel = 9 | |
| pitch = (ss-256) % 128 | |
| if 16768 <= ss < 18816: | |
| dur = ((ss-16768) // 8) * 16 | |
| vel = (((ss-16768) % 8)+1) * 15 | |
| song_f.append(['note', time, dur, channel, pitch, vel, patch]) | |
| if song_f is not None and song_f: | |
| song_f = TMIDIX.remove_duplicate_pitches_from_escore_notes(song_f) | |
| song_f = TMIDIX.fix_escore_notes_durations(song_f, | |
| min_notes_gap=0 | |
| ) | |
| output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f) | |
| now = datetime.datetime.now(PDT) | |
| ms4 = now.strftime("%f")[:4] # first four digits of microseconds | |
| fname = ( | |
| "Orpheus-Masked-Pitches-Inpainter-Composition-" | |
| + now.strftime(f"%Y-%m-%d-%H-%M-%S-{ms4}") | |
| ) | |
| os.makedirs(OUTPUT_MIDIS_DIR, exist_ok=True) | |
| output_fname = os.path.join(OUTPUT_MIDIS_DIR, fname) | |
| TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter( | |
| output_score, | |
| output_signature='Orpheus Masked Pitches Inpainter', | |
| output_file_name=output_fname, | |
| track_name='Project Los Angeles', | |
| list_of_MIDI_patches=patches, | |
| verbose=False | |
| ) | |
| return output_fname, output_score | |
| else: | |
| return None, None | |
| # ================================================================================================= | |
| def inpaint_pitches(inp_seq, | |
| input_patch, | |
| input_inpaint_ratio, | |
| input_num_prime_notes | |
| ): | |
| print('*' * 70) | |
| print('Inpainting pitches...') | |
| inp_seq = inp_seq[:SEQ_LEN] | |
| m_pos = [i for i in range(SEQ_LEN) if (128*input_patch)+256 < inp_seq[i] < (128*(input_patch+1))+256] | |
| m_pos = m_pos[min(len(m_pos), input_num_prime_notes):] | |
| if input_inpaint_ratio < 1: | |
| m_pos = sorted(random.sample(m_pos, k=int(round(len(m_pos) * input_inpaint_ratio)))) | |
| results = predict_masked_tokens(model, inp_seq, mask_positions=m_pos, topk=1) | |
| output_seq = results['predicted_ids'] | |
| print('Done!') | |
| print('=' * 70) | |
| return output_seq | |
| # ================================================================================================= | |
| def Inpaint_Pitches(input_midi, | |
| input_patch, | |
| input_inpaint_ratio, | |
| input_num_prime_notes | |
| ): | |
| if input_midi is not None: | |
| print('=' * 70) | |
| print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) | |
| start_time = reqtime.time() | |
| print('=' * 70) | |
| fn = os.path.basename(input_midi.name) | |
| fn1 = fn.split('.')[0] | |
| print('Input file name:', fn) | |
| print('Input patch:', input_patch) | |
| print('Input inpaint ratio:', input_inpaint_ratio) | |
| print('Input number of prime notes:', input_num_prime_notes) | |
| print('=' * 70) | |
| print('Loading MIDI...') | |
| inp_seq = load_midi(input_midi) | |
| print('Composition has', len(inp_seq), 'tokens') | |
| print('Sample composition tokens:', inp_seq[:5]) | |
| print('=' * 70) | |
| #=============================================================================== | |
| output_seq = inpaint_pitches(inp_seq, | |
| input_patch, | |
| input_inpaint_ratio, | |
| input_num_prime_notes | |
| ) | |
| #=============================================================================== | |
| print('Saving MIDI...') | |
| print('=' * 70) | |
| output_fname, output_score = save_midi(output_seq) | |
| #=============================================================================== | |
| print('Rendering results...') | |
| print('=' * 70) | |
| audio = midi_to_colab_audio(output_fname+'.mid', | |
| soundfont_path=SOUNDFONT_PATH, | |
| sample_rate=16000, | |
| output_for_gradio=True | |
| ) | |
| #======================================================== | |
| output_audio = (16000, audio) | |
| output_plot = TMIDIX.plot_ms_SONG(output_score, | |
| plot_title=os.path.basename(output_fname)+'.mid', | |
| return_plt=True | |
| ) | |
| print('Done!') | |
| print('=' * 70) | |
| #======================================================== | |
| print('-' * 70) | |
| print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) | |
| print('-' * 70) | |
| print('Req execution time:', (reqtime.time() - start_time), 'sec') | |
| return output_audio, output_plot, output_fname+'.mid' | |
| return None, None, None | |
| # ================================================================================================= | |
| PDT = timezone('US/Pacific') | |
| print('=' * 70) | |
| print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) | |
| print('=' * 70) | |
| soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2" | |
| app = gr.Blocks() | |
| with app: | |
| gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Orpheus Masked Pitches Inpainter</h1>") | |
| gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Instantly inpaint pitches in any MIDI with Orpheus masked encoder</h1>") | |
| with gr.Row(elem_classes="duplicate-row"): | |
| gr.DuplicateButton( | |
| value="🤗 Duplicate 🤗", | |
| variant="huggingface", | |
| size="md", | |
| link="https://huggingface.co/spaces/projectlosangeles/Orpheus-Masked-Pitches-Inpainter?duplicate=true", | |
| link_target="_blank" | |
| ) | |
| gr.Button( | |
| value="❤️ Models ❤️", | |
| variant="huggingface", | |
| size="md", | |
| link="https://huggingface.co/asigalov61/Orpheus-Music-Transformer", | |
| link_target="_blank" | |
| ) | |
| gr.Button( | |
| value="🚀 Spaces 🚀", | |
| variant="huggingface", | |
| size="md", | |
| link="https://huggingface.co/collections/asigalov61/orpheus-music-transformer", | |
| link_target="_blank" | |
| ) | |
| gr.Button( | |
| value="🦖 Dataset 🦖", | |
| variant="huggingface", | |
| size="md", | |
| link="https://huggingface.co/datasets/projectlosangeles/Godzilla-MIDI-Dataset", | |
| link_target="_blank" | |
| ) | |
| gr.Markdown("## Upload your MIDI or select an example MIDI at the bottom of the page") | |
| input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"]) | |
| input_patch = gr.Slider(0, 128, value=40, step=1, label="Patch number to inpaint") | |
| input_num_prime_notes = gr.Slider(0, 64, value=16, step=1, label="Number of prime notes") | |
| input_inpaint_ratio = gr.Slider(0.01, 1.0, value=1, step=0.01, label="Pitches inpaint ratio") | |
| run_btn = gr.Button("Inpaint Pitches", variant="primary") | |
| gr.Markdown("## Generation results") | |
| output_audio = gr.Audio(label="Output MIDI audio", format="mp3", elem_id="midi_audio") | |
| output_plot = gr.Plot(label="Output MIDI score plot") | |
| output_midi = gr.File(label="Output MIDI file", file_types=[".mid"]) | |
| run_event = run_btn.click(Inpaint_Pitches, | |
| [input_midi, | |
| input_patch, | |
| input_inpaint_ratio, | |
| input_num_prime_notes | |
| ], | |
| [output_audio, | |
| output_plot, | |
| output_midi | |
| ]) | |
| gr.Examples( | |
| [["Gang Stop.mid", 40, 1, 16], | |
| ["Soli.mid", 40, 1, 16] | |
| ], | |
| [input_midi, | |
| input_patch, | |
| input_inpaint_ratio, | |
| input_num_prime_notes | |
| ], | |
| [output_audio, | |
| output_plot, | |
| output_midi | |
| ], | |
| Inpaint_Pitches | |
| ) | |
| app.launch(mcp_server=True) |