WeixuanYuan commited on
Commit
2b389c5
·
verified ·
1 Parent(s): 05f9bd8

Upload 49 files

Browse files
Files changed (49) hide show
  1. model/DiffSynthSampler.py +425 -0
  2. model/GAN.py +262 -0
  3. model/VQGAN.py +684 -0
  4. model/__pycache__/DiffSynthSampler.cpython-310.pyc +0 -0
  5. model/__pycache__/GAN.cpython-310.pyc +0 -0
  6. model/__pycache__/VQGAN.cpython-310.pyc +0 -0
  7. model/__pycache__/diffusion.cpython-310.pyc +0 -0
  8. model/__pycache__/diffusion_components.cpython-310.pyc +0 -0
  9. model/__pycache__/multimodal_model.cpython-310.pyc +0 -0
  10. model/__pycache__/perceptual_label_predictor.cpython-37.pyc +0 -0
  11. model/__pycache__/timbre_encoder_pretrain.cpython-310.pyc +0 -0
  12. model/diffusion.py +371 -0
  13. model/diffusion_components.py +351 -0
  14. model/multimodal_model.py +274 -0
  15. model/timbre_encoder_pretrain.py +220 -0
  16. tools.py +344 -0
  17. webUI/__pycache__/app.cpython-310.pyc +0 -0
  18. webUI/deprecated/interpolationWithCondition.py +178 -0
  19. webUI/deprecated/interpolationWithXT.py +173 -0
  20. webUI/natural_language_guided/GAN.py +164 -0
  21. webUI/natural_language_guided/README.py +53 -0
  22. webUI/natural_language_guided/__pycache__/README.cpython-310.pyc +0 -0
  23. webUI/natural_language_guided/__pycache__/README_STFT.cpython-310.pyc +0 -0
  24. webUI/natural_language_guided/__pycache__/buildInstrument_STFT.cpython-310.pyc +0 -0
  25. webUI/natural_language_guided/__pycache__/build_instrument.cpython-310.pyc +0 -0
  26. webUI/natural_language_guided/__pycache__/gradioWebUI.cpython-310.pyc +0 -0
  27. webUI/natural_language_guided/__pycache__/gradioWebUI_STFT.cpython-310.pyc +0 -0
  28. webUI/natural_language_guided/__pycache__/gradio_webUI.cpython-310.pyc +0 -0
  29. webUI/natural_language_guided/__pycache__/inpaintWithText.cpython-310.pyc +0 -0
  30. webUI/natural_language_guided/__pycache__/inpaintWithText_STFT.cpython-310.pyc +0 -0
  31. webUI/natural_language_guided/__pycache__/inpaint_with_text.cpython-310.pyc +0 -0
  32. webUI/natural_language_guided/__pycache__/rec.cpython-310.pyc +0 -0
  33. webUI/natural_language_guided/__pycache__/recSTFT.cpython-310.pyc +0 -0
  34. webUI/natural_language_guided/__pycache__/sound2soundWithText.cpython-310.pyc +0 -0
  35. webUI/natural_language_guided/__pycache__/sound2soundWithText_STFT.cpython-310.pyc +0 -0
  36. webUI/natural_language_guided/__pycache__/sound2sound_with_text.cpython-310.pyc +0 -0
  37. webUI/natural_language_guided/__pycache__/text2sound.cpython-310.pyc +0 -0
  38. webUI/natural_language_guided/__pycache__/text2sound_STFT.cpython-310.pyc +0 -0
  39. webUI/natural_language_guided/__pycache__/track_maker.cpython-310.pyc +0 -0
  40. webUI/natural_language_guided/__pycache__/utils.cpython-310.pyc +0 -0
  41. webUI/natural_language_guided/build_instrument.py +274 -0
  42. webUI/natural_language_guided/gradio_webUI.py +68 -0
  43. webUI/natural_language_guided/inpaint_with_text.py +441 -0
  44. webUI/natural_language_guided/rec.py +190 -0
  45. webUI/natural_language_guided/sound2sound_with_text.py +416 -0
  46. webUI/natural_language_guided/super_resolution_with_text.py +387 -0
  47. webUI/natural_language_guided/text2sound.py +212 -0
  48. webUI/natural_language_guided/track_maker.py +192 -0
  49. webUI/natural_language_guided/utils.py +174 -0
model/DiffSynthSampler.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+
6
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
7
+ """
8
+ Extract values from a 1-D numpy array for a batch of indices.
9
+
10
+ :param arr: the 1-D numpy array.
11
+ :param timesteps: a tensor of indices into the array to extract.
12
+ :param broadcast_shape: a larger shape of K dimensions with the batch
13
+ dimension equal to the length of timesteps.
14
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
15
+ """
16
+ res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
17
+ while len(res.shape) < len(broadcast_shape):
18
+ res = res[..., None]
19
+ return res.expand(broadcast_shape)
20
+
21
+
22
+ class DiffSynthSampler:
23
+
24
+ def __init__(self, timesteps, beta_start=0.0001, beta_end=0.02, device=None, mute=False,
25
+ height=128, max_batchsize=16, max_width=256, channels=4, train_width=64, noise_strategy="repeat"):
26
+ if device is None:
27
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ else:
29
+ self.device = device
30
+ self.height = height
31
+ self.train_width = train_width
32
+ self.max_batchsize = max_batchsize
33
+ self.max_width = max_width
34
+ self.channels = channels
35
+ self.num_timesteps = timesteps
36
+ self.timestep_map = list(range(self.num_timesteps))
37
+ self.betas = np.array(np.linspace(beta_start, beta_end, self.num_timesteps), dtype=np.float64)
38
+ self.respaced = False
39
+ self.define_beta_schedule()
40
+ self.CFG = 1.0
41
+ self.mute = mute
42
+ self.noise_strategy = noise_strategy
43
+
44
+ def get_deterministic_noise_tensor_non_repeat(self, batchsize, width, reference_noise=None):
45
+ if reference_noise is None:
46
+ large_noise_tensor = torch.randn((self.max_batchsize, self.channels, self.height, self.max_width), device=self.device)
47
+ else:
48
+ assert reference_noise.shape == (batchsize, self.channels, self.height, self.max_width), "reference_noise shape mismatch"
49
+ large_noise_tensor = reference_noise
50
+ return large_noise_tensor[:batchsize, :, :, :width], None
51
+
52
+ def get_deterministic_noise_tensor(self, batchsize, width, reference_noise=None):
53
+ if self.noise_strategy == "repeat":
54
+ noise, concat_points = self.get_deterministic_noise_tensor_repeat(batchsize, width, reference_noise=reference_noise)
55
+ return noise, concat_points
56
+ else:
57
+ noise, concat_points = self.get_deterministic_noise_tensor_non_repeat(batchsize, width, reference_noise=reference_noise)
58
+ return noise, concat_points
59
+
60
+
61
+ def get_deterministic_noise_tensor_repeat(self, batchsize, width, reference_noise=None):
62
+ # 生成与训练数据长度相等的噪音
63
+ if reference_noise is None:
64
+ train_noise_tensor = torch.randn((self.max_batchsize, self.channels, self.height, self.train_width), device=self.device)
65
+ else:
66
+ assert reference_noise.shape == (batchsize, self.channels, self.height, self.train_width), "reference_noise shape mismatch"
67
+ train_noise_tensor = reference_noise
68
+
69
+ release_width = int(self.train_width * 1.0 / 4)
70
+ first_part_width = self.train_width - release_width
71
+
72
+ first_part = train_noise_tensor[:batchsize, :, :, :first_part_width]
73
+ release_part = train_noise_tensor[:batchsize, :, :, -release_width:]
74
+
75
+ # 如果所需 length 小于等于 origin length,去掉 first_part 的中间部分
76
+ if width <= self.train_width:
77
+ _first_part_head_width = int((width - release_width) / 2)
78
+ _first_part_tail_width = width - release_width - _first_part_head_width
79
+ all_parts = [first_part[:, :, :, :_first_part_head_width], first_part[:, :, :, -_first_part_tail_width:], release_part]
80
+
81
+ # 沿第四维度拼接张量
82
+ noise_tensor = torch.cat(all_parts, dim=3)
83
+
84
+ # 记录拼接点的位置
85
+ concat_points = [0]
86
+ for part in all_parts[:-1]:
87
+ next_point = concat_points[-1] + part.size(3)
88
+ concat_points.append(next_point)
89
+
90
+ return noise_tensor, concat_points
91
+
92
+ # 如果所需 length 大于 origin length,不断地从中间插入 first_part 的中间部分
93
+ else:
94
+ # 计算需要重复front_width的次数
95
+ repeats = (width - release_width) // first_part_width
96
+ extra = (width - release_width) % first_part_width
97
+
98
+ _repeat_first_part_head_width = int(first_part_width / 2)
99
+ _repeat_first_part_tail_width = first_part_width - _repeat_first_part_head_width
100
+
101
+ repeated_first_head_parts = [first_part[:, :, :, :_repeat_first_part_head_width] for _ in range(repeats)]
102
+ repeated_first_tail_parts = [first_part[:, :, :, -_repeat_first_part_tail_width:] for _ in range(repeats)]
103
+
104
+ # 计算起始索引
105
+ _middle_part_start_index = (first_part_width - extra) // 2
106
+ # 切片张量以获取中间部分
107
+ middle_part = first_part[:, :, :, _middle_part_start_index: _middle_part_start_index + extra]
108
+
109
+ all_parts = repeated_first_head_parts + [middle_part] + repeated_first_tail_parts + [release_part]
110
+
111
+ # 沿第四维度拼接张量
112
+ noise_tensor = torch.cat(all_parts, dim=3)
113
+
114
+ # 记录拼接点的位置
115
+ concat_points = [0]
116
+ for part in all_parts[:-1]:
117
+ next_point = concat_points[-1] + part.size(3)
118
+ concat_points.append(next_point)
119
+
120
+ return noise_tensor, concat_points
121
+
122
+ def define_beta_schedule(self):
123
+ assert self.respaced == False, "This schedule has already been respaced!"
124
+ # define alphas
125
+ self.alphas = 1.0 - self.betas
126
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
127
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
128
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
129
+
130
+ # calculations for diffusion q(x_t | x_{t-1}) and others
131
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
132
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
133
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
134
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
135
+ self.sqrt_recip_alphas = np.sqrt(1.0 / self.alphas)
136
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
137
+
138
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
139
+ self.posterior_variance = (self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod))
140
+
141
+ def activate_classifier_free_guidance(self, CFG, unconditional_condition):
142
+ assert (
143
+ not unconditional_condition is None) or CFG == 1.0, "For CFG != 1.0, unconditional_condition must be available"
144
+ self.CFG = CFG
145
+ self.unconditional_condition = unconditional_condition
146
+
147
+ def respace(self, use_timesteps=None):
148
+ if not use_timesteps is None:
149
+ last_alpha_cumprod = 1.0
150
+ new_betas = []
151
+ self.timestep_map = []
152
+ for i, _alpha_cumprod in enumerate(self.alphas_cumprod):
153
+ if i in use_timesteps:
154
+ new_betas.append(1 - _alpha_cumprod / last_alpha_cumprod)
155
+ last_alpha_cumprod = _alpha_cumprod
156
+ self.timestep_map.append(i)
157
+ self.num_timesteps = len(use_timesteps)
158
+ self.betas = np.array(new_betas)
159
+ self.define_beta_schedule()
160
+ self.respaced = True
161
+
162
+ def generate_linear_noise(self, shape, variance=1.0, first_endpoint=None, second_endpoint=None):
163
+ assert shape[1] == self.channels, "shape[1] != self.channels"
164
+ assert shape[2] == self.height, "shape[2] != self.height"
165
+ noise = torch.empty(*shape, device=self.device)
166
+
167
+ # 第三种情况:两个端点都不是None,进行线性插值
168
+ if first_endpoint is not None and second_endpoint is not None:
169
+ for i in range(shape[0]):
170
+ alpha = i / (shape[0] - 1) # 插值系数
171
+ noise[i] = alpha * second_endpoint + (1 - alpha) * first_endpoint
172
+ return noise # 返回插值后的结果,不需要进行后续的均值和方差调整
173
+ else:
174
+ # 第一个端点不是None
175
+ if first_endpoint is not None:
176
+ noise[0] = first_endpoint
177
+ if shape[0] > 1:
178
+ noise[1], _ = self.get_deterministic_noise_tensor(1, shape[3])[0]
179
+ else:
180
+ noise[0], _ = self.get_deterministic_noise_tensor(1, shape[3])[0]
181
+ if shape[0] > 1:
182
+ noise[1], _ = self.get_deterministic_noise_tensor(1, shape[3])[0]
183
+
184
+ # 生成其他的噪声点
185
+ for i in range(2, shape[0]):
186
+ noise[i] = 2 * noise[i - 1] - noise[i - 2]
187
+
188
+ # 当只有一个端点被指定时
189
+ current_var = noise.var()
190
+ stddev_ratio = torch.sqrt(variance / current_var)
191
+ noise = noise * stddev_ratio
192
+
193
+ # 如果第一个端点被指定,进行平移调整
194
+ if first_endpoint is not None:
195
+ shift = first_endpoint - noise[0]
196
+ noise += shift
197
+
198
+ return noise
199
+
200
+ def q_sample(self, x_start, t, noise=None):
201
+ """
202
+ Diffuse the data for a given number of diffusion steps.
203
+
204
+ In other words, sample from q(x_t | x_0).
205
+
206
+ :param x_start: the initial data batch.
207
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
208
+ :param noise: if specified, the split-out normal noise.
209
+ :return: A noisy version of x_start.
210
+ """
211
+ assert x_start.shape[1] == self.channels, "shape[1] != self.channels"
212
+ assert x_start.shape[2] == self.height, "shape[2] != self.height"
213
+
214
+ if noise is None:
215
+ # noise = torch.randn_like(x_start)
216
+ noise, _ = self.get_deterministic_noise_tensor(x_start.shape[0], x_start.shape[3])
217
+
218
+ assert noise.shape == x_start.shape
219
+ return (
220
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
221
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
222
+ * noise
223
+ )
224
+
225
+ @torch.no_grad()
226
+ def ddim_sample(self, model, x, t, condition=None, ddim_eta=0.0):
227
+ map_tensor = torch.tensor(self.timestep_map, device=t.device, dtype=t.dtype)
228
+ mapped_t = map_tensor[t]
229
+
230
+ # Todo: add CFG
231
+
232
+ if self.CFG == 1.0:
233
+ pred_noise = model(x, mapped_t, condition)
234
+ else:
235
+ unconditional_condition = self.unconditional_condition.unsqueeze(0).repeat(
236
+ *([x.shape[0]] + [1] * len(self.unconditional_condition.shape)))
237
+ x_in = torch.cat([x] * 2)
238
+ t_in = torch.cat([mapped_t] * 2)
239
+ c_in = torch.cat([unconditional_condition, condition])
240
+ noise_uncond, noise = model(x_in, t_in, c_in).chunk(2)
241
+ pred_noise = noise_uncond + self.CFG * (noise - noise_uncond)
242
+
243
+ # Todo: END
244
+
245
+ alpha_cumprod_t = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
246
+ alpha_cumprod_t_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
247
+
248
+ pred_x0 = (x - torch.sqrt((1. - alpha_cumprod_t)) * pred_noise) / torch.sqrt(alpha_cumprod_t)
249
+
250
+ sigmas_t = (
251
+ ddim_eta
252
+ * torch.sqrt((1 - alpha_cumprod_t_prev) / (1 - alpha_cumprod_t))
253
+ * torch.sqrt(1 - alpha_cumprod_t / alpha_cumprod_t_prev)
254
+ )
255
+
256
+ pred_dir_xt = torch.sqrt(1 - alpha_cumprod_t_prev - sigmas_t ** 2) * pred_noise
257
+
258
+
259
+ step_noise, _ = self.get_deterministic_noise_tensor(x.shape[0], x.shape[3])
260
+
261
+
262
+ x_prev = torch.sqrt(alpha_cumprod_t_prev) * pred_x0 + pred_dir_xt + sigmas_t * step_noise
263
+
264
+ return x_prev
265
+
266
+ def p_sample(self, model, x, t, condition=None, sampler="ddim"):
267
+ if sampler == "ddim":
268
+ return self.ddim_sample(model, x, t, condition=condition, ddim_eta=0.0)
269
+ elif sampler == "ddpm":
270
+ return self.ddim_sample(model, x, t, condition=condition, ddim_eta=1.0)
271
+ else:
272
+ raise NotImplementedError()
273
+
274
+ def get_dynamic_masks(self, n_masks, shape, concat_points, mask_flexivity=0.8):
275
+ release_length = int(self.train_width / 4)
276
+ assert shape[3] == (concat_points[-1] + release_length), "shape[3] != (concat_points[-1] + release_length)"
277
+
278
+ fraction_lengths = [concat_points[i + 1] - concat_points[i] for i in range(len(concat_points) - 1)]
279
+
280
+ # Todo: remove hard-coding
281
+ n_guidance_steps = int(n_masks * mask_flexivity)
282
+ n_free_steps = n_masks - n_guidance_steps
283
+
284
+ masks = []
285
+ # Todo: 在一半的 steps 内收缩 mask。也就是说,在后程对 release 以外的区域不做inpaint,而是 img2img
286
+ for i in range(n_guidance_steps):
287
+ # mask = 1, freeze
288
+ step_i_mask = torch.zeros((shape[0], 1, shape[2], shape[3]), dtype=torch.float32).to(self.device)
289
+ step_i_mask[:, :, :, -release_length:] = 1.0
290
+
291
+ for fraction_index in range(len(fraction_lengths)):
292
+
293
+ _fraction_mask_length = int((n_guidance_steps - 1 - i) / (n_guidance_steps - 1) * fraction_lengths[fraction_index])
294
+
295
+ if fraction_index == 0:
296
+ step_i_mask[:, :, :, :_fraction_mask_length] = 1.0
297
+ elif fraction_index == len(fraction_lengths) - 1:
298
+ if not _fraction_mask_length == 0:
299
+ step_i_mask[:, :, :, -_fraction_mask_length - release_length:] = 1.0
300
+ else:
301
+ fraction_mask_start_position = int((fraction_lengths[fraction_index] - _fraction_mask_length) / 2)
302
+
303
+ step_i_mask[:, :, :,
304
+ concat_points[fraction_index] + fraction_mask_start_position:concat_points[
305
+ fraction_index] + fraction_mask_start_position + _fraction_mask_length] = 1.0
306
+ masks.append(step_i_mask)
307
+
308
+ for i in range(n_free_steps):
309
+ step_i_mask = torch.zeros((shape[0], 1, shape[2], shape[3]), dtype=torch.float32).to(self.device)
310
+ step_i_mask[:, :, :, -release_length:] = 1.0
311
+ masks.append(step_i_mask)
312
+
313
+ masks.reverse()
314
+ return masks
315
+
316
+ @torch.no_grad()
317
+ def p_sample_loop(self, model, shape, initial_noise=None, start_noise_level_ratio=1.0, end_noise_level_ratio=0.0,
318
+ return_tensor=False, condition=None, guide_img=None,
319
+ mask=None, sampler="ddim", inpaint=False, use_dynamic_mask=False, mask_flexivity=0.8):
320
+
321
+ assert shape[1] == self.channels, "shape[1] != self.channels"
322
+ assert shape[2] == self.height, "shape[2] != self.height"
323
+
324
+ initial_noise, _ = self.get_deterministic_noise_tensor(shape[0], shape[3], reference_noise=initial_noise)
325
+ assert initial_noise.shape == shape, "initial_noise.shape != shape"
326
+
327
+ start_noise_level_index = int(self.num_timesteps * start_noise_level_ratio) # not included!!!
328
+ end_noise_level_index = int(self.num_timesteps * end_noise_level_ratio)
329
+
330
+ timesteps = reversed(range(end_noise_level_index, start_noise_level_index))
331
+
332
+ # configure initial img
333
+ assert (start_noise_level_ratio == 1.0) or (
334
+ not guide_img is None), "A guide_img must be given to sample from a non-pure-noise."
335
+
336
+ if guide_img is None:
337
+ img = initial_noise
338
+ else:
339
+ guide_img, concat_points = self.get_deterministic_noise_tensor_repeat(shape[0], shape[3], reference_noise=guide_img)
340
+ assert guide_img.shape == shape, "guide_img.shape != shape"
341
+
342
+ if start_noise_level_index > 0:
343
+ t = torch.full((shape[0],), start_noise_level_index-1, device=self.device).long() # -1 for start_noise_level_index not included
344
+ img = self.q_sample(guide_img, t, noise=initial_noise)
345
+ else:
346
+ print("Zero noise added to the guidance latent representation.")
347
+ img = guide_img
348
+
349
+ # get masks
350
+ n_masks = start_noise_level_index - end_noise_level_index
351
+ if use_dynamic_mask:
352
+ masks = self.get_dynamic_masks(n_masks, shape, concat_points, mask_flexivity)
353
+ else:
354
+ masks = [mask for _ in range(n_masks)]
355
+
356
+ imgs = [img]
357
+ current_mask = None
358
+
359
+
360
+ for i in tqdm(timesteps, total=start_noise_level_index - end_noise_level_index, disable=self.mute):
361
+
362
+ # if i == 3:
363
+ # return [img], initial_noise # 第1排,第1列
364
+
365
+ img = self.p_sample(model, img, torch.full((shape[0],), i, device=self.device, dtype=torch.long),
366
+ condition=condition,
367
+ sampler=sampler)
368
+ # if i == 3:
369
+ # return [img], initial_noise # 第1排,第2列
370
+
371
+ if inpaint:
372
+ if i > 0:
373
+ t = torch.full((shape[0],), int(i-1), device=self.device).long()
374
+ img_noise_t = self.q_sample(guide_img, t, noise=initial_noise)
375
+ # if i == 3:
376
+ # return [img_noise_t], initial_noise # 第2排,第2列
377
+ current_mask = masks.pop()
378
+ img = current_mask * img_noise_t + (1 - current_mask) * img
379
+ # if i == 3:
380
+ # return [img], initial_noise # 第1.5排,最后1列
381
+ else:
382
+ img = current_mask * guide_img + (1 - current_mask) * img
383
+
384
+ if return_tensor:
385
+ imgs.append(img)
386
+ else:
387
+ imgs.append(img.cpu().numpy())
388
+
389
+ return imgs, initial_noise
390
+
391
+
392
+ def sample(self, model, shape, return_tensor=False, condition=None, sampler="ddim", initial_noise=None, seed=None):
393
+ if not seed is None:
394
+ torch.manual_seed(seed)
395
+ return self.p_sample_loop(model, shape, initial_noise=initial_noise, start_noise_level_ratio=1.0, end_noise_level_ratio=0.0,
396
+ return_tensor=return_tensor, condition=condition, sampler=sampler)
397
+
398
+ def interpolate(self, model, shape, variance, first_endpoint=None, second_endpoint=None, return_tensor=False,
399
+ condition=None, sampler="ddim", seed=None):
400
+ if not seed is None:
401
+ torch.manual_seed(seed)
402
+ linear_noise = self.generate_linear_noise(shape, variance, first_endpoint=first_endpoint,
403
+ second_endpoint=second_endpoint)
404
+ return self.p_sample_loop(model, shape, initial_noise=linear_noise, start_noise_level_ratio=1.0,
405
+ end_noise_level_ratio=0.0,
406
+ return_tensor=return_tensor, condition=condition, sampler=sampler)
407
+
408
+ def img_guided_sample(self, model, shape, noising_strength, guide_img, return_tensor=False, condition=None,
409
+ sampler="ddim", initial_noise=None, seed=None):
410
+ if not seed is None:
411
+ torch.manual_seed(seed)
412
+ assert guide_img.shape[-1] == shape[-1], "guide_img.shape[:-1] != shape[:-1]"
413
+ return self.p_sample_loop(model, shape, start_noise_level_ratio=noising_strength, end_noise_level_ratio=0.0,
414
+ return_tensor=return_tensor, condition=condition, sampler=sampler,
415
+ guide_img=guide_img, initial_noise=initial_noise)
416
+
417
+ def inpaint_sample(self, model, shape, noising_strength, guide_img, mask, return_tensor=False, condition=None,
418
+ sampler="ddim", initial_noise=None, use_dynamic_mask=False, end_noise_level_ratio=0.0, seed=None,
419
+ mask_flexivity=0.8):
420
+ if not seed is None:
421
+ torch.manual_seed(seed)
422
+ return self.p_sample_loop(model, shape, start_noise_level_ratio=noising_strength, end_noise_level_ratio=end_noise_level_ratio,
423
+ return_tensor=return_tensor, condition=condition, guide_img=guide_img, mask=mask,
424
+ sampler=sampler, inpaint=True, initial_noise=initial_noise, use_dynamic_mask=use_dynamic_mask,
425
+ mask_flexivity=mask_flexivity)
model/GAN.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from six.moves import xrange
6
+ from torch.utils.tensorboard import SummaryWriter
7
+ import random
8
+
9
+ from model.diffusion import ConditionedUnet
10
+ from tools import create_key
11
+
12
+ class Discriminator(nn.Module):
13
+ def __init__(self, label_emb_dim):
14
+ super(Discriminator, self).__init__()
15
+ # 特征图卷积层
16
+ self.conv_layers = nn.Sequential(
17
+ nn.Conv2d(4, 64, kernel_size=4, stride=2, padding=1),
18
+ nn.LeakyReLU(0.2, inplace=True),
19
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
20
+ nn.BatchNorm2d(128),
21
+ nn.LeakyReLU(0.2, inplace=True),
22
+ nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
23
+ nn.BatchNorm2d(256),
24
+ nn.LeakyReLU(0.2, inplace=True),
25
+ nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
26
+ nn.BatchNorm2d(512),
27
+ nn.LeakyReLU(0.2, inplace=True),
28
+ nn.AdaptiveAvgPool2d(1), # 添加适应性池化层
29
+ nn.Flatten()
30
+ )
31
+
32
+ # 文本嵌入处理
33
+ self.text_embedding = nn.Sequential(
34
+ nn.Linear(label_emb_dim, 512),
35
+ nn.LeakyReLU(0.2, inplace=True)
36
+ )
37
+
38
+ # 判别器最后的全连接层
39
+ self.fc = nn.Linear(512 + 512, 1) # 两个512分别来自特征图和文本嵌入
40
+
41
+ def forward(self, x, text_emb):
42
+ x = self.conv_layers(x)
43
+ text_emb = self.text_embedding(text_emb)
44
+ combined = torch.cat((x, text_emb), dim=1)
45
+ output = self.fc(combined)
46
+ return output
47
+
48
+
49
+
50
+ def evaluate_GAN(device, generator, discriminator, iterator, encodes2embeddings_mapping):
51
+ generator.to(device)
52
+ discriminator.to(device)
53
+ generator.eval()
54
+ discriminator.eval()
55
+
56
+ real_accs = []
57
+ fake_accs = []
58
+
59
+ with torch.no_grad():
60
+ for i in range(100):
61
+ data, attributes = next(iter(iterator))
62
+ data = data.to(device)
63
+
64
+ conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes]
65
+ selected_conditions = [random.choice(conditions_of_one_sample) for conditions_of_one_sample in conditions]
66
+ selected_conditions = torch.stack(selected_conditions).float().to(device)
67
+
68
+ # 将数据和标签移至设备
69
+ real_images = data.to(device)
70
+ labels = selected_conditions.to(device)
71
+
72
+ # 生成噪声和假图像
73
+ noise = torch.randn_like(real_images).to(device)
74
+ fake_images = generator(noise)
75
+
76
+ # 评估鉴别器的性能
77
+ real_preds = discriminator(real_images, labels).reshape(-1)
78
+ fake_preds = discriminator(fake_images, labels).reshape(-1)
79
+ real_acc = (real_preds > 0.5).float().mean().item() # 真实图像的准确率
80
+ fake_acc = (fake_preds < 0.5).float().mean().item() # 生成图像的准确率
81
+
82
+ real_accs.append(real_acc)
83
+ fake_accs.append(fake_acc)
84
+
85
+
86
+ # 计算平均准确率
87
+ average_real_acc = sum(real_accs) / len(real_accs)
88
+ average_fake_acc = sum(fake_accs) / len(fake_accs)
89
+
90
+ return average_real_acc, average_fake_acc
91
+
92
+
93
+ def get_Generator(model_Config, load_pretrain=False, model_name=None, device="cpu"):
94
+ generator = ConditionedUnet(**model_Config)
95
+ print(f"Model intialized, size: {sum(p.numel() for p in generator.parameters() if p.requires_grad)}")
96
+ generator.to(device)
97
+
98
+ if load_pretrain:
99
+ print(f"Loading weights from models/{model_name}_generator.pth")
100
+ checkpoint = torch.load(f'models/{model_name}_generator.pth', map_location=device)
101
+ generator.load_state_dict(checkpoint['model_state_dict'])
102
+ generator.eval()
103
+ return generator
104
+
105
+
106
+ def get_Discriminator(model_Config, load_pretrain=False, model_name=None, device="cpu"):
107
+ discriminator = Discriminator(**model_Config)
108
+ print(f"Model intialized, size: {sum(p.numel() for p in discriminator.parameters() if p.requires_grad)}")
109
+ discriminator.to(device)
110
+
111
+ if load_pretrain:
112
+ print(f"Loading weights from models/{model_name}_discriminator.pth")
113
+ checkpoint = torch.load(f'models/{model_name}_discriminator.pth', map_location=device)
114
+ discriminator.load_state_dict(checkpoint['model_state_dict'])
115
+ discriminator.eval()
116
+ return discriminator
117
+
118
+
119
+ def train_GAN(device, init_model_name, unetConfig, BATCH_SIZE, lr_G, lr_D, max_iter, iterator, load_pretrain,
120
+ encodes2embeddings_mapping, save_steps, unconditional_condition, uncondition_rate, save_model_name=None):
121
+
122
+ if save_model_name is None:
123
+ save_model_name = init_model_name
124
+
125
+ def save_model_hyperparameter(model_name, unetConfig, BATCH_SIZE, model_size, current_iter, current_loss):
126
+ model_hyperparameter = unetConfig
127
+ model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE
128
+ model_hyperparameter["lr_G"] = lr_G
129
+ model_hyperparameter["lr_D"] = lr_D
130
+ model_hyperparameter["model_size"] = model_size
131
+ model_hyperparameter["current_iter"] = current_iter
132
+ model_hyperparameter["current_loss"] = current_loss
133
+ with open(f"models/hyperparameters/{model_name}_GAN.json", "w") as json_file:
134
+ json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
135
+
136
+ generator = ConditionedUnet(**unetConfig)
137
+ discriminator = Discriminator(unetConfig["label_emb_dim"])
138
+ generator_size = sum(p.numel() for p in generator.parameters() if p.requires_grad)
139
+ discriminator_size = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
140
+
141
+ print(f"Generator trainable parameters: {generator_size}, discriminator trainable parameters: {discriminator_size}")
142
+ generator.to(device)
143
+ discriminator.to(device)
144
+ optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()), lr=lr_G, amsgrad=False)
145
+ optimizer_D = torch.optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()), lr=lr_D, amsgrad=False)
146
+
147
+ if load_pretrain:
148
+ print(f"Loading weights from models/{init_model_name}_generator.pt")
149
+ checkpoint = torch.load(f'models/{init_model_name}_generator.pth')
150
+ generator.load_state_dict(checkpoint['model_state_dict'])
151
+ optimizer_G.load_state_dict(checkpoint['optimizer_state_dict'])
152
+ print(f"Loading weights from models/{init_model_name}_discriminator.pt")
153
+ checkpoint = torch.load(f'models/{init_model_name}_discriminator.pth')
154
+ discriminator.load_state_dict(checkpoint['model_state_dict'])
155
+ optimizer_D.load_state_dict(checkpoint['optimizer_state_dict'])
156
+ else:
157
+ print("Model initialized.")
158
+ if max_iter == 0:
159
+ print("Return model directly.")
160
+ return generator, discriminator, optimizer_G, optimizer_D
161
+
162
+
163
+ train_loss_G, train_loss_D = [], []
164
+ writer = SummaryWriter(f'runs/{save_model_name}_GAN')
165
+
166
+ # average_real_acc, average_fake_acc = evaluate_GAN(device, generator, discriminator, iterator, encodes2embeddings_mapping)
167
+ # print(f"average_real_acc, average_fake_acc: {average_real_acc, average_fake_acc}")
168
+
169
+ criterion = nn.BCEWithLogitsLoss()
170
+ generator.train()
171
+ for i in xrange(max_iter):
172
+ data, attributes = next(iter(iterator))
173
+ data = data.to(device)
174
+
175
+ conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes]
176
+ unconditional_condition_copy = torch.tensor(unconditional_condition, dtype=torch.float32).to(device).detach()
177
+ selected_conditions = [unconditional_condition_copy if random.random() < uncondition_rate else random.choice(
178
+ conditions_of_one_sample) for conditions_of_one_sample in conditions]
179
+ batch_size = len(selected_conditions)
180
+ selected_conditions = torch.stack(selected_conditions).float().to(device)
181
+
182
+ # 将数据和标签移至设备
183
+ real_images = data.to(device)
184
+ labels = selected_conditions.to(device)
185
+
186
+ # 真实和假的标签
187
+ real_labels = torch.ones(batch_size, 1).to(device)
188
+ fake_labels = torch.zeros(batch_size, 1).to(device)
189
+
190
+ # ========== 训练鉴别器 ==========
191
+ optimizer_D.zero_grad()
192
+
193
+ # 计算鉴别器对真实图像的损失
194
+ outputs_real = discriminator(real_images, labels)
195
+ loss_D_real = criterion(outputs_real, real_labels)
196
+
197
+ # 生成假图像
198
+ noise = torch.randn_like(real_images).to(device)
199
+ fake_images = generator(noise, labels)
200
+
201
+ # 计算鉴别器对假图像的损失
202
+ outputs_fake = discriminator(fake_images.detach(), labels)
203
+ loss_D_fake = criterion(outputs_fake, fake_labels)
204
+
205
+ # 反向传播和优化
206
+ loss_D = loss_D_real + loss_D_fake
207
+ loss_D.backward()
208
+ optimizer_D.step()
209
+
210
+ # ========== 训练生成器 ==========
211
+ optimizer_G.zero_grad()
212
+
213
+ # 计算生成器的损失
214
+ outputs_fake = discriminator(fake_images, labels)
215
+ loss_G = criterion(outputs_fake, real_labels)
216
+
217
+ # 反向传播和优化
218
+ loss_G.backward()
219
+ optimizer_G.step()
220
+
221
+
222
+ train_loss_G.append(loss_G.item())
223
+ train_loss_D.append(loss_D.item())
224
+ step = int(optimizer_G.state_dict()['state'][list(optimizer_G.state_dict()['state'].keys())[0]]['step'].numpy())
225
+
226
+ if (i + 1) % 100 == 0:
227
+ print('%d step' % (step))
228
+
229
+ if (i + 1) % save_steps == 0:
230
+ current_loss_D = np.mean(train_loss_D[-save_steps:])
231
+ current_loss_G = np.mean(train_loss_G[-save_steps:])
232
+ print('current_loss_G: %.5f' % current_loss_G)
233
+ print('current_loss_D: %.5f' % current_loss_D)
234
+
235
+ writer.add_scalar(f"current_loss_G", current_loss_G, step)
236
+ writer.add_scalar(f"current_loss_D", current_loss_D, step)
237
+
238
+
239
+ torch.save({
240
+ 'model_state_dict': generator.state_dict(),
241
+ 'optimizer_state_dict': optimizer_G.state_dict(),
242
+ }, f'models/{save_model_name}_generator.pth')
243
+ save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, generator_size, step, current_loss_G)
244
+ torch.save({
245
+ 'model_state_dict': discriminator.state_dict(),
246
+ 'optimizer_state_dict': optimizer_D.state_dict(),
247
+ }, f'models/{save_model_name}_discriminator.pth')
248
+ save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, discriminator_size, step, current_loss_D)
249
+
250
+ if step % 10000 == 0:
251
+ torch.save({
252
+ 'model_state_dict': generator.state_dict(),
253
+ 'optimizer_state_dict': optimizer_G.state_dict(),
254
+ }, f'models/history/{save_model_name}_{step}_generator.pth')
255
+ torch.save({
256
+ 'model_state_dict': discriminator.state_dict(),
257
+ 'optimizer_state_dict': optimizer_D.state_dict(),
258
+ }, f'models/history/{save_model_name}_{step}_discriminator.pth')
259
+
260
+ return generator, discriminator, optimizer_G, optimizer_D
261
+
262
+
model/VQGAN.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from torch.utils.tensorboard import SummaryWriter
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from six.moves import xrange
8
+ from einops import rearrange
9
+ from torchvision import models
10
+
11
+
12
+ def Normalize(in_channels, num_groups=32, norm_type="groupnorm"):
13
+ """Normalization layer"""
14
+
15
+ if norm_type == "batchnorm":
16
+ return torch.nn.BatchNorm2d(in_channels)
17
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
18
+
19
+
20
+ def nonlinearity(x, act_type="relu"):
21
+ """Nonlinear activation function"""
22
+
23
+ if act_type == "relu":
24
+ return F.relu(x)
25
+ else:
26
+ # swish
27
+ return x * torch.sigmoid(x)
28
+
29
+
30
+ class VectorQuantizer(nn.Module):
31
+ """Vector quantization layer"""
32
+
33
+ def __init__(self, num_embeddings, embedding_dim, commitment_cost):
34
+ super(VectorQuantizer, self).__init__()
35
+
36
+ self._embedding_dim = embedding_dim
37
+ self._num_embeddings = num_embeddings
38
+
39
+ self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
40
+ self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings)
41
+ self._commitment_cost = commitment_cost
42
+
43
+ def forward(self, inputs):
44
+ # convert inputs from BCHW -> BHWC
45
+ inputs = inputs.permute(0, 2, 3, 1).contiguous()
46
+ input_shape = inputs.shape
47
+
48
+ # Flatten input BCHW -> (BHW)C
49
+ flat_input = inputs.view(-1, self._embedding_dim)
50
+
51
+ # Calculate distances (input-embedding)^2
52
+ distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
53
+ + torch.sum(self._embedding.weight ** 2, dim=1)
54
+ - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
55
+
56
+ # Encoding (one-hot-encoding matrix)
57
+ encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
58
+ encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
59
+ encodings.scatter_(1, encoding_indices, 1)
60
+
61
+ # Quantize and unflatten
62
+ quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
63
+
64
+ # Loss
65
+ e_latent_loss = F.mse_loss(quantized.detach(), inputs)
66
+ q_latent_loss = F.mse_loss(quantized, inputs.detach())
67
+ loss = q_latent_loss + self._commitment_cost * e_latent_loss
68
+
69
+ quantized = inputs + (quantized - inputs).detach()
70
+ avg_probs = torch.mean(encodings, dim=0)
71
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
72
+
73
+ # convert quantized from BHWC -> BCHW
74
+ min_encodings, min_encoding_indices = None, None
75
+ return quantized.permute(0, 3, 1, 2).contiguous(), loss, (perplexity, min_encodings, min_encoding_indices)
76
+
77
+
78
+ class VectorQuantizerEMA(nn.Module):
79
+ """Vector quantization layer based on exponential moving average"""
80
+
81
+ def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
82
+ super(VectorQuantizerEMA, self).__init__()
83
+
84
+ self._embedding_dim = embedding_dim
85
+ self._num_embeddings = num_embeddings
86
+
87
+ self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
88
+ self._embedding.weight.data.normal_()
89
+ self._commitment_cost = commitment_cost
90
+
91
+ self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
92
+ self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
93
+ self._ema_w.data.normal_()
94
+
95
+ self._decay = decay
96
+ self._epsilon = epsilon
97
+
98
+ def forward(self, inputs):
99
+ # convert inputs from BCHW -> BHWC
100
+ inputs = inputs.permute(0, 2, 3, 1).contiguous()
101
+ input_shape = inputs.shape
102
+
103
+ # Flatten input
104
+ flat_input = inputs.view(-1, self._embedding_dim)
105
+
106
+ # Calculate distances
107
+ distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
108
+ + torch.sum(self._embedding.weight ** 2, dim=1)
109
+ - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
110
+
111
+ # Encoding
112
+ encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
113
+ encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
114
+ encodings.scatter_(1, encoding_indices, 1)
115
+
116
+ # Quantize and unflatten
117
+ quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
118
+
119
+ # Use EMA to update the embedding vectors
120
+ if self.training:
121
+ self._ema_cluster_size = self._ema_cluster_size * self._decay + \
122
+ (1 - self._decay) * torch.sum(encodings, 0)
123
+
124
+ # Laplace smoothing of the cluster size
125
+ n = torch.sum(self._ema_cluster_size.data)
126
+ self._ema_cluster_size = (
127
+ (self._ema_cluster_size + self._epsilon)
128
+ / (n + self._num_embeddings * self._epsilon) * n)
129
+
130
+ dw = torch.matmul(encodings.t(), flat_input)
131
+ self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
132
+
133
+ self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
134
+
135
+ # Loss
136
+ e_latent_loss = F.mse_loss(quantized.detach(), inputs)
137
+ loss = self._commitment_cost * e_latent_loss
138
+
139
+ # Straight Through Estimator
140
+ quantized = inputs + (quantized - inputs).detach()
141
+ avg_probs = torch.mean(encodings, dim=0)
142
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
143
+
144
+ # convert quantized from BHWC -> BCHW
145
+ min_encodings, min_encoding_indices = None, None
146
+ return quantized.permute(0, 3, 1, 2).contiguous(), loss, (perplexity, min_encodings, min_encoding_indices)
147
+
148
+
149
+ class DownSample(nn.Module):
150
+ """DownSample layer"""
151
+
152
+ def __init__(self, in_channels, out_channels):
153
+ super(DownSample, self).__init__()
154
+ self._conv2d = nn.Conv2d(in_channels=in_channels,
155
+ out_channels=out_channels,
156
+ kernel_size=4,
157
+ stride=2, padding=1)
158
+
159
+ def forward(self, x):
160
+ return self._conv2d(x)
161
+
162
+
163
+ class UpSample(nn.Module):
164
+ """UpSample layer"""
165
+
166
+ def __init__(self, in_channels, out_channels):
167
+ super(UpSample, self).__init__()
168
+ self._conv2d = nn.ConvTranspose2d(in_channels=in_channels,
169
+ out_channels=out_channels,
170
+ kernel_size=4,
171
+ stride=2, padding=1)
172
+
173
+ def forward(self, x):
174
+ return self._conv2d(x)
175
+
176
+
177
+ class ResnetBlock(nn.Module):
178
+ """ResnetBlock is a combination of non-linearity, convolution, and normalization"""
179
+
180
+ def __init__(self, *, in_channels, out_channels=None, double_conv=False, conv_shortcut=False,
181
+ dropout=0.0, temb_channels=512, norm_type="groupnorm", act_type="relu", num_groups=32):
182
+ super().__init__()
183
+ self.in_channels = in_channels
184
+ out_channels = in_channels if out_channels is None else out_channels
185
+ self.out_channels = out_channels
186
+ self.use_conv_shortcut = conv_shortcut
187
+ self.act_type = act_type
188
+
189
+ self.norm1 = Normalize(in_channels, norm_type=norm_type, num_groups=num_groups)
190
+ self.conv1 = torch.nn.Conv2d(in_channels,
191
+ out_channels,
192
+ kernel_size=3,
193
+ stride=1,
194
+ padding=1)
195
+ if temb_channels > 0:
196
+ self.temb_proj = torch.nn.Linear(temb_channels,
197
+ out_channels)
198
+
199
+ self.double_conv = double_conv
200
+ if self.double_conv:
201
+ self.norm2 = Normalize(out_channels, norm_type=norm_type, num_groups=num_groups)
202
+ self.dropout = torch.nn.Dropout(dropout)
203
+ self.conv2 = torch.nn.Conv2d(out_channels,
204
+ out_channels,
205
+ kernel_size=3,
206
+ stride=1,
207
+ padding=1)
208
+
209
+ if self.in_channels != self.out_channels:
210
+ if self.use_conv_shortcut:
211
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
212
+ out_channels,
213
+ kernel_size=3,
214
+ stride=1,
215
+ padding=1)
216
+ else:
217
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
218
+ out_channels,
219
+ kernel_size=1,
220
+ stride=1,
221
+ padding=0)
222
+
223
+ def forward(self, x, temb=None):
224
+ h = x
225
+ h = self.norm1(h)
226
+ h = nonlinearity(h, act_type=self.act_type)
227
+ h = self.conv1(h)
228
+
229
+ if temb is not None:
230
+ h = h + self.temb_proj(nonlinearity(temb, act_type=self.act_type))[:, :, None, None]
231
+
232
+ if self.double_conv:
233
+ h = self.norm2(h)
234
+ h = nonlinearity(h, act_type=self.act_type)
235
+ h = self.dropout(h)
236
+ h = self.conv2(h)
237
+
238
+ if self.in_channels != self.out_channels:
239
+ if self.use_conv_shortcut:
240
+ x = self.conv_shortcut(x)
241
+ else:
242
+ x = self.nin_shortcut(x)
243
+
244
+ return x + h
245
+
246
+
247
+ class LinearAttention(nn.Module):
248
+ """Efficient attention block based on <https://proceedings.mlr.press/v119/katharopoulos20a.html>"""
249
+
250
+ def __init__(self, dim, heads=4, dim_head=32, with_skip=True):
251
+ super().__init__()
252
+ self.heads = heads
253
+ hidden_dim = dim_head * heads
254
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
255
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
256
+
257
+ self.with_skip = with_skip
258
+ if self.with_skip:
259
+ self.nin_shortcut = torch.nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0)
260
+
261
+ def forward(self, x):
262
+ b, c, h, w = x.shape
263
+ qkv = self.to_qkv(x)
264
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3)
265
+ k = k.softmax(dim=-1)
266
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
267
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
268
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
269
+
270
+ if self.with_skip:
271
+ return self.to_out(out) + self.nin_shortcut(x)
272
+ return self.to_out(out)
273
+
274
+
275
+ class Encoder(nn.Module):
276
+ """The encoder, consisting of alternating stacks of ResNet blocks, efficient attention modules, and downsampling layers."""
277
+
278
+ def __init__(self, in_channels, hidden_channels, embedding_dim, block_depth=2,
279
+ attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu", num_groups=32):
280
+ super(Encoder, self).__init__()
281
+
282
+ if attn_pos is None:
283
+ attn_pos = []
284
+ self._layers = nn.ModuleList([DownSample(in_channels, hidden_channels[0])])
285
+ current_channel = hidden_channels[0]
286
+
287
+ for i in range(1, len(hidden_channels)):
288
+ for _ in range(block_depth - 1):
289
+ self._layers.append(ResnetBlock(in_channels=current_channel,
290
+ out_channels=current_channel,
291
+ double_conv=False,
292
+ conv_shortcut=False,
293
+ norm_type=norm_type,
294
+ act_type=act_type,
295
+ num_groups=num_groups))
296
+ if current_channel in attn_pos:
297
+ self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
298
+
299
+ self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
300
+ self._layers.append(nn.ReLU())
301
+ self._layers.append(DownSample(current_channel, hidden_channels[i]))
302
+ current_channel = hidden_channels[i]
303
+
304
+ for _ in range(block_depth - 1):
305
+ self._layers.append(ResnetBlock(in_channels=current_channel,
306
+ out_channels=current_channel,
307
+ double_conv=False,
308
+ conv_shortcut=False,
309
+ norm_type=norm_type,
310
+ act_type=act_type,
311
+ num_groups=num_groups))
312
+ if current_channel in attn_pos:
313
+ self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
314
+
315
+ # Conv1x1: hidden_channels[-1] -> embedding_dim
316
+ self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
317
+ self._layers.append(nn.ReLU())
318
+ self._layers.append(nn.Conv2d(in_channels=current_channel,
319
+ out_channels=embedding_dim,
320
+ kernel_size=1,
321
+ stride=1))
322
+
323
+ def forward(self, x):
324
+ for layer in self._layers:
325
+ x = layer(x)
326
+ return x
327
+
328
+
329
+ class Decoder(nn.Module):
330
+ """The decoder, consisting of alternating stacks of ResNet blocks, efficient attention modules, and upsampling layers."""
331
+
332
+ def __init__(self, embedding_dim, hidden_channels, out_channels, block_depth=2,
333
+ attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu",
334
+ num_groups=32):
335
+ super(Decoder, self).__init__()
336
+
337
+ if attn_pos is None:
338
+ attn_pos = []
339
+ reversed_hidden_channels = list(reversed(hidden_channels))
340
+
341
+ # Conv1x1: hidden_channels[-1] -> embedding_dim
342
+ self._layers = nn.ModuleList([nn.Conv2d(in_channels=embedding_dim,
343
+ out_channels=reversed_hidden_channels[0],
344
+ kernel_size=1, stride=1, bias=False)])
345
+
346
+ current_channel = reversed_hidden_channels[0]
347
+
348
+ for _ in range(block_depth - 1):
349
+ if current_channel in attn_pos:
350
+ self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
351
+ self._layers.append(ResnetBlock(in_channels=current_channel,
352
+ out_channels=current_channel,
353
+ double_conv=False,
354
+ conv_shortcut=False,
355
+ norm_type=norm_type,
356
+ act_type=act_type,
357
+ num_groups=num_groups))
358
+
359
+ for i in range(1, len(reversed_hidden_channels)):
360
+ self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
361
+ self._layers.append(nn.ReLU())
362
+ self._layers.append(UpSample(current_channel, reversed_hidden_channels[i]))
363
+ current_channel = reversed_hidden_channels[i]
364
+
365
+ for _ in range(block_depth - 1):
366
+ if current_channel in attn_pos:
367
+ self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
368
+ self._layers.append(ResnetBlock(in_channels=current_channel,
369
+ out_channels=current_channel,
370
+ double_conv=False,
371
+ conv_shortcut=False,
372
+ norm_type=norm_type,
373
+ act_type=act_type,
374
+ num_groups=num_groups))
375
+
376
+ self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
377
+ self._layers.append(nn.ReLU())
378
+ self._layers.append(UpSample(current_channel, current_channel))
379
+
380
+ # final layers
381
+ self._layers.append(ResnetBlock(in_channels=current_channel,
382
+ out_channels=out_channels,
383
+ double_conv=False,
384
+ conv_shortcut=False,
385
+ norm_type=norm_type,
386
+ act_type=act_type,
387
+ num_groups=num_groups))
388
+
389
+
390
+ def forward(self, x):
391
+ for layer in self._layers:
392
+ x = layer(x)
393
+
394
+ log_magnitude = torch.nn.functional.softplus(x[:, 0, :, :])
395
+
396
+ cos_phase = torch.tanh(x[:, 1, :, :])
397
+ sin_phase = torch.tanh(x[:, 2, :, :])
398
+ x = torch.stack([log_magnitude, cos_phase, sin_phase], dim=1)
399
+
400
+ return x
401
+
402
+
403
+ class VQGAN_Discriminator(nn.Module):
404
+ """The discriminator employs an 18-layer-ResNet architecture , with the first layer replaced by a 2D convolutional
405
+ layer that accommodates spectral representation inputs and the last two layers replaced by a binary classifier
406
+ layer."""
407
+
408
+ def __init__(self, in_channels=1):
409
+ super(VQGAN_Discriminator, self).__init__()
410
+ resnet = models.resnet18(pretrained=True)
411
+
412
+ # 修改第一层以接受单通道(黑白)图像
413
+ resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
414
+
415
+ # 使用ResNet的特征提取部分
416
+ self.features = nn.Sequential(*list(resnet.children())[:-2])
417
+
418
+ # 添加判别器的额外层
419
+ self.classifier = nn.Sequential(
420
+ nn.Linear(512, 1),
421
+ nn.Sigmoid()
422
+ )
423
+
424
+ def forward(self, x):
425
+ x = self.features(x)
426
+ x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
427
+ x = torch.flatten(x, 1)
428
+ x = self.classifier(x)
429
+ return x
430
+
431
+
432
+ class VQGAN(nn.Module):
433
+ """The VQ-GAN model. <https://openaccess.thecvf.com/content/CVPR2021/html/Esser_Taming_Transformers_for_High-Resolution_Image_Synthesis_CVPR_2021_paper.html?ref=>"""
434
+
435
+ def __init__(self, in_channels, hidden_channels, embedding_dim, out_channels, block_depth=2,
436
+ attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu",
437
+ num_embeddings=1024, commitment_cost=0.25, decay=0.99, num_groups=32):
438
+ super(VQGAN, self).__init__()
439
+
440
+ self._encoder = Encoder(in_channels, hidden_channels, embedding_dim, block_depth=block_depth,
441
+ attn_pos=attn_pos, attn_with_skip=attn_with_skip, norm_type=norm_type, act_type="act_type", num_groups=num_groups)
442
+
443
+ if decay > 0.0:
444
+ self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim,
445
+ commitment_cost, decay)
446
+ else:
447
+ self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
448
+ commitment_cost)
449
+ self._decoder = Decoder(embedding_dim, hidden_channels, out_channels, block_depth=block_depth,
450
+ attn_pos=attn_pos, attn_with_skip=attn_with_skip, norm_type=norm_type,
451
+ act_type=act_type, num_groups=num_groups)
452
+
453
+ def forward(self, x):
454
+ z = self._encoder(x)
455
+ quantized, vq_loss, (perplexity, _, _) = self._vq_vae(z)
456
+ x_recon = self._decoder(quantized)
457
+
458
+ return vq_loss, x_recon, perplexity
459
+
460
+
461
+ class ReconstructionLoss(nn.Module):
462
+ def __init__(self, w1, w2, epsilon=1e-3):
463
+ super(ReconstructionLoss, self).__init__()
464
+ self.w1 = w1
465
+ self.w2 = w2
466
+ self.epsilon = epsilon
467
+
468
+ def weighted_mae_loss(self, y_true, y_pred):
469
+ # avoid divide by zero
470
+ y_true_safe = torch.clamp(y_true, min=self.epsilon)
471
+
472
+ # compute weighted MAE
473
+ loss = torch.mean(torch.abs(y_pred - y_true) / y_true_safe)
474
+ return loss
475
+
476
+ def mae_loss(self, y_true, y_pred):
477
+ loss = torch.mean(torch.abs(y_pred - y_true))
478
+ return loss
479
+
480
+ def forward(self, y_pred, y_true):
481
+ # loss for magnitude channel
482
+ log_magnitude_loss = self.w1 * self.weighted_mae_loss(y_pred[:, 0, :, :], y_true[:, 0, :, :])
483
+
484
+ # loss for phase channels
485
+ phase_loss = self.w2 * self.mae_loss(y_pred[:, 1:, :, :], y_true[:, 1:, :, :])
486
+
487
+ # sum up
488
+ rec_loss = log_magnitude_loss + phase_loss
489
+ return log_magnitude_loss, phase_loss, rec_loss
490
+
491
+
492
+ def evaluate_VQGAN(model, discriminator, iterator, reconstructionLoss, adversarial_loss, trainingConfig):
493
+ model.to(trainingConfig["device"])
494
+ model.eval()
495
+ train_res_error = []
496
+ for i in xrange(100):
497
+ data = next(iter(iterator))
498
+ data = data.to(trainingConfig["device"])
499
+
500
+ # true/fake labels
501
+ real_labels = torch.ones(data.size(0), 1).to(trainingConfig["device"])
502
+
503
+ vq_loss, data_recon, perplexity = model(data)
504
+
505
+
506
+ fake_preds = discriminator(data_recon)
507
+ adver_loss = adversarial_loss(fake_preds, real_labels)
508
+
509
+ log_magnitude_loss, phase_loss, rec_loss = reconstructionLoss(data_recon, data)
510
+ loss = rec_loss + trainingConfig["vq_weight"] * vq_loss + trainingConfig["adver_weight"] * adver_loss
511
+
512
+ train_res_error.append(loss.item())
513
+ initial_loss = np.mean(train_res_error)
514
+ return initial_loss
515
+
516
+
517
+ def get_VQGAN(model_Config, load_pretrain=False, model_name=None, device="cpu"):
518
+ VQVAE = VQGAN(**model_Config)
519
+ print(f"Model intialized, size: {sum(p.numel() for p in VQVAE.parameters() if p.requires_grad)}")
520
+ VQVAE.to(device)
521
+
522
+ if load_pretrain:
523
+ print(f"Loading weights from models/{model_name}_imageVQVAE.pth")
524
+ checkpoint = torch.load(f'models/{model_name}_imageVQVAE.pth', map_location=device)
525
+ VQVAE.load_state_dict(checkpoint['model_state_dict'])
526
+ VQVAE.eval()
527
+ return VQVAE
528
+
529
+
530
+ def train_VQGAN(model_Config, trainingConfig, iterator):
531
+
532
+ def save_model_hyperparameter(model_Config, trainingConfig, current_iter,
533
+ log_magnitude_loss, phase_loss, current_perplexity, current_vq_loss,
534
+ current_loss):
535
+ model_name = trainingConfig["model_name"]
536
+ model_hyperparameter = model_Config
537
+ model_hyperparameter.update(trainingConfig)
538
+ model_hyperparameter["current_iter"] = current_iter
539
+ model_hyperparameter["log_magnitude_loss"] = log_magnitude_loss
540
+ model_hyperparameter["phase_loss"] = phase_loss
541
+ model_hyperparameter["erplexity"] = current_perplexity
542
+ model_hyperparameter["vq_loss"] = current_vq_loss
543
+ model_hyperparameter["total_loss"] = current_loss
544
+
545
+ with open(f"models/hyperparameters/{model_name}_VQGAN_STFT.json", "w") as json_file:
546
+ json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
547
+
548
+ # initialize VAE
549
+ model = VQGAN(**model_Config)
550
+ print(f"VQ_VAE size: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
551
+ model.to(trainingConfig["device"])
552
+
553
+ VAE_optimizer = torch.optim.Adam(model.parameters(), lr=trainingConfig["lr"], amsgrad=False)
554
+ model_name = trainingConfig["model_name"]
555
+
556
+ if trainingConfig["load_pretrain"]:
557
+ print(f"Loading weights from models/{model_name}_imageVQVAE.pth")
558
+ checkpoint = torch.load(f'models/{model_name}_imageVQVAE.pth', map_location=trainingConfig["device"])
559
+ model.load_state_dict(checkpoint['model_state_dict'])
560
+ VAE_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
561
+ else:
562
+ print("VAE initialized.")
563
+ if trainingConfig["max_iter"] == 0:
564
+ print("Return VAE directly.")
565
+ return model
566
+
567
+ # initialize discriminator
568
+ discriminator = VQGAN_Discriminator(model_Config["in_channels"])
569
+ print(f"Discriminator size: {sum(p.numel() for p in discriminator.parameters() if p.requires_grad)}")
570
+ discriminator.to(trainingConfig["device"])
571
+
572
+ discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=trainingConfig["d_lr"], amsgrad=False)
573
+
574
+ if trainingConfig["load_pretrain"]:
575
+ print(f"Loading weights from models/{model_name}_imageVQVAE_discriminator.pth")
576
+ checkpoint = torch.load(f'models/{model_name}_imageVQVAE_discriminator.pth', map_location=trainingConfig["device"])
577
+ discriminator.load_state_dict(checkpoint['model_state_dict'])
578
+ discriminator_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
579
+ else:
580
+ print("Discriminator initialized.")
581
+
582
+ # Training
583
+
584
+ train_res_phase_loss, train_res_perplexity, train_res_log_magnitude_loss, train_res_vq_loss, train_res_loss = [], [], [], [], []
585
+ train_discriminator_loss, train_adverserial_loss = [], []
586
+
587
+ reconstructionLoss = ReconstructionLoss(w1=trainingConfig["w1"], w2=trainingConfig["w2"], epsilon=trainingConfig["threshold"])
588
+
589
+ adversarial_loss = nn.BCEWithLogitsLoss()
590
+ writer = SummaryWriter(f'runs/{model_name}_VQVAE_lr=1e-4')
591
+
592
+ previous_lowest_loss = evaluate_VQGAN(model, discriminator, iterator,
593
+ reconstructionLoss, adversarial_loss, trainingConfig)
594
+ print(f"initial_loss: {previous_lowest_loss}")
595
+
596
+ model.train()
597
+ for i in xrange(trainingConfig["max_iter"]):
598
+ data = next(iter(iterator))
599
+ data = data.to(trainingConfig["device"])
600
+
601
+ # true/fake labels
602
+ real_labels = torch.ones(data.size(0), 1).to(trainingConfig["device"])
603
+ fake_labels = torch.zeros(data.size(0), 1).to(trainingConfig["device"])
604
+
605
+ # update discriminator
606
+ discriminator_optimizer.zero_grad()
607
+
608
+ vq_loss, data_recon, perplexity = model(data)
609
+
610
+ real_preds = discriminator(data)
611
+ fake_preds = discriminator(data_recon.detach())
612
+
613
+ loss_real = adversarial_loss(real_preds, real_labels)
614
+ loss_fake = adversarial_loss(fake_preds, fake_labels)
615
+
616
+ loss_D = loss_real + loss_fake
617
+ loss_D.backward()
618
+ discriminator_optimizer.step()
619
+
620
+
621
+ # update VQVAE
622
+ VAE_optimizer.zero_grad()
623
+
624
+ fake_preds = discriminator(data_recon)
625
+ adver_loss = adversarial_loss(fake_preds, real_labels)
626
+
627
+ log_magnitude_loss, phase_loss, rec_loss = reconstructionLoss(data_recon, data)
628
+
629
+ loss = rec_loss + trainingConfig["vq_weight"] * vq_loss + trainingConfig["adver_weight"] * adver_loss
630
+ loss.backward()
631
+ VAE_optimizer.step()
632
+
633
+ train_discriminator_loss.append(loss_D.item())
634
+ train_adverserial_loss.append(trainingConfig["adver_weight"] * adver_loss.item())
635
+ train_res_log_magnitude_loss.append(log_magnitude_loss.item())
636
+ train_res_phase_loss.append(phase_loss.item())
637
+ train_res_perplexity.append(perplexity.item())
638
+ train_res_vq_loss.append(trainingConfig["vq_weight"] * vq_loss.item())
639
+ train_res_loss.append(loss.item())
640
+ step = int(VAE_optimizer.state_dict()['state'][list(VAE_optimizer.state_dict()['state'].keys())[0]]['step'].cpu().numpy())
641
+
642
+ save_steps = trainingConfig["save_steps"]
643
+ if (i + 1) % 100 == 0:
644
+ print('%d step' % (step))
645
+
646
+ if (i + 1) % save_steps == 0:
647
+ current_discriminator_loss = np.mean(train_discriminator_loss[-save_steps:])
648
+ current_adverserial_loss = np.mean(train_adverserial_loss[-save_steps:])
649
+ current_log_magnitude_loss = np.mean(train_res_log_magnitude_loss[-save_steps:])
650
+ current_phase_loss = np.mean(train_res_phase_loss[-save_steps:])
651
+ current_perplexity = np.mean(train_res_perplexity[-save_steps:])
652
+ current_vq_loss = np.mean(train_res_vq_loss[-save_steps:])
653
+ current_loss = np.mean(train_res_loss[-save_steps:])
654
+
655
+ print('discriminator_loss: %.3f' % current_discriminator_loss)
656
+ print('adverserial_loss: %.3f' % current_adverserial_loss)
657
+ print('log_magnitude_loss: %.3f' % current_log_magnitude_loss)
658
+ print('phase_loss: %.3f' % current_phase_loss)
659
+ print('perplexity: %.3f' % current_perplexity)
660
+ print('vq_loss: %.3f' % current_vq_loss)
661
+ print('total_loss: %.3f' % current_loss)
662
+ writer.add_scalar(f"log_magnitude_loss", current_log_magnitude_loss, step)
663
+ writer.add_scalar(f"phase_loss", current_phase_loss, step)
664
+ writer.add_scalar(f"perplexity", current_perplexity, step)
665
+ writer.add_scalar(f"vq_loss", current_vq_loss, step)
666
+ writer.add_scalar(f"total_loss", current_loss, step)
667
+ if current_loss < previous_lowest_loss:
668
+ previous_lowest_loss = current_loss
669
+
670
+ torch.save({
671
+ 'model_state_dict': model.state_dict(),
672
+ 'optimizer_state_dict': VAE_optimizer.state_dict(),
673
+ }, f'models/{model_name}_imageVQVAE.pth')
674
+
675
+ torch.save({
676
+ 'model_state_dict': discriminator.state_dict(),
677
+ 'optimizer_state_dict': discriminator_optimizer.state_dict(),
678
+ }, f'models/{model_name}_imageVQVAE_discriminator.pth')
679
+
680
+ save_model_hyperparameter(model_Config, trainingConfig, step,
681
+ current_log_magnitude_loss, current_phase_loss, current_perplexity, current_vq_loss,
682
+ current_loss)
683
+
684
+ return model
model/__pycache__/DiffSynthSampler.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
model/__pycache__/GAN.cpython-310.pyc ADDED
Binary file (7.49 kB). View file
 
model/__pycache__/VQGAN.cpython-310.pyc ADDED
Binary file (18.8 kB). View file
 
model/__pycache__/diffusion.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
model/__pycache__/diffusion_components.cpython-310.pyc ADDED
Binary file (11.6 kB). View file
 
model/__pycache__/multimodal_model.cpython-310.pyc ADDED
Binary file (9.88 kB). View file
 
model/__pycache__/perceptual_label_predictor.cpython-37.pyc ADDED
Binary file (1.67 kB). View file
 
model/__pycache__/timbre_encoder_pretrain.cpython-310.pyc ADDED
Binary file (7.96 kB). View file
 
model/diffusion.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ from six.moves import xrange
9
+ from torch.utils.tensorboard import SummaryWriter
10
+ import random
11
+
12
+ from metrics.IS import get_inception_score
13
+ from tools import create_key
14
+
15
+ from model.diffusion_components import default, ConvNextBlock, ResnetBlock, SinusoidalPositionEmbeddings, Residual, \
16
+ PreNorm, \
17
+ Downsample, Upsample, exists, q_sample, get_beta_schedule, pad_and_concat, ConditionalEmbedding, \
18
+ LinearCrossAttention, LinearCrossAttentionAdd
19
+
20
+
21
+ class ConditionedUnet(nn.Module):
22
+ def __init__(
23
+ self,
24
+ in_dim,
25
+ out_dim=None,
26
+ down_dims=None,
27
+ up_dims=None,
28
+ mid_depth=3,
29
+ with_time_emb=True,
30
+ time_dim=None,
31
+ resnet_block_groups=8,
32
+ use_convnext=True,
33
+ convnext_mult=2,
34
+ attn_type="linear_cat",
35
+ n_label_class=11,
36
+ condition_type="instrument_family",
37
+ label_emb_dim=128,
38
+ ):
39
+ super().__init__()
40
+
41
+ self.label_embedding = ConditionalEmbedding(int(n_label_class + 1), int(label_emb_dim), condition_type)
42
+
43
+ if up_dims is None:
44
+ up_dims = [128, 128, 64, 32]
45
+ if down_dims is None:
46
+ down_dims = [32, 32, 64, 128]
47
+
48
+ out_dim = default(out_dim, in_dim)
49
+ assert len(down_dims) == len(up_dims), "len(down_dims) != len(up_dims)"
50
+ assert down_dims[0] == up_dims[-1], "down_dims[0] != up_dims[-1]"
51
+ assert up_dims[0] == down_dims[-1], "up_dims[0] != down_dims[-1]"
52
+ down_in_out = list(zip(down_dims[:-1], down_dims[1:]))
53
+ up_in_out = list(zip(up_dims[:-1], up_dims[1:]))
54
+ print(f"down_in_out: {down_in_out}")
55
+ print(f"up_in_out: {up_in_out}")
56
+ time_dim = default(time_dim, int(down_dims[0] * 4))
57
+
58
+ self.init_conv = nn.Conv2d(in_dim, down_dims[0], 7, padding=3)
59
+
60
+ if use_convnext:
61
+ block_klass = partial(ConvNextBlock, mult=convnext_mult)
62
+ else:
63
+ block_klass = partial(ResnetBlock, groups=resnet_block_groups)
64
+
65
+ if attn_type == "linear_cat":
66
+ attn_klass = partial(LinearCrossAttention)
67
+ elif attn_type == "linear_add":
68
+ attn_klass = partial(LinearCrossAttentionAdd)
69
+ else:
70
+ raise NotImplementedError()
71
+
72
+ # time embeddings
73
+ if with_time_emb:
74
+ self.time_mlp = nn.Sequential(
75
+ SinusoidalPositionEmbeddings(down_dims[0]),
76
+ nn.Linear(down_dims[0], time_dim),
77
+ nn.GELU(),
78
+ nn.Linear(time_dim, time_dim),
79
+ )
80
+ else:
81
+ time_dim = None
82
+ self.time_mlp = None
83
+
84
+ # left layers
85
+ self.downs = nn.ModuleList([])
86
+ self.ups = nn.ModuleList([])
87
+ skip_dims = []
88
+
89
+ for down_dim_in, down_dim_out in down_in_out:
90
+ self.downs.append(
91
+ nn.ModuleList(
92
+ [
93
+ block_klass(down_dim_in, down_dim_out, time_emb_dim=time_dim),
94
+
95
+ Residual(PreNorm(down_dim_out, attn_klass(down_dim_out, label_emb_dim=label_emb_dim, ))),
96
+ block_klass(down_dim_out, down_dim_out, time_emb_dim=time_dim),
97
+ Residual(PreNorm(down_dim_out, attn_klass(down_dim_out, label_emb_dim=label_emb_dim, ))),
98
+ Downsample(down_dim_out),
99
+ ]
100
+ )
101
+ )
102
+ skip_dims.append(down_dim_out)
103
+
104
+ # bottleneck
105
+ mid_dim = down_dims[-1]
106
+ self.mid_left = nn.ModuleList([])
107
+ self.mid_right = nn.ModuleList([])
108
+ for _ in range(mid_depth - 1):
109
+ self.mid_left.append(block_klass(mid_dim, mid_dim, time_emb_dim=time_dim))
110
+ self.mid_right.append(block_klass(mid_dim * 2, mid_dim, time_emb_dim=time_dim))
111
+ self.mid_mid = nn.ModuleList(
112
+ [
113
+ block_klass(mid_dim, mid_dim, time_emb_dim=time_dim),
114
+ Residual(PreNorm(mid_dim, attn_klass(mid_dim, label_emb_dim=label_emb_dim, ))),
115
+ block_klass(mid_dim, mid_dim, time_emb_dim=time_dim),
116
+ ]
117
+ )
118
+
119
+ # right layers
120
+ for ind, (up_dim_in, up_dim_out) in enumerate(up_in_out):
121
+ skip_dim = skip_dims.pop() # down_dim_out
122
+ self.ups.append(
123
+ nn.ModuleList(
124
+ [
125
+ # pop&cat (h/2, w/2, down_dim_out)
126
+ block_klass(up_dim_in + skip_dim, up_dim_in, time_emb_dim=time_dim),
127
+ Residual(PreNorm(up_dim_in, attn_klass(up_dim_in, label_emb_dim=label_emb_dim, ))),
128
+ Upsample(up_dim_in),
129
+ # pop&cat (h, w, down_dim_out)
130
+ block_klass(up_dim_in + skip_dim, up_dim_out, time_emb_dim=time_dim),
131
+ Residual(PreNorm(up_dim_out, attn_klass(up_dim_out, label_emb_dim=label_emb_dim, ))),
132
+ # pop&cat (h, w, down_dim_out)
133
+ block_klass(up_dim_out + skip_dim, up_dim_out, time_emb_dim=time_dim),
134
+ Residual(PreNorm(up_dim_out, attn_klass(up_dim_out, label_emb_dim=label_emb_dim, ))),
135
+ ]
136
+ )
137
+ )
138
+
139
+ self.final_conv = nn.Sequential(
140
+ block_klass(down_dims[0] + up_dims[-1], up_dims[-1]), nn.Conv2d(up_dims[-1], out_dim, 3, padding=1)
141
+ )
142
+
143
+ def size(self):
144
+ total_params = sum(p.numel() for p in self.parameters())
145
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
146
+ print(f"Total parameters: {total_params}")
147
+ print(f"Trainable parameters: {trainable_params}")
148
+
149
+
150
+ def forward(self, x, time, condition=None):
151
+
152
+ if condition is not None:
153
+ condition_emb = self.label_embedding(condition)
154
+ else:
155
+ condition_emb = None
156
+
157
+ h = []
158
+
159
+ x = self.init_conv(x)
160
+ h.append(x)
161
+
162
+ time_emb = self.time_mlp(time) if exists(self.time_mlp) else None
163
+
164
+ # downsample
165
+ for block1, attn1, block2, attn2, downsample in self.downs:
166
+ x = block1(x, time_emb)
167
+ x = attn1(x, condition_emb)
168
+ h.append(x)
169
+ x = block2(x, time_emb)
170
+ x = attn2(x, condition_emb)
171
+ h.append(x)
172
+ x = downsample(x)
173
+ h.append(x)
174
+
175
+ # bottleneck
176
+
177
+ for block in self.mid_left:
178
+ x = block(x, time_emb)
179
+ h.append(x)
180
+
181
+ (block1, attn, block2) = self.mid_mid
182
+ x = block1(x, time_emb)
183
+ x = attn(x, condition_emb)
184
+ x = block2(x, time_emb)
185
+
186
+ for block in self.mid_right:
187
+ # This is U-Net!!!
188
+ x = pad_and_concat(h.pop(), x)
189
+ x = block(x, time_emb)
190
+
191
+ # upsample
192
+ for block1, attn1, upsample, block2, attn2, block3, attn3 in self.ups:
193
+ x = pad_and_concat(h.pop(), x)
194
+ x = block1(x, time_emb)
195
+ x = attn1(x, condition_emb)
196
+ x = upsample(x)
197
+
198
+ x = pad_and_concat(h.pop(), x)
199
+ x = block2(x, time_emb)
200
+ x = attn2(x, condition_emb)
201
+
202
+ x = pad_and_concat(h.pop(), x)
203
+ x = block3(x, time_emb)
204
+ x = attn3(x, condition_emb)
205
+
206
+ x = pad_and_concat(h.pop(), x)
207
+ x = self.final_conv(x)
208
+ return x
209
+
210
+
211
+ def conditional_p_losses(denoise_model, x_start, t, condition, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod,
212
+ noise=None, loss_type="l1"):
213
+ if noise is None:
214
+ noise = torch.randn_like(x_start)
215
+
216
+ x_noisy = q_sample(x_start=x_start, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod,
217
+ sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
218
+ predicted_noise = denoise_model(x_noisy, t, condition)
219
+
220
+ if loss_type == 'l1':
221
+ loss = F.l1_loss(noise, predicted_noise)
222
+ elif loss_type == 'l2':
223
+ loss = F.mse_loss(noise, predicted_noise)
224
+ elif loss_type == "huber":
225
+ loss = F.smooth_l1_loss(noise, predicted_noise)
226
+ else:
227
+ raise NotImplementedError()
228
+
229
+ return loss
230
+
231
+
232
+ def evaluate_diffusion_model(device, model, iterator, BATCH_SIZE, timesteps, unetConfig, encodes2embeddings_mapping,
233
+ uncondition_rate, unconditional_condition):
234
+ model.to(device)
235
+ model.eval()
236
+ eva_loss = []
237
+ sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, _, _ = get_beta_schedule(timesteps)
238
+ for i in xrange(500):
239
+ data, attributes = next(iter(iterator))
240
+ data = data.to(device)
241
+
242
+ conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes]
243
+ selected_conditions = [
244
+ unconditional_condition if random.random() < uncondition_rate else random.choice(conditions_of_one_sample)
245
+ for conditions_of_one_sample in conditions]
246
+
247
+ selected_conditions = torch.stack(selected_conditions).float().to(device)
248
+
249
+ t = torch.randint(0, timesteps, (BATCH_SIZE,), device=device).long()
250
+ loss = conditional_p_losses(model, data, t, selected_conditions, loss_type="huber",
251
+ sqrt_alphas_cumprod=sqrt_alphas_cumprod,
252
+ sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod)
253
+
254
+ eva_loss.append(loss.item())
255
+ initial_loss = np.mean(eva_loss)
256
+ return initial_loss
257
+
258
+
259
+ def get_diffusion_model(model_Config, load_pretrain=False, model_name=None, device="cpu"):
260
+ UNet = ConditionedUnet(**model_Config)
261
+ print(f"Model intialized, size: {sum(p.numel() for p in UNet.parameters() if p.requires_grad)}")
262
+ UNet.to(device)
263
+
264
+ if load_pretrain:
265
+ print(f"Loading weights from models/{model_name}_UNet.pth")
266
+ checkpoint = torch.load(f'models/{model_name}_UNet.pth', map_location=device)
267
+ UNet.load_state_dict(checkpoint['model_state_dict'])
268
+ UNet.eval()
269
+ return UNet
270
+
271
+
272
+ def train_diffusion_model(VAE, text_encoder, CLAP_tokenizer, timbre_encoder, device, init_model_name, unetConfig, BATCH_SIZE, timesteps, lr, max_iter, iterator, load_pretrain,
273
+ encodes2embeddings_mapping, uncondition_rate, unconditional_condition, save_steps=5000, init_loss=None, save_model_name=None,
274
+ n_IS_batches=50):
275
+
276
+ if save_model_name is None:
277
+ save_model_name = init_model_name
278
+
279
+ def save_model_hyperparameter(model_name, unetConfig, BATCH_SIZE, lr, model_size, current_iter, current_loss):
280
+ model_hyperparameter = unetConfig
281
+ model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE
282
+ model_hyperparameter["lr"] = lr
283
+ model_hyperparameter["model_size"] = model_size
284
+ model_hyperparameter["current_iter"] = current_iter
285
+ model_hyperparameter["current_loss"] = current_loss
286
+ with open(f"models/hyperparameters/{model_name}_UNet.json", "w") as json_file:
287
+ json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
288
+
289
+ model = ConditionedUnet(**unetConfig)
290
+ model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
291
+ print(f"Trainable parameters: {model_size}")
292
+ model.to(device)
293
+ optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, amsgrad=False)
294
+
295
+ if load_pretrain:
296
+ print(f"Loading weights from models/{init_model_name}_UNet.pt")
297
+ checkpoint = torch.load(f'models/{init_model_name}_UNet.pth')
298
+ model.load_state_dict(checkpoint['model_state_dict'])
299
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
300
+ else:
301
+ print("Model initialized.")
302
+ if max_iter == 0:
303
+ print("Return model directly.")
304
+ return model, optimizer
305
+
306
+
307
+ train_loss = []
308
+ writer = SummaryWriter(f'runs/{save_model_name}_UNet')
309
+ if init_loss is None:
310
+ previous_loss = evaluate_diffusion_model(device, model, iterator, BATCH_SIZE, timesteps, unetConfig, encodes2embeddings_mapping,
311
+ uncondition_rate, unconditional_condition)
312
+ else:
313
+ previous_loss = init_loss
314
+ print(f"initial_IS: {previous_loss}")
315
+ sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, _, _ = get_beta_schedule(timesteps)
316
+
317
+ model.train()
318
+ for i in xrange(max_iter):
319
+ data, attributes = next(iter(iterator))
320
+ data = data.to(device)
321
+
322
+ conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes]
323
+ unconditional_condition_copy = torch.tensor(unconditional_condition, dtype=torch.float32).to(device).detach()
324
+ selected_conditions = [unconditional_condition_copy if random.random() < uncondition_rate else random.choice(
325
+ conditions_of_one_sample) for conditions_of_one_sample in conditions]
326
+
327
+ selected_conditions = torch.stack(selected_conditions).float().to(device)
328
+
329
+ optimizer.zero_grad()
330
+
331
+ t = torch.randint(0, timesteps, (BATCH_SIZE,), device=device).long()
332
+ loss = conditional_p_losses(model, data, t, selected_conditions, loss_type="huber",
333
+ sqrt_alphas_cumprod=sqrt_alphas_cumprod,
334
+ sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod)
335
+
336
+ loss.backward()
337
+ optimizer.step()
338
+
339
+ train_loss.append(loss.item())
340
+ step = int(optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].numpy())
341
+
342
+ if step % 100 == 0:
343
+ print('%d step' % (step))
344
+
345
+ if step % save_steps == 0:
346
+ current_loss = np.mean(train_loss[-save_steps:])
347
+ print(f"current_loss = {current_loss}")
348
+ torch.save({
349
+ 'model_state_dict': model.state_dict(),
350
+ 'optimizer_state_dict': optimizer.state_dict(),
351
+ }, f'models/{save_model_name}_UNet.pth')
352
+ save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, lr, model_size, step, current_loss)
353
+
354
+
355
+ if step % 20000 == 0:
356
+ current_IS = get_inception_score(device, model, VAE, text_encoder, CLAP_tokenizer, timbre_encoder, n_IS_batches,
357
+ positive_prompts="", negative_prompts="", CFG=1, sample_steps=20, task="STFT")
358
+ print('current_IS: %.5f' % current_IS)
359
+ current_loss = np.mean(train_loss[-save_steps:])
360
+
361
+ writer.add_scalar(f"current_IS", current_IS, step)
362
+
363
+ torch.save({
364
+ 'model_state_dict': model.state_dict(),
365
+ 'optimizer_state_dict': optimizer.state_dict(),
366
+ }, f'models/history/{save_model_name}_{step}_UNet.pth')
367
+ save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, lr, model_size, step, current_loss)
368
+
369
+ return model, optimizer
370
+
371
+
model/diffusion_components.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ import torch
3
+ from torch import nn
4
+ from einops import rearrange
5
+ from inspect import isfunction
6
+ import math
7
+ from tqdm import tqdm
8
+
9
+
10
+ def exists(x):
11
+ """Return true for x is not None."""
12
+ return x is not None
13
+
14
+
15
+ def default(val, d):
16
+ """Helper function"""
17
+ if exists(val):
18
+ return val
19
+ return d() if isfunction(d) else d
20
+
21
+
22
+ class Residual(nn.Module):
23
+ """Skip connection"""
24
+ def __init__(self, fn):
25
+ super().__init__()
26
+ self.fn = fn
27
+
28
+ def forward(self, x, *args, **kwargs):
29
+ return self.fn(x, *args, **kwargs) + x
30
+
31
+
32
+ def Upsample(dim):
33
+ """Upsample layer, a transposed convolution layer with stride=2"""
34
+ return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
35
+
36
+
37
+ def Downsample(dim):
38
+ """Downsample layer, a convolution layer with stride=2"""
39
+ return nn.Conv2d(dim, dim, 4, 2, 1)
40
+
41
+
42
+ class SinusoidalPositionEmbeddings(nn.Module):
43
+ """Return sinusoidal embedding for integer time step."""
44
+
45
+ def __init__(self, dim):
46
+ super().__init__()
47
+ self.dim = dim
48
+
49
+ def forward(self, time):
50
+ device = time.device
51
+ half_dim = self.dim // 2
52
+ embeddings = math.log(10000) / (half_dim - 1)
53
+ embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
54
+ embeddings = time[:, None] * embeddings[None, :]
55
+ embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
56
+ return embeddings
57
+
58
+
59
+ class Block(nn.Module):
60
+ """Stack of convolution, normalization, and non-linear activation"""
61
+
62
+ def __init__(self, dim, dim_out, groups=8):
63
+ super().__init__()
64
+ self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
65
+ self.norm = nn.GroupNorm(groups, dim_out)
66
+ self.act = nn.SiLU()
67
+
68
+ def forward(self, x, scale_shift=None):
69
+ x = self.proj(x)
70
+ x = self.norm(x)
71
+
72
+ if exists(scale_shift):
73
+ scale, shift = scale_shift
74
+ x = x * (scale + 1) + shift
75
+
76
+ x = self.act(x)
77
+ return x
78
+
79
+
80
+ class ResnetBlock(nn.Module):
81
+ """Stack of [conv + norm + act (+ scale&shift)], with positional embedding inserted <https://arxiv.org/abs/1512.03385>"""
82
+
83
+ def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
84
+ super().__init__()
85
+ self.mlp = (
86
+ nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
87
+ if exists(time_emb_dim)
88
+ else None
89
+ )
90
+
91
+ self.block1 = Block(dim, dim_out, groups=groups)
92
+ self.block2 = Block(dim_out, dim_out, groups=groups)
93
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
94
+
95
+ def forward(self, x, time_emb=None):
96
+ h = self.block1(x)
97
+
98
+ if exists(self.mlp) and exists(time_emb):
99
+ time_emb = self.mlp(time_emb)
100
+ # Adding positional embedding to intermediate layer (by broadcasting along spatial dimension)
101
+ h = rearrange(time_emb, "b c -> b c 1 1") + h
102
+
103
+ h = self.block2(h)
104
+ return h + self.res_conv(x)
105
+
106
+
107
+ class ConvNextBlock(nn.Module):
108
+ """Stack of [conv7x7 (+ condition(pos)) + norm + conv3x3 + act + norm + conv3x3 + res1x1],with positional embedding inserted"""
109
+
110
+ def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
111
+ super().__init__()
112
+ self.mlp = (
113
+ nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
114
+ if exists(time_emb_dim)
115
+ else None
116
+ )
117
+
118
+ self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
119
+
120
+ self.net = nn.Sequential(
121
+ nn.GroupNorm(1, dim) if norm else nn.Identity(),
122
+ nn.Conv2d(dim, dim_out * mult, 3, padding=1),
123
+ nn.GELU(),
124
+ nn.GroupNorm(1, dim_out * mult),
125
+ nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
126
+ )
127
+
128
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
129
+
130
+ def forward(self, x, time_emb=None):
131
+ h = self.ds_conv(x)
132
+
133
+ if exists(self.mlp) and exists(time_emb):
134
+ assert exists(time_emb), "time embedding must be passed in"
135
+ condition = self.mlp(time_emb)
136
+ h = h + rearrange(condition, "b c -> b c 1 1")
137
+
138
+ h = self.net(h)
139
+ return h + self.res_conv(x)
140
+
141
+
142
+ class PreNorm(nn.Module):
143
+ """Apply normalization before 'fn'"""
144
+
145
+ def __init__(self, dim, fn):
146
+ super().__init__()
147
+ self.fn = fn
148
+ self.norm = nn.GroupNorm(1, dim)
149
+
150
+ def forward(self, x, *args, **kwargs):
151
+ x = self.norm(x)
152
+ return self.fn(x, *args, **kwargs)
153
+
154
+
155
+ class ConditionalEmbedding(nn.Module):
156
+ """Return embedding for label and projection for text embedding"""
157
+
158
+ def __init__(self, num_labels, embedding_dim, condition_type="instrument_family"):
159
+ super(ConditionalEmbedding, self).__init__()
160
+ if condition_type == "instrument_family":
161
+ self.embedding = nn.Embedding(num_labels, embedding_dim)
162
+ elif condition_type == "natural_language_prompt":
163
+ self.embedding = nn.Linear(embedding_dim, embedding_dim, bias=True)
164
+ else:
165
+ raise NotImplementedError()
166
+
167
+ def forward(self, labels):
168
+ return self.embedding(labels)
169
+
170
+
171
+ class LinearCrossAttention(nn.Module):
172
+ """Combination of efficient attention and cross attention."""
173
+
174
+ def __init__(self, dim, heads=4, label_emb_dim=128, dim_head=32):
175
+ super().__init__()
176
+ self.dim_head = dim_head
177
+ self.scale = dim_head ** -0.5
178
+ self.heads = heads
179
+ hidden_dim = dim_head * heads
180
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
181
+ self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
182
+
183
+ # embedding for key and value
184
+ self.label_key = nn.Linear(label_emb_dim, hidden_dim)
185
+ self.label_value = nn.Linear(label_emb_dim, hidden_dim)
186
+
187
+ def forward(self, x, label_embedding=None):
188
+ b, c, h, w = x.shape
189
+ qkv = self.to_qkv(x).chunk(3, dim=1)
190
+ q, k, v = map(
191
+ lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
192
+ )
193
+
194
+ if label_embedding is not None:
195
+ label_k = self.label_key(label_embedding).view(b, self.heads, self.dim_head, 1)
196
+ label_v = self.label_value(label_embedding).view(b, self.heads, self.dim_head, 1)
197
+
198
+ k = torch.cat([k, label_k], dim=-1)
199
+ v = torch.cat([v, label_v], dim=-1)
200
+
201
+ q = q.softmax(dim=-2)
202
+ k = k.softmax(dim=-1)
203
+ q = q * self.scale
204
+ context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
205
+ out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
206
+ out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
207
+ return self.to_out(out)
208
+
209
+
210
+ def pad_to_match(encoder_tensor, decoder_tensor):
211
+ """
212
+ Pads the decoder_tensor to match the spatial dimensions of encoder_tensor.
213
+
214
+ :param encoder_tensor: The feature map from the encoder.
215
+ :param decoder_tensor: The feature map from the decoder that needs to be upsampled.
216
+ :return: Padded decoder_tensor with the same spatial dimensions as encoder_tensor.
217
+ """
218
+
219
+ enc_shape = encoder_tensor.shape[2:] # spatial dimensions are at index 2 and 3
220
+ dec_shape = decoder_tensor.shape[2:]
221
+
222
+ # assume enc_shape >= dec_shape
223
+ delta_w = enc_shape[1] - dec_shape[1]
224
+ delta_h = enc_shape[0] - dec_shape[0]
225
+
226
+ # padding
227
+ padding_left = delta_w // 2
228
+ padding_right = delta_w - padding_left
229
+ padding_top = delta_h // 2
230
+ padding_bottom = delta_h - padding_top
231
+ decoder_tensor_padded = F.pad(decoder_tensor, (padding_left, padding_right, padding_top, padding_bottom))
232
+
233
+ return decoder_tensor_padded
234
+
235
+
236
+ def pad_and_concat(encoder_tensor, decoder_tensor):
237
+ """
238
+ Pads the decoder_tensor and concatenates it with the encoder_tensor along the channel dimension.
239
+
240
+ :param encoder_tensor: The feature map from the encoder.
241
+ :param decoder_tensor: The feature map from the decoder that needs to be concatenated with encoder_tensor.
242
+ :return: Concatenated tensor.
243
+ """
244
+
245
+ # pad decoder_tensor
246
+ decoder_tensor_padded = pad_to_match(encoder_tensor, decoder_tensor)
247
+ # concat encoder_tensor and decoder_tensor_padded
248
+ concatenated_tensor = torch.cat((encoder_tensor, decoder_tensor_padded), dim=1)
249
+ return concatenated_tensor
250
+
251
+
252
+ class LinearCrossAttentionAdd(nn.Module):
253
+ def __init__(self, dim, heads=4, label_emb_dim=128, dim_head=32):
254
+ super().__init__()
255
+ self.dim = dim
256
+ self.dim_head = dim_head
257
+ self.scale = dim_head ** -0.5
258
+ self.heads = heads
259
+ self.label_emb_dim = label_emb_dim
260
+ self.dim_head = dim_head
261
+
262
+ self.hidden_dim = dim_head * heads
263
+ self.to_qkv = nn.Conv2d(self.dim, self.hidden_dim * 3, 1, bias=False)
264
+ self.to_out = nn.Sequential(nn.Conv2d(self.hidden_dim, dim, 1), nn.GroupNorm(1, dim))
265
+
266
+ # embedding for key and value
267
+ self.label_key = nn.Linear(label_emb_dim, self.hidden_dim)
268
+ self.label_query = nn.Linear(label_emb_dim, self.hidden_dim)
269
+
270
+
271
+ def forward(self, x, condition=None):
272
+ b, c, h, w = x.shape
273
+
274
+ qkv = self.to_qkv(x).chunk(3, dim=1)
275
+
276
+ q, k, v = map(
277
+ lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
278
+ )
279
+
280
+ # if condition exists,concat its key and value with origin
281
+ if condition is not None:
282
+ label_k = self.label_key(condition).view(b, self.heads, self.dim_head, 1)
283
+ label_q = self.label_query(condition).view(b, self.heads, self.dim_head, 1)
284
+ k = k + label_k
285
+ q = q + label_q
286
+
287
+ q = q.softmax(dim=-2)
288
+ k = k.softmax(dim=-1)
289
+ q = q * self.scale
290
+ context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
291
+ out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
292
+ out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
293
+ return self.to_out(out)
294
+
295
+
296
+
297
+ def linear_beta_schedule(timesteps):
298
+ beta_start = 0.0001
299
+ beta_end = 0.02
300
+ return torch.linspace(beta_start, beta_end, timesteps)
301
+
302
+
303
+ def get_beta_schedule(timesteps):
304
+ betas = linear_beta_schedule(timesteps=timesteps)
305
+
306
+ # define alphas
307
+ alphas = 1. - betas
308
+ alphas_cumprod = torch.cumprod(alphas, axis=0)
309
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
310
+ sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
311
+
312
+ # calculations for diffusion q(x_t | x_{t-1}) and others
313
+ sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
314
+ sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
315
+
316
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
317
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
318
+ return sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, posterior_variance, sqrt_recip_alphas
319
+
320
+
321
+ def extract(a, t, x_shape):
322
+ batch_size = t.shape[0]
323
+ out = a.gather(-1, t.cpu())
324
+ return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
325
+
326
+
327
+ # forward diffusion
328
+ def q_sample(x_start, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, noise=None):
329
+ if noise is None:
330
+ noise = torch.randn_like(x_start)
331
+
332
+ sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
333
+ sqrt_one_minus_alphas_cumprod_t = extract(
334
+ sqrt_one_minus_alphas_cumprod, t, x_start.shape
335
+ )
336
+
337
+ return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
338
+
339
+
340
+
341
+
342
+
343
+
344
+
345
+
346
+
347
+
348
+
349
+
350
+
351
+
model/multimodal_model.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import json
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+
10
+ from tools import create_key
11
+ from model.timbre_encoder_pretrain import get_timbre_encoder
12
+
13
+
14
+ class ProjectionLayer(nn.Module):
15
+ """Single-layer Linear projection with dropout, layer norm, and Gelu activation"""
16
+
17
+ def __init__(self, input_dim, output_dim, dropout):
18
+ super(ProjectionLayer, self).__init__()
19
+ self.projection = nn.Linear(input_dim, output_dim)
20
+ self.gelu = nn.GELU()
21
+ self.fc = nn.Linear(output_dim, output_dim)
22
+ self.dropout = nn.Dropout(dropout)
23
+ self.layer_norm = nn.LayerNorm(output_dim)
24
+
25
+ def forward(self, x):
26
+ projected = self.projection(x)
27
+ x = self.gelu(projected)
28
+ x = self.fc(x)
29
+ x = self.dropout(x)
30
+ x = x + projected
31
+ x = self.layer_norm(x)
32
+ return x
33
+
34
+
35
+ class ProjectionHead(nn.Module):
36
+ """Stack of 'ProjectionLayer'"""
37
+
38
+ def __init__(self, embedding_dim, projection_dim, dropout, num_layers=2):
39
+ super(ProjectionHead, self).__init__()
40
+ self.layers = nn.ModuleList([ProjectionLayer(embedding_dim if i == 0 else projection_dim,
41
+ projection_dim,
42
+ dropout) for i in range(num_layers)])
43
+
44
+ def forward(self, x):
45
+ for layer in self.layers:
46
+ x = layer(x)
47
+ return x
48
+
49
+
50
+ class multi_modal_model(nn.Module):
51
+ """The multi-modal model for contrastive learning"""
52
+
53
+ def __init__(
54
+ self,
55
+ timbre_encoder,
56
+ text_encoder,
57
+ spectrogram_feature_dim,
58
+ text_feature_dim,
59
+ multi_modal_emb_dim,
60
+ temperature,
61
+ dropout,
62
+ num_projection_layers=1,
63
+ freeze_spectrogram_encoder=True,
64
+ freeze_text_encoder=True,
65
+ ):
66
+ super().__init__()
67
+ self.timbre_encoder = timbre_encoder
68
+ self.text_encoder = text_encoder
69
+
70
+ self.multi_modal_emb_dim = multi_modal_emb_dim
71
+
72
+ self.text_projection = ProjectionHead(embedding_dim=text_feature_dim,
73
+ projection_dim=self.multi_modal_emb_dim, dropout=dropout,
74
+ num_layers=num_projection_layers)
75
+
76
+ self.spectrogram_projection = ProjectionHead(embedding_dim=spectrogram_feature_dim,
77
+ projection_dim=self.multi_modal_emb_dim, dropout=dropout,
78
+ num_layers=num_projection_layers)
79
+
80
+ self.temperature = temperature
81
+
82
+ # Make spectrogram_encoder parameters non-trainable
83
+ for param in self.timbre_encoder.parameters():
84
+ param.requires_grad = not freeze_spectrogram_encoder
85
+
86
+ # Make text_encoder parameters non-trainable
87
+ for param in self.text_encoder.parameters():
88
+ param.requires_grad = not freeze_text_encoder
89
+
90
+ def forward(self, spectrogram_batch, tokenized_text_batch):
91
+ # Getting Image and Text Embeddings (with same dimension)
92
+ spectrogram_features, _, _, _, _ = self.timbre_encoder(spectrogram_batch)
93
+ text_features = self.text_encoder.get_text_features(**tokenized_text_batch)
94
+
95
+ # Concat and apply projection
96
+ spectrogram_embeddings = self.spectrogram_projection(spectrogram_features)
97
+ text_embeddings = self.text_projection(text_features)
98
+
99
+ # Calculating the Loss
100
+ logits = (text_embeddings @ spectrogram_embeddings.T) / self.temperature
101
+ images_similarity = spectrogram_embeddings @ spectrogram_embeddings.T
102
+ texts_similarity = text_embeddings @ text_embeddings.T
103
+ targets = F.softmax(
104
+ (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
105
+ )
106
+ texts_loss = cross_entropy(logits, targets, reduction='none')
107
+ images_loss = cross_entropy(logits.T, targets.T, reduction='none')
108
+ contrastive_loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
109
+ contrastive_loss = contrastive_loss.mean()
110
+
111
+ return contrastive_loss
112
+
113
+
114
+ def get_text_features(self, input_ids, attention_mask):
115
+ text_features = self.text_encoder.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
116
+ return self.text_projection(text_features)
117
+
118
+
119
+ def get_timbre_features(self, spectrogram_batch):
120
+ spectrogram_features, _, _, _, _ = self.timbre_encoder(spectrogram_batch)
121
+ return self.spectrogram_projection(spectrogram_features)
122
+
123
+
124
+ def cross_entropy(preds, targets, reduction='none'):
125
+ log_softmax = nn.LogSoftmax(dim=-1)
126
+ loss = (-targets * log_softmax(preds)).sum(1)
127
+ if reduction == "none":
128
+ return loss
129
+ elif reduction == "mean":
130
+ return loss.mean()
131
+
132
+
133
+ def get_multi_modal_model(timbre_encoder, text_encoder, model_Config, load_pretrain=False, model_name=None, device="cpu"):
134
+ mmm = multi_modal_model(timbre_encoder, text_encoder, **model_Config)
135
+ print(f"Model intialized, size: {sum(p.numel() for p in mmm.parameters() if p.requires_grad)}")
136
+ mmm.to(device)
137
+
138
+ if load_pretrain:
139
+ print(f"Loading weights from models/{model_name}_MMM.pth")
140
+ checkpoint = torch.load(f'models/{model_name}_MMM.pth', map_location=device)
141
+ mmm.load_state_dict(checkpoint['model_state_dict'])
142
+ mmm.eval()
143
+ return mmm
144
+
145
+
146
+ def train_epoch(text_tokenizer, model, train_loader, labels_mapping, optimizer, device):
147
+ (data, attributes) = next(iter(train_loader))
148
+ keys = [create_key(attribute) for attribute in attributes]
149
+
150
+ while(len(set(keys)) != len(keys)):
151
+ (data, attributes) = next(iter(train_loader))
152
+ keys = [create_key(attribute) for attribute in attributes]
153
+
154
+ data = data.to(device)
155
+
156
+ texts = [labels_mapping[create_key(attribute)] for attribute in attributes]
157
+ selected_texts = [l[random.randint(0, len(l) - 1)] for l in texts]
158
+
159
+ tokenized_text = text_tokenizer(selected_texts, padding=True, return_tensors="pt").to(device)
160
+
161
+ loss = model(data, tokenized_text)
162
+ optimizer.zero_grad()
163
+ loss.backward()
164
+ optimizer.step()
165
+
166
+ return loss.item()
167
+
168
+
169
+ def valid_epoch(text_tokenizer, model, valid_loader, labels_mapping, device):
170
+ (data, attributes) = next(iter(valid_loader))
171
+ keys = [create_key(attribute) for attribute in attributes]
172
+
173
+ while(len(set(keys)) != len(keys)):
174
+ (data, attributes) = next(iter(valid_loader))
175
+ keys = [create_key(attribute) for attribute in attributes]
176
+
177
+ data = data.to(device)
178
+ texts = [labels_mapping[create_key(attribute)] for attribute in attributes]
179
+ selected_texts = [l[random.randint(0, len(l) - 1)] for l in texts]
180
+
181
+ tokenized_text = text_tokenizer(selected_texts, padding=True, return_tensors="pt").to(device)
182
+
183
+ loss = model(data, tokenized_text)
184
+ return loss.item()
185
+
186
+
187
+ def train_multi_modal_model(device, training_dataloader, labels_mapping, text_tokenizer, text_encoder,
188
+ timbre_encoder_Config, MMM_config, MMM_training_config,
189
+ mmm_name, BATCH_SIZE, max_iter=0, load_pretrain=True,
190
+ timbre_encoder_name=None, init_loss=None, save_steps=2000):
191
+
192
+ def save_model_hyperparameter(model_name, MMM_config, MMM_training_config, BATCH_SIZE, model_size, current_iter,
193
+ current_loss):
194
+
195
+ model_hyperparameter = MMM_config
196
+ model_hyperparameter.update(MMM_training_config)
197
+ model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE
198
+ model_hyperparameter["model_size"] = model_size
199
+ model_hyperparameter["current_iter"] = current_iter
200
+ model_hyperparameter["current_loss"] = current_loss
201
+ with open(f"models/hyperparameters/{model_name}_MMM.json", "w") as json_file:
202
+ json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
203
+
204
+ timbreEncoder = get_timbre_encoder(timbre_encoder_Config, load_pretrain=True, model_name=timbre_encoder_name,
205
+ device=device)
206
+
207
+ mmm = multi_modal_model(timbreEncoder, text_encoder, **MMM_config).to(device)
208
+
209
+ print(f"spectrogram_encoder parameter: {sum(p.numel() for p in mmm.timbre_encoder.parameters())}")
210
+ print(f"text_encoder parameter: {sum(p.numel() for p in mmm.text_encoder.parameters())}")
211
+ print(f"spectrogram_projection parameter: {sum(p.numel() for p in mmm.spectrogram_projection.parameters())}")
212
+ print(f"text_projection parameter: {sum(p.numel() for p in mmm.text_projection.parameters())}")
213
+ total_parameters = sum(p.numel() for p in mmm.parameters())
214
+ trainable_parameters = sum(p.numel() for p in mmm.parameters() if p.requires_grad)
215
+ print(f"Trainable/Total parameter: {trainable_parameters}/{total_parameters}")
216
+
217
+ params = [
218
+ {"params": itertools.chain(
219
+ mmm.spectrogram_projection.parameters(),
220
+ mmm.text_projection.parameters(),
221
+ ), "lr": MMM_training_config["head_lr"], "weight_decay": MMM_training_config["head_weight_decay"]},
222
+ ]
223
+ if not MMM_config["freeze_text_encoder"]:
224
+ params.append({"params": mmm.text_encoder.parameters(), "lr": MMM_training_config["text_encoder_lr"],
225
+ "weight_decay": MMM_training_config["text_encoder_weight_decay"]})
226
+ if not MMM_config["freeze_spectrogram_encoder"]:
227
+ params.append({"params": mmm.timbre_encoder.parameters(), "lr": MMM_training_config["spectrogram_encoder_lr"],
228
+ "weight_decay": MMM_training_config["timbre_encoder_weight_decay"]})
229
+
230
+ optimizer = torch.optim.AdamW(params, weight_decay=0.)
231
+
232
+ if load_pretrain:
233
+ print(f"Loading weights from models/{mmm_name}_MMM.pt")
234
+ checkpoint = torch.load(f'models/{mmm_name}_MMM.pth')
235
+ mmm.load_state_dict(checkpoint['model_state_dict'])
236
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
237
+ else:
238
+ print("Model initialized.")
239
+
240
+ if max_iter == 0:
241
+ print("Return model directly.")
242
+ return mmm, optimizer
243
+
244
+ if init_loss is None:
245
+ previous_lowest_loss = valid_epoch(text_tokenizer, mmm, training_dataloader, labels_mapping, device)
246
+ else:
247
+ previous_lowest_loss = init_loss
248
+ print(f"Initial total loss: {previous_lowest_loss}")
249
+
250
+ train_loss_list = []
251
+ for i in range(max_iter):
252
+
253
+ mmm.train()
254
+ train_loss = train_epoch(text_tokenizer, mmm, training_dataloader, labels_mapping, optimizer, device)
255
+ train_loss_list.append(train_loss)
256
+
257
+ step = int(
258
+ optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].cpu().numpy())
259
+ if (i + 1) % 100 == 0:
260
+ print('%d step' % (step))
261
+
262
+ if (i + 1) % save_steps == 0:
263
+ current_loss = np.mean(train_loss_list[-save_steps:])
264
+ print(f"train_total_loss: {current_loss}")
265
+ if current_loss < previous_lowest_loss:
266
+ previous_lowest_loss = current_loss
267
+ torch.save({
268
+ 'model_state_dict': mmm.state_dict(),
269
+ 'optimizer_state_dict': optimizer.state_dict(),
270
+ }, f'models/{mmm_name}_MMM.pth')
271
+ save_model_hyperparameter(mmm_name, MMM_config, MMM_training_config, BATCH_SIZE, total_parameters, step,
272
+ current_loss)
273
+
274
+ return mmm, optimizer
model/timbre_encoder_pretrain.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.utils.tensorboard import SummaryWriter
6
+ from tools import create_key
7
+
8
+
9
+ class TimbreEncoder(nn.Module):
10
+ def __init__(self, input_dim, feature_dim, hidden_dim, num_instrument_classes, num_instrument_family_classes, num_velocity_classes, num_qualities, num_layers=1):
11
+ super(TimbreEncoder, self).__init__()
12
+
13
+ # Input layer
14
+ self.input_layer = nn.Linear(input_dim, feature_dim)
15
+
16
+ # LSTM Layer
17
+ self.lstm = nn.LSTM(feature_dim, hidden_dim, num_layers=num_layers, batch_first=True)
18
+
19
+ # Fully Connected Layers for classification
20
+ self.instrument_classifier_layer = nn.Linear(hidden_dim, num_instrument_classes)
21
+ self.instrument_family_classifier_layer = nn.Linear(hidden_dim, num_instrument_family_classes)
22
+ self.velocity_classifier_layer = nn.Linear(hidden_dim, num_velocity_classes)
23
+ self.qualities_classifier_layer = nn.Linear(hidden_dim, num_qualities)
24
+
25
+ # Softmax for converting output to probabilities
26
+ self.softmax = nn.LogSoftmax(dim=1)
27
+
28
+ def forward(self, x):
29
+ # # Merge first two dimensions
30
+ batch_size, _, _, seq_len = x.shape
31
+ x = x.view(batch_size, -1, seq_len) # [batch_size, input_dim, seq_len]
32
+
33
+ # Forward propagate LSTM
34
+ x = x.permute(0, 2, 1)
35
+ x = self.input_layer(x)
36
+ feature, _ = self.lstm(x)
37
+ feature = feature[:, -1, :]
38
+
39
+ # Apply classification layers
40
+ instrument_logits = self.instrument_classifier_layer(feature)
41
+ instrument_family_logits = self.instrument_family_classifier_layer(feature)
42
+ velocity_logits = self.velocity_classifier_layer(feature)
43
+ qualities = self.qualities_classifier_layer(feature)
44
+
45
+ # Apply Softmax
46
+ instrument_logits = self.softmax(instrument_logits)
47
+ instrument_family_logits= self.softmax(instrument_family_logits)
48
+ velocity_logits = self.softmax(velocity_logits)
49
+ qualities = torch.sigmoid(qualities)
50
+
51
+ return feature, instrument_logits, instrument_family_logits, velocity_logits, qualities
52
+
53
+
54
+ def get_multiclass_acc(outputs, ground_truth):
55
+ _, predicted = torch.max(outputs.data, 1)
56
+ total = ground_truth.size(0)
57
+ correct = (predicted == ground_truth).sum().item()
58
+ accuracy = 100 * correct / total
59
+ return accuracy
60
+
61
+ def get_binary_accuracy(y_pred, y_true):
62
+ predictions = (y_pred > 0.5).int()
63
+
64
+ correct_predictions = (predictions == y_true).float()
65
+
66
+ accuracy = correct_predictions.mean()
67
+
68
+ return accuracy.item() * 100.0
69
+
70
+
71
+ def get_timbre_encoder(model_Config, load_pretrain=False, model_name=None, device="cpu"):
72
+ timbreEncoder = TimbreEncoder(**model_Config)
73
+ print(f"Model intialized, size: {sum(p.numel() for p in timbreEncoder.parameters() if p.requires_grad)}")
74
+ timbreEncoder.to(device)
75
+
76
+ if load_pretrain:
77
+ print(f"Loading weights from models/{model_name}_timbre_encoder.pth")
78
+ checkpoint = torch.load(f'models/{model_name}_timbre_encoder.pth', map_location=device)
79
+ timbreEncoder.load_state_dict(checkpoint['model_state_dict'])
80
+ timbreEncoder.eval()
81
+ return timbreEncoder
82
+
83
+
84
+ def evaluate_timbre_encoder(device, model, iterator, nll_Loss, bce_Loss, n_sample=100):
85
+ model.to(device)
86
+ model.eval()
87
+
88
+ eva_loss = []
89
+ for i in range(n_sample):
90
+ representation, attributes = next(iter(iterator))
91
+
92
+ instrument = torch.tensor([s["instrument"] for s in attributes], dtype=torch.long).to(device)
93
+ instrument_family = torch.tensor([s["instrument_family"] for s in attributes], dtype=torch.long).to(device)
94
+ velocity = torch.tensor([s["velocity"] for s in attributes], dtype=torch.long).to(device)
95
+ qualities = torch.tensor([[int(char) for char in create_key(attribute)[-10:]] for attribute in attributes], dtype=torch.float32).to(device)
96
+
97
+ _, instrument_logits, instrument_family_logits, velocity_logits, qualities_pred = model(representation.to(device))
98
+
99
+ # compute loss
100
+ instrument_loss = nll_Loss(instrument_logits, instrument)
101
+ instrument_family_loss = nll_Loss(instrument_family_logits, instrument_family)
102
+ velocity_loss = nll_Loss(velocity_logits, velocity)
103
+ qualities_loss = bce_Loss(qualities_pred, qualities)
104
+
105
+ loss = instrument_loss + instrument_family_loss + velocity_loss + qualities_loss
106
+
107
+ eva_loss.append(loss.item())
108
+
109
+ eva_loss = np.mean(eva_loss)
110
+ return eva_loss
111
+
112
+
113
+ def train_timbre_encoder(device, model_name, timbre_encoder_Config, BATCH_SIZE, lr, max_iter, training_iterator, load_pretrain):
114
+ def save_model_hyperparameter(model_name, timbre_encoder_Config, BATCH_SIZE, lr, model_size, current_iter,
115
+ current_loss):
116
+ model_hyperparameter = timbre_encoder_Config
117
+ model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE
118
+ model_hyperparameter["lr"] = lr
119
+ model_hyperparameter["model_size"] = model_size
120
+ model_hyperparameter["current_iter"] = current_iter
121
+ model_hyperparameter["current_loss"] = current_loss
122
+ with open(f"models/hyperparameters/{model_name}_timbre_encoder.json", "w") as json_file:
123
+ json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
124
+
125
+ model = TimbreEncoder(**timbre_encoder_Config)
126
+ model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
127
+ print(f"Model size: {model_size}")
128
+ model.to(device)
129
+ nll_Loss = torch.nn.NLLLoss()
130
+ bce_Loss = torch.nn.BCELoss()
131
+
132
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr, amsgrad=False)
133
+
134
+ if load_pretrain:
135
+ print(f"Loading weights from models/{model_name}_timbre_encoder.pt")
136
+ checkpoint = torch.load(f'models/{model_name}_timbre_encoder.pth')
137
+ model.load_state_dict(checkpoint['model_state_dict'])
138
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
139
+ else:
140
+ print("Model initialized.")
141
+ if max_iter == 0:
142
+ print("Return model directly.")
143
+ return model, model
144
+
145
+ train_loss, training_instrument_acc, training_instrument_family_acc, training_velocity_acc, training_qualities_acc = [], [], [], [], []
146
+ writer = SummaryWriter(f'runs/{model_name}_timbre_encoder')
147
+ current_best_model = model
148
+ previous_lowest_loss = 100.0
149
+ print(f"initial__loss: {previous_lowest_loss}")
150
+
151
+ for i in range(max_iter):
152
+ model.train()
153
+
154
+ representation, attributes = next(iter(training_iterator))
155
+
156
+ instrument = torch.tensor([s["instrument"] for s in attributes], dtype=torch.long).to(device)
157
+ instrument_family = torch.tensor([s["instrument_family"] for s in attributes], dtype=torch.long).to(device)
158
+ velocity = torch.tensor([s["velocity"] for s in attributes], dtype=torch.long).to(device)
159
+ qualities = torch.tensor([[int(char) for char in create_key(attribute)[-10:]] for attribute in attributes], dtype=torch.float32).to(device)
160
+
161
+ optimizer.zero_grad()
162
+
163
+ _, instrument_logits, instrument_family_logits, velocity_logits, qualities_pred = model(representation.to(device))
164
+
165
+ # compute loss
166
+ instrument_loss = nll_Loss(instrument_logits, instrument)
167
+ instrument_family_loss = nll_Loss(instrument_family_logits, instrument_family)
168
+ velocity_loss = nll_Loss(velocity_logits, velocity)
169
+ qualities_loss = bce_Loss(qualities_pred, qualities)
170
+
171
+ loss = instrument_loss + instrument_family_loss + velocity_loss + qualities_loss
172
+
173
+ loss.backward()
174
+ optimizer.step()
175
+ instrument_acc = get_multiclass_acc(instrument_logits, instrument)
176
+ instrument_family_acc = get_multiclass_acc(instrument_family_logits, instrument_family)
177
+ velocity_acc = get_multiclass_acc(velocity_logits, velocity)
178
+ qualities_acc = get_binary_accuracy(qualities_pred, qualities)
179
+
180
+ train_loss.append(loss.item())
181
+ training_instrument_acc.append(instrument_acc)
182
+ training_instrument_family_acc.append(instrument_family_acc)
183
+ training_velocity_acc.append(velocity_acc)
184
+ training_qualities_acc.append(qualities_acc)
185
+ step = int(optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].numpy())
186
+
187
+ if (i + 1) % 100 == 0:
188
+ print('%d step' % (step))
189
+
190
+ save_steps = 500
191
+ if (i + 1) % save_steps == 0:
192
+ current_loss = np.mean(train_loss[-save_steps:])
193
+ current_instrument_acc = np.mean(training_instrument_acc[-save_steps:])
194
+ current_instrument_family_acc = np.mean(training_instrument_family_acc[-save_steps:])
195
+ current_velocity_acc = np.mean(training_velocity_acc[-save_steps:])
196
+ current_qualities_acc = np.mean(training_qualities_acc[-save_steps:])
197
+ print('train_loss: %.5f' % current_loss)
198
+ print('current_instrument_acc: %.5f' % current_instrument_acc)
199
+ print('current_instrument_family_acc: %.5f' % current_instrument_family_acc)
200
+ print('current_velocity_acc: %.5f' % current_velocity_acc)
201
+ print('current_qualities_acc: %.5f' % current_qualities_acc)
202
+ writer.add_scalar(f"train_loss", current_loss, step)
203
+ writer.add_scalar(f"current_instrument_acc", current_instrument_acc, step)
204
+ writer.add_scalar(f"current_instrument_family_acc", current_instrument_family_acc, step)
205
+ writer.add_scalar(f"current_velocity_acc", current_velocity_acc, step)
206
+ writer.add_scalar(f"current_qualities_acc", current_qualities_acc, step)
207
+
208
+ if current_loss < previous_lowest_loss:
209
+ previous_lowest_loss = current_loss
210
+ current_best_model = model
211
+ torch.save({
212
+ 'model_state_dict': model.state_dict(),
213
+ 'optimizer_state_dict': optimizer.state_dict(),
214
+ }, f'models/{model_name}_timbre_encoder.pth')
215
+ save_model_hyperparameter(model_name, timbre_encoder_Config, BATCH_SIZE, lr, model_size, step,
216
+ current_loss)
217
+
218
+ return model, current_best_model
219
+
220
+
tools.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib
4
+ import librosa
5
+ from scipy.io.wavfile import write
6
+ import torch
7
+
8
+ k = 1e-16
9
+
10
+ def np_log10(x):
11
+ """Safe log function with base 10."""
12
+ numerator = np.log(x + 1e-16)
13
+ denominator = np.log(10)
14
+ return numerator / denominator
15
+
16
+
17
+ def sigmoid(x):
18
+ """Safe log function with base 10."""
19
+ s = 1 / (1 + np.exp(-x))
20
+ return s
21
+
22
+
23
+ def inv_sigmoid(s):
24
+ """Safe inverse sigmoid function."""
25
+ x = np.log((s / (1 - s)) + 1e-16)
26
+ return x
27
+
28
+
29
+ def spc_to_VAE_input(spc):
30
+ """Restrict value range from [0, infinite] to [0, 1]. (deprecated )"""
31
+ return spc / (1 + spc)
32
+
33
+
34
+ def VAE_out_put_to_spc(o):
35
+ """Inverse transform of function 'spc_to_VAE_input'. (deprecated )"""
36
+ return o / (1 - o + k)
37
+
38
+
39
+
40
+ def np_power_to_db(S, amin=1e-16, top_db=80.0):
41
+ """Helper method for numpy data scaling. (deprecated )"""
42
+ ref = S.max()
43
+
44
+ log_spec = 10.0 * np_log10(np.maximum(amin, S))
45
+ log_spec -= 10.0 * np_log10(np.maximum(amin, ref))
46
+
47
+ log_spec = np.maximum(log_spec, log_spec.max() - top_db)
48
+
49
+ return log_spec
50
+
51
+
52
+ def show_spc(spc):
53
+ """Show a spectrogram. (deprecated )"""
54
+ s = np.shape(spc)
55
+ spc = np.reshape(spc, (s[0], s[1]))
56
+ magnitude_spectrum = np.abs(spc)
57
+ log_spectrum = np_power_to_db(magnitude_spectrum)
58
+ plt.imshow(np.flipud(log_spectrum))
59
+ plt.show()
60
+
61
+
62
+ def save_results(spectrogram, spectrogram_image_path, waveform_path):
63
+ """Save the input 'spectrogram' and its waveform (reconstructed by Griffin Lim)
64
+ to path provided by 'spectrogram_image_path' and 'waveform_path'."""
65
+ magnitude_spectrum = np.abs(spectrogram)
66
+ log_spc = np_power_to_db(magnitude_spectrum)
67
+ log_spc = np.reshape(log_spc, (512, 256))
68
+ matplotlib.pyplot.imsave(spectrogram_image_path, log_spc, vmin=-100, vmax=0,
69
+ origin='lower')
70
+
71
+ # save waveform
72
+ abs_spec = np.zeros((513, 256))
73
+ abs_spec[:512, :] = abs_spec[:512, :] + np.sqrt(np.reshape(spectrogram, (512, 256)))
74
+ rec_signal = librosa.griffinlim(abs_spec, n_iter=32, hop_length=256, win_length=1024)
75
+ write(waveform_path, 16000, rec_signal)
76
+
77
+
78
+ def plot_log_spectrogram(signal: np.ndarray,
79
+ path: str,
80
+ n_fft=2048,
81
+ frame_length=1024,
82
+ frame_step=256):
83
+ """Save spectrogram."""
84
+ stft = librosa.stft(signal, n_fft=n_fft, hop_length=frame_step, win_length=frame_length)
85
+ amp = np.square(np.real(stft)) + np.square(np.imag(stft))
86
+ magnitude_spectrum = np.abs(amp)
87
+ log_mel = np_power_to_db(magnitude_spectrum)
88
+ matplotlib.pyplot.imsave(path, log_mel, vmin=-100, vmax=0, origin='lower')
89
+
90
+
91
+ def visualize_feature_maps(device, model, inputs, channel_indices=[0, 3,]):
92
+ """
93
+ Visualize feature maps before and after quantization for given input.
94
+
95
+ Parameters:
96
+ - model: Your VQ-VAE model.
97
+ - inputs: A batch of input data.
98
+ - channel_indices: Indices of feature map channels to visualize.
99
+ """
100
+ model.eval()
101
+ inputs = inputs.to(device)
102
+
103
+ with torch.no_grad():
104
+ z_e = model._encoder(inputs)
105
+ z_q, loss, (perplexity, min_encodings, min_encoding_indices) = model._vq_vae(z_e)
106
+
107
+ # Assuming inputs have shape [batch_size, channels, height, width]
108
+ batch_size = z_e.size(0)
109
+
110
+ for idx in range(batch_size):
111
+ fig, axs = plt.subplots(1, len(channel_indices)*2, figsize=(15, 5))
112
+
113
+ for i, channel_idx in enumerate(channel_indices):
114
+ # Plot encoder output
115
+ axs[2*i].imshow(z_e[idx][channel_idx].cpu().numpy(), cmap='viridis')
116
+ axs[2*i].set_title(f"Encoder Output - Channel {channel_idx}")
117
+
118
+ # Plot quantized output
119
+ axs[2*i+1].imshow(z_q[idx][channel_idx].cpu().numpy(), cmap='viridis')
120
+ axs[2*i+1].set_title(f"Quantized Output - Channel {channel_idx}")
121
+
122
+ plt.show()
123
+
124
+
125
+ def adjust_audio_length(audio, desired_length, original_sample_rate, target_sample_rate):
126
+ """
127
+ Adjust the audio length to the desired length and resample to target sample rate.
128
+
129
+ Parameters:
130
+ - audio (np.array): The input audio signal
131
+ - desired_length (int): The desired length of the output audio
132
+ - original_sample_rate (int): The original sample rate of the audio
133
+ - target_sample_rate (int): The target sample rate for the output audio
134
+
135
+ Returns:
136
+ - np.array: The adjusted and resampled audio
137
+ """
138
+
139
+ if not (original_sample_rate == target_sample_rate):
140
+ audio = librosa.core.resample(audio, orig_sr=original_sample_rate, target_sr=target_sample_rate)
141
+
142
+ if len(audio) > desired_length:
143
+ return audio[:desired_length]
144
+
145
+ elif len(audio) < desired_length:
146
+ padded_audio = np.zeros(desired_length)
147
+ padded_audio[:len(audio)] = audio
148
+ return padded_audio
149
+ else:
150
+ return audio
151
+
152
+
153
+ def safe_int(s, default=0):
154
+ try:
155
+ return int(s)
156
+ except ValueError:
157
+ return default
158
+
159
+
160
+ def pad_spectrogram(D):
161
+ """Resize spectrogram to (512, 256). (deprecated )"""
162
+ D = D[1:, :]
163
+
164
+ padding_length = 256 - D.shape[1]
165
+ D_padded = np.pad(D, ((0, 0), (0, padding_length)), 'constant')
166
+ return D_padded
167
+
168
+
169
+ def pad_STFT(D, time_resolution=256):
170
+ """Resize spectral matrix by padding and cropping"""
171
+ D = D[1:, :]
172
+
173
+ if time_resolution is None:
174
+ return D
175
+
176
+ padding_length = time_resolution - D.shape[1]
177
+ if padding_length > 0:
178
+ D_padded = np.pad(D, ((0, 0), (0, padding_length)), 'constant')
179
+ return D_padded
180
+ else:
181
+ return D
182
+
183
+
184
+ def depad_STFT(D_padded):
185
+ """Inverse function of 'pad_STFT'"""
186
+ zero_row = np.zeros((1, D_padded.shape[1]))
187
+
188
+ D_restored = np.concatenate([zero_row, D_padded], axis=0)
189
+
190
+ return D_restored
191
+
192
+
193
+ def nnData2Audio(spectrogram_batch, resolution=(512, 256), squared=False):
194
+ """Transform batch of numpy spectrogram into signals and encodings."""
195
+ # Todo: remove resolution hard-coding
196
+ frequency_resolution, time_resolution = resolution
197
+
198
+ if isinstance(spectrogram_batch, torch.Tensor):
199
+ spectrogram_batch = spectrogram_batch.to("cpu").detach().numpy()
200
+
201
+ origin_signals = []
202
+ for spectrogram in spectrogram_batch:
203
+ spc = VAE_out_put_to_spc(spectrogram)
204
+
205
+ # get_audio
206
+ abs_spec = np.zeros((frequency_resolution+1, time_resolution))
207
+
208
+ if squared:
209
+ abs_spec[1:, :] = abs_spec[1:, :] + np.sqrt(np.reshape(spc, (frequency_resolution, time_resolution)))
210
+ else:
211
+ abs_spec[1:, :] = abs_spec[1:, :] + np.reshape(spc, (frequency_resolution, time_resolution))
212
+
213
+ origin_signal = librosa.griffinlim(abs_spec, n_iter=32, hop_length=256, win_length=1024)
214
+ origin_signals.append(origin_signal)
215
+
216
+ return origin_signals
217
+
218
+
219
+ def amp_to_audio(amp, n_iter=50):
220
+ """The Griffin-Lim algorithm."""
221
+ y_reconstructed = librosa.griffinlim(amp, n_iter=n_iter, hop_length=256, win_length=1024)
222
+ return y_reconstructed
223
+
224
+
225
+ def rescale(amp, method="log1p"):
226
+ """Rescale function."""
227
+ if method == "log1p":
228
+ return np.log1p(amp)
229
+ elif method == "NormalizedLogisticCompression":
230
+ return amp / (1.0 + amp)
231
+ else:
232
+ raise NotImplementedError()
233
+
234
+
235
+ def unrescale(scaled_amp, method="NormalizedLogisticCompression"):
236
+ """Inverse function of 'rescale'"""
237
+ if method == "log1p":
238
+ return np.expm1(scaled_amp)
239
+ elif method == "NormalizedLogisticCompression":
240
+ return scaled_amp / (1.0 - scaled_amp + 1e-10)
241
+ else:
242
+ raise NotImplementedError()
243
+
244
+
245
+ def create_key(attributes):
246
+ """Create unique key for each multi-label."""
247
+ qualities_str = ''.join(map(str, attributes["qualities"]))
248
+ instrument_source_str = attributes["instrument_source_str"]
249
+ instrument_family = attributes["instrument_family_str"]
250
+ key = f"{instrument_source_str}_{instrument_family}_{qualities_str}"
251
+ return key
252
+
253
+
254
+ def merge_dictionaries(dicts):
255
+ """Merge dictionaries."""
256
+ merged_dict = {}
257
+ for dictionary in dicts:
258
+ for key, value in dictionary.items():
259
+ if key in merged_dict:
260
+ merged_dict[key] += value
261
+ else:
262
+ merged_dict[key] = value
263
+ return merged_dict
264
+
265
+
266
+ def adsr_envelope(signal, sample_rate, duration, attack_time, decay_time, sustain_level, release_time):
267
+ """
268
+ Apply an ADSR envelope to an audio signal.
269
+
270
+ :param signal: The original audio signal (numpy array).
271
+ :param sample_rate: The sample rate of the audio signal.
272
+ :param attack_time: Attack time in seconds.
273
+ :param decay_time: Decay time in seconds.
274
+ :param sustain_level: Sustain level as a fraction of the peak (0 to 1).
275
+ :param release_time: Release time in seconds.
276
+ :return: The audio signal with the ADSR envelope applied.
277
+ """
278
+ # Calculate the number of samples for each ADSR phase
279
+ duration_samples = int(duration * sample_rate)
280
+
281
+ # assert (duration_samples + int(1.0 * sample_rate)) <= len(signal), "(duration_samples + sample_rate) > len(signal)"
282
+ assert release_time <= 1.0, "release_time > 1.0"
283
+
284
+ attack_samples = int(attack_time * sample_rate)
285
+ decay_samples = int(decay_time * sample_rate)
286
+ release_samples = int(release_time * sample_rate)
287
+ sustain_samples = max(0, duration_samples - attack_samples - decay_samples)
288
+
289
+ # Create ADSR envelope
290
+ attack_env = np.linspace(0, 1, attack_samples)
291
+ decay_env = np.linspace(1, sustain_level, decay_samples)
292
+ sustain_env = np.full(sustain_samples, sustain_level)
293
+ release_env = np.linspace(sustain_level, 0, release_samples)
294
+ release_env_expand = np.zeros(int(1.0 * sample_rate))
295
+ release_env_expand[:len(release_env)] = release_env
296
+
297
+ # Concatenate all phases to create the complete envelope
298
+ envelope = np.concatenate([attack_env, decay_env, sustain_env, release_env_expand])
299
+
300
+ # Apply the envelope to the signal
301
+ if len(envelope) <= len(signal):
302
+ applied_signal = signal[:len(envelope)] * envelope
303
+ else:
304
+ signal_expanded = np.zeros(len(envelope))
305
+ signal_expanded[:len(signal)] = signal
306
+ applied_signal = signal_expanded * envelope
307
+
308
+ return applied_signal
309
+
310
+
311
+ def rms_normalize(audio, target_rms=0.1):
312
+ """Normalize the RMS value."""
313
+ current_rms = np.sqrt(np.mean(audio**2))
314
+ scaling_factor = target_rms / current_rms
315
+ normalized_audio = audio * scaling_factor
316
+ return normalized_audio
317
+
318
+
319
+ def encode_stft(D):
320
+ """'STFT+' function that transform spectral matrix into spectral representation."""
321
+ magnitude = np.abs(D)
322
+ phase = np.angle(D)
323
+
324
+ log_magnitude = np.log1p(magnitude)
325
+
326
+ cos_phase = np.cos(phase)
327
+ sin_phase = np.sin(phase)
328
+
329
+ encoded_D = np.stack([log_magnitude, cos_phase, sin_phase], axis=0)
330
+ return encoded_D
331
+
332
+
333
+ def decode_stft(encoded_D):
334
+ """'ISTFT+' function that reconstructs spectral matrix from spectral representation."""
335
+ log_magnitude = encoded_D[0, ...]
336
+ cos_phase = encoded_D[1, ...]
337
+ sin_phase = encoded_D[2, ...]
338
+
339
+ magnitude = np.expm1(log_magnitude)
340
+
341
+ phase = np.arctan2(sin_phase, cos_phase)
342
+
343
+ D = magnitude * (np.cos(phase) + 1j * np.sin(phase))
344
+ return D
webUI/__pycache__/app.cpython-310.pyc ADDED
Binary file (10.4 kB). View file
 
webUI/deprecated/interpolationWithCondition.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+
5
+ from model.DiffSynthSampler import DiffSynthSampler
6
+ from tools import safe_int
7
+ from webUI.natural_language_guided_STFT.utils import encodeBatch2GradioOutput, latent_representation_to_Gradio_image
8
+
9
+
10
+ def get_interpolation_with_condition_module(gradioWebUI, interpolation_with_text_state):
11
+ # Load configurations
12
+ uNet = gradioWebUI.uNet
13
+ freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
14
+ VAE_scale = gradioWebUI.VAE_scale
15
+ height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels
16
+ timesteps = gradioWebUI.timesteps
17
+ VAE_quantizer = gradioWebUI.VAE_quantizer
18
+ VAE_decoder = gradioWebUI.VAE_decoder
19
+ CLAP = gradioWebUI.CLAP
20
+ CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
21
+ device = gradioWebUI.device
22
+ squared = gradioWebUI.squared
23
+ sample_rate = gradioWebUI.sample_rate
24
+ noise_strategy = gradioWebUI.noise_strategy
25
+
26
+ def diffusion_random_sample(text2sound_prompts_1, text2sound_prompts_2, text2sound_negative_prompts, text2sound_batchsize,
27
+ text2sound_duration,
28
+ text2sound_guidance_scale, text2sound_sampler,
29
+ text2sound_sample_steps, text2sound_seed,
30
+ interpolation_with_text_dict):
31
+ text2sound_sample_steps = int(text2sound_sample_steps)
32
+ text2sound_seed = safe_int(text2sound_seed, 12345678)
33
+ # Todo: take care of text2sound_time_resolution/width
34
+ width = int(time_resolution*((text2sound_duration+1)/4) / VAE_scale)
35
+ text2sound_batchsize = int(text2sound_batchsize)
36
+
37
+ text2sound_embedding_1 = \
38
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts_1], padding=True, return_tensors="pt"))[0].to(device)
39
+ text2sound_embedding_2 = \
40
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts_2], padding=True, return_tensors="pt"))[0].to(device)
41
+
42
+ CFG = int(text2sound_guidance_scale)
43
+
44
+ mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
45
+ unconditional_condition = \
46
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[0]
47
+ mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device))
48
+
49
+ mySampler.respace(list(np.linspace(0, timesteps - 1, text2sound_sample_steps, dtype=np.int32)))
50
+
51
+ condition = torch.linspace(1, 0, steps=text2sound_batchsize).unsqueeze(1).to(device) * text2sound_embedding_1 + \
52
+ torch.linspace(0, 1, steps=text2sound_batchsize).unsqueeze(1).to(device) * text2sound_embedding_2
53
+
54
+ # Todo: move this code
55
+ torch.manual_seed(text2sound_seed)
56
+ initial_noise = torch.randn(text2sound_batchsize, channels, height, width).to(device)
57
+
58
+ latent_representations, initial_noise = \
59
+ mySampler.sample(model=uNet, shape=(text2sound_batchsize, channels, height, width), seed=text2sound_seed,
60
+ return_tensor=True, condition=condition, sampler=text2sound_sampler, initial_noise=initial_noise)
61
+
62
+ latent_representations = latent_representations[-1]
63
+
64
+ interpolation_with_text_dict["latent_representations"] = latent_representations
65
+
66
+ latent_representation_gradio_images = []
67
+ quantized_latent_representation_gradio_images = []
68
+ new_sound_spectrogram_gradio_images = []
69
+ new_sound_rec_signals_gradio = []
70
+
71
+ quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
72
+ # Todo: remove hard-coding
73
+ flipped_log_spectrums, rec_signals = encodeBatch2GradioOutput(VAE_decoder, quantized_latent_representations,
74
+ resolution=(512, width * VAE_scale), centralized=False,
75
+ squared=squared)
76
+
77
+ for i in range(text2sound_batchsize):
78
+ latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
79
+ quantized_latent_representation_gradio_images.append(
80
+ latent_representation_to_Gradio_image(quantized_latent_representations[i]))
81
+ new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i])
82
+ new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i]))
83
+
84
+ def concatenate_arrays(arrays_list):
85
+ return np.concatenate(arrays_list, axis=1)
86
+
87
+ concatenated_spectrogram_gradio_image = concatenate_arrays(new_sound_spectrogram_gradio_images)
88
+
89
+ interpolation_with_text_dict["latent_representation_gradio_images"] = latent_representation_gradio_images
90
+ interpolation_with_text_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
91
+ interpolation_with_text_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
92
+ interpolation_with_text_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
93
+
94
+ return {text2sound_latent_representation_image: interpolation_with_text_dict["latent_representation_gradio_images"][0],
95
+ text2sound_quantized_latent_representation_image:
96
+ interpolation_with_text_dict["quantized_latent_representation_gradio_images"][0],
97
+ text2sound_sampled_concatenated_spectrogram_image: concatenated_spectrogram_gradio_image,
98
+ text2sound_sampled_spectrogram_image: interpolation_with_text_dict["new_sound_spectrogram_gradio_images"][0],
99
+ text2sound_sampled_audio: interpolation_with_text_dict["new_sound_rec_signals_gradio"][0],
100
+ text2sound_seed_textbox: text2sound_seed,
101
+ interpolation_with_text_state: interpolation_with_text_dict,
102
+ text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1,
103
+ visible=True,
104
+ label="Sample index.",
105
+ info="Swipe to view other samples")}
106
+
107
+ def show_random_sample(sample_index, text2sound_dict):
108
+ sample_index = int(sample_index)
109
+ return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][
110
+ sample_index],
111
+ text2sound_quantized_latent_representation_image:
112
+ text2sound_dict["quantized_latent_representation_gradio_images"][sample_index],
113
+ text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][sample_index],
114
+ text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]}
115
+
116
+ with gr.Tab("InterpolationCond."):
117
+ gr.Markdown("Use interpolation to generate a gradient sound sequence.")
118
+ with gr.Row(variant="panel"):
119
+ with gr.Column(scale=3):
120
+ text2sound_prompts_1_textbox = gr.Textbox(label="Positive prompt 1", lines=2, value="organ")
121
+ text2sound_prompts_2_textbox = gr.Textbox(label="Positive prompt 2", lines=2, value="string")
122
+ text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
123
+
124
+ with gr.Column(scale=1):
125
+ text2sound_sampling_button = gr.Button(variant="primary",
126
+ value="Generate a batch of samples and show "
127
+ "the first one",
128
+ scale=1)
129
+ text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
130
+ label="Sample index",
131
+ info="Swipe to view other samples")
132
+ with gr.Row(variant="panel"):
133
+ with gr.Column(scale=1, variant="panel"):
134
+ text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
135
+ text2sound_sampler_radio = gradioWebUI.get_sampler_radio()
136
+ text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider(cpu_batchsize=3)
137
+ text2sound_duration_slider = gradioWebUI.get_duration_slider()
138
+ text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
139
+ text2sound_seed_textbox = gradioWebUI.get_seed_textbox()
140
+
141
+ with gr.Column(scale=1):
142
+ with gr.Row(variant="panel"):
143
+ text2sound_sampled_concatenated_spectrogram_image = gr.Image(label="Interpolations", type="numpy",
144
+ height=420, scale=8)
145
+ text2sound_sampled_spectrogram_image = gr.Image(label="Selected spectrogram", type="numpy",
146
+ height=420, scale=1)
147
+ text2sound_sampled_audio = gr.Audio(type="numpy", label="Play")
148
+
149
+ with gr.Row(variant="panel"):
150
+ text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy",
151
+ height=200, width=100)
152
+ text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation",
153
+ type="numpy", height=200, width=100)
154
+
155
+ text2sound_sampling_button.click(diffusion_random_sample,
156
+ inputs=[text2sound_prompts_1_textbox,
157
+ text2sound_prompts_2_textbox,
158
+ text2sound_negative_prompts_textbox,
159
+ text2sound_batchsize_slider,
160
+ text2sound_duration_slider,
161
+ text2sound_guidance_scale_slider, text2sound_sampler_radio,
162
+ text2sound_sample_steps_slider,
163
+ text2sound_seed_textbox,
164
+ interpolation_with_text_state],
165
+ outputs=[text2sound_latent_representation_image,
166
+ text2sound_quantized_latent_representation_image,
167
+ text2sound_sampled_concatenated_spectrogram_image,
168
+ text2sound_sampled_spectrogram_image,
169
+ text2sound_sampled_audio,
170
+ text2sound_seed_textbox,
171
+ interpolation_with_text_state,
172
+ text2sound_sample_index_slider])
173
+ text2sound_sample_index_slider.change(show_random_sample,
174
+ inputs=[text2sound_sample_index_slider, interpolation_with_text_state],
175
+ outputs=[text2sound_latent_representation_image,
176
+ text2sound_quantized_latent_representation_image,
177
+ text2sound_sampled_spectrogram_image,
178
+ text2sound_sampled_audio])
webUI/deprecated/interpolationWithXT.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+
5
+ from model.DiffSynthSampler import DiffSynthSampler
6
+ from tools import safe_int
7
+ from webUI.natural_language_guided.utils import encodeBatch2GradioOutput, latent_representation_to_Gradio_image
8
+
9
+
10
+ def get_interpolation_with_xT_module(gradioWebUI, interpolation_with_text_state):
11
+ # Load configurations
12
+ uNet = gradioWebUI.uNet
13
+ freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
14
+ VAE_scale = gradioWebUI.VAE_scale
15
+ height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels
16
+ timesteps = gradioWebUI.timesteps
17
+ VAE_quantizer = gradioWebUI.VAE_quantizer
18
+ VAE_decoder = gradioWebUI.VAE_decoder
19
+ CLAP = gradioWebUI.CLAP
20
+ CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
21
+ device = gradioWebUI.device
22
+ squared = gradioWebUI.squared
23
+ sample_rate = gradioWebUI.sample_rate
24
+ noise_strategy = gradioWebUI.noise_strategy
25
+
26
+ def diffusion_random_sample(text2sound_prompts, text2sound_negative_prompts, text2sound_batchsize,
27
+ text2sound_duration,
28
+ text2sound_noise_variance, text2sound_guidance_scale, text2sound_sampler,
29
+ text2sound_sample_steps, text2sound_seed,
30
+ interpolation_with_text_dict):
31
+ text2sound_sample_steps = int(text2sound_sample_steps)
32
+ text2sound_seed = safe_int(text2sound_seed, 12345678)
33
+ # Todo: take care of text2sound_time_resolution/width
34
+ width = int(time_resolution*((text2sound_duration+1)/4) / VAE_scale)
35
+ text2sound_batchsize = int(text2sound_batchsize)
36
+
37
+ text2sound_embedding = \
38
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(device)
39
+
40
+ CFG = int(text2sound_guidance_scale)
41
+
42
+ mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
43
+ unconditional_condition = \
44
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[0]
45
+ mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device))
46
+
47
+ mySampler.respace(list(np.linspace(0, timesteps - 1, text2sound_sample_steps, dtype=np.int32)))
48
+
49
+ condition = text2sound_embedding.repeat(text2sound_batchsize, 1)
50
+ latent_representations, initial_noise = \
51
+ mySampler.interpolate(model=uNet, shape=(text2sound_batchsize, channels, height, width),
52
+ seed=text2sound_seed,
53
+ variance=text2sound_noise_variance,
54
+ return_tensor=True, condition=condition, sampler=text2sound_sampler)
55
+
56
+ latent_representations = latent_representations[-1]
57
+
58
+ interpolation_with_text_dict["latent_representations"] = latent_representations
59
+
60
+ latent_representation_gradio_images = []
61
+ quantized_latent_representation_gradio_images = []
62
+ new_sound_spectrogram_gradio_images = []
63
+ new_sound_rec_signals_gradio = []
64
+
65
+ quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
66
+ # Todo: remove hard-coding
67
+ flipped_log_spectrums, rec_signals = encodeBatch2GradioOutput(VAE_decoder, quantized_latent_representations,
68
+ resolution=(512, width * VAE_scale), centralized=False,
69
+ squared=squared)
70
+
71
+ for i in range(text2sound_batchsize):
72
+ latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
73
+ quantized_latent_representation_gradio_images.append(
74
+ latent_representation_to_Gradio_image(quantized_latent_representations[i]))
75
+ new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i])
76
+ new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i]))
77
+
78
+ def concatenate_arrays(arrays_list):
79
+ return np.concatenate(arrays_list, axis=1)
80
+
81
+ concatenated_spectrogram_gradio_image = concatenate_arrays(new_sound_spectrogram_gradio_images)
82
+
83
+ interpolation_with_text_dict["latent_representation_gradio_images"] = latent_representation_gradio_images
84
+ interpolation_with_text_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
85
+ interpolation_with_text_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
86
+ interpolation_with_text_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
87
+
88
+ return {text2sound_latent_representation_image: interpolation_with_text_dict["latent_representation_gradio_images"][0],
89
+ text2sound_quantized_latent_representation_image:
90
+ interpolation_with_text_dict["quantized_latent_representation_gradio_images"][0],
91
+ text2sound_sampled_concatenated_spectrogram_image: concatenated_spectrogram_gradio_image,
92
+ text2sound_sampled_spectrogram_image: interpolation_with_text_dict["new_sound_spectrogram_gradio_images"][0],
93
+ text2sound_sampled_audio: interpolation_with_text_dict["new_sound_rec_signals_gradio"][0],
94
+ text2sound_seed_textbox: text2sound_seed,
95
+ interpolation_with_text_state: interpolation_with_text_dict,
96
+ text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1,
97
+ visible=True,
98
+ label="Sample index.",
99
+ info="Swipe to view other samples")}
100
+
101
+ def show_random_sample(sample_index, text2sound_dict):
102
+ sample_index = int(sample_index)
103
+ return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][
104
+ sample_index],
105
+ text2sound_quantized_latent_representation_image:
106
+ text2sound_dict["quantized_latent_representation_gradio_images"][sample_index],
107
+ text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][sample_index],
108
+ text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]}
109
+
110
+ with gr.Tab("InterpolationXT"):
111
+ gr.Markdown("Use interpolation to generate a gradient sound sequence.")
112
+ with gr.Row(variant="panel"):
113
+ with gr.Column(scale=3):
114
+ text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ")
115
+ text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
116
+
117
+ with gr.Column(scale=1):
118
+ text2sound_sampling_button = gr.Button(variant="primary",
119
+ value="Generate a batch of samples and show "
120
+ "the first one",
121
+ scale=1)
122
+ text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
123
+ label="Sample index",
124
+ info="Swipe to view other samples")
125
+ with gr.Row(variant="panel"):
126
+ with gr.Column(scale=1, variant="panel"):
127
+ text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
128
+ text2sound_sampler_radio = gradioWebUI.get_sampler_radio()
129
+ text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider(cpu_batchsize=3)
130
+ text2sound_duration_slider = gradioWebUI.get_duration_slider()
131
+ text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
132
+ text2sound_seed_textbox = gradioWebUI.get_seed_textbox()
133
+ text2sound_noise_variance_slider = gr.Slider(minimum=0., maximum=5., value=1., step=0.01,
134
+ label="Noise variance",
135
+ info="The larger this value, the more diversity the interpolation has.")
136
+
137
+ with gr.Column(scale=1):
138
+ with gr.Row(variant="panel"):
139
+ text2sound_sampled_concatenated_spectrogram_image = gr.Image(label="Interpolations", type="numpy",
140
+ height=420, scale=8)
141
+ text2sound_sampled_spectrogram_image = gr.Image(label="Selected spectrogram", type="numpy",
142
+ height=420, scale=1)
143
+ text2sound_sampled_audio = gr.Audio(type="numpy", label="Play")
144
+
145
+ with gr.Row(variant="panel"):
146
+ text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy",
147
+ height=200, width=100)
148
+ text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation",
149
+ type="numpy", height=200, width=100)
150
+
151
+ text2sound_sampling_button.click(diffusion_random_sample,
152
+ inputs=[text2sound_prompts_textbox, text2sound_negative_prompts_textbox,
153
+ text2sound_batchsize_slider,
154
+ text2sound_duration_slider,
155
+ text2sound_noise_variance_slider,
156
+ text2sound_guidance_scale_slider, text2sound_sampler_radio,
157
+ text2sound_sample_steps_slider,
158
+ text2sound_seed_textbox,
159
+ interpolation_with_text_state],
160
+ outputs=[text2sound_latent_representation_image,
161
+ text2sound_quantized_latent_representation_image,
162
+ text2sound_sampled_concatenated_spectrogram_image,
163
+ text2sound_sampled_spectrogram_image,
164
+ text2sound_sampled_audio,
165
+ text2sound_seed_textbox,
166
+ interpolation_with_text_state,
167
+ text2sound_sample_index_slider])
168
+ text2sound_sample_index_slider.change(show_random_sample,
169
+ inputs=[text2sound_sample_index_slider, interpolation_with_text_state],
170
+ outputs=[text2sound_latent_representation_image,
171
+ text2sound_quantized_latent_representation_image,
172
+ text2sound_sampled_spectrogram_image,
173
+ text2sound_sampled_audio])
webUI/natural_language_guided/GAN.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+
5
+ from tools import safe_int
6
+ from webUI.natural_language_guided_STFT.utils import encodeBatch2GradioOutput, latent_representation_to_Gradio_image, \
7
+ add_instrument
8
+
9
+
10
+ def get_testGAN(gradioWebUI, text2sound_state, virtual_instruments_state):
11
+ # Load configurations
12
+ gan_generator = gradioWebUI.GAN_generator
13
+ freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
14
+ VAE_scale = gradioWebUI.VAE_scale
15
+ height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels
16
+
17
+ timesteps = gradioWebUI.timesteps
18
+ VAE_quantizer = gradioWebUI.VAE_quantizer
19
+ VAE_decoder = gradioWebUI.VAE_decoder
20
+ CLAP = gradioWebUI.CLAP
21
+ CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
22
+ device = gradioWebUI.device
23
+ squared = gradioWebUI.squared
24
+ sample_rate = gradioWebUI.sample_rate
25
+ noise_strategy = gradioWebUI.noise_strategy
26
+
27
+ def gan_random_sample(text2sound_prompts, text2sound_negative_prompts, text2sound_batchsize,
28
+ text2sound_duration,
29
+ text2sound_guidance_scale, text2sound_sampler,
30
+ text2sound_sample_steps, text2sound_seed,
31
+ text2sound_dict):
32
+ text2sound_seed = safe_int(text2sound_seed, 12345678)
33
+
34
+ width = int(time_resolution * ((text2sound_duration + 1) / 4) / VAE_scale)
35
+
36
+ text2sound_batchsize = int(text2sound_batchsize)
37
+
38
+ text2sound_embedding = \
39
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(
40
+ device)
41
+
42
+ CFG = int(text2sound_guidance_scale)
43
+
44
+ condition = text2sound_embedding.repeat(text2sound_batchsize, 1)
45
+
46
+ noise = torch.randn(text2sound_batchsize, channels, height, width).to(device)
47
+ latent_representations = gan_generator(noise, condition)
48
+
49
+ print(latent_representations[0, 0, :3, :3])
50
+
51
+ latent_representation_gradio_images = []
52
+ quantized_latent_representation_gradio_images = []
53
+ new_sound_spectrogram_gradio_images = []
54
+ new_sound_rec_signals_gradio = []
55
+
56
+ quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
57
+ # Todo: remove hard-coding
58
+ flipped_log_spectrums, rec_signals = encodeBatch2GradioOutput(VAE_decoder, quantized_latent_representations,
59
+ resolution=(512, width * VAE_scale),
60
+ centralized=False,
61
+ squared=squared)
62
+
63
+ for i in range(text2sound_batchsize):
64
+ latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
65
+ quantized_latent_representation_gradio_images.append(
66
+ latent_representation_to_Gradio_image(quantized_latent_representations[i]))
67
+ new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i])
68
+ new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i]))
69
+
70
+ text2sound_dict["latent_representations"] = latent_representations.to("cpu").detach().numpy()
71
+ text2sound_dict["quantized_latent_representations"] = quantized_latent_representations.to("cpu").detach().numpy()
72
+ text2sound_dict["latent_representation_gradio_images"] = latent_representation_gradio_images
73
+ text2sound_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
74
+ text2sound_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
75
+ text2sound_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
76
+
77
+ text2sound_dict["condition"] = condition.to("cpu").detach().numpy()
78
+ # text2sound_dict["negative_condition"] = negative_condition.to("cpu").detach().numpy()
79
+ text2sound_dict["guidance_scale"] = CFG
80
+ text2sound_dict["sampler"] = text2sound_sampler
81
+
82
+ return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][0],
83
+ text2sound_quantized_latent_representation_image:
84
+ text2sound_dict["quantized_latent_representation_gradio_images"][0],
85
+ text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][0],
86
+ text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][0],
87
+ text2sound_seed_textbox: text2sound_seed,
88
+ text2sound_state: text2sound_dict,
89
+ text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1,
90
+ visible=True,
91
+ label="Sample index.",
92
+ info="Swipe to view other samples")}
93
+
94
+ def show_random_sample(sample_index, text2sound_dict):
95
+ sample_index = int(sample_index)
96
+ text2sound_dict["sample_index"] = sample_index
97
+ return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][
98
+ sample_index],
99
+ text2sound_quantized_latent_representation_image:
100
+ text2sound_dict["quantized_latent_representation_gradio_images"][sample_index],
101
+ text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][
102
+ sample_index],
103
+ text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]}
104
+
105
+
106
+ with gr.Tab("Text2sound_GAN"):
107
+ gr.Markdown("Use neural networks to select random sounds using your favorite instrument!")
108
+ with gr.Row(variant="panel"):
109
+ with gr.Column(scale=3):
110
+ text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ")
111
+ text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
112
+
113
+ with gr.Column(scale=1):
114
+ text2sound_sampling_button = gr.Button(variant="primary",
115
+ value="Generate a batch of samples and show "
116
+ "the first one",
117
+ scale=1)
118
+ text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
119
+ label="Sample index",
120
+ info="Swipe to view other samples")
121
+ with gr.Row(variant="panel"):
122
+ with gr.Column(scale=1, variant="panel"):
123
+ text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
124
+ text2sound_sampler_radio = gradioWebUI.get_sampler_radio()
125
+ text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider()
126
+ text2sound_duration_slider = gradioWebUI.get_duration_slider()
127
+ text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
128
+ text2sound_seed_textbox = gradioWebUI.get_seed_textbox()
129
+
130
+ with gr.Column(scale=1):
131
+ text2sound_sampled_spectrogram_image = gr.Image(label="Sampled spectrogram", type="numpy", height=420)
132
+ text2sound_sampled_audio = gr.Audio(type="numpy", label="Play")
133
+
134
+
135
+ with gr.Row(variant="panel"):
136
+ text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy",
137
+ height=200, width=100)
138
+ text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation",
139
+ type="numpy", height=200, width=100)
140
+
141
+ text2sound_sampling_button.click(gan_random_sample,
142
+ inputs=[text2sound_prompts_textbox,
143
+ text2sound_negative_prompts_textbox,
144
+ text2sound_batchsize_slider,
145
+ text2sound_duration_slider,
146
+ text2sound_guidance_scale_slider, text2sound_sampler_radio,
147
+ text2sound_sample_steps_slider,
148
+ text2sound_seed_textbox,
149
+ text2sound_state],
150
+ outputs=[text2sound_latent_representation_image,
151
+ text2sound_quantized_latent_representation_image,
152
+ text2sound_sampled_spectrogram_image,
153
+ text2sound_sampled_audio,
154
+ text2sound_seed_textbox,
155
+ text2sound_state,
156
+ text2sound_sample_index_slider])
157
+
158
+
159
+ text2sound_sample_index_slider.change(show_random_sample,
160
+ inputs=[text2sound_sample_index_slider, text2sound_state],
161
+ outputs=[text2sound_latent_representation_image,
162
+ text2sound_quantized_latent_representation_image,
163
+ text2sound_sampled_spectrogram_image,
164
+ text2sound_sampled_audio])
webUI/natural_language_guided/README.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ readme_content = """## Stable Diffusion for Sound Generation
4
+
5
+ This project applies stable diffusion[1] to sound generation. Inspired by the work of AUTOMATIC1111, 2022[2], we have implemented a preliminary version of text2sound, sound2sound, inpaint, as well as an additional interpolation feature, all accessible through a web UI.
6
+
7
+ ### Neural Network Training Data:
8
+ The neural network is trained using the filtered NSynth dataset[3], which is a large-scale and high-quality collection of annotated musical notes, comprising 305,979 musical notes. However, for this project, only samples with a pitch set to E3 were used, resulting in an actual training sample size of 4,096, making it a low-resource project.
9
+
10
+ The training took place on an NVIDIA Tesla T4 GPU and spanned approximately 10 hours.
11
+
12
+ ### Natural Language Guidance:
13
+ Natural language guidance is derived from the multi-label annotations of the NSynth dataset. The labels included in the training are:
14
+
15
+ - **Instrument Families**: bass, brass, flute, guitar, keyboard, mallet, organ, reed, string, synth lead, vocal.
16
+
17
+ - **Instrument Sources**: acoustic, electronic, synthetic.
18
+
19
+ - **Note Qualities**: bright, dark, distortion, fast decay, long release, multiphonic, nonlinear env, percussive, reverb, tempo-synced.
20
+
21
+ ### Usage Hints:
22
+
23
+ 1. **Prompt Format**: It's recommended to use the format “label1, label2, label3“, e.g., ”organ, dark, long release“.
24
+
25
+ 2. **Unique Sounds**: If you keep generating the same sound, try setting a different seed!
26
+
27
+ 3. **Sample Indexing**: Drag the "Sample index slider" to view other samples within the generated batch.
28
+
29
+ 4. **Running on CPU**: Be cautious with the settings for 'batchsize' and 'sample_steps' when running on CPU to avoid timeouts. Recommended settings are batchsize ≤ 4 and sample_steps = 15.
30
+
31
+ 5. **Editing Sounds**: Generated audio can be downloaded and then re-uploaded for further editing at the sound2sound/inpaint sections.
32
+
33
+ 6. **Guidance Scale**: A higher 'guidance_scale' intensifies the influence of natural language conditioning on the generation[4]. It's recommended to set it between 3 and 10.
34
+
35
+ 7. **Noising Strength**: A smaller 'noising_strength' value makes the generated sound closer to the input sound.
36
+
37
+ References:
38
+
39
+ [1] Rombach, R., Blattmann, A., Lorenz, D., Esser, P., & Ommer, B. (2022). High-resolution image synthesis with latent diffusion models. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (pp. 10684-10695).
40
+
41
+ [2] AUTOMATIC1111. (2022). Stable Diffusion Web UI [Computer software]. Retrieved from https://github.com/AUTOMATIC1111/stable-diffusion-webui
42
+
43
+ [3] Engel, J., Resnick, C., Roberts, A., Dieleman, S., Eck, D., Simonyan, K., & Norouzi, M. (2017). Neural Audio Synthesis of Musical Notes with WaveNet Autoencoders.
44
+
45
+ [4] Ho, J., & Salimans, T. (2022). Classifier-free diffusion guidance. arXiv preprint arXiv:2207.12598.
46
+ """
47
+
48
+ def get_readme_module():
49
+
50
+ with gr.Tab("README"):
51
+ # gr.Markdown("Use interpolation to generate a gradient sound sequence.")
52
+ with gr.Column(scale=3):
53
+ readme_textbox = gr.Textbox(label="readme", lines=40, value=readme_content, interactive=False)
webUI/natural_language_guided/__pycache__/README.cpython-310.pyc ADDED
Binary file (3.51 kB). View file
 
webUI/natural_language_guided/__pycache__/README_STFT.cpython-310.pyc ADDED
Binary file (3.51 kB). View file
 
webUI/natural_language_guided/__pycache__/buildInstrument_STFT.cpython-310.pyc ADDED
Binary file (8.26 kB). View file
 
webUI/natural_language_guided/__pycache__/build_instrument.cpython-310.pyc ADDED
Binary file (8.08 kB). View file
 
webUI/natural_language_guided/__pycache__/gradioWebUI.cpython-310.pyc ADDED
Binary file (3.61 kB). View file
 
webUI/natural_language_guided/__pycache__/gradioWebUI_STFT.cpython-310.pyc ADDED
Binary file (3.62 kB). View file
 
webUI/natural_language_guided/__pycache__/gradio_webUI.cpython-310.pyc ADDED
Binary file (3.61 kB). View file
 
webUI/natural_language_guided/__pycache__/inpaintWithText.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
webUI/natural_language_guided/__pycache__/inpaintWithText_STFT.cpython-310.pyc ADDED
Binary file (11.6 kB). View file
 
webUI/natural_language_guided/__pycache__/inpaint_with_text.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
webUI/natural_language_guided/__pycache__/rec.cpython-310.pyc ADDED
Binary file (6.11 kB). View file
 
webUI/natural_language_guided/__pycache__/recSTFT.cpython-310.pyc ADDED
Binary file (6.11 kB). View file
 
webUI/natural_language_guided/__pycache__/sound2soundWithText.cpython-310.pyc ADDED
Binary file (10.7 kB). View file
 
webUI/natural_language_guided/__pycache__/sound2soundWithText_STFT.cpython-310.pyc ADDED
Binary file (10.7 kB). View file
 
webUI/natural_language_guided/__pycache__/sound2sound_with_text.cpython-310.pyc ADDED
Binary file (10.7 kB). View file
 
webUI/natural_language_guided/__pycache__/text2sound.cpython-310.pyc ADDED
Binary file (6.45 kB). View file
 
webUI/natural_language_guided/__pycache__/text2sound_STFT.cpython-310.pyc ADDED
Binary file (6.43 kB). View file
 
webUI/natural_language_guided/__pycache__/track_maker.cpython-310.pyc ADDED
Binary file (7 kB). View file
 
webUI/natural_language_guided/__pycache__/utils.cpython-310.pyc ADDED
Binary file (4.78 kB). View file
 
webUI/natural_language_guided/build_instrument.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import torch
4
+ import gradio as gr
5
+ import mido
6
+ from io import BytesIO
7
+ import pyrubberband as pyrb
8
+
9
+ from model.DiffSynthSampler import DiffSynthSampler
10
+ from tools import adsr_envelope, adjust_audio_length
11
+ from webUI.natural_language_guided.track_maker import DiffSynth
12
+ from webUI.natural_language_guided.utils import encodeBatch2GradioOutput_STFT, phase_to_Gradio_image, \
13
+ spectrogram_to_Gradio_image
14
+
15
+
16
+ def get_build_instrument_module(gradioWebUI, virtual_instruments_state):
17
+ # Load configurations
18
+ uNet = gradioWebUI.uNet
19
+ freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
20
+ VAE_scale = gradioWebUI.VAE_scale
21
+ height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels
22
+
23
+ timesteps = gradioWebUI.timesteps
24
+ VAE_quantizer = gradioWebUI.VAE_quantizer
25
+ VAE_decoder = gradioWebUI.VAE_decoder
26
+ CLAP = gradioWebUI.CLAP
27
+ CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
28
+ device = gradioWebUI.device
29
+ squared = gradioWebUI.squared
30
+ sample_rate = gradioWebUI.sample_rate
31
+ noise_strategy = gradioWebUI.noise_strategy
32
+
33
+ def select_sound(virtual_instrument_name, virtual_instruments_dict):
34
+ virtual_instruments = virtual_instruments_dict["virtual_instruments"]
35
+ virtual_instrument = virtual_instruments[virtual_instrument_name]
36
+
37
+ return {source_sound_spectrogram_image: virtual_instrument["spectrogram_gradio_image"],
38
+ source_sound_phase_image: virtual_instrument["phase_gradio_image"],
39
+ source_sound_audio: virtual_instrument["signal"]}
40
+
41
+ def make_track(inpaint_steps, midi, noising_strength, attack, before_release, instrument_names, virtual_instruments_dict):
42
+
43
+ if noising_strength < 1:
44
+ print(f"Warning: making track with noising_strength = {noising_strength} < 1")
45
+ virtual_instruments = virtual_instruments_dict["virtual_instruments"]
46
+ sample_steps = int(inpaint_steps)
47
+
48
+ instrument_names = instrument_names.split("@")
49
+ instruments_configs = {}
50
+ for virtual_instrument_name in instrument_names:
51
+ virtual_instrument = virtual_instruments[virtual_instrument_name]
52
+
53
+ latent_representation = torch.tensor(virtual_instrument["latent_representation"], dtype=torch.float32).to(device)
54
+ sampler = virtual_instrument["sampler"]
55
+
56
+ batchsize = 1
57
+
58
+ latent_representation = latent_representation.repeat(batchsize, 1, 1, 1)
59
+
60
+ mid = mido.MidiFile(file=BytesIO(midi))
61
+ instruments_configs[virtual_instrument_name] = {
62
+ 'sample_steps': sample_steps,
63
+ 'sampler': sampler,
64
+ 'noising_strength': noising_strength,
65
+ 'latent_representation': latent_representation,
66
+ 'attack': attack,
67
+ 'before_release': before_release}
68
+
69
+ diffSynth = DiffSynth(instruments_configs, uNet, VAE_quantizer, VAE_decoder, CLAP, CLAP_tokenizer, device)
70
+
71
+ full_audio = diffSynth.get_music(mid, instrument_names)
72
+
73
+ return {track_audio: (sample_rate, full_audio)}
74
+
75
+ def test_duration_inpaint(virtual_instrument_name, inpaint_steps, duration, noising_strength, end_noise_level_ratio, attack, before_release, mask_flexivity, virtual_instruments_dict, use_dynamic_mask):
76
+ width = int(time_resolution * ((duration + 1) / 4) / VAE_scale)
77
+
78
+ virtual_instruments = virtual_instruments_dict["virtual_instruments"]
79
+ virtual_instrument = virtual_instruments[virtual_instrument_name]
80
+
81
+ latent_representation = torch.tensor(virtual_instrument["latent_representation"], dtype=torch.float32).to(device)
82
+ sample_steps = int(inpaint_steps)
83
+ sampler = virtual_instrument["sampler"]
84
+ batchsize = 1
85
+
86
+ mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
87
+ mySampler.respace(list(np.linspace(0, timesteps - 1, sample_steps, dtype=np.int32)))
88
+
89
+ latent_representation = latent_representation.repeat(batchsize, 1, 1, 1)
90
+
91
+ # mask = 1, freeze
92
+ latent_mask = torch.zeros((batchsize, 1, height, width), dtype=torch.float32).to(device)
93
+
94
+ latent_mask[:, :, :, :int(time_resolution * (attack / 4) / VAE_scale)] = 1.0
95
+ latent_mask[:, :, :, -int(time_resolution * ((before_release+1) / 4) / VAE_scale):] = 1.0
96
+
97
+
98
+ text2sound_embedding = \
99
+ CLAP.get_text_features(**CLAP_tokenizer([""], padding=True, return_tensors="pt"))[0].to(
100
+ device)
101
+ condition = text2sound_embedding.repeat(1, 1)
102
+
103
+
104
+ latent_representations, initial_noise = \
105
+ mySampler.inpaint_sample(model=uNet, shape=(batchsize, channels, height, width),
106
+ noising_strength=noising_strength,
107
+ guide_img=latent_representation, mask=latent_mask, return_tensor=True,
108
+ condition=condition, sampler=sampler,
109
+ use_dynamic_mask=use_dynamic_mask,
110
+ end_noise_level_ratio=end_noise_level_ratio,
111
+ mask_flexivity=mask_flexivity)
112
+
113
+ latent_representations = latent_representations[-1]
114
+
115
+ quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
116
+ # Todo: remove hard-coding
117
+ flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder,
118
+ quantized_latent_representations,
119
+ resolution=(
120
+ 512,
121
+ width * VAE_scale),
122
+ original_STFT_batch=None
123
+ )
124
+
125
+
126
+ return {test_duration_spectrogram_image: flipped_log_spectrums[0],
127
+ test_duration_phase_image: flipped_phases[0],
128
+ test_duration_audio: (sample_rate, rec_signals[0])}
129
+
130
+ def test_duration_envelope(virtual_instrument_name, duration, noising_strength, attack, before_release, release, virtual_instruments_dict):
131
+
132
+ virtual_instruments = virtual_instruments_dict["virtual_instruments"]
133
+ virtual_instrument = virtual_instruments[virtual_instrument_name]
134
+ sample_rate, signal = virtual_instrument["signal"]
135
+
136
+ applied_signal = adsr_envelope(signal=signal, sample_rate=sample_rate, duration=duration,
137
+ attack_time=0.0, decay_time=0.0, sustain_level=1.0, release_time=release)
138
+
139
+ D = librosa.stft(applied_signal, n_fft=1024, hop_length=256, win_length=1024)[1:, :]
140
+ spc = np.abs(D)
141
+ phase = np.angle(D)
142
+
143
+ flipped_log_spectrum = spectrogram_to_Gradio_image(spc)
144
+ flipped_phase = phase_to_Gradio_image(phase)
145
+
146
+ return {test_duration_spectrogram_image: flipped_log_spectrum,
147
+ test_duration_phase_image: flipped_phase,
148
+ test_duration_audio: (sample_rate, applied_signal)}
149
+
150
+ def test_duration_stretch(virtual_instrument_name, duration, noising_strength, attack, before_release, release, virtual_instruments_dict):
151
+
152
+ virtual_instruments = virtual_instruments_dict["virtual_instruments"]
153
+ virtual_instrument = virtual_instruments[virtual_instrument_name]
154
+ sample_rate, signal = virtual_instrument["signal"]
155
+
156
+ s = 3 / duration
157
+ applied_signal = pyrb.time_stretch(signal, sample_rate, s)
158
+ applied_signal = adjust_audio_length(applied_signal, int((duration+1) * sample_rate), sample_rate, sample_rate)
159
+
160
+ D = librosa.stft(applied_signal, n_fft=1024, hop_length=256, win_length=1024)[1:, :]
161
+ spc = np.abs(D)
162
+ phase = np.angle(D)
163
+
164
+ flipped_log_spectrum = spectrogram_to_Gradio_image(spc)
165
+ flipped_phase = phase_to_Gradio_image(phase)
166
+
167
+ return {test_duration_spectrogram_image: flipped_log_spectrum,
168
+ test_duration_phase_image: flipped_phase,
169
+ test_duration_audio: (sample_rate, applied_signal)}
170
+
171
+
172
+ with gr.Tab("TestInTrack"):
173
+ gr.Markdown("Make music with generated sounds!")
174
+ with gr.Row(variant="panel"):
175
+ with gr.Column(scale=3):
176
+ instrument_name_textbox = gr.Textbox(label="Instrument name", lines=1,
177
+ placeholder="Name of your instrument", scale=1)
178
+ select_instrument_button = gr.Button(variant="primary", value="Select", scale=1)
179
+ with gr.Column(scale=3):
180
+ inpaint_steps_slider = gr.Slider(minimum=5.0, maximum=999.0, value=20.0, step=1.0, label="inpaint_steps")
181
+ noising_strength_slider = gradioWebUI.get_noising_strength_slider(default_noising_strength=1.)
182
+ end_noise_level_ratio_slider = gr.Slider(minimum=0.0, maximum=1., value=0.0, step=0.01, label="end_noise_level_ratio")
183
+ attack_slider = gr.Slider(minimum=0.0, maximum=1.5, value=0.5, step=0.01, label="attack in sec")
184
+ before_release_slider = gr.Slider(minimum=0.0, maximum=1.5, value=0.5, step=0.01, label="before_release in sec")
185
+ release_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="release in sec")
186
+ mask_flexivity_slider = gr.Slider(minimum=0.01, maximum=1.00, value=1., step=0.01, label="mask_flexivity")
187
+ with gr.Column(scale=3):
188
+ use_dynamic_mask_checkbox = gr.Checkbox(label="Use dynamic mask", value=True)
189
+ test_duration_envelope_button = gr.Button(variant="primary", value="Apply envelope", scale=1)
190
+ test_duration_stretch_button = gr.Button(variant="primary", value="Apply stretch", scale=1)
191
+ test_duration_inpaint_button = gr.Button(variant="primary", value="Inpaint different duration", scale=1)
192
+ duration_slider = gradioWebUI.get_duration_slider()
193
+
194
+ with gr.Row(variant="panel"):
195
+ with gr.Column(scale=2):
196
+ with gr.Row(variant="panel"):
197
+ source_sound_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy",
198
+ height=600, scale=1)
199
+ source_sound_phase_image = gr.Image(label="New sound phase", type="numpy",
200
+ height=600, scale=1)
201
+ source_sound_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False)
202
+
203
+ with gr.Column(scale=3):
204
+ with gr.Row(variant="panel"):
205
+ test_duration_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy",
206
+ height=600, scale=1)
207
+ test_duration_phase_image = gr.Image(label="New sound phase", type="numpy",
208
+ height=600, scale=1)
209
+ test_duration_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False)
210
+
211
+ with gr.Row(variant="panel"):
212
+ with gr.Column(scale=1):
213
+ # track_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy",
214
+ # height=420, scale=1)
215
+ midi_file = gr.File(label="Upload midi file", type="binary")
216
+ instrument_names_textbox = gr.Textbox(label="Instrument names", lines=2,
217
+ placeholder="Names of your instrument used to play the midi", scale=1)
218
+ track_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False)
219
+ make_track_button = gr.Button(variant="primary", value="Make track", scale=1)
220
+
221
+ select_instrument_button.click(select_sound,
222
+ inputs=[instrument_name_textbox, virtual_instruments_state],
223
+ outputs=[source_sound_spectrogram_image,
224
+ source_sound_phase_image,
225
+ source_sound_audio])
226
+
227
+ test_duration_envelope_button.click(test_duration_envelope,
228
+ inputs=[instrument_name_textbox, duration_slider,
229
+ noising_strength_slider,
230
+ attack_slider,
231
+ before_release_slider,
232
+ release_slider,
233
+ virtual_instruments_state,
234
+ ],
235
+ outputs=[test_duration_spectrogram_image,
236
+ test_duration_phase_image,
237
+ test_duration_audio])
238
+
239
+ test_duration_stretch_button.click(test_duration_stretch,
240
+ inputs=[instrument_name_textbox, duration_slider,
241
+ noising_strength_slider,
242
+ attack_slider,
243
+ before_release_slider,
244
+ release_slider,
245
+ virtual_instruments_state,
246
+ ],
247
+ outputs=[test_duration_spectrogram_image,
248
+ test_duration_phase_image,
249
+ test_duration_audio])
250
+
251
+ test_duration_inpaint_button.click(test_duration_inpaint,
252
+ inputs=[instrument_name_textbox,
253
+ inpaint_steps_slider,
254
+ duration_slider,
255
+ noising_strength_slider,
256
+ end_noise_level_ratio_slider,
257
+ attack_slider,
258
+ before_release_slider,
259
+ mask_flexivity_slider,
260
+ virtual_instruments_state,
261
+ use_dynamic_mask_checkbox],
262
+ outputs=[test_duration_spectrogram_image,
263
+ test_duration_phase_image,
264
+ test_duration_audio])
265
+
266
+ make_track_button.click(make_track,
267
+ inputs=[inpaint_steps_slider, midi_file,
268
+ noising_strength_slider,
269
+ attack_slider,
270
+ before_release_slider,
271
+ instrument_names_textbox,
272
+ virtual_instruments_state],
273
+ outputs=[track_audio])
274
+
webUI/natural_language_guided/gradio_webUI.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ class GradioWebUI():
5
+
6
+ def __init__(self, device, VAE, uNet, CLAP, CLAP_tokenizer,
7
+ freq_resolution=512, time_resolution=256, channels=4, timesteps=1000,
8
+ sample_rate=16000, squared=False, VAE_scale=4,
9
+ flexible_duration=False, noise_strategy="repeat",
10
+ GAN_generator = None):
11
+ self.device = device
12
+ self.VAE_encoder, self.VAE_quantizer, self.VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
13
+ self.uNet = uNet
14
+ self.CLAP, self.CLAP_tokenizer = CLAP, CLAP_tokenizer
15
+ self.freq_resolution, self.time_resolution = freq_resolution, time_resolution
16
+ self.channels = channels
17
+ self.GAN_generator = GAN_generator
18
+
19
+ self.timesteps = timesteps
20
+ self.sample_rate = sample_rate
21
+ self.squared = squared
22
+ self.VAE_scale = VAE_scale
23
+ self.flexible_duration = flexible_duration
24
+ self.noise_strategy = noise_strategy
25
+
26
+ self.text2sound_state = gr.State(value={})
27
+ self.interpolation_state = gr.State(value={})
28
+ self.sound2sound_state = gr.State(value={})
29
+ self.inpaint_state = gr.State(value={})
30
+
31
+ def get_sample_steps_slider(self):
32
+ default_steps = 10 if (self.device == "cpu") else 20
33
+ return gr.Slider(minimum=10, maximum=100, value=default_steps, step=1,
34
+ label="Sample steps",
35
+ info="Sampling steps. The more sampling steps, the better the "
36
+ "theoretical result, but the time it consumes.")
37
+
38
+ def get_sampler_radio(self):
39
+ # return gr.Radio(choices=["ddpm", "ddim", "dpmsolver++", "dpmsolver"], value="ddim", label="Sampler")
40
+ return gr.Radio(choices=["ddpm", "ddim"], value="ddim", label="Sampler")
41
+
42
+ def get_batchsize_slider(self, cpu_batchsize=1):
43
+ return gr.Slider(minimum=1., maximum=16, value=cpu_batchsize if (self.device == "cpu") else 8, step=1, label="Batchsize")
44
+
45
+ def get_time_resolution_slider(self):
46
+ return gr.Slider(minimum=16., maximum=int(1024/self.VAE_scale), value=int(256/self.VAE_scale), step=1, label="Time resolution", interactive=True)
47
+
48
+ def get_duration_slider(self):
49
+ if self.flexible_duration:
50
+ return gr.Slider(minimum=0.25, maximum=8., value=3., step=0.01, label="duration in sec")
51
+ else:
52
+ return gr.Slider(minimum=1., maximum=8., value=3., step=1., label="duration in sec")
53
+
54
+ def get_guidance_scale_slider(self):
55
+ return gr.Slider(minimum=0., maximum=20., value=6., step=1.,
56
+ label="Guidance scale",
57
+ info="The larger this value, the more the generated sound is "
58
+ "influenced by the condition. Setting it to 0 is equivalent to "
59
+ "the negative case.")
60
+
61
+ def get_noising_strength_slider(self, default_noising_strength=0.7):
62
+ return gr.Slider(minimum=0.0, maximum=1.00, value=default_noising_strength, step=0.01,
63
+ label="noising strength",
64
+ info="The smaller this value, the more the generated sound is "
65
+ "closed to the origin.")
66
+
67
+ def get_seed_textbox(self):
68
+ return gr.Textbox(label="Seed", lines=1, placeholder="seed", value=0)
webUI/natural_language_guided/inpaint_with_text.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import torch
4
+ import gradio as gr
5
+ from scipy.ndimage import zoom
6
+
7
+ from model.DiffSynthSampler import DiffSynthSampler
8
+ from tools import adjust_audio_length, safe_int, pad_STFT, encode_stft
9
+ from webUI.natural_language_guided.utils import latent_representation_to_Gradio_image, InputBatch2Encode_STFT, encodeBatch2GradioOutput_STFT, add_instrument
10
+
11
+
12
+ def get_triangle_mask(height, width):
13
+ mask = np.zeros((height, width))
14
+ slope = 8 / 3
15
+ for i in range(height):
16
+ for j in range(width):
17
+ if i < slope * j:
18
+ mask[i, j] = 1
19
+ return mask
20
+
21
+
22
+ def get_inpaint_with_text_module(gradioWebUI, inpaintWithText_state, virtual_instruments_state):
23
+ # Load configurations
24
+ uNet = gradioWebUI.uNet
25
+ freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
26
+ VAE_scale = gradioWebUI.VAE_scale
27
+ height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels
28
+ timesteps = gradioWebUI.timesteps
29
+ VAE_encoder = gradioWebUI.VAE_encoder
30
+ VAE_quantizer = gradioWebUI.VAE_quantizer
31
+ VAE_decoder = gradioWebUI.VAE_decoder
32
+ CLAP = gradioWebUI.CLAP
33
+ CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
34
+ device = gradioWebUI.device
35
+ squared = gradioWebUI.squared
36
+ sample_rate = gradioWebUI.sample_rate
37
+ noise_strategy = gradioWebUI.noise_strategy
38
+
39
+ def receive_uopoad_origin_audio(sound2sound_duration, sound2sound_origin_source, sound2sound_origin_upload, sound2sound_origin_microphone,
40
+ inpaintWithText_dict):
41
+
42
+ if sound2sound_origin_source == "upload":
43
+ origin_sr, origin_audio = sound2sound_origin_upload
44
+ else:
45
+ origin_sr, origin_audio = sound2sound_origin_microphone
46
+
47
+ origin_audio = origin_audio / np.max(np.abs(origin_audio))
48
+
49
+ width = int(time_resolution*((sound2sound_duration+1)/4) / VAE_scale)
50
+ audio_length = 256 * (VAE_scale * width - 1)
51
+ origin_audio = adjust_audio_length(origin_audio, audio_length, origin_sr, sample_rate)
52
+
53
+ D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024)
54
+ padded_D = pad_STFT(D)
55
+ encoded_D = encode_stft(padded_D)
56
+
57
+ # Todo: justify batchsize to 1
58
+ origin_spectrogram_batch_tensor = torch.from_numpy(
59
+ np.repeat(encoded_D[np.newaxis, :, :, :], 1, axis=0)).float().to(device)
60
+
61
+ # Todo: remove hard-coding
62
+ origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, origin_latent_representations, quantized_origin_latent_representations = InputBatch2Encode_STFT(
63
+ VAE_encoder, origin_spectrogram_batch_tensor, resolution=(512, width * VAE_scale), quantizer=VAE_quantizer, squared=squared)
64
+
65
+ if sound2sound_origin_source == "upload":
66
+ inpaintWithText_dict["origin_upload_latent_representations"] = origin_latent_representations.tolist()
67
+ inpaintWithText_dict[
68
+ "sound2sound_origin_upload_latent_representation_image"] = latent_representation_to_Gradio_image(
69
+ origin_latent_representations[0]).tolist()
70
+ inpaintWithText_dict[
71
+ "sound2sound_origin_upload_quantized_latent_representation_image"] = latent_representation_to_Gradio_image(
72
+ quantized_origin_latent_representations[0]).tolist()
73
+ return {sound2sound_origin_spectrogram_upload_image: origin_flipped_log_spectrums[0],
74
+ sound2sound_origin_phase_upload_image: origin_flipped_phases[0],
75
+ sound2sound_origin_spectrogram_microphone_image: gr.update(),
76
+ sound2sound_origin_phase_microphone_image: gr.update(),
77
+ sound2sound_origin_upload_latent_representation_image: latent_representation_to_Gradio_image(
78
+ origin_latent_representations[0]),
79
+ sound2sound_origin_upload_quantized_latent_representation_image: latent_representation_to_Gradio_image(
80
+ quantized_origin_latent_representations[0]),
81
+ sound2sound_origin_microphone_latent_representation_image: gr.update(),
82
+ sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(),
83
+ inpaintWithText_state: inpaintWithText_dict}
84
+ else:
85
+ inpaintWithText_dict["origin_microphone_latent_representations"] = origin_latent_representations.tolist()
86
+ inpaintWithText_dict[
87
+ "sound2sound_origin_microphone_latent_representation_image"] = latent_representation_to_Gradio_image(
88
+ origin_latent_representations[0]).tolist()
89
+ inpaintWithText_dict[
90
+ "sound2sound_origin_microphone_quantized_latent_representation_image"] = latent_representation_to_Gradio_image(
91
+ quantized_origin_latent_representations[0]).tolist()
92
+ return {sound2sound_origin_spectrogram_upload_image: origin_flipped_log_spectrums[0],
93
+ sound2sound_origin_phase_upload_image: origin_flipped_phases[0],
94
+ sound2sound_origin_spectrogram_microphone_image: gr.update(),
95
+ sound2sound_origin_phase_microphone_image: gr.update(),
96
+ sound2sound_origin_upload_latent_representation_image: latent_representation_to_Gradio_image(
97
+ origin_latent_representations[0]),
98
+ sound2sound_origin_upload_quantized_latent_representation_image: latent_representation_to_Gradio_image(
99
+ quantized_origin_latent_representations[0]),
100
+ sound2sound_origin_microphone_latent_representation_image: gr.update(),
101
+ sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(),
102
+ inpaintWithText_state: inpaintWithText_dict}
103
+
104
+ def sound2sound_sample(sound2sound_origin_spectrogram_upload, sound2sound_origin_spectrogram_microphone,
105
+ text2sound_prompts, text2sound_negative_prompts, sound2sound_batchsize,
106
+ sound2sound_guidance_scale, sound2sound_sampler,
107
+ sound2sound_sample_steps, sound2sound_origin_source,
108
+ sound2sound_noising_strength, sound2sound_seed, sound2sound_inpaint_area,
109
+ mask_time_begin, mask_time_end, mask_frequency_begin, mask_frequency_end, inpaintWithText_dict
110
+ ):
111
+
112
+ # input preprocessing
113
+ sound2sound_seed = safe_int(sound2sound_seed, 12345678)
114
+ sound2sound_batchsize = int(sound2sound_batchsize)
115
+ noising_strength = sound2sound_noising_strength
116
+ sound2sound_sample_steps = int(sound2sound_sample_steps)
117
+ CFG = int(sound2sound_guidance_scale)
118
+
119
+ text2sound_embedding = \
120
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(device)
121
+
122
+ if sound2sound_origin_source == "upload":
123
+ origin_latent_representations = torch.tensor(
124
+ inpaintWithText_dict["origin_upload_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to(
125
+ device)
126
+ mask = np.array(sound2sound_origin_spectrogram_upload["mask"])
127
+ elif sound2sound_origin_source == "microphone":
128
+ origin_latent_representations = torch.tensor(
129
+ inpaintWithText_dict["origin_microphone_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to(
130
+ device)
131
+ mask = np.array(sound2sound_origin_spectrogram_microphone["mask"])
132
+ else:
133
+ print("Input source not in ['upload', 'microphone']!")
134
+ raise NotImplementedError()
135
+
136
+ merged_mask = np.all(mask == 255, axis=2).astype(np.uint8)
137
+ latent_mask = zoom(merged_mask, (1 / VAE_scale, 1 / VAE_scale))
138
+ latent_mask = np.clip(latent_mask, 0, 1)
139
+ print(f"latent_mask.avg = {np.mean(latent_mask)}")
140
+ latent_mask[int(mask_frequency_begin):int(mask_frequency_end), int(mask_time_begin*time_resolution/(VAE_scale*4)):int(mask_time_end*time_resolution/(VAE_scale*4))] = 1
141
+
142
+
143
+ # latent_mask = get_triangle_mask(128, 64)
144
+
145
+ print(f"latent_mask.avg = {np.mean(latent_mask)}")
146
+ if sound2sound_inpaint_area == "inpaint masked":
147
+ latent_mask = 1 - latent_mask
148
+ latent_mask = torch.from_numpy(latent_mask).unsqueeze(0).unsqueeze(1).repeat(sound2sound_batchsize, channels, 1,
149
+ 1).float().to(device)
150
+ latent_mask = torch.flip(latent_mask, [2])
151
+
152
+ mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
153
+ unconditional_condition = \
154
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[0]
155
+ mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device))
156
+
157
+ normalized_sample_steps = int(sound2sound_sample_steps / noising_strength)
158
+
159
+ mySampler.respace(list(np.linspace(0, timesteps - 1, normalized_sample_steps, dtype=np.int32)))
160
+
161
+ # Todo: remove hard-coding
162
+ width = origin_latent_representations.shape[-1]
163
+ condition = text2sound_embedding.repeat(sound2sound_batchsize, 1)
164
+
165
+ new_sound_latent_representations, initial_noise = \
166
+ mySampler.inpaint_sample(model=uNet, shape=(sound2sound_batchsize, channels, height, width),
167
+ seed=sound2sound_seed,
168
+ noising_strength=noising_strength,
169
+ guide_img=origin_latent_representations, mask=latent_mask, return_tensor=True,
170
+ condition=condition, sampler=sound2sound_sampler)
171
+
172
+ new_sound_latent_representations = new_sound_latent_representations[-1]
173
+
174
+ # Quantize new sound latent representations
175
+ quantized_new_sound_latent_representations, loss, (_, _, _) = VAE_quantizer(new_sound_latent_representations)
176
+ new_sound_flipped_log_spectrums, new_sound_flipped_phases, new_sound_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder,
177
+ quantized_new_sound_latent_representations,
178
+ resolution=(
179
+ 512,
180
+ width * VAE_scale),
181
+ original_STFT_batch=None
182
+ )
183
+
184
+ new_sound_latent_representation_gradio_images = []
185
+ new_sound_quantized_latent_representation_gradio_images = []
186
+ new_sound_spectrogram_gradio_images = []
187
+ new_sound_phase_gradio_images = []
188
+ new_sound_rec_signals_gradio = []
189
+ for i in range(sound2sound_batchsize):
190
+ new_sound_latent_representation_gradio_images.append(
191
+ latent_representation_to_Gradio_image(new_sound_latent_representations[i]))
192
+ new_sound_quantized_latent_representation_gradio_images.append(
193
+ latent_representation_to_Gradio_image(quantized_new_sound_latent_representations[i]))
194
+ new_sound_spectrogram_gradio_images.append(new_sound_flipped_log_spectrums[i])
195
+ new_sound_phase_gradio_images.append(new_sound_flipped_phases[i])
196
+ new_sound_rec_signals_gradio.append((sample_rate, new_sound_signals[i]))
197
+
198
+ inpaintWithText_dict[
199
+ "new_sound_latent_representation_gradio_images"] = new_sound_latent_representation_gradio_images
200
+ inpaintWithText_dict[
201
+ "new_sound_quantized_latent_representation_gradio_images"] = new_sound_quantized_latent_representation_gradio_images
202
+ inpaintWithText_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
203
+ inpaintWithText_dict["new_sound_phase_gradio_images"] = new_sound_phase_gradio_images
204
+ inpaintWithText_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
205
+
206
+ inpaintWithText_dict["latent_representations"] = new_sound_latent_representations.to("cpu").detach().numpy()
207
+ inpaintWithText_dict["quantized_latent_representations"] = quantized_new_sound_latent_representations.to("cpu").detach().numpy()
208
+ inpaintWithText_dict["sampler"] = sound2sound_sampler
209
+
210
+ return {sound2sound_new_sound_latent_representation_image: latent_representation_to_Gradio_image(
211
+ new_sound_latent_representations[0]),
212
+ sound2sound_new_sound_quantized_latent_representation_image: latent_representation_to_Gradio_image(
213
+ quantized_new_sound_latent_representations[0]),
214
+ sound2sound_new_sound_spectrogram_image: new_sound_flipped_log_spectrums[0],
215
+ sound2sound_new_sound_phase_image: new_sound_flipped_phases[0],
216
+ sound2sound_new_sound_audio: (sample_rate, new_sound_signals[0]),
217
+ sound2sound_sample_index_slider: gr.update(minimum=0, maximum=sound2sound_batchsize - 1, value=0,
218
+ step=1.0,
219
+ visible=True,
220
+ label="Sample index",
221
+ info="Swipe to view other samples"),
222
+ sound2sound_seed_textbox: sound2sound_seed,
223
+ inpaintWithText_state: inpaintWithText_dict}
224
+
225
+ def show_sound2sound_sample(sound2sound_sample_index, inpaintWithText_dict):
226
+ sample_index = int(sound2sound_sample_index)
227
+ return {sound2sound_new_sound_latent_representation_image:
228
+ inpaintWithText_dict["new_sound_latent_representation_gradio_images"][sample_index],
229
+ sound2sound_new_sound_quantized_latent_representation_image:
230
+ inpaintWithText_dict["new_sound_quantized_latent_representation_gradio_images"][sample_index],
231
+ sound2sound_new_sound_spectrogram_image: inpaintWithText_dict["new_sound_spectrogram_gradio_images"][
232
+ sample_index],
233
+ sound2sound_new_sound_phase_image: inpaintWithText_dict["new_sound_phase_gradio_images"][
234
+ sample_index],
235
+ sound2sound_new_sound_audio: inpaintWithText_dict["new_sound_rec_signals_gradio"][sample_index]}
236
+
237
+ def sound2sound_switch_origin_source(sound2sound_origin_source):
238
+
239
+ if sound2sound_origin_source == "upload":
240
+ return {sound2sound_origin_upload_audio: gr.update(visible=True),
241
+ sound2sound_origin_microphone_audio: gr.update(visible=False),
242
+ sound2sound_origin_spectrogram_upload_image: gr.update(visible=True),
243
+ sound2sound_origin_phase_upload_image: gr.update(visible=True),
244
+ sound2sound_origin_spectrogram_microphone_image: gr.update(visible=False),
245
+ sound2sound_origin_phase_microphone_image: gr.update(visible=False),
246
+ sound2sound_origin_upload_latent_representation_image: gr.update(visible=True),
247
+ sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=True),
248
+ sound2sound_origin_microphone_latent_representation_image: gr.update(visible=False),
249
+ sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=False)}
250
+ elif sound2sound_origin_source == "microphone":
251
+ return {sound2sound_origin_upload_audio: gr.update(visible=False),
252
+ sound2sound_origin_microphone_audio: gr.update(visible=True),
253
+ sound2sound_origin_spectrogram_upload_image: gr.update(visible=False),
254
+ sound2sound_origin_phase_upload_image: gr.update(visible=False),
255
+ sound2sound_origin_spectrogram_microphone_image: gr.update(visible=True),
256
+ sound2sound_origin_phase_microphone_image: gr.update(visible=True),
257
+ sound2sound_origin_upload_latent_representation_image: gr.update(visible=False),
258
+ sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=False),
259
+ sound2sound_origin_microphone_latent_representation_image: gr.update(visible=True),
260
+ sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=True)}
261
+ else:
262
+ print("Input source not in ['upload', 'microphone']!")
263
+
264
+ def save_virtual_instrument(sample_index, virtual_instrument_name, sound2sound_dict, virtual_instruments_dict):
265
+
266
+ virtual_instruments_dict = add_instrument(sound2sound_dict, virtual_instruments_dict, virtual_instrument_name, sample_index)
267
+ return {virtual_instruments_state: virtual_instruments_dict,
268
+ sound2sound_instrument_name_textbox: gr.Textbox(label="Instrument name", lines=1,
269
+ placeholder=f"Saved as {virtual_instrument_name}!")}
270
+
271
+ with gr.Tab("Inpaint"):
272
+ gr.Markdown("Select the area to inpaint and use the prompt to guide the synthesis of a new sound!")
273
+ with gr.Row(variant="panel"):
274
+ with gr.Column(scale=3):
275
+ text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ")
276
+ text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
277
+
278
+ with gr.Column(scale=1):
279
+ sound2sound_sample_button = gr.Button(variant="primary", value="Generate", scale=1)
280
+
281
+ sound2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
282
+ label="Sample index",
283
+ info="Swipe to view other samples")
284
+
285
+ with gr.Row(variant="panel"):
286
+ with gr.Column(scale=1):
287
+ with gr.Tab("Origin sound"):
288
+ sound2sound_duration_slider = gradioWebUI.get_duration_slider()
289
+ sound2sound_origin_source_radio = gr.Radio(choices=["upload", "microphone"], value="upload",
290
+ label="Input source")
291
+
292
+ sound2sound_origin_upload_audio = gr.Audio(type="numpy", label="Upload", source="upload",
293
+ interactive=True, visible=True)
294
+ sound2sound_origin_microphone_audio = gr.Audio(type="numpy", label="Record", source="microphone",
295
+ interactive=True, visible=False)
296
+ with gr.Row(variant="panel"):
297
+ sound2sound_origin_spectrogram_upload_image = gr.Image(label="Original upload spectrogram",
298
+ type="numpy", height=600,
299
+ visible=True, tool="sketch")
300
+ sound2sound_origin_phase_upload_image = gr.Image(label="Original upload phase",
301
+ type="numpy", height=600,
302
+ visible=True)
303
+ sound2sound_origin_spectrogram_microphone_image = gr.Image(label="Original microphone spectrogram",
304
+ type="numpy", height=600,
305
+ visible=False, tool="sketch")
306
+ sound2sound_origin_phase_microphone_image = gr.Image(label="Original microphone phase",
307
+ type="numpy", height=600,
308
+ visible=False)
309
+ sound2sound_inpaint_area_radio = gr.Radio(choices=["inpaint masked", "inpaint not masked"],
310
+ value="inpaint masked")
311
+
312
+ with gr.Tab("Sound2sound settings"):
313
+ sound2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
314
+ sound2sound_sampler_radio = gradioWebUI.get_sampler_radio()
315
+ sound2sound_batchsize_slider = gradioWebUI.get_batchsize_slider()
316
+ sound2sound_noising_strength_slider = gradioWebUI.get_noising_strength_slider()
317
+ sound2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
318
+ sound2sound_seed_textbox = gradioWebUI.get_seed_textbox()
319
+
320
+ with gr.Tab("Mask prototypes"):
321
+ with gr.Tab("Mask along time axis"):
322
+ mask_time_begin_slider = gr.Slider(minimum=0.0, maximum=4.00, value=0.0, step=0.01, label="Begin time")
323
+ mask_time_end_slider = gr.Slider(minimum=0.0, maximum=4.00, value=0.0, step=0.01, label="End time")
324
+ with gr.Tab("Mask along frequency axis"):
325
+ mask_frequency_begin_slider = gr.Slider(minimum=0, maximum=127, value=0, step=1, label="Begin freq pixel")
326
+ mask_frequency_end_slider = gr.Slider(minimum=0, maximum=127, value=0, step=1, label="End freq pixel")
327
+
328
+ with gr.Column(scale=1):
329
+ sound2sound_new_sound_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False)
330
+ with gr.Row(variant="panel"):
331
+ sound2sound_new_sound_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy",
332
+ height=600, scale=1)
333
+ sound2sound_new_sound_phase_image = gr.Image(label="New sound phase", type="numpy",
334
+ height=600, scale=1)
335
+
336
+
337
+ with gr.Row(variant="panel"):
338
+ sound2sound_instrument_name_textbox = gr.Textbox(label="Instrument name", lines=1,
339
+ placeholder="Name of your instrument")
340
+ sound2sound_save_instrument_button = gr.Button(variant="primary",
341
+ value="Save instrument",
342
+ scale=1)
343
+
344
+ with gr.Row(variant="panel"):
345
+ sound2sound_origin_upload_latent_representation_image = gr.Image(label="Original latent representation",
346
+ type="numpy", height=800,
347
+ visible=True)
348
+ sound2sound_origin_upload_quantized_latent_representation_image = gr.Image(
349
+ label="Original quantized latent representation", type="numpy", height=800, visible=True)
350
+
351
+ sound2sound_origin_microphone_latent_representation_image = gr.Image(label="Original latent representation",
352
+ type="numpy", height=800,
353
+ visible=False)
354
+ sound2sound_origin_microphone_quantized_latent_representation_image = gr.Image(
355
+ label="Original quantized latent representation", type="numpy", height=800, visible=False)
356
+
357
+ sound2sound_new_sound_latent_representation_image = gr.Image(label="New latent representation",
358
+ type="numpy", height=800)
359
+ sound2sound_new_sound_quantized_latent_representation_image = gr.Image(
360
+ label="New sound quantized latent representation", type="numpy", height=800)
361
+
362
+ sound2sound_origin_upload_audio.change(receive_uopoad_origin_audio,
363
+ inputs=[sound2sound_duration_slider, sound2sound_origin_source_radio, sound2sound_origin_upload_audio,
364
+ sound2sound_origin_microphone_audio, inpaintWithText_state],
365
+ outputs=[sound2sound_origin_spectrogram_upload_image,
366
+ sound2sound_origin_phase_upload_image,
367
+ sound2sound_origin_spectrogram_microphone_image,
368
+ sound2sound_origin_phase_microphone_image,
369
+ sound2sound_origin_upload_latent_representation_image,
370
+ sound2sound_origin_upload_quantized_latent_representation_image,
371
+ sound2sound_origin_microphone_latent_representation_image,
372
+ sound2sound_origin_microphone_quantized_latent_representation_image,
373
+ inpaintWithText_state])
374
+ sound2sound_origin_microphone_audio.change(receive_uopoad_origin_audio,
375
+ inputs=[sound2sound_duration_slider, sound2sound_origin_source_radio, sound2sound_origin_upload_audio,
376
+ sound2sound_origin_microphone_audio, inpaintWithText_state],
377
+ outputs=[sound2sound_origin_spectrogram_upload_image,
378
+ sound2sound_origin_phase_upload_image,
379
+ sound2sound_origin_spectrogram_microphone_image,
380
+ sound2sound_origin_phase_microphone_image,
381
+ sound2sound_origin_upload_latent_representation_image,
382
+ sound2sound_origin_upload_quantized_latent_representation_image,
383
+ sound2sound_origin_microphone_latent_representation_image,
384
+ sound2sound_origin_microphone_quantized_latent_representation_image,
385
+ inpaintWithText_state])
386
+
387
+ sound2sound_sample_button.click(sound2sound_sample,
388
+ inputs=[sound2sound_origin_spectrogram_upload_image,
389
+ sound2sound_origin_spectrogram_microphone_image,
390
+ text2sound_prompts_textbox,
391
+ text2sound_negative_prompts_textbox,
392
+ sound2sound_batchsize_slider,
393
+ sound2sound_guidance_scale_slider,
394
+ sound2sound_sampler_radio,
395
+ sound2sound_sample_steps_slider,
396
+ sound2sound_origin_source_radio,
397
+ sound2sound_noising_strength_slider,
398
+ sound2sound_seed_textbox,
399
+ sound2sound_inpaint_area_radio,
400
+ mask_time_begin_slider,
401
+ mask_time_end_slider,
402
+ mask_frequency_begin_slider,
403
+ mask_frequency_end_slider,
404
+ inpaintWithText_state],
405
+ outputs=[sound2sound_new_sound_latent_representation_image,
406
+ sound2sound_new_sound_quantized_latent_representation_image,
407
+ sound2sound_new_sound_spectrogram_image,
408
+ sound2sound_new_sound_phase_image,
409
+ sound2sound_new_sound_audio,
410
+ sound2sound_sample_index_slider,
411
+ sound2sound_seed_textbox,
412
+ inpaintWithText_state])
413
+
414
+ sound2sound_sample_index_slider.change(show_sound2sound_sample,
415
+ inputs=[sound2sound_sample_index_slider, inpaintWithText_state],
416
+ outputs=[sound2sound_new_sound_latent_representation_image,
417
+ sound2sound_new_sound_quantized_latent_representation_image,
418
+ sound2sound_new_sound_spectrogram_image,
419
+ sound2sound_new_sound_phase_image,
420
+ sound2sound_new_sound_audio])
421
+
422
+ sound2sound_origin_source_radio.change(sound2sound_switch_origin_source,
423
+ inputs=[sound2sound_origin_source_radio],
424
+ outputs=[sound2sound_origin_upload_audio,
425
+ sound2sound_origin_microphone_audio,
426
+ sound2sound_origin_spectrogram_upload_image,
427
+ sound2sound_origin_phase_upload_image,
428
+ sound2sound_origin_spectrogram_microphone_image,
429
+ sound2sound_origin_phase_microphone_image,
430
+ sound2sound_origin_upload_latent_representation_image,
431
+ sound2sound_origin_upload_quantized_latent_representation_image,
432
+ sound2sound_origin_microphone_latent_representation_image,
433
+ sound2sound_origin_microphone_quantized_latent_representation_image])
434
+
435
+ sound2sound_save_instrument_button.click(save_virtual_instrument,
436
+ inputs=[sound2sound_sample_index_slider,
437
+ sound2sound_instrument_name_textbox,
438
+ inpaintWithText_state,
439
+ virtual_instruments_state],
440
+ outputs=[virtual_instruments_state,
441
+ sound2sound_instrument_name_textbox])
webUI/natural_language_guided/rec.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from data_generation.nsynth import get_nsynth_dataloader
4
+ from webUI.natural_language_guided_STFT.utils import encodeBatch2GradioOutput_STFT, InputBatch2Encode_STFT, \
5
+ latent_representation_to_Gradio_image
6
+
7
+
8
+ def get_recSTFT_module(gradioWebUI, reconstruction_state):
9
+ # Load configurations
10
+ uNet = gradioWebUI.uNet
11
+ freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
12
+ VAE_scale = gradioWebUI.VAE_scale
13
+ height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels
14
+
15
+ timesteps = gradioWebUI.timesteps
16
+ VAE_quantizer = gradioWebUI.VAE_quantizer
17
+ VAE_encoder = gradioWebUI.VAE_encoder
18
+ VAE_decoder = gradioWebUI.VAE_decoder
19
+ CLAP = gradioWebUI.CLAP
20
+ CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
21
+ device = gradioWebUI.device
22
+ squared = gradioWebUI.squared
23
+ sample_rate = gradioWebUI.sample_rate
24
+ noise_strategy = gradioWebUI.noise_strategy
25
+
26
+ def generate_reconstruction_samples(sample_source, batchsize_slider, encodeCache,
27
+ reconstruction_samples):
28
+
29
+ vae_batchsize = int(batchsize_slider)
30
+
31
+ if sample_source == "text2sound_trainSTFT":
32
+ training_dataset_path = f'data/NSynth/nsynth-STFT-train-52.hdf5' # Make sure to use your actual path
33
+ iterator = get_nsynth_dataloader(training_dataset_path, batch_size=vae_batchsize, shuffle=True,
34
+ get_latent_representation=False, with_meta_data=False,
35
+ task="STFT")
36
+ elif sample_source == "text2sound_validSTFT":
37
+ training_dataset_path = f'data/NSynth/nsynth-STFT-valid-52.hdf5' # Make sure to use your actual path
38
+ iterator = get_nsynth_dataloader(training_dataset_path, batch_size=vae_batchsize, shuffle=True,
39
+ get_latent_representation=False, with_meta_data=False,
40
+ task="STFT")
41
+ elif sample_source == "text2sound_testSTFT":
42
+ training_dataset_path = f'data/NSynth/nsynth-STFT-test-52.hdf5' # Make sure to use your actual path
43
+ iterator = get_nsynth_dataloader(training_dataset_path, batch_size=vae_batchsize, shuffle=True,
44
+ get_latent_representation=False, with_meta_data=False,
45
+ task="STFT")
46
+ else:
47
+ raise NotImplementedError()
48
+
49
+ spectrogram_batch = next(iter(iterator))
50
+
51
+ origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, latent_representations, quantized_latent_representations = InputBatch2Encode_STFT(
52
+ VAE_encoder, spectrogram_batch, resolution=(512, width * VAE_scale), quantizer=VAE_quantizer, squared=squared)
53
+
54
+ latent_representation_gradio_images, quantized_latent_representation_gradio_images = [], []
55
+ for i in range(vae_batchsize):
56
+ latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
57
+ quantized_latent_representation_gradio_images.append(
58
+ latent_representation_to_Gradio_image(quantized_latent_representations[i]))
59
+
60
+ if quantized_latent_representations is None:
61
+ quantized_latent_representations = latent_representations
62
+ reconstruction_flipped_log_spectrums, reconstruction_flipped_phases, reconstruction_signals, reconstruction_flipped_log_spectrums_WOA, reconstruction_flipped_phases_WOA, reconstruction_signals_WOA = encodeBatch2GradioOutput_STFT(VAE_decoder,
63
+ quantized_latent_representations,
64
+ resolution=(
65
+ 512,
66
+ width * VAE_scale),
67
+ original_STFT_batch=spectrogram_batch
68
+ )
69
+
70
+ reconstruction_samples["origin_flipped_log_spectrums"] = origin_flipped_log_spectrums
71
+ reconstruction_samples["origin_flipped_phases"] = origin_flipped_phases
72
+ reconstruction_samples["origin_signals"] = origin_signals
73
+ reconstruction_samples["latent_representation_gradio_images"] = latent_representation_gradio_images
74
+ reconstruction_samples[
75
+ "quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
76
+ reconstruction_samples[
77
+ "reconstruction_flipped_log_spectrums"] = reconstruction_flipped_log_spectrums
78
+ reconstruction_samples[
79
+ "reconstruction_flipped_phases"] = reconstruction_flipped_phases
80
+ reconstruction_samples["reconstruction_signals"] = reconstruction_signals
81
+ reconstruction_samples[
82
+ "reconstruction_flipped_log_spectrums_WOA"] = reconstruction_flipped_log_spectrums_WOA
83
+ reconstruction_samples[
84
+ "reconstruction_flipped_phases_WOA"] = reconstruction_flipped_phases_WOA
85
+ reconstruction_samples["reconstruction_signals_WOA"] = reconstruction_signals_WOA
86
+ reconstruction_samples["sampleRate"] = sample_rate
87
+
88
+ latent_representation_gradio_image = reconstruction_samples["latent_representation_gradio_images"][0]
89
+ quantized_latent_representation_gradio_image = \
90
+ reconstruction_samples["quantized_latent_representation_gradio_images"][0]
91
+ origin_flipped_log_spectrum = reconstruction_samples["origin_flipped_log_spectrums"][0]
92
+ origin_flipped_phase = reconstruction_samples["origin_flipped_phases"][0]
93
+ origin_signal = reconstruction_samples["origin_signals"][0]
94
+ reconstruction_flipped_log_spectrum = reconstruction_samples["reconstruction_flipped_log_spectrums"][0]
95
+ reconstruction_flipped_phase = reconstruction_samples["reconstruction_flipped_phases"][0]
96
+ reconstruction_signal = reconstruction_samples["reconstruction_signals"][0]
97
+ reconstruction_flipped_log_spectrum_WOA = reconstruction_samples["reconstruction_flipped_log_spectrums_WOA"][0]
98
+ reconstruction_flipped_phase_WOA = reconstruction_samples["reconstruction_flipped_phases_WOA"][0]
99
+ reconstruction_signal_WOA = reconstruction_samples["reconstruction_signals_WOA"][0]
100
+
101
+ return {origin_amplitude_image_output: origin_flipped_log_spectrum,
102
+ origin_phase_image_output: origin_flipped_phase,
103
+ origin_audio_output: (sample_rate, origin_signal),
104
+ latent_representation_image_output: latent_representation_gradio_image,
105
+ quantized_latent_representation_image_output: quantized_latent_representation_gradio_image,
106
+ reconstruction_amplitude_image_output: reconstruction_flipped_log_spectrum,
107
+ reconstruction_phase_image_output: reconstruction_flipped_phase,
108
+ reconstruction_audio_output: (sample_rate, reconstruction_signal),
109
+ reconstruction_amplitude_image_output_WOA: reconstruction_flipped_log_spectrum_WOA,
110
+ reconstruction_phase_image_output_WOA: reconstruction_flipped_phase_WOA,
111
+ reconstruction_audio_output_WOA: (sample_rate, reconstruction_signal_WOA),
112
+ sample_index_slider: gr.update(minimum=0, maximum=vae_batchsize - 1, value=0, step=1.0,
113
+ label="Sample index.",
114
+ info="Slide to view other samples", scale=1, visible=True),
115
+ reconstruction_state: encodeCache,
116
+ reconstruction_samples_state: reconstruction_samples}
117
+
118
+ def show_reconstruction_sample(sample_index, encodeCache_state, reconstruction_samples_state):
119
+ sample_index = int(sample_index)
120
+ sampleRate = reconstruction_samples_state["sampleRate"]
121
+ latent_representation_gradio_image = reconstruction_samples_state["latent_representation_gradio_images"][
122
+ sample_index]
123
+ quantized_latent_representation_gradio_image = \
124
+ reconstruction_samples_state["quantized_latent_representation_gradio_images"][sample_index]
125
+ origin_flipped_log_spectrum = reconstruction_samples_state["origin_flipped_log_spectrums"][sample_index]
126
+ origin_flipped_phase = reconstruction_samples_state["origin_flipped_phases"][sample_index]
127
+ origin_signal = reconstruction_samples_state["origin_signals"][sample_index]
128
+ reconstruction_flipped_log_spectrum = reconstruction_samples_state["reconstruction_flipped_log_spectrums"][
129
+ sample_index]
130
+ reconstruction_flipped_phase = reconstruction_samples_state["reconstruction_flipped_phases"][
131
+ sample_index]
132
+ reconstruction_signal = reconstruction_samples_state["reconstruction_signals"][sample_index]
133
+ reconstruction_flipped_log_spectrum_WOA = reconstruction_samples_state["reconstruction_flipped_log_spectrums_WOA"][
134
+ sample_index]
135
+ reconstruction_flipped_phase_WOA = reconstruction_samples_state["reconstruction_flipped_phases_WOA"][
136
+ sample_index]
137
+ reconstruction_signal_WOA = reconstruction_samples_state["reconstruction_signals_WOA"][sample_index]
138
+ return origin_flipped_log_spectrum, origin_flipped_phase, (sampleRate, origin_signal), \
139
+ latent_representation_gradio_image, quantized_latent_representation_gradio_image, \
140
+ reconstruction_flipped_log_spectrum, reconstruction_flipped_phase, (sampleRate, reconstruction_signal), \
141
+ reconstruction_flipped_log_spectrum_WOA, reconstruction_flipped_phase_WOA, (sampleRate, reconstruction_signal_WOA), \
142
+ encodeCache_state, reconstruction_samples_state
143
+
144
+ with gr.Tab("Reconstruction"):
145
+ reconstruction_samples_state = gr.State(value={})
146
+ gr.Markdown("Test reconstruction.")
147
+ with gr.Row(variant="panel"):
148
+ with gr.Column():
149
+ sample_source_radio = gr.Radio(
150
+ choices=["synthetic", "external", "text2sound_trainSTFT", "text2sound_testSTFT", "text2sound_validSTFT"],
151
+ value="text2sound_trainf", info="Info placeholder", scale=2)
152
+ batchsize_slider = gr.Slider(minimum=1., maximum=16., value=4., step=1.,
153
+ label="batchsize")
154
+ with gr.Column():
155
+ generate_button = gr.Button(variant="primary", value="Generate reconstruction samples", scale=1)
156
+ sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, label="Sample index.",
157
+ info="Slide to view other samples", scale=1, visible=False)
158
+ with gr.Row(variant="panel"):
159
+ with gr.Column():
160
+ origin_amplitude_image_output = gr.Image(label="Spectrogram", type="numpy", height=300, width=100, scale=1)
161
+ origin_phase_image_output = gr.Image(label="Phase", type="numpy", height=300, width=100, scale=1)
162
+ origin_audio_output = gr.Audio(type="numpy", label="Play the example!")
163
+ with gr.Column():
164
+ reconstruction_amplitude_image_output = gr.Image(label="Spectrogram", type="numpy", height=300, width=100, scale=1)
165
+ reconstruction_phase_image_output = gr.Image(label="Phase", type="numpy", height=300, width=100, scale=1)
166
+ reconstruction_audio_output = gr.Audio(type="numpy", label="Play the example!")
167
+ with gr.Column():
168
+ reconstruction_amplitude_image_output_WOA = gr.Image(label="Spectrogram", type="numpy", height=300, width=100, scale=1)
169
+ reconstruction_phase_image_output_WOA = gr.Image(label="Phase", type="numpy", height=300, width=100, scale=1)
170
+ reconstruction_audio_output_WOA = gr.Audio(type="numpy", label="Play the example!")
171
+ with gr.Row(variant="panel", equal_height=True):
172
+ latent_representation_image_output = gr.Image(label="latent_representation", type="numpy", height=300, width=100)
173
+ quantized_latent_representation_image_output = gr.Image(label="quantized", type="numpy", height=300, width=100)
174
+
175
+ generate_button.click(generate_reconstruction_samples,
176
+ inputs=[sample_source_radio, batchsize_slider, reconstruction_state,
177
+ reconstruction_samples_state],
178
+ outputs=[origin_amplitude_image_output, origin_phase_image_output, origin_audio_output,
179
+ latent_representation_image_output, quantized_latent_representation_image_output,
180
+ reconstruction_amplitude_image_output, reconstruction_phase_image_output, reconstruction_audio_output,
181
+ reconstruction_amplitude_image_output_WOA, reconstruction_phase_image_output_WOA, reconstruction_audio_output_WOA,
182
+ sample_index_slider, reconstruction_state, reconstruction_samples_state])
183
+
184
+ sample_index_slider.change(show_reconstruction_sample,
185
+ inputs=[sample_index_slider, reconstruction_state, reconstruction_samples_state],
186
+ outputs=[origin_amplitude_image_output, origin_phase_image_output, origin_audio_output,
187
+ latent_representation_image_output, quantized_latent_representation_image_output,
188
+ reconstruction_amplitude_image_output, reconstruction_phase_image_output, reconstruction_audio_output,
189
+ reconstruction_amplitude_image_output_WOA, reconstruction_phase_image_output_WOA, reconstruction_audio_output_WOA,
190
+ reconstruction_state, reconstruction_samples_state])
webUI/natural_language_guided/sound2sound_with_text.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import librosa
3
+ import numpy as np
4
+ import torch
5
+
6
+ from model.DiffSynthSampler import DiffSynthSampler
7
+ from tools import pad_STFT, encode_stft
8
+ from tools import safe_int, adjust_audio_length
9
+ from webUI.natural_language_guided.utils import InputBatch2Encode_STFT, encodeBatch2GradioOutput_STFT, \
10
+ latent_representation_to_Gradio_image
11
+
12
+
13
+ def get_sound2sound_with_text_module(gradioWebUI, sound2sound_with_text_state, virtual_instruments_state):
14
+ # Load configurations
15
+ uNet = gradioWebUI.uNet
16
+ freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
17
+ VAE_scale = gradioWebUI.VAE_scale
18
+ height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels
19
+ timesteps = gradioWebUI.timesteps
20
+ VAE_encoder = gradioWebUI.VAE_encoder
21
+ VAE_quantizer = gradioWebUI.VAE_quantizer
22
+ VAE_decoder = gradioWebUI.VAE_decoder
23
+ CLAP = gradioWebUI.CLAP
24
+ CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
25
+ device = gradioWebUI.device
26
+ squared = gradioWebUI.squared
27
+ sample_rate = gradioWebUI.sample_rate
28
+ noise_strategy = gradioWebUI.noise_strategy
29
+
30
+ def receive_upload_origin_audio(sound2sound_duration, sound2sound_origin_source,
31
+ sound2sound_origin_upload, sound2sound_origin_microphone,
32
+ sound2sound_with_text_dict, virtual_instruments_dict):
33
+
34
+ if sound2sound_origin_source == "upload":
35
+ origin_sr, origin_audio = sound2sound_origin_upload
36
+ else:
37
+ origin_sr, origin_audio = sound2sound_origin_microphone
38
+
39
+ origin_audio = origin_audio / np.max(np.abs(origin_audio))
40
+
41
+ width = int(time_resolution*((sound2sound_duration+1)/4) / VAE_scale)
42
+ audio_length = 256 * (VAE_scale * width - 1)
43
+ origin_audio = adjust_audio_length(origin_audio, audio_length, origin_sr, sample_rate)
44
+
45
+ D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024)
46
+ padded_D = pad_STFT(D)
47
+ encoded_D = encode_stft(padded_D)
48
+
49
+ # Todo: justify batchsize to 1
50
+ origin_spectrogram_batch_tensor = torch.from_numpy(
51
+ np.repeat(encoded_D[np.newaxis, :, :, :], 1, axis=0)).float().to(device)
52
+
53
+ # Todo: remove hard-coding
54
+ origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, origin_latent_representations, quantized_origin_latent_representations = InputBatch2Encode_STFT(
55
+ VAE_encoder, origin_spectrogram_batch_tensor, resolution=(512, width * VAE_scale), quantizer=VAE_quantizer, squared=squared)
56
+
57
+ default_condition = CLAP.get_text_features(**CLAP_tokenizer([""], padding=True, return_tensors="pt"))[0].to("cpu").detach().numpy()
58
+
59
+ if sound2sound_origin_source == "upload":
60
+ sound2sound_with_text_dict["origin_upload_latent_representations"] = origin_latent_representations.tolist()
61
+ sound2sound_with_text_dict[
62
+ "sound2sound_origin_upload_latent_representation_image"] = latent_representation_to_Gradio_image(
63
+ origin_latent_representations[0]).tolist()
64
+ sound2sound_with_text_dict[
65
+ "sound2sound_origin_upload_quantized_latent_representation_image"] = latent_representation_to_Gradio_image(
66
+ quantized_origin_latent_representations[0]).tolist()
67
+
68
+ virtual_instruments = virtual_instruments_dict["virtual_instruments"]
69
+ virtual_instrument = {"condition": default_condition,
70
+ "negative_condition": default_condition, # care!!!
71
+ "CFG": 1,
72
+ "latent_representation": origin_latent_representations[0].to("cpu").detach().numpy(),
73
+ "quantized_latent_representation": quantized_origin_latent_representations[0].to("cpu").detach().numpy(),
74
+ "sampler": "ddim",
75
+ "signal": (sample_rate, origin_audio),
76
+ "spectrogram_gradio_image": origin_flipped_log_spectrums[0],
77
+ "phase_gradio_image": origin_flipped_phases[0]}
78
+ virtual_instruments["s2sup"] = virtual_instrument
79
+ virtual_instruments_dict["virtual_instruments"] = virtual_instruments
80
+
81
+ return {sound2sound_origin_spectrogram_upload_image: origin_flipped_log_spectrums[0],
82
+ sound2sound_origin_phase_upload_image: origin_flipped_phases[0],
83
+ sound2sound_origin_spectrogram_microphone_image: gr.update(),
84
+ sound2sound_origin_phase_microphone_image: gr.update(),
85
+ sound2sound_origin_upload_latent_representation_image: latent_representation_to_Gradio_image(
86
+ origin_latent_representations[0]),
87
+ sound2sound_origin_upload_quantized_latent_representation_image: latent_representation_to_Gradio_image(
88
+ quantized_origin_latent_representations[0]),
89
+ sound2sound_origin_microphone_latent_representation_image: gr.update(),
90
+ sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(),
91
+ sound2sound_with_text_state: sound2sound_with_text_dict,
92
+ virtual_instruments_state: virtual_instruments_dict}
93
+ else:
94
+ sound2sound_with_text_dict["origin_microphone_latent_representations"] = origin_latent_representations.tolist()
95
+ sound2sound_with_text_dict[
96
+ "sound2sound_origin_microphone_latent_representation_image"] = latent_representation_to_Gradio_image(
97
+ origin_latent_representations[0]).tolist()
98
+ sound2sound_with_text_dict[
99
+ "sound2sound_origin_microphone_quantized_latent_representation_image"] = latent_representation_to_Gradio_image(
100
+ quantized_origin_latent_representations[0]).tolist()
101
+
102
+ virtual_instruments = virtual_instruments_dict["virtual_instruments"]
103
+ virtual_instrument = {"condition": default_condition,
104
+ "negative_condition": default_condition, # care!!!
105
+ "CFG": 1,
106
+ "latent_representation": origin_latent_representations[0],
107
+ "quantized_latent_representation": quantized_origin_latent_representations[0],
108
+ "sampler": "ddim",
109
+ "signal": origin_audio,
110
+ "spectrogram_gradio_image": origin_flipped_log_spectrums[0]}
111
+ virtual_instruments["s2sre"] = virtual_instrument
112
+ virtual_instruments_dict["virtual_instruments"] = virtual_instruments
113
+
114
+ return {sound2sound_origin_spectrogram_upload_image: gr.update(),
115
+ sound2sound_origin_phase_upload_image: gr.update(),
116
+ sound2sound_origin_spectrogram_microphone_image: origin_flipped_log_spectrums[0],
117
+ sound2sound_origin_phase_microphone_image: origin_flipped_phases[0],
118
+ sound2sound_origin_upload_latent_representation_image: gr.update(),
119
+ sound2sound_origin_upload_quantized_latent_representation_image: gr.update(),
120
+ sound2sound_origin_microphone_latent_representation_image: latent_representation_to_Gradio_image(
121
+ origin_latent_representations[0]),
122
+ sound2sound_origin_microphone_quantized_latent_representation_image: latent_representation_to_Gradio_image(
123
+ quantized_origin_latent_representations[0]),
124
+ sound2sound_with_text_state: sound2sound_with_text_dict,
125
+ virtual_instruments_state: virtual_instruments_dict}
126
+
127
+ def sound2sound_sample(sound2sound_prompts, sound2sound_negative_prompts, sound2sound_batchsize,
128
+ sound2sound_guidance_scale, sound2sound_sampler,
129
+ sound2sound_sample_steps,
130
+ sound2sound_origin_source,
131
+ sound2sound_noising_strength, sound2sound_seed, sound2sound_dict, virtual_instruments_dict):
132
+
133
+ # input processing
134
+ sound2sound_seed = safe_int(sound2sound_seed, 12345678)
135
+ sound2sound_batchsize = int(sound2sound_batchsize)
136
+ noising_strength = sound2sound_noising_strength
137
+ sound2sound_sample_steps = int(sound2sound_sample_steps)
138
+ CFG = int(sound2sound_guidance_scale)
139
+
140
+ if sound2sound_origin_source == "upload":
141
+ origin_latent_representations = torch.tensor(
142
+ sound2sound_dict["origin_upload_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to(
143
+ device)
144
+ elif sound2sound_origin_source == "microphone":
145
+ origin_latent_representations = torch.tensor(
146
+ sound2sound_dict["origin_microphone_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to(
147
+ device)
148
+ else:
149
+ print("Input source not in ['upload', 'microphone']!")
150
+ raise NotImplementedError()
151
+
152
+ # sound2sound
153
+ text2sound_embedding = \
154
+ CLAP.get_text_features(**CLAP_tokenizer([sound2sound_prompts], padding=True, return_tensors="pt"))[0].to(
155
+ device)
156
+
157
+ mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
158
+ unconditional_condition = \
159
+ CLAP.get_text_features(**CLAP_tokenizer([sound2sound_negative_prompts], padding=True, return_tensors="pt"))[
160
+ 0]
161
+ mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device))
162
+
163
+ normalized_sample_steps = int(sound2sound_sample_steps / noising_strength)
164
+ mySampler.respace(list(np.linspace(0, timesteps - 1, normalized_sample_steps, dtype=np.int32)))
165
+
166
+ condition = text2sound_embedding.repeat(sound2sound_batchsize, 1)
167
+
168
+ # Todo: remove-hard coding
169
+ width = origin_latent_representations.shape[-1]
170
+ new_sound_latent_representations, initial_noise = \
171
+ mySampler.img_guided_sample(model=uNet, shape=(sound2sound_batchsize, channels, height, width),
172
+ seed=sound2sound_seed,
173
+ noising_strength=noising_strength,
174
+ guide_img=origin_latent_representations, return_tensor=True,
175
+ condition=condition,
176
+ sampler=sound2sound_sampler)
177
+
178
+ new_sound_latent_representations = new_sound_latent_representations[-1]
179
+
180
+ # Quantize new sound latent representations
181
+ quantized_new_sound_latent_representations, loss, (_, _, _) = VAE_quantizer(new_sound_latent_representations)
182
+
183
+ new_sound_flipped_log_spectrums, new_sound_flipped_phases, new_sound_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder,
184
+ quantized_new_sound_latent_representations,
185
+ resolution=(
186
+ 512,
187
+ width * VAE_scale),
188
+ original_STFT_batch=None
189
+ )
190
+
191
+
192
+
193
+ new_sound_latent_representation_gradio_images = []
194
+ new_sound_quantized_latent_representation_gradio_images = []
195
+ new_sound_spectrogram_gradio_images = []
196
+ new_sound_phase_gradio_images = []
197
+ new_sound_rec_signals_gradio = []
198
+ for i in range(sound2sound_batchsize):
199
+ new_sound_latent_representation_gradio_images.append(
200
+ latent_representation_to_Gradio_image(new_sound_latent_representations[i]))
201
+ new_sound_quantized_latent_representation_gradio_images.append(
202
+ latent_representation_to_Gradio_image(quantized_new_sound_latent_representations[i]))
203
+ new_sound_spectrogram_gradio_images.append(new_sound_flipped_log_spectrums[i])
204
+ new_sound_phase_gradio_images.append(new_sound_flipped_phases[i])
205
+ new_sound_rec_signals_gradio.append((sample_rate, new_sound_signals[i]))
206
+
207
+ sound2sound_dict[
208
+ "new_sound_latent_representation_gradio_images"] = new_sound_latent_representation_gradio_images
209
+ sound2sound_dict[
210
+ "new_sound_quantized_latent_representation_gradio_images"] = new_sound_quantized_latent_representation_gradio_images
211
+ sound2sound_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
212
+ sound2sound_dict["new_sound_phase_gradio_images"] = new_sound_phase_gradio_images
213
+ sound2sound_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
214
+
215
+ return {sound2sound_new_sound_latent_representation_image: latent_representation_to_Gradio_image(
216
+ new_sound_latent_representations[0]),
217
+ sound2sound_new_sound_quantized_latent_representation_image: latent_representation_to_Gradio_image(
218
+ quantized_new_sound_latent_representations[0]),
219
+ sound2sound_new_sound_spectrogram_image: new_sound_flipped_log_spectrums[0],
220
+ sound2sound_new_sound_phase_image: new_sound_flipped_phases[0],
221
+ sound2sound_new_sound_audio: (sample_rate, new_sound_signals[0]),
222
+ sound2sound_sample_index_slider: gr.update(minimum=0, maximum=sound2sound_batchsize - 1, value=0,
223
+ step=1.0,
224
+ visible=True,
225
+ label="Sample index",
226
+ info="Swipe to view other samples"),
227
+ sound2sound_seed_textbox: sound2sound_seed,
228
+ sound2sound_with_text_state: sound2sound_dict,
229
+ virtual_instruments_state: virtual_instruments_dict}
230
+
231
+ def show_sound2sound_sample(sound2sound_sample_index, sound2sound_with_text_dict):
232
+ sample_index = int(sound2sound_sample_index)
233
+ return {sound2sound_new_sound_latent_representation_image:
234
+ sound2sound_with_text_dict["new_sound_latent_representation_gradio_images"][sample_index],
235
+ sound2sound_new_sound_quantized_latent_representation_image:
236
+ sound2sound_with_text_dict["new_sound_quantized_latent_representation_gradio_images"][sample_index],
237
+ sound2sound_new_sound_spectrogram_image: sound2sound_with_text_dict["new_sound_spectrogram_gradio_images"][
238
+ sample_index],
239
+ sound2sound_new_sound_phase_image: sound2sound_with_text_dict["new_sound_phase_gradio_images"][
240
+ sample_index],
241
+ sound2sound_new_sound_audio: sound2sound_with_text_dict["new_sound_rec_signals_gradio"][sample_index]}
242
+
243
+ def sound2sound_switch_origin_source(sound2sound_origin_source):
244
+
245
+ if sound2sound_origin_source == "upload":
246
+ return {sound2sound_origin_upload_audio: gr.update(visible=True),
247
+ sound2sound_origin_microphone_audio: gr.update(visible=False),
248
+ sound2sound_origin_spectrogram_upload_image: gr.update(visible=True),
249
+ sound2sound_origin_phase_upload_image: gr.update(visible=True),
250
+ sound2sound_origin_spectrogram_microphone_image: gr.update(visible=False),
251
+ sound2sound_origin_phase_microphone_image: gr.update(visible=False),
252
+ sound2sound_origin_upload_latent_representation_image: gr.update(visible=True),
253
+ sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=True),
254
+ sound2sound_origin_microphone_latent_representation_image: gr.update(visible=False),
255
+ sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=False)}
256
+ elif sound2sound_origin_source == "microphone":
257
+ return {sound2sound_origin_upload_audio: gr.update(visible=False),
258
+ sound2sound_origin_microphone_audio: gr.update(visible=True),
259
+ sound2sound_origin_spectrogram_upload_image: gr.update(visible=False),
260
+ sound2sound_origin_phase_upload_image: gr.update(visible=False),
261
+ sound2sound_origin_spectrogram_microphone_image: gr.update(visible=True),
262
+ sound2sound_origin_phase_microphone_image: gr.update(visible=True),
263
+ sound2sound_origin_upload_latent_representation_image: gr.update(visible=False),
264
+ sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=False),
265
+ sound2sound_origin_microphone_latent_representation_image: gr.update(visible=True),
266
+ sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=True)}
267
+ else:
268
+ print("Input source not in ['upload', 'microphone']!")
269
+
270
+ with gr.Tab("Sound2Sound"):
271
+ gr.Markdown("Generate new sound based on a given sound!")
272
+ with gr.Row(variant="panel"):
273
+ with gr.Column(scale=3):
274
+ sound2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ")
275
+ text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
276
+
277
+ with gr.Column(scale=1):
278
+ sound2sound_sample_button = gr.Button(variant="primary", value="Generate", scale=1)
279
+
280
+ sound2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
281
+ label="Sample index",
282
+ info="Swipe to view other samples")
283
+
284
+ with gr.Row(variant="panel"):
285
+ with gr.Column(scale=1):
286
+ with gr.Tab("Origin sound"):
287
+ sound2sound_duration_slider = gradioWebUI.get_duration_slider()
288
+ sound2sound_origin_source_radio = gr.Radio(choices=["upload", "microphone"], value="upload",
289
+ label="Input source")
290
+
291
+ sound2sound_origin_upload_audio = gr.Audio(type="numpy", label="Upload", source="upload",
292
+ interactive=True, visible=True)
293
+ sound2sound_origin_microphone_audio = gr.Audio(type="numpy", label="Record", source="microphone",
294
+ interactive=True, visible=False)
295
+ with gr.Row(variant="panel"):
296
+ sound2sound_origin_spectrogram_upload_image = gr.Image(label="Original upload spectrogram",
297
+ type="numpy", height=600,
298
+ visible=True)
299
+ sound2sound_origin_phase_upload_image = gr.Image(label="Original upload phase",
300
+ type="numpy", height=600,
301
+ visible=True)
302
+ sound2sound_origin_spectrogram_microphone_image = gr.Image(label="Original microphone spectrogram",
303
+ type="numpy", height=600,
304
+ visible=False)
305
+ sound2sound_origin_phase_microphone_image = gr.Image(label="Original microphone phase",
306
+ type="numpy", height=600,
307
+ visible=False)
308
+
309
+ with gr.Tab("Sound2sound settings"):
310
+ sound2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
311
+ sound2sound_sampler_radio = gradioWebUI.get_sampler_radio()
312
+ sound2sound_batchsize_slider = gradioWebUI.get_batchsize_slider()
313
+ sound2sound_noising_strength_slider = gradioWebUI.get_noising_strength_slider()
314
+ sound2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
315
+ sound2sound_seed_textbox = gradioWebUI.get_seed_textbox()
316
+
317
+ with gr.Column(scale=1):
318
+ sound2sound_new_sound_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False)
319
+ with gr.Row(variant="panel"):
320
+ sound2sound_new_sound_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy",
321
+ height=600, scale=1)
322
+ sound2sound_new_sound_phase_image = gr.Image(label="New sound phase", type="numpy",
323
+ height=600, scale=1)
324
+
325
+ with gr.Row(variant="panel"):
326
+ sound2sound_origin_upload_latent_representation_image = gr.Image(label="Original latent representation",
327
+ type="numpy", height=800,
328
+ visible=True)
329
+ sound2sound_origin_upload_quantized_latent_representation_image = gr.Image(
330
+ label="Original quantized latent representation", type="numpy", height=800, visible=True)
331
+
332
+ sound2sound_origin_microphone_latent_representation_image = gr.Image(label="Original latent representation",
333
+ type="numpy", height=800,
334
+ visible=False)
335
+ sound2sound_origin_microphone_quantized_latent_representation_image = gr.Image(
336
+ label="Original quantized latent representation", type="numpy", height=800, visible=False)
337
+
338
+ sound2sound_new_sound_latent_representation_image = gr.Image(label="New latent representation",
339
+ type="numpy", height=800)
340
+ sound2sound_new_sound_quantized_latent_representation_image = gr.Image(
341
+ label="New sound quantized latent representation", type="numpy", height=800)
342
+
343
+ sound2sound_origin_upload_audio.change(receive_upload_origin_audio,
344
+ inputs=[sound2sound_duration_slider, sound2sound_origin_source_radio,
345
+ sound2sound_origin_upload_audio,
346
+ sound2sound_origin_microphone_audio, sound2sound_with_text_state,
347
+ virtual_instruments_state],
348
+ outputs=[sound2sound_origin_spectrogram_upload_image,
349
+ sound2sound_origin_phase_upload_image,
350
+ sound2sound_origin_spectrogram_microphone_image,
351
+ sound2sound_origin_phase_microphone_image,
352
+ sound2sound_origin_upload_latent_representation_image,
353
+ sound2sound_origin_upload_quantized_latent_representation_image,
354
+ sound2sound_origin_microphone_latent_representation_image,
355
+ sound2sound_origin_microphone_quantized_latent_representation_image,
356
+ sound2sound_with_text_state,
357
+ virtual_instruments_state])
358
+
359
+ sound2sound_origin_microphone_audio.change(receive_upload_origin_audio,
360
+ inputs=[sound2sound_duration_slider,
361
+ sound2sound_origin_source_radio, sound2sound_origin_upload_audio,
362
+ sound2sound_origin_microphone_audio, sound2sound_with_text_state,
363
+ virtual_instruments_state],
364
+ outputs=[sound2sound_origin_spectrogram_upload_image,
365
+ sound2sound_origin_phase_upload_image,
366
+ sound2sound_origin_spectrogram_microphone_image,
367
+ sound2sound_origin_phase_microphone_image,
368
+ sound2sound_origin_upload_latent_representation_image,
369
+ sound2sound_origin_upload_quantized_latent_representation_image,
370
+ sound2sound_origin_microphone_latent_representation_image,
371
+ sound2sound_origin_microphone_quantized_latent_representation_image,
372
+ sound2sound_with_text_state,
373
+ virtual_instruments_state])
374
+
375
+ sound2sound_sample_button.click(sound2sound_sample,
376
+ inputs=[sound2sound_prompts_textbox,
377
+ text2sound_negative_prompts_textbox,
378
+ sound2sound_batchsize_slider,
379
+ sound2sound_guidance_scale_slider,
380
+ sound2sound_sampler_radio,
381
+ sound2sound_sample_steps_slider,
382
+ sound2sound_origin_source_radio,
383
+ sound2sound_noising_strength_slider,
384
+ sound2sound_seed_textbox,
385
+ sound2sound_with_text_state,
386
+ virtual_instruments_state],
387
+ outputs=[sound2sound_new_sound_latent_representation_image,
388
+ sound2sound_new_sound_quantized_latent_representation_image,
389
+ sound2sound_new_sound_spectrogram_image,
390
+ sound2sound_new_sound_phase_image,
391
+ sound2sound_new_sound_audio,
392
+ sound2sound_sample_index_slider,
393
+ sound2sound_seed_textbox,
394
+ sound2sound_with_text_state,
395
+ virtual_instruments_state])
396
+
397
+ sound2sound_sample_index_slider.change(show_sound2sound_sample,
398
+ inputs=[sound2sound_sample_index_slider, sound2sound_with_text_state],
399
+ outputs=[sound2sound_new_sound_latent_representation_image,
400
+ sound2sound_new_sound_quantized_latent_representation_image,
401
+ sound2sound_new_sound_spectrogram_image,
402
+ sound2sound_new_sound_phase_image,
403
+ sound2sound_new_sound_audio])
404
+
405
+ sound2sound_origin_source_radio.change(sound2sound_switch_origin_source,
406
+ inputs=[sound2sound_origin_source_radio],
407
+ outputs=[sound2sound_origin_upload_audio,
408
+ sound2sound_origin_microphone_audio,
409
+ sound2sound_origin_spectrogram_upload_image,
410
+ sound2sound_origin_phase_upload_image,
411
+ sound2sound_origin_spectrogram_microphone_image,
412
+ sound2sound_origin_phase_microphone_image,
413
+ sound2sound_origin_upload_latent_representation_image,
414
+ sound2sound_origin_upload_quantized_latent_representation_image,
415
+ sound2sound_origin_microphone_latent_representation_image,
416
+ sound2sound_origin_microphone_quantized_latent_representation_image])
webUI/natural_language_guided/super_resolution_with_text.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import torch
4
+ import gradio as gr
5
+ from scipy.ndimage import zoom
6
+
7
+ from model.DiffSynthSampler import DiffSynthSampler
8
+ from tools import adjust_audio_length, rescale, safe_int, pad_STFT, encode_stft
9
+ from webUI.natural_language_guided_STFT.utils import latent_representation_to_Gradio_image
10
+ from webUI.natural_language_guided_STFT.utils import InputBatch2Encode_STFT, encodeBatch2GradioOutput_STFT
11
+
12
+
13
+ def get_super_resolution_with_text_module(gradioWebUI, inpaintWithText_state):
14
+ # Load configurations
15
+ uNet = gradioWebUI.uNet
16
+ freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
17
+ VAE_scale = gradioWebUI.VAE_scale
18
+ height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels
19
+ timesteps = gradioWebUI.timesteps
20
+ VAE_encoder = gradioWebUI.VAE_encoder
21
+ VAE_quantizer = gradioWebUI.VAE_quantizer
22
+ VAE_decoder = gradioWebUI.VAE_decoder
23
+ CLAP = gradioWebUI.CLAP
24
+ CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
25
+ device = gradioWebUI.device
26
+ squared = gradioWebUI.squared
27
+ sample_rate = gradioWebUI.sample_rate
28
+ noise_strategy = gradioWebUI.noise_strategy
29
+
30
+ def receive_uopoad_origin_audio(sound2sound_duration, sound2sound_origin_source, sound2sound_origin_upload, sound2sound_origin_microphone,
31
+ inpaintWithText_dict):
32
+
33
+ if sound2sound_origin_source == "upload":
34
+ origin_sr, origin_audio = sound2sound_origin_upload
35
+ else:
36
+ origin_sr, origin_audio = sound2sound_origin_microphone
37
+
38
+ origin_audio = origin_audio / np.max(np.abs(origin_audio))
39
+
40
+ width = int(time_resolution*((sound2sound_duration+1)/4) / VAE_scale)
41
+ audio_length = 256 * (VAE_scale * width - 1)
42
+ origin_audio = adjust_audio_length(origin_audio, audio_length, origin_sr, sample_rate)
43
+
44
+ D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024)
45
+ padded_D = pad_STFT(D)
46
+ encoded_D = encode_stft(padded_D)
47
+
48
+ # Todo: justify batchsize to 1
49
+ origin_spectrogram_batch_tensor = torch.from_numpy(
50
+ np.repeat(encoded_D[np.newaxis, :, :, :], 1, axis=0)).float().to(device)
51
+
52
+ # Todo: remove hard-coding
53
+ origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, origin_latent_representations, quantized_origin_latent_representations = InputBatch2Encode_STFT(
54
+ VAE_encoder, origin_spectrogram_batch_tensor, resolution=(512, width * VAE_scale), quantizer=VAE_quantizer, squared=squared)
55
+
56
+ if sound2sound_origin_source == "upload":
57
+ inpaintWithText_dict["origin_upload_latent_representations"] = origin_latent_representations.tolist()
58
+ inpaintWithText_dict[
59
+ "sound2sound_origin_upload_latent_representation_image"] = latent_representation_to_Gradio_image(
60
+ origin_latent_representations[0]).tolist()
61
+ inpaintWithText_dict[
62
+ "sound2sound_origin_upload_quantized_latent_representation_image"] = latent_representation_to_Gradio_image(
63
+ quantized_origin_latent_representations[0]).tolist()
64
+ return {sound2sound_origin_spectrogram_upload_image: origin_flipped_log_spectrums[0],
65
+ sound2sound_origin_phase_upload_image: origin_flipped_phases[0],
66
+ sound2sound_origin_spectrogram_microphone_image: gr.update(),
67
+ sound2sound_origin_phase_microphone_image: gr.update(),
68
+ sound2sound_origin_upload_latent_representation_image: latent_representation_to_Gradio_image(
69
+ origin_latent_representations[0]),
70
+ sound2sound_origin_upload_quantized_latent_representation_image: latent_representation_to_Gradio_image(
71
+ quantized_origin_latent_representations[0]),
72
+ sound2sound_origin_microphone_latent_representation_image: gr.update(),
73
+ sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(),
74
+ inpaintWithText_state: inpaintWithText_dict}
75
+ else:
76
+ inpaintWithText_dict["origin_microphone_latent_representations"] = origin_latent_representations.tolist()
77
+ inpaintWithText_dict[
78
+ "sound2sound_origin_microphone_latent_representation_image"] = latent_representation_to_Gradio_image(
79
+ origin_latent_representations[0]).tolist()
80
+ inpaintWithText_dict[
81
+ "sound2sound_origin_microphone_quantized_latent_representation_image"] = latent_representation_to_Gradio_image(
82
+ quantized_origin_latent_representations[0]).tolist()
83
+ return {sound2sound_origin_spectrogram_upload_image: origin_flipped_log_spectrums[0],
84
+ sound2sound_origin_phase_upload_image: origin_flipped_phases[0],
85
+ sound2sound_origin_spectrogram_microphone_image: gr.update(),
86
+ sound2sound_origin_phase_microphone_image: gr.update(),
87
+ sound2sound_origin_upload_latent_representation_image: latent_representation_to_Gradio_image(
88
+ origin_latent_representations[0]),
89
+ sound2sound_origin_upload_quantized_latent_representation_image: latent_representation_to_Gradio_image(
90
+ quantized_origin_latent_representations[0]),
91
+ sound2sound_origin_microphone_latent_representation_image: gr.update(),
92
+ sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(),
93
+ inpaintWithText_state: inpaintWithText_dict}
94
+
95
+ def sound2sound_sample(sound2sound_origin_spectrogram_upload, sound2sound_origin_spectrogram_microphone,
96
+ text2sound_prompts, text2sound_negative_prompts, sound2sound_batchsize,
97
+ sound2sound_guidance_scale, sound2sound_sampler,
98
+ sound2sound_sample_steps, sound2sound_origin_source,
99
+ sound2sound_noising_strength, sound2sound_seed, sound2sound_inpaint_area, inpaintWithText_dict
100
+ ):
101
+
102
+ # input preprocessing
103
+ sound2sound_seed = safe_int(sound2sound_seed, 12345678)
104
+ sound2sound_batchsize = int(sound2sound_batchsize)
105
+ noising_strength = sound2sound_noising_strength
106
+ sound2sound_sample_steps = int(sound2sound_sample_steps)
107
+ CFG = int(sound2sound_guidance_scale)
108
+
109
+ text2sound_embedding = \
110
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(device)
111
+
112
+ if sound2sound_origin_source == "upload":
113
+ origin_latent_representations = torch.tensor(
114
+ inpaintWithText_dict["origin_upload_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to(
115
+ device)
116
+ elif sound2sound_origin_source == "microphone":
117
+ origin_latent_representations = torch.tensor(
118
+ inpaintWithText_dict["origin_microphone_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to(
119
+ device)
120
+ else:
121
+ print("Input source not in ['upload', 'microphone']!")
122
+ raise NotImplementedError()
123
+
124
+ high_resolution_latent_representations = torch.zeros((sound2sound_batchsize, channels, 256, 64)).to(device)
125
+ high_resolution_latent_representations[:, :, :128, :] = origin_latent_representations
126
+ latent_mask = np.ones((256, 64))
127
+ latent_mask[192:, :] = 0.0
128
+ print(f"latent_mask mean: {np.mean(latent_mask)}")
129
+
130
+ if sound2sound_inpaint_area == "inpaint masked":
131
+ latent_mask = 1 - latent_mask
132
+ latent_mask = torch.from_numpy(latent_mask).unsqueeze(0).unsqueeze(1).repeat(sound2sound_batchsize, channels, 1,
133
+ 1).float().to(device)
134
+ latent_mask = torch.flip(latent_mask, [2])
135
+
136
+ mySampler = DiffSynthSampler(timesteps, height=height*2, channels=channels, noise_strategy=noise_strategy)
137
+ unconditional_condition = \
138
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[0]
139
+ mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device))
140
+
141
+ normalized_sample_steps = int(sound2sound_sample_steps / noising_strength)
142
+
143
+ mySampler.respace(list(np.linspace(0, timesteps - 1, normalized_sample_steps, dtype=np.int32)))
144
+
145
+ # Todo: remove hard-coding
146
+ width = high_resolution_latent_representations.shape[-1]
147
+ condition = text2sound_embedding.repeat(sound2sound_batchsize, 1)
148
+
149
+ new_sound_latent_representations, initial_noise = \
150
+ mySampler.inpaint_sample(model=uNet, shape=(sound2sound_batchsize, channels, height*2, width),
151
+ seed=sound2sound_seed,
152
+ noising_strength=noising_strength,
153
+ guide_img=high_resolution_latent_representations, mask=latent_mask, return_tensor=True,
154
+ condition=condition, sampler=sound2sound_sampler)
155
+
156
+ new_sound_latent_representations = new_sound_latent_representations[-1]
157
+
158
+ # Quantize new sound latent representations
159
+ quantized_new_sound_latent_representations, loss, (_, _, _) = VAE_quantizer(new_sound_latent_representations)
160
+ new_sound_flipped_log_spectrums, new_sound_flipped_phases, new_sound_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder,
161
+ quantized_new_sound_latent_representations,
162
+ resolution=(
163
+ 1024,
164
+ width * VAE_scale),
165
+ original_STFT_batch=None
166
+ )
167
+
168
+ new_sound_latent_representation_gradio_images = []
169
+ new_sound_quantized_latent_representation_gradio_images = []
170
+ new_sound_spectrogram_gradio_images = []
171
+ new_sound_phase_gradio_images = []
172
+ new_sound_rec_signals_gradio = []
173
+ for i in range(sound2sound_batchsize):
174
+ new_sound_latent_representation_gradio_images.append(
175
+ latent_representation_to_Gradio_image(new_sound_latent_representations[i]))
176
+ new_sound_quantized_latent_representation_gradio_images.append(
177
+ latent_representation_to_Gradio_image(quantized_new_sound_latent_representations[i]))
178
+ new_sound_spectrogram_gradio_images.append(new_sound_flipped_log_spectrums[i])
179
+ new_sound_phase_gradio_images.append(new_sound_flipped_phases[i])
180
+ new_sound_rec_signals_gradio.append((sample_rate, new_sound_signals[i]))
181
+
182
+ inpaintWithText_dict[
183
+ "new_sound_latent_representation_gradio_images"] = new_sound_latent_representation_gradio_images
184
+ inpaintWithText_dict[
185
+ "new_sound_quantized_latent_representation_gradio_images"] = new_sound_quantized_latent_representation_gradio_images
186
+ inpaintWithText_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
187
+ inpaintWithText_dict["new_sound_phase_gradio_images"] = new_sound_phase_gradio_images
188
+ inpaintWithText_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
189
+
190
+ return {sound2sound_new_sound_latent_representation_image: latent_representation_to_Gradio_image(
191
+ new_sound_latent_representations[0]),
192
+ sound2sound_new_sound_quantized_latent_representation_image: latent_representation_to_Gradio_image(
193
+ quantized_new_sound_latent_representations[0]),
194
+ sound2sound_new_sound_spectrogram_image: new_sound_flipped_log_spectrums[0],
195
+ sound2sound_new_sound_phase_image: new_sound_flipped_phases[0],
196
+ sound2sound_new_sound_audio: (sample_rate, new_sound_signals[0]),
197
+ sound2sound_sample_index_slider: gr.update(minimum=0, maximum=sound2sound_batchsize - 1, value=0,
198
+ step=1.0,
199
+ visible=True,
200
+ label="Sample index",
201
+ info="Swipe to view other samples"),
202
+ sound2sound_seed_textbox: sound2sound_seed,
203
+ inpaintWithText_state: inpaintWithText_dict}
204
+
205
+ def show_sound2sound_sample(sound2sound_sample_index, inpaintWithText_dict):
206
+ sample_index = int(sound2sound_sample_index)
207
+ return {sound2sound_new_sound_latent_representation_image:
208
+ inpaintWithText_dict["new_sound_latent_representation_gradio_images"][sample_index],
209
+ sound2sound_new_sound_quantized_latent_representation_image:
210
+ inpaintWithText_dict["new_sound_quantized_latent_representation_gradio_images"][sample_index],
211
+ sound2sound_new_sound_spectrogram_image: inpaintWithText_dict["new_sound_spectrogram_gradio_images"][
212
+ sample_index],
213
+ sound2sound_new_sound_phase_image: inpaintWithText_dict["new_sound_phase_gradio_images"][
214
+ sample_index],
215
+ sound2sound_new_sound_audio: inpaintWithText_dict["new_sound_rec_signals_gradio"][sample_index]}
216
+
217
+ def sound2sound_switch_origin_source(sound2sound_origin_source):
218
+
219
+ if sound2sound_origin_source == "upload":
220
+ return {sound2sound_origin_upload_audio: gr.update(visible=True),
221
+ sound2sound_origin_microphone_audio: gr.update(visible=False),
222
+ sound2sound_origin_spectrogram_upload_image: gr.update(visible=True),
223
+ sound2sound_origin_phase_upload_image: gr.update(visible=True),
224
+ sound2sound_origin_spectrogram_microphone_image: gr.update(visible=False),
225
+ sound2sound_origin_phase_microphone_image: gr.update(visible=False),
226
+ sound2sound_origin_upload_latent_representation_image: gr.update(visible=True),
227
+ sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=True),
228
+ sound2sound_origin_microphone_latent_representation_image: gr.update(visible=False),
229
+ sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=False)}
230
+ elif sound2sound_origin_source == "microphone":
231
+ return {sound2sound_origin_upload_audio: gr.update(visible=False),
232
+ sound2sound_origin_microphone_audio: gr.update(visible=True),
233
+ sound2sound_origin_spectrogram_upload_image: gr.update(visible=False),
234
+ sound2sound_origin_phase_upload_image: gr.update(visible=False),
235
+ sound2sound_origin_spectrogram_microphone_image: gr.update(visible=True),
236
+ sound2sound_origin_phase_microphone_image: gr.update(visible=True),
237
+ sound2sound_origin_upload_latent_representation_image: gr.update(visible=False),
238
+ sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=False),
239
+ sound2sound_origin_microphone_latent_representation_image: gr.update(visible=True),
240
+ sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=True)}
241
+ else:
242
+ print("Input source not in ['upload', 'microphone']!")
243
+
244
+ with gr.Tab("Super Resolution"):
245
+ gr.Markdown("Select the area to inpaint and use the prompt to guide the synthesis of a new sound!")
246
+ with gr.Row(variant="panel"):
247
+ with gr.Column(scale=3):
248
+ text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ")
249
+ text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
250
+
251
+ with gr.Column(scale=1):
252
+ sound2sound_sample_button = gr.Button(variant="primary", value="Generate", scale=1)
253
+
254
+ sound2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
255
+ label="Sample index",
256
+ info="Swipe to view other samples")
257
+
258
+ with gr.Row(variant="panel"):
259
+ with gr.Column(scale=1):
260
+ with gr.Tab("Origin sound"):
261
+ sound2sound_duration_slider = gradioWebUI.get_duration_slider()
262
+ sound2sound_origin_source_radio = gr.Radio(choices=["upload", "microphone"], value="upload",
263
+ label="Input source")
264
+
265
+ sound2sound_origin_upload_audio = gr.Audio(type="numpy", label="Upload", source="upload",
266
+ interactive=True, visible=True)
267
+ sound2sound_origin_microphone_audio = gr.Audio(type="numpy", label="Record", source="microphone",
268
+ interactive=True, visible=False)
269
+ with gr.Row(variant="panel"):
270
+ sound2sound_origin_spectrogram_upload_image = gr.Image(label="Original upload spectrogram",
271
+ type="numpy", height=600,
272
+ visible=True, tool="sketch")
273
+ sound2sound_origin_phase_upload_image = gr.Image(label="Original upload phase",
274
+ type="numpy", height=600,
275
+ visible=True)
276
+ sound2sound_origin_spectrogram_microphone_image = gr.Image(label="Original microphone spectrogram",
277
+ type="numpy", height=600,
278
+ visible=False, tool="sketch")
279
+ sound2sound_origin_phase_microphone_image = gr.Image(label="Original microphone phase",
280
+ type="numpy", height=600,
281
+ visible=False)
282
+ sound2sound_inpaint_area_radio = gr.Radio(choices=["inpaint masked", "inpaint not masked"],
283
+ value="inpaint masked")
284
+
285
+ with gr.Tab("Sound2sound settings"):
286
+ sound2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
287
+ sound2sound_sampler_radio = gradioWebUI.get_sampler_radio()
288
+ sound2sound_batchsize_slider = gradioWebUI.get_batchsize_slider()
289
+ sound2sound_noising_strength_slider = gradioWebUI.get_noising_strength_slider(default_noising_strength=1.0)
290
+ sound2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
291
+ sound2sound_seed_textbox = gradioWebUI.get_seed_textbox()
292
+
293
+
294
+ with gr.Column(scale=1):
295
+ sound2sound_new_sound_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False)
296
+ with gr.Row(variant="panel"):
297
+ sound2sound_new_sound_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy",
298
+ height=1200, scale=1)
299
+ sound2sound_new_sound_phase_image = gr.Image(label="New sound phase", type="numpy",
300
+ height=1200, scale=1)
301
+
302
+ with gr.Row(variant="panel"):
303
+ sound2sound_origin_upload_latent_representation_image = gr.Image(label="Original latent representation",
304
+ type="numpy", height=1200,
305
+ visible=True)
306
+ sound2sound_origin_upload_quantized_latent_representation_image = gr.Image(
307
+ label="Original quantized latent representation", type="numpy", height=1200, visible=True)
308
+
309
+ sound2sound_origin_microphone_latent_representation_image = gr.Image(label="Original latent representation",
310
+ type="numpy", height=1200,
311
+ visible=False)
312
+ sound2sound_origin_microphone_quantized_latent_representation_image = gr.Image(
313
+ label="Original quantized latent representation", type="numpy", height=1200, visible=False)
314
+
315
+ sound2sound_new_sound_latent_representation_image = gr.Image(label="New latent representation",
316
+ type="numpy", height=1200)
317
+ sound2sound_new_sound_quantized_latent_representation_image = gr.Image(
318
+ label="New sound quantized latent representation", type="numpy", height=1200)
319
+
320
+ sound2sound_origin_upload_audio.change(receive_uopoad_origin_audio,
321
+ inputs=[sound2sound_duration_slider, sound2sound_origin_source_radio, sound2sound_origin_upload_audio,
322
+ sound2sound_origin_microphone_audio, inpaintWithText_state],
323
+ outputs=[sound2sound_origin_spectrogram_upload_image,
324
+ sound2sound_origin_phase_upload_image,
325
+ sound2sound_origin_spectrogram_microphone_image,
326
+ sound2sound_origin_phase_microphone_image,
327
+ sound2sound_origin_upload_latent_representation_image,
328
+ sound2sound_origin_upload_quantized_latent_representation_image,
329
+ sound2sound_origin_microphone_latent_representation_image,
330
+ sound2sound_origin_microphone_quantized_latent_representation_image,
331
+ inpaintWithText_state])
332
+ sound2sound_origin_microphone_audio.change(receive_uopoad_origin_audio,
333
+ inputs=[sound2sound_duration_slider, sound2sound_origin_source_radio, sound2sound_origin_upload_audio,
334
+ sound2sound_origin_microphone_audio, inpaintWithText_state],
335
+ outputs=[sound2sound_origin_spectrogram_upload_image,
336
+ sound2sound_origin_phase_upload_image,
337
+ sound2sound_origin_spectrogram_microphone_image,
338
+ sound2sound_origin_phase_microphone_image,
339
+ sound2sound_origin_upload_latent_representation_image,
340
+ sound2sound_origin_upload_quantized_latent_representation_image,
341
+ sound2sound_origin_microphone_latent_representation_image,
342
+ sound2sound_origin_microphone_quantized_latent_representation_image,
343
+ inpaintWithText_state])
344
+
345
+ sound2sound_sample_button.click(sound2sound_sample,
346
+ inputs=[sound2sound_origin_spectrogram_upload_image,
347
+ sound2sound_origin_spectrogram_microphone_image,
348
+ text2sound_prompts_textbox,
349
+ text2sound_negative_prompts_textbox,
350
+ sound2sound_batchsize_slider,
351
+ sound2sound_guidance_scale_slider,
352
+ sound2sound_sampler_radio,
353
+ sound2sound_sample_steps_slider,
354
+ sound2sound_origin_source_radio,
355
+ sound2sound_noising_strength_slider,
356
+ sound2sound_seed_textbox,
357
+ sound2sound_inpaint_area_radio,
358
+ inpaintWithText_state],
359
+ outputs=[sound2sound_new_sound_latent_representation_image,
360
+ sound2sound_new_sound_quantized_latent_representation_image,
361
+ sound2sound_new_sound_spectrogram_image,
362
+ sound2sound_new_sound_phase_image,
363
+ sound2sound_new_sound_audio,
364
+ sound2sound_sample_index_slider,
365
+ sound2sound_seed_textbox,
366
+ inpaintWithText_state])
367
+
368
+ sound2sound_sample_index_slider.change(show_sound2sound_sample,
369
+ inputs=[sound2sound_sample_index_slider, inpaintWithText_state],
370
+ outputs=[sound2sound_new_sound_latent_representation_image,
371
+ sound2sound_new_sound_quantized_latent_representation_image,
372
+ sound2sound_new_sound_spectrogram_image,
373
+ sound2sound_new_sound_phase_image,
374
+ sound2sound_new_sound_audio])
375
+
376
+ sound2sound_origin_source_radio.change(sound2sound_switch_origin_source,
377
+ inputs=[sound2sound_origin_source_radio],
378
+ outputs=[sound2sound_origin_upload_audio,
379
+ sound2sound_origin_microphone_audio,
380
+ sound2sound_origin_spectrogram_upload_image,
381
+ sound2sound_origin_phase_upload_image,
382
+ sound2sound_origin_spectrogram_microphone_image,
383
+ sound2sound_origin_phase_microphone_image,
384
+ sound2sound_origin_upload_latent_representation_image,
385
+ sound2sound_origin_upload_quantized_latent_representation_image,
386
+ sound2sound_origin_microphone_latent_representation_image,
387
+ sound2sound_origin_microphone_quantized_latent_representation_image])
webUI/natural_language_guided/text2sound.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from matplotlib import pyplot as plt
5
+
6
+ from model.DiffSynthSampler import DiffSynthSampler
7
+ from tools import safe_int
8
+ from webUI.natural_language_guided.utils import latent_representation_to_Gradio_image, \
9
+ encodeBatch2GradioOutput_STFT, add_instrument
10
+
11
+
12
+ def get_text2sound_module(gradioWebUI, text2sound_state, virtual_instruments_state):
13
+ # Load configurations
14
+ uNet = gradioWebUI.uNet
15
+ freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
16
+ VAE_scale = gradioWebUI.VAE_scale
17
+ height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels
18
+
19
+ timesteps = gradioWebUI.timesteps
20
+ VAE_quantizer = gradioWebUI.VAE_quantizer
21
+ VAE_decoder = gradioWebUI.VAE_decoder
22
+ CLAP = gradioWebUI.CLAP
23
+ CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
24
+ device = gradioWebUI.device
25
+ squared = gradioWebUI.squared
26
+ sample_rate = gradioWebUI.sample_rate
27
+ noise_strategy = gradioWebUI.noise_strategy
28
+
29
+ def diffusion_random_sample(text2sound_prompts, text2sound_negative_prompts, text2sound_batchsize,
30
+ text2sound_duration,
31
+ text2sound_guidance_scale, text2sound_sampler,
32
+ text2sound_sample_steps, text2sound_seed,
33
+ text2sound_dict):
34
+ text2sound_sample_steps = int(text2sound_sample_steps)
35
+ text2sound_seed = safe_int(text2sound_seed, 12345678)
36
+
37
+ width = int(time_resolution * ((text2sound_duration + 1) / 4) / VAE_scale)
38
+
39
+ text2sound_batchsize = int(text2sound_batchsize)
40
+
41
+ text2sound_embedding = \
42
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(
43
+ device)
44
+
45
+ CFG = int(text2sound_guidance_scale)
46
+
47
+ mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
48
+ negative_condition = \
49
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[
50
+ 0]
51
+ mySampler.activate_classifier_free_guidance(CFG, negative_condition.to(device))
52
+
53
+ mySampler.respace(list(np.linspace(0, timesteps - 1, text2sound_sample_steps, dtype=np.int32)))
54
+
55
+ condition = text2sound_embedding.repeat(text2sound_batchsize, 1)
56
+
57
+ latent_representations, initial_noise = \
58
+ mySampler.sample(model=uNet, shape=(text2sound_batchsize, channels, height, width), seed=text2sound_seed,
59
+ return_tensor=True, condition=condition, sampler=text2sound_sampler)
60
+
61
+ latent_representations = latent_representations[-1]
62
+ print(latent_representations[0, 0, :3, :3])
63
+
64
+ latent_representation_gradio_images = []
65
+ quantized_latent_representation_gradio_images = []
66
+ new_sound_spectrogram_gradio_images = []
67
+ new_sound_phase_gradio_images = []
68
+ new_sound_rec_signals_gradio = []
69
+
70
+ quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
71
+ # Todo: remove hard-coding
72
+ flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder,
73
+ quantized_latent_representations,
74
+ resolution=(
75
+ 512,
76
+ width * VAE_scale),
77
+ original_STFT_batch=None
78
+ )
79
+
80
+ for i in range(text2sound_batchsize):
81
+ latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
82
+ quantized_latent_representation_gradio_images.append(
83
+ latent_representation_to_Gradio_image(quantized_latent_representations[i]))
84
+ new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i])
85
+ new_sound_phase_gradio_images.append(flipped_phases[i])
86
+ new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i]))
87
+
88
+ text2sound_dict["latent_representations"] = latent_representations.to("cpu").detach().numpy()
89
+ text2sound_dict["quantized_latent_representations"] = quantized_latent_representations.to("cpu").detach().numpy()
90
+ text2sound_dict["latent_representation_gradio_images"] = latent_representation_gradio_images
91
+ text2sound_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
92
+ text2sound_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
93
+ text2sound_dict["new_sound_phase_gradio_images"] = new_sound_phase_gradio_images
94
+ text2sound_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
95
+
96
+ text2sound_dict["condition"] = condition.to("cpu").detach().numpy()
97
+ text2sound_dict["negative_condition"] = negative_condition.to("cpu").detach().numpy()
98
+ text2sound_dict["guidance_scale"] = CFG
99
+ text2sound_dict["sampler"] = text2sound_sampler
100
+
101
+ return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][0],
102
+ text2sound_quantized_latent_representation_image:
103
+ text2sound_dict["quantized_latent_representation_gradio_images"][0],
104
+ text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][0],
105
+ text2sound_sampled_phase_image: text2sound_dict["new_sound_phase_gradio_images"][0],
106
+ text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][0],
107
+ text2sound_seed_textbox: text2sound_seed,
108
+ text2sound_state: text2sound_dict,
109
+ text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1,
110
+ visible=True,
111
+ label="Sample index.",
112
+ info="Swipe to view other samples")}
113
+
114
+ def show_random_sample(sample_index, text2sound_dict):
115
+ sample_index = int(sample_index)
116
+ text2sound_dict["sample_index"] = sample_index
117
+ return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][
118
+ sample_index],
119
+ text2sound_quantized_latent_representation_image:
120
+ text2sound_dict["quantized_latent_representation_gradio_images"][sample_index],
121
+ text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][
122
+ sample_index],
123
+ text2sound_sampled_phase_image: text2sound_dict["new_sound_phase_gradio_images"][
124
+ sample_index],
125
+ text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]}
126
+
127
+
128
+
129
+ def save_virtual_instrument(sample_index, virtual_instrument_name, text2sound_dict, virtual_instruments_dict):
130
+
131
+ virtual_instruments_dict = add_instrument(text2sound_dict, virtual_instruments_dict, virtual_instrument_name, sample_index)
132
+
133
+ return {virtual_instruments_state: virtual_instruments_dict,
134
+ text2sound_instrument_name_textbox: gr.Textbox(label="Instrument name", lines=1,
135
+ placeholder=f"Saved as {virtual_instrument_name}!")}
136
+
137
+ with gr.Tab("Text2sound"):
138
+ gr.Markdown("Use neural networks to select random sounds using your favorite instrument!")
139
+ with gr.Row(variant="panel"):
140
+ with gr.Column(scale=3):
141
+ text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ")
142
+ text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
143
+
144
+ with gr.Column(scale=1):
145
+ text2sound_sampling_button = gr.Button(variant="primary",
146
+ value="Generate a batch of samples and show "
147
+ "the first one",
148
+ scale=1)
149
+ text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
150
+ label="Sample index",
151
+ info="Swipe to view other samples")
152
+ with gr.Row(variant="panel"):
153
+ with gr.Column(variant="panel"):
154
+ text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
155
+ text2sound_sampler_radio = gradioWebUI.get_sampler_radio()
156
+ text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider()
157
+ text2sound_duration_slider = gradioWebUI.get_duration_slider()
158
+ text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
159
+ text2sound_seed_textbox = gradioWebUI.get_seed_textbox()
160
+
161
+ with gr.Column(variant="panel"):
162
+ with gr.Row(variant="panel"):
163
+ text2sound_sampled_spectrogram_image = gr.Image(label="Sampled spectrogram", type="numpy", height=600)
164
+ text2sound_sampled_phase_image = gr.Image(label="Sampled phase", type="numpy", height=600)
165
+ text2sound_sampled_audio = gr.Audio(type="numpy", label="Play")
166
+
167
+ with gr.Row(variant="panel"):
168
+ text2sound_instrument_name_textbox = gr.Textbox(label="Instrument name", lines=1,
169
+ placeholder="Name of your instrument")
170
+ text2sound_save_instrument_button = gr.Button(variant="primary",
171
+ value="Save instrument",
172
+ scale=1)
173
+
174
+ with gr.Row(variant="panel"):
175
+ text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy",
176
+ height=200, width=100)
177
+ text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation",
178
+ type="numpy", height=200, width=100)
179
+
180
+ text2sound_sampling_button.click(diffusion_random_sample,
181
+ inputs=[text2sound_prompts_textbox,
182
+ text2sound_negative_prompts_textbox,
183
+ text2sound_batchsize_slider,
184
+ text2sound_duration_slider,
185
+ text2sound_guidance_scale_slider, text2sound_sampler_radio,
186
+ text2sound_sample_steps_slider,
187
+ text2sound_seed_textbox,
188
+ text2sound_state],
189
+ outputs=[text2sound_latent_representation_image,
190
+ text2sound_quantized_latent_representation_image,
191
+ text2sound_sampled_spectrogram_image,
192
+ text2sound_sampled_phase_image,
193
+ text2sound_sampled_audio,
194
+ text2sound_seed_textbox,
195
+ text2sound_state,
196
+ text2sound_sample_index_slider])
197
+
198
+ text2sound_save_instrument_button.click(save_virtual_instrument,
199
+ inputs=[text2sound_sample_index_slider,
200
+ text2sound_instrument_name_textbox,
201
+ text2sound_state,
202
+ virtual_instruments_state],
203
+ outputs=[virtual_instruments_state,
204
+ text2sound_instrument_name_textbox])
205
+
206
+ text2sound_sample_index_slider.change(show_random_sample,
207
+ inputs=[text2sound_sample_index_slider, text2sound_state],
208
+ outputs=[text2sound_latent_representation_image,
209
+ text2sound_quantized_latent_representation_image,
210
+ text2sound_sampled_spectrogram_image,
211
+ text2sound_sampled_phase_image,
212
+ text2sound_sampled_audio])
webUI/natural_language_guided/track_maker.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from model.DiffSynthSampler import DiffSynthSampler
5
+ from webUI.natural_language_guided.utils import encodeBatch2GradioOutput_STFT
6
+ import mido
7
+ import pyrubberband as pyrb
8
+ from tqdm import tqdm
9
+
10
+ class NoteEvent:
11
+ def __init__(self, note, velocity, start_time, duration):
12
+ self.note = note
13
+ self.velocity = velocity
14
+ self.start_time = start_time # In ticks
15
+ self.duration = duration # In ticks
16
+
17
+ def __str__(self):
18
+ return f"Note {self.note}, velocity {self.velocity}, start_time {self.start_time}, duration {self.duration}"
19
+
20
+
21
+ class Track:
22
+ def __init__(self, track, ticks_per_beat):
23
+ self.tempo_events = self._parse_tempo_events(track)
24
+ self.events = self._parse_note_events(track)
25
+ self.ticks_per_beat = ticks_per_beat
26
+
27
+ def _parse_tempo_events(self, track):
28
+ tempo_events = []
29
+ current_tempo = 500000 # Default MIDI tempo is 120 BPM which is 500000 microseconds per beat
30
+ for msg in track:
31
+ if msg.type == 'set_tempo':
32
+ tempo_events.append((msg.time, msg.tempo))
33
+ elif not msg.is_meta:
34
+ tempo_events.append((msg.time, current_tempo))
35
+ return tempo_events
36
+
37
+ def _parse_note_events(self, track):
38
+ events = []
39
+ start_time = 0
40
+ for msg in track:
41
+ if not msg.is_meta:
42
+ start_time += msg.time
43
+ if msg.type == 'note_on' and msg.velocity > 0:
44
+ note_on_time = start_time
45
+ elif msg.type == 'note_on' and msg.velocity == 0:
46
+ duration = start_time - note_on_time
47
+ events.append(NoteEvent(msg.note, msg.velocity, note_on_time, duration))
48
+ return events
49
+
50
+ def synthesize_track(self, diffSynthSampler, sample_rate=16000):
51
+ track_audio = np.zeros(int(self._get_total_time() * sample_rate), dtype=np.float32)
52
+ current_tempo = 500000 # Start with default MIDI tempo 120 BPM
53
+ duration_note_mapping = {}
54
+
55
+ for event in tqdm(self.events[:50]):
56
+ current_tempo = self._get_tempo_at(event.start_time)
57
+ seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo)
58
+ start_time_sec = event.start_time * seconds_per_tick
59
+ # Todo: set a minimum duration
60
+ duration_sec = event.duration * seconds_per_tick
61
+ duration_sec = max(duration_sec, 0.75)
62
+ start_sample = int(start_time_sec * sample_rate)
63
+ if not (str(duration_sec) in duration_note_mapping):
64
+ note_sample = diffSynthSampler(event.velocity, duration_sec)
65
+ duration_note_mapping[str(duration_sec)] = note_sample / np.max(np.abs(note_sample))
66
+
67
+ note_audio = pyrb.pitch_shift(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52)
68
+ end_sample = start_sample + len(note_audio)
69
+ track_audio[start_sample:end_sample] += note_audio
70
+
71
+ return track_audio
72
+
73
+ def _get_tempo_at(self, time_tick):
74
+ current_tempo = 500000 # Start with default MIDI tempo 120 BPM
75
+ elapsed_ticks = 0
76
+
77
+ for tempo_change in self.tempo_events:
78
+ if elapsed_ticks + tempo_change[0] > time_tick:
79
+ return current_tempo
80
+ elapsed_ticks += tempo_change[0]
81
+ current_tempo = tempo_change[1]
82
+
83
+ return current_tempo
84
+
85
+ def _get_total_time(self):
86
+ total_time = 0
87
+ current_tempo = 500000 # Start with default MIDI tempo 120 BPM
88
+
89
+ for event in self.events:
90
+ current_tempo = self._get_tempo_at(event.start_time)
91
+ seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo)
92
+ total_time += event.duration * seconds_per_tick
93
+
94
+ return total_time
95
+
96
+
97
+ class DiffSynth:
98
+ def __init__(self, instruments_configs, noise_prediction_model, VAE_quantizer, VAE_decoder, text_encoder, CLAP_tokenizer, device,
99
+ model_sample_rate=16000, timesteps=1000, channels=4, freq_resolution=512, time_resolution=256, VAE_scale=4, squared=False):
100
+
101
+ self.noise_prediction_model = noise_prediction_model
102
+ self.VAE_quantizer = VAE_quantizer
103
+ self.VAE_decoder = VAE_decoder
104
+ self.device = device
105
+ self.model_sample_rate = model_sample_rate
106
+ self.timesteps = timesteps
107
+ self.channels = channels
108
+ self.freq_resolution = freq_resolution
109
+ self.time_resolution = time_resolution
110
+ self.height = int(freq_resolution/VAE_scale)
111
+ self.VAE_scale = VAE_scale
112
+ self.squared = squared
113
+ self.text_encoder = text_encoder
114
+ self.CLAP_tokenizer = CLAP_tokenizer
115
+
116
+ # instruments_configs 是字典 string -> (condition, negative_condition, guidance_scale, sample_steps, seed, initial_noise, sampler)
117
+ self.instruments_configs = instruments_configs
118
+ self.diffSynthSamplers = {}
119
+ self._update_instruments()
120
+
121
+
122
+ def _update_instruments(self):
123
+
124
+ def diffSynthSamplerWrapper(instruments_config):
125
+
126
+ def diffSynthSampler(velocity, duration_sec, sample_rate=16000):
127
+
128
+ condition = self.text_encoder.get_text_features(**self.CLAP_tokenizer([""], padding=True, return_tensors="pt")).to(self.device)
129
+ sample_steps = instruments_config['sample_steps']
130
+ sampler = instruments_config['sampler']
131
+ noising_strength = instruments_config['noising_strength']
132
+ latent_representation = instruments_config['latent_representation']
133
+ attack = instruments_config['attack']
134
+ before_release = instruments_config['before_release']
135
+
136
+ assert sample_rate == self.model_sample_rate, "sample_rate != model_sample_rate"
137
+
138
+ width = int(self.time_resolution * ((duration_sec + 1) / 4) / self.VAE_scale)
139
+
140
+ mySampler = DiffSynthSampler(self.timesteps, height=128, channels=4, noise_strategy="repeat", mute=True)
141
+ mySampler.respace(list(np.linspace(0, self.timesteps - 1, sample_steps, dtype=np.int32)))
142
+
143
+ # mask = 1, freeze
144
+ latent_mask = torch.zeros((1, 1, self.height, width), dtype=torch.float32).to(self.device)
145
+ latent_mask[:, :, :, :int(self.time_resolution * (attack / 4) / self.VAE_scale)] = 1.0
146
+ latent_mask[:, :, :, -int(self.time_resolution * ((before_release+1) / 4) / self.VAE_scale):] = 1.0
147
+
148
+ latent_representations, _ = \
149
+ mySampler.inpaint_sample(model=self.noise_prediction_model, shape=(1, self.channels, self.height, width),
150
+ noising_strength=noising_strength, condition=condition,
151
+ guide_img=latent_representation, mask=latent_mask, return_tensor=True,
152
+ sampler=sampler,
153
+ use_dynamic_mask=True, end_noise_level_ratio=0.0,
154
+ mask_flexivity=1.0)
155
+
156
+
157
+ latent_representations = latent_representations[-1]
158
+
159
+ quantized_latent_representations, _, (_, _, _) = self.VAE_quantizer(latent_representations)
160
+ # Todo: remove hard-coding
161
+
162
+ flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(self.VAE_decoder,
163
+ quantized_latent_representations,
164
+ resolution=(
165
+ 512,
166
+ width * self.VAE_scale),
167
+ original_STFT_batch=None,
168
+ )
169
+
170
+
171
+ return rec_signals[0]
172
+
173
+ return diffSynthSampler
174
+
175
+ for key in self.instruments_configs.keys():
176
+ self.diffSynthSamplers[key] = diffSynthSamplerWrapper(self.instruments_configs[key])
177
+
178
+ def get_music(self, mid, instrument_names, sample_rate=16000):
179
+ tracks = [Track(t, mid.ticks_per_beat) for t in mid.tracks]
180
+ assert len(tracks) == len(instrument_names), f"len(tracks) = {len(tracks)} != {len(instrument_names)} = len(instrument_names)"
181
+
182
+ track_audios = [track.synthesize_track(self.diffSynthSamplers[instrument_names[i]], sample_rate=sample_rate) for i, track in enumerate(tracks)]
183
+
184
+ # 将所有音轨填充至最长音轨的长度,以便它们可以被叠加
185
+ max_length = max(len(audio) for audio in track_audios)
186
+ full_audio = np.zeros(max_length, dtype=np.float32) # 初始化全音频数组为零
187
+ for audio in track_audios:
188
+ # 音轨可能不够长,需要填充零
189
+ padded_audio = np.pad(audio, (0, max_length - len(audio)), 'constant')
190
+ full_audio += padded_audio # 叠加音轨
191
+
192
+ return full_audio
webUI/natural_language_guided/utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import torch
4
+
5
+ from tools import np_power_to_db, decode_stft, depad_STFT
6
+
7
+
8
+ def spectrogram_to_Gradio_image(spc):
9
+ ### input: spc [np.ndarray]
10
+ frequency_resolution, time_resolution = spc.shape[-2], spc.shape[-1]
11
+ spc = np.reshape(spc, (frequency_resolution, time_resolution))
12
+
13
+ # Todo:
14
+ magnitude_spectrum = np.abs(spc)
15
+ log_spectrum = np_power_to_db(magnitude_spectrum)
16
+ flipped_log_spectrum = np.flipud(log_spectrum)
17
+
18
+ colorful_spc = np.ones((frequency_resolution, time_resolution, 3)) * -80.0
19
+ colorful_spc[:, :, 0] = flipped_log_spectrum
20
+ colorful_spc[:, :, 1] = flipped_log_spectrum
21
+ colorful_spc[:, :, 2] = np.ones((frequency_resolution, time_resolution)) * -60.0
22
+ # Rescale to 0-255 and convert to uint8
23
+ rescaled = (colorful_spc + 80.0) / 80.0
24
+ rescaled = (255.0 * rescaled).astype(np.uint8)
25
+ return rescaled
26
+
27
+
28
+ def phase_to_Gradio_image(phase):
29
+ ### input: spc [np.ndarray]
30
+ frequency_resolution, time_resolution = phase.shape[-2], phase.shape[-1]
31
+ phase = np.reshape(phase, (frequency_resolution, time_resolution))
32
+
33
+ # Todo:
34
+ flipped_phase = np.flipud(phase)
35
+ flipped_phase = (flipped_phase + 1.0) / 2.0
36
+
37
+ colorful_spc = np.zeros((frequency_resolution, time_resolution, 3))
38
+ colorful_spc[:, :, 0] = flipped_phase
39
+ colorful_spc[:, :, 1] = flipped_phase
40
+ colorful_spc[:, :, 2] = 0.2
41
+ # Rescale to 0-255 and convert to uint8
42
+ rescaled = (255.0 * colorful_spc).astype(np.uint8)
43
+ return rescaled
44
+
45
+
46
+ def latent_representation_to_Gradio_image(latent_representation):
47
+ # input: latent_representation [torch.tensor]
48
+ if not isinstance(latent_representation, np.ndarray):
49
+ latent_representation = latent_representation.to("cpu").detach().numpy()
50
+ image = latent_representation
51
+
52
+ def normalize_image(img):
53
+ min_val = img.min()
54
+ max_val = img.max()
55
+ normalized_img = ((img - min_val) / (max_val - min_val) * 255)
56
+ return normalized_img
57
+
58
+ image[0, :, :] = normalize_image(image[0, :, :])
59
+ image[1, :, :] = normalize_image(image[1, :, :])
60
+ image[2, :, :] = normalize_image(image[2, :, :])
61
+ image[3, :, :] = normalize_image(image[3, :, :])
62
+ image_transposed = np.transpose(image, (1, 2, 0))
63
+ enlarged_image = np.repeat(image_transposed, 8, axis=0)
64
+ enlarged_image = np.repeat(enlarged_image, 8, axis=1)
65
+ return np.flipud(enlarged_image).astype(np.uint8)
66
+
67
+
68
+ def InputBatch2Encode_STFT(encoder, STFT_batch, resolution=(512, 256), quantizer=None, squared=True):
69
+ """Transform batch of numpy spectrogram's into signals and encodings."""
70
+ # Todo: remove resolution hard-coding
71
+ frequency_resolution, time_resolution = resolution
72
+
73
+ device = next(encoder.parameters()).device
74
+ if not (quantizer is None):
75
+ latent_representation_batch = encoder(STFT_batch.to(device))
76
+ quantized_latent_representation_batch, loss, (_, _, _) = quantizer(latent_representation_batch)
77
+ else:
78
+ mu, logvar, latent_representation_batch = encoder(STFT_batch.to(device))
79
+ quantized_latent_representation_batch = None
80
+
81
+ STFT_batch = STFT_batch.to("cpu").detach().numpy()
82
+
83
+ origin_flipped_log_spectrums, origin_flipped_phases, origin_signals = [], [], []
84
+ for STFT in STFT_batch:
85
+
86
+ padded_D_rec = decode_stft(STFT)
87
+ D_rec = depad_STFT(padded_D_rec)
88
+ spc = np.abs(D_rec)
89
+ phase = np.angle(D_rec)
90
+
91
+ flipped_log_spectrum = spectrogram_to_Gradio_image(spc)
92
+ flipped_phase = phase_to_Gradio_image(phase)
93
+
94
+ # get_audio
95
+ rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)
96
+
97
+ origin_flipped_log_spectrums.append(flipped_log_spectrum)
98
+ origin_flipped_phases.append(flipped_phase)
99
+ origin_signals.append(rec_signal)
100
+
101
+ return origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, \
102
+ latent_representation_batch, quantized_latent_representation_batch
103
+
104
+
105
+ def encodeBatch2GradioOutput_STFT(decoder, latent_vector_batch, resolution=(512, 256), original_STFT_batch=None):
106
+ """Show a spectrogram."""
107
+ # Todo: remove resolution hard-coding
108
+ frequency_resolution, time_resolution = resolution
109
+
110
+ if isinstance(latent_vector_batch, np.ndarray):
111
+ latent_vector_batch = torch.from_numpy(latent_vector_batch).to(next(decoder.parameters()).device)
112
+
113
+ reconstruction_batch = decoder(latent_vector_batch).to("cpu").detach().numpy()
114
+
115
+ flipped_log_spectrums, flipped_phases, rec_signals = [], [], []
116
+ flipped_log_spectrums_with_original_amp, flipped_phases_with_original_amp, rec_signals_with_original_amp = [], [], []
117
+
118
+ for index, STFT in enumerate(reconstruction_batch):
119
+ padded_D_rec = decode_stft(STFT)
120
+ D_rec = depad_STFT(padded_D_rec)
121
+ spc = np.abs(D_rec)
122
+ phase = np.angle(D_rec)
123
+
124
+ flipped_log_spectrum = spectrogram_to_Gradio_image(spc)
125
+ flipped_phase = phase_to_Gradio_image(phase)
126
+
127
+ # get_audio
128
+ rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)
129
+
130
+ flipped_log_spectrums.append(flipped_log_spectrum)
131
+ flipped_phases.append(flipped_phase)
132
+ rec_signals.append(rec_signal)
133
+
134
+ ##########################################
135
+
136
+ if original_STFT_batch is not None:
137
+ STFT[0, :, :] = original_STFT_batch[index, 0, :, :]
138
+
139
+ padded_D_rec = decode_stft(STFT)
140
+ D_rec = depad_STFT(padded_D_rec)
141
+ spc = np.abs(D_rec)
142
+ phase = np.angle(D_rec)
143
+
144
+ flipped_log_spectrum = spectrogram_to_Gradio_image(spc)
145
+ flipped_phase = phase_to_Gradio_image(phase)
146
+
147
+ # get_audio
148
+ rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)
149
+
150
+ flipped_log_spectrums_with_original_amp.append(flipped_log_spectrum)
151
+ flipped_phases_with_original_amp.append(flipped_phase)
152
+ rec_signals_with_original_amp.append(rec_signal)
153
+
154
+
155
+ return flipped_log_spectrums, flipped_phases, rec_signals, \
156
+ flipped_log_spectrums_with_original_amp, flipped_phases_with_original_amp, rec_signals_with_original_amp
157
+
158
+
159
+
160
+ def add_instrument(source_dict, virtual_instruments_dict, virtual_instrument_name, sample_index):
161
+
162
+ virtual_instruments = virtual_instruments_dict["virtual_instruments"]
163
+ virtual_instrument = {
164
+ "latent_representation": source_dict["latent_representations"][sample_index],
165
+ "quantized_latent_representation": source_dict["quantized_latent_representations"][sample_index],
166
+ "sampler": source_dict["sampler"],
167
+ "signal": source_dict["new_sound_rec_signals_gradio"][sample_index],
168
+ "spectrogram_gradio_image": source_dict["new_sound_spectrogram_gradio_images"][
169
+ sample_index],
170
+ "phase_gradio_image": source_dict["new_sound_phase_gradio_images"][
171
+ sample_index]}
172
+ virtual_instruments[virtual_instrument_name] = virtual_instrument
173
+ virtual_instruments_dict["virtual_instruments"] = virtual_instruments
174
+ return virtual_instruments_dict