Sound_VAE / app.py
WeixuanYuan's picture
Update app.py
be81083
import gradio as gr
import os
import json
import numpy as np
import torch
import librosa
from tools import VAE_out_put_to_spc, np_power_to_db
from model.VAE_torchV import Encoder, Decoder
SPECTROGRAM_RESOLUTION = (512, 256, 3)
device = "cpu"
encoder = Encoder((1, 512, 256), 24, N2=0, channel_sizes=[64, 64, 64, 96, 96, 128, 160, 216]).to(device)
decoder = Decoder(24, N2=0, N3=8, channel_sizes=[64, 64, 64, 96, 96, 128, 160, 216]).to(device)
model_name = "test"
encoder.load_state_dict(torch.load(f"models/test_encoder_CA.pt", map_location=torch.device(device)))
decoder.load_state_dict(torch.load(f"models/test_decoder_CA.pt", map_location=torch.device(device)))
INIT_ENCODE_CACHE = {"init": np.random.random((24, ))}
with open('webUI/initial_example_encodes.json', 'r') as f:
list_dict = json.load(f)
for k in list_dict.keys():
INIT_ENCODE_CACHE[k] = np.array(list_dict[k])
#################################
def prepare_image(image):
# Rescale to 0-255 and convert to uint8
rescaled = (image + 80.0) / 80.0
rescaled = (255.0 * rescaled).astype(np.uint8)
return rescaled
def encodeBatch2GradioOutput(latent_vector_batch, resolution=(512, 256)):
"""Show a spectrogram."""
reconstruction_batch = decoder(latent_vector_batch).to("cpu").detach().numpy()
flipped_log_spectrums, rec_signals = [], []
for reconstruction in reconstruction_batch:
spc = VAE_out_put_to_spc(reconstruction)
spc = np.reshape(spc, resolution)
magnitude_spectrum = np.abs(spc)
log_spectrum = np_power_to_db(magnitude_spectrum)
flipped_log_spectrum = np.flipud(log_spectrum)
colorful_spc = np.ones((512, 256, 3)) * -80.0
colorful_spc[:, :, 0] = flipped_log_spectrum
colorful_spc[:, :, 1] = flipped_log_spectrum
colorful_spc[:, :, 2] = np.ones((512, 256)) * -60.0
flipped_log_spectrum = prepare_image(colorful_spc)
# get_audio
abs_spec = np.zeros((513, 256))
abs_spec[:512, :] = abs_spec[:512, :] + np.sqrt(np.reshape(spc, (512, 256)))
rec_signal = librosa.griffinlim(abs_spec, n_iter=32, hop_length=256, win_length=1024)
flipped_log_spectrums.append(flipped_log_spectrum)
rec_signals.append(rec_signal)
return flipped_log_spectrums, 16000, rec_signals
def get_example_module(encodeCache):
def show_example(selected_example, encodeCache):
example_encode = torch.Tensor(np.reshape(encodeCache[selected_example], (-1, 24))).to(device)
flipped_log_spectrums, sampleRate, rec_signals = encodeBatch2GradioOutput(example_encode)
flipped_log_spectrum, rec_signal = flipped_log_spectrums[0], rec_signals[0]
return flipped_log_spectrum, (sampleRate, rec_signal), encodeCache
with gr.Tab("Examples"):
gr.Markdown("Some predefined examples.")
with gr.Row():
with gr.Column():
selected_example = gr.Dropdown(
list(INIT_ENCODE_CACHE.keys()), label="Examples", info="Choose one example! More samples coming."
)
example_button = gr.Button(value="Show example")
with gr.Column():
example_image_output = gr.Image(label="Spectrogram", type="numpy")
example_image_output.style(height=250, width=600)
example_audio_output = gr.Audio(type="numpy", label="Play the example!")
example_button.click(show_example, inputs=[selected_example, encodeCache],
outputs=[example_image_output, example_audio_output, encodeCache])
def get_reconstruction_module():
def do_nothing(image_input):
return np.random.random(SPECTROGRAM_RESOLUTION)
with gr.Tab("Reconstruction"):
gr.Markdown("Test reconstruction.")
with gr.Row():
with gr.Column():
test_reconstruction_input = gr.Number(label="Batch_index")
test_reconstruction_button = gr.Button(value="Generate")
with gr.Column():
test_reconstruction_output = gr.Image(label="Reconstruction", type="numpy")
test_reconstruction_output.style(height=250, width=600)
test_reconstruction_button.click(do_nothing, inputs=test_reconstruction_input, outputs=test_reconstruction_output)
def get_interpolation_module(encodeCache):
def interpolate(first_interpulation_input, second_interpulation_input, interpulation_input_ratio, encodeCache):
# Todo: use batch
first_interpulation_input_encode = torch.Tensor(np.reshape(encodeCache[first_interpulation_input], (-1, 24)))
second_interpulation_input_encode = torch.Tensor(np.reshape(encodeCache[second_interpulation_input], (-1, 24)))
ratio = torch.Tensor([interpulation_input_ratio])
interpulation_encode = first_interpulation_input_encode * ratio + second_interpulation_input_encode * (1 - ratio)
interpulation_input_encode = torch.stack((first_interpulation_input_encode, second_interpulation_input_encode, interpulation_encode), dim=0).to(device)
flipped_log_spectrums, sampleRate, rec_signals = encodeBatch2GradioOutput(interpulation_input_encode)
first_flipped_log_spectrum, first_rec_signal = flipped_log_spectrums[0], rec_signals[0]
second_flipped_log_spectrum, second_rec_signal = flipped_log_spectrums[1], rec_signals[1]
interpolation_flipped_log_spectrum, interpolation_rec_signal = flipped_log_spectrums[2], rec_signals[2]
return first_flipped_log_spectrum, (sampleRate, first_rec_signal), second_flipped_log_spectrum, (sampleRate, second_rec_signal), interpolation_flipped_log_spectrum, (sampleRate, interpolation_rec_signal), encodeCache
def refresh_interpolation_input(encodeCache):
return gr.Dropdown.update(choices=list(encodeCache.keys())), gr.Dropdown.update(choices=list(encodeCache.keys())), encodeCache
with gr.Tab("Interpolation"):
gr.Markdown("Test Interpolation. Sounds that you sampled can be added to the input dropdown by clicking [Refresh].")
with gr.Row():
with gr.Column():
interpulation_refresh_button = gr.Button(value="Refresh")
with gr.Row():
first_interpulation_input = gr.Dropdown(list(INIT_ENCODE_CACHE.keys()), label="First input")
second_interpulation_input = gr.Dropdown(list(INIT_ENCODE_CACHE.keys()), label="Second input")
with gr.Row():
first_input_audio = gr.Audio(type="numpy", label="Play first input")
first_input_audio.style(length=125)
second_input_audio = gr.Audio(type="numpy", label="sePlay second input")
second_input_audio.style(length=125)
interpulation_input_ratio = gr.Slider(minimum=-0.20, maximum=1.20, value=0.5, step=0.01, label="Ratio of the first input.")
interpulation_button = gr.Button(value="Interpulate")
with gr.Column():
with gr.Row():
first_input_spectrogram = gr.Image(label="First Input", type="numpy")
first_input_spectrogram.style(height=250, width=125)
interpolation_spectrogram = gr.Image(label="Interpolation", type="numpy")
interpolation_spectrogram.style(height=250, width=125)
second_input_spectrogram = gr.Image(label="Second Input", type="numpy")
second_input_spectrogram.style(height=250, width=125)
interpolation_audio = gr.Audio(type="numpy", label="Play interpolation")
interpolation_audio.style(length=125)
interpulation_refresh_button.click(refresh_interpolation_input, inputs=[encodeCache],
outputs=[first_interpulation_input, second_interpulation_input, encodeCache])
interpulation_button.click(interpolate, inputs=[first_interpulation_input, second_interpulation_input, interpulation_input_ratio, encodeCache],
outputs=[first_input_spectrogram, first_input_audio, second_input_spectrogram, second_input_audio, interpolation_spectrogram, interpolation_audio, encodeCache])
def get_random_sampling_module(encodeCache, current_encode):
def random_sample(sigma, current_encode):
random_encode = torch.Tensor([sigma]) * torch.randn(1, 24)
# random_encode = torch.Tensor([mu]) + torch.Tensor([sigma]) * torch.randn(1, 24)
current_encode = current_encode * 0.0 + random_encode.detach().numpy()
random_encode = random_encode.to(device)
flipped_log_spectrums, sampleRate, rec_signals = encodeBatch2GradioOutput(random_encode)
random_log_spectrum, random_rec_signal = flipped_log_spectrums[0], rec_signals[0]
return random_log_spectrum, (sampleRate, random_rec_signal), current_encode
def save_sample(save_name, current_encode, encodeCache):
if not (np.sum(current_encode) == 24):
if len(save_name) == 0:
return "The save name is empty.", current_encode, encodeCache
encodeCache[save_name] = current_encode
return "Sample saved.", current_encode, encodeCache
else:
return f"Please generate one sample.", current_encode, encodeCache
with gr.Tab("Random sampling"):
gr.Markdown("Sample new sound! Feel free to name and save your samples!")
with gr.Row():
with gr.Column():
with gr.Row():
# mu = gr.Number(label="mu")
sigma = gr.Number(value=1.0, label="sigma")
random_sampling_button = gr.Button(value="Sample")
with gr.Column():
random_sampling_spectrogram = gr.Image(label="Random sampling", type="numpy")
random_sampling_spectrogram.style(height=250, width=600)
random_sampling_audio = gr.Audio(type="numpy", label="Play the sample")
random_sampling_audio.style(length=125)
save_name_input = gr.Textbox(label="Name your sound")
save_button = gr.Button(value="save")
save_name_output = gr.Textbox(label="Save it for interpolation")
random_sampling_button.click(random_sample, inputs=[sigma, current_encode], outputs=[random_sampling_spectrogram, random_sampling_audio, current_encode])
save_button.click(save_sample, inputs=[save_name_input, current_encode, encodeCache], outputs=[save_name_output, current_encode, encodeCache])
with gr.Blocks() as demo:
gr.Markdown("WebUI for [DL for sound generation]. webUI version:1.0. Model version: [UNDERFIT_torch_15_5_2023].")
current_encode = gr.State(value=np.ones((1, 24)))
initial_examples = gr.State(value=INIT_ENCODE_CACHE)
# initial_interpolation_examples = gr.State(value={"init": np.random.random(SPECTROGRAM_RESOLUTION)})
get_example_module(initial_examples)
# get_reconstruction_module()
get_random_sampling_module(initial_examples, current_encode)
get_interpolation_module(initial_examples)
demo.launch()