Spaces:
Paused
Paused
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from model.DiffSynthSampler import DiffSynthSampler | |
| from tools import safe_int | |
| from webUI.natural_language_guided_STFT.utils import encodeBatch2GradioOutput, latent_representation_to_Gradio_image | |
| def get_interpolation_with_condition_module(gradioWebUI, interpolation_with_text_state): | |
| # Load configurations | |
| uNet = gradioWebUI.uNet | |
| freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution | |
| VAE_scale = gradioWebUI.VAE_scale | |
| height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels | |
| timesteps = gradioWebUI.timesteps | |
| VAE_quantizer = gradioWebUI.VAE_quantizer | |
| VAE_decoder = gradioWebUI.VAE_decoder | |
| CLAP = gradioWebUI.CLAP | |
| CLAP_tokenizer = gradioWebUI.CLAP_tokenizer | |
| device = gradioWebUI.device | |
| squared = gradioWebUI.squared | |
| sample_rate = gradioWebUI.sample_rate | |
| noise_strategy = gradioWebUI.noise_strategy | |
| def diffusion_random_sample(text2sound_prompts_1, text2sound_prompts_2, text2sound_negative_prompts, text2sound_batchsize, | |
| text2sound_duration, | |
| text2sound_guidance_scale, text2sound_sampler, | |
| text2sound_sample_steps, text2sound_seed, | |
| interpolation_with_text_dict): | |
| text2sound_sample_steps = int(text2sound_sample_steps) | |
| text2sound_seed = safe_int(text2sound_seed, 12345678) | |
| # Todo: take care of text2sound_time_resolution/width | |
| width = int(time_resolution*((text2sound_duration+1)/4) / VAE_scale) | |
| text2sound_batchsize = int(text2sound_batchsize) | |
| text2sound_embedding_1 = \ | |
| CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts_1], padding=True, return_tensors="pt"))[0].to(device) | |
| text2sound_embedding_2 = \ | |
| CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts_2], padding=True, return_tensors="pt"))[0].to(device) | |
| CFG = int(text2sound_guidance_scale) | |
| mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy) | |
| unconditional_condition = \ | |
| CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[0] | |
| mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device)) | |
| mySampler.respace(list(np.linspace(0, timesteps - 1, text2sound_sample_steps, dtype=np.int32))) | |
| condition = torch.linspace(1, 0, steps=text2sound_batchsize).unsqueeze(1).to(device) * text2sound_embedding_1 + \ | |
| torch.linspace(0, 1, steps=text2sound_batchsize).unsqueeze(1).to(device) * text2sound_embedding_2 | |
| # Todo: move this code | |
| torch.manual_seed(text2sound_seed) | |
| initial_noise = torch.randn(text2sound_batchsize, channels, height, width).to(device) | |
| latent_representations, initial_noise = \ | |
| mySampler.sample(model=uNet, shape=(text2sound_batchsize, channels, height, width), seed=text2sound_seed, | |
| return_tensor=True, condition=condition, sampler=text2sound_sampler, initial_noise=initial_noise) | |
| latent_representations = latent_representations[-1] | |
| interpolation_with_text_dict["latent_representations"] = latent_representations | |
| latent_representation_gradio_images = [] | |
| quantized_latent_representation_gradio_images = [] | |
| new_sound_spectrogram_gradio_images = [] | |
| new_sound_rec_signals_gradio = [] | |
| quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations) | |
| # Todo: remove hard-coding | |
| flipped_log_spectrums, rec_signals = encodeBatch2GradioOutput(VAE_decoder, quantized_latent_representations, | |
| resolution=(512, width * VAE_scale), centralized=False, | |
| squared=squared) | |
| for i in range(text2sound_batchsize): | |
| latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i])) | |
| quantized_latent_representation_gradio_images.append( | |
| latent_representation_to_Gradio_image(quantized_latent_representations[i])) | |
| new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i]) | |
| new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i])) | |
| def concatenate_arrays(arrays_list): | |
| return np.concatenate(arrays_list, axis=1) | |
| concatenated_spectrogram_gradio_image = concatenate_arrays(new_sound_spectrogram_gradio_images) | |
| interpolation_with_text_dict["latent_representation_gradio_images"] = latent_representation_gradio_images | |
| interpolation_with_text_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images | |
| interpolation_with_text_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images | |
| interpolation_with_text_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio | |
| return {text2sound_latent_representation_image: interpolation_with_text_dict["latent_representation_gradio_images"][0], | |
| text2sound_quantized_latent_representation_image: | |
| interpolation_with_text_dict["quantized_latent_representation_gradio_images"][0], | |
| text2sound_sampled_concatenated_spectrogram_image: concatenated_spectrogram_gradio_image, | |
| text2sound_sampled_spectrogram_image: interpolation_with_text_dict["new_sound_spectrogram_gradio_images"][0], | |
| text2sound_sampled_audio: interpolation_with_text_dict["new_sound_rec_signals_gradio"][0], | |
| text2sound_seed_textbox: text2sound_seed, | |
| interpolation_with_text_state: interpolation_with_text_dict, | |
| text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1, | |
| visible=True, | |
| label="Sample index.", | |
| info="Swipe to view other samples")} | |
| def show_random_sample(sample_index, text2sound_dict): | |
| sample_index = int(sample_index) | |
| return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][ | |
| sample_index], | |
| text2sound_quantized_latent_representation_image: | |
| text2sound_dict["quantized_latent_representation_gradio_images"][sample_index], | |
| text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][sample_index], | |
| text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]} | |
| with gr.Tab("InterpolationCond."): | |
| gr.Markdown("Use interpolation to generate a gradient sound sequence.") | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=3): | |
| text2sound_prompts_1_textbox = gr.Textbox(label="Positive prompt 1", lines=2, value="organ") | |
| text2sound_prompts_2_textbox = gr.Textbox(label="Positive prompt 2", lines=2, value="string") | |
| text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="") | |
| with gr.Column(scale=1): | |
| text2sound_sampling_button = gr.Button(variant="primary", | |
| value="Generate a batch of samples and show " | |
| "the first one", | |
| scale=1) | |
| text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False, | |
| label="Sample index", | |
| info="Swipe to view other samples") | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=1, variant="panel"): | |
| text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider() | |
| text2sound_sampler_radio = gradioWebUI.get_sampler_radio() | |
| text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider(cpu_batchsize=3) | |
| text2sound_duration_slider = gradioWebUI.get_duration_slider() | |
| text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider() | |
| text2sound_seed_textbox = gradioWebUI.get_seed_textbox() | |
| with gr.Column(scale=1): | |
| with gr.Row(variant="panel"): | |
| text2sound_sampled_concatenated_spectrogram_image = gr.Image(label="Interpolations", type="numpy", | |
| height=420, scale=8) | |
| text2sound_sampled_spectrogram_image = gr.Image(label="Selected spectrogram", type="numpy", | |
| height=420, scale=1) | |
| text2sound_sampled_audio = gr.Audio(type="numpy", label="Play") | |
| with gr.Row(variant="panel"): | |
| text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy", | |
| height=200, width=100) | |
| text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation", | |
| type="numpy", height=200, width=100) | |
| text2sound_sampling_button.click(diffusion_random_sample, | |
| inputs=[text2sound_prompts_1_textbox, | |
| text2sound_prompts_2_textbox, | |
| text2sound_negative_prompts_textbox, | |
| text2sound_batchsize_slider, | |
| text2sound_duration_slider, | |
| text2sound_guidance_scale_slider, text2sound_sampler_radio, | |
| text2sound_sample_steps_slider, | |
| text2sound_seed_textbox, | |
| interpolation_with_text_state], | |
| outputs=[text2sound_latent_representation_image, | |
| text2sound_quantized_latent_representation_image, | |
| text2sound_sampled_concatenated_spectrogram_image, | |
| text2sound_sampled_spectrogram_image, | |
| text2sound_sampled_audio, | |
| text2sound_seed_textbox, | |
| interpolation_with_text_state, | |
| text2sound_sample_index_slider]) | |
| text2sound_sample_index_slider.change(show_random_sample, | |
| inputs=[text2sound_sample_index_slider, interpolation_with_text_state], | |
| outputs=[text2sound_latent_representation_image, | |
| text2sound_quantized_latent_representation_image, | |
| text2sound_sampled_spectrogram_image, | |
| text2sound_sampled_audio]) | |