Spaces:
Build error
Build error
| import gradio as gr | |
| import os | |
| import json | |
| import numpy as np | |
| import torch | |
| import librosa | |
| from tools import VAE_out_put_to_spc, np_power_to_db | |
| from model.VAE_torchV import Encoder, Decoder | |
| SPECTROGRAM_RESOLUTION = (512, 256, 3) | |
| device = "cpu" | |
| encoder = Encoder((1, 512, 256), 24, N2=0, channel_sizes=[64, 64, 64, 96, 96, 128, 160, 216]).to(device) | |
| decoder = Decoder(24, N2=0, N3=8, channel_sizes=[64, 64, 64, 96, 96, 128, 160, 216]).to(device) | |
| model_name = "test" | |
| encoder.load_state_dict(torch.load(f"models/test_encoder_CA.pt", map_location=torch.device(device))) | |
| decoder.load_state_dict(torch.load(f"models/test_decoder_CA.pt", map_location=torch.device(device))) | |
| INIT_ENCODE_CACHE = {"init": np.random.random((24, ))} | |
| with open('webUI/initial_example_encodes.json', 'r') as f: | |
| list_dict = json.load(f) | |
| for k in list_dict.keys(): | |
| INIT_ENCODE_CACHE[k] = np.array(list_dict[k]) | |
| ################################# | |
| def prepare_image(image): | |
| # Rescale to 0-255 and convert to uint8 | |
| rescaled = (image + 80.0) / 80.0 | |
| rescaled = (255.0 * rescaled).astype(np.uint8) | |
| return rescaled | |
| def encodeBatch2GradioOutput(latent_vector_batch, resolution=(512, 256)): | |
| """Show a spectrogram.""" | |
| reconstruction_batch = decoder(latent_vector_batch).to("cpu").detach().numpy() | |
| flipped_log_spectrums, rec_signals = [], [] | |
| for reconstruction in reconstruction_batch: | |
| spc = VAE_out_put_to_spc(reconstruction) | |
| spc = np.reshape(spc, resolution) | |
| magnitude_spectrum = np.abs(spc) | |
| log_spectrum = np_power_to_db(magnitude_spectrum) | |
| flipped_log_spectrum = np.flipud(log_spectrum) | |
| colorful_spc = np.ones((512, 256, 3)) * -80.0 | |
| colorful_spc[:, :, 0] = flipped_log_spectrum | |
| colorful_spc[:, :, 1] = flipped_log_spectrum | |
| colorful_spc[:, :, 2] = np.ones((512, 256)) * -60.0 | |
| flipped_log_spectrum = prepare_image(colorful_spc) | |
| # get_audio | |
| abs_spec = np.zeros((513, 256)) | |
| abs_spec[:512, :] = abs_spec[:512, :] + np.sqrt(np.reshape(spc, (512, 256))) | |
| rec_signal = librosa.griffinlim(abs_spec, n_iter=32, hop_length=256, win_length=1024) | |
| flipped_log_spectrums.append(flipped_log_spectrum) | |
| rec_signals.append(rec_signal) | |
| return flipped_log_spectrums, 16000, rec_signals | |
| def get_example_module(encodeCache): | |
| def show_example(selected_example, encodeCache): | |
| example_encode = torch.Tensor(np.reshape(encodeCache[selected_example], (-1, 24))).to(device) | |
| flipped_log_spectrums, sampleRate, rec_signals = encodeBatch2GradioOutput(example_encode) | |
| flipped_log_spectrum, rec_signal = flipped_log_spectrums[0], rec_signals[0] | |
| return flipped_log_spectrum, (sampleRate, rec_signal), encodeCache | |
| with gr.Tab("Examples"): | |
| gr.Markdown("Some predefined examples.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| selected_example = gr.Dropdown( | |
| list(INIT_ENCODE_CACHE.keys()), label="Examples", info="Choose one example! More samples coming." | |
| ) | |
| example_button = gr.Button(value="Show example") | |
| with gr.Column(): | |
| example_image_output = gr.Image(label="Spectrogram", type="numpy") | |
| example_image_output.style(height=250, width=600) | |
| example_audio_output = gr.Audio(type="numpy", label="Play the example!") | |
| example_button.click(show_example, inputs=[selected_example, encodeCache], | |
| outputs=[example_image_output, example_audio_output, encodeCache]) | |
| def get_reconstruction_module(): | |
| def do_nothing(image_input): | |
| return np.random.random(SPECTROGRAM_RESOLUTION) | |
| with gr.Tab("Reconstruction"): | |
| gr.Markdown("Test reconstruction.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| test_reconstruction_input = gr.Number(label="Batch_index") | |
| test_reconstruction_button = gr.Button(value="Generate") | |
| with gr.Column(): | |
| test_reconstruction_output = gr.Image(label="Reconstruction", type="numpy") | |
| test_reconstruction_output.style(height=250, width=600) | |
| test_reconstruction_button.click(do_nothing, inputs=test_reconstruction_input, outputs=test_reconstruction_output) | |
| def get_interpolation_module(encodeCache): | |
| def interpolate(first_interpulation_input, second_interpulation_input, interpulation_input_ratio, encodeCache): | |
| # Todo: use batch | |
| first_interpulation_input_encode = torch.Tensor(np.reshape(encodeCache[first_interpulation_input], (-1, 24))) | |
| second_interpulation_input_encode = torch.Tensor(np.reshape(encodeCache[second_interpulation_input], (-1, 24))) | |
| ratio = torch.Tensor([interpulation_input_ratio]) | |
| interpulation_encode = first_interpulation_input_encode * ratio + second_interpulation_input_encode * (1 - ratio) | |
| interpulation_input_encode = torch.stack((first_interpulation_input_encode, second_interpulation_input_encode, interpulation_encode), dim=0).to(device) | |
| flipped_log_spectrums, sampleRate, rec_signals = encodeBatch2GradioOutput(interpulation_input_encode) | |
| first_flipped_log_spectrum, first_rec_signal = flipped_log_spectrums[0], rec_signals[0] | |
| second_flipped_log_spectrum, second_rec_signal = flipped_log_spectrums[1], rec_signals[1] | |
| interpolation_flipped_log_spectrum, interpolation_rec_signal = flipped_log_spectrums[2], rec_signals[2] | |
| return first_flipped_log_spectrum, (sampleRate, first_rec_signal), second_flipped_log_spectrum, (sampleRate, second_rec_signal), interpolation_flipped_log_spectrum, (sampleRate, interpolation_rec_signal), encodeCache | |
| def refresh_interpolation_input(encodeCache): | |
| return gr.Dropdown.update(choices=list(encodeCache.keys())), gr.Dropdown.update(choices=list(encodeCache.keys())), encodeCache | |
| with gr.Tab("Interpolation"): | |
| gr.Markdown("Test Interpolation. Sounds that you sampled can be added to the input dropdown by clicking [Refresh].") | |
| with gr.Row(): | |
| with gr.Column(): | |
| interpulation_refresh_button = gr.Button(value="Refresh") | |
| with gr.Row(): | |
| first_interpulation_input = gr.Dropdown(list(INIT_ENCODE_CACHE.keys()), label="First input") | |
| second_interpulation_input = gr.Dropdown(list(INIT_ENCODE_CACHE.keys()), label="Second input") | |
| with gr.Row(): | |
| first_input_audio = gr.Audio(type="numpy", label="Play first input") | |
| first_input_audio.style(length=125) | |
| second_input_audio = gr.Audio(type="numpy", label="sePlay second input") | |
| second_input_audio.style(length=125) | |
| interpulation_input_ratio = gr.Slider(minimum=-0.20, maximum=1.20, value=0.5, step=0.01, label="Ratio of the first input.") | |
| interpulation_button = gr.Button(value="Interpulate") | |
| with gr.Column(): | |
| with gr.Row(): | |
| first_input_spectrogram = gr.Image(label="First Input", type="numpy") | |
| first_input_spectrogram.style(height=250, width=125) | |
| interpolation_spectrogram = gr.Image(label="Interpolation", type="numpy") | |
| interpolation_spectrogram.style(height=250, width=125) | |
| second_input_spectrogram = gr.Image(label="Second Input", type="numpy") | |
| second_input_spectrogram.style(height=250, width=125) | |
| interpolation_audio = gr.Audio(type="numpy", label="Play interpolation") | |
| interpolation_audio.style(length=125) | |
| interpulation_refresh_button.click(refresh_interpolation_input, inputs=[encodeCache], | |
| outputs=[first_interpulation_input, second_interpulation_input, encodeCache]) | |
| interpulation_button.click(interpolate, inputs=[first_interpulation_input, second_interpulation_input, interpulation_input_ratio, encodeCache], | |
| outputs=[first_input_spectrogram, first_input_audio, second_input_spectrogram, second_input_audio, interpolation_spectrogram, interpolation_audio, encodeCache]) | |
| def get_random_sampling_module(encodeCache, current_encode): | |
| def random_sample(sigma, current_encode): | |
| random_encode = torch.Tensor([sigma]) * torch.randn(1, 24) | |
| # random_encode = torch.Tensor([mu]) + torch.Tensor([sigma]) * torch.randn(1, 24) | |
| current_encode = current_encode * 0.0 + random_encode.detach().numpy() | |
| random_encode = random_encode.to(device) | |
| flipped_log_spectrums, sampleRate, rec_signals = encodeBatch2GradioOutput(random_encode) | |
| random_log_spectrum, random_rec_signal = flipped_log_spectrums[0], rec_signals[0] | |
| return random_log_spectrum, (sampleRate, random_rec_signal), current_encode | |
| def save_sample(save_name, current_encode, encodeCache): | |
| if not (np.sum(current_encode) == 24): | |
| if len(save_name) == 0: | |
| return "The save name is empty.", current_encode, encodeCache | |
| encodeCache[save_name] = current_encode | |
| return "Sample saved.", current_encode, encodeCache | |
| else: | |
| return f"Please generate one sample.", current_encode, encodeCache | |
| with gr.Tab("Random sampling"): | |
| gr.Markdown("Sample new sound! Feel free to name and save your samples!") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| # mu = gr.Number(label="mu") | |
| sigma = gr.Number(value=1.0, label="sigma") | |
| random_sampling_button = gr.Button(value="Sample") | |
| with gr.Column(): | |
| random_sampling_spectrogram = gr.Image(label="Random sampling", type="numpy") | |
| random_sampling_spectrogram.style(height=250, width=600) | |
| random_sampling_audio = gr.Audio(type="numpy", label="Play the sample") | |
| random_sampling_audio.style(length=125) | |
| save_name_input = gr.Textbox(label="Name your sound") | |
| save_button = gr.Button(value="save") | |
| save_name_output = gr.Textbox(label="Save it for interpolation") | |
| random_sampling_button.click(random_sample, inputs=[sigma, current_encode], outputs=[random_sampling_spectrogram, random_sampling_audio, current_encode]) | |
| save_button.click(save_sample, inputs=[save_name_input, current_encode, encodeCache], outputs=[save_name_output, current_encode, encodeCache]) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("WebUI for [DL for sound generation]. webUI version:1.0. Model version: [UNDERFIT_torch_15_5_2023].") | |
| current_encode = gr.State(value=np.ones((1, 24))) | |
| initial_examples = gr.State(value=INIT_ENCODE_CACHE) | |
| # initial_interpolation_examples = gr.State(value={"init": np.random.random(SPECTROGRAM_RESOLUTION)}) | |
| get_example_module(initial_examples) | |
| # get_reconstruction_module() | |
| get_random_sampling_module(initial_examples, current_encode) | |
| get_interpolation_module(initial_examples) | |
| demo.launch() |