Spaces:
Runtime error
Runtime error
| # Generate ByteCover and CLAP Embeddings for a dataset and put to Pinecone | |
| import argparse | |
| import os | |
| from typing import Iterator | |
| import time | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import laion_clap | |
| from tqdm import tqdm | |
| from pinecone.grpc import PineconeGRPC as Pinecone | |
| from pinecone import PodSpec, PineconeApiException | |
| import ffmpeg | |
| from bytecover.models.train_module import TrainModule | |
| from bytecover.models.data_loader import ByteCoverDataset | |
| from bytecover.utils import load_config | |
| class BatchGenerator: | |
| # | |
| def __init__(self, batch_size: int = 10) -> None: | |
| self.batch_size = batch_size | |
| # | |
| # Makes chunks out of an input DataFrame | |
| def to_batches(self, df: pd.DataFrame) -> Iterator[pd.DataFrame]: | |
| splits = self.splits_num(df.shape[0]) | |
| if splits <= 1: | |
| yield df | |
| else: | |
| for chunk in np.array_split(df, splits): | |
| yield chunk | |
| # | |
| # Determines how many chunks DataFrame contains | |
| def splits_num(self, elements: int) -> int: | |
| return round(elements / self.batch_size) | |
| # | |
| __call__ = to_batches | |
| # quantization | |
| def int16_to_float32(x): | |
| return (x / 32767.0).astype(np.float32) | |
| def float32_to_int16(x): | |
| x = np.clip(x, a_min=-1., a_max=1.) | |
| return (x * 32767.).astype(np.int16) | |
| def flatten_vector_embed(vector_embed): | |
| return list(vector_embed.flatten()) | |
| def grab_song_title(vector_name): | |
| return vector_name.split("_")[0] | |
| def convert_to_npfloat64(original_array): | |
| #return np.array(flat_df["flat_vector_embed"][0],dtype=np.float64) | |
| return np.array(original_array,dtype=np.float64) | |
| def convert_to_npfloat64_to_list(vector_embed_64): | |
| # list(flat_df["flat_vector_embed_64"][0]) | |
| return list(vector_embed_64) | |
| def look_up_metadata(track_id, meta_dataframe, meta_col_interest): | |
| # track_id: form = spotify:track:id_##,mp3 | |
| # meta_datframe: df of all the metavalues | |
| # column options = album, artist_names, popularity, release_date, genre | |
| df_id = track_id.split("_")[0] | |
| meta_row = meta_dataframe[meta_dataframe['uri'] == df_id].reset_index(drop=True) | |
| try: | |
| return meta_row[meta_col_interest][0] | |
| except: | |
| return "unknown" | |
| #return meta_row[meta_col_interest][0] | |
| def strip_year_from_date(full_date): | |
| if type(full_date) == int: | |
| return str(full_date) | |
| else: | |
| try: | |
| return full_date[:4] | |
| except: | |
| return "CHECK_THIS" | |
| def strip_vector_clip(vector_name): | |
| return vector_name.split(".")[0].split("_")[1] | |
| def get_triplet_num(vector_name_str): | |
| return str(int(vector_name_str.split("_")[2].split(".")[0]) + 1) | |
| def generate(audio_dir, metadata_dir, index_naming_conv): | |
| # FILE AND METADATA LOADING | |
| file_list = [f for f in os.listdir(audio_dir)] | |
| print(f"Found {len(file_list)} files") | |
| meta_list = [f for f in os.listdir(metadata_dir)] | |
| meta_list = sorted(meta_list) | |
| meta_df = pd.read_json(metadata_dir + "/" + meta_list[0]) | |
| for i in range(1, len(meta_list)-1): | |
| new_row = pd.read_json(metadata_dir + "/" + meta_list[i]) | |
| meta_df = pd.concat([meta_df, new_row]).reset_index(drop = True) | |
| meta_df["year"] = meta_df.apply(lambda row: strip_year_from_date(row['release_date']),axis=1) | |
| # BYTECOVER MODEL INITIALIZATION | |
| print("Loading ByteCover model") | |
| bytecover_config = load_config(config_path="bytecover/config.yaml") | |
| bytecover_module = TrainModule(bytecover_config) | |
| bytecover_model = bytecover_module.model | |
| if bytecover_module.best_model_path is not None: | |
| bytecover_model.load_state_dict(torch.load(bytecover_module.best_model_path), strict=False) | |
| print(f"Best model loaded from checkpoint: {bytecover_module.best_model_path}") | |
| elif bytecover_module.config["test"]["model_ckpt"] is not None: | |
| bytecover_model.load_state_dict(torch.load(bytecover_module.config["test"]["model_ckpt"], map_location='cpu'), strict=False) | |
| print(f'Model loaded from checkpoint: {bytecover_module.config["test"]["model_ckpt"]}') | |
| elif bytecover_module.state == "initializing": | |
| print("Warning: Running with random weights") | |
| bytecover_model.eval() | |
| # BYTECOVER EMBEDDING GENERATION | |
| audio_dict_bytecover = {} | |
| for file in tqdm(file_list, desc="Generating Bytecover Embeddings"): | |
| file_path = audio_dir + file | |
| # try statement here allows you to skip to items you haven't yet embedded if you stop this step midway (if a key exists, you move on to next key) | |
| try: | |
| audio_dict_bytecover[file] | |
| except: | |
| # Load audio | |
| try: | |
| # This launches a subprocess to decode audio while down-mixing and resampling as necessary. | |
| # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. | |
| out, _ = ( | |
| ffmpeg.input(file_path, threads=0) | |
| .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=22050) | |
| .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) | |
| ) | |
| except ffmpeg.Error as e: | |
| raise RuntimeError( | |
| f"Failed to load audio:{file_path}\n{e.stderr.decode()}" | |
| ) from e | |
| audio = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 | |
| song_tensor = torch.from_numpy(audio) | |
| # this step grabs a ByteCover embedding | |
| audio_embed = bytecover_model.forward(song_tensor.to(bytecover_module.config["device"]))['f_t'].detach() | |
| audio_dict_bytecover[file] = audio_embed.squeeze() | |
| # CLAP MODEL INITIALIZATION | |
| print("Loading CLAP model") | |
| clap_model = laion_clap.CLAP_Module(enable_fusion=False) | |
| clap_model.load_ckpt() # download the default pretrained checkpoint. | |
| # CLAP EMBEDDING GENERATION | |
| audio_dict_CLAP = {} | |
| for file in tqdm(file_list, desc="Generating CLAP Embeddings"): | |
| # try statement here allows you to skip to items you haven't yet embedded if you stop this step midway (if a key exists, you move on to next key) | |
| try: | |
| audio_dict_CLAP[file] | |
| except: | |
| # Get audio embeddings from audio data | |
| full_path = audio_dir + "/" + file | |
| # this step grabs a CLAP embedding from laion_clap library | |
| audio_embed = clap_model.get_audio_embedding_from_filelist(x = [full_path], use_tensor=False) | |
| audio_dict_CLAP[file] = audio_embed | |
| # DATAFRAME GENERATION | |
| flat_dfs = [] | |
| for audio_dict in [audio_dict_CLAP, audio_dict_bytecover]: | |
| flat_df = pd.DataFrame(audio_dict.items(), columns=['vector_name','vector_embed']).reset_index() | |
| flat_df.columns=['vector_id','vector_name','vector_embed'] | |
| flat_df["song_title"] = flat_df.apply(lambda row: grab_song_title(row['vector_name']),axis=1) | |
| flat_df["flat_vector_embed"] = flat_df.apply(lambda row: flatten_vector_embed(row['vector_embed']),axis=1) | |
| flat_df["flat_vector_embed_64"] = flat_df.apply(lambda row: convert_to_npfloat64(row['flat_vector_embed']),axis=1) | |
| flat_df["flat_vector_embed_64_list"] = flat_df.apply(lambda row: convert_to_npfloat64_to_list(row['flat_vector_embed_64']),axis=1) | |
| flat_df["genre"] = flat_df.apply(lambda row: look_up_metadata(row['vector_name'], meta_df, 'genre'),axis=1) | |
| flat_df["album"] = flat_df.apply(lambda row: look_up_metadata(row['vector_name'], meta_df, 'album'),axis=1) | |
| flat_df["name"] = flat_df.apply(lambda row: look_up_metadata(row['vector_name'], meta_df, 'name'),axis=1) | |
| flat_df["artist"] = flat_df.apply(lambda row: look_up_metadata(row['vector_name'], meta_df, 'artist_names'),axis=1) | |
| flat_df["year"] = flat_df.apply(lambda row: look_up_metadata(row['vector_name'], meta_df, 'year'),axis=1) | |
| flat_df["vector_clip_num"] = flat_df.apply(lambda row: strip_vector_clip(row['vector_name']),axis=1) | |
| flat_df['embedding_triplet_num'] = flat_df.vector_name.apply(get_triplet_num) | |
| flat_dfs.append(flat_df) | |
| print("unique songs:", len(flat_df.song_title.unique())) | |
| # PINECONE UPLOAD | |
| api_key = os.environ['PC_API_KEY'] | |
| pc = Pinecone(api_key=api_key) | |
| index_name_clap = f'clap-{index_naming_conv}' # free (comes with plan, can have 100k records) | |
| index_name_bytecover = f'bytecover-{index_naming_conv}' # free (comes with plan, can have 100k records) | |
| index_env = 'us-west1-gcp' # NOT free (take down when not in use) | |
| pod_type = 'p1.x1' # NOT free (take down when not in use) | |
| for index_name, flat_df, index_dim in zip([index_name_clap, index_name_bytecover], flat_dfs, [512, 2048]): | |
| try: | |
| pc.create_index( | |
| name=index_name, | |
| dimension=index_dim, | |
| metric="cosine", | |
| spec=PodSpec( | |
| environment=index_env, | |
| pod_type=pod_type, | |
| pods=1 | |
| ), | |
| deletion_protection="disabled" | |
| ) | |
| except PineconeApiException: | |
| print(f"WARNING: INDEX {index_name} ALREADY EXISTS") | |
| time.sleep(5) | |
| index = pc.Index(index_name) | |
| batch_id = 0 | |
| df_batcher = BatchGenerator(64) | |
| for batch_df in tqdm(df_batcher(flat_df), desc="Uploading batches"): | |
| #print(batch_df) | |
| batch_id = batch_id + 1 | |
| index.upsert(vectors=list(zip(batch_df["vector_name"],batch_df["flat_vector_embed_64_list"]))) | |
| failed_list_update_metadata = [] | |
| for vec_id in tqdm(range(0,len(flat_df)), desc="Adding metadata"): | |
| try: | |
| row = flat_df.iloc[vec_id] | |
| index.update(id=str(row['vector_name']), | |
| set_metadata={"genre": row['genre'], | |
| "song" : row['name'], | |
| "album": row['album'], | |
| "artists": row['artist'], | |
| "year" : str(row['year']), | |
| "clip_num" : row['vector_clip_num'], | |
| "triplet_num": str(row['embedding_triplet_num']), | |
| "spotify_id" : row['song_title'] | |
| }) | |
| except: | |
| print("failed on:", vec_id) | |
| failed_list_update_metadata.append(vec_id) | |
| pc.create_collection(index_name, index_name) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Generate ByteCover and CLAP Embeddings for a dataset and put to Pinecone") | |
| parser.add_argument('audio_dir') | |
| parser.add_argument('metadata_dir') | |
| parser.add_argument('index_name') | |
| args = parser.parse_args() | |
| generate(args.audio_dir, args.metadata_dir, args.index_name) |