# import spaces from pathlib import Path import yaml import time import uuid import numpy as np import audiotools as at import argbind import shutil import torch from datetime import datetime import gradio as gr import spaces from vampnet.interface import Interface, signal_concat from vampnet import mask as pmask from ytmusicapi import YTMusic # from pyharp import AudioLabel, LabelList from bytecover.models.train_module import TrainModule from bytecover.utils import initialize_logging, load_config import pinecone import laion_clap from tqdm import tqdm import os ### INIT BYTECOVER print(f"Is CUDA available: {torch.cuda.is_available()}") # True print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") index_clap = pinecone.Index(os.environ["PC_API_KEY"], host=os.environ["CLAP_INDEX"]) #host='https://clap-nathan-500-index-af8053a.svc.us-west1-gcp.pinecone.io') index_bytecover = pinecone.Index(os.environ["PC_API_KEY"], host=os.environ["BC_INDEX"]) #host='https://bytecover-nathan-500-index-af8053a.svc.us-west1-gcp.pinecone.io') print("Loading ByteCover model") if torch.cuda.is_available(): bytecover_config = load_config(config_path="bytecover/config_gpu.yaml") else: bytecover_config = load_config(config_path="bytecover/config.yaml") bytecover_module = TrainModule(bytecover_config) bytecover_model = bytecover_module.model if bytecover_module.best_model_path is not None: bytecover_model.load_state_dict(torch.load(bytecover_module.best_model_path), strict=False) print(f"Best model loaded from checkpoint: {bytecover_module.best_model_path}") elif bytecover_module.config["test"]["model_ckpt"] is not None: bytecover_model.load_state_dict(torch.load(bytecover_module.config["test"]["model_ckpt"], map_location='cpu'), strict=False) print(f'Model loaded from checkpoint: {bytecover_module.config["test"]["model_ckpt"]}') elif bytecover_module.state == "initializing": print("Warning: Running with random weights") bytecover_model.eval() ytm = YTMusic() print("Loading CLAP model") if torch.cuda.is_available(): clap_model = laion_clap.CLAP_Module(enable_fusion=False, device="cuda:0") else: clap_model = laion_clap.CLAP_Module(enable_fusion=False) clap_model.load_ckpt() # download the default pretrained checkpoint. print("Models loaded!") def convert_to_npfloat64(original_array): #return np.array(flat_df["flat_vector_embed"][0],dtype=np.float64) return np.array(original_array,dtype=np.float64) def convert_to_npfloat64_to_list(vector_embed_64): # list(flat_df["flat_vector_embed_64"][0]) return list(vector_embed_64) def flatten_vector_embed(vector_embed): return list(vector_embed.flatten()) def format_time(num_seconds): return f"{num_seconds // 60}:{num_seconds % 60:02d}" def chunk_audio(chunk_size, sig, sr): # Chunk audio to desired length chunk_samples = int(chunk_size * sr) print(f"Chunk samples: {chunk_samples}") print(f"Shape of audio: {sig.shape}") chunks = torch.tensor_split(sig, [i for i in range(chunk_samples, sig.shape[1], chunk_samples)], dim=1) if chunks[-1].shape[1] < chunk_samples: print("Cutting last chunk due to length") chunks = tuple(list(chunks)[:-1]) print(f"Number of chunks: {len(chunks)}") return chunks def bytecover(sig, bytecover_match_ct=3, clap_match_ct=3, chunk_size=None): """ This function defines the audio processing steps Args: input_audio_path (str): the audio filepath to be processed. : additional keyword arguments necessary for processing. NOTE: These should correspond to and match order of UI elements defined below. Returns: output_audio_path (str): the filepath of the processed audio. output_labels (LabelList): any labels to display. """ """ """ """ """ sig_mono = sig.copy().to_mono().audio_data.squeeze(1) if chunk_size is not None: chunks = chunk_audio(chunk_size, sig_mono, sig.sample_rate) bc_chunks = chunks clap_chunks = chunks chunk_sizes = [chunk_size, chunk_size] else: bc_chunks = chunk_audio(10, sig_mono, sig.sample_rate) clap_chunks = chunk_audio(3, sig_mono, sig.sample_rate) chunk_sizes = [10, 3] print("Getting Bytecover embeddings") bytecover_embeddings = [] for chunk in tqdm(bc_chunks): result = bytecover_model.forward(chunk.to(bytecover_module.config["device"]))['f_t'].detach() bytecover_embeddings.append(result) clean_bytecover_embeddings = [convert_to_npfloat64_to_list(convert_to_npfloat64(flatten_vector_embed(embedding.cpu()))) for embedding in bytecover_embeddings] print("Getting CLAP embeddings") clap_embeddings = [] for chunk in tqdm(clap_chunks): result = clap_model.get_audio_embedding_from_data(chunk, use_tensor=True).detach() clap_embeddings.append(result) clean_clap_embeddings = [convert_to_npfloat64_to_list(convert_to_npfloat64(flatten_vector_embed(embedding.cpu()))) for embedding in clap_embeddings] clap_matches = [] bytecover_matches = [] match_metadatas = {} output_md = "" times = {} for clean_embeddings, pinecone_index, match_list, embedding_num, num_matches, chunk_size in zip([clean_bytecover_embeddings, clean_clap_embeddings], [index_bytecover, index_clap], [bytecover_matches, clap_matches], range(2), [bytecover_match_ct, clap_match_ct], chunk_sizes): if embedding_num == 0: continue output_md += "# Melodic Matches\n" else: output_md += "# Timbral Matches\n" for i, embedding in enumerate(clean_embeddings): print(f"Getting match {i + 1} of {len(clean_embeddings)}") matches = pinecone_index.query( vector=embedding, top_k=10, #include_values=False, include_metadata=True )['matches'] # Store matches as [score, time, id] for match in matches: id = match['id'] if id not in match_metadatas: match_metadatas[id] = match['metadata'] match_list.append([match['score'], i * chunk_size, id]) print("Matches obtained!") top_matches = sorted(match_list, key=lambda item: item[0], reverse=True) found_tracks = [] for i, match in enumerate(top_matches): if len(found_tracks) >= num_matches: break #print(match[0]) metadata = match_metadatas[match[2]] song_artists = metadata['artists'] if type(song_artists) is list: artists = ', '.join(artists) song_title = metadata['song'] if metadata['spotify_id'] in found_tracks: continue found_tracks.append(metadata['spotify_id']) song_genre = metadata['genre'] yt_id = ytm.search(f"{song_title} {song_artists}", filter="songs", limit = 1)[0]['videoId'] song_link = f"https://music.youtube.com/watch?v={yt_id}&t={int(metadata['clip_num']) * 10}" #song_link = f"https://open.spotify.com/track/{metadata['spotify_id'].split(':')[2]}" embed_name = ['Melodic', 'Timbral'][embedding_num] match_time = match[1] times[match_time] = times.get(match_time, 0) + 1 # if embedding_num == 1: # color = OutputLabel.rgb_color_to_int(200, 170, 3, 20) # else: # color = OutputLabel.rgb_color_to_int(204, 52, 235, 20) # if match[0] < 0.5: # color_list = min_sim_color # else: # color_list = [int(min_color + (match[0] - 0.5) * 2 * (max_color - min_color)) for min_color, max_color in zip(min_sim_color, max_sim_color)] # if match[0] < 0.5: # color_list = [0, 200, 0, 20] # normalized_similarity = (match[0] - 0.5) * 2 # color_list = [int(min(400 * normalized_similarity, 200)), int(min(400 * (1 - normalized_similarity), 200)), 0, 20] output_md += f'{format_time(match_time)}: \n [{song_title} by {song_artists}]({song_link}) \n Genre: {song_genre} \n Similarity: {match[0]}\n\n' # label = AudioLabel(t=match_time, # label=f'{song_title}', # duration=chunk_size, # link=song_link, # description=f'Similarity type: {embed_name}, similarity: {match[0]}\n{song_title} by {song_artists}\nGenre: {song_genre}\nClick the tag to view on YouTube Music!', # # amplitude=1.0 - 0.5 * (times[match_time] - 1), # color=color) return output_md ### END BYTECOVER device = "cuda" if torch.cuda.is_available() else "cpu" interface = Interface.default() init_model_choice = open("DEFAULT_MODEL").read().strip() # load the init model interface.load_finetuned(init_model_choice) def to_output(sig): return sig.sample_rate, sig.cpu().detach().numpy()[0][0] MAX_DURATION_S = 10 def load_audio(file): print(file) if isinstance(file, str): filepath = file elif isinstance(file, tuple): # not a file sr, samples = file samples = samples / np.iinfo(samples.dtype).max return sr, samples else: filepath = file.name sig = at.AudioSignal.salient_excerpt( filepath, duration=MAX_DURATION_S ) # sig = at.AudioSignal(filepath) return to_output(sig) def load_example_audio(): return load_audio("./assets/example.wav") from torch_pitch_shift import pitch_shift, get_fast_shifts def shift_pitch(signal, interval: int): signal.samples = pitch_shift( signal.samples, shift=interval, sample_rate=signal.sample_rate ) return signal def mask_preview(periodic_p, n_mask_codebooks, onset_mask_width, dropout): # make a mask preview codes = torch.zeros((1, 14, 80)).to(device) mask = interface.build_mask( codes, periodic_prompt=periodic_p, # onset_mask_width=onset_mask_width, _dropout=dropout, upper_codebook_mask=n_mask_codebooks, ) # mask = mask.cpu().numpy() import matplotlib.pyplot as plt plt.clf() interface.visualize_codes(mask) plt.title("mask preview") plt.savefig("scratch/mask-prev.png") return "scratch/mask-prev.png" @spaces.GPU def _vamp_internal( seed, input_audio, model_choice, pitch_shift_amt, periodic_p, n_mask_codebooks, onset_mask_width, dropout, sampletemp, typical_filtering, typical_mass, typical_min_tokens, top_p, sample_cutoff, stretch_factor, sampling_steps, beat_mask_ms, num_feedback_steps, api=False ): t0 = time.time() interface.to("cuda" if torch.cuda.is_available() else "cpu") print(f"using device {interface.device}") _seed = seed if seed > 0 else None if _seed is None: _seed = int(torch.randint(0, 2**32, (1,)).item()) at.util.seed(_seed) if input_audio is None: raise gr.Error("please upload an audio file") sr, input_audio = input_audio input_audio = input_audio / np.iinfo(input_audio.dtype).max sig = at.AudioSignal(input_audio, sr) # reload the model if necessary interface.load_finetuned(model_choice) if pitch_shift_amt != 0: sig = shift_pitch(sig, pitch_shift_amt) codes = interface.encode(sig) mask = interface.build_mask( codes, sig, rand_mask_intensity=1.0, prefix_s=0.0, suffix_s=0.0, periodic_prompt=int(periodic_p), periodic_prompt_width=1, onset_mask_width=onset_mask_width, _dropout=dropout, upper_codebook_mask=int(n_mask_codebooks), ) # save the mask as a txt file interface.set_chunk_size(10.0) codes, mask = interface.vamp( codes, mask, batch_size=1 if api else 1, feedback_steps=1, _sampling_steps=12 if sig.duration <6.0 else 24, time_stretch_factor=stretch_factor, return_mask=True, temperature=sampletemp, typical_filtering=typical_filtering, typical_mass=typical_mass, typical_min_tokens=typical_min_tokens, top_p=None, seed=_seed, sample_cutoff=1.0, ) print(f"vamp took {time.time() - t0} seconds") sig = interface.decode(codes) # run bytecover bytecover_match_ct = 3 clap_match_ct = 3 chunk_size = 3.0 labels = bytecover(sig, chunk_size, bytecover_match_ct, clap_match_ct) return to_output(sig), labels def vamp(input_audio, sampletemp, top_p, periodic_p, dropout, stretch_factor, onset_mask_width, typical_filtering, typical_mass, typical_min_tokens, seed, model_choice, n_mask_codebooks, pitch_shift_amt, sample_cutoff, sampling_steps, beat_mask_ms, num_feedback_steps): return _vamp_internal( seed=seed, input_audio=input_audio, model_choice=model_choice, pitch_shift_amt=pitch_shift_amt, periodic_p=periodic_p, n_mask_codebooks=n_mask_codebooks, onset_mask_width=onset_mask_width, dropout=dropout, sampletemp=sampletemp, typical_filtering=typical_filtering, typical_mass=typical_mass, typical_min_tokens=typical_min_tokens, top_p=top_p, sample_cutoff=sample_cutoff, stretch_factor=stretch_factor, sampling_steps=sampling_steps, beat_mask_ms=beat_mask_ms, num_feedback_steps=num_feedback_steps, api=False, ) with gr.Blocks() as demo: with gr.Row(): with gr.Column(): manual_audio_upload = gr.File( label=f"upload some audio (will be randomly trimmed to max of 100s)", file_types=["audio"] ) load_example_audio_button = gr.Button("or load example audio") input_audio = gr.Audio( label="input audio", interactive=False, type="numpy", ) # audio_mask = gr.Audio( # label="audio mask (listen to this to hear the mask hints)", # interactive=False, # type="numpy", # ) # connect widgets load_example_audio_button.click( fn=load_example_audio, inputs=[], outputs=[ input_audio] ) manual_audio_upload.change( fn=load_audio, inputs=[manual_audio_upload], outputs=[ input_audio] ) # mask settings with gr.Column(): with gr.Accordion("manual controls", open=True): periodic_p = gr.Slider( label="periodic prompt", minimum=0, maximum=13, step=1, value=7, ) onset_mask_width = gr.Slider( label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) does not affect mask preview", minimum=0, maximum=100, step=1, value=0, visible=True ) beat_mask_ms = gr.Slider( label="beat mask width (milliseconds) does not affect mask preview", minimum=1, maximum=200, step=1, value=0, visible=True ) n_mask_codebooks = gr.Slider( label="compression prompt ", value=3, minimum=1, maximum=14, step=1, ) dropout = gr.Slider( label="mask dropout", minimum=0.0, maximum=1.0, step=0.01, value=0.0 ) num_feedback_steps = gr.Slider( label="feedback steps (token telephone) -- turn it up for better timbre/rhythm transfer quality, but it's slower!", minimum=1, maximum=8, step=1, value=1 ) preset_dropdown = gr.Dropdown( label="preset", choices=["timbre transfer", "small variation", "small variation (follow beat)", "medium variation", "medium variation (follow beat)", "large variation", "large variation (follow beat)", "unconditional"], value="medium variation" ) def change_preset(preset_dropdown): if preset_dropdown == "timbre transfer": periodic_p = 2 n_mask_codebooks = 1 onset_mask_width = 0 dropout = 0.0 beat_mask_ms = 0 elif preset_dropdown == "small variation": periodic_p = 5 n_mask_codebooks = 4 onset_mask_width = 0 dropout = 0.0 beat_mask_ms = 0 elif preset_dropdown == "small variation (follow beat)": periodic_p = 7 n_mask_codebooks = 4 onset_mask_width = 0 dropout = 0.0 beat_mask_ms = 50 elif preset_dropdown == "medium variation": periodic_p = 7 n_mask_codebooks = 4 onset_mask_width = 0 dropout = 0.0 beat_mask_ms = 0 elif preset_dropdown == "medium variation (follow beat)": periodic_p = 13 n_mask_codebooks = 4 onset_mask_width = 0 dropout = 0.0 beat_mask_ms = 50 elif preset_dropdown == "large variation": periodic_p = 13 n_mask_codebooks = 4 onset_mask_width = 0 dropout = 0.2 beat_mask_ms = 0 elif preset_dropdown == "large variation (follow beat)": periodic_p = 0 n_mask_codebooks = 4 onset_mask_width = 0 dropout = 0.0 beat_mask_ms=80 elif preset_dropdown == "unconditional": periodic_p=0 n_mask_codebooks=1 onset_mask_width=0 dropout=0.0 return periodic_p, n_mask_codebooks, onset_mask_width, dropout, beat_mask_ms preset_dropdown.change( fn=change_preset, inputs=[preset_dropdown], outputs=[periodic_p, n_mask_codebooks, onset_mask_width, dropout, beat_mask_ms] ) # preset_dropdown.change( maskimg = gr.Image( label="mask image", interactive=False, type="filepath" ) with gr.Accordion("extras ", open=False): pitch_shift_amt = gr.Slider( label="pitch shift amount (semitones)", minimum=-12, maximum=12, step=1, value=0, ) stretch_factor = gr.Slider( label="time stretch factor", minimum=0, maximum=8, step=1, value=1, ) with gr.Accordion("sampling settings", open=False): sampletemp = gr.Slider( label="sample temperature", minimum=0.1, maximum=10.0, value=1.0, step=0.001 ) top_p = gr.Slider( label="top p (0.0 = off)", minimum=0.0, maximum=1.0, value=0.0 ) typical_filtering = gr.Checkbox( label="typical filtering ", value=True ) typical_mass = gr.Slider( label="typical mass (should probably stay between 0.1 and 0.5)", minimum=0.01, maximum=0.99, value=0.15 ) typical_min_tokens = gr.Slider( label="typical min tokens (should probably stay between 1 and 256)", minimum=1, maximum=256, step=1, value=64 ) sample_cutoff = gr.Slider( label="sample cutoff", minimum=0.0, maximum=0.9, value=1.0, step=0.01 ) sampling_steps = gr.Slider( label="sampling steps", minimum=1, maximum=128, step=1, value=36 ) seed = gr.Number( label="seed (0 for random)", value=0, precision=0, ) # mask settings with gr.Column(): model_choice = gr.Dropdown( label="model choice", choices=list(interface.available_models()), value=init_model_choice, visible=True ) vamp_button = gr.Button("generate (vamp)!!!") audio_outs = [] use_as_input_btns = [] for i in range(1): with gr.Column(): audio_outs.append(gr.Audio( label=f"output audio {i+1}", interactive=False, type="numpy" )) use_as_input_btns.append( gr.Button(f"use as input (feedback)") ) #thank_you = gr.Markdown("") labels = gr.Markdown(label="output labels") # download all the outputs # download = gr.File(type="filepath", label="download outputs") # mask preview change for widget in ( periodic_p, n_mask_codebooks, onset_mask_width, dropout ): widget.change( fn=mask_preview, inputs=[periodic_p, n_mask_codebooks, onset_mask_width, dropout], outputs=[maskimg] ) _inputs = [ input_audio, sampletemp, top_p, periodic_p, dropout, stretch_factor, onset_mask_width, typical_filtering, typical_mass, typical_min_tokens, seed, model_choice, n_mask_codebooks, pitch_shift_amt, sample_cutoff, sampling_steps, beat_mask_ms, num_feedback_steps ] # connect widgets vamp_button.click( fn=vamp, inputs=_inputs, outputs=[audio_outs[0], labels], ) # api_vamp_button = gr.Button("api vamp", visible=True) # api_vamp_button.click( # fn=api_vamp, # inputs=_inputs, # outputs=[audio_outs[0]], # api_name="vamp" # ) # from pyharp import ModelCard, build_endpoint # card = ModelCard( # name="vampnet + aitribution", # description="vampnet! is a model for generating audio from audio", # author="hugo flores garcĂ­a", # tags=["music generation"], # midi_in=False, # midi_out=False # ) # BYTECOVER # Define Gradio Components # components = [ # # # gr.Slider( # minimum=1.0, # maximum=10.0, # step=0.5, # value=3.0, # label="Sample size (s)" # ), # gr.Slider( # minimum=0, # maximum=5, # step=1, # value=3, # label="Bytecover matches to generate" # ), # gr.Slider( # minimum=0, # maximum=5, # step=1, # value=3, # label="CLAP matches to generate" # ) # ] # Build a HARP-compatible endpoint # app = build_endpoint(model_card=card, # components=[ # periodic_p, # n_mask_codebooks, # *components # ], # process_fn=harp_vamp) try: demo.queue() demo.launch(share=True) except KeyboardInterrupt: shutil.rmtree("gradio-outputs", ignore_errors=True) raise