Spaces:
Paused
Paused
File size: 11,954 Bytes
2b389c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import gradio as gr
import numpy as np
import torch
from model.DiffSynthSampler import DiffSynthSampler
from tools import safe_int
from webUI.natural_language_guided_STFT.utils import encodeBatch2GradioOutput, latent_representation_to_Gradio_image
def get_interpolation_with_condition_module(gradioWebUI, interpolation_with_text_state):
# Load configurations
uNet = gradioWebUI.uNet
freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
VAE_scale = gradioWebUI.VAE_scale
height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels
timesteps = gradioWebUI.timesteps
VAE_quantizer = gradioWebUI.VAE_quantizer
VAE_decoder = gradioWebUI.VAE_decoder
CLAP = gradioWebUI.CLAP
CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
device = gradioWebUI.device
squared = gradioWebUI.squared
sample_rate = gradioWebUI.sample_rate
noise_strategy = gradioWebUI.noise_strategy
def diffusion_random_sample(text2sound_prompts_1, text2sound_prompts_2, text2sound_negative_prompts, text2sound_batchsize,
text2sound_duration,
text2sound_guidance_scale, text2sound_sampler,
text2sound_sample_steps, text2sound_seed,
interpolation_with_text_dict):
text2sound_sample_steps = int(text2sound_sample_steps)
text2sound_seed = safe_int(text2sound_seed, 12345678)
# Todo: take care of text2sound_time_resolution/width
width = int(time_resolution*((text2sound_duration+1)/4) / VAE_scale)
text2sound_batchsize = int(text2sound_batchsize)
text2sound_embedding_1 = \
CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts_1], padding=True, return_tensors="pt"))[0].to(device)
text2sound_embedding_2 = \
CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts_2], padding=True, return_tensors="pt"))[0].to(device)
CFG = int(text2sound_guidance_scale)
mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
unconditional_condition = \
CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[0]
mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device))
mySampler.respace(list(np.linspace(0, timesteps - 1, text2sound_sample_steps, dtype=np.int32)))
condition = torch.linspace(1, 0, steps=text2sound_batchsize).unsqueeze(1).to(device) * text2sound_embedding_1 + \
torch.linspace(0, 1, steps=text2sound_batchsize).unsqueeze(1).to(device) * text2sound_embedding_2
# Todo: move this code
torch.manual_seed(text2sound_seed)
initial_noise = torch.randn(text2sound_batchsize, channels, height, width).to(device)
latent_representations, initial_noise = \
mySampler.sample(model=uNet, shape=(text2sound_batchsize, channels, height, width), seed=text2sound_seed,
return_tensor=True, condition=condition, sampler=text2sound_sampler, initial_noise=initial_noise)
latent_representations = latent_representations[-1]
interpolation_with_text_dict["latent_representations"] = latent_representations
latent_representation_gradio_images = []
quantized_latent_representation_gradio_images = []
new_sound_spectrogram_gradio_images = []
new_sound_rec_signals_gradio = []
quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
# Todo: remove hard-coding
flipped_log_spectrums, rec_signals = encodeBatch2GradioOutput(VAE_decoder, quantized_latent_representations,
resolution=(512, width * VAE_scale), centralized=False,
squared=squared)
for i in range(text2sound_batchsize):
latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
quantized_latent_representation_gradio_images.append(
latent_representation_to_Gradio_image(quantized_latent_representations[i]))
new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i])
new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i]))
def concatenate_arrays(arrays_list):
return np.concatenate(arrays_list, axis=1)
concatenated_spectrogram_gradio_image = concatenate_arrays(new_sound_spectrogram_gradio_images)
interpolation_with_text_dict["latent_representation_gradio_images"] = latent_representation_gradio_images
interpolation_with_text_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
interpolation_with_text_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
interpolation_with_text_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
return {text2sound_latent_representation_image: interpolation_with_text_dict["latent_representation_gradio_images"][0],
text2sound_quantized_latent_representation_image:
interpolation_with_text_dict["quantized_latent_representation_gradio_images"][0],
text2sound_sampled_concatenated_spectrogram_image: concatenated_spectrogram_gradio_image,
text2sound_sampled_spectrogram_image: interpolation_with_text_dict["new_sound_spectrogram_gradio_images"][0],
text2sound_sampled_audio: interpolation_with_text_dict["new_sound_rec_signals_gradio"][0],
text2sound_seed_textbox: text2sound_seed,
interpolation_with_text_state: interpolation_with_text_dict,
text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1,
visible=True,
label="Sample index.",
info="Swipe to view other samples")}
def show_random_sample(sample_index, text2sound_dict):
sample_index = int(sample_index)
return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][
sample_index],
text2sound_quantized_latent_representation_image:
text2sound_dict["quantized_latent_representation_gradio_images"][sample_index],
text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][sample_index],
text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]}
with gr.Tab("InterpolationCond."):
gr.Markdown("Use interpolation to generate a gradient sound sequence.")
with gr.Row(variant="panel"):
with gr.Column(scale=3):
text2sound_prompts_1_textbox = gr.Textbox(label="Positive prompt 1", lines=2, value="organ")
text2sound_prompts_2_textbox = gr.Textbox(label="Positive prompt 2", lines=2, value="string")
text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
with gr.Column(scale=1):
text2sound_sampling_button = gr.Button(variant="primary",
value="Generate a batch of samples and show "
"the first one",
scale=1)
text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
label="Sample index",
info="Swipe to view other samples")
with gr.Row(variant="panel"):
with gr.Column(scale=1, variant="panel"):
text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
text2sound_sampler_radio = gradioWebUI.get_sampler_radio()
text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider(cpu_batchsize=3)
text2sound_duration_slider = gradioWebUI.get_duration_slider()
text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
text2sound_seed_textbox = gradioWebUI.get_seed_textbox()
with gr.Column(scale=1):
with gr.Row(variant="panel"):
text2sound_sampled_concatenated_spectrogram_image = gr.Image(label="Interpolations", type="numpy",
height=420, scale=8)
text2sound_sampled_spectrogram_image = gr.Image(label="Selected spectrogram", type="numpy",
height=420, scale=1)
text2sound_sampled_audio = gr.Audio(type="numpy", label="Play")
with gr.Row(variant="panel"):
text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy",
height=200, width=100)
text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation",
type="numpy", height=200, width=100)
text2sound_sampling_button.click(diffusion_random_sample,
inputs=[text2sound_prompts_1_textbox,
text2sound_prompts_2_textbox,
text2sound_negative_prompts_textbox,
text2sound_batchsize_slider,
text2sound_duration_slider,
text2sound_guidance_scale_slider, text2sound_sampler_radio,
text2sound_sample_steps_slider,
text2sound_seed_textbox,
interpolation_with_text_state],
outputs=[text2sound_latent_representation_image,
text2sound_quantized_latent_representation_image,
text2sound_sampled_concatenated_spectrogram_image,
text2sound_sampled_spectrogram_image,
text2sound_sampled_audio,
text2sound_seed_textbox,
interpolation_with_text_state,
text2sound_sample_index_slider])
text2sound_sample_index_slider.change(show_random_sample,
inputs=[text2sound_sample_index_slider, interpolation_with_text_state],
outputs=[text2sound_latent_representation_image,
text2sound_quantized_latent_representation_image,
text2sound_sampled_spectrogram_image,
text2sound_sampled_audio])
|