File size: 10,443 Bytes
b88cc47
 
 
 
 
 
 
 
 
 
82e5c79
 
 
 
b88cc47
b74de4e
 
b88cc47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82e5c79
b88cc47
 
 
d5db813
b88cc47
 
d5db813
b88cc47
 
 
d5db813
b88cc47
 
 
d5db813
b88cc47
d5db813
b88cc47
d5db813
b88cc47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82e5c79
b88cc47
 
 
 
 
 
 
d5db813
b88cc47
 
d5db813
b88cc47
 
be81083
b88cc47
 
 
 
d5db813
b88cc47
d5db813
b88cc47
 
 
 
 
 
 
 
 
 
 
 
d5db813
b88cc47
 
 
d5db813
b88cc47
 
 
d5db813
b88cc47
d5db813
 
 
 
82e5c79
b88cc47
 
d5db813
 
 
 
 
 
 
 
 
 
b88cc47
 
d5db813
b88cc47
 
 
d5db813
 
b88cc47
 
 
 
d5db813
b88cc47
d5db813
 
 
 
 
 
b88cc47
 
 
edd98f3
d5db813
b88cc47
 
 
 
d5db813
b88cc47
 
58b268b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import gradio as gr
import os
import json
import numpy as np
import torch
import librosa
from tools import VAE_out_put_to_spc, np_power_to_db
from model.VAE_torchV import Encoder, Decoder

SPECTROGRAM_RESOLUTION = (512, 256, 3)
device = "cpu"

encoder = Encoder((1, 512, 256), 24, N2=0, channel_sizes=[64, 64, 64, 96, 96, 128, 160, 216]).to(device)
decoder = Decoder(24, N2=0, N3=8, channel_sizes=[64, 64, 64, 96, 96, 128, 160, 216]).to(device)
model_name = "test"
encoder.load_state_dict(torch.load(f"models/test_encoder_CA.pt", map_location=torch.device(device)))
decoder.load_state_dict(torch.load(f"models/test_decoder_CA.pt", map_location=torch.device(device)))

INIT_ENCODE_CACHE = {"init": np.random.random((24, ))}
with open('webUI/initial_example_encodes.json', 'r') as f:
    list_dict = json.load(f)
    for k in list_dict.keys():
      INIT_ENCODE_CACHE[k] = np.array(list_dict[k])

#################################
def prepare_image(image):
    # Rescale to 0-255 and convert to uint8
    rescaled = (image + 80.0) / 80.0
    rescaled = (255.0 * rescaled).astype(np.uint8)
    return rescaled

def encodeBatch2GradioOutput(latent_vector_batch, resolution=(512, 256)):
    """Show a spectrogram."""
    reconstruction_batch = decoder(latent_vector_batch).to("cpu").detach().numpy()
    flipped_log_spectrums, rec_signals = [], []
    for reconstruction in reconstruction_batch:
      spc = VAE_out_put_to_spc(reconstruction)
      spc = np.reshape(spc, resolution)
      magnitude_spectrum = np.abs(spc)
      log_spectrum = np_power_to_db(magnitude_spectrum)
      flipped_log_spectrum = np.flipud(log_spectrum)

      colorful_spc = np.ones((512, 256, 3)) * -80.0
      colorful_spc[:, :, 0] = flipped_log_spectrum
      colorful_spc[:, :, 1] = flipped_log_spectrum
      colorful_spc[:, :, 2] = np.ones((512, 256)) * -60.0
      flipped_log_spectrum = prepare_image(colorful_spc)

      # get_audio
      abs_spec = np.zeros((513, 256))
      abs_spec[:512, :] = abs_spec[:512, :] + np.sqrt(np.reshape(spc, (512, 256)))
      rec_signal = librosa.griffinlim(abs_spec, n_iter=32, hop_length=256, win_length=1024)
      flipped_log_spectrums.append(flipped_log_spectrum)
      rec_signals.append(rec_signal)

    return flipped_log_spectrums, 16000, rec_signals


def get_example_module(encodeCache):
  def show_example(selected_example, encodeCache):
    example_encode = torch.Tensor(np.reshape(encodeCache[selected_example], (-1, 24))).to(device)
    flipped_log_spectrums, sampleRate, rec_signals = encodeBatch2GradioOutput(example_encode)
    flipped_log_spectrum, rec_signal = flipped_log_spectrums[0], rec_signals[0]

    return flipped_log_spectrum, (sampleRate, rec_signal), encodeCache

  with gr.Tab("Examples"):
    gr.Markdown("Some predefined examples.")
    with gr.Row():
      with gr.Column():
        selected_example = gr.Dropdown(
            list(INIT_ENCODE_CACHE.keys()), label="Examples", info="Choose one example! More samples coming."
        )
        example_button = gr.Button(value="Show example")
      with gr.Column():
        example_image_output = gr.Image(label="Spectrogram", type="numpy")
        example_image_output.style(height=250, width=600)
        example_audio_output = gr.Audio(type="numpy", label="Play the example!")
  example_button.click(show_example, inputs=[selected_example, encodeCache],
            outputs=[example_image_output, example_audio_output, encodeCache])

def get_reconstruction_module():

  def do_nothing(image_input):
    return np.random.random(SPECTROGRAM_RESOLUTION)

  with gr.Tab("Reconstruction"):
    gr.Markdown("Test reconstruction.")
    with gr.Row():
      with gr.Column():
        test_reconstruction_input = gr.Number(label="Batch_index")
        test_reconstruction_button = gr.Button(value="Generate")
      with gr.Column():
        test_reconstruction_output = gr.Image(label="Reconstruction", type="numpy")
        test_reconstruction_output.style(height=250, width=600)
  test_reconstruction_button.click(do_nothing, inputs=test_reconstruction_input, outputs=test_reconstruction_output)

def get_interpolation_module(encodeCache):

  def interpolate(first_interpulation_input, second_interpulation_input, interpulation_input_ratio, encodeCache):
    # Todo: use batch
    first_interpulation_input_encode = torch.Tensor(np.reshape(encodeCache[first_interpulation_input], (-1, 24)))
    second_interpulation_input_encode = torch.Tensor(np.reshape(encodeCache[second_interpulation_input], (-1, 24)))
    ratio = torch.Tensor([interpulation_input_ratio])
    interpulation_encode = first_interpulation_input_encode * ratio + second_interpulation_input_encode * (1 - ratio)

    interpulation_input_encode = torch.stack((first_interpulation_input_encode, second_interpulation_input_encode, interpulation_encode), dim=0).to(device)
    flipped_log_spectrums, sampleRate, rec_signals = encodeBatch2GradioOutput(interpulation_input_encode)
    first_flipped_log_spectrum, first_rec_signal = flipped_log_spectrums[0], rec_signals[0]
    second_flipped_log_spectrum, second_rec_signal = flipped_log_spectrums[1], rec_signals[1]
    interpolation_flipped_log_spectrum, interpolation_rec_signal = flipped_log_spectrums[2], rec_signals[2]
    return first_flipped_log_spectrum, (sampleRate, first_rec_signal), second_flipped_log_spectrum, (sampleRate, second_rec_signal), interpolation_flipped_log_spectrum, (sampleRate, interpolation_rec_signal), encodeCache

  def refresh_interpolation_input(encodeCache):
    return gr.Dropdown.update(choices=list(encodeCache.keys())), gr.Dropdown.update(choices=list(encodeCache.keys())), encodeCache

  with gr.Tab("Interpolation"):
    gr.Markdown("Test Interpolation. Sounds that you sampled can be added to the input dropdown by clicking [Refresh].")
    with gr.Row():
      with gr.Column():
        interpulation_refresh_button = gr.Button(value="Refresh")
        with gr.Row():
          first_interpulation_input = gr.Dropdown(list(INIT_ENCODE_CACHE.keys()), label="First input")
          second_interpulation_input = gr.Dropdown(list(INIT_ENCODE_CACHE.keys()), label="Second input")
        with gr.Row():
          first_input_audio = gr.Audio(type="numpy", label="Play first input")
          first_input_audio.style(length=125)
          second_input_audio = gr.Audio(type="numpy", label="sePlay second input")
          second_input_audio.style(length=125)

        interpulation_input_ratio = gr.Slider(minimum=-0.20, maximum=1.20, value=0.5, step=0.01, label="Ratio of the first input.")
        interpulation_button = gr.Button(value="Interpulate")
      with gr.Column():
        with gr.Row():
          first_input_spectrogram = gr.Image(label="First Input", type="numpy")
          first_input_spectrogram.style(height=250, width=125)
          interpolation_spectrogram = gr.Image(label="Interpolation", type="numpy")
          interpolation_spectrogram.style(height=250, width=125)
          second_input_spectrogram = gr.Image(label="Second Input", type="numpy")
          second_input_spectrogram.style(height=250, width=125)
        interpolation_audio = gr.Audio(type="numpy", label="Play interpolation")
        interpolation_audio.style(length=125)

  interpulation_refresh_button.click(refresh_interpolation_input, inputs=[encodeCache],
                                     outputs=[first_interpulation_input, second_interpulation_input, encodeCache])
  interpulation_button.click(interpolate, inputs=[first_interpulation_input, second_interpulation_input, interpulation_input_ratio, encodeCache],
                             outputs=[first_input_spectrogram, first_input_audio, second_input_spectrogram, second_input_audio, interpolation_spectrogram, interpolation_audio, encodeCache])

def get_random_sampling_module(encodeCache, current_encode):

  def random_sample(sigma, current_encode):
    random_encode = torch.Tensor([sigma]) * torch.randn(1, 24)
    # random_encode = torch.Tensor([mu]) + torch.Tensor([sigma]) * torch.randn(1, 24)
    current_encode = current_encode * 0.0 + random_encode.detach().numpy()
    random_encode = random_encode.to(device)
    flipped_log_spectrums, sampleRate, rec_signals = encodeBatch2GradioOutput(random_encode)
    random_log_spectrum, random_rec_signal = flipped_log_spectrums[0], rec_signals[0]
    return random_log_spectrum, (sampleRate, random_rec_signal), current_encode

  def save_sample(save_name, current_encode, encodeCache):
    if not (np.sum(current_encode) == 24):
      if len(save_name) == 0:
        return "The save name is empty.", current_encode, encodeCache
      encodeCache[save_name] = current_encode
      return "Sample saved.", current_encode, encodeCache
    else:
      return f"Please generate one sample.", current_encode, encodeCache

  with gr.Tab("Random sampling"):
    gr.Markdown("Sample new sound! Feel free to name and save your samples!")
    with gr.Row():
      with gr.Column():
        with gr.Row():
          # mu = gr.Number(label="mu")
          sigma = gr.Number(value=1.0, label="sigma")
        random_sampling_button = gr.Button(value="Sample")
      with gr.Column():
        random_sampling_spectrogram = gr.Image(label="Random sampling", type="numpy")
        random_sampling_spectrogram.style(height=250, width=600)
        random_sampling_audio = gr.Audio(type="numpy", label="Play the sample")
        random_sampling_audio.style(length=125)
        save_name_input = gr.Textbox(label="Name your sound")
        save_button = gr.Button(value="save")
        save_name_output = gr.Textbox(label="Save it for interpolation")

  random_sampling_button.click(random_sample, inputs=[sigma, current_encode], outputs=[random_sampling_spectrogram, random_sampling_audio, current_encode])
  save_button.click(save_sample, inputs=[save_name_input, current_encode, encodeCache], outputs=[save_name_output, current_encode, encodeCache])


with gr.Blocks() as demo:
  gr.Markdown("WebUI for [DL for sound generation]. webUI version:1.0. Model version: [UNDERFIT_torch_15_5_2023].")
  current_encode = gr.State(value=np.ones((1, 24)))
  initial_examples = gr.State(value=INIT_ENCODE_CACHE)
  # initial_interpolation_examples = gr.State(value={"init": np.random.random(SPECTROGRAM_RESOLUTION)})
  get_example_module(initial_examples)
  # get_reconstruction_module()
  get_random_sampling_module(initial_examples, current_encode)
  get_interpolation_module(initial_examples)

demo.launch()