Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import torch | |
| import time | |
| from src.cocktails.pipeline.get_affect2affective_cluster import get_affect2affective_cluster | |
| from src.music2cocktailrep.training.latent_translation.setup_trained_model import setup_trained_model | |
| from src.music2cocktailrep.pipeline.music2affect import setup_pretrained_affective_models | |
| global music2affect, find_affective_cluster, translation_vae | |
| import streamlit as st | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| def setup_translation_models(): | |
| global music2affect, find_affective_cluster, translation_vae | |
| music2affect, keys = setup_pretrained_affective_models() | |
| find_affective_cluster = get_affect2affective_cluster() | |
| translation_vae = setup_trained_model() | |
| return translation_vae | |
| def music2affect_cluster(handcoded_rep): | |
| global music2affect, find_affective_cluster | |
| affects = np.clip(music2affect(handcoded_rep), -1, 1) | |
| cluster_id = find_affective_cluster(affects) | |
| return cluster_id, affects | |
| def music2flavor(music_ai_rep, affective_cluster_id): | |
| global translation_vae | |
| cocktail_rep = translation_vae(music_ai_rep, modality_out='cocktail') | |
| return cocktail_rep | |
| def debug_translation(music_ai_rep): | |
| global translation_vae | |
| music_reconstruction = translation_vae(music_ai_rep, modality_out='music') | |
| return music_reconstruction | |
| def music2cocktailrep(music_ai_rep, handcoded_music_rep, verbose=False, level=0): | |
| init_time = time.time() | |
| if verbose: print(' ' * level + 'Synesthetic mapping..') | |
| if verbose: print(' ' * (level*2) + 'Mapping to affective cluster.') | |
| # affective_cluster_id, affect = music2affect_cluster(handcoded_music_rep) | |
| affective_cluster_id, affect = None, None | |
| if verbose: print(' ' * (level*2) + 'Mapping to flavors.') | |
| cocktail_rep = music2flavor(music_ai_rep, affective_cluster_id) | |
| if verbose: print(' ' * (level + 2) + f'Mapped in {int(time.time() - init_time)} seconds.') | |
| return cocktail_rep, affective_cluster_id, affect | |
| # def sigmoid(x, shift, beta): | |
| # return (1 / (1 + np.exp(-(x + shift) * beta)) - 0.5) * 2 | |
| # | |
| # cluster_colors = ['#%06X' % random.randint(0, 0xFFFFFF) for _ in range(10)] | |
| # def plot_cluster_ids_dataset(handcoded_rep_path): | |
| # import matplotlib.pyplot as plt | |
| # reps, _, _ = get_data(handcoded_rep_path, keys) | |
| # cluster_ids, affects = music2affect_cluster(reps) | |
| # # plt.figure() | |
| # # affects2 = affects.copy() | |
| # # affects2 = sigmoid(affects2, 0.05, 8) | |
| # # plt.hist(affects2[:, 2], bins=30) | |
| # # plt.xlim([-1, 1]) | |
| # fig = plt.figure() | |
| # ax = fig.add_subplot(projection='3d') | |
| # ax.set_xlim([-1, 1]) | |
| # ax.set_ylim([-1, 1]) | |
| # ax.set_zlim([-1, 1]) | |
| # for cluster_id in sorted(set(cluster_ids)): | |
| # indexes = np.argwhere(cluster_ids == cluster_id).flatten() | |
| # if len(indexes) > 0: | |
| # ax.scatter(affects[indexes, 0], affects[indexes, 1], affects[indexes, 2], c=cluster_colors[cluster_id], s=150) | |
| # ax.set_xlabel('Valence') | |
| # ax.set_ylabel('Arousal') | |
| # ax.set_zlabel('Dominance') | |
| # plt.figure() | |
| # plt.bar(range(10), [np.argwhere(cluster_ids == i).size for i in range(10)]) | |
| # plt.show() | |
| # | |
| # plot_cluster_ids_dataset(handcoded_rep_path) |