| |
| |
| |
|
|
| """ |
| Search for similar MIDIs with Orpheus embeddings |
| """ |
|
|
| |
|
|
| print('=' * 70) |
| print("Orpheus MIDI Search Gradio App") |
| print('=' * 70) |
| print("Loading modules...") |
|
|
| import os |
|
|
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
|
|
| import time as reqtime |
| import datetime |
| from pytz import timezone |
|
|
| import numpy as np |
|
|
| from sentence_transformers import SentenceTransformer, util |
|
|
| import matplotlib.pyplot as plt |
|
|
| import gradio as gr |
|
|
| from huggingface_hub import hf_hub_download |
|
|
| import TMIDIX |
|
|
| from midi_to_colab_audio import midi_to_colab_audio |
|
|
| |
|
|
| PDT = timezone('US/Pacific') |
|
|
| MODEL_CHECKPOINT = 'all-mpnet-base-v2' |
| EMB_DATASETS_REPO = 'projectlosangeles/Orpheus-MIDI-Search' |
|
|
| MI_EMB_DATASET_FILES = ['168082_Orpheus_Song_Artist_Cap_Score_Sim_Dataset_CC_BY_NC_SA.pickle', |
| '168082_orpheus_song_artist_corpus_emb_all_mpnet_base_v2.npy' |
| ] |
| SP_EMB_DATASET_FILES = ['164598_Orpheus_Piano_Song_Artist_Cap_Score_Sim_Dataset_CC_BY_NC_SA.pickle', |
| '164598_orpheus_piano_song_artist_corpus_emb_all_mpnet_base_v2.npy' |
| ] |
|
|
| SOUNDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2' |
|
|
| |
|
|
| dataset = 'Multi-Instrumental' |
|
|
| |
|
|
| print('=' * 70) |
| print("Done loading modules!") |
| print('=' * 70) |
|
|
| |
|
|
| print('Loading Sentence Transformer model...') |
| print('=' * 70) |
| model = SentenceTransformer(MODEL_CHECKPOINT) |
| print('=' * 70) |
| print('Done!') |
| print('=' * 70) |
|
|
| |
|
|
| |
| |
| |
| def render_midi_output(final_composition, input_title=''): |
|
|
| midi_score = save_midi(final_composition, |
| input_title=input_title |
| ) |
| |
| midi_plot = TMIDIX.plot_ms_SONG(midi_score, |
| plot_title=input_title, |
| return_plt=True |
| ) |
| |
| midi_audio = midi_to_colab_audio(input_title + '.mid', |
| soundfont_path=SOUNDFONT_PATH, |
| sample_rate=16000, |
| output_for_gradio=True |
| ) |
| |
| return (16000, midi_audio), midi_plot, input_title + '.mid' |
|
|
| |
| |
| |
|
|
| def save_midi(tokens, input_title=''): |
|
|
| time = 0 |
| dur = 1 |
| vel = 90 |
| pitch = 60 |
| channel = 0 |
| patch = 0 |
|
|
| patches = [-1] * 16 |
|
|
| channels = [0] * 16 |
| channels[9] = 1 |
|
|
| song_f = [] |
|
|
| for ss in tokens: |
|
|
| if 0 <= ss < 256: |
|
|
| time += ss * 16 |
|
|
| if 256 <= ss < 16768: |
|
|
| patch = (ss-256) // 128 |
|
|
| if patch < 128: |
|
|
| if patch not in patches: |
| if 0 in channels: |
| cha = channels.index(0) |
| channels[cha] = 1 |
| else: |
| cha = 15 |
|
|
| patches[cha] = patch |
| channel = patches.index(patch) |
| else: |
| channel = patches.index(patch) |
|
|
| if patch == 128: |
| channel = 9 |
|
|
| pitch = (ss-256) % 128 |
|
|
|
|
| if 16768 <= ss < 18816: |
|
|
| dur = ((ss-16768) // 8) * 16 |
| vel = (((ss-16768) % 8)+1) * 15 |
|
|
| song_f.append(['note', time, dur, channel, pitch, vel, patch]) |
|
|
| patches = [0 if x==-1 else x for x in patches] |
|
|
| output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f) |
| |
| TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score, |
| output_signature=input_title, |
| output_file_name=input_title, |
| track_name='Project Los Angeles', |
| list_of_MIDI_patches=patches, |
| verbose=False |
| ) |
| |
| return output_score |
|
|
| |
| |
| |
|
|
| def load_dataset(dataset_name): |
|
|
| print('=' * 70) |
| print("Loading requested Orpheus MIDI search dataset...") |
| print('=' * 70) |
|
|
| if dataset_name == 'Multi-Instrumental': |
| emb_dataset = hf_hub_download(repo_id=EMB_DATASETS_REPO, |
| repo_type='dataset', |
| filename=MI_EMB_DATASET_FILES[0] |
| ) |
| |
| emb_dataset_corpus = hf_hub_download(repo_id=EMB_DATASETS_REPO, |
| repo_type='dataset', |
| filename=MI_EMB_DATASET_FILES[1] |
| ) |
|
|
| else: |
| emb_dataset = hf_hub_download(repo_id=EMB_DATASETS_REPO, |
| repo_type='dataset', |
| filename=SP_EMB_DATASET_FILES[0] |
| ) |
| |
| emb_dataset_corpus = hf_hub_download(repo_id=EMB_DATASETS_REPO, |
| repo_type='dataset', |
| filename=SP_EMB_DATASET_FILES[1] |
| ) |
|
|
| print('=' * 70) |
| print("Done!") |
| print('=' * 70) |
| |
| return emb_dataset, emb_dataset_corpus |
|
|
| |
| |
| |
|
|
| def MIDI_Search(title, artist): |
| |
| """ |
| Generate tokens using the model, update the composition state, and prepare outputs. |
| This function combines seed loading, token generation, and UI output packaging. |
| """ |
| |
| print('=' * 70) |
| print("Request start time:", datetime.datetime.now(PDT).strftime("%Y-%m-%d %H:%M:%S")) |
| start_time = reqtime.time() |
|
|
| print('=' * 70) |
| print('Req title:', title) |
| print('Req artist:', artist) |
| print('Req dataset:', dataset) |
| print('=' * 70) |
|
|
| if title and artist: |
| input_title = title + ' --- ' + artist |
|
|
| else: |
| input_title = '' |
|
|
| if title: |
| input_title = title |
|
|
| if artist: |
| input_title = artist |
|
|
| print('Searching for best matching title...') |
| |
| query_embedding = model.encode([input_title]) |
|
|
| similarities = util.cos_sim(query_embedding, |
| embeddings_dataset_corpus |
| ) |
| |
| selected_title_index = np.argmax(similarities).tolist() |
|
|
| selected_title = song_artist_list[selected_title_index] |
|
|
| print('Done!') |
| print('=' * 70) |
| print('Selected title:', selected_title) |
| print('Selected title index:', selected_title_index) |
| print('=' * 70) |
| |
| print('Rendering selected title...') |
| print('=' * 70) |
|
|
| final_outputs = [] |
|
|
| song, artist, cap, score, matches = embeddings_dataset[selected_title_index] |
|
|
| audio, plot, fname = render_midi_output(score, selected_title) |
|
|
| top_ten_titles_list = [song_artist_list[i] for i, s in matches] |
|
|
| top_ten_titles = '' |
|
|
| for i, t in enumerate(top_ten_titles_list): |
| top_ten_titles += str(i+1) + ') ' + t + '\n' |
|
|
| final_outputs.extend([top_ten_titles, |
| selected_title, |
| cap, |
| audio, |
| plot, |
| fname |
| ]) |
|
|
| print('Done!') |
| print('=' * 70) |
| |
| print('Rendering top 10 titles...') |
| print('=' * 70) |
|
|
| for idx, sim in matches: |
| |
| song, artist, cap, score, matches = embeddings_dataset[idx] |
| title = song_artist_list[idx] |
| |
| audio, plot, fname = render_midi_output(score, title) |
| |
| final_outputs.extend([title, |
| cap, |
| audio, |
| plot, |
| fname |
| ]) |
| |
| print('Done!') |
| print('=' * 70) |
|
|
| print(top_ten_titles) |
|
|
| print("Request end time:", datetime.datetime.now(PDT).strftime("%Y-%m-%d %H:%M:%S")) |
| print('=' * 70) |
| |
| end_time = reqtime.time() |
| execution_time = end_time - start_time |
| |
| print(f"Request execution time: {execution_time} seconds") |
| print('=' * 70) |
| |
| return final_outputs |
|
|
| |
|
|
| print('=' * 70) |
| print('Prepping requested embeddings dataset...') |
|
|
| emb_dat, emb_dat_cor = load_dataset(dataset) |
|
|
| print('=' * 70) |
| print('Loading requested embeddings dataset...') |
|
|
| embeddings_dataset = TMIDIX.Tegridy_Any_Pickle_File_Reader(emb_dat, |
| verbose=False |
| ) |
|
|
| song_artist_list = [d[0] + ' --- ' + d[1] for d in embeddings_dataset] |
|
|
| embeddings_dataset_corpus = np.load(emb_dat_cor) |
|
|
| print('Done!') |
| print('=' * 70) |
|
|
| |
|
|
| |
| |
| |
| with gr.Blocks() as demo: |
|
|
| gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Orpheus MIDI Search</h1>") |
| gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Search for similar MIDIs with Orpheus embeddings</h1>") |
|
|
| gr.HTML(""" |
| Check out <a href="https://huggingface.co/datasets/projectlosangeles/Godzilla-MIDI-Dataset">Godzilla MIDI Dataset</a> on Hugging Face |
| <p> |
| <a href="https://huggingface.co/spaces/asigalov61/Orpheus-MIDI-Search?duplicate=true"> |
| <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-md.svg" alt="Duplicate in Hugging Face"> |
| </a> |
| </p> |
| """) |
|
|
| gr.Markdown("# Enter any desired title, artist or both\n\n") |
| |
| title = gr.Textbox(label="Song Title", value="Family Guy") |
| artist = gr.Textbox(label="Song Artist", value="TV Themes") |
| |
| search_btn = gr.Button(value='Search', variant="primary") |
| gr.ClearButton(components=[title, artist]) |
|
|
| gr.Markdown("# Search results") |
|
|
| gr.Markdown("## Top 10 matches summary") |
|
|
| top_ten_matches = gr.Textbox(label="Top 10 matches", lines=11, max_lines=11) |
|
|
| gr.Markdown("## Query MIDI preview") |
|
|
| query_midi_title = gr.Textbox(label="Query MIDI title", lines=1, max_lines=1) |
| query_midi_cap = gr.Textbox(label="Query MIDI caption", lines=7, max_lines=7) |
| query_audio = gr.Audio(label="Query MIDI audio", format="wav", elem_id="midi_audio") |
| query_plot = gr.Plot(label="Query MIDI score plot") |
| query_midi = gr.File(label="Query MIDI file", file_types=[".mid"]) |
|
|
| outputs = [top_ten_matches, |
| query_midi_title, |
| query_midi_cap, |
| query_audio, |
| query_plot, |
| query_midi |
| ] |
| |
| gr.Markdown("## Top 10 matches previews") |
| |
| for i in range(10): |
| with gr.Tab(f"Match # {i}"): |
| title_output = gr.Textbox(label=f"Match # {i} MIDI title", lines=1, max_lines=1) |
| cap_output = gr.Textbox(label=f"Match # {i} MIDI caption", lines=7, max_lines=7) |
| audio_output = gr.Audio(label=f"Match # {i} MIDI audio", format="mp3") |
| plot_output = gr.Plot(label=f"Match # {i} MIDI plot") |
| midi_output = gr.File(label=f"Match # {i} MIDI file", file_types=[".mid"]) |
| |
| outputs.extend([title_output, |
| cap_output, |
| audio_output, |
| plot_output, |
| midi_output |
| ]) |
| |
| search_btn.click( |
| MIDI_Search, |
| [title, |
| artist |
| ], |
| outputs |
| ) |
|
|
| demo.launch() |