projectlosangeles's picture
Update app.py
cc4736b verified
#====================================================================
# https://huggingface.co/spaces/projectlosangeles/Orpheus-MIDI-Search
#====================================================================
"""
Search for similar MIDIs with Orpheus embeddings
"""
#====================================================================
print('=' * 70)
print("Orpheus MIDI Search Gradio App")
print('=' * 70)
print("Loading modules...")
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
import time as reqtime
import datetime
from pytz import timezone
import numpy as np
from sentence_transformers import SentenceTransformer, util
import matplotlib.pyplot as plt
import gradio as gr
from huggingface_hub import hf_hub_download
import TMIDIX
from midi_to_colab_audio import midi_to_colab_audio
#====================================================================
PDT = timezone('US/Pacific')
MODEL_CHECKPOINT = 'all-mpnet-base-v2'
EMB_DATASETS_REPO = 'projectlosangeles/Orpheus-MIDI-Search'
MI_EMB_DATASET_FILES = ['168082_Orpheus_Song_Artist_Cap_Score_Sim_Dataset_CC_BY_NC_SA.pickle',
'168082_orpheus_song_artist_corpus_emb_all_mpnet_base_v2.npy'
]
SP_EMB_DATASET_FILES = ['164598_Orpheus_Piano_Song_Artist_Cap_Score_Sim_Dataset_CC_BY_NC_SA.pickle',
'164598_orpheus_piano_song_artist_corpus_emb_all_mpnet_base_v2.npy'
]
SOUNDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
#====================================================================
dataset = 'Multi-Instrumental' # You can switch it to Piano one if you want
#====================================================================
print('=' * 70)
print("Done loading modules!")
print('=' * 70)
#====================================================================
print('Loading Sentence Transformer model...')
print('=' * 70)
model = SentenceTransformer(MODEL_CHECKPOINT)
print('=' * 70)
print('Done!')
print('=' * 70)
#====================================================================
# -----------------------------
# RENDER FUNCTION
# -----------------------------
def render_midi_output(final_composition, input_title=''):
midi_score = save_midi(final_composition,
input_title=input_title
)
midi_plot = TMIDIX.plot_ms_SONG(midi_score,
plot_title=input_title,
return_plt=True
)
midi_audio = midi_to_colab_audio(input_title + '.mid',
soundfont_path=SOUNDFONT_PATH,
sample_rate=16000,
output_for_gradio=True
)
return (16000, midi_audio), midi_plot, input_title + '.mid'
# -----------------------------
# SAVE MIDI FUNCTION
# -----------------------------
def save_midi(tokens, input_title=''):
time = 0
dur = 1
vel = 90
pitch = 60
channel = 0
patch = 0
patches = [-1] * 16
channels = [0] * 16
channels[9] = 1
song_f = []
for ss in tokens:
if 0 <= ss < 256:
time += ss * 16
if 256 <= ss < 16768:
patch = (ss-256) // 128
if patch < 128:
if patch not in patches:
if 0 in channels:
cha = channels.index(0)
channels[cha] = 1
else:
cha = 15
patches[cha] = patch
channel = patches.index(patch)
else:
channel = patches.index(patch)
if patch == 128:
channel = 9
pitch = (ss-256) % 128
if 16768 <= ss < 18816:
dur = ((ss-16768) // 8) * 16
vel = (((ss-16768) % 8)+1) * 15
song_f.append(['note', time, dur, channel, pitch, vel, patch])
patches = [0 if x==-1 else x for x in patches]
output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f)
TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score,
output_signature=input_title,
output_file_name=input_title,
track_name='Project Los Angeles',
list_of_MIDI_patches=patches,
verbose=False
)
return output_score
# -----------------------------
# DATASET LOADER FUNCTION
# -----------------------------
def load_dataset(dataset_name):
print('=' * 70)
print("Loading requested Orpheus MIDI search dataset...")
print('=' * 70)
if dataset_name == 'Multi-Instrumental':
emb_dataset = hf_hub_download(repo_id=EMB_DATASETS_REPO,
repo_type='dataset',
filename=MI_EMB_DATASET_FILES[0]
)
emb_dataset_corpus = hf_hub_download(repo_id=EMB_DATASETS_REPO,
repo_type='dataset',
filename=MI_EMB_DATASET_FILES[1]
)
else:
emb_dataset = hf_hub_download(repo_id=EMB_DATASETS_REPO,
repo_type='dataset',
filename=SP_EMB_DATASET_FILES[0]
)
emb_dataset_corpus = hf_hub_download(repo_id=EMB_DATASETS_REPO,
repo_type='dataset',
filename=SP_EMB_DATASET_FILES[1]
)
print('=' * 70)
print("Done!")
print('=' * 70)
return emb_dataset, emb_dataset_corpus
# -----------------------------
# MAIN MIDI SEARCH FUNCTION
# -----------------------------
def MIDI_Search(title, artist):
"""
Generate tokens using the model, update the composition state, and prepare outputs.
This function combines seed loading, token generation, and UI output packaging.
"""
print('=' * 70)
print("Request start time:", datetime.datetime.now(PDT).strftime("%Y-%m-%d %H:%M:%S"))
start_time = reqtime.time()
print('=' * 70)
print('Req title:', title)
print('Req artist:', artist)
print('Req dataset:', dataset)
print('=' * 70)
if title and artist:
input_title = title + ' --- ' + artist
else:
input_title = ''
if title:
input_title = title
if artist:
input_title = artist
print('Searching for best matching title...')
query_embedding = model.encode([input_title])
similarities = util.cos_sim(query_embedding,
embeddings_dataset_corpus
)
selected_title_index = np.argmax(similarities).tolist()
selected_title = song_artist_list[selected_title_index]
print('Done!')
print('=' * 70)
print('Selected title:', selected_title)
print('Selected title index:', selected_title_index)
print('=' * 70)
print('Rendering selected title...')
print('=' * 70)
final_outputs = []
song, artist, cap, score, matches = embeddings_dataset[selected_title_index]
audio, plot, fname = render_midi_output(score, selected_title)
top_ten_titles_list = [song_artist_list[i] for i, s in matches]
top_ten_titles = ''
for i, t in enumerate(top_ten_titles_list):
top_ten_titles += str(i+1) + ') ' + t + '\n'
final_outputs.extend([top_ten_titles,
selected_title,
cap,
audio,
plot,
fname
])
print('Done!')
print('=' * 70)
print('Rendering top 10 titles...')
print('=' * 70)
for idx, sim in matches:
song, artist, cap, score, matches = embeddings_dataset[idx]
title = song_artist_list[idx]
audio, plot, fname = render_midi_output(score, title)
final_outputs.extend([title,
cap,
audio,
plot,
fname
])
print('Done!')
print('=' * 70)
print(top_ten_titles)
print("Request end time:", datetime.datetime.now(PDT).strftime("%Y-%m-%d %H:%M:%S"))
print('=' * 70)
end_time = reqtime.time()
execution_time = end_time - start_time
print(f"Request execution time: {execution_time} seconds")
print('=' * 70)
return final_outputs
#====================================================================
print('=' * 70)
print('Prepping requested embeddings dataset...')
emb_dat, emb_dat_cor = load_dataset(dataset)
print('=' * 70)
print('Loading requested embeddings dataset...')
embeddings_dataset = TMIDIX.Tegridy_Any_Pickle_File_Reader(emb_dat,
verbose=False
)
song_artist_list = [d[0] + ' --- ' + d[1] for d in embeddings_dataset]
embeddings_dataset_corpus = np.load(emb_dat_cor)
print('Done!')
print('=' * 70)
#====================================================================
# -----------------------------
# GRADIO INTERFACE SETUP
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Orpheus MIDI Search</h1>")
gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Search for similar MIDIs with Orpheus embeddings</h1>")
gr.HTML("""
Check out <a href="https://huggingface.co/datasets/projectlosangeles/Godzilla-MIDI-Dataset">Godzilla MIDI Dataset</a> on Hugging Face
<p>
<a href="https://huggingface.co/spaces/asigalov61/Orpheus-MIDI-Search?duplicate=true">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-md.svg" alt="Duplicate in Hugging Face">
</a>
</p>
""")
gr.Markdown("# Enter any desired title, artist or both\n\n")
title = gr.Textbox(label="Song Title", value="Family Guy")
artist = gr.Textbox(label="Song Artist", value="TV Themes")
search_btn = gr.Button(value='Search', variant="primary")
gr.ClearButton(components=[title, artist])
gr.Markdown("# Search results")
gr.Markdown("## Top 10 matches summary")
top_ten_matches = gr.Textbox(label="Top 10 matches", lines=11, max_lines=11)
gr.Markdown("## Query MIDI preview")
query_midi_title = gr.Textbox(label="Query MIDI title", lines=1, max_lines=1)
query_midi_cap = gr.Textbox(label="Query MIDI caption", lines=7, max_lines=7)
query_audio = gr.Audio(label="Query MIDI audio", format="wav", elem_id="midi_audio")
query_plot = gr.Plot(label="Query MIDI score plot")
query_midi = gr.File(label="Query MIDI file", file_types=[".mid"])
outputs = [top_ten_matches,
query_midi_title,
query_midi_cap,
query_audio,
query_plot,
query_midi
]
gr.Markdown("## Top 10 matches previews")
for i in range(10):
with gr.Tab(f"Match # {i}"):
title_output = gr.Textbox(label=f"Match # {i} MIDI title", lines=1, max_lines=1)
cap_output = gr.Textbox(label=f"Match # {i} MIDI caption", lines=7, max_lines=7)
audio_output = gr.Audio(label=f"Match # {i} MIDI audio", format="mp3")
plot_output = gr.Plot(label=f"Match # {i} MIDI plot")
midi_output = gr.File(label=f"Match # {i} MIDI file", file_types=[".mid"])
outputs.extend([title_output,
cap_output,
audio_output,
plot_output,
midi_output
])
search_btn.click(
MIDI_Search,
[title,
artist
],
outputs
)
demo.launch()