Spaces:
Paused
Paused
Upload 49 files
Browse files- model/DiffSynthSampler.py +425 -0
- model/GAN.py +262 -0
- model/VQGAN.py +684 -0
- model/__pycache__/DiffSynthSampler.cpython-310.pyc +0 -0
- model/__pycache__/GAN.cpython-310.pyc +0 -0
- model/__pycache__/VQGAN.cpython-310.pyc +0 -0
- model/__pycache__/diffusion.cpython-310.pyc +0 -0
- model/__pycache__/diffusion_components.cpython-310.pyc +0 -0
- model/__pycache__/multimodal_model.cpython-310.pyc +0 -0
- model/__pycache__/perceptual_label_predictor.cpython-37.pyc +0 -0
- model/__pycache__/timbre_encoder_pretrain.cpython-310.pyc +0 -0
- model/diffusion.py +371 -0
- model/diffusion_components.py +351 -0
- model/multimodal_model.py +274 -0
- model/timbre_encoder_pretrain.py +220 -0
- tools.py +344 -0
- webUI/__pycache__/app.cpython-310.pyc +0 -0
- webUI/deprecated/interpolationWithCondition.py +178 -0
- webUI/deprecated/interpolationWithXT.py +173 -0
- webUI/natural_language_guided/GAN.py +164 -0
- webUI/natural_language_guided/README.py +53 -0
- webUI/natural_language_guided/__pycache__/README.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/README_STFT.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/buildInstrument_STFT.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/build_instrument.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/gradioWebUI.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/gradioWebUI_STFT.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/gradio_webUI.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/inpaintWithText.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/inpaintWithText_STFT.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/inpaint_with_text.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/rec.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/recSTFT.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/sound2soundWithText.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/sound2soundWithText_STFT.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/sound2sound_with_text.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/text2sound.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/text2sound_STFT.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/track_maker.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/utils.cpython-310.pyc +0 -0
- webUI/natural_language_guided/build_instrument.py +274 -0
- webUI/natural_language_guided/gradio_webUI.py +68 -0
- webUI/natural_language_guided/inpaint_with_text.py +441 -0
- webUI/natural_language_guided/rec.py +190 -0
- webUI/natural_language_guided/sound2sound_with_text.py +416 -0
- webUI/natural_language_guided/super_resolution_with_text.py +387 -0
- webUI/natural_language_guided/text2sound.py +212 -0
- webUI/natural_language_guided/track_maker.py +192 -0
- webUI/natural_language_guided/utils.py +174 -0
model/DiffSynthSampler.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
| 7 |
+
"""
|
| 8 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
| 9 |
+
|
| 10 |
+
:param arr: the 1-D numpy array.
|
| 11 |
+
:param timesteps: a tensor of indices into the array to extract.
|
| 12 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
| 13 |
+
dimension equal to the length of timesteps.
|
| 14 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
| 15 |
+
"""
|
| 16 |
+
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
| 17 |
+
while len(res.shape) < len(broadcast_shape):
|
| 18 |
+
res = res[..., None]
|
| 19 |
+
return res.expand(broadcast_shape)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DiffSynthSampler:
|
| 23 |
+
|
| 24 |
+
def __init__(self, timesteps, beta_start=0.0001, beta_end=0.02, device=None, mute=False,
|
| 25 |
+
height=128, max_batchsize=16, max_width=256, channels=4, train_width=64, noise_strategy="repeat"):
|
| 26 |
+
if device is None:
|
| 27 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
+
else:
|
| 29 |
+
self.device = device
|
| 30 |
+
self.height = height
|
| 31 |
+
self.train_width = train_width
|
| 32 |
+
self.max_batchsize = max_batchsize
|
| 33 |
+
self.max_width = max_width
|
| 34 |
+
self.channels = channels
|
| 35 |
+
self.num_timesteps = timesteps
|
| 36 |
+
self.timestep_map = list(range(self.num_timesteps))
|
| 37 |
+
self.betas = np.array(np.linspace(beta_start, beta_end, self.num_timesteps), dtype=np.float64)
|
| 38 |
+
self.respaced = False
|
| 39 |
+
self.define_beta_schedule()
|
| 40 |
+
self.CFG = 1.0
|
| 41 |
+
self.mute = mute
|
| 42 |
+
self.noise_strategy = noise_strategy
|
| 43 |
+
|
| 44 |
+
def get_deterministic_noise_tensor_non_repeat(self, batchsize, width, reference_noise=None):
|
| 45 |
+
if reference_noise is None:
|
| 46 |
+
large_noise_tensor = torch.randn((self.max_batchsize, self.channels, self.height, self.max_width), device=self.device)
|
| 47 |
+
else:
|
| 48 |
+
assert reference_noise.shape == (batchsize, self.channels, self.height, self.max_width), "reference_noise shape mismatch"
|
| 49 |
+
large_noise_tensor = reference_noise
|
| 50 |
+
return large_noise_tensor[:batchsize, :, :, :width], None
|
| 51 |
+
|
| 52 |
+
def get_deterministic_noise_tensor(self, batchsize, width, reference_noise=None):
|
| 53 |
+
if self.noise_strategy == "repeat":
|
| 54 |
+
noise, concat_points = self.get_deterministic_noise_tensor_repeat(batchsize, width, reference_noise=reference_noise)
|
| 55 |
+
return noise, concat_points
|
| 56 |
+
else:
|
| 57 |
+
noise, concat_points = self.get_deterministic_noise_tensor_non_repeat(batchsize, width, reference_noise=reference_noise)
|
| 58 |
+
return noise, concat_points
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_deterministic_noise_tensor_repeat(self, batchsize, width, reference_noise=None):
|
| 62 |
+
# 生成与训练数据长度相等的噪音
|
| 63 |
+
if reference_noise is None:
|
| 64 |
+
train_noise_tensor = torch.randn((self.max_batchsize, self.channels, self.height, self.train_width), device=self.device)
|
| 65 |
+
else:
|
| 66 |
+
assert reference_noise.shape == (batchsize, self.channels, self.height, self.train_width), "reference_noise shape mismatch"
|
| 67 |
+
train_noise_tensor = reference_noise
|
| 68 |
+
|
| 69 |
+
release_width = int(self.train_width * 1.0 / 4)
|
| 70 |
+
first_part_width = self.train_width - release_width
|
| 71 |
+
|
| 72 |
+
first_part = train_noise_tensor[:batchsize, :, :, :first_part_width]
|
| 73 |
+
release_part = train_noise_tensor[:batchsize, :, :, -release_width:]
|
| 74 |
+
|
| 75 |
+
# 如果所需 length 小于等于 origin length,去掉 first_part 的中间部分
|
| 76 |
+
if width <= self.train_width:
|
| 77 |
+
_first_part_head_width = int((width - release_width) / 2)
|
| 78 |
+
_first_part_tail_width = width - release_width - _first_part_head_width
|
| 79 |
+
all_parts = [first_part[:, :, :, :_first_part_head_width], first_part[:, :, :, -_first_part_tail_width:], release_part]
|
| 80 |
+
|
| 81 |
+
# 沿第四维度拼接张量
|
| 82 |
+
noise_tensor = torch.cat(all_parts, dim=3)
|
| 83 |
+
|
| 84 |
+
# 记录拼接点的位置
|
| 85 |
+
concat_points = [0]
|
| 86 |
+
for part in all_parts[:-1]:
|
| 87 |
+
next_point = concat_points[-1] + part.size(3)
|
| 88 |
+
concat_points.append(next_point)
|
| 89 |
+
|
| 90 |
+
return noise_tensor, concat_points
|
| 91 |
+
|
| 92 |
+
# 如果所需 length 大于 origin length,不断地从中间插入 first_part 的中间部分
|
| 93 |
+
else:
|
| 94 |
+
# 计算需要重复front_width的次数
|
| 95 |
+
repeats = (width - release_width) // first_part_width
|
| 96 |
+
extra = (width - release_width) % first_part_width
|
| 97 |
+
|
| 98 |
+
_repeat_first_part_head_width = int(first_part_width / 2)
|
| 99 |
+
_repeat_first_part_tail_width = first_part_width - _repeat_first_part_head_width
|
| 100 |
+
|
| 101 |
+
repeated_first_head_parts = [first_part[:, :, :, :_repeat_first_part_head_width] for _ in range(repeats)]
|
| 102 |
+
repeated_first_tail_parts = [first_part[:, :, :, -_repeat_first_part_tail_width:] for _ in range(repeats)]
|
| 103 |
+
|
| 104 |
+
# 计算起始索引
|
| 105 |
+
_middle_part_start_index = (first_part_width - extra) // 2
|
| 106 |
+
# 切片张量以获取中间部分
|
| 107 |
+
middle_part = first_part[:, :, :, _middle_part_start_index: _middle_part_start_index + extra]
|
| 108 |
+
|
| 109 |
+
all_parts = repeated_first_head_parts + [middle_part] + repeated_first_tail_parts + [release_part]
|
| 110 |
+
|
| 111 |
+
# 沿第四维度拼接张量
|
| 112 |
+
noise_tensor = torch.cat(all_parts, dim=3)
|
| 113 |
+
|
| 114 |
+
# 记录拼接点的位置
|
| 115 |
+
concat_points = [0]
|
| 116 |
+
for part in all_parts[:-1]:
|
| 117 |
+
next_point = concat_points[-1] + part.size(3)
|
| 118 |
+
concat_points.append(next_point)
|
| 119 |
+
|
| 120 |
+
return noise_tensor, concat_points
|
| 121 |
+
|
| 122 |
+
def define_beta_schedule(self):
|
| 123 |
+
assert self.respaced == False, "This schedule has already been respaced!"
|
| 124 |
+
# define alphas
|
| 125 |
+
self.alphas = 1.0 - self.betas
|
| 126 |
+
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
| 127 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
| 128 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
| 129 |
+
|
| 130 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 131 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
| 132 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
| 133 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
| 134 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
| 135 |
+
self.sqrt_recip_alphas = np.sqrt(1.0 / self.alphas)
|
| 136 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
| 137 |
+
|
| 138 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 139 |
+
self.posterior_variance = (self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod))
|
| 140 |
+
|
| 141 |
+
def activate_classifier_free_guidance(self, CFG, unconditional_condition):
|
| 142 |
+
assert (
|
| 143 |
+
not unconditional_condition is None) or CFG == 1.0, "For CFG != 1.0, unconditional_condition must be available"
|
| 144 |
+
self.CFG = CFG
|
| 145 |
+
self.unconditional_condition = unconditional_condition
|
| 146 |
+
|
| 147 |
+
def respace(self, use_timesteps=None):
|
| 148 |
+
if not use_timesteps is None:
|
| 149 |
+
last_alpha_cumprod = 1.0
|
| 150 |
+
new_betas = []
|
| 151 |
+
self.timestep_map = []
|
| 152 |
+
for i, _alpha_cumprod in enumerate(self.alphas_cumprod):
|
| 153 |
+
if i in use_timesteps:
|
| 154 |
+
new_betas.append(1 - _alpha_cumprod / last_alpha_cumprod)
|
| 155 |
+
last_alpha_cumprod = _alpha_cumprod
|
| 156 |
+
self.timestep_map.append(i)
|
| 157 |
+
self.num_timesteps = len(use_timesteps)
|
| 158 |
+
self.betas = np.array(new_betas)
|
| 159 |
+
self.define_beta_schedule()
|
| 160 |
+
self.respaced = True
|
| 161 |
+
|
| 162 |
+
def generate_linear_noise(self, shape, variance=1.0, first_endpoint=None, second_endpoint=None):
|
| 163 |
+
assert shape[1] == self.channels, "shape[1] != self.channels"
|
| 164 |
+
assert shape[2] == self.height, "shape[2] != self.height"
|
| 165 |
+
noise = torch.empty(*shape, device=self.device)
|
| 166 |
+
|
| 167 |
+
# 第三种情况:两个端点都不是None,进行线性插值
|
| 168 |
+
if first_endpoint is not None and second_endpoint is not None:
|
| 169 |
+
for i in range(shape[0]):
|
| 170 |
+
alpha = i / (shape[0] - 1) # 插值系数
|
| 171 |
+
noise[i] = alpha * second_endpoint + (1 - alpha) * first_endpoint
|
| 172 |
+
return noise # 返回插值后的结果,不需要进行后续的均值和方差调整
|
| 173 |
+
else:
|
| 174 |
+
# 第一个端点不是None
|
| 175 |
+
if first_endpoint is not None:
|
| 176 |
+
noise[0] = first_endpoint
|
| 177 |
+
if shape[0] > 1:
|
| 178 |
+
noise[1], _ = self.get_deterministic_noise_tensor(1, shape[3])[0]
|
| 179 |
+
else:
|
| 180 |
+
noise[0], _ = self.get_deterministic_noise_tensor(1, shape[3])[0]
|
| 181 |
+
if shape[0] > 1:
|
| 182 |
+
noise[1], _ = self.get_deterministic_noise_tensor(1, shape[3])[0]
|
| 183 |
+
|
| 184 |
+
# 生成其他的噪声点
|
| 185 |
+
for i in range(2, shape[0]):
|
| 186 |
+
noise[i] = 2 * noise[i - 1] - noise[i - 2]
|
| 187 |
+
|
| 188 |
+
# 当只有一个端点被指定时
|
| 189 |
+
current_var = noise.var()
|
| 190 |
+
stddev_ratio = torch.sqrt(variance / current_var)
|
| 191 |
+
noise = noise * stddev_ratio
|
| 192 |
+
|
| 193 |
+
# 如果第一个端点被指定,进行平移调整
|
| 194 |
+
if first_endpoint is not None:
|
| 195 |
+
shift = first_endpoint - noise[0]
|
| 196 |
+
noise += shift
|
| 197 |
+
|
| 198 |
+
return noise
|
| 199 |
+
|
| 200 |
+
def q_sample(self, x_start, t, noise=None):
|
| 201 |
+
"""
|
| 202 |
+
Diffuse the data for a given number of diffusion steps.
|
| 203 |
+
|
| 204 |
+
In other words, sample from q(x_t | x_0).
|
| 205 |
+
|
| 206 |
+
:param x_start: the initial data batch.
|
| 207 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
| 208 |
+
:param noise: if specified, the split-out normal noise.
|
| 209 |
+
:return: A noisy version of x_start.
|
| 210 |
+
"""
|
| 211 |
+
assert x_start.shape[1] == self.channels, "shape[1] != self.channels"
|
| 212 |
+
assert x_start.shape[2] == self.height, "shape[2] != self.height"
|
| 213 |
+
|
| 214 |
+
if noise is None:
|
| 215 |
+
# noise = torch.randn_like(x_start)
|
| 216 |
+
noise, _ = self.get_deterministic_noise_tensor(x_start.shape[0], x_start.shape[3])
|
| 217 |
+
|
| 218 |
+
assert noise.shape == x_start.shape
|
| 219 |
+
return (
|
| 220 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
| 221 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
| 222 |
+
* noise
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
@torch.no_grad()
|
| 226 |
+
def ddim_sample(self, model, x, t, condition=None, ddim_eta=0.0):
|
| 227 |
+
map_tensor = torch.tensor(self.timestep_map, device=t.device, dtype=t.dtype)
|
| 228 |
+
mapped_t = map_tensor[t]
|
| 229 |
+
|
| 230 |
+
# Todo: add CFG
|
| 231 |
+
|
| 232 |
+
if self.CFG == 1.0:
|
| 233 |
+
pred_noise = model(x, mapped_t, condition)
|
| 234 |
+
else:
|
| 235 |
+
unconditional_condition = self.unconditional_condition.unsqueeze(0).repeat(
|
| 236 |
+
*([x.shape[0]] + [1] * len(self.unconditional_condition.shape)))
|
| 237 |
+
x_in = torch.cat([x] * 2)
|
| 238 |
+
t_in = torch.cat([mapped_t] * 2)
|
| 239 |
+
c_in = torch.cat([unconditional_condition, condition])
|
| 240 |
+
noise_uncond, noise = model(x_in, t_in, c_in).chunk(2)
|
| 241 |
+
pred_noise = noise_uncond + self.CFG * (noise - noise_uncond)
|
| 242 |
+
|
| 243 |
+
# Todo: END
|
| 244 |
+
|
| 245 |
+
alpha_cumprod_t = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
| 246 |
+
alpha_cumprod_t_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
| 247 |
+
|
| 248 |
+
pred_x0 = (x - torch.sqrt((1. - alpha_cumprod_t)) * pred_noise) / torch.sqrt(alpha_cumprod_t)
|
| 249 |
+
|
| 250 |
+
sigmas_t = (
|
| 251 |
+
ddim_eta
|
| 252 |
+
* torch.sqrt((1 - alpha_cumprod_t_prev) / (1 - alpha_cumprod_t))
|
| 253 |
+
* torch.sqrt(1 - alpha_cumprod_t / alpha_cumprod_t_prev)
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
pred_dir_xt = torch.sqrt(1 - alpha_cumprod_t_prev - sigmas_t ** 2) * pred_noise
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
step_noise, _ = self.get_deterministic_noise_tensor(x.shape[0], x.shape[3])
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
x_prev = torch.sqrt(alpha_cumprod_t_prev) * pred_x0 + pred_dir_xt + sigmas_t * step_noise
|
| 263 |
+
|
| 264 |
+
return x_prev
|
| 265 |
+
|
| 266 |
+
def p_sample(self, model, x, t, condition=None, sampler="ddim"):
|
| 267 |
+
if sampler == "ddim":
|
| 268 |
+
return self.ddim_sample(model, x, t, condition=condition, ddim_eta=0.0)
|
| 269 |
+
elif sampler == "ddpm":
|
| 270 |
+
return self.ddim_sample(model, x, t, condition=condition, ddim_eta=1.0)
|
| 271 |
+
else:
|
| 272 |
+
raise NotImplementedError()
|
| 273 |
+
|
| 274 |
+
def get_dynamic_masks(self, n_masks, shape, concat_points, mask_flexivity=0.8):
|
| 275 |
+
release_length = int(self.train_width / 4)
|
| 276 |
+
assert shape[3] == (concat_points[-1] + release_length), "shape[3] != (concat_points[-1] + release_length)"
|
| 277 |
+
|
| 278 |
+
fraction_lengths = [concat_points[i + 1] - concat_points[i] for i in range(len(concat_points) - 1)]
|
| 279 |
+
|
| 280 |
+
# Todo: remove hard-coding
|
| 281 |
+
n_guidance_steps = int(n_masks * mask_flexivity)
|
| 282 |
+
n_free_steps = n_masks - n_guidance_steps
|
| 283 |
+
|
| 284 |
+
masks = []
|
| 285 |
+
# Todo: 在一半的 steps 内收缩 mask。也就是说,在后程对 release 以外的区域不做inpaint,而是 img2img
|
| 286 |
+
for i in range(n_guidance_steps):
|
| 287 |
+
# mask = 1, freeze
|
| 288 |
+
step_i_mask = torch.zeros((shape[0], 1, shape[2], shape[3]), dtype=torch.float32).to(self.device)
|
| 289 |
+
step_i_mask[:, :, :, -release_length:] = 1.0
|
| 290 |
+
|
| 291 |
+
for fraction_index in range(len(fraction_lengths)):
|
| 292 |
+
|
| 293 |
+
_fraction_mask_length = int((n_guidance_steps - 1 - i) / (n_guidance_steps - 1) * fraction_lengths[fraction_index])
|
| 294 |
+
|
| 295 |
+
if fraction_index == 0:
|
| 296 |
+
step_i_mask[:, :, :, :_fraction_mask_length] = 1.0
|
| 297 |
+
elif fraction_index == len(fraction_lengths) - 1:
|
| 298 |
+
if not _fraction_mask_length == 0:
|
| 299 |
+
step_i_mask[:, :, :, -_fraction_mask_length - release_length:] = 1.0
|
| 300 |
+
else:
|
| 301 |
+
fraction_mask_start_position = int((fraction_lengths[fraction_index] - _fraction_mask_length) / 2)
|
| 302 |
+
|
| 303 |
+
step_i_mask[:, :, :,
|
| 304 |
+
concat_points[fraction_index] + fraction_mask_start_position:concat_points[
|
| 305 |
+
fraction_index] + fraction_mask_start_position + _fraction_mask_length] = 1.0
|
| 306 |
+
masks.append(step_i_mask)
|
| 307 |
+
|
| 308 |
+
for i in range(n_free_steps):
|
| 309 |
+
step_i_mask = torch.zeros((shape[0], 1, shape[2], shape[3]), dtype=torch.float32).to(self.device)
|
| 310 |
+
step_i_mask[:, :, :, -release_length:] = 1.0
|
| 311 |
+
masks.append(step_i_mask)
|
| 312 |
+
|
| 313 |
+
masks.reverse()
|
| 314 |
+
return masks
|
| 315 |
+
|
| 316 |
+
@torch.no_grad()
|
| 317 |
+
def p_sample_loop(self, model, shape, initial_noise=None, start_noise_level_ratio=1.0, end_noise_level_ratio=0.0,
|
| 318 |
+
return_tensor=False, condition=None, guide_img=None,
|
| 319 |
+
mask=None, sampler="ddim", inpaint=False, use_dynamic_mask=False, mask_flexivity=0.8):
|
| 320 |
+
|
| 321 |
+
assert shape[1] == self.channels, "shape[1] != self.channels"
|
| 322 |
+
assert shape[2] == self.height, "shape[2] != self.height"
|
| 323 |
+
|
| 324 |
+
initial_noise, _ = self.get_deterministic_noise_tensor(shape[0], shape[3], reference_noise=initial_noise)
|
| 325 |
+
assert initial_noise.shape == shape, "initial_noise.shape != shape"
|
| 326 |
+
|
| 327 |
+
start_noise_level_index = int(self.num_timesteps * start_noise_level_ratio) # not included!!!
|
| 328 |
+
end_noise_level_index = int(self.num_timesteps * end_noise_level_ratio)
|
| 329 |
+
|
| 330 |
+
timesteps = reversed(range(end_noise_level_index, start_noise_level_index))
|
| 331 |
+
|
| 332 |
+
# configure initial img
|
| 333 |
+
assert (start_noise_level_ratio == 1.0) or (
|
| 334 |
+
not guide_img is None), "A guide_img must be given to sample from a non-pure-noise."
|
| 335 |
+
|
| 336 |
+
if guide_img is None:
|
| 337 |
+
img = initial_noise
|
| 338 |
+
else:
|
| 339 |
+
guide_img, concat_points = self.get_deterministic_noise_tensor_repeat(shape[0], shape[3], reference_noise=guide_img)
|
| 340 |
+
assert guide_img.shape == shape, "guide_img.shape != shape"
|
| 341 |
+
|
| 342 |
+
if start_noise_level_index > 0:
|
| 343 |
+
t = torch.full((shape[0],), start_noise_level_index-1, device=self.device).long() # -1 for start_noise_level_index not included
|
| 344 |
+
img = self.q_sample(guide_img, t, noise=initial_noise)
|
| 345 |
+
else:
|
| 346 |
+
print("Zero noise added to the guidance latent representation.")
|
| 347 |
+
img = guide_img
|
| 348 |
+
|
| 349 |
+
# get masks
|
| 350 |
+
n_masks = start_noise_level_index - end_noise_level_index
|
| 351 |
+
if use_dynamic_mask:
|
| 352 |
+
masks = self.get_dynamic_masks(n_masks, shape, concat_points, mask_flexivity)
|
| 353 |
+
else:
|
| 354 |
+
masks = [mask for _ in range(n_masks)]
|
| 355 |
+
|
| 356 |
+
imgs = [img]
|
| 357 |
+
current_mask = None
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
for i in tqdm(timesteps, total=start_noise_level_index - end_noise_level_index, disable=self.mute):
|
| 361 |
+
|
| 362 |
+
# if i == 3:
|
| 363 |
+
# return [img], initial_noise # 第1排,第1列
|
| 364 |
+
|
| 365 |
+
img = self.p_sample(model, img, torch.full((shape[0],), i, device=self.device, dtype=torch.long),
|
| 366 |
+
condition=condition,
|
| 367 |
+
sampler=sampler)
|
| 368 |
+
# if i == 3:
|
| 369 |
+
# return [img], initial_noise # 第1排,第2列
|
| 370 |
+
|
| 371 |
+
if inpaint:
|
| 372 |
+
if i > 0:
|
| 373 |
+
t = torch.full((shape[0],), int(i-1), device=self.device).long()
|
| 374 |
+
img_noise_t = self.q_sample(guide_img, t, noise=initial_noise)
|
| 375 |
+
# if i == 3:
|
| 376 |
+
# return [img_noise_t], initial_noise # 第2排,第2列
|
| 377 |
+
current_mask = masks.pop()
|
| 378 |
+
img = current_mask * img_noise_t + (1 - current_mask) * img
|
| 379 |
+
# if i == 3:
|
| 380 |
+
# return [img], initial_noise # 第1.5排,最后1列
|
| 381 |
+
else:
|
| 382 |
+
img = current_mask * guide_img + (1 - current_mask) * img
|
| 383 |
+
|
| 384 |
+
if return_tensor:
|
| 385 |
+
imgs.append(img)
|
| 386 |
+
else:
|
| 387 |
+
imgs.append(img.cpu().numpy())
|
| 388 |
+
|
| 389 |
+
return imgs, initial_noise
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def sample(self, model, shape, return_tensor=False, condition=None, sampler="ddim", initial_noise=None, seed=None):
|
| 393 |
+
if not seed is None:
|
| 394 |
+
torch.manual_seed(seed)
|
| 395 |
+
return self.p_sample_loop(model, shape, initial_noise=initial_noise, start_noise_level_ratio=1.0, end_noise_level_ratio=0.0,
|
| 396 |
+
return_tensor=return_tensor, condition=condition, sampler=sampler)
|
| 397 |
+
|
| 398 |
+
def interpolate(self, model, shape, variance, first_endpoint=None, second_endpoint=None, return_tensor=False,
|
| 399 |
+
condition=None, sampler="ddim", seed=None):
|
| 400 |
+
if not seed is None:
|
| 401 |
+
torch.manual_seed(seed)
|
| 402 |
+
linear_noise = self.generate_linear_noise(shape, variance, first_endpoint=first_endpoint,
|
| 403 |
+
second_endpoint=second_endpoint)
|
| 404 |
+
return self.p_sample_loop(model, shape, initial_noise=linear_noise, start_noise_level_ratio=1.0,
|
| 405 |
+
end_noise_level_ratio=0.0,
|
| 406 |
+
return_tensor=return_tensor, condition=condition, sampler=sampler)
|
| 407 |
+
|
| 408 |
+
def img_guided_sample(self, model, shape, noising_strength, guide_img, return_tensor=False, condition=None,
|
| 409 |
+
sampler="ddim", initial_noise=None, seed=None):
|
| 410 |
+
if not seed is None:
|
| 411 |
+
torch.manual_seed(seed)
|
| 412 |
+
assert guide_img.shape[-1] == shape[-1], "guide_img.shape[:-1] != shape[:-1]"
|
| 413 |
+
return self.p_sample_loop(model, shape, start_noise_level_ratio=noising_strength, end_noise_level_ratio=0.0,
|
| 414 |
+
return_tensor=return_tensor, condition=condition, sampler=sampler,
|
| 415 |
+
guide_img=guide_img, initial_noise=initial_noise)
|
| 416 |
+
|
| 417 |
+
def inpaint_sample(self, model, shape, noising_strength, guide_img, mask, return_tensor=False, condition=None,
|
| 418 |
+
sampler="ddim", initial_noise=None, use_dynamic_mask=False, end_noise_level_ratio=0.0, seed=None,
|
| 419 |
+
mask_flexivity=0.8):
|
| 420 |
+
if not seed is None:
|
| 421 |
+
torch.manual_seed(seed)
|
| 422 |
+
return self.p_sample_loop(model, shape, start_noise_level_ratio=noising_strength, end_noise_level_ratio=end_noise_level_ratio,
|
| 423 |
+
return_tensor=return_tensor, condition=condition, guide_img=guide_img, mask=mask,
|
| 424 |
+
sampler=sampler, inpaint=True, initial_noise=initial_noise, use_dynamic_mask=use_dynamic_mask,
|
| 425 |
+
mask_flexivity=mask_flexivity)
|
model/GAN.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from six.moves import xrange
|
| 6 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 7 |
+
import random
|
| 8 |
+
|
| 9 |
+
from model.diffusion import ConditionedUnet
|
| 10 |
+
from tools import create_key
|
| 11 |
+
|
| 12 |
+
class Discriminator(nn.Module):
|
| 13 |
+
def __init__(self, label_emb_dim):
|
| 14 |
+
super(Discriminator, self).__init__()
|
| 15 |
+
# 特征图卷积层
|
| 16 |
+
self.conv_layers = nn.Sequential(
|
| 17 |
+
nn.Conv2d(4, 64, kernel_size=4, stride=2, padding=1),
|
| 18 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 19 |
+
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
|
| 20 |
+
nn.BatchNorm2d(128),
|
| 21 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 22 |
+
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
|
| 23 |
+
nn.BatchNorm2d(256),
|
| 24 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 25 |
+
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
|
| 26 |
+
nn.BatchNorm2d(512),
|
| 27 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 28 |
+
nn.AdaptiveAvgPool2d(1), # 添加适应性池化层
|
| 29 |
+
nn.Flatten()
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# 文本嵌入处理
|
| 33 |
+
self.text_embedding = nn.Sequential(
|
| 34 |
+
nn.Linear(label_emb_dim, 512),
|
| 35 |
+
nn.LeakyReLU(0.2, inplace=True)
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# 判别器最后的全连接层
|
| 39 |
+
self.fc = nn.Linear(512 + 512, 1) # 两个512分别来自特征图和文本嵌入
|
| 40 |
+
|
| 41 |
+
def forward(self, x, text_emb):
|
| 42 |
+
x = self.conv_layers(x)
|
| 43 |
+
text_emb = self.text_embedding(text_emb)
|
| 44 |
+
combined = torch.cat((x, text_emb), dim=1)
|
| 45 |
+
output = self.fc(combined)
|
| 46 |
+
return output
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def evaluate_GAN(device, generator, discriminator, iterator, encodes2embeddings_mapping):
|
| 51 |
+
generator.to(device)
|
| 52 |
+
discriminator.to(device)
|
| 53 |
+
generator.eval()
|
| 54 |
+
discriminator.eval()
|
| 55 |
+
|
| 56 |
+
real_accs = []
|
| 57 |
+
fake_accs = []
|
| 58 |
+
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
for i in range(100):
|
| 61 |
+
data, attributes = next(iter(iterator))
|
| 62 |
+
data = data.to(device)
|
| 63 |
+
|
| 64 |
+
conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes]
|
| 65 |
+
selected_conditions = [random.choice(conditions_of_one_sample) for conditions_of_one_sample in conditions]
|
| 66 |
+
selected_conditions = torch.stack(selected_conditions).float().to(device)
|
| 67 |
+
|
| 68 |
+
# 将数据和标签移至设备
|
| 69 |
+
real_images = data.to(device)
|
| 70 |
+
labels = selected_conditions.to(device)
|
| 71 |
+
|
| 72 |
+
# 生成噪声和假图像
|
| 73 |
+
noise = torch.randn_like(real_images).to(device)
|
| 74 |
+
fake_images = generator(noise)
|
| 75 |
+
|
| 76 |
+
# 评估鉴别器的性能
|
| 77 |
+
real_preds = discriminator(real_images, labels).reshape(-1)
|
| 78 |
+
fake_preds = discriminator(fake_images, labels).reshape(-1)
|
| 79 |
+
real_acc = (real_preds > 0.5).float().mean().item() # 真实图像的准确率
|
| 80 |
+
fake_acc = (fake_preds < 0.5).float().mean().item() # 生成图像的准确率
|
| 81 |
+
|
| 82 |
+
real_accs.append(real_acc)
|
| 83 |
+
fake_accs.append(fake_acc)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# 计算平均准确率
|
| 87 |
+
average_real_acc = sum(real_accs) / len(real_accs)
|
| 88 |
+
average_fake_acc = sum(fake_accs) / len(fake_accs)
|
| 89 |
+
|
| 90 |
+
return average_real_acc, average_fake_acc
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_Generator(model_Config, load_pretrain=False, model_name=None, device="cpu"):
|
| 94 |
+
generator = ConditionedUnet(**model_Config)
|
| 95 |
+
print(f"Model intialized, size: {sum(p.numel() for p in generator.parameters() if p.requires_grad)}")
|
| 96 |
+
generator.to(device)
|
| 97 |
+
|
| 98 |
+
if load_pretrain:
|
| 99 |
+
print(f"Loading weights from models/{model_name}_generator.pth")
|
| 100 |
+
checkpoint = torch.load(f'models/{model_name}_generator.pth', map_location=device)
|
| 101 |
+
generator.load_state_dict(checkpoint['model_state_dict'])
|
| 102 |
+
generator.eval()
|
| 103 |
+
return generator
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def get_Discriminator(model_Config, load_pretrain=False, model_name=None, device="cpu"):
|
| 107 |
+
discriminator = Discriminator(**model_Config)
|
| 108 |
+
print(f"Model intialized, size: {sum(p.numel() for p in discriminator.parameters() if p.requires_grad)}")
|
| 109 |
+
discriminator.to(device)
|
| 110 |
+
|
| 111 |
+
if load_pretrain:
|
| 112 |
+
print(f"Loading weights from models/{model_name}_discriminator.pth")
|
| 113 |
+
checkpoint = torch.load(f'models/{model_name}_discriminator.pth', map_location=device)
|
| 114 |
+
discriminator.load_state_dict(checkpoint['model_state_dict'])
|
| 115 |
+
discriminator.eval()
|
| 116 |
+
return discriminator
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def train_GAN(device, init_model_name, unetConfig, BATCH_SIZE, lr_G, lr_D, max_iter, iterator, load_pretrain,
|
| 120 |
+
encodes2embeddings_mapping, save_steps, unconditional_condition, uncondition_rate, save_model_name=None):
|
| 121 |
+
|
| 122 |
+
if save_model_name is None:
|
| 123 |
+
save_model_name = init_model_name
|
| 124 |
+
|
| 125 |
+
def save_model_hyperparameter(model_name, unetConfig, BATCH_SIZE, model_size, current_iter, current_loss):
|
| 126 |
+
model_hyperparameter = unetConfig
|
| 127 |
+
model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE
|
| 128 |
+
model_hyperparameter["lr_G"] = lr_G
|
| 129 |
+
model_hyperparameter["lr_D"] = lr_D
|
| 130 |
+
model_hyperparameter["model_size"] = model_size
|
| 131 |
+
model_hyperparameter["current_iter"] = current_iter
|
| 132 |
+
model_hyperparameter["current_loss"] = current_loss
|
| 133 |
+
with open(f"models/hyperparameters/{model_name}_GAN.json", "w") as json_file:
|
| 134 |
+
json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
|
| 135 |
+
|
| 136 |
+
generator = ConditionedUnet(**unetConfig)
|
| 137 |
+
discriminator = Discriminator(unetConfig["label_emb_dim"])
|
| 138 |
+
generator_size = sum(p.numel() for p in generator.parameters() if p.requires_grad)
|
| 139 |
+
discriminator_size = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
|
| 140 |
+
|
| 141 |
+
print(f"Generator trainable parameters: {generator_size}, discriminator trainable parameters: {discriminator_size}")
|
| 142 |
+
generator.to(device)
|
| 143 |
+
discriminator.to(device)
|
| 144 |
+
optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()), lr=lr_G, amsgrad=False)
|
| 145 |
+
optimizer_D = torch.optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()), lr=lr_D, amsgrad=False)
|
| 146 |
+
|
| 147 |
+
if load_pretrain:
|
| 148 |
+
print(f"Loading weights from models/{init_model_name}_generator.pt")
|
| 149 |
+
checkpoint = torch.load(f'models/{init_model_name}_generator.pth')
|
| 150 |
+
generator.load_state_dict(checkpoint['model_state_dict'])
|
| 151 |
+
optimizer_G.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 152 |
+
print(f"Loading weights from models/{init_model_name}_discriminator.pt")
|
| 153 |
+
checkpoint = torch.load(f'models/{init_model_name}_discriminator.pth')
|
| 154 |
+
discriminator.load_state_dict(checkpoint['model_state_dict'])
|
| 155 |
+
optimizer_D.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 156 |
+
else:
|
| 157 |
+
print("Model initialized.")
|
| 158 |
+
if max_iter == 0:
|
| 159 |
+
print("Return model directly.")
|
| 160 |
+
return generator, discriminator, optimizer_G, optimizer_D
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
train_loss_G, train_loss_D = [], []
|
| 164 |
+
writer = SummaryWriter(f'runs/{save_model_name}_GAN')
|
| 165 |
+
|
| 166 |
+
# average_real_acc, average_fake_acc = evaluate_GAN(device, generator, discriminator, iterator, encodes2embeddings_mapping)
|
| 167 |
+
# print(f"average_real_acc, average_fake_acc: {average_real_acc, average_fake_acc}")
|
| 168 |
+
|
| 169 |
+
criterion = nn.BCEWithLogitsLoss()
|
| 170 |
+
generator.train()
|
| 171 |
+
for i in xrange(max_iter):
|
| 172 |
+
data, attributes = next(iter(iterator))
|
| 173 |
+
data = data.to(device)
|
| 174 |
+
|
| 175 |
+
conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes]
|
| 176 |
+
unconditional_condition_copy = torch.tensor(unconditional_condition, dtype=torch.float32).to(device).detach()
|
| 177 |
+
selected_conditions = [unconditional_condition_copy if random.random() < uncondition_rate else random.choice(
|
| 178 |
+
conditions_of_one_sample) for conditions_of_one_sample in conditions]
|
| 179 |
+
batch_size = len(selected_conditions)
|
| 180 |
+
selected_conditions = torch.stack(selected_conditions).float().to(device)
|
| 181 |
+
|
| 182 |
+
# 将数据和标签移至设备
|
| 183 |
+
real_images = data.to(device)
|
| 184 |
+
labels = selected_conditions.to(device)
|
| 185 |
+
|
| 186 |
+
# 真实和假的标签
|
| 187 |
+
real_labels = torch.ones(batch_size, 1).to(device)
|
| 188 |
+
fake_labels = torch.zeros(batch_size, 1).to(device)
|
| 189 |
+
|
| 190 |
+
# ========== 训练鉴别器 ==========
|
| 191 |
+
optimizer_D.zero_grad()
|
| 192 |
+
|
| 193 |
+
# 计算鉴别器对真实图像的损失
|
| 194 |
+
outputs_real = discriminator(real_images, labels)
|
| 195 |
+
loss_D_real = criterion(outputs_real, real_labels)
|
| 196 |
+
|
| 197 |
+
# 生成假图像
|
| 198 |
+
noise = torch.randn_like(real_images).to(device)
|
| 199 |
+
fake_images = generator(noise, labels)
|
| 200 |
+
|
| 201 |
+
# 计算鉴别器对假图像的损失
|
| 202 |
+
outputs_fake = discriminator(fake_images.detach(), labels)
|
| 203 |
+
loss_D_fake = criterion(outputs_fake, fake_labels)
|
| 204 |
+
|
| 205 |
+
# 反向传播和优化
|
| 206 |
+
loss_D = loss_D_real + loss_D_fake
|
| 207 |
+
loss_D.backward()
|
| 208 |
+
optimizer_D.step()
|
| 209 |
+
|
| 210 |
+
# ========== 训练生成器 ==========
|
| 211 |
+
optimizer_G.zero_grad()
|
| 212 |
+
|
| 213 |
+
# 计算生成器的损失
|
| 214 |
+
outputs_fake = discriminator(fake_images, labels)
|
| 215 |
+
loss_G = criterion(outputs_fake, real_labels)
|
| 216 |
+
|
| 217 |
+
# 反向传播和优化
|
| 218 |
+
loss_G.backward()
|
| 219 |
+
optimizer_G.step()
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
train_loss_G.append(loss_G.item())
|
| 223 |
+
train_loss_D.append(loss_D.item())
|
| 224 |
+
step = int(optimizer_G.state_dict()['state'][list(optimizer_G.state_dict()['state'].keys())[0]]['step'].numpy())
|
| 225 |
+
|
| 226 |
+
if (i + 1) % 100 == 0:
|
| 227 |
+
print('%d step' % (step))
|
| 228 |
+
|
| 229 |
+
if (i + 1) % save_steps == 0:
|
| 230 |
+
current_loss_D = np.mean(train_loss_D[-save_steps:])
|
| 231 |
+
current_loss_G = np.mean(train_loss_G[-save_steps:])
|
| 232 |
+
print('current_loss_G: %.5f' % current_loss_G)
|
| 233 |
+
print('current_loss_D: %.5f' % current_loss_D)
|
| 234 |
+
|
| 235 |
+
writer.add_scalar(f"current_loss_G", current_loss_G, step)
|
| 236 |
+
writer.add_scalar(f"current_loss_D", current_loss_D, step)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
torch.save({
|
| 240 |
+
'model_state_dict': generator.state_dict(),
|
| 241 |
+
'optimizer_state_dict': optimizer_G.state_dict(),
|
| 242 |
+
}, f'models/{save_model_name}_generator.pth')
|
| 243 |
+
save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, generator_size, step, current_loss_G)
|
| 244 |
+
torch.save({
|
| 245 |
+
'model_state_dict': discriminator.state_dict(),
|
| 246 |
+
'optimizer_state_dict': optimizer_D.state_dict(),
|
| 247 |
+
}, f'models/{save_model_name}_discriminator.pth')
|
| 248 |
+
save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, discriminator_size, step, current_loss_D)
|
| 249 |
+
|
| 250 |
+
if step % 10000 == 0:
|
| 251 |
+
torch.save({
|
| 252 |
+
'model_state_dict': generator.state_dict(),
|
| 253 |
+
'optimizer_state_dict': optimizer_G.state_dict(),
|
| 254 |
+
}, f'models/history/{save_model_name}_{step}_generator.pth')
|
| 255 |
+
torch.save({
|
| 256 |
+
'model_state_dict': discriminator.state_dict(),
|
| 257 |
+
'optimizer_state_dict': optimizer_D.state_dict(),
|
| 258 |
+
}, f'models/history/{save_model_name}_{step}_discriminator.pth')
|
| 259 |
+
|
| 260 |
+
return generator, discriminator, optimizer_G, optimizer_D
|
| 261 |
+
|
| 262 |
+
|
model/VQGAN.py
ADDED
|
@@ -0,0 +1,684 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import numpy as np
|
| 7 |
+
from six.moves import xrange
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from torchvision import models
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def Normalize(in_channels, num_groups=32, norm_type="groupnorm"):
|
| 13 |
+
"""Normalization layer"""
|
| 14 |
+
|
| 15 |
+
if norm_type == "batchnorm":
|
| 16 |
+
return torch.nn.BatchNorm2d(in_channels)
|
| 17 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def nonlinearity(x, act_type="relu"):
|
| 21 |
+
"""Nonlinear activation function"""
|
| 22 |
+
|
| 23 |
+
if act_type == "relu":
|
| 24 |
+
return F.relu(x)
|
| 25 |
+
else:
|
| 26 |
+
# swish
|
| 27 |
+
return x * torch.sigmoid(x)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class VectorQuantizer(nn.Module):
|
| 31 |
+
"""Vector quantization layer"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, num_embeddings, embedding_dim, commitment_cost):
|
| 34 |
+
super(VectorQuantizer, self).__init__()
|
| 35 |
+
|
| 36 |
+
self._embedding_dim = embedding_dim
|
| 37 |
+
self._num_embeddings = num_embeddings
|
| 38 |
+
|
| 39 |
+
self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
|
| 40 |
+
self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings)
|
| 41 |
+
self._commitment_cost = commitment_cost
|
| 42 |
+
|
| 43 |
+
def forward(self, inputs):
|
| 44 |
+
# convert inputs from BCHW -> BHWC
|
| 45 |
+
inputs = inputs.permute(0, 2, 3, 1).contiguous()
|
| 46 |
+
input_shape = inputs.shape
|
| 47 |
+
|
| 48 |
+
# Flatten input BCHW -> (BHW)C
|
| 49 |
+
flat_input = inputs.view(-1, self._embedding_dim)
|
| 50 |
+
|
| 51 |
+
# Calculate distances (input-embedding)^2
|
| 52 |
+
distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
|
| 53 |
+
+ torch.sum(self._embedding.weight ** 2, dim=1)
|
| 54 |
+
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))
|
| 55 |
+
|
| 56 |
+
# Encoding (one-hot-encoding matrix)
|
| 57 |
+
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
|
| 58 |
+
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
|
| 59 |
+
encodings.scatter_(1, encoding_indices, 1)
|
| 60 |
+
|
| 61 |
+
# Quantize and unflatten
|
| 62 |
+
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
|
| 63 |
+
|
| 64 |
+
# Loss
|
| 65 |
+
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
|
| 66 |
+
q_latent_loss = F.mse_loss(quantized, inputs.detach())
|
| 67 |
+
loss = q_latent_loss + self._commitment_cost * e_latent_loss
|
| 68 |
+
|
| 69 |
+
quantized = inputs + (quantized - inputs).detach()
|
| 70 |
+
avg_probs = torch.mean(encodings, dim=0)
|
| 71 |
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
| 72 |
+
|
| 73 |
+
# convert quantized from BHWC -> BCHW
|
| 74 |
+
min_encodings, min_encoding_indices = None, None
|
| 75 |
+
return quantized.permute(0, 3, 1, 2).contiguous(), loss, (perplexity, min_encodings, min_encoding_indices)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class VectorQuantizerEMA(nn.Module):
|
| 79 |
+
"""Vector quantization layer based on exponential moving average"""
|
| 80 |
+
|
| 81 |
+
def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
|
| 82 |
+
super(VectorQuantizerEMA, self).__init__()
|
| 83 |
+
|
| 84 |
+
self._embedding_dim = embedding_dim
|
| 85 |
+
self._num_embeddings = num_embeddings
|
| 86 |
+
|
| 87 |
+
self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
|
| 88 |
+
self._embedding.weight.data.normal_()
|
| 89 |
+
self._commitment_cost = commitment_cost
|
| 90 |
+
|
| 91 |
+
self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
|
| 92 |
+
self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
|
| 93 |
+
self._ema_w.data.normal_()
|
| 94 |
+
|
| 95 |
+
self._decay = decay
|
| 96 |
+
self._epsilon = epsilon
|
| 97 |
+
|
| 98 |
+
def forward(self, inputs):
|
| 99 |
+
# convert inputs from BCHW -> BHWC
|
| 100 |
+
inputs = inputs.permute(0, 2, 3, 1).contiguous()
|
| 101 |
+
input_shape = inputs.shape
|
| 102 |
+
|
| 103 |
+
# Flatten input
|
| 104 |
+
flat_input = inputs.view(-1, self._embedding_dim)
|
| 105 |
+
|
| 106 |
+
# Calculate distances
|
| 107 |
+
distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
|
| 108 |
+
+ torch.sum(self._embedding.weight ** 2, dim=1)
|
| 109 |
+
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))
|
| 110 |
+
|
| 111 |
+
# Encoding
|
| 112 |
+
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
|
| 113 |
+
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
|
| 114 |
+
encodings.scatter_(1, encoding_indices, 1)
|
| 115 |
+
|
| 116 |
+
# Quantize and unflatten
|
| 117 |
+
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
|
| 118 |
+
|
| 119 |
+
# Use EMA to update the embedding vectors
|
| 120 |
+
if self.training:
|
| 121 |
+
self._ema_cluster_size = self._ema_cluster_size * self._decay + \
|
| 122 |
+
(1 - self._decay) * torch.sum(encodings, 0)
|
| 123 |
+
|
| 124 |
+
# Laplace smoothing of the cluster size
|
| 125 |
+
n = torch.sum(self._ema_cluster_size.data)
|
| 126 |
+
self._ema_cluster_size = (
|
| 127 |
+
(self._ema_cluster_size + self._epsilon)
|
| 128 |
+
/ (n + self._num_embeddings * self._epsilon) * n)
|
| 129 |
+
|
| 130 |
+
dw = torch.matmul(encodings.t(), flat_input)
|
| 131 |
+
self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
|
| 132 |
+
|
| 133 |
+
self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
|
| 134 |
+
|
| 135 |
+
# Loss
|
| 136 |
+
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
|
| 137 |
+
loss = self._commitment_cost * e_latent_loss
|
| 138 |
+
|
| 139 |
+
# Straight Through Estimator
|
| 140 |
+
quantized = inputs + (quantized - inputs).detach()
|
| 141 |
+
avg_probs = torch.mean(encodings, dim=0)
|
| 142 |
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
| 143 |
+
|
| 144 |
+
# convert quantized from BHWC -> BCHW
|
| 145 |
+
min_encodings, min_encoding_indices = None, None
|
| 146 |
+
return quantized.permute(0, 3, 1, 2).contiguous(), loss, (perplexity, min_encodings, min_encoding_indices)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class DownSample(nn.Module):
|
| 150 |
+
"""DownSample layer"""
|
| 151 |
+
|
| 152 |
+
def __init__(self, in_channels, out_channels):
|
| 153 |
+
super(DownSample, self).__init__()
|
| 154 |
+
self._conv2d = nn.Conv2d(in_channels=in_channels,
|
| 155 |
+
out_channels=out_channels,
|
| 156 |
+
kernel_size=4,
|
| 157 |
+
stride=2, padding=1)
|
| 158 |
+
|
| 159 |
+
def forward(self, x):
|
| 160 |
+
return self._conv2d(x)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class UpSample(nn.Module):
|
| 164 |
+
"""UpSample layer"""
|
| 165 |
+
|
| 166 |
+
def __init__(self, in_channels, out_channels):
|
| 167 |
+
super(UpSample, self).__init__()
|
| 168 |
+
self._conv2d = nn.ConvTranspose2d(in_channels=in_channels,
|
| 169 |
+
out_channels=out_channels,
|
| 170 |
+
kernel_size=4,
|
| 171 |
+
stride=2, padding=1)
|
| 172 |
+
|
| 173 |
+
def forward(self, x):
|
| 174 |
+
return self._conv2d(x)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class ResnetBlock(nn.Module):
|
| 178 |
+
"""ResnetBlock is a combination of non-linearity, convolution, and normalization"""
|
| 179 |
+
|
| 180 |
+
def __init__(self, *, in_channels, out_channels=None, double_conv=False, conv_shortcut=False,
|
| 181 |
+
dropout=0.0, temb_channels=512, norm_type="groupnorm", act_type="relu", num_groups=32):
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.in_channels = in_channels
|
| 184 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 185 |
+
self.out_channels = out_channels
|
| 186 |
+
self.use_conv_shortcut = conv_shortcut
|
| 187 |
+
self.act_type = act_type
|
| 188 |
+
|
| 189 |
+
self.norm1 = Normalize(in_channels, norm_type=norm_type, num_groups=num_groups)
|
| 190 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
| 191 |
+
out_channels,
|
| 192 |
+
kernel_size=3,
|
| 193 |
+
stride=1,
|
| 194 |
+
padding=1)
|
| 195 |
+
if temb_channels > 0:
|
| 196 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
| 197 |
+
out_channels)
|
| 198 |
+
|
| 199 |
+
self.double_conv = double_conv
|
| 200 |
+
if self.double_conv:
|
| 201 |
+
self.norm2 = Normalize(out_channels, norm_type=norm_type, num_groups=num_groups)
|
| 202 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 203 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
| 204 |
+
out_channels,
|
| 205 |
+
kernel_size=3,
|
| 206 |
+
stride=1,
|
| 207 |
+
padding=1)
|
| 208 |
+
|
| 209 |
+
if self.in_channels != self.out_channels:
|
| 210 |
+
if self.use_conv_shortcut:
|
| 211 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
| 212 |
+
out_channels,
|
| 213 |
+
kernel_size=3,
|
| 214 |
+
stride=1,
|
| 215 |
+
padding=1)
|
| 216 |
+
else:
|
| 217 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
| 218 |
+
out_channels,
|
| 219 |
+
kernel_size=1,
|
| 220 |
+
stride=1,
|
| 221 |
+
padding=0)
|
| 222 |
+
|
| 223 |
+
def forward(self, x, temb=None):
|
| 224 |
+
h = x
|
| 225 |
+
h = self.norm1(h)
|
| 226 |
+
h = nonlinearity(h, act_type=self.act_type)
|
| 227 |
+
h = self.conv1(h)
|
| 228 |
+
|
| 229 |
+
if temb is not None:
|
| 230 |
+
h = h + self.temb_proj(nonlinearity(temb, act_type=self.act_type))[:, :, None, None]
|
| 231 |
+
|
| 232 |
+
if self.double_conv:
|
| 233 |
+
h = self.norm2(h)
|
| 234 |
+
h = nonlinearity(h, act_type=self.act_type)
|
| 235 |
+
h = self.dropout(h)
|
| 236 |
+
h = self.conv2(h)
|
| 237 |
+
|
| 238 |
+
if self.in_channels != self.out_channels:
|
| 239 |
+
if self.use_conv_shortcut:
|
| 240 |
+
x = self.conv_shortcut(x)
|
| 241 |
+
else:
|
| 242 |
+
x = self.nin_shortcut(x)
|
| 243 |
+
|
| 244 |
+
return x + h
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class LinearAttention(nn.Module):
|
| 248 |
+
"""Efficient attention block based on <https://proceedings.mlr.press/v119/katharopoulos20a.html>"""
|
| 249 |
+
|
| 250 |
+
def __init__(self, dim, heads=4, dim_head=32, with_skip=True):
|
| 251 |
+
super().__init__()
|
| 252 |
+
self.heads = heads
|
| 253 |
+
hidden_dim = dim_head * heads
|
| 254 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
| 255 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
| 256 |
+
|
| 257 |
+
self.with_skip = with_skip
|
| 258 |
+
if self.with_skip:
|
| 259 |
+
self.nin_shortcut = torch.nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0)
|
| 260 |
+
|
| 261 |
+
def forward(self, x):
|
| 262 |
+
b, c, h, w = x.shape
|
| 263 |
+
qkv = self.to_qkv(x)
|
| 264 |
+
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3)
|
| 265 |
+
k = k.softmax(dim=-1)
|
| 266 |
+
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
| 267 |
+
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
| 268 |
+
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
| 269 |
+
|
| 270 |
+
if self.with_skip:
|
| 271 |
+
return self.to_out(out) + self.nin_shortcut(x)
|
| 272 |
+
return self.to_out(out)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class Encoder(nn.Module):
|
| 276 |
+
"""The encoder, consisting of alternating stacks of ResNet blocks, efficient attention modules, and downsampling layers."""
|
| 277 |
+
|
| 278 |
+
def __init__(self, in_channels, hidden_channels, embedding_dim, block_depth=2,
|
| 279 |
+
attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu", num_groups=32):
|
| 280 |
+
super(Encoder, self).__init__()
|
| 281 |
+
|
| 282 |
+
if attn_pos is None:
|
| 283 |
+
attn_pos = []
|
| 284 |
+
self._layers = nn.ModuleList([DownSample(in_channels, hidden_channels[0])])
|
| 285 |
+
current_channel = hidden_channels[0]
|
| 286 |
+
|
| 287 |
+
for i in range(1, len(hidden_channels)):
|
| 288 |
+
for _ in range(block_depth - 1):
|
| 289 |
+
self._layers.append(ResnetBlock(in_channels=current_channel,
|
| 290 |
+
out_channels=current_channel,
|
| 291 |
+
double_conv=False,
|
| 292 |
+
conv_shortcut=False,
|
| 293 |
+
norm_type=norm_type,
|
| 294 |
+
act_type=act_type,
|
| 295 |
+
num_groups=num_groups))
|
| 296 |
+
if current_channel in attn_pos:
|
| 297 |
+
self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
|
| 298 |
+
|
| 299 |
+
self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
|
| 300 |
+
self._layers.append(nn.ReLU())
|
| 301 |
+
self._layers.append(DownSample(current_channel, hidden_channels[i]))
|
| 302 |
+
current_channel = hidden_channels[i]
|
| 303 |
+
|
| 304 |
+
for _ in range(block_depth - 1):
|
| 305 |
+
self._layers.append(ResnetBlock(in_channels=current_channel,
|
| 306 |
+
out_channels=current_channel,
|
| 307 |
+
double_conv=False,
|
| 308 |
+
conv_shortcut=False,
|
| 309 |
+
norm_type=norm_type,
|
| 310 |
+
act_type=act_type,
|
| 311 |
+
num_groups=num_groups))
|
| 312 |
+
if current_channel in attn_pos:
|
| 313 |
+
self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
|
| 314 |
+
|
| 315 |
+
# Conv1x1: hidden_channels[-1] -> embedding_dim
|
| 316 |
+
self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
|
| 317 |
+
self._layers.append(nn.ReLU())
|
| 318 |
+
self._layers.append(nn.Conv2d(in_channels=current_channel,
|
| 319 |
+
out_channels=embedding_dim,
|
| 320 |
+
kernel_size=1,
|
| 321 |
+
stride=1))
|
| 322 |
+
|
| 323 |
+
def forward(self, x):
|
| 324 |
+
for layer in self._layers:
|
| 325 |
+
x = layer(x)
|
| 326 |
+
return x
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class Decoder(nn.Module):
|
| 330 |
+
"""The decoder, consisting of alternating stacks of ResNet blocks, efficient attention modules, and upsampling layers."""
|
| 331 |
+
|
| 332 |
+
def __init__(self, embedding_dim, hidden_channels, out_channels, block_depth=2,
|
| 333 |
+
attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu",
|
| 334 |
+
num_groups=32):
|
| 335 |
+
super(Decoder, self).__init__()
|
| 336 |
+
|
| 337 |
+
if attn_pos is None:
|
| 338 |
+
attn_pos = []
|
| 339 |
+
reversed_hidden_channels = list(reversed(hidden_channels))
|
| 340 |
+
|
| 341 |
+
# Conv1x1: hidden_channels[-1] -> embedding_dim
|
| 342 |
+
self._layers = nn.ModuleList([nn.Conv2d(in_channels=embedding_dim,
|
| 343 |
+
out_channels=reversed_hidden_channels[0],
|
| 344 |
+
kernel_size=1, stride=1, bias=False)])
|
| 345 |
+
|
| 346 |
+
current_channel = reversed_hidden_channels[0]
|
| 347 |
+
|
| 348 |
+
for _ in range(block_depth - 1):
|
| 349 |
+
if current_channel in attn_pos:
|
| 350 |
+
self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
|
| 351 |
+
self._layers.append(ResnetBlock(in_channels=current_channel,
|
| 352 |
+
out_channels=current_channel,
|
| 353 |
+
double_conv=False,
|
| 354 |
+
conv_shortcut=False,
|
| 355 |
+
norm_type=norm_type,
|
| 356 |
+
act_type=act_type,
|
| 357 |
+
num_groups=num_groups))
|
| 358 |
+
|
| 359 |
+
for i in range(1, len(reversed_hidden_channels)):
|
| 360 |
+
self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
|
| 361 |
+
self._layers.append(nn.ReLU())
|
| 362 |
+
self._layers.append(UpSample(current_channel, reversed_hidden_channels[i]))
|
| 363 |
+
current_channel = reversed_hidden_channels[i]
|
| 364 |
+
|
| 365 |
+
for _ in range(block_depth - 1):
|
| 366 |
+
if current_channel in attn_pos:
|
| 367 |
+
self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
|
| 368 |
+
self._layers.append(ResnetBlock(in_channels=current_channel,
|
| 369 |
+
out_channels=current_channel,
|
| 370 |
+
double_conv=False,
|
| 371 |
+
conv_shortcut=False,
|
| 372 |
+
norm_type=norm_type,
|
| 373 |
+
act_type=act_type,
|
| 374 |
+
num_groups=num_groups))
|
| 375 |
+
|
| 376 |
+
self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
|
| 377 |
+
self._layers.append(nn.ReLU())
|
| 378 |
+
self._layers.append(UpSample(current_channel, current_channel))
|
| 379 |
+
|
| 380 |
+
# final layers
|
| 381 |
+
self._layers.append(ResnetBlock(in_channels=current_channel,
|
| 382 |
+
out_channels=out_channels,
|
| 383 |
+
double_conv=False,
|
| 384 |
+
conv_shortcut=False,
|
| 385 |
+
norm_type=norm_type,
|
| 386 |
+
act_type=act_type,
|
| 387 |
+
num_groups=num_groups))
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def forward(self, x):
|
| 391 |
+
for layer in self._layers:
|
| 392 |
+
x = layer(x)
|
| 393 |
+
|
| 394 |
+
log_magnitude = torch.nn.functional.softplus(x[:, 0, :, :])
|
| 395 |
+
|
| 396 |
+
cos_phase = torch.tanh(x[:, 1, :, :])
|
| 397 |
+
sin_phase = torch.tanh(x[:, 2, :, :])
|
| 398 |
+
x = torch.stack([log_magnitude, cos_phase, sin_phase], dim=1)
|
| 399 |
+
|
| 400 |
+
return x
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
class VQGAN_Discriminator(nn.Module):
|
| 404 |
+
"""The discriminator employs an 18-layer-ResNet architecture , with the first layer replaced by a 2D convolutional
|
| 405 |
+
layer that accommodates spectral representation inputs and the last two layers replaced by a binary classifier
|
| 406 |
+
layer."""
|
| 407 |
+
|
| 408 |
+
def __init__(self, in_channels=1):
|
| 409 |
+
super(VQGAN_Discriminator, self).__init__()
|
| 410 |
+
resnet = models.resnet18(pretrained=True)
|
| 411 |
+
|
| 412 |
+
# 修改第一层以接受单通道(黑白)图像
|
| 413 |
+
resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
|
| 414 |
+
|
| 415 |
+
# 使用ResNet的特征提取部分
|
| 416 |
+
self.features = nn.Sequential(*list(resnet.children())[:-2])
|
| 417 |
+
|
| 418 |
+
# 添加判别器的额外层
|
| 419 |
+
self.classifier = nn.Sequential(
|
| 420 |
+
nn.Linear(512, 1),
|
| 421 |
+
nn.Sigmoid()
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
def forward(self, x):
|
| 425 |
+
x = self.features(x)
|
| 426 |
+
x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
|
| 427 |
+
x = torch.flatten(x, 1)
|
| 428 |
+
x = self.classifier(x)
|
| 429 |
+
return x
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
class VQGAN(nn.Module):
|
| 433 |
+
"""The VQ-GAN model. <https://openaccess.thecvf.com/content/CVPR2021/html/Esser_Taming_Transformers_for_High-Resolution_Image_Synthesis_CVPR_2021_paper.html?ref=>"""
|
| 434 |
+
|
| 435 |
+
def __init__(self, in_channels, hidden_channels, embedding_dim, out_channels, block_depth=2,
|
| 436 |
+
attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu",
|
| 437 |
+
num_embeddings=1024, commitment_cost=0.25, decay=0.99, num_groups=32):
|
| 438 |
+
super(VQGAN, self).__init__()
|
| 439 |
+
|
| 440 |
+
self._encoder = Encoder(in_channels, hidden_channels, embedding_dim, block_depth=block_depth,
|
| 441 |
+
attn_pos=attn_pos, attn_with_skip=attn_with_skip, norm_type=norm_type, act_type="act_type", num_groups=num_groups)
|
| 442 |
+
|
| 443 |
+
if decay > 0.0:
|
| 444 |
+
self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim,
|
| 445 |
+
commitment_cost, decay)
|
| 446 |
+
else:
|
| 447 |
+
self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
|
| 448 |
+
commitment_cost)
|
| 449 |
+
self._decoder = Decoder(embedding_dim, hidden_channels, out_channels, block_depth=block_depth,
|
| 450 |
+
attn_pos=attn_pos, attn_with_skip=attn_with_skip, norm_type=norm_type,
|
| 451 |
+
act_type=act_type, num_groups=num_groups)
|
| 452 |
+
|
| 453 |
+
def forward(self, x):
|
| 454 |
+
z = self._encoder(x)
|
| 455 |
+
quantized, vq_loss, (perplexity, _, _) = self._vq_vae(z)
|
| 456 |
+
x_recon = self._decoder(quantized)
|
| 457 |
+
|
| 458 |
+
return vq_loss, x_recon, perplexity
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
class ReconstructionLoss(nn.Module):
|
| 462 |
+
def __init__(self, w1, w2, epsilon=1e-3):
|
| 463 |
+
super(ReconstructionLoss, self).__init__()
|
| 464 |
+
self.w1 = w1
|
| 465 |
+
self.w2 = w2
|
| 466 |
+
self.epsilon = epsilon
|
| 467 |
+
|
| 468 |
+
def weighted_mae_loss(self, y_true, y_pred):
|
| 469 |
+
# avoid divide by zero
|
| 470 |
+
y_true_safe = torch.clamp(y_true, min=self.epsilon)
|
| 471 |
+
|
| 472 |
+
# compute weighted MAE
|
| 473 |
+
loss = torch.mean(torch.abs(y_pred - y_true) / y_true_safe)
|
| 474 |
+
return loss
|
| 475 |
+
|
| 476 |
+
def mae_loss(self, y_true, y_pred):
|
| 477 |
+
loss = torch.mean(torch.abs(y_pred - y_true))
|
| 478 |
+
return loss
|
| 479 |
+
|
| 480 |
+
def forward(self, y_pred, y_true):
|
| 481 |
+
# loss for magnitude channel
|
| 482 |
+
log_magnitude_loss = self.w1 * self.weighted_mae_loss(y_pred[:, 0, :, :], y_true[:, 0, :, :])
|
| 483 |
+
|
| 484 |
+
# loss for phase channels
|
| 485 |
+
phase_loss = self.w2 * self.mae_loss(y_pred[:, 1:, :, :], y_true[:, 1:, :, :])
|
| 486 |
+
|
| 487 |
+
# sum up
|
| 488 |
+
rec_loss = log_magnitude_loss + phase_loss
|
| 489 |
+
return log_magnitude_loss, phase_loss, rec_loss
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def evaluate_VQGAN(model, discriminator, iterator, reconstructionLoss, adversarial_loss, trainingConfig):
|
| 493 |
+
model.to(trainingConfig["device"])
|
| 494 |
+
model.eval()
|
| 495 |
+
train_res_error = []
|
| 496 |
+
for i in xrange(100):
|
| 497 |
+
data = next(iter(iterator))
|
| 498 |
+
data = data.to(trainingConfig["device"])
|
| 499 |
+
|
| 500 |
+
# true/fake labels
|
| 501 |
+
real_labels = torch.ones(data.size(0), 1).to(trainingConfig["device"])
|
| 502 |
+
|
| 503 |
+
vq_loss, data_recon, perplexity = model(data)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
fake_preds = discriminator(data_recon)
|
| 507 |
+
adver_loss = adversarial_loss(fake_preds, real_labels)
|
| 508 |
+
|
| 509 |
+
log_magnitude_loss, phase_loss, rec_loss = reconstructionLoss(data_recon, data)
|
| 510 |
+
loss = rec_loss + trainingConfig["vq_weight"] * vq_loss + trainingConfig["adver_weight"] * adver_loss
|
| 511 |
+
|
| 512 |
+
train_res_error.append(loss.item())
|
| 513 |
+
initial_loss = np.mean(train_res_error)
|
| 514 |
+
return initial_loss
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def get_VQGAN(model_Config, load_pretrain=False, model_name=None, device="cpu"):
|
| 518 |
+
VQVAE = VQGAN(**model_Config)
|
| 519 |
+
print(f"Model intialized, size: {sum(p.numel() for p in VQVAE.parameters() if p.requires_grad)}")
|
| 520 |
+
VQVAE.to(device)
|
| 521 |
+
|
| 522 |
+
if load_pretrain:
|
| 523 |
+
print(f"Loading weights from models/{model_name}_imageVQVAE.pth")
|
| 524 |
+
checkpoint = torch.load(f'models/{model_name}_imageVQVAE.pth', map_location=device)
|
| 525 |
+
VQVAE.load_state_dict(checkpoint['model_state_dict'])
|
| 526 |
+
VQVAE.eval()
|
| 527 |
+
return VQVAE
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def train_VQGAN(model_Config, trainingConfig, iterator):
|
| 531 |
+
|
| 532 |
+
def save_model_hyperparameter(model_Config, trainingConfig, current_iter,
|
| 533 |
+
log_magnitude_loss, phase_loss, current_perplexity, current_vq_loss,
|
| 534 |
+
current_loss):
|
| 535 |
+
model_name = trainingConfig["model_name"]
|
| 536 |
+
model_hyperparameter = model_Config
|
| 537 |
+
model_hyperparameter.update(trainingConfig)
|
| 538 |
+
model_hyperparameter["current_iter"] = current_iter
|
| 539 |
+
model_hyperparameter["log_magnitude_loss"] = log_magnitude_loss
|
| 540 |
+
model_hyperparameter["phase_loss"] = phase_loss
|
| 541 |
+
model_hyperparameter["erplexity"] = current_perplexity
|
| 542 |
+
model_hyperparameter["vq_loss"] = current_vq_loss
|
| 543 |
+
model_hyperparameter["total_loss"] = current_loss
|
| 544 |
+
|
| 545 |
+
with open(f"models/hyperparameters/{model_name}_VQGAN_STFT.json", "w") as json_file:
|
| 546 |
+
json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
|
| 547 |
+
|
| 548 |
+
# initialize VAE
|
| 549 |
+
model = VQGAN(**model_Config)
|
| 550 |
+
print(f"VQ_VAE size: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
| 551 |
+
model.to(trainingConfig["device"])
|
| 552 |
+
|
| 553 |
+
VAE_optimizer = torch.optim.Adam(model.parameters(), lr=trainingConfig["lr"], amsgrad=False)
|
| 554 |
+
model_name = trainingConfig["model_name"]
|
| 555 |
+
|
| 556 |
+
if trainingConfig["load_pretrain"]:
|
| 557 |
+
print(f"Loading weights from models/{model_name}_imageVQVAE.pth")
|
| 558 |
+
checkpoint = torch.load(f'models/{model_name}_imageVQVAE.pth', map_location=trainingConfig["device"])
|
| 559 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 560 |
+
VAE_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 561 |
+
else:
|
| 562 |
+
print("VAE initialized.")
|
| 563 |
+
if trainingConfig["max_iter"] == 0:
|
| 564 |
+
print("Return VAE directly.")
|
| 565 |
+
return model
|
| 566 |
+
|
| 567 |
+
# initialize discriminator
|
| 568 |
+
discriminator = VQGAN_Discriminator(model_Config["in_channels"])
|
| 569 |
+
print(f"Discriminator size: {sum(p.numel() for p in discriminator.parameters() if p.requires_grad)}")
|
| 570 |
+
discriminator.to(trainingConfig["device"])
|
| 571 |
+
|
| 572 |
+
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=trainingConfig["d_lr"], amsgrad=False)
|
| 573 |
+
|
| 574 |
+
if trainingConfig["load_pretrain"]:
|
| 575 |
+
print(f"Loading weights from models/{model_name}_imageVQVAE_discriminator.pth")
|
| 576 |
+
checkpoint = torch.load(f'models/{model_name}_imageVQVAE_discriminator.pth', map_location=trainingConfig["device"])
|
| 577 |
+
discriminator.load_state_dict(checkpoint['model_state_dict'])
|
| 578 |
+
discriminator_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 579 |
+
else:
|
| 580 |
+
print("Discriminator initialized.")
|
| 581 |
+
|
| 582 |
+
# Training
|
| 583 |
+
|
| 584 |
+
train_res_phase_loss, train_res_perplexity, train_res_log_magnitude_loss, train_res_vq_loss, train_res_loss = [], [], [], [], []
|
| 585 |
+
train_discriminator_loss, train_adverserial_loss = [], []
|
| 586 |
+
|
| 587 |
+
reconstructionLoss = ReconstructionLoss(w1=trainingConfig["w1"], w2=trainingConfig["w2"], epsilon=trainingConfig["threshold"])
|
| 588 |
+
|
| 589 |
+
adversarial_loss = nn.BCEWithLogitsLoss()
|
| 590 |
+
writer = SummaryWriter(f'runs/{model_name}_VQVAE_lr=1e-4')
|
| 591 |
+
|
| 592 |
+
previous_lowest_loss = evaluate_VQGAN(model, discriminator, iterator,
|
| 593 |
+
reconstructionLoss, adversarial_loss, trainingConfig)
|
| 594 |
+
print(f"initial_loss: {previous_lowest_loss}")
|
| 595 |
+
|
| 596 |
+
model.train()
|
| 597 |
+
for i in xrange(trainingConfig["max_iter"]):
|
| 598 |
+
data = next(iter(iterator))
|
| 599 |
+
data = data.to(trainingConfig["device"])
|
| 600 |
+
|
| 601 |
+
# true/fake labels
|
| 602 |
+
real_labels = torch.ones(data.size(0), 1).to(trainingConfig["device"])
|
| 603 |
+
fake_labels = torch.zeros(data.size(0), 1).to(trainingConfig["device"])
|
| 604 |
+
|
| 605 |
+
# update discriminator
|
| 606 |
+
discriminator_optimizer.zero_grad()
|
| 607 |
+
|
| 608 |
+
vq_loss, data_recon, perplexity = model(data)
|
| 609 |
+
|
| 610 |
+
real_preds = discriminator(data)
|
| 611 |
+
fake_preds = discriminator(data_recon.detach())
|
| 612 |
+
|
| 613 |
+
loss_real = adversarial_loss(real_preds, real_labels)
|
| 614 |
+
loss_fake = adversarial_loss(fake_preds, fake_labels)
|
| 615 |
+
|
| 616 |
+
loss_D = loss_real + loss_fake
|
| 617 |
+
loss_D.backward()
|
| 618 |
+
discriminator_optimizer.step()
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
# update VQVAE
|
| 622 |
+
VAE_optimizer.zero_grad()
|
| 623 |
+
|
| 624 |
+
fake_preds = discriminator(data_recon)
|
| 625 |
+
adver_loss = adversarial_loss(fake_preds, real_labels)
|
| 626 |
+
|
| 627 |
+
log_magnitude_loss, phase_loss, rec_loss = reconstructionLoss(data_recon, data)
|
| 628 |
+
|
| 629 |
+
loss = rec_loss + trainingConfig["vq_weight"] * vq_loss + trainingConfig["adver_weight"] * adver_loss
|
| 630 |
+
loss.backward()
|
| 631 |
+
VAE_optimizer.step()
|
| 632 |
+
|
| 633 |
+
train_discriminator_loss.append(loss_D.item())
|
| 634 |
+
train_adverserial_loss.append(trainingConfig["adver_weight"] * adver_loss.item())
|
| 635 |
+
train_res_log_magnitude_loss.append(log_magnitude_loss.item())
|
| 636 |
+
train_res_phase_loss.append(phase_loss.item())
|
| 637 |
+
train_res_perplexity.append(perplexity.item())
|
| 638 |
+
train_res_vq_loss.append(trainingConfig["vq_weight"] * vq_loss.item())
|
| 639 |
+
train_res_loss.append(loss.item())
|
| 640 |
+
step = int(VAE_optimizer.state_dict()['state'][list(VAE_optimizer.state_dict()['state'].keys())[0]]['step'].cpu().numpy())
|
| 641 |
+
|
| 642 |
+
save_steps = trainingConfig["save_steps"]
|
| 643 |
+
if (i + 1) % 100 == 0:
|
| 644 |
+
print('%d step' % (step))
|
| 645 |
+
|
| 646 |
+
if (i + 1) % save_steps == 0:
|
| 647 |
+
current_discriminator_loss = np.mean(train_discriminator_loss[-save_steps:])
|
| 648 |
+
current_adverserial_loss = np.mean(train_adverserial_loss[-save_steps:])
|
| 649 |
+
current_log_magnitude_loss = np.mean(train_res_log_magnitude_loss[-save_steps:])
|
| 650 |
+
current_phase_loss = np.mean(train_res_phase_loss[-save_steps:])
|
| 651 |
+
current_perplexity = np.mean(train_res_perplexity[-save_steps:])
|
| 652 |
+
current_vq_loss = np.mean(train_res_vq_loss[-save_steps:])
|
| 653 |
+
current_loss = np.mean(train_res_loss[-save_steps:])
|
| 654 |
+
|
| 655 |
+
print('discriminator_loss: %.3f' % current_discriminator_loss)
|
| 656 |
+
print('adverserial_loss: %.3f' % current_adverserial_loss)
|
| 657 |
+
print('log_magnitude_loss: %.3f' % current_log_magnitude_loss)
|
| 658 |
+
print('phase_loss: %.3f' % current_phase_loss)
|
| 659 |
+
print('perplexity: %.3f' % current_perplexity)
|
| 660 |
+
print('vq_loss: %.3f' % current_vq_loss)
|
| 661 |
+
print('total_loss: %.3f' % current_loss)
|
| 662 |
+
writer.add_scalar(f"log_magnitude_loss", current_log_magnitude_loss, step)
|
| 663 |
+
writer.add_scalar(f"phase_loss", current_phase_loss, step)
|
| 664 |
+
writer.add_scalar(f"perplexity", current_perplexity, step)
|
| 665 |
+
writer.add_scalar(f"vq_loss", current_vq_loss, step)
|
| 666 |
+
writer.add_scalar(f"total_loss", current_loss, step)
|
| 667 |
+
if current_loss < previous_lowest_loss:
|
| 668 |
+
previous_lowest_loss = current_loss
|
| 669 |
+
|
| 670 |
+
torch.save({
|
| 671 |
+
'model_state_dict': model.state_dict(),
|
| 672 |
+
'optimizer_state_dict': VAE_optimizer.state_dict(),
|
| 673 |
+
}, f'models/{model_name}_imageVQVAE.pth')
|
| 674 |
+
|
| 675 |
+
torch.save({
|
| 676 |
+
'model_state_dict': discriminator.state_dict(),
|
| 677 |
+
'optimizer_state_dict': discriminator_optimizer.state_dict(),
|
| 678 |
+
}, f'models/{model_name}_imageVQVAE_discriminator.pth')
|
| 679 |
+
|
| 680 |
+
save_model_hyperparameter(model_Config, trainingConfig, step,
|
| 681 |
+
current_log_magnitude_loss, current_phase_loss, current_perplexity, current_vq_loss,
|
| 682 |
+
current_loss)
|
| 683 |
+
|
| 684 |
+
return model
|
model/__pycache__/DiffSynthSampler.cpython-310.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
model/__pycache__/GAN.cpython-310.pyc
ADDED
|
Binary file (7.49 kB). View file
|
|
|
model/__pycache__/VQGAN.cpython-310.pyc
ADDED
|
Binary file (18.8 kB). View file
|
|
|
model/__pycache__/diffusion.cpython-310.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
model/__pycache__/diffusion_components.cpython-310.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
model/__pycache__/multimodal_model.cpython-310.pyc
ADDED
|
Binary file (9.88 kB). View file
|
|
|
model/__pycache__/perceptual_label_predictor.cpython-37.pyc
ADDED
|
Binary file (1.67 kB). View file
|
|
|
model/__pycache__/timbre_encoder_pretrain.cpython-310.pyc
ADDED
|
Binary file (7.96 kB). View file
|
|
|
model/diffusion.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from functools import partial
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from six.moves import xrange
|
| 9 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 10 |
+
import random
|
| 11 |
+
|
| 12 |
+
from metrics.IS import get_inception_score
|
| 13 |
+
from tools import create_key
|
| 14 |
+
|
| 15 |
+
from model.diffusion_components import default, ConvNextBlock, ResnetBlock, SinusoidalPositionEmbeddings, Residual, \
|
| 16 |
+
PreNorm, \
|
| 17 |
+
Downsample, Upsample, exists, q_sample, get_beta_schedule, pad_and_concat, ConditionalEmbedding, \
|
| 18 |
+
LinearCrossAttention, LinearCrossAttentionAdd
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ConditionedUnet(nn.Module):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
in_dim,
|
| 25 |
+
out_dim=None,
|
| 26 |
+
down_dims=None,
|
| 27 |
+
up_dims=None,
|
| 28 |
+
mid_depth=3,
|
| 29 |
+
with_time_emb=True,
|
| 30 |
+
time_dim=None,
|
| 31 |
+
resnet_block_groups=8,
|
| 32 |
+
use_convnext=True,
|
| 33 |
+
convnext_mult=2,
|
| 34 |
+
attn_type="linear_cat",
|
| 35 |
+
n_label_class=11,
|
| 36 |
+
condition_type="instrument_family",
|
| 37 |
+
label_emb_dim=128,
|
| 38 |
+
):
|
| 39 |
+
super().__init__()
|
| 40 |
+
|
| 41 |
+
self.label_embedding = ConditionalEmbedding(int(n_label_class + 1), int(label_emb_dim), condition_type)
|
| 42 |
+
|
| 43 |
+
if up_dims is None:
|
| 44 |
+
up_dims = [128, 128, 64, 32]
|
| 45 |
+
if down_dims is None:
|
| 46 |
+
down_dims = [32, 32, 64, 128]
|
| 47 |
+
|
| 48 |
+
out_dim = default(out_dim, in_dim)
|
| 49 |
+
assert len(down_dims) == len(up_dims), "len(down_dims) != len(up_dims)"
|
| 50 |
+
assert down_dims[0] == up_dims[-1], "down_dims[0] != up_dims[-1]"
|
| 51 |
+
assert up_dims[0] == down_dims[-1], "up_dims[0] != down_dims[-1]"
|
| 52 |
+
down_in_out = list(zip(down_dims[:-1], down_dims[1:]))
|
| 53 |
+
up_in_out = list(zip(up_dims[:-1], up_dims[1:]))
|
| 54 |
+
print(f"down_in_out: {down_in_out}")
|
| 55 |
+
print(f"up_in_out: {up_in_out}")
|
| 56 |
+
time_dim = default(time_dim, int(down_dims[0] * 4))
|
| 57 |
+
|
| 58 |
+
self.init_conv = nn.Conv2d(in_dim, down_dims[0], 7, padding=3)
|
| 59 |
+
|
| 60 |
+
if use_convnext:
|
| 61 |
+
block_klass = partial(ConvNextBlock, mult=convnext_mult)
|
| 62 |
+
else:
|
| 63 |
+
block_klass = partial(ResnetBlock, groups=resnet_block_groups)
|
| 64 |
+
|
| 65 |
+
if attn_type == "linear_cat":
|
| 66 |
+
attn_klass = partial(LinearCrossAttention)
|
| 67 |
+
elif attn_type == "linear_add":
|
| 68 |
+
attn_klass = partial(LinearCrossAttentionAdd)
|
| 69 |
+
else:
|
| 70 |
+
raise NotImplementedError()
|
| 71 |
+
|
| 72 |
+
# time embeddings
|
| 73 |
+
if with_time_emb:
|
| 74 |
+
self.time_mlp = nn.Sequential(
|
| 75 |
+
SinusoidalPositionEmbeddings(down_dims[0]),
|
| 76 |
+
nn.Linear(down_dims[0], time_dim),
|
| 77 |
+
nn.GELU(),
|
| 78 |
+
nn.Linear(time_dim, time_dim),
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
time_dim = None
|
| 82 |
+
self.time_mlp = None
|
| 83 |
+
|
| 84 |
+
# left layers
|
| 85 |
+
self.downs = nn.ModuleList([])
|
| 86 |
+
self.ups = nn.ModuleList([])
|
| 87 |
+
skip_dims = []
|
| 88 |
+
|
| 89 |
+
for down_dim_in, down_dim_out in down_in_out:
|
| 90 |
+
self.downs.append(
|
| 91 |
+
nn.ModuleList(
|
| 92 |
+
[
|
| 93 |
+
block_klass(down_dim_in, down_dim_out, time_emb_dim=time_dim),
|
| 94 |
+
|
| 95 |
+
Residual(PreNorm(down_dim_out, attn_klass(down_dim_out, label_emb_dim=label_emb_dim, ))),
|
| 96 |
+
block_klass(down_dim_out, down_dim_out, time_emb_dim=time_dim),
|
| 97 |
+
Residual(PreNorm(down_dim_out, attn_klass(down_dim_out, label_emb_dim=label_emb_dim, ))),
|
| 98 |
+
Downsample(down_dim_out),
|
| 99 |
+
]
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
skip_dims.append(down_dim_out)
|
| 103 |
+
|
| 104 |
+
# bottleneck
|
| 105 |
+
mid_dim = down_dims[-1]
|
| 106 |
+
self.mid_left = nn.ModuleList([])
|
| 107 |
+
self.mid_right = nn.ModuleList([])
|
| 108 |
+
for _ in range(mid_depth - 1):
|
| 109 |
+
self.mid_left.append(block_klass(mid_dim, mid_dim, time_emb_dim=time_dim))
|
| 110 |
+
self.mid_right.append(block_klass(mid_dim * 2, mid_dim, time_emb_dim=time_dim))
|
| 111 |
+
self.mid_mid = nn.ModuleList(
|
| 112 |
+
[
|
| 113 |
+
block_klass(mid_dim, mid_dim, time_emb_dim=time_dim),
|
| 114 |
+
Residual(PreNorm(mid_dim, attn_klass(mid_dim, label_emb_dim=label_emb_dim, ))),
|
| 115 |
+
block_klass(mid_dim, mid_dim, time_emb_dim=time_dim),
|
| 116 |
+
]
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# right layers
|
| 120 |
+
for ind, (up_dim_in, up_dim_out) in enumerate(up_in_out):
|
| 121 |
+
skip_dim = skip_dims.pop() # down_dim_out
|
| 122 |
+
self.ups.append(
|
| 123 |
+
nn.ModuleList(
|
| 124 |
+
[
|
| 125 |
+
# pop&cat (h/2, w/2, down_dim_out)
|
| 126 |
+
block_klass(up_dim_in + skip_dim, up_dim_in, time_emb_dim=time_dim),
|
| 127 |
+
Residual(PreNorm(up_dim_in, attn_klass(up_dim_in, label_emb_dim=label_emb_dim, ))),
|
| 128 |
+
Upsample(up_dim_in),
|
| 129 |
+
# pop&cat (h, w, down_dim_out)
|
| 130 |
+
block_klass(up_dim_in + skip_dim, up_dim_out, time_emb_dim=time_dim),
|
| 131 |
+
Residual(PreNorm(up_dim_out, attn_klass(up_dim_out, label_emb_dim=label_emb_dim, ))),
|
| 132 |
+
# pop&cat (h, w, down_dim_out)
|
| 133 |
+
block_klass(up_dim_out + skip_dim, up_dim_out, time_emb_dim=time_dim),
|
| 134 |
+
Residual(PreNorm(up_dim_out, attn_klass(up_dim_out, label_emb_dim=label_emb_dim, ))),
|
| 135 |
+
]
|
| 136 |
+
)
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
self.final_conv = nn.Sequential(
|
| 140 |
+
block_klass(down_dims[0] + up_dims[-1], up_dims[-1]), nn.Conv2d(up_dims[-1], out_dim, 3, padding=1)
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def size(self):
|
| 144 |
+
total_params = sum(p.numel() for p in self.parameters())
|
| 145 |
+
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 146 |
+
print(f"Total parameters: {total_params}")
|
| 147 |
+
print(f"Trainable parameters: {trainable_params}")
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def forward(self, x, time, condition=None):
|
| 151 |
+
|
| 152 |
+
if condition is not None:
|
| 153 |
+
condition_emb = self.label_embedding(condition)
|
| 154 |
+
else:
|
| 155 |
+
condition_emb = None
|
| 156 |
+
|
| 157 |
+
h = []
|
| 158 |
+
|
| 159 |
+
x = self.init_conv(x)
|
| 160 |
+
h.append(x)
|
| 161 |
+
|
| 162 |
+
time_emb = self.time_mlp(time) if exists(self.time_mlp) else None
|
| 163 |
+
|
| 164 |
+
# downsample
|
| 165 |
+
for block1, attn1, block2, attn2, downsample in self.downs:
|
| 166 |
+
x = block1(x, time_emb)
|
| 167 |
+
x = attn1(x, condition_emb)
|
| 168 |
+
h.append(x)
|
| 169 |
+
x = block2(x, time_emb)
|
| 170 |
+
x = attn2(x, condition_emb)
|
| 171 |
+
h.append(x)
|
| 172 |
+
x = downsample(x)
|
| 173 |
+
h.append(x)
|
| 174 |
+
|
| 175 |
+
# bottleneck
|
| 176 |
+
|
| 177 |
+
for block in self.mid_left:
|
| 178 |
+
x = block(x, time_emb)
|
| 179 |
+
h.append(x)
|
| 180 |
+
|
| 181 |
+
(block1, attn, block2) = self.mid_mid
|
| 182 |
+
x = block1(x, time_emb)
|
| 183 |
+
x = attn(x, condition_emb)
|
| 184 |
+
x = block2(x, time_emb)
|
| 185 |
+
|
| 186 |
+
for block in self.mid_right:
|
| 187 |
+
# This is U-Net!!!
|
| 188 |
+
x = pad_and_concat(h.pop(), x)
|
| 189 |
+
x = block(x, time_emb)
|
| 190 |
+
|
| 191 |
+
# upsample
|
| 192 |
+
for block1, attn1, upsample, block2, attn2, block3, attn3 in self.ups:
|
| 193 |
+
x = pad_and_concat(h.pop(), x)
|
| 194 |
+
x = block1(x, time_emb)
|
| 195 |
+
x = attn1(x, condition_emb)
|
| 196 |
+
x = upsample(x)
|
| 197 |
+
|
| 198 |
+
x = pad_and_concat(h.pop(), x)
|
| 199 |
+
x = block2(x, time_emb)
|
| 200 |
+
x = attn2(x, condition_emb)
|
| 201 |
+
|
| 202 |
+
x = pad_and_concat(h.pop(), x)
|
| 203 |
+
x = block3(x, time_emb)
|
| 204 |
+
x = attn3(x, condition_emb)
|
| 205 |
+
|
| 206 |
+
x = pad_and_concat(h.pop(), x)
|
| 207 |
+
x = self.final_conv(x)
|
| 208 |
+
return x
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def conditional_p_losses(denoise_model, x_start, t, condition, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod,
|
| 212 |
+
noise=None, loss_type="l1"):
|
| 213 |
+
if noise is None:
|
| 214 |
+
noise = torch.randn_like(x_start)
|
| 215 |
+
|
| 216 |
+
x_noisy = q_sample(x_start=x_start, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod,
|
| 217 |
+
sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
|
| 218 |
+
predicted_noise = denoise_model(x_noisy, t, condition)
|
| 219 |
+
|
| 220 |
+
if loss_type == 'l1':
|
| 221 |
+
loss = F.l1_loss(noise, predicted_noise)
|
| 222 |
+
elif loss_type == 'l2':
|
| 223 |
+
loss = F.mse_loss(noise, predicted_noise)
|
| 224 |
+
elif loss_type == "huber":
|
| 225 |
+
loss = F.smooth_l1_loss(noise, predicted_noise)
|
| 226 |
+
else:
|
| 227 |
+
raise NotImplementedError()
|
| 228 |
+
|
| 229 |
+
return loss
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def evaluate_diffusion_model(device, model, iterator, BATCH_SIZE, timesteps, unetConfig, encodes2embeddings_mapping,
|
| 233 |
+
uncondition_rate, unconditional_condition):
|
| 234 |
+
model.to(device)
|
| 235 |
+
model.eval()
|
| 236 |
+
eva_loss = []
|
| 237 |
+
sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, _, _ = get_beta_schedule(timesteps)
|
| 238 |
+
for i in xrange(500):
|
| 239 |
+
data, attributes = next(iter(iterator))
|
| 240 |
+
data = data.to(device)
|
| 241 |
+
|
| 242 |
+
conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes]
|
| 243 |
+
selected_conditions = [
|
| 244 |
+
unconditional_condition if random.random() < uncondition_rate else random.choice(conditions_of_one_sample)
|
| 245 |
+
for conditions_of_one_sample in conditions]
|
| 246 |
+
|
| 247 |
+
selected_conditions = torch.stack(selected_conditions).float().to(device)
|
| 248 |
+
|
| 249 |
+
t = torch.randint(0, timesteps, (BATCH_SIZE,), device=device).long()
|
| 250 |
+
loss = conditional_p_losses(model, data, t, selected_conditions, loss_type="huber",
|
| 251 |
+
sqrt_alphas_cumprod=sqrt_alphas_cumprod,
|
| 252 |
+
sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod)
|
| 253 |
+
|
| 254 |
+
eva_loss.append(loss.item())
|
| 255 |
+
initial_loss = np.mean(eva_loss)
|
| 256 |
+
return initial_loss
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def get_diffusion_model(model_Config, load_pretrain=False, model_name=None, device="cpu"):
|
| 260 |
+
UNet = ConditionedUnet(**model_Config)
|
| 261 |
+
print(f"Model intialized, size: {sum(p.numel() for p in UNet.parameters() if p.requires_grad)}")
|
| 262 |
+
UNet.to(device)
|
| 263 |
+
|
| 264 |
+
if load_pretrain:
|
| 265 |
+
print(f"Loading weights from models/{model_name}_UNet.pth")
|
| 266 |
+
checkpoint = torch.load(f'models/{model_name}_UNet.pth', map_location=device)
|
| 267 |
+
UNet.load_state_dict(checkpoint['model_state_dict'])
|
| 268 |
+
UNet.eval()
|
| 269 |
+
return UNet
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def train_diffusion_model(VAE, text_encoder, CLAP_tokenizer, timbre_encoder, device, init_model_name, unetConfig, BATCH_SIZE, timesteps, lr, max_iter, iterator, load_pretrain,
|
| 273 |
+
encodes2embeddings_mapping, uncondition_rate, unconditional_condition, save_steps=5000, init_loss=None, save_model_name=None,
|
| 274 |
+
n_IS_batches=50):
|
| 275 |
+
|
| 276 |
+
if save_model_name is None:
|
| 277 |
+
save_model_name = init_model_name
|
| 278 |
+
|
| 279 |
+
def save_model_hyperparameter(model_name, unetConfig, BATCH_SIZE, lr, model_size, current_iter, current_loss):
|
| 280 |
+
model_hyperparameter = unetConfig
|
| 281 |
+
model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE
|
| 282 |
+
model_hyperparameter["lr"] = lr
|
| 283 |
+
model_hyperparameter["model_size"] = model_size
|
| 284 |
+
model_hyperparameter["current_iter"] = current_iter
|
| 285 |
+
model_hyperparameter["current_loss"] = current_loss
|
| 286 |
+
with open(f"models/hyperparameters/{model_name}_UNet.json", "w") as json_file:
|
| 287 |
+
json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
|
| 288 |
+
|
| 289 |
+
model = ConditionedUnet(**unetConfig)
|
| 290 |
+
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 291 |
+
print(f"Trainable parameters: {model_size}")
|
| 292 |
+
model.to(device)
|
| 293 |
+
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, amsgrad=False)
|
| 294 |
+
|
| 295 |
+
if load_pretrain:
|
| 296 |
+
print(f"Loading weights from models/{init_model_name}_UNet.pt")
|
| 297 |
+
checkpoint = torch.load(f'models/{init_model_name}_UNet.pth')
|
| 298 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 299 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 300 |
+
else:
|
| 301 |
+
print("Model initialized.")
|
| 302 |
+
if max_iter == 0:
|
| 303 |
+
print("Return model directly.")
|
| 304 |
+
return model, optimizer
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
train_loss = []
|
| 308 |
+
writer = SummaryWriter(f'runs/{save_model_name}_UNet')
|
| 309 |
+
if init_loss is None:
|
| 310 |
+
previous_loss = evaluate_diffusion_model(device, model, iterator, BATCH_SIZE, timesteps, unetConfig, encodes2embeddings_mapping,
|
| 311 |
+
uncondition_rate, unconditional_condition)
|
| 312 |
+
else:
|
| 313 |
+
previous_loss = init_loss
|
| 314 |
+
print(f"initial_IS: {previous_loss}")
|
| 315 |
+
sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, _, _ = get_beta_schedule(timesteps)
|
| 316 |
+
|
| 317 |
+
model.train()
|
| 318 |
+
for i in xrange(max_iter):
|
| 319 |
+
data, attributes = next(iter(iterator))
|
| 320 |
+
data = data.to(device)
|
| 321 |
+
|
| 322 |
+
conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes]
|
| 323 |
+
unconditional_condition_copy = torch.tensor(unconditional_condition, dtype=torch.float32).to(device).detach()
|
| 324 |
+
selected_conditions = [unconditional_condition_copy if random.random() < uncondition_rate else random.choice(
|
| 325 |
+
conditions_of_one_sample) for conditions_of_one_sample in conditions]
|
| 326 |
+
|
| 327 |
+
selected_conditions = torch.stack(selected_conditions).float().to(device)
|
| 328 |
+
|
| 329 |
+
optimizer.zero_grad()
|
| 330 |
+
|
| 331 |
+
t = torch.randint(0, timesteps, (BATCH_SIZE,), device=device).long()
|
| 332 |
+
loss = conditional_p_losses(model, data, t, selected_conditions, loss_type="huber",
|
| 333 |
+
sqrt_alphas_cumprod=sqrt_alphas_cumprod,
|
| 334 |
+
sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod)
|
| 335 |
+
|
| 336 |
+
loss.backward()
|
| 337 |
+
optimizer.step()
|
| 338 |
+
|
| 339 |
+
train_loss.append(loss.item())
|
| 340 |
+
step = int(optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].numpy())
|
| 341 |
+
|
| 342 |
+
if step % 100 == 0:
|
| 343 |
+
print('%d step' % (step))
|
| 344 |
+
|
| 345 |
+
if step % save_steps == 0:
|
| 346 |
+
current_loss = np.mean(train_loss[-save_steps:])
|
| 347 |
+
print(f"current_loss = {current_loss}")
|
| 348 |
+
torch.save({
|
| 349 |
+
'model_state_dict': model.state_dict(),
|
| 350 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 351 |
+
}, f'models/{save_model_name}_UNet.pth')
|
| 352 |
+
save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, lr, model_size, step, current_loss)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
if step % 20000 == 0:
|
| 356 |
+
current_IS = get_inception_score(device, model, VAE, text_encoder, CLAP_tokenizer, timbre_encoder, n_IS_batches,
|
| 357 |
+
positive_prompts="", negative_prompts="", CFG=1, sample_steps=20, task="STFT")
|
| 358 |
+
print('current_IS: %.5f' % current_IS)
|
| 359 |
+
current_loss = np.mean(train_loss[-save_steps:])
|
| 360 |
+
|
| 361 |
+
writer.add_scalar(f"current_IS", current_IS, step)
|
| 362 |
+
|
| 363 |
+
torch.save({
|
| 364 |
+
'model_state_dict': model.state_dict(),
|
| 365 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 366 |
+
}, f'models/history/{save_model_name}_{step}_UNet.pth')
|
| 367 |
+
save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, lr, model_size, step, current_loss)
|
| 368 |
+
|
| 369 |
+
return model, optimizer
|
| 370 |
+
|
| 371 |
+
|
model/diffusion_components.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn.functional as F
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from inspect import isfunction
|
| 6 |
+
import math
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def exists(x):
|
| 11 |
+
"""Return true for x is not None."""
|
| 12 |
+
return x is not None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def default(val, d):
|
| 16 |
+
"""Helper function"""
|
| 17 |
+
if exists(val):
|
| 18 |
+
return val
|
| 19 |
+
return d() if isfunction(d) else d
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Residual(nn.Module):
|
| 23 |
+
"""Skip connection"""
|
| 24 |
+
def __init__(self, fn):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.fn = fn
|
| 27 |
+
|
| 28 |
+
def forward(self, x, *args, **kwargs):
|
| 29 |
+
return self.fn(x, *args, **kwargs) + x
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def Upsample(dim):
|
| 33 |
+
"""Upsample layer, a transposed convolution layer with stride=2"""
|
| 34 |
+
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def Downsample(dim):
|
| 38 |
+
"""Downsample layer, a convolution layer with stride=2"""
|
| 39 |
+
return nn.Conv2d(dim, dim, 4, 2, 1)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class SinusoidalPositionEmbeddings(nn.Module):
|
| 43 |
+
"""Return sinusoidal embedding for integer time step."""
|
| 44 |
+
|
| 45 |
+
def __init__(self, dim):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.dim = dim
|
| 48 |
+
|
| 49 |
+
def forward(self, time):
|
| 50 |
+
device = time.device
|
| 51 |
+
half_dim = self.dim // 2
|
| 52 |
+
embeddings = math.log(10000) / (half_dim - 1)
|
| 53 |
+
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
|
| 54 |
+
embeddings = time[:, None] * embeddings[None, :]
|
| 55 |
+
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
|
| 56 |
+
return embeddings
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class Block(nn.Module):
|
| 60 |
+
"""Stack of convolution, normalization, and non-linear activation"""
|
| 61 |
+
|
| 62 |
+
def __init__(self, dim, dim_out, groups=8):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
|
| 65 |
+
self.norm = nn.GroupNorm(groups, dim_out)
|
| 66 |
+
self.act = nn.SiLU()
|
| 67 |
+
|
| 68 |
+
def forward(self, x, scale_shift=None):
|
| 69 |
+
x = self.proj(x)
|
| 70 |
+
x = self.norm(x)
|
| 71 |
+
|
| 72 |
+
if exists(scale_shift):
|
| 73 |
+
scale, shift = scale_shift
|
| 74 |
+
x = x * (scale + 1) + shift
|
| 75 |
+
|
| 76 |
+
x = self.act(x)
|
| 77 |
+
return x
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class ResnetBlock(nn.Module):
|
| 81 |
+
"""Stack of [conv + norm + act (+ scale&shift)], with positional embedding inserted <https://arxiv.org/abs/1512.03385>"""
|
| 82 |
+
|
| 83 |
+
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.mlp = (
|
| 86 |
+
nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
|
| 87 |
+
if exists(time_emb_dim)
|
| 88 |
+
else None
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
self.block1 = Block(dim, dim_out, groups=groups)
|
| 92 |
+
self.block2 = Block(dim_out, dim_out, groups=groups)
|
| 93 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
| 94 |
+
|
| 95 |
+
def forward(self, x, time_emb=None):
|
| 96 |
+
h = self.block1(x)
|
| 97 |
+
|
| 98 |
+
if exists(self.mlp) and exists(time_emb):
|
| 99 |
+
time_emb = self.mlp(time_emb)
|
| 100 |
+
# Adding positional embedding to intermediate layer (by broadcasting along spatial dimension)
|
| 101 |
+
h = rearrange(time_emb, "b c -> b c 1 1") + h
|
| 102 |
+
|
| 103 |
+
h = self.block2(h)
|
| 104 |
+
return h + self.res_conv(x)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class ConvNextBlock(nn.Module):
|
| 108 |
+
"""Stack of [conv7x7 (+ condition(pos)) + norm + conv3x3 + act + norm + conv3x3 + res1x1],with positional embedding inserted"""
|
| 109 |
+
|
| 110 |
+
def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.mlp = (
|
| 113 |
+
nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
|
| 114 |
+
if exists(time_emb_dim)
|
| 115 |
+
else None
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
|
| 119 |
+
|
| 120 |
+
self.net = nn.Sequential(
|
| 121 |
+
nn.GroupNorm(1, dim) if norm else nn.Identity(),
|
| 122 |
+
nn.Conv2d(dim, dim_out * mult, 3, padding=1),
|
| 123 |
+
nn.GELU(),
|
| 124 |
+
nn.GroupNorm(1, dim_out * mult),
|
| 125 |
+
nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
| 129 |
+
|
| 130 |
+
def forward(self, x, time_emb=None):
|
| 131 |
+
h = self.ds_conv(x)
|
| 132 |
+
|
| 133 |
+
if exists(self.mlp) and exists(time_emb):
|
| 134 |
+
assert exists(time_emb), "time embedding must be passed in"
|
| 135 |
+
condition = self.mlp(time_emb)
|
| 136 |
+
h = h + rearrange(condition, "b c -> b c 1 1")
|
| 137 |
+
|
| 138 |
+
h = self.net(h)
|
| 139 |
+
return h + self.res_conv(x)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class PreNorm(nn.Module):
|
| 143 |
+
"""Apply normalization before 'fn'"""
|
| 144 |
+
|
| 145 |
+
def __init__(self, dim, fn):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.fn = fn
|
| 148 |
+
self.norm = nn.GroupNorm(1, dim)
|
| 149 |
+
|
| 150 |
+
def forward(self, x, *args, **kwargs):
|
| 151 |
+
x = self.norm(x)
|
| 152 |
+
return self.fn(x, *args, **kwargs)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class ConditionalEmbedding(nn.Module):
|
| 156 |
+
"""Return embedding for label and projection for text embedding"""
|
| 157 |
+
|
| 158 |
+
def __init__(self, num_labels, embedding_dim, condition_type="instrument_family"):
|
| 159 |
+
super(ConditionalEmbedding, self).__init__()
|
| 160 |
+
if condition_type == "instrument_family":
|
| 161 |
+
self.embedding = nn.Embedding(num_labels, embedding_dim)
|
| 162 |
+
elif condition_type == "natural_language_prompt":
|
| 163 |
+
self.embedding = nn.Linear(embedding_dim, embedding_dim, bias=True)
|
| 164 |
+
else:
|
| 165 |
+
raise NotImplementedError()
|
| 166 |
+
|
| 167 |
+
def forward(self, labels):
|
| 168 |
+
return self.embedding(labels)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class LinearCrossAttention(nn.Module):
|
| 172 |
+
"""Combination of efficient attention and cross attention."""
|
| 173 |
+
|
| 174 |
+
def __init__(self, dim, heads=4, label_emb_dim=128, dim_head=32):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.dim_head = dim_head
|
| 177 |
+
self.scale = dim_head ** -0.5
|
| 178 |
+
self.heads = heads
|
| 179 |
+
hidden_dim = dim_head * heads
|
| 180 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
| 181 |
+
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
|
| 182 |
+
|
| 183 |
+
# embedding for key and value
|
| 184 |
+
self.label_key = nn.Linear(label_emb_dim, hidden_dim)
|
| 185 |
+
self.label_value = nn.Linear(label_emb_dim, hidden_dim)
|
| 186 |
+
|
| 187 |
+
def forward(self, x, label_embedding=None):
|
| 188 |
+
b, c, h, w = x.shape
|
| 189 |
+
qkv = self.to_qkv(x).chunk(3, dim=1)
|
| 190 |
+
q, k, v = map(
|
| 191 |
+
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
if label_embedding is not None:
|
| 195 |
+
label_k = self.label_key(label_embedding).view(b, self.heads, self.dim_head, 1)
|
| 196 |
+
label_v = self.label_value(label_embedding).view(b, self.heads, self.dim_head, 1)
|
| 197 |
+
|
| 198 |
+
k = torch.cat([k, label_k], dim=-1)
|
| 199 |
+
v = torch.cat([v, label_v], dim=-1)
|
| 200 |
+
|
| 201 |
+
q = q.softmax(dim=-2)
|
| 202 |
+
k = k.softmax(dim=-1)
|
| 203 |
+
q = q * self.scale
|
| 204 |
+
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
|
| 205 |
+
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
|
| 206 |
+
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
|
| 207 |
+
return self.to_out(out)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def pad_to_match(encoder_tensor, decoder_tensor):
|
| 211 |
+
"""
|
| 212 |
+
Pads the decoder_tensor to match the spatial dimensions of encoder_tensor.
|
| 213 |
+
|
| 214 |
+
:param encoder_tensor: The feature map from the encoder.
|
| 215 |
+
:param decoder_tensor: The feature map from the decoder that needs to be upsampled.
|
| 216 |
+
:return: Padded decoder_tensor with the same spatial dimensions as encoder_tensor.
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
enc_shape = encoder_tensor.shape[2:] # spatial dimensions are at index 2 and 3
|
| 220 |
+
dec_shape = decoder_tensor.shape[2:]
|
| 221 |
+
|
| 222 |
+
# assume enc_shape >= dec_shape
|
| 223 |
+
delta_w = enc_shape[1] - dec_shape[1]
|
| 224 |
+
delta_h = enc_shape[0] - dec_shape[0]
|
| 225 |
+
|
| 226 |
+
# padding
|
| 227 |
+
padding_left = delta_w // 2
|
| 228 |
+
padding_right = delta_w - padding_left
|
| 229 |
+
padding_top = delta_h // 2
|
| 230 |
+
padding_bottom = delta_h - padding_top
|
| 231 |
+
decoder_tensor_padded = F.pad(decoder_tensor, (padding_left, padding_right, padding_top, padding_bottom))
|
| 232 |
+
|
| 233 |
+
return decoder_tensor_padded
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def pad_and_concat(encoder_tensor, decoder_tensor):
|
| 237 |
+
"""
|
| 238 |
+
Pads the decoder_tensor and concatenates it with the encoder_tensor along the channel dimension.
|
| 239 |
+
|
| 240 |
+
:param encoder_tensor: The feature map from the encoder.
|
| 241 |
+
:param decoder_tensor: The feature map from the decoder that needs to be concatenated with encoder_tensor.
|
| 242 |
+
:return: Concatenated tensor.
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
# pad decoder_tensor
|
| 246 |
+
decoder_tensor_padded = pad_to_match(encoder_tensor, decoder_tensor)
|
| 247 |
+
# concat encoder_tensor and decoder_tensor_padded
|
| 248 |
+
concatenated_tensor = torch.cat((encoder_tensor, decoder_tensor_padded), dim=1)
|
| 249 |
+
return concatenated_tensor
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class LinearCrossAttentionAdd(nn.Module):
|
| 253 |
+
def __init__(self, dim, heads=4, label_emb_dim=128, dim_head=32):
|
| 254 |
+
super().__init__()
|
| 255 |
+
self.dim = dim
|
| 256 |
+
self.dim_head = dim_head
|
| 257 |
+
self.scale = dim_head ** -0.5
|
| 258 |
+
self.heads = heads
|
| 259 |
+
self.label_emb_dim = label_emb_dim
|
| 260 |
+
self.dim_head = dim_head
|
| 261 |
+
|
| 262 |
+
self.hidden_dim = dim_head * heads
|
| 263 |
+
self.to_qkv = nn.Conv2d(self.dim, self.hidden_dim * 3, 1, bias=False)
|
| 264 |
+
self.to_out = nn.Sequential(nn.Conv2d(self.hidden_dim, dim, 1), nn.GroupNorm(1, dim))
|
| 265 |
+
|
| 266 |
+
# embedding for key and value
|
| 267 |
+
self.label_key = nn.Linear(label_emb_dim, self.hidden_dim)
|
| 268 |
+
self.label_query = nn.Linear(label_emb_dim, self.hidden_dim)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def forward(self, x, condition=None):
|
| 272 |
+
b, c, h, w = x.shape
|
| 273 |
+
|
| 274 |
+
qkv = self.to_qkv(x).chunk(3, dim=1)
|
| 275 |
+
|
| 276 |
+
q, k, v = map(
|
| 277 |
+
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# if condition exists,concat its key and value with origin
|
| 281 |
+
if condition is not None:
|
| 282 |
+
label_k = self.label_key(condition).view(b, self.heads, self.dim_head, 1)
|
| 283 |
+
label_q = self.label_query(condition).view(b, self.heads, self.dim_head, 1)
|
| 284 |
+
k = k + label_k
|
| 285 |
+
q = q + label_q
|
| 286 |
+
|
| 287 |
+
q = q.softmax(dim=-2)
|
| 288 |
+
k = k.softmax(dim=-1)
|
| 289 |
+
q = q * self.scale
|
| 290 |
+
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
|
| 291 |
+
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
|
| 292 |
+
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
|
| 293 |
+
return self.to_out(out)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def linear_beta_schedule(timesteps):
|
| 298 |
+
beta_start = 0.0001
|
| 299 |
+
beta_end = 0.02
|
| 300 |
+
return torch.linspace(beta_start, beta_end, timesteps)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def get_beta_schedule(timesteps):
|
| 304 |
+
betas = linear_beta_schedule(timesteps=timesteps)
|
| 305 |
+
|
| 306 |
+
# define alphas
|
| 307 |
+
alphas = 1. - betas
|
| 308 |
+
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
| 309 |
+
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
| 310 |
+
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
|
| 311 |
+
|
| 312 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 313 |
+
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
| 314 |
+
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
|
| 315 |
+
|
| 316 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 317 |
+
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
| 318 |
+
return sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, posterior_variance, sqrt_recip_alphas
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def extract(a, t, x_shape):
|
| 322 |
+
batch_size = t.shape[0]
|
| 323 |
+
out = a.gather(-1, t.cpu())
|
| 324 |
+
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
# forward diffusion
|
| 328 |
+
def q_sample(x_start, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, noise=None):
|
| 329 |
+
if noise is None:
|
| 330 |
+
noise = torch.randn_like(x_start)
|
| 331 |
+
|
| 332 |
+
sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
|
| 333 |
+
sqrt_one_minus_alphas_cumprod_t = extract(
|
| 334 |
+
sqrt_one_minus_alphas_cumprod, t, x_start.shape
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
|
model/multimodal_model.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from tools import create_key
|
| 11 |
+
from model.timbre_encoder_pretrain import get_timbre_encoder
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ProjectionLayer(nn.Module):
|
| 15 |
+
"""Single-layer Linear projection with dropout, layer norm, and Gelu activation"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, input_dim, output_dim, dropout):
|
| 18 |
+
super(ProjectionLayer, self).__init__()
|
| 19 |
+
self.projection = nn.Linear(input_dim, output_dim)
|
| 20 |
+
self.gelu = nn.GELU()
|
| 21 |
+
self.fc = nn.Linear(output_dim, output_dim)
|
| 22 |
+
self.dropout = nn.Dropout(dropout)
|
| 23 |
+
self.layer_norm = nn.LayerNorm(output_dim)
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
projected = self.projection(x)
|
| 27 |
+
x = self.gelu(projected)
|
| 28 |
+
x = self.fc(x)
|
| 29 |
+
x = self.dropout(x)
|
| 30 |
+
x = x + projected
|
| 31 |
+
x = self.layer_norm(x)
|
| 32 |
+
return x
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ProjectionHead(nn.Module):
|
| 36 |
+
"""Stack of 'ProjectionLayer'"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, embedding_dim, projection_dim, dropout, num_layers=2):
|
| 39 |
+
super(ProjectionHead, self).__init__()
|
| 40 |
+
self.layers = nn.ModuleList([ProjectionLayer(embedding_dim if i == 0 else projection_dim,
|
| 41 |
+
projection_dim,
|
| 42 |
+
dropout) for i in range(num_layers)])
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
for layer in self.layers:
|
| 46 |
+
x = layer(x)
|
| 47 |
+
return x
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class multi_modal_model(nn.Module):
|
| 51 |
+
"""The multi-modal model for contrastive learning"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
timbre_encoder,
|
| 56 |
+
text_encoder,
|
| 57 |
+
spectrogram_feature_dim,
|
| 58 |
+
text_feature_dim,
|
| 59 |
+
multi_modal_emb_dim,
|
| 60 |
+
temperature,
|
| 61 |
+
dropout,
|
| 62 |
+
num_projection_layers=1,
|
| 63 |
+
freeze_spectrogram_encoder=True,
|
| 64 |
+
freeze_text_encoder=True,
|
| 65 |
+
):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.timbre_encoder = timbre_encoder
|
| 68 |
+
self.text_encoder = text_encoder
|
| 69 |
+
|
| 70 |
+
self.multi_modal_emb_dim = multi_modal_emb_dim
|
| 71 |
+
|
| 72 |
+
self.text_projection = ProjectionHead(embedding_dim=text_feature_dim,
|
| 73 |
+
projection_dim=self.multi_modal_emb_dim, dropout=dropout,
|
| 74 |
+
num_layers=num_projection_layers)
|
| 75 |
+
|
| 76 |
+
self.spectrogram_projection = ProjectionHead(embedding_dim=spectrogram_feature_dim,
|
| 77 |
+
projection_dim=self.multi_modal_emb_dim, dropout=dropout,
|
| 78 |
+
num_layers=num_projection_layers)
|
| 79 |
+
|
| 80 |
+
self.temperature = temperature
|
| 81 |
+
|
| 82 |
+
# Make spectrogram_encoder parameters non-trainable
|
| 83 |
+
for param in self.timbre_encoder.parameters():
|
| 84 |
+
param.requires_grad = not freeze_spectrogram_encoder
|
| 85 |
+
|
| 86 |
+
# Make text_encoder parameters non-trainable
|
| 87 |
+
for param in self.text_encoder.parameters():
|
| 88 |
+
param.requires_grad = not freeze_text_encoder
|
| 89 |
+
|
| 90 |
+
def forward(self, spectrogram_batch, tokenized_text_batch):
|
| 91 |
+
# Getting Image and Text Embeddings (with same dimension)
|
| 92 |
+
spectrogram_features, _, _, _, _ = self.timbre_encoder(spectrogram_batch)
|
| 93 |
+
text_features = self.text_encoder.get_text_features(**tokenized_text_batch)
|
| 94 |
+
|
| 95 |
+
# Concat and apply projection
|
| 96 |
+
spectrogram_embeddings = self.spectrogram_projection(spectrogram_features)
|
| 97 |
+
text_embeddings = self.text_projection(text_features)
|
| 98 |
+
|
| 99 |
+
# Calculating the Loss
|
| 100 |
+
logits = (text_embeddings @ spectrogram_embeddings.T) / self.temperature
|
| 101 |
+
images_similarity = spectrogram_embeddings @ spectrogram_embeddings.T
|
| 102 |
+
texts_similarity = text_embeddings @ text_embeddings.T
|
| 103 |
+
targets = F.softmax(
|
| 104 |
+
(images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
|
| 105 |
+
)
|
| 106 |
+
texts_loss = cross_entropy(logits, targets, reduction='none')
|
| 107 |
+
images_loss = cross_entropy(logits.T, targets.T, reduction='none')
|
| 108 |
+
contrastive_loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
|
| 109 |
+
contrastive_loss = contrastive_loss.mean()
|
| 110 |
+
|
| 111 |
+
return contrastive_loss
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def get_text_features(self, input_ids, attention_mask):
|
| 115 |
+
text_features = self.text_encoder.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
|
| 116 |
+
return self.text_projection(text_features)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def get_timbre_features(self, spectrogram_batch):
|
| 120 |
+
spectrogram_features, _, _, _, _ = self.timbre_encoder(spectrogram_batch)
|
| 121 |
+
return self.spectrogram_projection(spectrogram_features)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def cross_entropy(preds, targets, reduction='none'):
|
| 125 |
+
log_softmax = nn.LogSoftmax(dim=-1)
|
| 126 |
+
loss = (-targets * log_softmax(preds)).sum(1)
|
| 127 |
+
if reduction == "none":
|
| 128 |
+
return loss
|
| 129 |
+
elif reduction == "mean":
|
| 130 |
+
return loss.mean()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def get_multi_modal_model(timbre_encoder, text_encoder, model_Config, load_pretrain=False, model_name=None, device="cpu"):
|
| 134 |
+
mmm = multi_modal_model(timbre_encoder, text_encoder, **model_Config)
|
| 135 |
+
print(f"Model intialized, size: {sum(p.numel() for p in mmm.parameters() if p.requires_grad)}")
|
| 136 |
+
mmm.to(device)
|
| 137 |
+
|
| 138 |
+
if load_pretrain:
|
| 139 |
+
print(f"Loading weights from models/{model_name}_MMM.pth")
|
| 140 |
+
checkpoint = torch.load(f'models/{model_name}_MMM.pth', map_location=device)
|
| 141 |
+
mmm.load_state_dict(checkpoint['model_state_dict'])
|
| 142 |
+
mmm.eval()
|
| 143 |
+
return mmm
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def train_epoch(text_tokenizer, model, train_loader, labels_mapping, optimizer, device):
|
| 147 |
+
(data, attributes) = next(iter(train_loader))
|
| 148 |
+
keys = [create_key(attribute) for attribute in attributes]
|
| 149 |
+
|
| 150 |
+
while(len(set(keys)) != len(keys)):
|
| 151 |
+
(data, attributes) = next(iter(train_loader))
|
| 152 |
+
keys = [create_key(attribute) for attribute in attributes]
|
| 153 |
+
|
| 154 |
+
data = data.to(device)
|
| 155 |
+
|
| 156 |
+
texts = [labels_mapping[create_key(attribute)] for attribute in attributes]
|
| 157 |
+
selected_texts = [l[random.randint(0, len(l) - 1)] for l in texts]
|
| 158 |
+
|
| 159 |
+
tokenized_text = text_tokenizer(selected_texts, padding=True, return_tensors="pt").to(device)
|
| 160 |
+
|
| 161 |
+
loss = model(data, tokenized_text)
|
| 162 |
+
optimizer.zero_grad()
|
| 163 |
+
loss.backward()
|
| 164 |
+
optimizer.step()
|
| 165 |
+
|
| 166 |
+
return loss.item()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def valid_epoch(text_tokenizer, model, valid_loader, labels_mapping, device):
|
| 170 |
+
(data, attributes) = next(iter(valid_loader))
|
| 171 |
+
keys = [create_key(attribute) for attribute in attributes]
|
| 172 |
+
|
| 173 |
+
while(len(set(keys)) != len(keys)):
|
| 174 |
+
(data, attributes) = next(iter(valid_loader))
|
| 175 |
+
keys = [create_key(attribute) for attribute in attributes]
|
| 176 |
+
|
| 177 |
+
data = data.to(device)
|
| 178 |
+
texts = [labels_mapping[create_key(attribute)] for attribute in attributes]
|
| 179 |
+
selected_texts = [l[random.randint(0, len(l) - 1)] for l in texts]
|
| 180 |
+
|
| 181 |
+
tokenized_text = text_tokenizer(selected_texts, padding=True, return_tensors="pt").to(device)
|
| 182 |
+
|
| 183 |
+
loss = model(data, tokenized_text)
|
| 184 |
+
return loss.item()
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def train_multi_modal_model(device, training_dataloader, labels_mapping, text_tokenizer, text_encoder,
|
| 188 |
+
timbre_encoder_Config, MMM_config, MMM_training_config,
|
| 189 |
+
mmm_name, BATCH_SIZE, max_iter=0, load_pretrain=True,
|
| 190 |
+
timbre_encoder_name=None, init_loss=None, save_steps=2000):
|
| 191 |
+
|
| 192 |
+
def save_model_hyperparameter(model_name, MMM_config, MMM_training_config, BATCH_SIZE, model_size, current_iter,
|
| 193 |
+
current_loss):
|
| 194 |
+
|
| 195 |
+
model_hyperparameter = MMM_config
|
| 196 |
+
model_hyperparameter.update(MMM_training_config)
|
| 197 |
+
model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE
|
| 198 |
+
model_hyperparameter["model_size"] = model_size
|
| 199 |
+
model_hyperparameter["current_iter"] = current_iter
|
| 200 |
+
model_hyperparameter["current_loss"] = current_loss
|
| 201 |
+
with open(f"models/hyperparameters/{model_name}_MMM.json", "w") as json_file:
|
| 202 |
+
json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
|
| 203 |
+
|
| 204 |
+
timbreEncoder = get_timbre_encoder(timbre_encoder_Config, load_pretrain=True, model_name=timbre_encoder_name,
|
| 205 |
+
device=device)
|
| 206 |
+
|
| 207 |
+
mmm = multi_modal_model(timbreEncoder, text_encoder, **MMM_config).to(device)
|
| 208 |
+
|
| 209 |
+
print(f"spectrogram_encoder parameter: {sum(p.numel() for p in mmm.timbre_encoder.parameters())}")
|
| 210 |
+
print(f"text_encoder parameter: {sum(p.numel() for p in mmm.text_encoder.parameters())}")
|
| 211 |
+
print(f"spectrogram_projection parameter: {sum(p.numel() for p in mmm.spectrogram_projection.parameters())}")
|
| 212 |
+
print(f"text_projection parameter: {sum(p.numel() for p in mmm.text_projection.parameters())}")
|
| 213 |
+
total_parameters = sum(p.numel() for p in mmm.parameters())
|
| 214 |
+
trainable_parameters = sum(p.numel() for p in mmm.parameters() if p.requires_grad)
|
| 215 |
+
print(f"Trainable/Total parameter: {trainable_parameters}/{total_parameters}")
|
| 216 |
+
|
| 217 |
+
params = [
|
| 218 |
+
{"params": itertools.chain(
|
| 219 |
+
mmm.spectrogram_projection.parameters(),
|
| 220 |
+
mmm.text_projection.parameters(),
|
| 221 |
+
), "lr": MMM_training_config["head_lr"], "weight_decay": MMM_training_config["head_weight_decay"]},
|
| 222 |
+
]
|
| 223 |
+
if not MMM_config["freeze_text_encoder"]:
|
| 224 |
+
params.append({"params": mmm.text_encoder.parameters(), "lr": MMM_training_config["text_encoder_lr"],
|
| 225 |
+
"weight_decay": MMM_training_config["text_encoder_weight_decay"]})
|
| 226 |
+
if not MMM_config["freeze_spectrogram_encoder"]:
|
| 227 |
+
params.append({"params": mmm.timbre_encoder.parameters(), "lr": MMM_training_config["spectrogram_encoder_lr"],
|
| 228 |
+
"weight_decay": MMM_training_config["timbre_encoder_weight_decay"]})
|
| 229 |
+
|
| 230 |
+
optimizer = torch.optim.AdamW(params, weight_decay=0.)
|
| 231 |
+
|
| 232 |
+
if load_pretrain:
|
| 233 |
+
print(f"Loading weights from models/{mmm_name}_MMM.pt")
|
| 234 |
+
checkpoint = torch.load(f'models/{mmm_name}_MMM.pth')
|
| 235 |
+
mmm.load_state_dict(checkpoint['model_state_dict'])
|
| 236 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 237 |
+
else:
|
| 238 |
+
print("Model initialized.")
|
| 239 |
+
|
| 240 |
+
if max_iter == 0:
|
| 241 |
+
print("Return model directly.")
|
| 242 |
+
return mmm, optimizer
|
| 243 |
+
|
| 244 |
+
if init_loss is None:
|
| 245 |
+
previous_lowest_loss = valid_epoch(text_tokenizer, mmm, training_dataloader, labels_mapping, device)
|
| 246 |
+
else:
|
| 247 |
+
previous_lowest_loss = init_loss
|
| 248 |
+
print(f"Initial total loss: {previous_lowest_loss}")
|
| 249 |
+
|
| 250 |
+
train_loss_list = []
|
| 251 |
+
for i in range(max_iter):
|
| 252 |
+
|
| 253 |
+
mmm.train()
|
| 254 |
+
train_loss = train_epoch(text_tokenizer, mmm, training_dataloader, labels_mapping, optimizer, device)
|
| 255 |
+
train_loss_list.append(train_loss)
|
| 256 |
+
|
| 257 |
+
step = int(
|
| 258 |
+
optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].cpu().numpy())
|
| 259 |
+
if (i + 1) % 100 == 0:
|
| 260 |
+
print('%d step' % (step))
|
| 261 |
+
|
| 262 |
+
if (i + 1) % save_steps == 0:
|
| 263 |
+
current_loss = np.mean(train_loss_list[-save_steps:])
|
| 264 |
+
print(f"train_total_loss: {current_loss}")
|
| 265 |
+
if current_loss < previous_lowest_loss:
|
| 266 |
+
previous_lowest_loss = current_loss
|
| 267 |
+
torch.save({
|
| 268 |
+
'model_state_dict': mmm.state_dict(),
|
| 269 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 270 |
+
}, f'models/{mmm_name}_MMM.pth')
|
| 271 |
+
save_model_hyperparameter(mmm_name, MMM_config, MMM_training_config, BATCH_SIZE, total_parameters, step,
|
| 272 |
+
current_loss)
|
| 273 |
+
|
| 274 |
+
return mmm, optimizer
|
model/timbre_encoder_pretrain.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 6 |
+
from tools import create_key
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TimbreEncoder(nn.Module):
|
| 10 |
+
def __init__(self, input_dim, feature_dim, hidden_dim, num_instrument_classes, num_instrument_family_classes, num_velocity_classes, num_qualities, num_layers=1):
|
| 11 |
+
super(TimbreEncoder, self).__init__()
|
| 12 |
+
|
| 13 |
+
# Input layer
|
| 14 |
+
self.input_layer = nn.Linear(input_dim, feature_dim)
|
| 15 |
+
|
| 16 |
+
# LSTM Layer
|
| 17 |
+
self.lstm = nn.LSTM(feature_dim, hidden_dim, num_layers=num_layers, batch_first=True)
|
| 18 |
+
|
| 19 |
+
# Fully Connected Layers for classification
|
| 20 |
+
self.instrument_classifier_layer = nn.Linear(hidden_dim, num_instrument_classes)
|
| 21 |
+
self.instrument_family_classifier_layer = nn.Linear(hidden_dim, num_instrument_family_classes)
|
| 22 |
+
self.velocity_classifier_layer = nn.Linear(hidden_dim, num_velocity_classes)
|
| 23 |
+
self.qualities_classifier_layer = nn.Linear(hidden_dim, num_qualities)
|
| 24 |
+
|
| 25 |
+
# Softmax for converting output to probabilities
|
| 26 |
+
self.softmax = nn.LogSoftmax(dim=1)
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
# # Merge first two dimensions
|
| 30 |
+
batch_size, _, _, seq_len = x.shape
|
| 31 |
+
x = x.view(batch_size, -1, seq_len) # [batch_size, input_dim, seq_len]
|
| 32 |
+
|
| 33 |
+
# Forward propagate LSTM
|
| 34 |
+
x = x.permute(0, 2, 1)
|
| 35 |
+
x = self.input_layer(x)
|
| 36 |
+
feature, _ = self.lstm(x)
|
| 37 |
+
feature = feature[:, -1, :]
|
| 38 |
+
|
| 39 |
+
# Apply classification layers
|
| 40 |
+
instrument_logits = self.instrument_classifier_layer(feature)
|
| 41 |
+
instrument_family_logits = self.instrument_family_classifier_layer(feature)
|
| 42 |
+
velocity_logits = self.velocity_classifier_layer(feature)
|
| 43 |
+
qualities = self.qualities_classifier_layer(feature)
|
| 44 |
+
|
| 45 |
+
# Apply Softmax
|
| 46 |
+
instrument_logits = self.softmax(instrument_logits)
|
| 47 |
+
instrument_family_logits= self.softmax(instrument_family_logits)
|
| 48 |
+
velocity_logits = self.softmax(velocity_logits)
|
| 49 |
+
qualities = torch.sigmoid(qualities)
|
| 50 |
+
|
| 51 |
+
return feature, instrument_logits, instrument_family_logits, velocity_logits, qualities
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def get_multiclass_acc(outputs, ground_truth):
|
| 55 |
+
_, predicted = torch.max(outputs.data, 1)
|
| 56 |
+
total = ground_truth.size(0)
|
| 57 |
+
correct = (predicted == ground_truth).sum().item()
|
| 58 |
+
accuracy = 100 * correct / total
|
| 59 |
+
return accuracy
|
| 60 |
+
|
| 61 |
+
def get_binary_accuracy(y_pred, y_true):
|
| 62 |
+
predictions = (y_pred > 0.5).int()
|
| 63 |
+
|
| 64 |
+
correct_predictions = (predictions == y_true).float()
|
| 65 |
+
|
| 66 |
+
accuracy = correct_predictions.mean()
|
| 67 |
+
|
| 68 |
+
return accuracy.item() * 100.0
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get_timbre_encoder(model_Config, load_pretrain=False, model_name=None, device="cpu"):
|
| 72 |
+
timbreEncoder = TimbreEncoder(**model_Config)
|
| 73 |
+
print(f"Model intialized, size: {sum(p.numel() for p in timbreEncoder.parameters() if p.requires_grad)}")
|
| 74 |
+
timbreEncoder.to(device)
|
| 75 |
+
|
| 76 |
+
if load_pretrain:
|
| 77 |
+
print(f"Loading weights from models/{model_name}_timbre_encoder.pth")
|
| 78 |
+
checkpoint = torch.load(f'models/{model_name}_timbre_encoder.pth', map_location=device)
|
| 79 |
+
timbreEncoder.load_state_dict(checkpoint['model_state_dict'])
|
| 80 |
+
timbreEncoder.eval()
|
| 81 |
+
return timbreEncoder
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def evaluate_timbre_encoder(device, model, iterator, nll_Loss, bce_Loss, n_sample=100):
|
| 85 |
+
model.to(device)
|
| 86 |
+
model.eval()
|
| 87 |
+
|
| 88 |
+
eva_loss = []
|
| 89 |
+
for i in range(n_sample):
|
| 90 |
+
representation, attributes = next(iter(iterator))
|
| 91 |
+
|
| 92 |
+
instrument = torch.tensor([s["instrument"] for s in attributes], dtype=torch.long).to(device)
|
| 93 |
+
instrument_family = torch.tensor([s["instrument_family"] for s in attributes], dtype=torch.long).to(device)
|
| 94 |
+
velocity = torch.tensor([s["velocity"] for s in attributes], dtype=torch.long).to(device)
|
| 95 |
+
qualities = torch.tensor([[int(char) for char in create_key(attribute)[-10:]] for attribute in attributes], dtype=torch.float32).to(device)
|
| 96 |
+
|
| 97 |
+
_, instrument_logits, instrument_family_logits, velocity_logits, qualities_pred = model(representation.to(device))
|
| 98 |
+
|
| 99 |
+
# compute loss
|
| 100 |
+
instrument_loss = nll_Loss(instrument_logits, instrument)
|
| 101 |
+
instrument_family_loss = nll_Loss(instrument_family_logits, instrument_family)
|
| 102 |
+
velocity_loss = nll_Loss(velocity_logits, velocity)
|
| 103 |
+
qualities_loss = bce_Loss(qualities_pred, qualities)
|
| 104 |
+
|
| 105 |
+
loss = instrument_loss + instrument_family_loss + velocity_loss + qualities_loss
|
| 106 |
+
|
| 107 |
+
eva_loss.append(loss.item())
|
| 108 |
+
|
| 109 |
+
eva_loss = np.mean(eva_loss)
|
| 110 |
+
return eva_loss
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def train_timbre_encoder(device, model_name, timbre_encoder_Config, BATCH_SIZE, lr, max_iter, training_iterator, load_pretrain):
|
| 114 |
+
def save_model_hyperparameter(model_name, timbre_encoder_Config, BATCH_SIZE, lr, model_size, current_iter,
|
| 115 |
+
current_loss):
|
| 116 |
+
model_hyperparameter = timbre_encoder_Config
|
| 117 |
+
model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE
|
| 118 |
+
model_hyperparameter["lr"] = lr
|
| 119 |
+
model_hyperparameter["model_size"] = model_size
|
| 120 |
+
model_hyperparameter["current_iter"] = current_iter
|
| 121 |
+
model_hyperparameter["current_loss"] = current_loss
|
| 122 |
+
with open(f"models/hyperparameters/{model_name}_timbre_encoder.json", "w") as json_file:
|
| 123 |
+
json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
|
| 124 |
+
|
| 125 |
+
model = TimbreEncoder(**timbre_encoder_Config)
|
| 126 |
+
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 127 |
+
print(f"Model size: {model_size}")
|
| 128 |
+
model.to(device)
|
| 129 |
+
nll_Loss = torch.nn.NLLLoss()
|
| 130 |
+
bce_Loss = torch.nn.BCELoss()
|
| 131 |
+
|
| 132 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr, amsgrad=False)
|
| 133 |
+
|
| 134 |
+
if load_pretrain:
|
| 135 |
+
print(f"Loading weights from models/{model_name}_timbre_encoder.pt")
|
| 136 |
+
checkpoint = torch.load(f'models/{model_name}_timbre_encoder.pth')
|
| 137 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 138 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 139 |
+
else:
|
| 140 |
+
print("Model initialized.")
|
| 141 |
+
if max_iter == 0:
|
| 142 |
+
print("Return model directly.")
|
| 143 |
+
return model, model
|
| 144 |
+
|
| 145 |
+
train_loss, training_instrument_acc, training_instrument_family_acc, training_velocity_acc, training_qualities_acc = [], [], [], [], []
|
| 146 |
+
writer = SummaryWriter(f'runs/{model_name}_timbre_encoder')
|
| 147 |
+
current_best_model = model
|
| 148 |
+
previous_lowest_loss = 100.0
|
| 149 |
+
print(f"initial__loss: {previous_lowest_loss}")
|
| 150 |
+
|
| 151 |
+
for i in range(max_iter):
|
| 152 |
+
model.train()
|
| 153 |
+
|
| 154 |
+
representation, attributes = next(iter(training_iterator))
|
| 155 |
+
|
| 156 |
+
instrument = torch.tensor([s["instrument"] for s in attributes], dtype=torch.long).to(device)
|
| 157 |
+
instrument_family = torch.tensor([s["instrument_family"] for s in attributes], dtype=torch.long).to(device)
|
| 158 |
+
velocity = torch.tensor([s["velocity"] for s in attributes], dtype=torch.long).to(device)
|
| 159 |
+
qualities = torch.tensor([[int(char) for char in create_key(attribute)[-10:]] for attribute in attributes], dtype=torch.float32).to(device)
|
| 160 |
+
|
| 161 |
+
optimizer.zero_grad()
|
| 162 |
+
|
| 163 |
+
_, instrument_logits, instrument_family_logits, velocity_logits, qualities_pred = model(representation.to(device))
|
| 164 |
+
|
| 165 |
+
# compute loss
|
| 166 |
+
instrument_loss = nll_Loss(instrument_logits, instrument)
|
| 167 |
+
instrument_family_loss = nll_Loss(instrument_family_logits, instrument_family)
|
| 168 |
+
velocity_loss = nll_Loss(velocity_logits, velocity)
|
| 169 |
+
qualities_loss = bce_Loss(qualities_pred, qualities)
|
| 170 |
+
|
| 171 |
+
loss = instrument_loss + instrument_family_loss + velocity_loss + qualities_loss
|
| 172 |
+
|
| 173 |
+
loss.backward()
|
| 174 |
+
optimizer.step()
|
| 175 |
+
instrument_acc = get_multiclass_acc(instrument_logits, instrument)
|
| 176 |
+
instrument_family_acc = get_multiclass_acc(instrument_family_logits, instrument_family)
|
| 177 |
+
velocity_acc = get_multiclass_acc(velocity_logits, velocity)
|
| 178 |
+
qualities_acc = get_binary_accuracy(qualities_pred, qualities)
|
| 179 |
+
|
| 180 |
+
train_loss.append(loss.item())
|
| 181 |
+
training_instrument_acc.append(instrument_acc)
|
| 182 |
+
training_instrument_family_acc.append(instrument_family_acc)
|
| 183 |
+
training_velocity_acc.append(velocity_acc)
|
| 184 |
+
training_qualities_acc.append(qualities_acc)
|
| 185 |
+
step = int(optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].numpy())
|
| 186 |
+
|
| 187 |
+
if (i + 1) % 100 == 0:
|
| 188 |
+
print('%d step' % (step))
|
| 189 |
+
|
| 190 |
+
save_steps = 500
|
| 191 |
+
if (i + 1) % save_steps == 0:
|
| 192 |
+
current_loss = np.mean(train_loss[-save_steps:])
|
| 193 |
+
current_instrument_acc = np.mean(training_instrument_acc[-save_steps:])
|
| 194 |
+
current_instrument_family_acc = np.mean(training_instrument_family_acc[-save_steps:])
|
| 195 |
+
current_velocity_acc = np.mean(training_velocity_acc[-save_steps:])
|
| 196 |
+
current_qualities_acc = np.mean(training_qualities_acc[-save_steps:])
|
| 197 |
+
print('train_loss: %.5f' % current_loss)
|
| 198 |
+
print('current_instrument_acc: %.5f' % current_instrument_acc)
|
| 199 |
+
print('current_instrument_family_acc: %.5f' % current_instrument_family_acc)
|
| 200 |
+
print('current_velocity_acc: %.5f' % current_velocity_acc)
|
| 201 |
+
print('current_qualities_acc: %.5f' % current_qualities_acc)
|
| 202 |
+
writer.add_scalar(f"train_loss", current_loss, step)
|
| 203 |
+
writer.add_scalar(f"current_instrument_acc", current_instrument_acc, step)
|
| 204 |
+
writer.add_scalar(f"current_instrument_family_acc", current_instrument_family_acc, step)
|
| 205 |
+
writer.add_scalar(f"current_velocity_acc", current_velocity_acc, step)
|
| 206 |
+
writer.add_scalar(f"current_qualities_acc", current_qualities_acc, step)
|
| 207 |
+
|
| 208 |
+
if current_loss < previous_lowest_loss:
|
| 209 |
+
previous_lowest_loss = current_loss
|
| 210 |
+
current_best_model = model
|
| 211 |
+
torch.save({
|
| 212 |
+
'model_state_dict': model.state_dict(),
|
| 213 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 214 |
+
}, f'models/{model_name}_timbre_encoder.pth')
|
| 215 |
+
save_model_hyperparameter(model_name, timbre_encoder_Config, BATCH_SIZE, lr, model_size, step,
|
| 216 |
+
current_loss)
|
| 217 |
+
|
| 218 |
+
return model, current_best_model
|
| 219 |
+
|
| 220 |
+
|
tools.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import matplotlib
|
| 4 |
+
import librosa
|
| 5 |
+
from scipy.io.wavfile import write
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
k = 1e-16
|
| 9 |
+
|
| 10 |
+
def np_log10(x):
|
| 11 |
+
"""Safe log function with base 10."""
|
| 12 |
+
numerator = np.log(x + 1e-16)
|
| 13 |
+
denominator = np.log(10)
|
| 14 |
+
return numerator / denominator
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def sigmoid(x):
|
| 18 |
+
"""Safe log function with base 10."""
|
| 19 |
+
s = 1 / (1 + np.exp(-x))
|
| 20 |
+
return s
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def inv_sigmoid(s):
|
| 24 |
+
"""Safe inverse sigmoid function."""
|
| 25 |
+
x = np.log((s / (1 - s)) + 1e-16)
|
| 26 |
+
return x
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def spc_to_VAE_input(spc):
|
| 30 |
+
"""Restrict value range from [0, infinite] to [0, 1]. (deprecated )"""
|
| 31 |
+
return spc / (1 + spc)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def VAE_out_put_to_spc(o):
|
| 35 |
+
"""Inverse transform of function 'spc_to_VAE_input'. (deprecated )"""
|
| 36 |
+
return o / (1 - o + k)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def np_power_to_db(S, amin=1e-16, top_db=80.0):
|
| 41 |
+
"""Helper method for numpy data scaling. (deprecated )"""
|
| 42 |
+
ref = S.max()
|
| 43 |
+
|
| 44 |
+
log_spec = 10.0 * np_log10(np.maximum(amin, S))
|
| 45 |
+
log_spec -= 10.0 * np_log10(np.maximum(amin, ref))
|
| 46 |
+
|
| 47 |
+
log_spec = np.maximum(log_spec, log_spec.max() - top_db)
|
| 48 |
+
|
| 49 |
+
return log_spec
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def show_spc(spc):
|
| 53 |
+
"""Show a spectrogram. (deprecated )"""
|
| 54 |
+
s = np.shape(spc)
|
| 55 |
+
spc = np.reshape(spc, (s[0], s[1]))
|
| 56 |
+
magnitude_spectrum = np.abs(spc)
|
| 57 |
+
log_spectrum = np_power_to_db(magnitude_spectrum)
|
| 58 |
+
plt.imshow(np.flipud(log_spectrum))
|
| 59 |
+
plt.show()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def save_results(spectrogram, spectrogram_image_path, waveform_path):
|
| 63 |
+
"""Save the input 'spectrogram' and its waveform (reconstructed by Griffin Lim)
|
| 64 |
+
to path provided by 'spectrogram_image_path' and 'waveform_path'."""
|
| 65 |
+
magnitude_spectrum = np.abs(spectrogram)
|
| 66 |
+
log_spc = np_power_to_db(magnitude_spectrum)
|
| 67 |
+
log_spc = np.reshape(log_spc, (512, 256))
|
| 68 |
+
matplotlib.pyplot.imsave(spectrogram_image_path, log_spc, vmin=-100, vmax=0,
|
| 69 |
+
origin='lower')
|
| 70 |
+
|
| 71 |
+
# save waveform
|
| 72 |
+
abs_spec = np.zeros((513, 256))
|
| 73 |
+
abs_spec[:512, :] = abs_spec[:512, :] + np.sqrt(np.reshape(spectrogram, (512, 256)))
|
| 74 |
+
rec_signal = librosa.griffinlim(abs_spec, n_iter=32, hop_length=256, win_length=1024)
|
| 75 |
+
write(waveform_path, 16000, rec_signal)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def plot_log_spectrogram(signal: np.ndarray,
|
| 79 |
+
path: str,
|
| 80 |
+
n_fft=2048,
|
| 81 |
+
frame_length=1024,
|
| 82 |
+
frame_step=256):
|
| 83 |
+
"""Save spectrogram."""
|
| 84 |
+
stft = librosa.stft(signal, n_fft=n_fft, hop_length=frame_step, win_length=frame_length)
|
| 85 |
+
amp = np.square(np.real(stft)) + np.square(np.imag(stft))
|
| 86 |
+
magnitude_spectrum = np.abs(amp)
|
| 87 |
+
log_mel = np_power_to_db(magnitude_spectrum)
|
| 88 |
+
matplotlib.pyplot.imsave(path, log_mel, vmin=-100, vmax=0, origin='lower')
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def visualize_feature_maps(device, model, inputs, channel_indices=[0, 3,]):
|
| 92 |
+
"""
|
| 93 |
+
Visualize feature maps before and after quantization for given input.
|
| 94 |
+
|
| 95 |
+
Parameters:
|
| 96 |
+
- model: Your VQ-VAE model.
|
| 97 |
+
- inputs: A batch of input data.
|
| 98 |
+
- channel_indices: Indices of feature map channels to visualize.
|
| 99 |
+
"""
|
| 100 |
+
model.eval()
|
| 101 |
+
inputs = inputs.to(device)
|
| 102 |
+
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
z_e = model._encoder(inputs)
|
| 105 |
+
z_q, loss, (perplexity, min_encodings, min_encoding_indices) = model._vq_vae(z_e)
|
| 106 |
+
|
| 107 |
+
# Assuming inputs have shape [batch_size, channels, height, width]
|
| 108 |
+
batch_size = z_e.size(0)
|
| 109 |
+
|
| 110 |
+
for idx in range(batch_size):
|
| 111 |
+
fig, axs = plt.subplots(1, len(channel_indices)*2, figsize=(15, 5))
|
| 112 |
+
|
| 113 |
+
for i, channel_idx in enumerate(channel_indices):
|
| 114 |
+
# Plot encoder output
|
| 115 |
+
axs[2*i].imshow(z_e[idx][channel_idx].cpu().numpy(), cmap='viridis')
|
| 116 |
+
axs[2*i].set_title(f"Encoder Output - Channel {channel_idx}")
|
| 117 |
+
|
| 118 |
+
# Plot quantized output
|
| 119 |
+
axs[2*i+1].imshow(z_q[idx][channel_idx].cpu().numpy(), cmap='viridis')
|
| 120 |
+
axs[2*i+1].set_title(f"Quantized Output - Channel {channel_idx}")
|
| 121 |
+
|
| 122 |
+
plt.show()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def adjust_audio_length(audio, desired_length, original_sample_rate, target_sample_rate):
|
| 126 |
+
"""
|
| 127 |
+
Adjust the audio length to the desired length and resample to target sample rate.
|
| 128 |
+
|
| 129 |
+
Parameters:
|
| 130 |
+
- audio (np.array): The input audio signal
|
| 131 |
+
- desired_length (int): The desired length of the output audio
|
| 132 |
+
- original_sample_rate (int): The original sample rate of the audio
|
| 133 |
+
- target_sample_rate (int): The target sample rate for the output audio
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
- np.array: The adjusted and resampled audio
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
if not (original_sample_rate == target_sample_rate):
|
| 140 |
+
audio = librosa.core.resample(audio, orig_sr=original_sample_rate, target_sr=target_sample_rate)
|
| 141 |
+
|
| 142 |
+
if len(audio) > desired_length:
|
| 143 |
+
return audio[:desired_length]
|
| 144 |
+
|
| 145 |
+
elif len(audio) < desired_length:
|
| 146 |
+
padded_audio = np.zeros(desired_length)
|
| 147 |
+
padded_audio[:len(audio)] = audio
|
| 148 |
+
return padded_audio
|
| 149 |
+
else:
|
| 150 |
+
return audio
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def safe_int(s, default=0):
|
| 154 |
+
try:
|
| 155 |
+
return int(s)
|
| 156 |
+
except ValueError:
|
| 157 |
+
return default
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def pad_spectrogram(D):
|
| 161 |
+
"""Resize spectrogram to (512, 256). (deprecated )"""
|
| 162 |
+
D = D[1:, :]
|
| 163 |
+
|
| 164 |
+
padding_length = 256 - D.shape[1]
|
| 165 |
+
D_padded = np.pad(D, ((0, 0), (0, padding_length)), 'constant')
|
| 166 |
+
return D_padded
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def pad_STFT(D, time_resolution=256):
|
| 170 |
+
"""Resize spectral matrix by padding and cropping"""
|
| 171 |
+
D = D[1:, :]
|
| 172 |
+
|
| 173 |
+
if time_resolution is None:
|
| 174 |
+
return D
|
| 175 |
+
|
| 176 |
+
padding_length = time_resolution - D.shape[1]
|
| 177 |
+
if padding_length > 0:
|
| 178 |
+
D_padded = np.pad(D, ((0, 0), (0, padding_length)), 'constant')
|
| 179 |
+
return D_padded
|
| 180 |
+
else:
|
| 181 |
+
return D
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def depad_STFT(D_padded):
|
| 185 |
+
"""Inverse function of 'pad_STFT'"""
|
| 186 |
+
zero_row = np.zeros((1, D_padded.shape[1]))
|
| 187 |
+
|
| 188 |
+
D_restored = np.concatenate([zero_row, D_padded], axis=0)
|
| 189 |
+
|
| 190 |
+
return D_restored
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def nnData2Audio(spectrogram_batch, resolution=(512, 256), squared=False):
|
| 194 |
+
"""Transform batch of numpy spectrogram into signals and encodings."""
|
| 195 |
+
# Todo: remove resolution hard-coding
|
| 196 |
+
frequency_resolution, time_resolution = resolution
|
| 197 |
+
|
| 198 |
+
if isinstance(spectrogram_batch, torch.Tensor):
|
| 199 |
+
spectrogram_batch = spectrogram_batch.to("cpu").detach().numpy()
|
| 200 |
+
|
| 201 |
+
origin_signals = []
|
| 202 |
+
for spectrogram in spectrogram_batch:
|
| 203 |
+
spc = VAE_out_put_to_spc(spectrogram)
|
| 204 |
+
|
| 205 |
+
# get_audio
|
| 206 |
+
abs_spec = np.zeros((frequency_resolution+1, time_resolution))
|
| 207 |
+
|
| 208 |
+
if squared:
|
| 209 |
+
abs_spec[1:, :] = abs_spec[1:, :] + np.sqrt(np.reshape(spc, (frequency_resolution, time_resolution)))
|
| 210 |
+
else:
|
| 211 |
+
abs_spec[1:, :] = abs_spec[1:, :] + np.reshape(spc, (frequency_resolution, time_resolution))
|
| 212 |
+
|
| 213 |
+
origin_signal = librosa.griffinlim(abs_spec, n_iter=32, hop_length=256, win_length=1024)
|
| 214 |
+
origin_signals.append(origin_signal)
|
| 215 |
+
|
| 216 |
+
return origin_signals
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def amp_to_audio(amp, n_iter=50):
|
| 220 |
+
"""The Griffin-Lim algorithm."""
|
| 221 |
+
y_reconstructed = librosa.griffinlim(amp, n_iter=n_iter, hop_length=256, win_length=1024)
|
| 222 |
+
return y_reconstructed
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def rescale(amp, method="log1p"):
|
| 226 |
+
"""Rescale function."""
|
| 227 |
+
if method == "log1p":
|
| 228 |
+
return np.log1p(amp)
|
| 229 |
+
elif method == "NormalizedLogisticCompression":
|
| 230 |
+
return amp / (1.0 + amp)
|
| 231 |
+
else:
|
| 232 |
+
raise NotImplementedError()
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def unrescale(scaled_amp, method="NormalizedLogisticCompression"):
|
| 236 |
+
"""Inverse function of 'rescale'"""
|
| 237 |
+
if method == "log1p":
|
| 238 |
+
return np.expm1(scaled_amp)
|
| 239 |
+
elif method == "NormalizedLogisticCompression":
|
| 240 |
+
return scaled_amp / (1.0 - scaled_amp + 1e-10)
|
| 241 |
+
else:
|
| 242 |
+
raise NotImplementedError()
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def create_key(attributes):
|
| 246 |
+
"""Create unique key for each multi-label."""
|
| 247 |
+
qualities_str = ''.join(map(str, attributes["qualities"]))
|
| 248 |
+
instrument_source_str = attributes["instrument_source_str"]
|
| 249 |
+
instrument_family = attributes["instrument_family_str"]
|
| 250 |
+
key = f"{instrument_source_str}_{instrument_family}_{qualities_str}"
|
| 251 |
+
return key
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def merge_dictionaries(dicts):
|
| 255 |
+
"""Merge dictionaries."""
|
| 256 |
+
merged_dict = {}
|
| 257 |
+
for dictionary in dicts:
|
| 258 |
+
for key, value in dictionary.items():
|
| 259 |
+
if key in merged_dict:
|
| 260 |
+
merged_dict[key] += value
|
| 261 |
+
else:
|
| 262 |
+
merged_dict[key] = value
|
| 263 |
+
return merged_dict
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def adsr_envelope(signal, sample_rate, duration, attack_time, decay_time, sustain_level, release_time):
|
| 267 |
+
"""
|
| 268 |
+
Apply an ADSR envelope to an audio signal.
|
| 269 |
+
|
| 270 |
+
:param signal: The original audio signal (numpy array).
|
| 271 |
+
:param sample_rate: The sample rate of the audio signal.
|
| 272 |
+
:param attack_time: Attack time in seconds.
|
| 273 |
+
:param decay_time: Decay time in seconds.
|
| 274 |
+
:param sustain_level: Sustain level as a fraction of the peak (0 to 1).
|
| 275 |
+
:param release_time: Release time in seconds.
|
| 276 |
+
:return: The audio signal with the ADSR envelope applied.
|
| 277 |
+
"""
|
| 278 |
+
# Calculate the number of samples for each ADSR phase
|
| 279 |
+
duration_samples = int(duration * sample_rate)
|
| 280 |
+
|
| 281 |
+
# assert (duration_samples + int(1.0 * sample_rate)) <= len(signal), "(duration_samples + sample_rate) > len(signal)"
|
| 282 |
+
assert release_time <= 1.0, "release_time > 1.0"
|
| 283 |
+
|
| 284 |
+
attack_samples = int(attack_time * sample_rate)
|
| 285 |
+
decay_samples = int(decay_time * sample_rate)
|
| 286 |
+
release_samples = int(release_time * sample_rate)
|
| 287 |
+
sustain_samples = max(0, duration_samples - attack_samples - decay_samples)
|
| 288 |
+
|
| 289 |
+
# Create ADSR envelope
|
| 290 |
+
attack_env = np.linspace(0, 1, attack_samples)
|
| 291 |
+
decay_env = np.linspace(1, sustain_level, decay_samples)
|
| 292 |
+
sustain_env = np.full(sustain_samples, sustain_level)
|
| 293 |
+
release_env = np.linspace(sustain_level, 0, release_samples)
|
| 294 |
+
release_env_expand = np.zeros(int(1.0 * sample_rate))
|
| 295 |
+
release_env_expand[:len(release_env)] = release_env
|
| 296 |
+
|
| 297 |
+
# Concatenate all phases to create the complete envelope
|
| 298 |
+
envelope = np.concatenate([attack_env, decay_env, sustain_env, release_env_expand])
|
| 299 |
+
|
| 300 |
+
# Apply the envelope to the signal
|
| 301 |
+
if len(envelope) <= len(signal):
|
| 302 |
+
applied_signal = signal[:len(envelope)] * envelope
|
| 303 |
+
else:
|
| 304 |
+
signal_expanded = np.zeros(len(envelope))
|
| 305 |
+
signal_expanded[:len(signal)] = signal
|
| 306 |
+
applied_signal = signal_expanded * envelope
|
| 307 |
+
|
| 308 |
+
return applied_signal
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def rms_normalize(audio, target_rms=0.1):
|
| 312 |
+
"""Normalize the RMS value."""
|
| 313 |
+
current_rms = np.sqrt(np.mean(audio**2))
|
| 314 |
+
scaling_factor = target_rms / current_rms
|
| 315 |
+
normalized_audio = audio * scaling_factor
|
| 316 |
+
return normalized_audio
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def encode_stft(D):
|
| 320 |
+
"""'STFT+' function that transform spectral matrix into spectral representation."""
|
| 321 |
+
magnitude = np.abs(D)
|
| 322 |
+
phase = np.angle(D)
|
| 323 |
+
|
| 324 |
+
log_magnitude = np.log1p(magnitude)
|
| 325 |
+
|
| 326 |
+
cos_phase = np.cos(phase)
|
| 327 |
+
sin_phase = np.sin(phase)
|
| 328 |
+
|
| 329 |
+
encoded_D = np.stack([log_magnitude, cos_phase, sin_phase], axis=0)
|
| 330 |
+
return encoded_D
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def decode_stft(encoded_D):
|
| 334 |
+
"""'ISTFT+' function that reconstructs spectral matrix from spectral representation."""
|
| 335 |
+
log_magnitude = encoded_D[0, ...]
|
| 336 |
+
cos_phase = encoded_D[1, ...]
|
| 337 |
+
sin_phase = encoded_D[2, ...]
|
| 338 |
+
|
| 339 |
+
magnitude = np.expm1(log_magnitude)
|
| 340 |
+
|
| 341 |
+
phase = np.arctan2(sin_phase, cos_phase)
|
| 342 |
+
|
| 343 |
+
D = magnitude * (np.cos(phase) + 1j * np.sin(phase))
|
| 344 |
+
return D
|
webUI/__pycache__/app.cpython-310.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
webUI/deprecated/interpolationWithCondition.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from model.DiffSynthSampler import DiffSynthSampler
|
| 6 |
+
from tools import safe_int
|
| 7 |
+
from webUI.natural_language_guided_STFT.utils import encodeBatch2GradioOutput, latent_representation_to_Gradio_image
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_interpolation_with_condition_module(gradioWebUI, interpolation_with_text_state):
|
| 11 |
+
# Load configurations
|
| 12 |
+
uNet = gradioWebUI.uNet
|
| 13 |
+
freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
|
| 14 |
+
VAE_scale = gradioWebUI.VAE_scale
|
| 15 |
+
height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels
|
| 16 |
+
timesteps = gradioWebUI.timesteps
|
| 17 |
+
VAE_quantizer = gradioWebUI.VAE_quantizer
|
| 18 |
+
VAE_decoder = gradioWebUI.VAE_decoder
|
| 19 |
+
CLAP = gradioWebUI.CLAP
|
| 20 |
+
CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
|
| 21 |
+
device = gradioWebUI.device
|
| 22 |
+
squared = gradioWebUI.squared
|
| 23 |
+
sample_rate = gradioWebUI.sample_rate
|
| 24 |
+
noise_strategy = gradioWebUI.noise_strategy
|
| 25 |
+
|
| 26 |
+
def diffusion_random_sample(text2sound_prompts_1, text2sound_prompts_2, text2sound_negative_prompts, text2sound_batchsize,
|
| 27 |
+
text2sound_duration,
|
| 28 |
+
text2sound_guidance_scale, text2sound_sampler,
|
| 29 |
+
text2sound_sample_steps, text2sound_seed,
|
| 30 |
+
interpolation_with_text_dict):
|
| 31 |
+
text2sound_sample_steps = int(text2sound_sample_steps)
|
| 32 |
+
text2sound_seed = safe_int(text2sound_seed, 12345678)
|
| 33 |
+
# Todo: take care of text2sound_time_resolution/width
|
| 34 |
+
width = int(time_resolution*((text2sound_duration+1)/4) / VAE_scale)
|
| 35 |
+
text2sound_batchsize = int(text2sound_batchsize)
|
| 36 |
+
|
| 37 |
+
text2sound_embedding_1 = \
|
| 38 |
+
CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts_1], padding=True, return_tensors="pt"))[0].to(device)
|
| 39 |
+
text2sound_embedding_2 = \
|
| 40 |
+
CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts_2], padding=True, return_tensors="pt"))[0].to(device)
|
| 41 |
+
|
| 42 |
+
CFG = int(text2sound_guidance_scale)
|
| 43 |
+
|
| 44 |
+
mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
|
| 45 |
+
unconditional_condition = \
|
| 46 |
+
CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[0]
|
| 47 |
+
mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device))
|
| 48 |
+
|
| 49 |
+
mySampler.respace(list(np.linspace(0, timesteps - 1, text2sound_sample_steps, dtype=np.int32)))
|
| 50 |
+
|
| 51 |
+
condition = torch.linspace(1, 0, steps=text2sound_batchsize).unsqueeze(1).to(device) * text2sound_embedding_1 + \
|
| 52 |
+
torch.linspace(0, 1, steps=text2sound_batchsize).unsqueeze(1).to(device) * text2sound_embedding_2
|
| 53 |
+
|
| 54 |
+
# Todo: move this code
|
| 55 |
+
torch.manual_seed(text2sound_seed)
|
| 56 |
+
initial_noise = torch.randn(text2sound_batchsize, channels, height, width).to(device)
|
| 57 |
+
|
| 58 |
+
latent_representations, initial_noise = \
|
| 59 |
+
mySampler.sample(model=uNet, shape=(text2sound_batchsize, channels, height, width), seed=text2sound_seed,
|
| 60 |
+
return_tensor=True, condition=condition, sampler=text2sound_sampler, initial_noise=initial_noise)
|
| 61 |
+
|
| 62 |
+
latent_representations = latent_representations[-1]
|
| 63 |
+
|
| 64 |
+
interpolation_with_text_dict["latent_representations"] = latent_representations
|
| 65 |
+
|
| 66 |
+
latent_representation_gradio_images = []
|
| 67 |
+
quantized_latent_representation_gradio_images = []
|
| 68 |
+
new_sound_spectrogram_gradio_images = []
|
| 69 |
+
new_sound_rec_signals_gradio = []
|
| 70 |
+
|
| 71 |
+
quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
|
| 72 |
+
# Todo: remove hard-coding
|
| 73 |
+
flipped_log_spectrums, rec_signals = encodeBatch2GradioOutput(VAE_decoder, quantized_latent_representations,
|
| 74 |
+
resolution=(512, width * VAE_scale), centralized=False,
|
| 75 |
+
squared=squared)
|
| 76 |
+
|
| 77 |
+
for i in range(text2sound_batchsize):
|
| 78 |
+
latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
|
| 79 |
+
quantized_latent_representation_gradio_images.append(
|
| 80 |
+
latent_representation_to_Gradio_image(quantized_latent_representations[i]))
|
| 81 |
+
new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i])
|
| 82 |
+
new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i]))
|
| 83 |
+
|
| 84 |
+
def concatenate_arrays(arrays_list):
|
| 85 |
+
return np.concatenate(arrays_list, axis=1)
|
| 86 |
+
|
| 87 |
+
concatenated_spectrogram_gradio_image = concatenate_arrays(new_sound_spectrogram_gradio_images)
|
| 88 |
+
|
| 89 |
+
interpolation_with_text_dict["latent_representation_gradio_images"] = latent_representation_gradio_images
|
| 90 |
+
interpolation_with_text_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
|
| 91 |
+
interpolation_with_text_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
|
| 92 |
+
interpolation_with_text_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
|
| 93 |
+
|
| 94 |
+
return {text2sound_latent_representation_image: interpolation_with_text_dict["latent_representation_gradio_images"][0],
|
| 95 |
+
text2sound_quantized_latent_representation_image:
|
| 96 |
+
interpolation_with_text_dict["quantized_latent_representation_gradio_images"][0],
|
| 97 |
+
text2sound_sampled_concatenated_spectrogram_image: concatenated_spectrogram_gradio_image,
|
| 98 |
+
text2sound_sampled_spectrogram_image: interpolation_with_text_dict["new_sound_spectrogram_gradio_images"][0],
|
| 99 |
+
text2sound_sampled_audio: interpolation_with_text_dict["new_sound_rec_signals_gradio"][0],
|
| 100 |
+
text2sound_seed_textbox: text2sound_seed,
|
| 101 |
+
interpolation_with_text_state: interpolation_with_text_dict,
|
| 102 |
+
text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1,
|
| 103 |
+
visible=True,
|
| 104 |
+
label="Sample index.",
|
| 105 |
+
info="Swipe to view other samples")}
|
| 106 |
+
|
| 107 |
+
def show_random_sample(sample_index, text2sound_dict):
|
| 108 |
+
sample_index = int(sample_index)
|
| 109 |
+
return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][
|
| 110 |
+
sample_index],
|
| 111 |
+
text2sound_quantized_latent_representation_image:
|
| 112 |
+
text2sound_dict["quantized_latent_representation_gradio_images"][sample_index],
|
| 113 |
+
text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][sample_index],
|
| 114 |
+
text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]}
|
| 115 |
+
|
| 116 |
+
with gr.Tab("InterpolationCond."):
|
| 117 |
+
gr.Markdown("Use interpolation to generate a gradient sound sequence.")
|
| 118 |
+
with gr.Row(variant="panel"):
|
| 119 |
+
with gr.Column(scale=3):
|
| 120 |
+
text2sound_prompts_1_textbox = gr.Textbox(label="Positive prompt 1", lines=2, value="organ")
|
| 121 |
+
text2sound_prompts_2_textbox = gr.Textbox(label="Positive prompt 2", lines=2, value="string")
|
| 122 |
+
text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
|
| 123 |
+
|
| 124 |
+
with gr.Column(scale=1):
|
| 125 |
+
text2sound_sampling_button = gr.Button(variant="primary",
|
| 126 |
+
value="Generate a batch of samples and show "
|
| 127 |
+
"the first one",
|
| 128 |
+
scale=1)
|
| 129 |
+
text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
|
| 130 |
+
label="Sample index",
|
| 131 |
+
info="Swipe to view other samples")
|
| 132 |
+
with gr.Row(variant="panel"):
|
| 133 |
+
with gr.Column(scale=1, variant="panel"):
|
| 134 |
+
text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
|
| 135 |
+
text2sound_sampler_radio = gradioWebUI.get_sampler_radio()
|
| 136 |
+
text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider(cpu_batchsize=3)
|
| 137 |
+
text2sound_duration_slider = gradioWebUI.get_duration_slider()
|
| 138 |
+
text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
|
| 139 |
+
text2sound_seed_textbox = gradioWebUI.get_seed_textbox()
|
| 140 |
+
|
| 141 |
+
with gr.Column(scale=1):
|
| 142 |
+
with gr.Row(variant="panel"):
|
| 143 |
+
text2sound_sampled_concatenated_spectrogram_image = gr.Image(label="Interpolations", type="numpy",
|
| 144 |
+
height=420, scale=8)
|
| 145 |
+
text2sound_sampled_spectrogram_image = gr.Image(label="Selected spectrogram", type="numpy",
|
| 146 |
+
height=420, scale=1)
|
| 147 |
+
text2sound_sampled_audio = gr.Audio(type="numpy", label="Play")
|
| 148 |
+
|
| 149 |
+
with gr.Row(variant="panel"):
|
| 150 |
+
text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy",
|
| 151 |
+
height=200, width=100)
|
| 152 |
+
text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation",
|
| 153 |
+
type="numpy", height=200, width=100)
|
| 154 |
+
|
| 155 |
+
text2sound_sampling_button.click(diffusion_random_sample,
|
| 156 |
+
inputs=[text2sound_prompts_1_textbox,
|
| 157 |
+
text2sound_prompts_2_textbox,
|
| 158 |
+
text2sound_negative_prompts_textbox,
|
| 159 |
+
text2sound_batchsize_slider,
|
| 160 |
+
text2sound_duration_slider,
|
| 161 |
+
text2sound_guidance_scale_slider, text2sound_sampler_radio,
|
| 162 |
+
text2sound_sample_steps_slider,
|
| 163 |
+
text2sound_seed_textbox,
|
| 164 |
+
interpolation_with_text_state],
|
| 165 |
+
outputs=[text2sound_latent_representation_image,
|
| 166 |
+
text2sound_quantized_latent_representation_image,
|
| 167 |
+
text2sound_sampled_concatenated_spectrogram_image,
|
| 168 |
+
text2sound_sampled_spectrogram_image,
|
| 169 |
+
text2sound_sampled_audio,
|
| 170 |
+
text2sound_seed_textbox,
|
| 171 |
+
interpolation_with_text_state,
|
| 172 |
+
text2sound_sample_index_slider])
|
| 173 |
+
text2sound_sample_index_slider.change(show_random_sample,
|
| 174 |
+
inputs=[text2sound_sample_index_slider, interpolation_with_text_state],
|
| 175 |
+
outputs=[text2sound_latent_representation_image,
|
| 176 |
+
text2sound_quantized_latent_representation_image,
|
| 177 |
+
text2sound_sampled_spectrogram_image,
|
| 178 |
+
text2sound_sampled_audio])
|
webUI/deprecated/interpolationWithXT.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from model.DiffSynthSampler import DiffSynthSampler
|
| 6 |
+
from tools import safe_int
|
| 7 |
+
from webUI.natural_language_guided.utils import encodeBatch2GradioOutput, latent_representation_to_Gradio_image
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_interpolation_with_xT_module(gradioWebUI, interpolation_with_text_state):
|
| 11 |
+
# Load configurations
|
| 12 |
+
uNet = gradioWebUI.uNet
|
| 13 |
+
freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
|
| 14 |
+
VAE_scale = gradioWebUI.VAE_scale
|
| 15 |
+
height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels
|
| 16 |
+
timesteps = gradioWebUI.timesteps
|
| 17 |
+
VAE_quantizer = gradioWebUI.VAE_quantizer
|
| 18 |
+
VAE_decoder = gradioWebUI.VAE_decoder
|
| 19 |
+
CLAP = gradioWebUI.CLAP
|
| 20 |
+
CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
|
| 21 |
+
device = gradioWebUI.device
|
| 22 |
+
squared = gradioWebUI.squared
|
| 23 |
+
sample_rate = gradioWebUI.sample_rate
|
| 24 |
+
noise_strategy = gradioWebUI.noise_strategy
|
| 25 |
+
|
| 26 |
+
def diffusion_random_sample(text2sound_prompts, text2sound_negative_prompts, text2sound_batchsize,
|
| 27 |
+
text2sound_duration,
|
| 28 |
+
text2sound_noise_variance, text2sound_guidance_scale, text2sound_sampler,
|
| 29 |
+
text2sound_sample_steps, text2sound_seed,
|
| 30 |
+
interpolation_with_text_dict):
|
| 31 |
+
text2sound_sample_steps = int(text2sound_sample_steps)
|
| 32 |
+
text2sound_seed = safe_int(text2sound_seed, 12345678)
|
| 33 |
+
# Todo: take care of text2sound_time_resolution/width
|
| 34 |
+
width = int(time_resolution*((text2sound_duration+1)/4) / VAE_scale)
|
| 35 |
+
text2sound_batchsize = int(text2sound_batchsize)
|
| 36 |
+
|
| 37 |
+
text2sound_embedding = \
|
| 38 |
+
CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(device)
|
| 39 |
+
|
| 40 |
+
CFG = int(text2sound_guidance_scale)
|
| 41 |
+
|
| 42 |
+
mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
|
| 43 |
+
unconditional_condition = \
|
| 44 |
+
CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[0]
|
| 45 |
+
mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device))
|
| 46 |
+
|
| 47 |
+
mySampler.respace(list(np.linspace(0, timesteps - 1, text2sound_sample_steps, dtype=np.int32)))
|
| 48 |
+
|
| 49 |
+
condition = text2sound_embedding.repeat(text2sound_batchsize, 1)
|
| 50 |
+
latent_representations, initial_noise = \
|
| 51 |
+
mySampler.interpolate(model=uNet, shape=(text2sound_batchsize, channels, height, width),
|
| 52 |
+
seed=text2sound_seed,
|
| 53 |
+
variance=text2sound_noise_variance,
|
| 54 |
+
return_tensor=True, condition=condition, sampler=text2sound_sampler)
|
| 55 |
+
|
| 56 |
+
latent_representations = latent_representations[-1]
|
| 57 |
+
|
| 58 |
+
interpolation_with_text_dict["latent_representations"] = latent_representations
|
| 59 |
+
|
| 60 |
+
latent_representation_gradio_images = []
|
| 61 |
+
quantized_latent_representation_gradio_images = []
|
| 62 |
+
new_sound_spectrogram_gradio_images = []
|
| 63 |
+
new_sound_rec_signals_gradio = []
|
| 64 |
+
|
| 65 |
+
quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
|
| 66 |
+
# Todo: remove hard-coding
|
| 67 |
+
flipped_log_spectrums, rec_signals = encodeBatch2GradioOutput(VAE_decoder, quantized_latent_representations,
|
| 68 |
+
resolution=(512, width * VAE_scale), centralized=False,
|
| 69 |
+
squared=squared)
|
| 70 |
+
|
| 71 |
+
for i in range(text2sound_batchsize):
|
| 72 |
+
latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
|
| 73 |
+
quantized_latent_representation_gradio_images.append(
|
| 74 |
+
latent_representation_to_Gradio_image(quantized_latent_representations[i]))
|
| 75 |
+
new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i])
|
| 76 |
+
new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i]))
|
| 77 |
+
|
| 78 |
+
def concatenate_arrays(arrays_list):
|
| 79 |
+
return np.concatenate(arrays_list, axis=1)
|
| 80 |
+
|
| 81 |
+
concatenated_spectrogram_gradio_image = concatenate_arrays(new_sound_spectrogram_gradio_images)
|
| 82 |
+
|
| 83 |
+
interpolation_with_text_dict["latent_representation_gradio_images"] = latent_representation_gradio_images
|
| 84 |
+
interpolation_with_text_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
|
| 85 |
+
interpolation_with_text_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
|
| 86 |
+
interpolation_with_text_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
|
| 87 |
+
|
| 88 |
+
return {text2sound_latent_representation_image: interpolation_with_text_dict["latent_representation_gradio_images"][0],
|
| 89 |
+
text2sound_quantized_latent_representation_image:
|
| 90 |
+
interpolation_with_text_dict["quantized_latent_representation_gradio_images"][0],
|
| 91 |
+
text2sound_sampled_concatenated_spectrogram_image: concatenated_spectrogram_gradio_image,
|
| 92 |
+
text2sound_sampled_spectrogram_image: interpolation_with_text_dict["new_sound_spectrogram_gradio_images"][0],
|
| 93 |
+
text2sound_sampled_audio: interpolation_with_text_dict["new_sound_rec_signals_gradio"][0],
|
| 94 |
+
text2sound_seed_textbox: text2sound_seed,
|
| 95 |
+
interpolation_with_text_state: interpolation_with_text_dict,
|
| 96 |
+
text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1,
|
| 97 |
+
visible=True,
|
| 98 |
+
label="Sample index.",
|
| 99 |
+
info="Swipe to view other samples")}
|
| 100 |
+
|
| 101 |
+
def show_random_sample(sample_index, text2sound_dict):
|
| 102 |
+
sample_index = int(sample_index)
|
| 103 |
+
return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][
|
| 104 |
+
sample_index],
|
| 105 |
+
text2sound_quantized_latent_representation_image:
|
| 106 |
+
text2sound_dict["quantized_latent_representation_gradio_images"][sample_index],
|
| 107 |
+
text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][sample_index],
|
| 108 |
+
text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]}
|
| 109 |
+
|
| 110 |
+
with gr.Tab("InterpolationXT"):
|
| 111 |
+
gr.Markdown("Use interpolation to generate a gradient sound sequence.")
|
| 112 |
+
with gr.Row(variant="panel"):
|
| 113 |
+
with gr.Column(scale=3):
|
| 114 |
+
text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ")
|
| 115 |
+
text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
|
| 116 |
+
|
| 117 |
+
with gr.Column(scale=1):
|
| 118 |
+
text2sound_sampling_button = gr.Button(variant="primary",
|
| 119 |
+
value="Generate a batch of samples and show "
|
| 120 |
+
"the first one",
|
| 121 |
+
scale=1)
|
| 122 |
+
text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
|
| 123 |
+
label="Sample index",
|
| 124 |
+
info="Swipe to view other samples")
|
| 125 |
+
with gr.Row(variant="panel"):
|
| 126 |
+
with gr.Column(scale=1, variant="panel"):
|
| 127 |
+
text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
|
| 128 |
+
text2sound_sampler_radio = gradioWebUI.get_sampler_radio()
|
| 129 |
+
text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider(cpu_batchsize=3)
|
| 130 |
+
text2sound_duration_slider = gradioWebUI.get_duration_slider()
|
| 131 |
+
text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
|
| 132 |
+
text2sound_seed_textbox = gradioWebUI.get_seed_textbox()
|
| 133 |
+
text2sound_noise_variance_slider = gr.Slider(minimum=0., maximum=5., value=1., step=0.01,
|
| 134 |
+
label="Noise variance",
|
| 135 |
+
info="The larger this value, the more diversity the interpolation has.")
|
| 136 |
+
|
| 137 |
+
with gr.Column(scale=1):
|
| 138 |
+
with gr.Row(variant="panel"):
|
| 139 |
+
text2sound_sampled_concatenated_spectrogram_image = gr.Image(label="Interpolations", type="numpy",
|
| 140 |
+
height=420, scale=8)
|
| 141 |
+
text2sound_sampled_spectrogram_image = gr.Image(label="Selected spectrogram", type="numpy",
|
| 142 |
+
height=420, scale=1)
|
| 143 |
+
text2sound_sampled_audio = gr.Audio(type="numpy", label="Play")
|
| 144 |
+
|
| 145 |
+
with gr.Row(variant="panel"):
|
| 146 |
+
text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy",
|
| 147 |
+
height=200, width=100)
|
| 148 |
+
text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation",
|
| 149 |
+
type="numpy", height=200, width=100)
|
| 150 |
+
|
| 151 |
+
text2sound_sampling_button.click(diffusion_random_sample,
|
| 152 |
+
inputs=[text2sound_prompts_textbox, text2sound_negative_prompts_textbox,
|
| 153 |
+
text2sound_batchsize_slider,
|
| 154 |
+
text2sound_duration_slider,
|
| 155 |
+
text2sound_noise_variance_slider,
|
| 156 |
+
text2sound_guidance_scale_slider, text2sound_sampler_radio,
|
| 157 |
+
text2sound_sample_steps_slider,
|
| 158 |
+
text2sound_seed_textbox,
|
| 159 |
+
interpolation_with_text_state],
|
| 160 |
+
outputs=[text2sound_latent_representation_image,
|
| 161 |
+
text2sound_quantized_latent_representation_image,
|
| 162 |
+
text2sound_sampled_concatenated_spectrogram_image,
|
| 163 |
+
text2sound_sampled_spectrogram_image,
|
| 164 |
+
text2sound_sampled_audio,
|
| 165 |
+
text2sound_seed_textbox,
|
| 166 |
+
interpolation_with_text_state,
|
| 167 |
+
text2sound_sample_index_slider])
|
| 168 |
+
text2sound_sample_index_slider.change(show_random_sample,
|
| 169 |
+
inputs=[text2sound_sample_index_slider, interpolation_with_text_state],
|
| 170 |
+
outputs=[text2sound_latent_representation_image,
|
| 171 |
+
text2sound_quantized_latent_representation_image,
|
| 172 |
+
text2sound_sampled_spectrogram_image,
|
| 173 |
+
text2sound_sampled_audio])
|
webUI/natural_language_guided/GAN.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from tools import safe_int
|
| 6 |
+
from webUI.natural_language_guided_STFT.utils import encodeBatch2GradioOutput, latent_representation_to_Gradio_image, \
|
| 7 |
+
add_instrument
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_testGAN(gradioWebUI, text2sound_state, virtual_instruments_state):
|
| 11 |
+
# Load configurations
|
| 12 |
+
gan_generator = gradioWebUI.GAN_generator
|
| 13 |
+
freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
|
| 14 |
+
VAE_scale = gradioWebUI.VAE_scale
|
| 15 |
+
height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels
|
| 16 |
+
|
| 17 |
+
timesteps = gradioWebUI.timesteps
|
| 18 |
+
VAE_quantizer = gradioWebUI.VAE_quantizer
|
| 19 |
+
VAE_decoder = gradioWebUI.VAE_decoder
|
| 20 |
+
CLAP = gradioWebUI.CLAP
|
| 21 |
+
CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
|
| 22 |
+
device = gradioWebUI.device
|
| 23 |
+
squared = gradioWebUI.squared
|
| 24 |
+
sample_rate = gradioWebUI.sample_rate
|
| 25 |
+
noise_strategy = gradioWebUI.noise_strategy
|
| 26 |
+
|
| 27 |
+
def gan_random_sample(text2sound_prompts, text2sound_negative_prompts, text2sound_batchsize,
|
| 28 |
+
text2sound_duration,
|
| 29 |
+
text2sound_guidance_scale, text2sound_sampler,
|
| 30 |
+
text2sound_sample_steps, text2sound_seed,
|
| 31 |
+
text2sound_dict):
|
| 32 |
+
text2sound_seed = safe_int(text2sound_seed, 12345678)
|
| 33 |
+
|
| 34 |
+
width = int(time_resolution * ((text2sound_duration + 1) / 4) / VAE_scale)
|
| 35 |
+
|
| 36 |
+
text2sound_batchsize = int(text2sound_batchsize)
|
| 37 |
+
|
| 38 |
+
text2sound_embedding = \
|
| 39 |
+
CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(
|
| 40 |
+
device)
|
| 41 |
+
|
| 42 |
+
CFG = int(text2sound_guidance_scale)
|
| 43 |
+
|
| 44 |
+
condition = text2sound_embedding.repeat(text2sound_batchsize, 1)
|
| 45 |
+
|
| 46 |
+
noise = torch.randn(text2sound_batchsize, channels, height, width).to(device)
|
| 47 |
+
latent_representations = gan_generator(noise, condition)
|
| 48 |
+
|
| 49 |
+
print(latent_representations[0, 0, :3, :3])
|
| 50 |
+
|
| 51 |
+
latent_representation_gradio_images = []
|
| 52 |
+
quantized_latent_representation_gradio_images = []
|
| 53 |
+
new_sound_spectrogram_gradio_images = []
|
| 54 |
+
new_sound_rec_signals_gradio = []
|
| 55 |
+
|
| 56 |
+
quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
|
| 57 |
+
# Todo: remove hard-coding
|
| 58 |
+
flipped_log_spectrums, rec_signals = encodeBatch2GradioOutput(VAE_decoder, quantized_latent_representations,
|
| 59 |
+
resolution=(512, width * VAE_scale),
|
| 60 |
+
centralized=False,
|
| 61 |
+
squared=squared)
|
| 62 |
+
|
| 63 |
+
for i in range(text2sound_batchsize):
|
| 64 |
+
latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
|
| 65 |
+
quantized_latent_representation_gradio_images.append(
|
| 66 |
+
latent_representation_to_Gradio_image(quantized_latent_representations[i]))
|
| 67 |
+
new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i])
|
| 68 |
+
new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i]))
|
| 69 |
+
|
| 70 |
+
text2sound_dict["latent_representations"] = latent_representations.to("cpu").detach().numpy()
|
| 71 |
+
text2sound_dict["quantized_latent_representations"] = quantized_latent_representations.to("cpu").detach().numpy()
|
| 72 |
+
text2sound_dict["latent_representation_gradio_images"] = latent_representation_gradio_images
|
| 73 |
+
text2sound_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
|
| 74 |
+
text2sound_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
|
| 75 |
+
text2sound_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
|
| 76 |
+
|
| 77 |
+
text2sound_dict["condition"] = condition.to("cpu").detach().numpy()
|
| 78 |
+
# text2sound_dict["negative_condition"] = negative_condition.to("cpu").detach().numpy()
|
| 79 |
+
text2sound_dict["guidance_scale"] = CFG
|
| 80 |
+
text2sound_dict["sampler"] = text2sound_sampler
|
| 81 |
+
|
| 82 |
+
return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][0],
|
| 83 |
+
text2sound_quantized_latent_representation_image:
|
| 84 |
+
text2sound_dict["quantized_latent_representation_gradio_images"][0],
|
| 85 |
+
text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][0],
|
| 86 |
+
text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][0],
|
| 87 |
+
text2sound_seed_textbox: text2sound_seed,
|
| 88 |
+
text2sound_state: text2sound_dict,
|
| 89 |
+
text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1,
|
| 90 |
+
visible=True,
|
| 91 |
+
label="Sample index.",
|
| 92 |
+
info="Swipe to view other samples")}
|
| 93 |
+
|
| 94 |
+
def show_random_sample(sample_index, text2sound_dict):
|
| 95 |
+
sample_index = int(sample_index)
|
| 96 |
+
text2sound_dict["sample_index"] = sample_index
|
| 97 |
+
return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][
|
| 98 |
+
sample_index],
|
| 99 |
+
text2sound_quantized_latent_representation_image:
|
| 100 |
+
text2sound_dict["quantized_latent_representation_gradio_images"][sample_index],
|
| 101 |
+
text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][
|
| 102 |
+
sample_index],
|
| 103 |
+
text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
with gr.Tab("Text2sound_GAN"):
|
| 107 |
+
gr.Markdown("Use neural networks to select random sounds using your favorite instrument!")
|
| 108 |
+
with gr.Row(variant="panel"):
|
| 109 |
+
with gr.Column(scale=3):
|
| 110 |
+
text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ")
|
| 111 |
+
text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
|
| 112 |
+
|
| 113 |
+
with gr.Column(scale=1):
|
| 114 |
+
text2sound_sampling_button = gr.Button(variant="primary",
|
| 115 |
+
value="Generate a batch of samples and show "
|
| 116 |
+
"the first one",
|
| 117 |
+
scale=1)
|
| 118 |
+
text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
|
| 119 |
+
label="Sample index",
|
| 120 |
+
info="Swipe to view other samples")
|
| 121 |
+
with gr.Row(variant="panel"):
|
| 122 |
+
with gr.Column(scale=1, variant="panel"):
|
| 123 |
+
text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
|
| 124 |
+
text2sound_sampler_radio = gradioWebUI.get_sampler_radio()
|
| 125 |
+
text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider()
|
| 126 |
+
text2sound_duration_slider = gradioWebUI.get_duration_slider()
|
| 127 |
+
text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
|
| 128 |
+
text2sound_seed_textbox = gradioWebUI.get_seed_textbox()
|
| 129 |
+
|
| 130 |
+
with gr.Column(scale=1):
|
| 131 |
+
text2sound_sampled_spectrogram_image = gr.Image(label="Sampled spectrogram", type="numpy", height=420)
|
| 132 |
+
text2sound_sampled_audio = gr.Audio(type="numpy", label="Play")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
with gr.Row(variant="panel"):
|
| 136 |
+
text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy",
|
| 137 |
+
height=200, width=100)
|
| 138 |
+
text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation",
|
| 139 |
+
type="numpy", height=200, width=100)
|
| 140 |
+
|
| 141 |
+
text2sound_sampling_button.click(gan_random_sample,
|
| 142 |
+
inputs=[text2sound_prompts_textbox,
|
| 143 |
+
text2sound_negative_prompts_textbox,
|
| 144 |
+
text2sound_batchsize_slider,
|
| 145 |
+
text2sound_duration_slider,
|
| 146 |
+
text2sound_guidance_scale_slider, text2sound_sampler_radio,
|
| 147 |
+
text2sound_sample_steps_slider,
|
| 148 |
+
text2sound_seed_textbox,
|
| 149 |
+
text2sound_state],
|
| 150 |
+
outputs=[text2sound_latent_representation_image,
|
| 151 |
+
text2sound_quantized_latent_representation_image,
|
| 152 |
+
text2sound_sampled_spectrogram_image,
|
| 153 |
+
text2sound_sampled_audio,
|
| 154 |
+
text2sound_seed_textbox,
|
| 155 |
+
text2sound_state,
|
| 156 |
+
text2sound_sample_index_slider])
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
text2sound_sample_index_slider.change(show_random_sample,
|
| 160 |
+
inputs=[text2sound_sample_index_slider, text2sound_state],
|
| 161 |
+
outputs=[text2sound_latent_representation_image,
|
| 162 |
+
text2sound_quantized_latent_representation_image,
|
| 163 |
+
text2sound_sampled_spectrogram_image,
|
| 164 |
+
text2sound_sampled_audio])
|
webUI/natural_language_guided/README.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
readme_content = """## Stable Diffusion for Sound Generation
|
| 4 |
+
|
| 5 |
+
This project applies stable diffusion[1] to sound generation. Inspired by the work of AUTOMATIC1111, 2022[2], we have implemented a preliminary version of text2sound, sound2sound, inpaint, as well as an additional interpolation feature, all accessible through a web UI.
|
| 6 |
+
|
| 7 |
+
### Neural Network Training Data:
|
| 8 |
+
The neural network is trained using the filtered NSynth dataset[3], which is a large-scale and high-quality collection of annotated musical notes, comprising 305,979 musical notes. However, for this project, only samples with a pitch set to E3 were used, resulting in an actual training sample size of 4,096, making it a low-resource project.
|
| 9 |
+
|
| 10 |
+
The training took place on an NVIDIA Tesla T4 GPU and spanned approximately 10 hours.
|
| 11 |
+
|
| 12 |
+
### Natural Language Guidance:
|
| 13 |
+
Natural language guidance is derived from the multi-label annotations of the NSynth dataset. The labels included in the training are:
|
| 14 |
+
|
| 15 |
+
- **Instrument Families**: bass, brass, flute, guitar, keyboard, mallet, organ, reed, string, synth lead, vocal.
|
| 16 |
+
|
| 17 |
+
- **Instrument Sources**: acoustic, electronic, synthetic.
|
| 18 |
+
|
| 19 |
+
- **Note Qualities**: bright, dark, distortion, fast decay, long release, multiphonic, nonlinear env, percussive, reverb, tempo-synced.
|
| 20 |
+
|
| 21 |
+
### Usage Hints:
|
| 22 |
+
|
| 23 |
+
1. **Prompt Format**: It's recommended to use the format “label1, label2, label3“, e.g., ”organ, dark, long release“.
|
| 24 |
+
|
| 25 |
+
2. **Unique Sounds**: If you keep generating the same sound, try setting a different seed!
|
| 26 |
+
|
| 27 |
+
3. **Sample Indexing**: Drag the "Sample index slider" to view other samples within the generated batch.
|
| 28 |
+
|
| 29 |
+
4. **Running on CPU**: Be cautious with the settings for 'batchsize' and 'sample_steps' when running on CPU to avoid timeouts. Recommended settings are batchsize ≤ 4 and sample_steps = 15.
|
| 30 |
+
|
| 31 |
+
5. **Editing Sounds**: Generated audio can be downloaded and then re-uploaded for further editing at the sound2sound/inpaint sections.
|
| 32 |
+
|
| 33 |
+
6. **Guidance Scale**: A higher 'guidance_scale' intensifies the influence of natural language conditioning on the generation[4]. It's recommended to set it between 3 and 10.
|
| 34 |
+
|
| 35 |
+
7. **Noising Strength**: A smaller 'noising_strength' value makes the generated sound closer to the input sound.
|
| 36 |
+
|
| 37 |
+
References:
|
| 38 |
+
|
| 39 |
+
[1] Rombach, R., Blattmann, A., Lorenz, D., Esser, P., & Ommer, B. (2022). High-resolution image synthesis with latent diffusion models. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (pp. 10684-10695).
|
| 40 |
+
|
| 41 |
+
[2] AUTOMATIC1111. (2022). Stable Diffusion Web UI [Computer software]. Retrieved from https://github.com/AUTOMATIC1111/stable-diffusion-webui
|
| 42 |
+
|
| 43 |
+
[3] Engel, J., Resnick, C., Roberts, A., Dieleman, S., Eck, D., Simonyan, K., & Norouzi, M. (2017). Neural Audio Synthesis of Musical Notes with WaveNet Autoencoders.
|
| 44 |
+
|
| 45 |
+
[4] Ho, J., & Salimans, T. (2022). Classifier-free diffusion guidance. arXiv preprint arXiv:2207.12598.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def get_readme_module():
|
| 49 |
+
|
| 50 |
+
with gr.Tab("README"):
|
| 51 |
+
# gr.Markdown("Use interpolation to generate a gradient sound sequence.")
|
| 52 |
+
with gr.Column(scale=3):
|
| 53 |
+
readme_textbox = gr.Textbox(label="readme", lines=40, value=readme_content, interactive=False)
|
webUI/natural_language_guided/__pycache__/README.cpython-310.pyc
ADDED
|
Binary file (3.51 kB). View file
|
|
|
webUI/natural_language_guided/__pycache__/README_STFT.cpython-310.pyc
ADDED
|
Binary file (3.51 kB). View file
|
|
|
webUI/natural_language_guided/__pycache__/buildInstrument_STFT.cpython-310.pyc
ADDED
|
Binary file (8.26 kB). View file
|
|
|
webUI/natural_language_guided/__pycache__/build_instrument.cpython-310.pyc
ADDED
|
Binary file (8.08 kB). View file
|
|
|
webUI/natural_language_guided/__pycache__/gradioWebUI.cpython-310.pyc
ADDED
|
Binary file (3.61 kB). View file
|
|
|
webUI/natural_language_guided/__pycache__/gradioWebUI_STFT.cpython-310.pyc
ADDED
|
Binary file (3.62 kB). View file
|
|
|
webUI/natural_language_guided/__pycache__/gradio_webUI.cpython-310.pyc
ADDED
|
Binary file (3.61 kB). View file
|
|
|
webUI/natural_language_guided/__pycache__/inpaintWithText.cpython-310.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
webUI/natural_language_guided/__pycache__/inpaintWithText_STFT.cpython-310.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
webUI/natural_language_guided/__pycache__/inpaint_with_text.cpython-310.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
webUI/natural_language_guided/__pycache__/rec.cpython-310.pyc
ADDED
|
Binary file (6.11 kB). View file
|
|
|
webUI/natural_language_guided/__pycache__/recSTFT.cpython-310.pyc
ADDED
|
Binary file (6.11 kB). View file
|
|
|
webUI/natural_language_guided/__pycache__/sound2soundWithText.cpython-310.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
webUI/natural_language_guided/__pycache__/sound2soundWithText_STFT.cpython-310.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
webUI/natural_language_guided/__pycache__/sound2sound_with_text.cpython-310.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
webUI/natural_language_guided/__pycache__/text2sound.cpython-310.pyc
ADDED
|
Binary file (6.45 kB). View file
|
|
|
webUI/natural_language_guided/__pycache__/text2sound_STFT.cpython-310.pyc
ADDED
|
Binary file (6.43 kB). View file
|
|
|
webUI/natural_language_guided/__pycache__/track_maker.cpython-310.pyc
ADDED
|
Binary file (7 kB). View file
|
|
|
webUI/natural_language_guided/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (4.78 kB). View file
|
|
|
webUI/natural_language_guided/build_instrument.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import mido
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
import pyrubberband as pyrb
|
| 8 |
+
|
| 9 |
+
from model.DiffSynthSampler import DiffSynthSampler
|
| 10 |
+
from tools import adsr_envelope, adjust_audio_length
|
| 11 |
+
from webUI.natural_language_guided.track_maker import DiffSynth
|
| 12 |
+
from webUI.natural_language_guided.utils import encodeBatch2GradioOutput_STFT, phase_to_Gradio_image, \
|
| 13 |
+
spectrogram_to_Gradio_image
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_build_instrument_module(gradioWebUI, virtual_instruments_state):
|
| 17 |
+
# Load configurations
|
| 18 |
+
uNet = gradioWebUI.uNet
|
| 19 |
+
freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
|
| 20 |
+
VAE_scale = gradioWebUI.VAE_scale
|
| 21 |
+
height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels
|
| 22 |
+
|
| 23 |
+
timesteps = gradioWebUI.timesteps
|
| 24 |
+
VAE_quantizer = gradioWebUI.VAE_quantizer
|
| 25 |
+
VAE_decoder = gradioWebUI.VAE_decoder
|
| 26 |
+
CLAP = gradioWebUI.CLAP
|
| 27 |
+
CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
|
| 28 |
+
device = gradioWebUI.device
|
| 29 |
+
squared = gradioWebUI.squared
|
| 30 |
+
sample_rate = gradioWebUI.sample_rate
|
| 31 |
+
noise_strategy = gradioWebUI.noise_strategy
|
| 32 |
+
|
| 33 |
+
def select_sound(virtual_instrument_name, virtual_instruments_dict):
|
| 34 |
+
virtual_instruments = virtual_instruments_dict["virtual_instruments"]
|
| 35 |
+
virtual_instrument = virtual_instruments[virtual_instrument_name]
|
| 36 |
+
|
| 37 |
+
return {source_sound_spectrogram_image: virtual_instrument["spectrogram_gradio_image"],
|
| 38 |
+
source_sound_phase_image: virtual_instrument["phase_gradio_image"],
|
| 39 |
+
source_sound_audio: virtual_instrument["signal"]}
|
| 40 |
+
|
| 41 |
+
def make_track(inpaint_steps, midi, noising_strength, attack, before_release, instrument_names, virtual_instruments_dict):
|
| 42 |
+
|
| 43 |
+
if noising_strength < 1:
|
| 44 |
+
print(f"Warning: making track with noising_strength = {noising_strength} < 1")
|
| 45 |
+
virtual_instruments = virtual_instruments_dict["virtual_instruments"]
|
| 46 |
+
sample_steps = int(inpaint_steps)
|
| 47 |
+
|
| 48 |
+
instrument_names = instrument_names.split("@")
|
| 49 |
+
instruments_configs = {}
|
| 50 |
+
for virtual_instrument_name in instrument_names:
|
| 51 |
+
virtual_instrument = virtual_instruments[virtual_instrument_name]
|
| 52 |
+
|
| 53 |
+
latent_representation = torch.tensor(virtual_instrument["latent_representation"], dtype=torch.float32).to(device)
|
| 54 |
+
sampler = virtual_instrument["sampler"]
|
| 55 |
+
|
| 56 |
+
batchsize = 1
|
| 57 |
+
|
| 58 |
+
latent_representation = latent_representation.repeat(batchsize, 1, 1, 1)
|
| 59 |
+
|
| 60 |
+
mid = mido.MidiFile(file=BytesIO(midi))
|
| 61 |
+
instruments_configs[virtual_instrument_name] = {
|
| 62 |
+
'sample_steps': sample_steps,
|
| 63 |
+
'sampler': sampler,
|
| 64 |
+
'noising_strength': noising_strength,
|
| 65 |
+
'latent_representation': latent_representation,
|
| 66 |
+
'attack': attack,
|
| 67 |
+
'before_release': before_release}
|
| 68 |
+
|
| 69 |
+
diffSynth = DiffSynth(instruments_configs, uNet, VAE_quantizer, VAE_decoder, CLAP, CLAP_tokenizer, device)
|
| 70 |
+
|
| 71 |
+
full_audio = diffSynth.get_music(mid, instrument_names)
|
| 72 |
+
|
| 73 |
+
return {track_audio: (sample_rate, full_audio)}
|
| 74 |
+
|
| 75 |
+
def test_duration_inpaint(virtual_instrument_name, inpaint_steps, duration, noising_strength, end_noise_level_ratio, attack, before_release, mask_flexivity, virtual_instruments_dict, use_dynamic_mask):
|
| 76 |
+
width = int(time_resolution * ((duration + 1) / 4) / VAE_scale)
|
| 77 |
+
|
| 78 |
+
virtual_instruments = virtual_instruments_dict["virtual_instruments"]
|
| 79 |
+
virtual_instrument = virtual_instruments[virtual_instrument_name]
|
| 80 |
+
|
| 81 |
+
latent_representation = torch.tensor(virtual_instrument["latent_representation"], dtype=torch.float32).to(device)
|
| 82 |
+
sample_steps = int(inpaint_steps)
|
| 83 |
+
sampler = virtual_instrument["sampler"]
|
| 84 |
+
batchsize = 1
|
| 85 |
+
|
| 86 |
+
mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
|
| 87 |
+
mySampler.respace(list(np.linspace(0, timesteps - 1, sample_steps, dtype=np.int32)))
|
| 88 |
+
|
| 89 |
+
latent_representation = latent_representation.repeat(batchsize, 1, 1, 1)
|
| 90 |
+
|
| 91 |
+
# mask = 1, freeze
|
| 92 |
+
latent_mask = torch.zeros((batchsize, 1, height, width), dtype=torch.float32).to(device)
|
| 93 |
+
|
| 94 |
+
latent_mask[:, :, :, :int(time_resolution * (attack / 4) / VAE_scale)] = 1.0
|
| 95 |
+
latent_mask[:, :, :, -int(time_resolution * ((before_release+1) / 4) / VAE_scale):] = 1.0
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
text2sound_embedding = \
|
| 99 |
+
CLAP.get_text_features(**CLAP_tokenizer([""], padding=True, return_tensors="pt"))[0].to(
|
| 100 |
+
device)
|
| 101 |
+
condition = text2sound_embedding.repeat(1, 1)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
latent_representations, initial_noise = \
|
| 105 |
+
mySampler.inpaint_sample(model=uNet, shape=(batchsize, channels, height, width),
|
| 106 |
+
noising_strength=noising_strength,
|
| 107 |
+
guide_img=latent_representation, mask=latent_mask, return_tensor=True,
|
| 108 |
+
condition=condition, sampler=sampler,
|
| 109 |
+
use_dynamic_mask=use_dynamic_mask,
|
| 110 |
+
end_noise_level_ratio=end_noise_level_ratio,
|
| 111 |
+
mask_flexivity=mask_flexivity)
|
| 112 |
+
|
| 113 |
+
latent_representations = latent_representations[-1]
|
| 114 |
+
|
| 115 |
+
quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
|
| 116 |
+
# Todo: remove hard-coding
|
| 117 |
+
flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder,
|
| 118 |
+
quantized_latent_representations,
|
| 119 |
+
resolution=(
|
| 120 |
+
512,
|
| 121 |
+
width * VAE_scale),
|
| 122 |
+
original_STFT_batch=None
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
return {test_duration_spectrogram_image: flipped_log_spectrums[0],
|
| 127 |
+
test_duration_phase_image: flipped_phases[0],
|
| 128 |
+
test_duration_audio: (sample_rate, rec_signals[0])}
|
| 129 |
+
|
| 130 |
+
def test_duration_envelope(virtual_instrument_name, duration, noising_strength, attack, before_release, release, virtual_instruments_dict):
|
| 131 |
+
|
| 132 |
+
virtual_instruments = virtual_instruments_dict["virtual_instruments"]
|
| 133 |
+
virtual_instrument = virtual_instruments[virtual_instrument_name]
|
| 134 |
+
sample_rate, signal = virtual_instrument["signal"]
|
| 135 |
+
|
| 136 |
+
applied_signal = adsr_envelope(signal=signal, sample_rate=sample_rate, duration=duration,
|
| 137 |
+
attack_time=0.0, decay_time=0.0, sustain_level=1.0, release_time=release)
|
| 138 |
+
|
| 139 |
+
D = librosa.stft(applied_signal, n_fft=1024, hop_length=256, win_length=1024)[1:, :]
|
| 140 |
+
spc = np.abs(D)
|
| 141 |
+
phase = np.angle(D)
|
| 142 |
+
|
| 143 |
+
flipped_log_spectrum = spectrogram_to_Gradio_image(spc)
|
| 144 |
+
flipped_phase = phase_to_Gradio_image(phase)
|
| 145 |
+
|
| 146 |
+
return {test_duration_spectrogram_image: flipped_log_spectrum,
|
| 147 |
+
test_duration_phase_image: flipped_phase,
|
| 148 |
+
test_duration_audio: (sample_rate, applied_signal)}
|
| 149 |
+
|
| 150 |
+
def test_duration_stretch(virtual_instrument_name, duration, noising_strength, attack, before_release, release, virtual_instruments_dict):
|
| 151 |
+
|
| 152 |
+
virtual_instruments = virtual_instruments_dict["virtual_instruments"]
|
| 153 |
+
virtual_instrument = virtual_instruments[virtual_instrument_name]
|
| 154 |
+
sample_rate, signal = virtual_instrument["signal"]
|
| 155 |
+
|
| 156 |
+
s = 3 / duration
|
| 157 |
+
applied_signal = pyrb.time_stretch(signal, sample_rate, s)
|
| 158 |
+
applied_signal = adjust_audio_length(applied_signal, int((duration+1) * sample_rate), sample_rate, sample_rate)
|
| 159 |
+
|
| 160 |
+
D = librosa.stft(applied_signal, n_fft=1024, hop_length=256, win_length=1024)[1:, :]
|
| 161 |
+
spc = np.abs(D)
|
| 162 |
+
phase = np.angle(D)
|
| 163 |
+
|
| 164 |
+
flipped_log_spectrum = spectrogram_to_Gradio_image(spc)
|
| 165 |
+
flipped_phase = phase_to_Gradio_image(phase)
|
| 166 |
+
|
| 167 |
+
return {test_duration_spectrogram_image: flipped_log_spectrum,
|
| 168 |
+
test_duration_phase_image: flipped_phase,
|
| 169 |
+
test_duration_audio: (sample_rate, applied_signal)}
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
with gr.Tab("TestInTrack"):
|
| 173 |
+
gr.Markdown("Make music with generated sounds!")
|
| 174 |
+
with gr.Row(variant="panel"):
|
| 175 |
+
with gr.Column(scale=3):
|
| 176 |
+
instrument_name_textbox = gr.Textbox(label="Instrument name", lines=1,
|
| 177 |
+
placeholder="Name of your instrument", scale=1)
|
| 178 |
+
select_instrument_button = gr.Button(variant="primary", value="Select", scale=1)
|
| 179 |
+
with gr.Column(scale=3):
|
| 180 |
+
inpaint_steps_slider = gr.Slider(minimum=5.0, maximum=999.0, value=20.0, step=1.0, label="inpaint_steps")
|
| 181 |
+
noising_strength_slider = gradioWebUI.get_noising_strength_slider(default_noising_strength=1.)
|
| 182 |
+
end_noise_level_ratio_slider = gr.Slider(minimum=0.0, maximum=1., value=0.0, step=0.01, label="end_noise_level_ratio")
|
| 183 |
+
attack_slider = gr.Slider(minimum=0.0, maximum=1.5, value=0.5, step=0.01, label="attack in sec")
|
| 184 |
+
before_release_slider = gr.Slider(minimum=0.0, maximum=1.5, value=0.5, step=0.01, label="before_release in sec")
|
| 185 |
+
release_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="release in sec")
|
| 186 |
+
mask_flexivity_slider = gr.Slider(minimum=0.01, maximum=1.00, value=1., step=0.01, label="mask_flexivity")
|
| 187 |
+
with gr.Column(scale=3):
|
| 188 |
+
use_dynamic_mask_checkbox = gr.Checkbox(label="Use dynamic mask", value=True)
|
| 189 |
+
test_duration_envelope_button = gr.Button(variant="primary", value="Apply envelope", scale=1)
|
| 190 |
+
test_duration_stretch_button = gr.Button(variant="primary", value="Apply stretch", scale=1)
|
| 191 |
+
test_duration_inpaint_button = gr.Button(variant="primary", value="Inpaint different duration", scale=1)
|
| 192 |
+
duration_slider = gradioWebUI.get_duration_slider()
|
| 193 |
+
|
| 194 |
+
with gr.Row(variant="panel"):
|
| 195 |
+
with gr.Column(scale=2):
|
| 196 |
+
with gr.Row(variant="panel"):
|
| 197 |
+
source_sound_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy",
|
| 198 |
+
height=600, scale=1)
|
| 199 |
+
source_sound_phase_image = gr.Image(label="New sound phase", type="numpy",
|
| 200 |
+
height=600, scale=1)
|
| 201 |
+
source_sound_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False)
|
| 202 |
+
|
| 203 |
+
with gr.Column(scale=3):
|
| 204 |
+
with gr.Row(variant="panel"):
|
| 205 |
+
test_duration_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy",
|
| 206 |
+
height=600, scale=1)
|
| 207 |
+
test_duration_phase_image = gr.Image(label="New sound phase", type="numpy",
|
| 208 |
+
height=600, scale=1)
|
| 209 |
+
test_duration_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False)
|
| 210 |
+
|
| 211 |
+
with gr.Row(variant="panel"):
|
| 212 |
+
with gr.Column(scale=1):
|
| 213 |
+
# track_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy",
|
| 214 |
+
# height=420, scale=1)
|
| 215 |
+
midi_file = gr.File(label="Upload midi file", type="binary")
|
| 216 |
+
instrument_names_textbox = gr.Textbox(label="Instrument names", lines=2,
|
| 217 |
+
placeholder="Names of your instrument used to play the midi", scale=1)
|
| 218 |
+
track_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False)
|
| 219 |
+
make_track_button = gr.Button(variant="primary", value="Make track", scale=1)
|
| 220 |
+
|
| 221 |
+
select_instrument_button.click(select_sound,
|
| 222 |
+
inputs=[instrument_name_textbox, virtual_instruments_state],
|
| 223 |
+
outputs=[source_sound_spectrogram_image,
|
| 224 |
+
source_sound_phase_image,
|
| 225 |
+
source_sound_audio])
|
| 226 |
+
|
| 227 |
+
test_duration_envelope_button.click(test_duration_envelope,
|
| 228 |
+
inputs=[instrument_name_textbox, duration_slider,
|
| 229 |
+
noising_strength_slider,
|
| 230 |
+
attack_slider,
|
| 231 |
+
before_release_slider,
|
| 232 |
+
release_slider,
|
| 233 |
+
virtual_instruments_state,
|
| 234 |
+
],
|
| 235 |
+
outputs=[test_duration_spectrogram_image,
|
| 236 |
+
test_duration_phase_image,
|
| 237 |
+
test_duration_audio])
|
| 238 |
+
|
| 239 |
+
test_duration_stretch_button.click(test_duration_stretch,
|
| 240 |
+
inputs=[instrument_name_textbox, duration_slider,
|
| 241 |
+
noising_strength_slider,
|
| 242 |
+
attack_slider,
|
| 243 |
+
before_release_slider,
|
| 244 |
+
release_slider,
|
| 245 |
+
virtual_instruments_state,
|
| 246 |
+
],
|
| 247 |
+
outputs=[test_duration_spectrogram_image,
|
| 248 |
+
test_duration_phase_image,
|
| 249 |
+
test_duration_audio])
|
| 250 |
+
|
| 251 |
+
test_duration_inpaint_button.click(test_duration_inpaint,
|
| 252 |
+
inputs=[instrument_name_textbox,
|
| 253 |
+
inpaint_steps_slider,
|
| 254 |
+
duration_slider,
|
| 255 |
+
noising_strength_slider,
|
| 256 |
+
end_noise_level_ratio_slider,
|
| 257 |
+
attack_slider,
|
| 258 |
+
before_release_slider,
|
| 259 |
+
mask_flexivity_slider,
|
| 260 |
+
virtual_instruments_state,
|
| 261 |
+
use_dynamic_mask_checkbox],
|
| 262 |
+
outputs=[test_duration_spectrogram_image,
|
| 263 |
+
test_duration_phase_image,
|
| 264 |
+
test_duration_audio])
|
| 265 |
+
|
| 266 |
+
make_track_button.click(make_track,
|
| 267 |
+
inputs=[inpaint_steps_slider, midi_file,
|
| 268 |
+
noising_strength_slider,
|
| 269 |
+
attack_slider,
|
| 270 |
+
before_release_slider,
|
| 271 |
+
instrument_names_textbox,
|
| 272 |
+
virtual_instruments_state],
|
| 273 |
+
outputs=[track_audio])
|
| 274 |
+
|
webUI/natural_language_guided/gradio_webUI.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class GradioWebUI():
|
| 5 |
+
|
| 6 |
+
def __init__(self, device, VAE, uNet, CLAP, CLAP_tokenizer,
|
| 7 |
+
freq_resolution=512, time_resolution=256, channels=4, timesteps=1000,
|
| 8 |
+
sample_rate=16000, squared=False, VAE_scale=4,
|
| 9 |
+
flexible_duration=False, noise_strategy="repeat",
|
| 10 |
+
GAN_generator = None):
|
| 11 |
+
self.device = device
|
| 12 |
+
self.VAE_encoder, self.VAE_quantizer, self.VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
|
| 13 |
+
self.uNet = uNet
|
| 14 |
+
self.CLAP, self.CLAP_tokenizer = CLAP, CLAP_tokenizer
|
| 15 |
+
self.freq_resolution, self.time_resolution = freq_resolution, time_resolution
|
| 16 |
+
self.channels = channels
|
| 17 |
+
self.GAN_generator = GAN_generator
|
| 18 |
+
|
| 19 |
+
self.timesteps = timesteps
|
| 20 |
+
self.sample_rate = sample_rate
|
| 21 |
+
self.squared = squared
|
| 22 |
+
self.VAE_scale = VAE_scale
|
| 23 |
+
self.flexible_duration = flexible_duration
|
| 24 |
+
self.noise_strategy = noise_strategy
|
| 25 |
+
|
| 26 |
+
self.text2sound_state = gr.State(value={})
|
| 27 |
+
self.interpolation_state = gr.State(value={})
|
| 28 |
+
self.sound2sound_state = gr.State(value={})
|
| 29 |
+
self.inpaint_state = gr.State(value={})
|
| 30 |
+
|
| 31 |
+
def get_sample_steps_slider(self):
|
| 32 |
+
default_steps = 10 if (self.device == "cpu") else 20
|
| 33 |
+
return gr.Slider(minimum=10, maximum=100, value=default_steps, step=1,
|
| 34 |
+
label="Sample steps",
|
| 35 |
+
info="Sampling steps. The more sampling steps, the better the "
|
| 36 |
+
"theoretical result, but the time it consumes.")
|
| 37 |
+
|
| 38 |
+
def get_sampler_radio(self):
|
| 39 |
+
# return gr.Radio(choices=["ddpm", "ddim", "dpmsolver++", "dpmsolver"], value="ddim", label="Sampler")
|
| 40 |
+
return gr.Radio(choices=["ddpm", "ddim"], value="ddim", label="Sampler")
|
| 41 |
+
|
| 42 |
+
def get_batchsize_slider(self, cpu_batchsize=1):
|
| 43 |
+
return gr.Slider(minimum=1., maximum=16, value=cpu_batchsize if (self.device == "cpu") else 8, step=1, label="Batchsize")
|
| 44 |
+
|
| 45 |
+
def get_time_resolution_slider(self):
|
| 46 |
+
return gr.Slider(minimum=16., maximum=int(1024/self.VAE_scale), value=int(256/self.VAE_scale), step=1, label="Time resolution", interactive=True)
|
| 47 |
+
|
| 48 |
+
def get_duration_slider(self):
|
| 49 |
+
if self.flexible_duration:
|
| 50 |
+
return gr.Slider(minimum=0.25, maximum=8., value=3., step=0.01, label="duration in sec")
|
| 51 |
+
else:
|
| 52 |
+
return gr.Slider(minimum=1., maximum=8., value=3., step=1., label="duration in sec")
|
| 53 |
+
|
| 54 |
+
def get_guidance_scale_slider(self):
|
| 55 |
+
return gr.Slider(minimum=0., maximum=20., value=6., step=1.,
|
| 56 |
+
label="Guidance scale",
|
| 57 |
+
info="The larger this value, the more the generated sound is "
|
| 58 |
+
"influenced by the condition. Setting it to 0 is equivalent to "
|
| 59 |
+
"the negative case.")
|
| 60 |
+
|
| 61 |
+
def get_noising_strength_slider(self, default_noising_strength=0.7):
|
| 62 |
+
return gr.Slider(minimum=0.0, maximum=1.00, value=default_noising_strength, step=0.01,
|
| 63 |
+
label="noising strength",
|
| 64 |
+
info="The smaller this value, the more the generated sound is "
|
| 65 |
+
"closed to the origin.")
|
| 66 |
+
|
| 67 |
+
def get_seed_textbox(self):
|
| 68 |
+
return gr.Textbox(label="Seed", lines=1, placeholder="seed", value=0)
|
webUI/natural_language_guided/inpaint_with_text.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import gradio as gr
|
| 5 |
+
from scipy.ndimage import zoom
|
| 6 |
+
|
| 7 |
+
from model.DiffSynthSampler import DiffSynthSampler
|
| 8 |
+
from tools import adjust_audio_length, safe_int, pad_STFT, encode_stft
|
| 9 |
+
from webUI.natural_language_guided.utils import latent_representation_to_Gradio_image, InputBatch2Encode_STFT, encodeBatch2GradioOutput_STFT, add_instrument
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_triangle_mask(height, width):
|
| 13 |
+
mask = np.zeros((height, width))
|
| 14 |
+
slope = 8 / 3
|
| 15 |
+
for i in range(height):
|
| 16 |
+
for j in range(width):
|
| 17 |
+
if i < slope * j:
|
| 18 |
+
mask[i, j] = 1
|
| 19 |
+
return mask
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_inpaint_with_text_module(gradioWebUI, inpaintWithText_state, virtual_instruments_state):
|
| 23 |
+
# Load configurations
|
| 24 |
+
uNet = gradioWebUI.uNet
|
| 25 |
+
freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
|
| 26 |
+
VAE_scale = gradioWebUI.VAE_scale
|
| 27 |
+
height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels
|
| 28 |
+
timesteps = gradioWebUI.timesteps
|
| 29 |
+
VAE_encoder = gradioWebUI.VAE_encoder
|
| 30 |
+
VAE_quantizer = gradioWebUI.VAE_quantizer
|
| 31 |
+
VAE_decoder = gradioWebUI.VAE_decoder
|
| 32 |
+
CLAP = gradioWebUI.CLAP
|
| 33 |
+
CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
|
| 34 |
+
device = gradioWebUI.device
|
| 35 |
+
squared = gradioWebUI.squared
|
| 36 |
+
sample_rate = gradioWebUI.sample_rate
|
| 37 |
+
noise_strategy = gradioWebUI.noise_strategy
|
| 38 |
+
|
| 39 |
+
def receive_uopoad_origin_audio(sound2sound_duration, sound2sound_origin_source, sound2sound_origin_upload, sound2sound_origin_microphone,
|
| 40 |
+
inpaintWithText_dict):
|
| 41 |
+
|
| 42 |
+
if sound2sound_origin_source == "upload":
|
| 43 |
+
origin_sr, origin_audio = sound2sound_origin_upload
|
| 44 |
+
else:
|
| 45 |
+
origin_sr, origin_audio = sound2sound_origin_microphone
|
| 46 |
+
|
| 47 |
+
origin_audio = origin_audio / np.max(np.abs(origin_audio))
|
| 48 |
+
|
| 49 |
+
width = int(time_resolution*((sound2sound_duration+1)/4) / VAE_scale)
|
| 50 |
+
audio_length = 256 * (VAE_scale * width - 1)
|
| 51 |
+
origin_audio = adjust_audio_length(origin_audio, audio_length, origin_sr, sample_rate)
|
| 52 |
+
|
| 53 |
+
D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024)
|
| 54 |
+
padded_D = pad_STFT(D)
|
| 55 |
+
encoded_D = encode_stft(padded_D)
|
| 56 |
+
|
| 57 |
+
# Todo: justify batchsize to 1
|
| 58 |
+
origin_spectrogram_batch_tensor = torch.from_numpy(
|
| 59 |
+
np.repeat(encoded_D[np.newaxis, :, :, :], 1, axis=0)).float().to(device)
|
| 60 |
+
|
| 61 |
+
# Todo: remove hard-coding
|
| 62 |
+
origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, origin_latent_representations, quantized_origin_latent_representations = InputBatch2Encode_STFT(
|
| 63 |
+
VAE_encoder, origin_spectrogram_batch_tensor, resolution=(512, width * VAE_scale), quantizer=VAE_quantizer, squared=squared)
|
| 64 |
+
|
| 65 |
+
if sound2sound_origin_source == "upload":
|
| 66 |
+
inpaintWithText_dict["origin_upload_latent_representations"] = origin_latent_representations.tolist()
|
| 67 |
+
inpaintWithText_dict[
|
| 68 |
+
"sound2sound_origin_upload_latent_representation_image"] = latent_representation_to_Gradio_image(
|
| 69 |
+
origin_latent_representations[0]).tolist()
|
| 70 |
+
inpaintWithText_dict[
|
| 71 |
+
"sound2sound_origin_upload_quantized_latent_representation_image"] = latent_representation_to_Gradio_image(
|
| 72 |
+
quantized_origin_latent_representations[0]).tolist()
|
| 73 |
+
return {sound2sound_origin_spectrogram_upload_image: origin_flipped_log_spectrums[0],
|
| 74 |
+
sound2sound_origin_phase_upload_image: origin_flipped_phases[0],
|
| 75 |
+
sound2sound_origin_spectrogram_microphone_image: gr.update(),
|
| 76 |
+
sound2sound_origin_phase_microphone_image: gr.update(),
|
| 77 |
+
sound2sound_origin_upload_latent_representation_image: latent_representation_to_Gradio_image(
|
| 78 |
+
origin_latent_representations[0]),
|
| 79 |
+
sound2sound_origin_upload_quantized_latent_representation_image: latent_representation_to_Gradio_image(
|
| 80 |
+
quantized_origin_latent_representations[0]),
|
| 81 |
+
sound2sound_origin_microphone_latent_representation_image: gr.update(),
|
| 82 |
+
sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(),
|
| 83 |
+
inpaintWithText_state: inpaintWithText_dict}
|
| 84 |
+
else:
|
| 85 |
+
inpaintWithText_dict["origin_microphone_latent_representations"] = origin_latent_representations.tolist()
|
| 86 |
+
inpaintWithText_dict[
|
| 87 |
+
"sound2sound_origin_microphone_latent_representation_image"] = latent_representation_to_Gradio_image(
|
| 88 |
+
origin_latent_representations[0]).tolist()
|
| 89 |
+
inpaintWithText_dict[
|
| 90 |
+
"sound2sound_origin_microphone_quantized_latent_representation_image"] = latent_representation_to_Gradio_image(
|
| 91 |
+
quantized_origin_latent_representations[0]).tolist()
|
| 92 |
+
return {sound2sound_origin_spectrogram_upload_image: origin_flipped_log_spectrums[0],
|
| 93 |
+
sound2sound_origin_phase_upload_image: origin_flipped_phases[0],
|
| 94 |
+
sound2sound_origin_spectrogram_microphone_image: gr.update(),
|
| 95 |
+
sound2sound_origin_phase_microphone_image: gr.update(),
|
| 96 |
+
sound2sound_origin_upload_latent_representation_image: latent_representation_to_Gradio_image(
|
| 97 |
+
origin_latent_representations[0]),
|
| 98 |
+
sound2sound_origin_upload_quantized_latent_representation_image: latent_representation_to_Gradio_image(
|
| 99 |
+
quantized_origin_latent_representations[0]),
|
| 100 |
+
sound2sound_origin_microphone_latent_representation_image: gr.update(),
|
| 101 |
+
sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(),
|
| 102 |
+
inpaintWithText_state: inpaintWithText_dict}
|
| 103 |
+
|
| 104 |
+
def sound2sound_sample(sound2sound_origin_spectrogram_upload, sound2sound_origin_spectrogram_microphone,
|
| 105 |
+
text2sound_prompts, text2sound_negative_prompts, sound2sound_batchsize,
|
| 106 |
+
sound2sound_guidance_scale, sound2sound_sampler,
|
| 107 |
+
sound2sound_sample_steps, sound2sound_origin_source,
|
| 108 |
+
sound2sound_noising_strength, sound2sound_seed, sound2sound_inpaint_area,
|
| 109 |
+
mask_time_begin, mask_time_end, mask_frequency_begin, mask_frequency_end, inpaintWithText_dict
|
| 110 |
+
):
|
| 111 |
+
|
| 112 |
+
# input preprocessing
|
| 113 |
+
sound2sound_seed = safe_int(sound2sound_seed, 12345678)
|
| 114 |
+
sound2sound_batchsize = int(sound2sound_batchsize)
|
| 115 |
+
noising_strength = sound2sound_noising_strength
|
| 116 |
+
sound2sound_sample_steps = int(sound2sound_sample_steps)
|
| 117 |
+
CFG = int(sound2sound_guidance_scale)
|
| 118 |
+
|
| 119 |
+
text2sound_embedding = \
|
| 120 |
+
CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(device)
|
| 121 |
+
|
| 122 |
+
if sound2sound_origin_source == "upload":
|
| 123 |
+
origin_latent_representations = torch.tensor(
|
| 124 |
+
inpaintWithText_dict["origin_upload_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to(
|
| 125 |
+
device)
|
| 126 |
+
mask = np.array(sound2sound_origin_spectrogram_upload["mask"])
|
| 127 |
+
elif sound2sound_origin_source == "microphone":
|
| 128 |
+
origin_latent_representations = torch.tensor(
|
| 129 |
+
inpaintWithText_dict["origin_microphone_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to(
|
| 130 |
+
device)
|
| 131 |
+
mask = np.array(sound2sound_origin_spectrogram_microphone["mask"])
|
| 132 |
+
else:
|
| 133 |
+
print("Input source not in ['upload', 'microphone']!")
|
| 134 |
+
raise NotImplementedError()
|
| 135 |
+
|
| 136 |
+
merged_mask = np.all(mask == 255, axis=2).astype(np.uint8)
|
| 137 |
+
latent_mask = zoom(merged_mask, (1 / VAE_scale, 1 / VAE_scale))
|
| 138 |
+
latent_mask = np.clip(latent_mask, 0, 1)
|
| 139 |
+
print(f"latent_mask.avg = {np.mean(latent_mask)}")
|
| 140 |
+
latent_mask[int(mask_frequency_begin):int(mask_frequency_end), int(mask_time_begin*time_resolution/(VAE_scale*4)):int(mask_time_end*time_resolution/(VAE_scale*4))] = 1
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# latent_mask = get_triangle_mask(128, 64)
|
| 144 |
+
|
| 145 |
+
print(f"latent_mask.avg = {np.mean(latent_mask)}")
|
| 146 |
+
if sound2sound_inpaint_area == "inpaint masked":
|
| 147 |
+
latent_mask = 1 - latent_mask
|
| 148 |
+
latent_mask = torch.from_numpy(latent_mask).unsqueeze(0).unsqueeze(1).repeat(sound2sound_batchsize, channels, 1,
|
| 149 |
+
1).float().to(device)
|
| 150 |
+
latent_mask = torch.flip(latent_mask, [2])
|
| 151 |
+
|
| 152 |
+
mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
|
| 153 |
+
unconditional_condition = \
|
| 154 |
+
CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[0]
|
| 155 |
+
mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device))
|
| 156 |
+
|
| 157 |
+
normalized_sample_steps = int(sound2sound_sample_steps / noising_strength)
|
| 158 |
+
|
| 159 |
+
mySampler.respace(list(np.linspace(0, timesteps - 1, normalized_sample_steps, dtype=np.int32)))
|
| 160 |
+
|
| 161 |
+
# Todo: remove hard-coding
|
| 162 |
+
width = origin_latent_representations.shape[-1]
|
| 163 |
+
condition = text2sound_embedding.repeat(sound2sound_batchsize, 1)
|
| 164 |
+
|
| 165 |
+
new_sound_latent_representations, initial_noise = \
|
| 166 |
+
mySampler.inpaint_sample(model=uNet, shape=(sound2sound_batchsize, channels, height, width),
|
| 167 |
+
seed=sound2sound_seed,
|
| 168 |
+
noising_strength=noising_strength,
|
| 169 |
+
guide_img=origin_latent_representations, mask=latent_mask, return_tensor=True,
|
| 170 |
+
condition=condition, sampler=sound2sound_sampler)
|
| 171 |
+
|
| 172 |
+
new_sound_latent_representations = new_sound_latent_representations[-1]
|
| 173 |
+
|
| 174 |
+
# Quantize new sound latent representations
|
| 175 |
+
quantized_new_sound_latent_representations, loss, (_, _, _) = VAE_quantizer(new_sound_latent_representations)
|
| 176 |
+
new_sound_flipped_log_spectrums, new_sound_flipped_phases, new_sound_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder,
|
| 177 |
+
quantized_new_sound_latent_representations,
|
| 178 |
+
resolution=(
|
| 179 |
+
512,
|
| 180 |
+
width * VAE_scale),
|
| 181 |
+
original_STFT_batch=None
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
new_sound_latent_representation_gradio_images = []
|
| 185 |
+
new_sound_quantized_latent_representation_gradio_images = []
|
| 186 |
+
new_sound_spectrogram_gradio_images = []
|
| 187 |
+
new_sound_phase_gradio_images = []
|
| 188 |
+
new_sound_rec_signals_gradio = []
|
| 189 |
+
for i in range(sound2sound_batchsize):
|
| 190 |
+
new_sound_latent_representation_gradio_images.append(
|
| 191 |
+
latent_representation_to_Gradio_image(new_sound_latent_representations[i]))
|
| 192 |
+
new_sound_quantized_latent_representation_gradio_images.append(
|
| 193 |
+
latent_representation_to_Gradio_image(quantized_new_sound_latent_representations[i]))
|
| 194 |
+
new_sound_spectrogram_gradio_images.append(new_sound_flipped_log_spectrums[i])
|
| 195 |
+
new_sound_phase_gradio_images.append(new_sound_flipped_phases[i])
|
| 196 |
+
new_sound_rec_signals_gradio.append((sample_rate, new_sound_signals[i]))
|
| 197 |
+
|
| 198 |
+
inpaintWithText_dict[
|
| 199 |
+
"new_sound_latent_representation_gradio_images"] = new_sound_latent_representation_gradio_images
|
| 200 |
+
inpaintWithText_dict[
|
| 201 |
+
"new_sound_quantized_latent_representation_gradio_images"] = new_sound_quantized_latent_representation_gradio_images
|
| 202 |
+
inpaintWithText_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
|
| 203 |
+
inpaintWithText_dict["new_sound_phase_gradio_images"] = new_sound_phase_gradio_images
|
| 204 |
+
inpaintWithText_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
|
| 205 |
+
|
| 206 |
+
inpaintWithText_dict["latent_representations"] = new_sound_latent_representations.to("cpu").detach().numpy()
|
| 207 |
+
inpaintWithText_dict["quantized_latent_representations"] = quantized_new_sound_latent_representations.to("cpu").detach().numpy()
|
| 208 |
+
inpaintWithText_dict["sampler"] = sound2sound_sampler
|
| 209 |
+
|
| 210 |
+
return {sound2sound_new_sound_latent_representation_image: latent_representation_to_Gradio_image(
|
| 211 |
+
new_sound_latent_representations[0]),
|
| 212 |
+
sound2sound_new_sound_quantized_latent_representation_image: latent_representation_to_Gradio_image(
|
| 213 |
+
quantized_new_sound_latent_representations[0]),
|
| 214 |
+
sound2sound_new_sound_spectrogram_image: new_sound_flipped_log_spectrums[0],
|
| 215 |
+
sound2sound_new_sound_phase_image: new_sound_flipped_phases[0],
|
| 216 |
+
sound2sound_new_sound_audio: (sample_rate, new_sound_signals[0]),
|
| 217 |
+
sound2sound_sample_index_slider: gr.update(minimum=0, maximum=sound2sound_batchsize - 1, value=0,
|
| 218 |
+
step=1.0,
|
| 219 |
+
visible=True,
|
| 220 |
+
label="Sample index",
|
| 221 |
+
info="Swipe to view other samples"),
|
| 222 |
+
sound2sound_seed_textbox: sound2sound_seed,
|
| 223 |
+
inpaintWithText_state: inpaintWithText_dict}
|
| 224 |
+
|
| 225 |
+
def show_sound2sound_sample(sound2sound_sample_index, inpaintWithText_dict):
|
| 226 |
+
sample_index = int(sound2sound_sample_index)
|
| 227 |
+
return {sound2sound_new_sound_latent_representation_image:
|
| 228 |
+
inpaintWithText_dict["new_sound_latent_representation_gradio_images"][sample_index],
|
| 229 |
+
sound2sound_new_sound_quantized_latent_representation_image:
|
| 230 |
+
inpaintWithText_dict["new_sound_quantized_latent_representation_gradio_images"][sample_index],
|
| 231 |
+
sound2sound_new_sound_spectrogram_image: inpaintWithText_dict["new_sound_spectrogram_gradio_images"][
|
| 232 |
+
sample_index],
|
| 233 |
+
sound2sound_new_sound_phase_image: inpaintWithText_dict["new_sound_phase_gradio_images"][
|
| 234 |
+
sample_index],
|
| 235 |
+
sound2sound_new_sound_audio: inpaintWithText_dict["new_sound_rec_signals_gradio"][sample_index]}
|
| 236 |
+
|
| 237 |
+
def sound2sound_switch_origin_source(sound2sound_origin_source):
|
| 238 |
+
|
| 239 |
+
if sound2sound_origin_source == "upload":
|
| 240 |
+
return {sound2sound_origin_upload_audio: gr.update(visible=True),
|
| 241 |
+
sound2sound_origin_microphone_audio: gr.update(visible=False),
|
| 242 |
+
sound2sound_origin_spectrogram_upload_image: gr.update(visible=True),
|
| 243 |
+
sound2sound_origin_phase_upload_image: gr.update(visible=True),
|
| 244 |
+
sound2sound_origin_spectrogram_microphone_image: gr.update(visible=False),
|
| 245 |
+
sound2sound_origin_phase_microphone_image: gr.update(visible=False),
|
| 246 |
+
sound2sound_origin_upload_latent_representation_image: gr.update(visible=True),
|
| 247 |
+
sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=True),
|
| 248 |
+
sound2sound_origin_microphone_latent_representation_image: gr.update(visible=False),
|
| 249 |
+
sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=False)}
|
| 250 |
+
elif sound2sound_origin_source == "microphone":
|
| 251 |
+
return {sound2sound_origin_upload_audio: gr.update(visible=False),
|
| 252 |
+
sound2sound_origin_microphone_audio: gr.update(visible=True),
|
| 253 |
+
sound2sound_origin_spectrogram_upload_image: gr.update(visible=False),
|
| 254 |
+
sound2sound_origin_phase_upload_image: gr.update(visible=False),
|
| 255 |
+
sound2sound_origin_spectrogram_microphone_image: gr.update(visible=True),
|
| 256 |
+
sound2sound_origin_phase_microphone_image: gr.update(visible=True),
|
| 257 |
+
sound2sound_origin_upload_latent_representation_image: gr.update(visible=False),
|
| 258 |
+
sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=False),
|
| 259 |
+
sound2sound_origin_microphone_latent_representation_image: gr.update(visible=True),
|
| 260 |
+
sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=True)}
|
| 261 |
+
else:
|
| 262 |
+
print("Input source not in ['upload', 'microphone']!")
|
| 263 |
+
|
| 264 |
+
def save_virtual_instrument(sample_index, virtual_instrument_name, sound2sound_dict, virtual_instruments_dict):
|
| 265 |
+
|
| 266 |
+
virtual_instruments_dict = add_instrument(sound2sound_dict, virtual_instruments_dict, virtual_instrument_name, sample_index)
|
| 267 |
+
return {virtual_instruments_state: virtual_instruments_dict,
|
| 268 |
+
sound2sound_instrument_name_textbox: gr.Textbox(label="Instrument name", lines=1,
|
| 269 |
+
placeholder=f"Saved as {virtual_instrument_name}!")}
|
| 270 |
+
|
| 271 |
+
with gr.Tab("Inpaint"):
|
| 272 |
+
gr.Markdown("Select the area to inpaint and use the prompt to guide the synthesis of a new sound!")
|
| 273 |
+
with gr.Row(variant="panel"):
|
| 274 |
+
with gr.Column(scale=3):
|
| 275 |
+
text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ")
|
| 276 |
+
text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
|
| 277 |
+
|
| 278 |
+
with gr.Column(scale=1):
|
| 279 |
+
sound2sound_sample_button = gr.Button(variant="primary", value="Generate", scale=1)
|
| 280 |
+
|
| 281 |
+
sound2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
|
| 282 |
+
label="Sample index",
|
| 283 |
+
info="Swipe to view other samples")
|
| 284 |
+
|
| 285 |
+
with gr.Row(variant="panel"):
|
| 286 |
+
with gr.Column(scale=1):
|
| 287 |
+
with gr.Tab("Origin sound"):
|
| 288 |
+
sound2sound_duration_slider = gradioWebUI.get_duration_slider()
|
| 289 |
+
sound2sound_origin_source_radio = gr.Radio(choices=["upload", "microphone"], value="upload",
|
| 290 |
+
label="Input source")
|
| 291 |
+
|
| 292 |
+
sound2sound_origin_upload_audio = gr.Audio(type="numpy", label="Upload", source="upload",
|
| 293 |
+
interactive=True, visible=True)
|
| 294 |
+
sound2sound_origin_microphone_audio = gr.Audio(type="numpy", label="Record", source="microphone",
|
| 295 |
+
interactive=True, visible=False)
|
| 296 |
+
with gr.Row(variant="panel"):
|
| 297 |
+
sound2sound_origin_spectrogram_upload_image = gr.Image(label="Original upload spectrogram",
|
| 298 |
+
type="numpy", height=600,
|
| 299 |
+
visible=True, tool="sketch")
|
| 300 |
+
sound2sound_origin_phase_upload_image = gr.Image(label="Original upload phase",
|
| 301 |
+
type="numpy", height=600,
|
| 302 |
+
visible=True)
|
| 303 |
+
sound2sound_origin_spectrogram_microphone_image = gr.Image(label="Original microphone spectrogram",
|
| 304 |
+
type="numpy", height=600,
|
| 305 |
+
visible=False, tool="sketch")
|
| 306 |
+
sound2sound_origin_phase_microphone_image = gr.Image(label="Original microphone phase",
|
| 307 |
+
type="numpy", height=600,
|
| 308 |
+
visible=False)
|
| 309 |
+
sound2sound_inpaint_area_radio = gr.Radio(choices=["inpaint masked", "inpaint not masked"],
|
| 310 |
+
value="inpaint masked")
|
| 311 |
+
|
| 312 |
+
with gr.Tab("Sound2sound settings"):
|
| 313 |
+
sound2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
|
| 314 |
+
sound2sound_sampler_radio = gradioWebUI.get_sampler_radio()
|
| 315 |
+
sound2sound_batchsize_slider = gradioWebUI.get_batchsize_slider()
|
| 316 |
+
sound2sound_noising_strength_slider = gradioWebUI.get_noising_strength_slider()
|
| 317 |
+
sound2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
|
| 318 |
+
sound2sound_seed_textbox = gradioWebUI.get_seed_textbox()
|
| 319 |
+
|
| 320 |
+
with gr.Tab("Mask prototypes"):
|
| 321 |
+
with gr.Tab("Mask along time axis"):
|
| 322 |
+
mask_time_begin_slider = gr.Slider(minimum=0.0, maximum=4.00, value=0.0, step=0.01, label="Begin time")
|
| 323 |
+
mask_time_end_slider = gr.Slider(minimum=0.0, maximum=4.00, value=0.0, step=0.01, label="End time")
|
| 324 |
+
with gr.Tab("Mask along frequency axis"):
|
| 325 |
+
mask_frequency_begin_slider = gr.Slider(minimum=0, maximum=127, value=0, step=1, label="Begin freq pixel")
|
| 326 |
+
mask_frequency_end_slider = gr.Slider(minimum=0, maximum=127, value=0, step=1, label="End freq pixel")
|
| 327 |
+
|
| 328 |
+
with gr.Column(scale=1):
|
| 329 |
+
sound2sound_new_sound_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False)
|
| 330 |
+
with gr.Row(variant="panel"):
|
| 331 |
+
sound2sound_new_sound_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy",
|
| 332 |
+
height=600, scale=1)
|
| 333 |
+
sound2sound_new_sound_phase_image = gr.Image(label="New sound phase", type="numpy",
|
| 334 |
+
height=600, scale=1)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
with gr.Row(variant="panel"):
|
| 338 |
+
sound2sound_instrument_name_textbox = gr.Textbox(label="Instrument name", lines=1,
|
| 339 |
+
placeholder="Name of your instrument")
|
| 340 |
+
sound2sound_save_instrument_button = gr.Button(variant="primary",
|
| 341 |
+
value="Save instrument",
|
| 342 |
+
scale=1)
|
| 343 |
+
|
| 344 |
+
with gr.Row(variant="panel"):
|
| 345 |
+
sound2sound_origin_upload_latent_representation_image = gr.Image(label="Original latent representation",
|
| 346 |
+
type="numpy", height=800,
|
| 347 |
+
visible=True)
|
| 348 |
+
sound2sound_origin_upload_quantized_latent_representation_image = gr.Image(
|
| 349 |
+
label="Original quantized latent representation", type="numpy", height=800, visible=True)
|
| 350 |
+
|
| 351 |
+
sound2sound_origin_microphone_latent_representation_image = gr.Image(label="Original latent representation",
|
| 352 |
+
type="numpy", height=800,
|
| 353 |
+
visible=False)
|
| 354 |
+
sound2sound_origin_microphone_quantized_latent_representation_image = gr.Image(
|
| 355 |
+
label="Original quantized latent representation", type="numpy", height=800, visible=False)
|
| 356 |
+
|
| 357 |
+
sound2sound_new_sound_latent_representation_image = gr.Image(label="New latent representation",
|
| 358 |
+
type="numpy", height=800)
|
| 359 |
+
sound2sound_new_sound_quantized_latent_representation_image = gr.Image(
|
| 360 |
+
label="New sound quantized latent representation", type="numpy", height=800)
|
| 361 |
+
|
| 362 |
+
sound2sound_origin_upload_audio.change(receive_uopoad_origin_audio,
|
| 363 |
+
inputs=[sound2sound_duration_slider, sound2sound_origin_source_radio, sound2sound_origin_upload_audio,
|
| 364 |
+
sound2sound_origin_microphone_audio, inpaintWithText_state],
|
| 365 |
+
outputs=[sound2sound_origin_spectrogram_upload_image,
|
| 366 |
+
sound2sound_origin_phase_upload_image,
|
| 367 |
+
sound2sound_origin_spectrogram_microphone_image,
|
| 368 |
+
sound2sound_origin_phase_microphone_image,
|
| 369 |
+
sound2sound_origin_upload_latent_representation_image,
|
| 370 |
+
sound2sound_origin_upload_quantized_latent_representation_image,
|
| 371 |
+
sound2sound_origin_microphone_latent_representation_image,
|
| 372 |
+
sound2sound_origin_microphone_quantized_latent_representation_image,
|
| 373 |
+
inpaintWithText_state])
|
| 374 |
+
sound2sound_origin_microphone_audio.change(receive_uopoad_origin_audio,
|
| 375 |
+
inputs=[sound2sound_duration_slider, sound2sound_origin_source_radio, sound2sound_origin_upload_audio,
|
| 376 |
+
sound2sound_origin_microphone_audio, inpaintWithText_state],
|
| 377 |
+
outputs=[sound2sound_origin_spectrogram_upload_image,
|
| 378 |
+
sound2sound_origin_phase_upload_image,
|
| 379 |
+
sound2sound_origin_spectrogram_microphone_image,
|
| 380 |
+
sound2sound_origin_phase_microphone_image,
|
| 381 |
+
sound2sound_origin_upload_latent_representation_image,
|
| 382 |
+
sound2sound_origin_upload_quantized_latent_representation_image,
|
| 383 |
+
sound2sound_origin_microphone_latent_representation_image,
|
| 384 |
+
sound2sound_origin_microphone_quantized_latent_representation_image,
|
| 385 |
+
inpaintWithText_state])
|
| 386 |
+
|
| 387 |
+
sound2sound_sample_button.click(sound2sound_sample,
|
| 388 |
+
inputs=[sound2sound_origin_spectrogram_upload_image,
|
| 389 |
+
sound2sound_origin_spectrogram_microphone_image,
|
| 390 |
+
text2sound_prompts_textbox,
|
| 391 |
+
text2sound_negative_prompts_textbox,
|
| 392 |
+
sound2sound_batchsize_slider,
|
| 393 |
+
sound2sound_guidance_scale_slider,
|
| 394 |
+
sound2sound_sampler_radio,
|
| 395 |
+
sound2sound_sample_steps_slider,
|
| 396 |
+
sound2sound_origin_source_radio,
|
| 397 |
+
sound2sound_noising_strength_slider,
|
| 398 |
+
sound2sound_seed_textbox,
|
| 399 |
+
sound2sound_inpaint_area_radio,
|
| 400 |
+
mask_time_begin_slider,
|
| 401 |
+
mask_time_end_slider,
|
| 402 |
+
mask_frequency_begin_slider,
|
| 403 |
+
mask_frequency_end_slider,
|
| 404 |
+
inpaintWithText_state],
|
| 405 |
+
outputs=[sound2sound_new_sound_latent_representation_image,
|
| 406 |
+
sound2sound_new_sound_quantized_latent_representation_image,
|
| 407 |
+
sound2sound_new_sound_spectrogram_image,
|
| 408 |
+
sound2sound_new_sound_phase_image,
|
| 409 |
+
sound2sound_new_sound_audio,
|
| 410 |
+
sound2sound_sample_index_slider,
|
| 411 |
+
sound2sound_seed_textbox,
|
| 412 |
+
inpaintWithText_state])
|
| 413 |
+
|
| 414 |
+
sound2sound_sample_index_slider.change(show_sound2sound_sample,
|
| 415 |
+
inputs=[sound2sound_sample_index_slider, inpaintWithText_state],
|
| 416 |
+
outputs=[sound2sound_new_sound_latent_representation_image,
|
| 417 |
+
sound2sound_new_sound_quantized_latent_representation_image,
|
| 418 |
+
sound2sound_new_sound_spectrogram_image,
|
| 419 |
+
sound2sound_new_sound_phase_image,
|
| 420 |
+
sound2sound_new_sound_audio])
|
| 421 |
+
|
| 422 |
+
sound2sound_origin_source_radio.change(sound2sound_switch_origin_source,
|
| 423 |
+
inputs=[sound2sound_origin_source_radio],
|
| 424 |
+
outputs=[sound2sound_origin_upload_audio,
|
| 425 |
+
sound2sound_origin_microphone_audio,
|
| 426 |
+
sound2sound_origin_spectrogram_upload_image,
|
| 427 |
+
sound2sound_origin_phase_upload_image,
|
| 428 |
+
sound2sound_origin_spectrogram_microphone_image,
|
| 429 |
+
sound2sound_origin_phase_microphone_image,
|
| 430 |
+
sound2sound_origin_upload_latent_representation_image,
|
| 431 |
+
sound2sound_origin_upload_quantized_latent_representation_image,
|
| 432 |
+
sound2sound_origin_microphone_latent_representation_image,
|
| 433 |
+
sound2sound_origin_microphone_quantized_latent_representation_image])
|
| 434 |
+
|
| 435 |
+
sound2sound_save_instrument_button.click(save_virtual_instrument,
|
| 436 |
+
inputs=[sound2sound_sample_index_slider,
|
| 437 |
+
sound2sound_instrument_name_textbox,
|
| 438 |
+
inpaintWithText_state,
|
| 439 |
+
virtual_instruments_state],
|
| 440 |
+
outputs=[virtual_instruments_state,
|
| 441 |
+
sound2sound_instrument_name_textbox])
|
webUI/natural_language_guided/rec.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
from data_generation.nsynth import get_nsynth_dataloader
|
| 4 |
+
from webUI.natural_language_guided_STFT.utils import encodeBatch2GradioOutput_STFT, InputBatch2Encode_STFT, \
|
| 5 |
+
latent_representation_to_Gradio_image
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_recSTFT_module(gradioWebUI, reconstruction_state):
|
| 9 |
+
# Load configurations
|
| 10 |
+
uNet = gradioWebUI.uNet
|
| 11 |
+
freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
|
| 12 |
+
VAE_scale = gradioWebUI.VAE_scale
|
| 13 |
+
height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels
|
| 14 |
+
|
| 15 |
+
timesteps = gradioWebUI.timesteps
|
| 16 |
+
VAE_quantizer = gradioWebUI.VAE_quantizer
|
| 17 |
+
VAE_encoder = gradioWebUI.VAE_encoder
|
| 18 |
+
VAE_decoder = gradioWebUI.VAE_decoder
|
| 19 |
+
CLAP = gradioWebUI.CLAP
|
| 20 |
+
CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
|
| 21 |
+
device = gradioWebUI.device
|
| 22 |
+
squared = gradioWebUI.squared
|
| 23 |
+
sample_rate = gradioWebUI.sample_rate
|
| 24 |
+
noise_strategy = gradioWebUI.noise_strategy
|
| 25 |
+
|
| 26 |
+
def generate_reconstruction_samples(sample_source, batchsize_slider, encodeCache,
|
| 27 |
+
reconstruction_samples):
|
| 28 |
+
|
| 29 |
+
vae_batchsize = int(batchsize_slider)
|
| 30 |
+
|
| 31 |
+
if sample_source == "text2sound_trainSTFT":
|
| 32 |
+
training_dataset_path = f'data/NSynth/nsynth-STFT-train-52.hdf5' # Make sure to use your actual path
|
| 33 |
+
iterator = get_nsynth_dataloader(training_dataset_path, batch_size=vae_batchsize, shuffle=True,
|
| 34 |
+
get_latent_representation=False, with_meta_data=False,
|
| 35 |
+
task="STFT")
|
| 36 |
+
elif sample_source == "text2sound_validSTFT":
|
| 37 |
+
training_dataset_path = f'data/NSynth/nsynth-STFT-valid-52.hdf5' # Make sure to use your actual path
|
| 38 |
+
iterator = get_nsynth_dataloader(training_dataset_path, batch_size=vae_batchsize, shuffle=True,
|
| 39 |
+
get_latent_representation=False, with_meta_data=False,
|
| 40 |
+
task="STFT")
|
| 41 |
+
elif sample_source == "text2sound_testSTFT":
|
| 42 |
+
training_dataset_path = f'data/NSynth/nsynth-STFT-test-52.hdf5' # Make sure to use your actual path
|
| 43 |
+
iterator = get_nsynth_dataloader(training_dataset_path, batch_size=vae_batchsize, shuffle=True,
|
| 44 |
+
get_latent_representation=False, with_meta_data=False,
|
| 45 |
+
task="STFT")
|
| 46 |
+
else:
|
| 47 |
+
raise NotImplementedError()
|
| 48 |
+
|
| 49 |
+
spectrogram_batch = next(iter(iterator))
|
| 50 |
+
|
| 51 |
+
origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, latent_representations, quantized_latent_representations = InputBatch2Encode_STFT(
|
| 52 |
+
VAE_encoder, spectrogram_batch, resolution=(512, width * VAE_scale), quantizer=VAE_quantizer, squared=squared)
|
| 53 |
+
|
| 54 |
+
latent_representation_gradio_images, quantized_latent_representation_gradio_images = [], []
|
| 55 |
+
for i in range(vae_batchsize):
|
| 56 |
+
latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
|
| 57 |
+
quantized_latent_representation_gradio_images.append(
|
| 58 |
+
latent_representation_to_Gradio_image(quantized_latent_representations[i]))
|
| 59 |
+
|
| 60 |
+
if quantized_latent_representations is None:
|
| 61 |
+
quantized_latent_representations = latent_representations
|
| 62 |
+
reconstruction_flipped_log_spectrums, reconstruction_flipped_phases, reconstruction_signals, reconstruction_flipped_log_spectrums_WOA, reconstruction_flipped_phases_WOA, reconstruction_signals_WOA = encodeBatch2GradioOutput_STFT(VAE_decoder,
|
| 63 |
+
quantized_latent_representations,
|
| 64 |
+
resolution=(
|
| 65 |
+
512,
|
| 66 |
+
width * VAE_scale),
|
| 67 |
+
original_STFT_batch=spectrogram_batch
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
reconstruction_samples["origin_flipped_log_spectrums"] = origin_flipped_log_spectrums
|
| 71 |
+
reconstruction_samples["origin_flipped_phases"] = origin_flipped_phases
|
| 72 |
+
reconstruction_samples["origin_signals"] = origin_signals
|
| 73 |
+
reconstruction_samples["latent_representation_gradio_images"] = latent_representation_gradio_images
|
| 74 |
+
reconstruction_samples[
|
| 75 |
+
"quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
|
| 76 |
+
reconstruction_samples[
|
| 77 |
+
"reconstruction_flipped_log_spectrums"] = reconstruction_flipped_log_spectrums
|
| 78 |
+
reconstruction_samples[
|
| 79 |
+
"reconstruction_flipped_phases"] = reconstruction_flipped_phases
|
| 80 |
+
reconstruction_samples["reconstruction_signals"] = reconstruction_signals
|
| 81 |
+
reconstruction_samples[
|
| 82 |
+
"reconstruction_flipped_log_spectrums_WOA"] = reconstruction_flipped_log_spectrums_WOA
|
| 83 |
+
reconstruction_samples[
|
| 84 |
+
"reconstruction_flipped_phases_WOA"] = reconstruction_flipped_phases_WOA
|
| 85 |
+
reconstruction_samples["reconstruction_signals_WOA"] = reconstruction_signals_WOA
|
| 86 |
+
reconstruction_samples["sampleRate"] = sample_rate
|
| 87 |
+
|
| 88 |
+
latent_representation_gradio_image = reconstruction_samples["latent_representation_gradio_images"][0]
|
| 89 |
+
quantized_latent_representation_gradio_image = \
|
| 90 |
+
reconstruction_samples["quantized_latent_representation_gradio_images"][0]
|
| 91 |
+
origin_flipped_log_spectrum = reconstruction_samples["origin_flipped_log_spectrums"][0]
|
| 92 |
+
origin_flipped_phase = reconstruction_samples["origin_flipped_phases"][0]
|
| 93 |
+
origin_signal = reconstruction_samples["origin_signals"][0]
|
| 94 |
+
reconstruction_flipped_log_spectrum = reconstruction_samples["reconstruction_flipped_log_spectrums"][0]
|
| 95 |
+
reconstruction_flipped_phase = reconstruction_samples["reconstruction_flipped_phases"][0]
|
| 96 |
+
reconstruction_signal = reconstruction_samples["reconstruction_signals"][0]
|
| 97 |
+
reconstruction_flipped_log_spectrum_WOA = reconstruction_samples["reconstruction_flipped_log_spectrums_WOA"][0]
|
| 98 |
+
reconstruction_flipped_phase_WOA = reconstruction_samples["reconstruction_flipped_phases_WOA"][0]
|
| 99 |
+
reconstruction_signal_WOA = reconstruction_samples["reconstruction_signals_WOA"][0]
|
| 100 |
+
|
| 101 |
+
return {origin_amplitude_image_output: origin_flipped_log_spectrum,
|
| 102 |
+
origin_phase_image_output: origin_flipped_phase,
|
| 103 |
+
origin_audio_output: (sample_rate, origin_signal),
|
| 104 |
+
latent_representation_image_output: latent_representation_gradio_image,
|
| 105 |
+
quantized_latent_representation_image_output: quantized_latent_representation_gradio_image,
|
| 106 |
+
reconstruction_amplitude_image_output: reconstruction_flipped_log_spectrum,
|
| 107 |
+
reconstruction_phase_image_output: reconstruction_flipped_phase,
|
| 108 |
+
reconstruction_audio_output: (sample_rate, reconstruction_signal),
|
| 109 |
+
reconstruction_amplitude_image_output_WOA: reconstruction_flipped_log_spectrum_WOA,
|
| 110 |
+
reconstruction_phase_image_output_WOA: reconstruction_flipped_phase_WOA,
|
| 111 |
+
reconstruction_audio_output_WOA: (sample_rate, reconstruction_signal_WOA),
|
| 112 |
+
sample_index_slider: gr.update(minimum=0, maximum=vae_batchsize - 1, value=0, step=1.0,
|
| 113 |
+
label="Sample index.",
|
| 114 |
+
info="Slide to view other samples", scale=1, visible=True),
|
| 115 |
+
reconstruction_state: encodeCache,
|
| 116 |
+
reconstruction_samples_state: reconstruction_samples}
|
| 117 |
+
|
| 118 |
+
def show_reconstruction_sample(sample_index, encodeCache_state, reconstruction_samples_state):
|
| 119 |
+
sample_index = int(sample_index)
|
| 120 |
+
sampleRate = reconstruction_samples_state["sampleRate"]
|
| 121 |
+
latent_representation_gradio_image = reconstruction_samples_state["latent_representation_gradio_images"][
|
| 122 |
+
sample_index]
|
| 123 |
+
quantized_latent_representation_gradio_image = \
|
| 124 |
+
reconstruction_samples_state["quantized_latent_representation_gradio_images"][sample_index]
|
| 125 |
+
origin_flipped_log_spectrum = reconstruction_samples_state["origin_flipped_log_spectrums"][sample_index]
|
| 126 |
+
origin_flipped_phase = reconstruction_samples_state["origin_flipped_phases"][sample_index]
|
| 127 |
+
origin_signal = reconstruction_samples_state["origin_signals"][sample_index]
|
| 128 |
+
reconstruction_flipped_log_spectrum = reconstruction_samples_state["reconstruction_flipped_log_spectrums"][
|
| 129 |
+
sample_index]
|
| 130 |
+
reconstruction_flipped_phase = reconstruction_samples_state["reconstruction_flipped_phases"][
|
| 131 |
+
sample_index]
|
| 132 |
+
reconstruction_signal = reconstruction_samples_state["reconstruction_signals"][sample_index]
|
| 133 |
+
reconstruction_flipped_log_spectrum_WOA = reconstruction_samples_state["reconstruction_flipped_log_spectrums_WOA"][
|
| 134 |
+
sample_index]
|
| 135 |
+
reconstruction_flipped_phase_WOA = reconstruction_samples_state["reconstruction_flipped_phases_WOA"][
|
| 136 |
+
sample_index]
|
| 137 |
+
reconstruction_signal_WOA = reconstruction_samples_state["reconstruction_signals_WOA"][sample_index]
|
| 138 |
+
return origin_flipped_log_spectrum, origin_flipped_phase, (sampleRate, origin_signal), \
|
| 139 |
+
latent_representation_gradio_image, quantized_latent_representation_gradio_image, \
|
| 140 |
+
reconstruction_flipped_log_spectrum, reconstruction_flipped_phase, (sampleRate, reconstruction_signal), \
|
| 141 |
+
reconstruction_flipped_log_spectrum_WOA, reconstruction_flipped_phase_WOA, (sampleRate, reconstruction_signal_WOA), \
|
| 142 |
+
encodeCache_state, reconstruction_samples_state
|
| 143 |
+
|
| 144 |
+
with gr.Tab("Reconstruction"):
|
| 145 |
+
reconstruction_samples_state = gr.State(value={})
|
| 146 |
+
gr.Markdown("Test reconstruction.")
|
| 147 |
+
with gr.Row(variant="panel"):
|
| 148 |
+
with gr.Column():
|
| 149 |
+
sample_source_radio = gr.Radio(
|
| 150 |
+
choices=["synthetic", "external", "text2sound_trainSTFT", "text2sound_testSTFT", "text2sound_validSTFT"],
|
| 151 |
+
value="text2sound_trainf", info="Info placeholder", scale=2)
|
| 152 |
+
batchsize_slider = gr.Slider(minimum=1., maximum=16., value=4., step=1.,
|
| 153 |
+
label="batchsize")
|
| 154 |
+
with gr.Column():
|
| 155 |
+
generate_button = gr.Button(variant="primary", value="Generate reconstruction samples", scale=1)
|
| 156 |
+
sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, label="Sample index.",
|
| 157 |
+
info="Slide to view other samples", scale=1, visible=False)
|
| 158 |
+
with gr.Row(variant="panel"):
|
| 159 |
+
with gr.Column():
|
| 160 |
+
origin_amplitude_image_output = gr.Image(label="Spectrogram", type="numpy", height=300, width=100, scale=1)
|
| 161 |
+
origin_phase_image_output = gr.Image(label="Phase", type="numpy", height=300, width=100, scale=1)
|
| 162 |
+
origin_audio_output = gr.Audio(type="numpy", label="Play the example!")
|
| 163 |
+
with gr.Column():
|
| 164 |
+
reconstruction_amplitude_image_output = gr.Image(label="Spectrogram", type="numpy", height=300, width=100, scale=1)
|
| 165 |
+
reconstruction_phase_image_output = gr.Image(label="Phase", type="numpy", height=300, width=100, scale=1)
|
| 166 |
+
reconstruction_audio_output = gr.Audio(type="numpy", label="Play the example!")
|
| 167 |
+
with gr.Column():
|
| 168 |
+
reconstruction_amplitude_image_output_WOA = gr.Image(label="Spectrogram", type="numpy", height=300, width=100, scale=1)
|
| 169 |
+
reconstruction_phase_image_output_WOA = gr.Image(label="Phase", type="numpy", height=300, width=100, scale=1)
|
| 170 |
+
reconstruction_audio_output_WOA = gr.Audio(type="numpy", label="Play the example!")
|
| 171 |
+
with gr.Row(variant="panel", equal_height=True):
|
| 172 |
+
latent_representation_image_output = gr.Image(label="latent_representation", type="numpy", height=300, width=100)
|
| 173 |
+
quantized_latent_representation_image_output = gr.Image(label="quantized", type="numpy", height=300, width=100)
|
| 174 |
+
|
| 175 |
+
generate_button.click(generate_reconstruction_samples,
|
| 176 |
+
inputs=[sample_source_radio, batchsize_slider, reconstruction_state,
|
| 177 |
+
reconstruction_samples_state],
|
| 178 |
+
outputs=[origin_amplitude_image_output, origin_phase_image_output, origin_audio_output,
|
| 179 |
+
latent_representation_image_output, quantized_latent_representation_image_output,
|
| 180 |
+
reconstruction_amplitude_image_output, reconstruction_phase_image_output, reconstruction_audio_output,
|
| 181 |
+
reconstruction_amplitude_image_output_WOA, reconstruction_phase_image_output_WOA, reconstruction_audio_output_WOA,
|
| 182 |
+
sample_index_slider, reconstruction_state, reconstruction_samples_state])
|
| 183 |
+
|
| 184 |
+
sample_index_slider.change(show_reconstruction_sample,
|
| 185 |
+
inputs=[sample_index_slider, reconstruction_state, reconstruction_samples_state],
|
| 186 |
+
outputs=[origin_amplitude_image_output, origin_phase_image_output, origin_audio_output,
|
| 187 |
+
latent_representation_image_output, quantized_latent_representation_image_output,
|
| 188 |
+
reconstruction_amplitude_image_output, reconstruction_phase_image_output, reconstruction_audio_output,
|
| 189 |
+
reconstruction_amplitude_image_output_WOA, reconstruction_phase_image_output_WOA, reconstruction_audio_output_WOA,
|
| 190 |
+
reconstruction_state, reconstruction_samples_state])
|
webUI/natural_language_guided/sound2sound_with_text.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import librosa
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from model.DiffSynthSampler import DiffSynthSampler
|
| 7 |
+
from tools import pad_STFT, encode_stft
|
| 8 |
+
from tools import safe_int, adjust_audio_length
|
| 9 |
+
from webUI.natural_language_guided.utils import InputBatch2Encode_STFT, encodeBatch2GradioOutput_STFT, \
|
| 10 |
+
latent_representation_to_Gradio_image
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_sound2sound_with_text_module(gradioWebUI, sound2sound_with_text_state, virtual_instruments_state):
|
| 14 |
+
# Load configurations
|
| 15 |
+
uNet = gradioWebUI.uNet
|
| 16 |
+
freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
|
| 17 |
+
VAE_scale = gradioWebUI.VAE_scale
|
| 18 |
+
height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels
|
| 19 |
+
timesteps = gradioWebUI.timesteps
|
| 20 |
+
VAE_encoder = gradioWebUI.VAE_encoder
|
| 21 |
+
VAE_quantizer = gradioWebUI.VAE_quantizer
|
| 22 |
+
VAE_decoder = gradioWebUI.VAE_decoder
|
| 23 |
+
CLAP = gradioWebUI.CLAP
|
| 24 |
+
CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
|
| 25 |
+
device = gradioWebUI.device
|
| 26 |
+
squared = gradioWebUI.squared
|
| 27 |
+
sample_rate = gradioWebUI.sample_rate
|
| 28 |
+
noise_strategy = gradioWebUI.noise_strategy
|
| 29 |
+
|
| 30 |
+
def receive_upload_origin_audio(sound2sound_duration, sound2sound_origin_source,
|
| 31 |
+
sound2sound_origin_upload, sound2sound_origin_microphone,
|
| 32 |
+
sound2sound_with_text_dict, virtual_instruments_dict):
|
| 33 |
+
|
| 34 |
+
if sound2sound_origin_source == "upload":
|
| 35 |
+
origin_sr, origin_audio = sound2sound_origin_upload
|
| 36 |
+
else:
|
| 37 |
+
origin_sr, origin_audio = sound2sound_origin_microphone
|
| 38 |
+
|
| 39 |
+
origin_audio = origin_audio / np.max(np.abs(origin_audio))
|
| 40 |
+
|
| 41 |
+
width = int(time_resolution*((sound2sound_duration+1)/4) / VAE_scale)
|
| 42 |
+
audio_length = 256 * (VAE_scale * width - 1)
|
| 43 |
+
origin_audio = adjust_audio_length(origin_audio, audio_length, origin_sr, sample_rate)
|
| 44 |
+
|
| 45 |
+
D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024)
|
| 46 |
+
padded_D = pad_STFT(D)
|
| 47 |
+
encoded_D = encode_stft(padded_D)
|
| 48 |
+
|
| 49 |
+
# Todo: justify batchsize to 1
|
| 50 |
+
origin_spectrogram_batch_tensor = torch.from_numpy(
|
| 51 |
+
np.repeat(encoded_D[np.newaxis, :, :, :], 1, axis=0)).float().to(device)
|
| 52 |
+
|
| 53 |
+
# Todo: remove hard-coding
|
| 54 |
+
origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, origin_latent_representations, quantized_origin_latent_representations = InputBatch2Encode_STFT(
|
| 55 |
+
VAE_encoder, origin_spectrogram_batch_tensor, resolution=(512, width * VAE_scale), quantizer=VAE_quantizer, squared=squared)
|
| 56 |
+
|
| 57 |
+
default_condition = CLAP.get_text_features(**CLAP_tokenizer([""], padding=True, return_tensors="pt"))[0].to("cpu").detach().numpy()
|
| 58 |
+
|
| 59 |
+
if sound2sound_origin_source == "upload":
|
| 60 |
+
sound2sound_with_text_dict["origin_upload_latent_representations"] = origin_latent_representations.tolist()
|
| 61 |
+
sound2sound_with_text_dict[
|
| 62 |
+
"sound2sound_origin_upload_latent_representation_image"] = latent_representation_to_Gradio_image(
|
| 63 |
+
origin_latent_representations[0]).tolist()
|
| 64 |
+
sound2sound_with_text_dict[
|
| 65 |
+
"sound2sound_origin_upload_quantized_latent_representation_image"] = latent_representation_to_Gradio_image(
|
| 66 |
+
quantized_origin_latent_representations[0]).tolist()
|
| 67 |
+
|
| 68 |
+
virtual_instruments = virtual_instruments_dict["virtual_instruments"]
|
| 69 |
+
virtual_instrument = {"condition": default_condition,
|
| 70 |
+
"negative_condition": default_condition, # care!!!
|
| 71 |
+
"CFG": 1,
|
| 72 |
+
"latent_representation": origin_latent_representations[0].to("cpu").detach().numpy(),
|
| 73 |
+
"quantized_latent_representation": quantized_origin_latent_representations[0].to("cpu").detach().numpy(),
|
| 74 |
+
"sampler": "ddim",
|
| 75 |
+
"signal": (sample_rate, origin_audio),
|
| 76 |
+
"spectrogram_gradio_image": origin_flipped_log_spectrums[0],
|
| 77 |
+
"phase_gradio_image": origin_flipped_phases[0]}
|
| 78 |
+
virtual_instruments["s2sup"] = virtual_instrument
|
| 79 |
+
virtual_instruments_dict["virtual_instruments"] = virtual_instruments
|
| 80 |
+
|
| 81 |
+
return {sound2sound_origin_spectrogram_upload_image: origin_flipped_log_spectrums[0],
|
| 82 |
+
sound2sound_origin_phase_upload_image: origin_flipped_phases[0],
|
| 83 |
+
sound2sound_origin_spectrogram_microphone_image: gr.update(),
|
| 84 |
+
sound2sound_origin_phase_microphone_image: gr.update(),
|
| 85 |
+
sound2sound_origin_upload_latent_representation_image: latent_representation_to_Gradio_image(
|
| 86 |
+
origin_latent_representations[0]),
|
| 87 |
+
sound2sound_origin_upload_quantized_latent_representation_image: latent_representation_to_Gradio_image(
|
| 88 |
+
quantized_origin_latent_representations[0]),
|
| 89 |
+
sound2sound_origin_microphone_latent_representation_image: gr.update(),
|
| 90 |
+
sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(),
|
| 91 |
+
sound2sound_with_text_state: sound2sound_with_text_dict,
|
| 92 |
+
virtual_instruments_state: virtual_instruments_dict}
|
| 93 |
+
else:
|
| 94 |
+
sound2sound_with_text_dict["origin_microphone_latent_representations"] = origin_latent_representations.tolist()
|
| 95 |
+
sound2sound_with_text_dict[
|
| 96 |
+
"sound2sound_origin_microphone_latent_representation_image"] = latent_representation_to_Gradio_image(
|
| 97 |
+
origin_latent_representations[0]).tolist()
|
| 98 |
+
sound2sound_with_text_dict[
|
| 99 |
+
"sound2sound_origin_microphone_quantized_latent_representation_image"] = latent_representation_to_Gradio_image(
|
| 100 |
+
quantized_origin_latent_representations[0]).tolist()
|
| 101 |
+
|
| 102 |
+
virtual_instruments = virtual_instruments_dict["virtual_instruments"]
|
| 103 |
+
virtual_instrument = {"condition": default_condition,
|
| 104 |
+
"negative_condition": default_condition, # care!!!
|
| 105 |
+
"CFG": 1,
|
| 106 |
+
"latent_representation": origin_latent_representations[0],
|
| 107 |
+
"quantized_latent_representation": quantized_origin_latent_representations[0],
|
| 108 |
+
"sampler": "ddim",
|
| 109 |
+
"signal": origin_audio,
|
| 110 |
+
"spectrogram_gradio_image": origin_flipped_log_spectrums[0]}
|
| 111 |
+
virtual_instruments["s2sre"] = virtual_instrument
|
| 112 |
+
virtual_instruments_dict["virtual_instruments"] = virtual_instruments
|
| 113 |
+
|
| 114 |
+
return {sound2sound_origin_spectrogram_upload_image: gr.update(),
|
| 115 |
+
sound2sound_origin_phase_upload_image: gr.update(),
|
| 116 |
+
sound2sound_origin_spectrogram_microphone_image: origin_flipped_log_spectrums[0],
|
| 117 |
+
sound2sound_origin_phase_microphone_image: origin_flipped_phases[0],
|
| 118 |
+
sound2sound_origin_upload_latent_representation_image: gr.update(),
|
| 119 |
+
sound2sound_origin_upload_quantized_latent_representation_image: gr.update(),
|
| 120 |
+
sound2sound_origin_microphone_latent_representation_image: latent_representation_to_Gradio_image(
|
| 121 |
+
origin_latent_representations[0]),
|
| 122 |
+
sound2sound_origin_microphone_quantized_latent_representation_image: latent_representation_to_Gradio_image(
|
| 123 |
+
quantized_origin_latent_representations[0]),
|
| 124 |
+
sound2sound_with_text_state: sound2sound_with_text_dict,
|
| 125 |
+
virtual_instruments_state: virtual_instruments_dict}
|
| 126 |
+
|
| 127 |
+
def sound2sound_sample(sound2sound_prompts, sound2sound_negative_prompts, sound2sound_batchsize,
|
| 128 |
+
sound2sound_guidance_scale, sound2sound_sampler,
|
| 129 |
+
sound2sound_sample_steps,
|
| 130 |
+
sound2sound_origin_source,
|
| 131 |
+
sound2sound_noising_strength, sound2sound_seed, sound2sound_dict, virtual_instruments_dict):
|
| 132 |
+
|
| 133 |
+
# input processing
|
| 134 |
+
sound2sound_seed = safe_int(sound2sound_seed, 12345678)
|
| 135 |
+
sound2sound_batchsize = int(sound2sound_batchsize)
|
| 136 |
+
noising_strength = sound2sound_noising_strength
|
| 137 |
+
sound2sound_sample_steps = int(sound2sound_sample_steps)
|
| 138 |
+
CFG = int(sound2sound_guidance_scale)
|
| 139 |
+
|
| 140 |
+
if sound2sound_origin_source == "upload":
|
| 141 |
+
origin_latent_representations = torch.tensor(
|
| 142 |
+
sound2sound_dict["origin_upload_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to(
|
| 143 |
+
device)
|
| 144 |
+
elif sound2sound_origin_source == "microphone":
|
| 145 |
+
origin_latent_representations = torch.tensor(
|
| 146 |
+
sound2sound_dict["origin_microphone_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to(
|
| 147 |
+
device)
|
| 148 |
+
else:
|
| 149 |
+
print("Input source not in ['upload', 'microphone']!")
|
| 150 |
+
raise NotImplementedError()
|
| 151 |
+
|
| 152 |
+
# sound2sound
|
| 153 |
+
text2sound_embedding = \
|
| 154 |
+
CLAP.get_text_features(**CLAP_tokenizer([sound2sound_prompts], padding=True, return_tensors="pt"))[0].to(
|
| 155 |
+
device)
|
| 156 |
+
|
| 157 |
+
mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
|
| 158 |
+
unconditional_condition = \
|
| 159 |
+
CLAP.get_text_features(**CLAP_tokenizer([sound2sound_negative_prompts], padding=True, return_tensors="pt"))[
|
| 160 |
+
0]
|
| 161 |
+
mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device))
|
| 162 |
+
|
| 163 |
+
normalized_sample_steps = int(sound2sound_sample_steps / noising_strength)
|
| 164 |
+
mySampler.respace(list(np.linspace(0, timesteps - 1, normalized_sample_steps, dtype=np.int32)))
|
| 165 |
+
|
| 166 |
+
condition = text2sound_embedding.repeat(sound2sound_batchsize, 1)
|
| 167 |
+
|
| 168 |
+
# Todo: remove-hard coding
|
| 169 |
+
width = origin_latent_representations.shape[-1]
|
| 170 |
+
new_sound_latent_representations, initial_noise = \
|
| 171 |
+
mySampler.img_guided_sample(model=uNet, shape=(sound2sound_batchsize, channels, height, width),
|
| 172 |
+
seed=sound2sound_seed,
|
| 173 |
+
noising_strength=noising_strength,
|
| 174 |
+
guide_img=origin_latent_representations, return_tensor=True,
|
| 175 |
+
condition=condition,
|
| 176 |
+
sampler=sound2sound_sampler)
|
| 177 |
+
|
| 178 |
+
new_sound_latent_representations = new_sound_latent_representations[-1]
|
| 179 |
+
|
| 180 |
+
# Quantize new sound latent representations
|
| 181 |
+
quantized_new_sound_latent_representations, loss, (_, _, _) = VAE_quantizer(new_sound_latent_representations)
|
| 182 |
+
|
| 183 |
+
new_sound_flipped_log_spectrums, new_sound_flipped_phases, new_sound_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder,
|
| 184 |
+
quantized_new_sound_latent_representations,
|
| 185 |
+
resolution=(
|
| 186 |
+
512,
|
| 187 |
+
width * VAE_scale),
|
| 188 |
+
original_STFT_batch=None
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
new_sound_latent_representation_gradio_images = []
|
| 194 |
+
new_sound_quantized_latent_representation_gradio_images = []
|
| 195 |
+
new_sound_spectrogram_gradio_images = []
|
| 196 |
+
new_sound_phase_gradio_images = []
|
| 197 |
+
new_sound_rec_signals_gradio = []
|
| 198 |
+
for i in range(sound2sound_batchsize):
|
| 199 |
+
new_sound_latent_representation_gradio_images.append(
|
| 200 |
+
latent_representation_to_Gradio_image(new_sound_latent_representations[i]))
|
| 201 |
+
new_sound_quantized_latent_representation_gradio_images.append(
|
| 202 |
+
latent_representation_to_Gradio_image(quantized_new_sound_latent_representations[i]))
|
| 203 |
+
new_sound_spectrogram_gradio_images.append(new_sound_flipped_log_spectrums[i])
|
| 204 |
+
new_sound_phase_gradio_images.append(new_sound_flipped_phases[i])
|
| 205 |
+
new_sound_rec_signals_gradio.append((sample_rate, new_sound_signals[i]))
|
| 206 |
+
|
| 207 |
+
sound2sound_dict[
|
| 208 |
+
"new_sound_latent_representation_gradio_images"] = new_sound_latent_representation_gradio_images
|
| 209 |
+
sound2sound_dict[
|
| 210 |
+
"new_sound_quantized_latent_representation_gradio_images"] = new_sound_quantized_latent_representation_gradio_images
|
| 211 |
+
sound2sound_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
|
| 212 |
+
sound2sound_dict["new_sound_phase_gradio_images"] = new_sound_phase_gradio_images
|
| 213 |
+
sound2sound_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
|
| 214 |
+
|
| 215 |
+
return {sound2sound_new_sound_latent_representation_image: latent_representation_to_Gradio_image(
|
| 216 |
+
new_sound_latent_representations[0]),
|
| 217 |
+
sound2sound_new_sound_quantized_latent_representation_image: latent_representation_to_Gradio_image(
|
| 218 |
+
quantized_new_sound_latent_representations[0]),
|
| 219 |
+
sound2sound_new_sound_spectrogram_image: new_sound_flipped_log_spectrums[0],
|
| 220 |
+
sound2sound_new_sound_phase_image: new_sound_flipped_phases[0],
|
| 221 |
+
sound2sound_new_sound_audio: (sample_rate, new_sound_signals[0]),
|
| 222 |
+
sound2sound_sample_index_slider: gr.update(minimum=0, maximum=sound2sound_batchsize - 1, value=0,
|
| 223 |
+
step=1.0,
|
| 224 |
+
visible=True,
|
| 225 |
+
label="Sample index",
|
| 226 |
+
info="Swipe to view other samples"),
|
| 227 |
+
sound2sound_seed_textbox: sound2sound_seed,
|
| 228 |
+
sound2sound_with_text_state: sound2sound_dict,
|
| 229 |
+
virtual_instruments_state: virtual_instruments_dict}
|
| 230 |
+
|
| 231 |
+
def show_sound2sound_sample(sound2sound_sample_index, sound2sound_with_text_dict):
|
| 232 |
+
sample_index = int(sound2sound_sample_index)
|
| 233 |
+
return {sound2sound_new_sound_latent_representation_image:
|
| 234 |
+
sound2sound_with_text_dict["new_sound_latent_representation_gradio_images"][sample_index],
|
| 235 |
+
sound2sound_new_sound_quantized_latent_representation_image:
|
| 236 |
+
sound2sound_with_text_dict["new_sound_quantized_latent_representation_gradio_images"][sample_index],
|
| 237 |
+
sound2sound_new_sound_spectrogram_image: sound2sound_with_text_dict["new_sound_spectrogram_gradio_images"][
|
| 238 |
+
sample_index],
|
| 239 |
+
sound2sound_new_sound_phase_image: sound2sound_with_text_dict["new_sound_phase_gradio_images"][
|
| 240 |
+
sample_index],
|
| 241 |
+
sound2sound_new_sound_audio: sound2sound_with_text_dict["new_sound_rec_signals_gradio"][sample_index]}
|
| 242 |
+
|
| 243 |
+
def sound2sound_switch_origin_source(sound2sound_origin_source):
|
| 244 |
+
|
| 245 |
+
if sound2sound_origin_source == "upload":
|
| 246 |
+
return {sound2sound_origin_upload_audio: gr.update(visible=True),
|
| 247 |
+
sound2sound_origin_microphone_audio: gr.update(visible=False),
|
| 248 |
+
sound2sound_origin_spectrogram_upload_image: gr.update(visible=True),
|
| 249 |
+
sound2sound_origin_phase_upload_image: gr.update(visible=True),
|
| 250 |
+
sound2sound_origin_spectrogram_microphone_image: gr.update(visible=False),
|
| 251 |
+
sound2sound_origin_phase_microphone_image: gr.update(visible=False),
|
| 252 |
+
sound2sound_origin_upload_latent_representation_image: gr.update(visible=True),
|
| 253 |
+
sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=True),
|
| 254 |
+
sound2sound_origin_microphone_latent_representation_image: gr.update(visible=False),
|
| 255 |
+
sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=False)}
|
| 256 |
+
elif sound2sound_origin_source == "microphone":
|
| 257 |
+
return {sound2sound_origin_upload_audio: gr.update(visible=False),
|
| 258 |
+
sound2sound_origin_microphone_audio: gr.update(visible=True),
|
| 259 |
+
sound2sound_origin_spectrogram_upload_image: gr.update(visible=False),
|
| 260 |
+
sound2sound_origin_phase_upload_image: gr.update(visible=False),
|
| 261 |
+
sound2sound_origin_spectrogram_microphone_image: gr.update(visible=True),
|
| 262 |
+
sound2sound_origin_phase_microphone_image: gr.update(visible=True),
|
| 263 |
+
sound2sound_origin_upload_latent_representation_image: gr.update(visible=False),
|
| 264 |
+
sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=False),
|
| 265 |
+
sound2sound_origin_microphone_latent_representation_image: gr.update(visible=True),
|
| 266 |
+
sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=True)}
|
| 267 |
+
else:
|
| 268 |
+
print("Input source not in ['upload', 'microphone']!")
|
| 269 |
+
|
| 270 |
+
with gr.Tab("Sound2Sound"):
|
| 271 |
+
gr.Markdown("Generate new sound based on a given sound!")
|
| 272 |
+
with gr.Row(variant="panel"):
|
| 273 |
+
with gr.Column(scale=3):
|
| 274 |
+
sound2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ")
|
| 275 |
+
text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
|
| 276 |
+
|
| 277 |
+
with gr.Column(scale=1):
|
| 278 |
+
sound2sound_sample_button = gr.Button(variant="primary", value="Generate", scale=1)
|
| 279 |
+
|
| 280 |
+
sound2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
|
| 281 |
+
label="Sample index",
|
| 282 |
+
info="Swipe to view other samples")
|
| 283 |
+
|
| 284 |
+
with gr.Row(variant="panel"):
|
| 285 |
+
with gr.Column(scale=1):
|
| 286 |
+
with gr.Tab("Origin sound"):
|
| 287 |
+
sound2sound_duration_slider = gradioWebUI.get_duration_slider()
|
| 288 |
+
sound2sound_origin_source_radio = gr.Radio(choices=["upload", "microphone"], value="upload",
|
| 289 |
+
label="Input source")
|
| 290 |
+
|
| 291 |
+
sound2sound_origin_upload_audio = gr.Audio(type="numpy", label="Upload", source="upload",
|
| 292 |
+
interactive=True, visible=True)
|
| 293 |
+
sound2sound_origin_microphone_audio = gr.Audio(type="numpy", label="Record", source="microphone",
|
| 294 |
+
interactive=True, visible=False)
|
| 295 |
+
with gr.Row(variant="panel"):
|
| 296 |
+
sound2sound_origin_spectrogram_upload_image = gr.Image(label="Original upload spectrogram",
|
| 297 |
+
type="numpy", height=600,
|
| 298 |
+
visible=True)
|
| 299 |
+
sound2sound_origin_phase_upload_image = gr.Image(label="Original upload phase",
|
| 300 |
+
type="numpy", height=600,
|
| 301 |
+
visible=True)
|
| 302 |
+
sound2sound_origin_spectrogram_microphone_image = gr.Image(label="Original microphone spectrogram",
|
| 303 |
+
type="numpy", height=600,
|
| 304 |
+
visible=False)
|
| 305 |
+
sound2sound_origin_phase_microphone_image = gr.Image(label="Original microphone phase",
|
| 306 |
+
type="numpy", height=600,
|
| 307 |
+
visible=False)
|
| 308 |
+
|
| 309 |
+
with gr.Tab("Sound2sound settings"):
|
| 310 |
+
sound2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
|
| 311 |
+
sound2sound_sampler_radio = gradioWebUI.get_sampler_radio()
|
| 312 |
+
sound2sound_batchsize_slider = gradioWebUI.get_batchsize_slider()
|
| 313 |
+
sound2sound_noising_strength_slider = gradioWebUI.get_noising_strength_slider()
|
| 314 |
+
sound2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
|
| 315 |
+
sound2sound_seed_textbox = gradioWebUI.get_seed_textbox()
|
| 316 |
+
|
| 317 |
+
with gr.Column(scale=1):
|
| 318 |
+
sound2sound_new_sound_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False)
|
| 319 |
+
with gr.Row(variant="panel"):
|
| 320 |
+
sound2sound_new_sound_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy",
|
| 321 |
+
height=600, scale=1)
|
| 322 |
+
sound2sound_new_sound_phase_image = gr.Image(label="New sound phase", type="numpy",
|
| 323 |
+
height=600, scale=1)
|
| 324 |
+
|
| 325 |
+
with gr.Row(variant="panel"):
|
| 326 |
+
sound2sound_origin_upload_latent_representation_image = gr.Image(label="Original latent representation",
|
| 327 |
+
type="numpy", height=800,
|
| 328 |
+
visible=True)
|
| 329 |
+
sound2sound_origin_upload_quantized_latent_representation_image = gr.Image(
|
| 330 |
+
label="Original quantized latent representation", type="numpy", height=800, visible=True)
|
| 331 |
+
|
| 332 |
+
sound2sound_origin_microphone_latent_representation_image = gr.Image(label="Original latent representation",
|
| 333 |
+
type="numpy", height=800,
|
| 334 |
+
visible=False)
|
| 335 |
+
sound2sound_origin_microphone_quantized_latent_representation_image = gr.Image(
|
| 336 |
+
label="Original quantized latent representation", type="numpy", height=800, visible=False)
|
| 337 |
+
|
| 338 |
+
sound2sound_new_sound_latent_representation_image = gr.Image(label="New latent representation",
|
| 339 |
+
type="numpy", height=800)
|
| 340 |
+
sound2sound_new_sound_quantized_latent_representation_image = gr.Image(
|
| 341 |
+
label="New sound quantized latent representation", type="numpy", height=800)
|
| 342 |
+
|
| 343 |
+
sound2sound_origin_upload_audio.change(receive_upload_origin_audio,
|
| 344 |
+
inputs=[sound2sound_duration_slider, sound2sound_origin_source_radio,
|
| 345 |
+
sound2sound_origin_upload_audio,
|
| 346 |
+
sound2sound_origin_microphone_audio, sound2sound_with_text_state,
|
| 347 |
+
virtual_instruments_state],
|
| 348 |
+
outputs=[sound2sound_origin_spectrogram_upload_image,
|
| 349 |
+
sound2sound_origin_phase_upload_image,
|
| 350 |
+
sound2sound_origin_spectrogram_microphone_image,
|
| 351 |
+
sound2sound_origin_phase_microphone_image,
|
| 352 |
+
sound2sound_origin_upload_latent_representation_image,
|
| 353 |
+
sound2sound_origin_upload_quantized_latent_representation_image,
|
| 354 |
+
sound2sound_origin_microphone_latent_representation_image,
|
| 355 |
+
sound2sound_origin_microphone_quantized_latent_representation_image,
|
| 356 |
+
sound2sound_with_text_state,
|
| 357 |
+
virtual_instruments_state])
|
| 358 |
+
|
| 359 |
+
sound2sound_origin_microphone_audio.change(receive_upload_origin_audio,
|
| 360 |
+
inputs=[sound2sound_duration_slider,
|
| 361 |
+
sound2sound_origin_source_radio, sound2sound_origin_upload_audio,
|
| 362 |
+
sound2sound_origin_microphone_audio, sound2sound_with_text_state,
|
| 363 |
+
virtual_instruments_state],
|
| 364 |
+
outputs=[sound2sound_origin_spectrogram_upload_image,
|
| 365 |
+
sound2sound_origin_phase_upload_image,
|
| 366 |
+
sound2sound_origin_spectrogram_microphone_image,
|
| 367 |
+
sound2sound_origin_phase_microphone_image,
|
| 368 |
+
sound2sound_origin_upload_latent_representation_image,
|
| 369 |
+
sound2sound_origin_upload_quantized_latent_representation_image,
|
| 370 |
+
sound2sound_origin_microphone_latent_representation_image,
|
| 371 |
+
sound2sound_origin_microphone_quantized_latent_representation_image,
|
| 372 |
+
sound2sound_with_text_state,
|
| 373 |
+
virtual_instruments_state])
|
| 374 |
+
|
| 375 |
+
sound2sound_sample_button.click(sound2sound_sample,
|
| 376 |
+
inputs=[sound2sound_prompts_textbox,
|
| 377 |
+
text2sound_negative_prompts_textbox,
|
| 378 |
+
sound2sound_batchsize_slider,
|
| 379 |
+
sound2sound_guidance_scale_slider,
|
| 380 |
+
sound2sound_sampler_radio,
|
| 381 |
+
sound2sound_sample_steps_slider,
|
| 382 |
+
sound2sound_origin_source_radio,
|
| 383 |
+
sound2sound_noising_strength_slider,
|
| 384 |
+
sound2sound_seed_textbox,
|
| 385 |
+
sound2sound_with_text_state,
|
| 386 |
+
virtual_instruments_state],
|
| 387 |
+
outputs=[sound2sound_new_sound_latent_representation_image,
|
| 388 |
+
sound2sound_new_sound_quantized_latent_representation_image,
|
| 389 |
+
sound2sound_new_sound_spectrogram_image,
|
| 390 |
+
sound2sound_new_sound_phase_image,
|
| 391 |
+
sound2sound_new_sound_audio,
|
| 392 |
+
sound2sound_sample_index_slider,
|
| 393 |
+
sound2sound_seed_textbox,
|
| 394 |
+
sound2sound_with_text_state,
|
| 395 |
+
virtual_instruments_state])
|
| 396 |
+
|
| 397 |
+
sound2sound_sample_index_slider.change(show_sound2sound_sample,
|
| 398 |
+
inputs=[sound2sound_sample_index_slider, sound2sound_with_text_state],
|
| 399 |
+
outputs=[sound2sound_new_sound_latent_representation_image,
|
| 400 |
+
sound2sound_new_sound_quantized_latent_representation_image,
|
| 401 |
+
sound2sound_new_sound_spectrogram_image,
|
| 402 |
+
sound2sound_new_sound_phase_image,
|
| 403 |
+
sound2sound_new_sound_audio])
|
| 404 |
+
|
| 405 |
+
sound2sound_origin_source_radio.change(sound2sound_switch_origin_source,
|
| 406 |
+
inputs=[sound2sound_origin_source_radio],
|
| 407 |
+
outputs=[sound2sound_origin_upload_audio,
|
| 408 |
+
sound2sound_origin_microphone_audio,
|
| 409 |
+
sound2sound_origin_spectrogram_upload_image,
|
| 410 |
+
sound2sound_origin_phase_upload_image,
|
| 411 |
+
sound2sound_origin_spectrogram_microphone_image,
|
| 412 |
+
sound2sound_origin_phase_microphone_image,
|
| 413 |
+
sound2sound_origin_upload_latent_representation_image,
|
| 414 |
+
sound2sound_origin_upload_quantized_latent_representation_image,
|
| 415 |
+
sound2sound_origin_microphone_latent_representation_image,
|
| 416 |
+
sound2sound_origin_microphone_quantized_latent_representation_image])
|
webUI/natural_language_guided/super_resolution_with_text.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import gradio as gr
|
| 5 |
+
from scipy.ndimage import zoom
|
| 6 |
+
|
| 7 |
+
from model.DiffSynthSampler import DiffSynthSampler
|
| 8 |
+
from tools import adjust_audio_length, rescale, safe_int, pad_STFT, encode_stft
|
| 9 |
+
from webUI.natural_language_guided_STFT.utils import latent_representation_to_Gradio_image
|
| 10 |
+
from webUI.natural_language_guided_STFT.utils import InputBatch2Encode_STFT, encodeBatch2GradioOutput_STFT
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_super_resolution_with_text_module(gradioWebUI, inpaintWithText_state):
|
| 14 |
+
# Load configurations
|
| 15 |
+
uNet = gradioWebUI.uNet
|
| 16 |
+
freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
|
| 17 |
+
VAE_scale = gradioWebUI.VAE_scale
|
| 18 |
+
height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels
|
| 19 |
+
timesteps = gradioWebUI.timesteps
|
| 20 |
+
VAE_encoder = gradioWebUI.VAE_encoder
|
| 21 |
+
VAE_quantizer = gradioWebUI.VAE_quantizer
|
| 22 |
+
VAE_decoder = gradioWebUI.VAE_decoder
|
| 23 |
+
CLAP = gradioWebUI.CLAP
|
| 24 |
+
CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
|
| 25 |
+
device = gradioWebUI.device
|
| 26 |
+
squared = gradioWebUI.squared
|
| 27 |
+
sample_rate = gradioWebUI.sample_rate
|
| 28 |
+
noise_strategy = gradioWebUI.noise_strategy
|
| 29 |
+
|
| 30 |
+
def receive_uopoad_origin_audio(sound2sound_duration, sound2sound_origin_source, sound2sound_origin_upload, sound2sound_origin_microphone,
|
| 31 |
+
inpaintWithText_dict):
|
| 32 |
+
|
| 33 |
+
if sound2sound_origin_source == "upload":
|
| 34 |
+
origin_sr, origin_audio = sound2sound_origin_upload
|
| 35 |
+
else:
|
| 36 |
+
origin_sr, origin_audio = sound2sound_origin_microphone
|
| 37 |
+
|
| 38 |
+
origin_audio = origin_audio / np.max(np.abs(origin_audio))
|
| 39 |
+
|
| 40 |
+
width = int(time_resolution*((sound2sound_duration+1)/4) / VAE_scale)
|
| 41 |
+
audio_length = 256 * (VAE_scale * width - 1)
|
| 42 |
+
origin_audio = adjust_audio_length(origin_audio, audio_length, origin_sr, sample_rate)
|
| 43 |
+
|
| 44 |
+
D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024)
|
| 45 |
+
padded_D = pad_STFT(D)
|
| 46 |
+
encoded_D = encode_stft(padded_D)
|
| 47 |
+
|
| 48 |
+
# Todo: justify batchsize to 1
|
| 49 |
+
origin_spectrogram_batch_tensor = torch.from_numpy(
|
| 50 |
+
np.repeat(encoded_D[np.newaxis, :, :, :], 1, axis=0)).float().to(device)
|
| 51 |
+
|
| 52 |
+
# Todo: remove hard-coding
|
| 53 |
+
origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, origin_latent_representations, quantized_origin_latent_representations = InputBatch2Encode_STFT(
|
| 54 |
+
VAE_encoder, origin_spectrogram_batch_tensor, resolution=(512, width * VAE_scale), quantizer=VAE_quantizer, squared=squared)
|
| 55 |
+
|
| 56 |
+
if sound2sound_origin_source == "upload":
|
| 57 |
+
inpaintWithText_dict["origin_upload_latent_representations"] = origin_latent_representations.tolist()
|
| 58 |
+
inpaintWithText_dict[
|
| 59 |
+
"sound2sound_origin_upload_latent_representation_image"] = latent_representation_to_Gradio_image(
|
| 60 |
+
origin_latent_representations[0]).tolist()
|
| 61 |
+
inpaintWithText_dict[
|
| 62 |
+
"sound2sound_origin_upload_quantized_latent_representation_image"] = latent_representation_to_Gradio_image(
|
| 63 |
+
quantized_origin_latent_representations[0]).tolist()
|
| 64 |
+
return {sound2sound_origin_spectrogram_upload_image: origin_flipped_log_spectrums[0],
|
| 65 |
+
sound2sound_origin_phase_upload_image: origin_flipped_phases[0],
|
| 66 |
+
sound2sound_origin_spectrogram_microphone_image: gr.update(),
|
| 67 |
+
sound2sound_origin_phase_microphone_image: gr.update(),
|
| 68 |
+
sound2sound_origin_upload_latent_representation_image: latent_representation_to_Gradio_image(
|
| 69 |
+
origin_latent_representations[0]),
|
| 70 |
+
sound2sound_origin_upload_quantized_latent_representation_image: latent_representation_to_Gradio_image(
|
| 71 |
+
quantized_origin_latent_representations[0]),
|
| 72 |
+
sound2sound_origin_microphone_latent_representation_image: gr.update(),
|
| 73 |
+
sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(),
|
| 74 |
+
inpaintWithText_state: inpaintWithText_dict}
|
| 75 |
+
else:
|
| 76 |
+
inpaintWithText_dict["origin_microphone_latent_representations"] = origin_latent_representations.tolist()
|
| 77 |
+
inpaintWithText_dict[
|
| 78 |
+
"sound2sound_origin_microphone_latent_representation_image"] = latent_representation_to_Gradio_image(
|
| 79 |
+
origin_latent_representations[0]).tolist()
|
| 80 |
+
inpaintWithText_dict[
|
| 81 |
+
"sound2sound_origin_microphone_quantized_latent_representation_image"] = latent_representation_to_Gradio_image(
|
| 82 |
+
quantized_origin_latent_representations[0]).tolist()
|
| 83 |
+
return {sound2sound_origin_spectrogram_upload_image: origin_flipped_log_spectrums[0],
|
| 84 |
+
sound2sound_origin_phase_upload_image: origin_flipped_phases[0],
|
| 85 |
+
sound2sound_origin_spectrogram_microphone_image: gr.update(),
|
| 86 |
+
sound2sound_origin_phase_microphone_image: gr.update(),
|
| 87 |
+
sound2sound_origin_upload_latent_representation_image: latent_representation_to_Gradio_image(
|
| 88 |
+
origin_latent_representations[0]),
|
| 89 |
+
sound2sound_origin_upload_quantized_latent_representation_image: latent_representation_to_Gradio_image(
|
| 90 |
+
quantized_origin_latent_representations[0]),
|
| 91 |
+
sound2sound_origin_microphone_latent_representation_image: gr.update(),
|
| 92 |
+
sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(),
|
| 93 |
+
inpaintWithText_state: inpaintWithText_dict}
|
| 94 |
+
|
| 95 |
+
def sound2sound_sample(sound2sound_origin_spectrogram_upload, sound2sound_origin_spectrogram_microphone,
|
| 96 |
+
text2sound_prompts, text2sound_negative_prompts, sound2sound_batchsize,
|
| 97 |
+
sound2sound_guidance_scale, sound2sound_sampler,
|
| 98 |
+
sound2sound_sample_steps, sound2sound_origin_source,
|
| 99 |
+
sound2sound_noising_strength, sound2sound_seed, sound2sound_inpaint_area, inpaintWithText_dict
|
| 100 |
+
):
|
| 101 |
+
|
| 102 |
+
# input preprocessing
|
| 103 |
+
sound2sound_seed = safe_int(sound2sound_seed, 12345678)
|
| 104 |
+
sound2sound_batchsize = int(sound2sound_batchsize)
|
| 105 |
+
noising_strength = sound2sound_noising_strength
|
| 106 |
+
sound2sound_sample_steps = int(sound2sound_sample_steps)
|
| 107 |
+
CFG = int(sound2sound_guidance_scale)
|
| 108 |
+
|
| 109 |
+
text2sound_embedding = \
|
| 110 |
+
CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(device)
|
| 111 |
+
|
| 112 |
+
if sound2sound_origin_source == "upload":
|
| 113 |
+
origin_latent_representations = torch.tensor(
|
| 114 |
+
inpaintWithText_dict["origin_upload_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to(
|
| 115 |
+
device)
|
| 116 |
+
elif sound2sound_origin_source == "microphone":
|
| 117 |
+
origin_latent_representations = torch.tensor(
|
| 118 |
+
inpaintWithText_dict["origin_microphone_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to(
|
| 119 |
+
device)
|
| 120 |
+
else:
|
| 121 |
+
print("Input source not in ['upload', 'microphone']!")
|
| 122 |
+
raise NotImplementedError()
|
| 123 |
+
|
| 124 |
+
high_resolution_latent_representations = torch.zeros((sound2sound_batchsize, channels, 256, 64)).to(device)
|
| 125 |
+
high_resolution_latent_representations[:, :, :128, :] = origin_latent_representations
|
| 126 |
+
latent_mask = np.ones((256, 64))
|
| 127 |
+
latent_mask[192:, :] = 0.0
|
| 128 |
+
print(f"latent_mask mean: {np.mean(latent_mask)}")
|
| 129 |
+
|
| 130 |
+
if sound2sound_inpaint_area == "inpaint masked":
|
| 131 |
+
latent_mask = 1 - latent_mask
|
| 132 |
+
latent_mask = torch.from_numpy(latent_mask).unsqueeze(0).unsqueeze(1).repeat(sound2sound_batchsize, channels, 1,
|
| 133 |
+
1).float().to(device)
|
| 134 |
+
latent_mask = torch.flip(latent_mask, [2])
|
| 135 |
+
|
| 136 |
+
mySampler = DiffSynthSampler(timesteps, height=height*2, channels=channels, noise_strategy=noise_strategy)
|
| 137 |
+
unconditional_condition = \
|
| 138 |
+
CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[0]
|
| 139 |
+
mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device))
|
| 140 |
+
|
| 141 |
+
normalized_sample_steps = int(sound2sound_sample_steps / noising_strength)
|
| 142 |
+
|
| 143 |
+
mySampler.respace(list(np.linspace(0, timesteps - 1, normalized_sample_steps, dtype=np.int32)))
|
| 144 |
+
|
| 145 |
+
# Todo: remove hard-coding
|
| 146 |
+
width = high_resolution_latent_representations.shape[-1]
|
| 147 |
+
condition = text2sound_embedding.repeat(sound2sound_batchsize, 1)
|
| 148 |
+
|
| 149 |
+
new_sound_latent_representations, initial_noise = \
|
| 150 |
+
mySampler.inpaint_sample(model=uNet, shape=(sound2sound_batchsize, channels, height*2, width),
|
| 151 |
+
seed=sound2sound_seed,
|
| 152 |
+
noising_strength=noising_strength,
|
| 153 |
+
guide_img=high_resolution_latent_representations, mask=latent_mask, return_tensor=True,
|
| 154 |
+
condition=condition, sampler=sound2sound_sampler)
|
| 155 |
+
|
| 156 |
+
new_sound_latent_representations = new_sound_latent_representations[-1]
|
| 157 |
+
|
| 158 |
+
# Quantize new sound latent representations
|
| 159 |
+
quantized_new_sound_latent_representations, loss, (_, _, _) = VAE_quantizer(new_sound_latent_representations)
|
| 160 |
+
new_sound_flipped_log_spectrums, new_sound_flipped_phases, new_sound_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder,
|
| 161 |
+
quantized_new_sound_latent_representations,
|
| 162 |
+
resolution=(
|
| 163 |
+
1024,
|
| 164 |
+
width * VAE_scale),
|
| 165 |
+
original_STFT_batch=None
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
new_sound_latent_representation_gradio_images = []
|
| 169 |
+
new_sound_quantized_latent_representation_gradio_images = []
|
| 170 |
+
new_sound_spectrogram_gradio_images = []
|
| 171 |
+
new_sound_phase_gradio_images = []
|
| 172 |
+
new_sound_rec_signals_gradio = []
|
| 173 |
+
for i in range(sound2sound_batchsize):
|
| 174 |
+
new_sound_latent_representation_gradio_images.append(
|
| 175 |
+
latent_representation_to_Gradio_image(new_sound_latent_representations[i]))
|
| 176 |
+
new_sound_quantized_latent_representation_gradio_images.append(
|
| 177 |
+
latent_representation_to_Gradio_image(quantized_new_sound_latent_representations[i]))
|
| 178 |
+
new_sound_spectrogram_gradio_images.append(new_sound_flipped_log_spectrums[i])
|
| 179 |
+
new_sound_phase_gradio_images.append(new_sound_flipped_phases[i])
|
| 180 |
+
new_sound_rec_signals_gradio.append((sample_rate, new_sound_signals[i]))
|
| 181 |
+
|
| 182 |
+
inpaintWithText_dict[
|
| 183 |
+
"new_sound_latent_representation_gradio_images"] = new_sound_latent_representation_gradio_images
|
| 184 |
+
inpaintWithText_dict[
|
| 185 |
+
"new_sound_quantized_latent_representation_gradio_images"] = new_sound_quantized_latent_representation_gradio_images
|
| 186 |
+
inpaintWithText_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
|
| 187 |
+
inpaintWithText_dict["new_sound_phase_gradio_images"] = new_sound_phase_gradio_images
|
| 188 |
+
inpaintWithText_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
|
| 189 |
+
|
| 190 |
+
return {sound2sound_new_sound_latent_representation_image: latent_representation_to_Gradio_image(
|
| 191 |
+
new_sound_latent_representations[0]),
|
| 192 |
+
sound2sound_new_sound_quantized_latent_representation_image: latent_representation_to_Gradio_image(
|
| 193 |
+
quantized_new_sound_latent_representations[0]),
|
| 194 |
+
sound2sound_new_sound_spectrogram_image: new_sound_flipped_log_spectrums[0],
|
| 195 |
+
sound2sound_new_sound_phase_image: new_sound_flipped_phases[0],
|
| 196 |
+
sound2sound_new_sound_audio: (sample_rate, new_sound_signals[0]),
|
| 197 |
+
sound2sound_sample_index_slider: gr.update(minimum=0, maximum=sound2sound_batchsize - 1, value=0,
|
| 198 |
+
step=1.0,
|
| 199 |
+
visible=True,
|
| 200 |
+
label="Sample index",
|
| 201 |
+
info="Swipe to view other samples"),
|
| 202 |
+
sound2sound_seed_textbox: sound2sound_seed,
|
| 203 |
+
inpaintWithText_state: inpaintWithText_dict}
|
| 204 |
+
|
| 205 |
+
def show_sound2sound_sample(sound2sound_sample_index, inpaintWithText_dict):
|
| 206 |
+
sample_index = int(sound2sound_sample_index)
|
| 207 |
+
return {sound2sound_new_sound_latent_representation_image:
|
| 208 |
+
inpaintWithText_dict["new_sound_latent_representation_gradio_images"][sample_index],
|
| 209 |
+
sound2sound_new_sound_quantized_latent_representation_image:
|
| 210 |
+
inpaintWithText_dict["new_sound_quantized_latent_representation_gradio_images"][sample_index],
|
| 211 |
+
sound2sound_new_sound_spectrogram_image: inpaintWithText_dict["new_sound_spectrogram_gradio_images"][
|
| 212 |
+
sample_index],
|
| 213 |
+
sound2sound_new_sound_phase_image: inpaintWithText_dict["new_sound_phase_gradio_images"][
|
| 214 |
+
sample_index],
|
| 215 |
+
sound2sound_new_sound_audio: inpaintWithText_dict["new_sound_rec_signals_gradio"][sample_index]}
|
| 216 |
+
|
| 217 |
+
def sound2sound_switch_origin_source(sound2sound_origin_source):
|
| 218 |
+
|
| 219 |
+
if sound2sound_origin_source == "upload":
|
| 220 |
+
return {sound2sound_origin_upload_audio: gr.update(visible=True),
|
| 221 |
+
sound2sound_origin_microphone_audio: gr.update(visible=False),
|
| 222 |
+
sound2sound_origin_spectrogram_upload_image: gr.update(visible=True),
|
| 223 |
+
sound2sound_origin_phase_upload_image: gr.update(visible=True),
|
| 224 |
+
sound2sound_origin_spectrogram_microphone_image: gr.update(visible=False),
|
| 225 |
+
sound2sound_origin_phase_microphone_image: gr.update(visible=False),
|
| 226 |
+
sound2sound_origin_upload_latent_representation_image: gr.update(visible=True),
|
| 227 |
+
sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=True),
|
| 228 |
+
sound2sound_origin_microphone_latent_representation_image: gr.update(visible=False),
|
| 229 |
+
sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=False)}
|
| 230 |
+
elif sound2sound_origin_source == "microphone":
|
| 231 |
+
return {sound2sound_origin_upload_audio: gr.update(visible=False),
|
| 232 |
+
sound2sound_origin_microphone_audio: gr.update(visible=True),
|
| 233 |
+
sound2sound_origin_spectrogram_upload_image: gr.update(visible=False),
|
| 234 |
+
sound2sound_origin_phase_upload_image: gr.update(visible=False),
|
| 235 |
+
sound2sound_origin_spectrogram_microphone_image: gr.update(visible=True),
|
| 236 |
+
sound2sound_origin_phase_microphone_image: gr.update(visible=True),
|
| 237 |
+
sound2sound_origin_upload_latent_representation_image: gr.update(visible=False),
|
| 238 |
+
sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=False),
|
| 239 |
+
sound2sound_origin_microphone_latent_representation_image: gr.update(visible=True),
|
| 240 |
+
sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=True)}
|
| 241 |
+
else:
|
| 242 |
+
print("Input source not in ['upload', 'microphone']!")
|
| 243 |
+
|
| 244 |
+
with gr.Tab("Super Resolution"):
|
| 245 |
+
gr.Markdown("Select the area to inpaint and use the prompt to guide the synthesis of a new sound!")
|
| 246 |
+
with gr.Row(variant="panel"):
|
| 247 |
+
with gr.Column(scale=3):
|
| 248 |
+
text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ")
|
| 249 |
+
text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
|
| 250 |
+
|
| 251 |
+
with gr.Column(scale=1):
|
| 252 |
+
sound2sound_sample_button = gr.Button(variant="primary", value="Generate", scale=1)
|
| 253 |
+
|
| 254 |
+
sound2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
|
| 255 |
+
label="Sample index",
|
| 256 |
+
info="Swipe to view other samples")
|
| 257 |
+
|
| 258 |
+
with gr.Row(variant="panel"):
|
| 259 |
+
with gr.Column(scale=1):
|
| 260 |
+
with gr.Tab("Origin sound"):
|
| 261 |
+
sound2sound_duration_slider = gradioWebUI.get_duration_slider()
|
| 262 |
+
sound2sound_origin_source_radio = gr.Radio(choices=["upload", "microphone"], value="upload",
|
| 263 |
+
label="Input source")
|
| 264 |
+
|
| 265 |
+
sound2sound_origin_upload_audio = gr.Audio(type="numpy", label="Upload", source="upload",
|
| 266 |
+
interactive=True, visible=True)
|
| 267 |
+
sound2sound_origin_microphone_audio = gr.Audio(type="numpy", label="Record", source="microphone",
|
| 268 |
+
interactive=True, visible=False)
|
| 269 |
+
with gr.Row(variant="panel"):
|
| 270 |
+
sound2sound_origin_spectrogram_upload_image = gr.Image(label="Original upload spectrogram",
|
| 271 |
+
type="numpy", height=600,
|
| 272 |
+
visible=True, tool="sketch")
|
| 273 |
+
sound2sound_origin_phase_upload_image = gr.Image(label="Original upload phase",
|
| 274 |
+
type="numpy", height=600,
|
| 275 |
+
visible=True)
|
| 276 |
+
sound2sound_origin_spectrogram_microphone_image = gr.Image(label="Original microphone spectrogram",
|
| 277 |
+
type="numpy", height=600,
|
| 278 |
+
visible=False, tool="sketch")
|
| 279 |
+
sound2sound_origin_phase_microphone_image = gr.Image(label="Original microphone phase",
|
| 280 |
+
type="numpy", height=600,
|
| 281 |
+
visible=False)
|
| 282 |
+
sound2sound_inpaint_area_radio = gr.Radio(choices=["inpaint masked", "inpaint not masked"],
|
| 283 |
+
value="inpaint masked")
|
| 284 |
+
|
| 285 |
+
with gr.Tab("Sound2sound settings"):
|
| 286 |
+
sound2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
|
| 287 |
+
sound2sound_sampler_radio = gradioWebUI.get_sampler_radio()
|
| 288 |
+
sound2sound_batchsize_slider = gradioWebUI.get_batchsize_slider()
|
| 289 |
+
sound2sound_noising_strength_slider = gradioWebUI.get_noising_strength_slider(default_noising_strength=1.0)
|
| 290 |
+
sound2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
|
| 291 |
+
sound2sound_seed_textbox = gradioWebUI.get_seed_textbox()
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
with gr.Column(scale=1):
|
| 295 |
+
sound2sound_new_sound_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False)
|
| 296 |
+
with gr.Row(variant="panel"):
|
| 297 |
+
sound2sound_new_sound_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy",
|
| 298 |
+
height=1200, scale=1)
|
| 299 |
+
sound2sound_new_sound_phase_image = gr.Image(label="New sound phase", type="numpy",
|
| 300 |
+
height=1200, scale=1)
|
| 301 |
+
|
| 302 |
+
with gr.Row(variant="panel"):
|
| 303 |
+
sound2sound_origin_upload_latent_representation_image = gr.Image(label="Original latent representation",
|
| 304 |
+
type="numpy", height=1200,
|
| 305 |
+
visible=True)
|
| 306 |
+
sound2sound_origin_upload_quantized_latent_representation_image = gr.Image(
|
| 307 |
+
label="Original quantized latent representation", type="numpy", height=1200, visible=True)
|
| 308 |
+
|
| 309 |
+
sound2sound_origin_microphone_latent_representation_image = gr.Image(label="Original latent representation",
|
| 310 |
+
type="numpy", height=1200,
|
| 311 |
+
visible=False)
|
| 312 |
+
sound2sound_origin_microphone_quantized_latent_representation_image = gr.Image(
|
| 313 |
+
label="Original quantized latent representation", type="numpy", height=1200, visible=False)
|
| 314 |
+
|
| 315 |
+
sound2sound_new_sound_latent_representation_image = gr.Image(label="New latent representation",
|
| 316 |
+
type="numpy", height=1200)
|
| 317 |
+
sound2sound_new_sound_quantized_latent_representation_image = gr.Image(
|
| 318 |
+
label="New sound quantized latent representation", type="numpy", height=1200)
|
| 319 |
+
|
| 320 |
+
sound2sound_origin_upload_audio.change(receive_uopoad_origin_audio,
|
| 321 |
+
inputs=[sound2sound_duration_slider, sound2sound_origin_source_radio, sound2sound_origin_upload_audio,
|
| 322 |
+
sound2sound_origin_microphone_audio, inpaintWithText_state],
|
| 323 |
+
outputs=[sound2sound_origin_spectrogram_upload_image,
|
| 324 |
+
sound2sound_origin_phase_upload_image,
|
| 325 |
+
sound2sound_origin_spectrogram_microphone_image,
|
| 326 |
+
sound2sound_origin_phase_microphone_image,
|
| 327 |
+
sound2sound_origin_upload_latent_representation_image,
|
| 328 |
+
sound2sound_origin_upload_quantized_latent_representation_image,
|
| 329 |
+
sound2sound_origin_microphone_latent_representation_image,
|
| 330 |
+
sound2sound_origin_microphone_quantized_latent_representation_image,
|
| 331 |
+
inpaintWithText_state])
|
| 332 |
+
sound2sound_origin_microphone_audio.change(receive_uopoad_origin_audio,
|
| 333 |
+
inputs=[sound2sound_duration_slider, sound2sound_origin_source_radio, sound2sound_origin_upload_audio,
|
| 334 |
+
sound2sound_origin_microphone_audio, inpaintWithText_state],
|
| 335 |
+
outputs=[sound2sound_origin_spectrogram_upload_image,
|
| 336 |
+
sound2sound_origin_phase_upload_image,
|
| 337 |
+
sound2sound_origin_spectrogram_microphone_image,
|
| 338 |
+
sound2sound_origin_phase_microphone_image,
|
| 339 |
+
sound2sound_origin_upload_latent_representation_image,
|
| 340 |
+
sound2sound_origin_upload_quantized_latent_representation_image,
|
| 341 |
+
sound2sound_origin_microphone_latent_representation_image,
|
| 342 |
+
sound2sound_origin_microphone_quantized_latent_representation_image,
|
| 343 |
+
inpaintWithText_state])
|
| 344 |
+
|
| 345 |
+
sound2sound_sample_button.click(sound2sound_sample,
|
| 346 |
+
inputs=[sound2sound_origin_spectrogram_upload_image,
|
| 347 |
+
sound2sound_origin_spectrogram_microphone_image,
|
| 348 |
+
text2sound_prompts_textbox,
|
| 349 |
+
text2sound_negative_prompts_textbox,
|
| 350 |
+
sound2sound_batchsize_slider,
|
| 351 |
+
sound2sound_guidance_scale_slider,
|
| 352 |
+
sound2sound_sampler_radio,
|
| 353 |
+
sound2sound_sample_steps_slider,
|
| 354 |
+
sound2sound_origin_source_radio,
|
| 355 |
+
sound2sound_noising_strength_slider,
|
| 356 |
+
sound2sound_seed_textbox,
|
| 357 |
+
sound2sound_inpaint_area_radio,
|
| 358 |
+
inpaintWithText_state],
|
| 359 |
+
outputs=[sound2sound_new_sound_latent_representation_image,
|
| 360 |
+
sound2sound_new_sound_quantized_latent_representation_image,
|
| 361 |
+
sound2sound_new_sound_spectrogram_image,
|
| 362 |
+
sound2sound_new_sound_phase_image,
|
| 363 |
+
sound2sound_new_sound_audio,
|
| 364 |
+
sound2sound_sample_index_slider,
|
| 365 |
+
sound2sound_seed_textbox,
|
| 366 |
+
inpaintWithText_state])
|
| 367 |
+
|
| 368 |
+
sound2sound_sample_index_slider.change(show_sound2sound_sample,
|
| 369 |
+
inputs=[sound2sound_sample_index_slider, inpaintWithText_state],
|
| 370 |
+
outputs=[sound2sound_new_sound_latent_representation_image,
|
| 371 |
+
sound2sound_new_sound_quantized_latent_representation_image,
|
| 372 |
+
sound2sound_new_sound_spectrogram_image,
|
| 373 |
+
sound2sound_new_sound_phase_image,
|
| 374 |
+
sound2sound_new_sound_audio])
|
| 375 |
+
|
| 376 |
+
sound2sound_origin_source_radio.change(sound2sound_switch_origin_source,
|
| 377 |
+
inputs=[sound2sound_origin_source_radio],
|
| 378 |
+
outputs=[sound2sound_origin_upload_audio,
|
| 379 |
+
sound2sound_origin_microphone_audio,
|
| 380 |
+
sound2sound_origin_spectrogram_upload_image,
|
| 381 |
+
sound2sound_origin_phase_upload_image,
|
| 382 |
+
sound2sound_origin_spectrogram_microphone_image,
|
| 383 |
+
sound2sound_origin_phase_microphone_image,
|
| 384 |
+
sound2sound_origin_upload_latent_representation_image,
|
| 385 |
+
sound2sound_origin_upload_quantized_latent_representation_image,
|
| 386 |
+
sound2sound_origin_microphone_latent_representation_image,
|
| 387 |
+
sound2sound_origin_microphone_quantized_latent_representation_image])
|
webUI/natural_language_guided/text2sound.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from matplotlib import pyplot as plt
|
| 5 |
+
|
| 6 |
+
from model.DiffSynthSampler import DiffSynthSampler
|
| 7 |
+
from tools import safe_int
|
| 8 |
+
from webUI.natural_language_guided.utils import latent_representation_to_Gradio_image, \
|
| 9 |
+
encodeBatch2GradioOutput_STFT, add_instrument
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_text2sound_module(gradioWebUI, text2sound_state, virtual_instruments_state):
|
| 13 |
+
# Load configurations
|
| 14 |
+
uNet = gradioWebUI.uNet
|
| 15 |
+
freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
|
| 16 |
+
VAE_scale = gradioWebUI.VAE_scale
|
| 17 |
+
height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels
|
| 18 |
+
|
| 19 |
+
timesteps = gradioWebUI.timesteps
|
| 20 |
+
VAE_quantizer = gradioWebUI.VAE_quantizer
|
| 21 |
+
VAE_decoder = gradioWebUI.VAE_decoder
|
| 22 |
+
CLAP = gradioWebUI.CLAP
|
| 23 |
+
CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
|
| 24 |
+
device = gradioWebUI.device
|
| 25 |
+
squared = gradioWebUI.squared
|
| 26 |
+
sample_rate = gradioWebUI.sample_rate
|
| 27 |
+
noise_strategy = gradioWebUI.noise_strategy
|
| 28 |
+
|
| 29 |
+
def diffusion_random_sample(text2sound_prompts, text2sound_negative_prompts, text2sound_batchsize,
|
| 30 |
+
text2sound_duration,
|
| 31 |
+
text2sound_guidance_scale, text2sound_sampler,
|
| 32 |
+
text2sound_sample_steps, text2sound_seed,
|
| 33 |
+
text2sound_dict):
|
| 34 |
+
text2sound_sample_steps = int(text2sound_sample_steps)
|
| 35 |
+
text2sound_seed = safe_int(text2sound_seed, 12345678)
|
| 36 |
+
|
| 37 |
+
width = int(time_resolution * ((text2sound_duration + 1) / 4) / VAE_scale)
|
| 38 |
+
|
| 39 |
+
text2sound_batchsize = int(text2sound_batchsize)
|
| 40 |
+
|
| 41 |
+
text2sound_embedding = \
|
| 42 |
+
CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(
|
| 43 |
+
device)
|
| 44 |
+
|
| 45 |
+
CFG = int(text2sound_guidance_scale)
|
| 46 |
+
|
| 47 |
+
mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
|
| 48 |
+
negative_condition = \
|
| 49 |
+
CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[
|
| 50 |
+
0]
|
| 51 |
+
mySampler.activate_classifier_free_guidance(CFG, negative_condition.to(device))
|
| 52 |
+
|
| 53 |
+
mySampler.respace(list(np.linspace(0, timesteps - 1, text2sound_sample_steps, dtype=np.int32)))
|
| 54 |
+
|
| 55 |
+
condition = text2sound_embedding.repeat(text2sound_batchsize, 1)
|
| 56 |
+
|
| 57 |
+
latent_representations, initial_noise = \
|
| 58 |
+
mySampler.sample(model=uNet, shape=(text2sound_batchsize, channels, height, width), seed=text2sound_seed,
|
| 59 |
+
return_tensor=True, condition=condition, sampler=text2sound_sampler)
|
| 60 |
+
|
| 61 |
+
latent_representations = latent_representations[-1]
|
| 62 |
+
print(latent_representations[0, 0, :3, :3])
|
| 63 |
+
|
| 64 |
+
latent_representation_gradio_images = []
|
| 65 |
+
quantized_latent_representation_gradio_images = []
|
| 66 |
+
new_sound_spectrogram_gradio_images = []
|
| 67 |
+
new_sound_phase_gradio_images = []
|
| 68 |
+
new_sound_rec_signals_gradio = []
|
| 69 |
+
|
| 70 |
+
quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
|
| 71 |
+
# Todo: remove hard-coding
|
| 72 |
+
flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder,
|
| 73 |
+
quantized_latent_representations,
|
| 74 |
+
resolution=(
|
| 75 |
+
512,
|
| 76 |
+
width * VAE_scale),
|
| 77 |
+
original_STFT_batch=None
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
for i in range(text2sound_batchsize):
|
| 81 |
+
latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
|
| 82 |
+
quantized_latent_representation_gradio_images.append(
|
| 83 |
+
latent_representation_to_Gradio_image(quantized_latent_representations[i]))
|
| 84 |
+
new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i])
|
| 85 |
+
new_sound_phase_gradio_images.append(flipped_phases[i])
|
| 86 |
+
new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i]))
|
| 87 |
+
|
| 88 |
+
text2sound_dict["latent_representations"] = latent_representations.to("cpu").detach().numpy()
|
| 89 |
+
text2sound_dict["quantized_latent_representations"] = quantized_latent_representations.to("cpu").detach().numpy()
|
| 90 |
+
text2sound_dict["latent_representation_gradio_images"] = latent_representation_gradio_images
|
| 91 |
+
text2sound_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
|
| 92 |
+
text2sound_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
|
| 93 |
+
text2sound_dict["new_sound_phase_gradio_images"] = new_sound_phase_gradio_images
|
| 94 |
+
text2sound_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
|
| 95 |
+
|
| 96 |
+
text2sound_dict["condition"] = condition.to("cpu").detach().numpy()
|
| 97 |
+
text2sound_dict["negative_condition"] = negative_condition.to("cpu").detach().numpy()
|
| 98 |
+
text2sound_dict["guidance_scale"] = CFG
|
| 99 |
+
text2sound_dict["sampler"] = text2sound_sampler
|
| 100 |
+
|
| 101 |
+
return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][0],
|
| 102 |
+
text2sound_quantized_latent_representation_image:
|
| 103 |
+
text2sound_dict["quantized_latent_representation_gradio_images"][0],
|
| 104 |
+
text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][0],
|
| 105 |
+
text2sound_sampled_phase_image: text2sound_dict["new_sound_phase_gradio_images"][0],
|
| 106 |
+
text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][0],
|
| 107 |
+
text2sound_seed_textbox: text2sound_seed,
|
| 108 |
+
text2sound_state: text2sound_dict,
|
| 109 |
+
text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1,
|
| 110 |
+
visible=True,
|
| 111 |
+
label="Sample index.",
|
| 112 |
+
info="Swipe to view other samples")}
|
| 113 |
+
|
| 114 |
+
def show_random_sample(sample_index, text2sound_dict):
|
| 115 |
+
sample_index = int(sample_index)
|
| 116 |
+
text2sound_dict["sample_index"] = sample_index
|
| 117 |
+
return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][
|
| 118 |
+
sample_index],
|
| 119 |
+
text2sound_quantized_latent_representation_image:
|
| 120 |
+
text2sound_dict["quantized_latent_representation_gradio_images"][sample_index],
|
| 121 |
+
text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][
|
| 122 |
+
sample_index],
|
| 123 |
+
text2sound_sampled_phase_image: text2sound_dict["new_sound_phase_gradio_images"][
|
| 124 |
+
sample_index],
|
| 125 |
+
text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def save_virtual_instrument(sample_index, virtual_instrument_name, text2sound_dict, virtual_instruments_dict):
|
| 130 |
+
|
| 131 |
+
virtual_instruments_dict = add_instrument(text2sound_dict, virtual_instruments_dict, virtual_instrument_name, sample_index)
|
| 132 |
+
|
| 133 |
+
return {virtual_instruments_state: virtual_instruments_dict,
|
| 134 |
+
text2sound_instrument_name_textbox: gr.Textbox(label="Instrument name", lines=1,
|
| 135 |
+
placeholder=f"Saved as {virtual_instrument_name}!")}
|
| 136 |
+
|
| 137 |
+
with gr.Tab("Text2sound"):
|
| 138 |
+
gr.Markdown("Use neural networks to select random sounds using your favorite instrument!")
|
| 139 |
+
with gr.Row(variant="panel"):
|
| 140 |
+
with gr.Column(scale=3):
|
| 141 |
+
text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ")
|
| 142 |
+
text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
|
| 143 |
+
|
| 144 |
+
with gr.Column(scale=1):
|
| 145 |
+
text2sound_sampling_button = gr.Button(variant="primary",
|
| 146 |
+
value="Generate a batch of samples and show "
|
| 147 |
+
"the first one",
|
| 148 |
+
scale=1)
|
| 149 |
+
text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
|
| 150 |
+
label="Sample index",
|
| 151 |
+
info="Swipe to view other samples")
|
| 152 |
+
with gr.Row(variant="panel"):
|
| 153 |
+
with gr.Column(variant="panel"):
|
| 154 |
+
text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
|
| 155 |
+
text2sound_sampler_radio = gradioWebUI.get_sampler_radio()
|
| 156 |
+
text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider()
|
| 157 |
+
text2sound_duration_slider = gradioWebUI.get_duration_slider()
|
| 158 |
+
text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
|
| 159 |
+
text2sound_seed_textbox = gradioWebUI.get_seed_textbox()
|
| 160 |
+
|
| 161 |
+
with gr.Column(variant="panel"):
|
| 162 |
+
with gr.Row(variant="panel"):
|
| 163 |
+
text2sound_sampled_spectrogram_image = gr.Image(label="Sampled spectrogram", type="numpy", height=600)
|
| 164 |
+
text2sound_sampled_phase_image = gr.Image(label="Sampled phase", type="numpy", height=600)
|
| 165 |
+
text2sound_sampled_audio = gr.Audio(type="numpy", label="Play")
|
| 166 |
+
|
| 167 |
+
with gr.Row(variant="panel"):
|
| 168 |
+
text2sound_instrument_name_textbox = gr.Textbox(label="Instrument name", lines=1,
|
| 169 |
+
placeholder="Name of your instrument")
|
| 170 |
+
text2sound_save_instrument_button = gr.Button(variant="primary",
|
| 171 |
+
value="Save instrument",
|
| 172 |
+
scale=1)
|
| 173 |
+
|
| 174 |
+
with gr.Row(variant="panel"):
|
| 175 |
+
text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy",
|
| 176 |
+
height=200, width=100)
|
| 177 |
+
text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation",
|
| 178 |
+
type="numpy", height=200, width=100)
|
| 179 |
+
|
| 180 |
+
text2sound_sampling_button.click(diffusion_random_sample,
|
| 181 |
+
inputs=[text2sound_prompts_textbox,
|
| 182 |
+
text2sound_negative_prompts_textbox,
|
| 183 |
+
text2sound_batchsize_slider,
|
| 184 |
+
text2sound_duration_slider,
|
| 185 |
+
text2sound_guidance_scale_slider, text2sound_sampler_radio,
|
| 186 |
+
text2sound_sample_steps_slider,
|
| 187 |
+
text2sound_seed_textbox,
|
| 188 |
+
text2sound_state],
|
| 189 |
+
outputs=[text2sound_latent_representation_image,
|
| 190 |
+
text2sound_quantized_latent_representation_image,
|
| 191 |
+
text2sound_sampled_spectrogram_image,
|
| 192 |
+
text2sound_sampled_phase_image,
|
| 193 |
+
text2sound_sampled_audio,
|
| 194 |
+
text2sound_seed_textbox,
|
| 195 |
+
text2sound_state,
|
| 196 |
+
text2sound_sample_index_slider])
|
| 197 |
+
|
| 198 |
+
text2sound_save_instrument_button.click(save_virtual_instrument,
|
| 199 |
+
inputs=[text2sound_sample_index_slider,
|
| 200 |
+
text2sound_instrument_name_textbox,
|
| 201 |
+
text2sound_state,
|
| 202 |
+
virtual_instruments_state],
|
| 203 |
+
outputs=[virtual_instruments_state,
|
| 204 |
+
text2sound_instrument_name_textbox])
|
| 205 |
+
|
| 206 |
+
text2sound_sample_index_slider.change(show_random_sample,
|
| 207 |
+
inputs=[text2sound_sample_index_slider, text2sound_state],
|
| 208 |
+
outputs=[text2sound_latent_representation_image,
|
| 209 |
+
text2sound_quantized_latent_representation_image,
|
| 210 |
+
text2sound_sampled_spectrogram_image,
|
| 211 |
+
text2sound_sampled_phase_image,
|
| 212 |
+
text2sound_sampled_audio])
|
webUI/natural_language_guided/track_maker.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from model.DiffSynthSampler import DiffSynthSampler
|
| 5 |
+
from webUI.natural_language_guided.utils import encodeBatch2GradioOutput_STFT
|
| 6 |
+
import mido
|
| 7 |
+
import pyrubberband as pyrb
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
class NoteEvent:
|
| 11 |
+
def __init__(self, note, velocity, start_time, duration):
|
| 12 |
+
self.note = note
|
| 13 |
+
self.velocity = velocity
|
| 14 |
+
self.start_time = start_time # In ticks
|
| 15 |
+
self.duration = duration # In ticks
|
| 16 |
+
|
| 17 |
+
def __str__(self):
|
| 18 |
+
return f"Note {self.note}, velocity {self.velocity}, start_time {self.start_time}, duration {self.duration}"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Track:
|
| 22 |
+
def __init__(self, track, ticks_per_beat):
|
| 23 |
+
self.tempo_events = self._parse_tempo_events(track)
|
| 24 |
+
self.events = self._parse_note_events(track)
|
| 25 |
+
self.ticks_per_beat = ticks_per_beat
|
| 26 |
+
|
| 27 |
+
def _parse_tempo_events(self, track):
|
| 28 |
+
tempo_events = []
|
| 29 |
+
current_tempo = 500000 # Default MIDI tempo is 120 BPM which is 500000 microseconds per beat
|
| 30 |
+
for msg in track:
|
| 31 |
+
if msg.type == 'set_tempo':
|
| 32 |
+
tempo_events.append((msg.time, msg.tempo))
|
| 33 |
+
elif not msg.is_meta:
|
| 34 |
+
tempo_events.append((msg.time, current_tempo))
|
| 35 |
+
return tempo_events
|
| 36 |
+
|
| 37 |
+
def _parse_note_events(self, track):
|
| 38 |
+
events = []
|
| 39 |
+
start_time = 0
|
| 40 |
+
for msg in track:
|
| 41 |
+
if not msg.is_meta:
|
| 42 |
+
start_time += msg.time
|
| 43 |
+
if msg.type == 'note_on' and msg.velocity > 0:
|
| 44 |
+
note_on_time = start_time
|
| 45 |
+
elif msg.type == 'note_on' and msg.velocity == 0:
|
| 46 |
+
duration = start_time - note_on_time
|
| 47 |
+
events.append(NoteEvent(msg.note, msg.velocity, note_on_time, duration))
|
| 48 |
+
return events
|
| 49 |
+
|
| 50 |
+
def synthesize_track(self, diffSynthSampler, sample_rate=16000):
|
| 51 |
+
track_audio = np.zeros(int(self._get_total_time() * sample_rate), dtype=np.float32)
|
| 52 |
+
current_tempo = 500000 # Start with default MIDI tempo 120 BPM
|
| 53 |
+
duration_note_mapping = {}
|
| 54 |
+
|
| 55 |
+
for event in tqdm(self.events[:50]):
|
| 56 |
+
current_tempo = self._get_tempo_at(event.start_time)
|
| 57 |
+
seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo)
|
| 58 |
+
start_time_sec = event.start_time * seconds_per_tick
|
| 59 |
+
# Todo: set a minimum duration
|
| 60 |
+
duration_sec = event.duration * seconds_per_tick
|
| 61 |
+
duration_sec = max(duration_sec, 0.75)
|
| 62 |
+
start_sample = int(start_time_sec * sample_rate)
|
| 63 |
+
if not (str(duration_sec) in duration_note_mapping):
|
| 64 |
+
note_sample = diffSynthSampler(event.velocity, duration_sec)
|
| 65 |
+
duration_note_mapping[str(duration_sec)] = note_sample / np.max(np.abs(note_sample))
|
| 66 |
+
|
| 67 |
+
note_audio = pyrb.pitch_shift(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52)
|
| 68 |
+
end_sample = start_sample + len(note_audio)
|
| 69 |
+
track_audio[start_sample:end_sample] += note_audio
|
| 70 |
+
|
| 71 |
+
return track_audio
|
| 72 |
+
|
| 73 |
+
def _get_tempo_at(self, time_tick):
|
| 74 |
+
current_tempo = 500000 # Start with default MIDI tempo 120 BPM
|
| 75 |
+
elapsed_ticks = 0
|
| 76 |
+
|
| 77 |
+
for tempo_change in self.tempo_events:
|
| 78 |
+
if elapsed_ticks + tempo_change[0] > time_tick:
|
| 79 |
+
return current_tempo
|
| 80 |
+
elapsed_ticks += tempo_change[0]
|
| 81 |
+
current_tempo = tempo_change[1]
|
| 82 |
+
|
| 83 |
+
return current_tempo
|
| 84 |
+
|
| 85 |
+
def _get_total_time(self):
|
| 86 |
+
total_time = 0
|
| 87 |
+
current_tempo = 500000 # Start with default MIDI tempo 120 BPM
|
| 88 |
+
|
| 89 |
+
for event in self.events:
|
| 90 |
+
current_tempo = self._get_tempo_at(event.start_time)
|
| 91 |
+
seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo)
|
| 92 |
+
total_time += event.duration * seconds_per_tick
|
| 93 |
+
|
| 94 |
+
return total_time
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class DiffSynth:
|
| 98 |
+
def __init__(self, instruments_configs, noise_prediction_model, VAE_quantizer, VAE_decoder, text_encoder, CLAP_tokenizer, device,
|
| 99 |
+
model_sample_rate=16000, timesteps=1000, channels=4, freq_resolution=512, time_resolution=256, VAE_scale=4, squared=False):
|
| 100 |
+
|
| 101 |
+
self.noise_prediction_model = noise_prediction_model
|
| 102 |
+
self.VAE_quantizer = VAE_quantizer
|
| 103 |
+
self.VAE_decoder = VAE_decoder
|
| 104 |
+
self.device = device
|
| 105 |
+
self.model_sample_rate = model_sample_rate
|
| 106 |
+
self.timesteps = timesteps
|
| 107 |
+
self.channels = channels
|
| 108 |
+
self.freq_resolution = freq_resolution
|
| 109 |
+
self.time_resolution = time_resolution
|
| 110 |
+
self.height = int(freq_resolution/VAE_scale)
|
| 111 |
+
self.VAE_scale = VAE_scale
|
| 112 |
+
self.squared = squared
|
| 113 |
+
self.text_encoder = text_encoder
|
| 114 |
+
self.CLAP_tokenizer = CLAP_tokenizer
|
| 115 |
+
|
| 116 |
+
# instruments_configs 是字典 string -> (condition, negative_condition, guidance_scale, sample_steps, seed, initial_noise, sampler)
|
| 117 |
+
self.instruments_configs = instruments_configs
|
| 118 |
+
self.diffSynthSamplers = {}
|
| 119 |
+
self._update_instruments()
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _update_instruments(self):
|
| 123 |
+
|
| 124 |
+
def diffSynthSamplerWrapper(instruments_config):
|
| 125 |
+
|
| 126 |
+
def diffSynthSampler(velocity, duration_sec, sample_rate=16000):
|
| 127 |
+
|
| 128 |
+
condition = self.text_encoder.get_text_features(**self.CLAP_tokenizer([""], padding=True, return_tensors="pt")).to(self.device)
|
| 129 |
+
sample_steps = instruments_config['sample_steps']
|
| 130 |
+
sampler = instruments_config['sampler']
|
| 131 |
+
noising_strength = instruments_config['noising_strength']
|
| 132 |
+
latent_representation = instruments_config['latent_representation']
|
| 133 |
+
attack = instruments_config['attack']
|
| 134 |
+
before_release = instruments_config['before_release']
|
| 135 |
+
|
| 136 |
+
assert sample_rate == self.model_sample_rate, "sample_rate != model_sample_rate"
|
| 137 |
+
|
| 138 |
+
width = int(self.time_resolution * ((duration_sec + 1) / 4) / self.VAE_scale)
|
| 139 |
+
|
| 140 |
+
mySampler = DiffSynthSampler(self.timesteps, height=128, channels=4, noise_strategy="repeat", mute=True)
|
| 141 |
+
mySampler.respace(list(np.linspace(0, self.timesteps - 1, sample_steps, dtype=np.int32)))
|
| 142 |
+
|
| 143 |
+
# mask = 1, freeze
|
| 144 |
+
latent_mask = torch.zeros((1, 1, self.height, width), dtype=torch.float32).to(self.device)
|
| 145 |
+
latent_mask[:, :, :, :int(self.time_resolution * (attack / 4) / self.VAE_scale)] = 1.0
|
| 146 |
+
latent_mask[:, :, :, -int(self.time_resolution * ((before_release+1) / 4) / self.VAE_scale):] = 1.0
|
| 147 |
+
|
| 148 |
+
latent_representations, _ = \
|
| 149 |
+
mySampler.inpaint_sample(model=self.noise_prediction_model, shape=(1, self.channels, self.height, width),
|
| 150 |
+
noising_strength=noising_strength, condition=condition,
|
| 151 |
+
guide_img=latent_representation, mask=latent_mask, return_tensor=True,
|
| 152 |
+
sampler=sampler,
|
| 153 |
+
use_dynamic_mask=True, end_noise_level_ratio=0.0,
|
| 154 |
+
mask_flexivity=1.0)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
latent_representations = latent_representations[-1]
|
| 158 |
+
|
| 159 |
+
quantized_latent_representations, _, (_, _, _) = self.VAE_quantizer(latent_representations)
|
| 160 |
+
# Todo: remove hard-coding
|
| 161 |
+
|
| 162 |
+
flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(self.VAE_decoder,
|
| 163 |
+
quantized_latent_representations,
|
| 164 |
+
resolution=(
|
| 165 |
+
512,
|
| 166 |
+
width * self.VAE_scale),
|
| 167 |
+
original_STFT_batch=None,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
return rec_signals[0]
|
| 172 |
+
|
| 173 |
+
return diffSynthSampler
|
| 174 |
+
|
| 175 |
+
for key in self.instruments_configs.keys():
|
| 176 |
+
self.diffSynthSamplers[key] = diffSynthSamplerWrapper(self.instruments_configs[key])
|
| 177 |
+
|
| 178 |
+
def get_music(self, mid, instrument_names, sample_rate=16000):
|
| 179 |
+
tracks = [Track(t, mid.ticks_per_beat) for t in mid.tracks]
|
| 180 |
+
assert len(tracks) == len(instrument_names), f"len(tracks) = {len(tracks)} != {len(instrument_names)} = len(instrument_names)"
|
| 181 |
+
|
| 182 |
+
track_audios = [track.synthesize_track(self.diffSynthSamplers[instrument_names[i]], sample_rate=sample_rate) for i, track in enumerate(tracks)]
|
| 183 |
+
|
| 184 |
+
# 将所有音轨填充至最长音轨的长度,以便它们可以被叠加
|
| 185 |
+
max_length = max(len(audio) for audio in track_audios)
|
| 186 |
+
full_audio = np.zeros(max_length, dtype=np.float32) # 初始化全音频数组为零
|
| 187 |
+
for audio in track_audios:
|
| 188 |
+
# 音轨可能不够长,需要填充零
|
| 189 |
+
padded_audio = np.pad(audio, (0, max_length - len(audio)), 'constant')
|
| 190 |
+
full_audio += padded_audio # 叠加音轨
|
| 191 |
+
|
| 192 |
+
return full_audio
|
webUI/natural_language_guided/utils.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from tools import np_power_to_db, decode_stft, depad_STFT
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def spectrogram_to_Gradio_image(spc):
|
| 9 |
+
### input: spc [np.ndarray]
|
| 10 |
+
frequency_resolution, time_resolution = spc.shape[-2], spc.shape[-1]
|
| 11 |
+
spc = np.reshape(spc, (frequency_resolution, time_resolution))
|
| 12 |
+
|
| 13 |
+
# Todo:
|
| 14 |
+
magnitude_spectrum = np.abs(spc)
|
| 15 |
+
log_spectrum = np_power_to_db(magnitude_spectrum)
|
| 16 |
+
flipped_log_spectrum = np.flipud(log_spectrum)
|
| 17 |
+
|
| 18 |
+
colorful_spc = np.ones((frequency_resolution, time_resolution, 3)) * -80.0
|
| 19 |
+
colorful_spc[:, :, 0] = flipped_log_spectrum
|
| 20 |
+
colorful_spc[:, :, 1] = flipped_log_spectrum
|
| 21 |
+
colorful_spc[:, :, 2] = np.ones((frequency_resolution, time_resolution)) * -60.0
|
| 22 |
+
# Rescale to 0-255 and convert to uint8
|
| 23 |
+
rescaled = (colorful_spc + 80.0) / 80.0
|
| 24 |
+
rescaled = (255.0 * rescaled).astype(np.uint8)
|
| 25 |
+
return rescaled
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def phase_to_Gradio_image(phase):
|
| 29 |
+
### input: spc [np.ndarray]
|
| 30 |
+
frequency_resolution, time_resolution = phase.shape[-2], phase.shape[-1]
|
| 31 |
+
phase = np.reshape(phase, (frequency_resolution, time_resolution))
|
| 32 |
+
|
| 33 |
+
# Todo:
|
| 34 |
+
flipped_phase = np.flipud(phase)
|
| 35 |
+
flipped_phase = (flipped_phase + 1.0) / 2.0
|
| 36 |
+
|
| 37 |
+
colorful_spc = np.zeros((frequency_resolution, time_resolution, 3))
|
| 38 |
+
colorful_spc[:, :, 0] = flipped_phase
|
| 39 |
+
colorful_spc[:, :, 1] = flipped_phase
|
| 40 |
+
colorful_spc[:, :, 2] = 0.2
|
| 41 |
+
# Rescale to 0-255 and convert to uint8
|
| 42 |
+
rescaled = (255.0 * colorful_spc).astype(np.uint8)
|
| 43 |
+
return rescaled
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def latent_representation_to_Gradio_image(latent_representation):
|
| 47 |
+
# input: latent_representation [torch.tensor]
|
| 48 |
+
if not isinstance(latent_representation, np.ndarray):
|
| 49 |
+
latent_representation = latent_representation.to("cpu").detach().numpy()
|
| 50 |
+
image = latent_representation
|
| 51 |
+
|
| 52 |
+
def normalize_image(img):
|
| 53 |
+
min_val = img.min()
|
| 54 |
+
max_val = img.max()
|
| 55 |
+
normalized_img = ((img - min_val) / (max_val - min_val) * 255)
|
| 56 |
+
return normalized_img
|
| 57 |
+
|
| 58 |
+
image[0, :, :] = normalize_image(image[0, :, :])
|
| 59 |
+
image[1, :, :] = normalize_image(image[1, :, :])
|
| 60 |
+
image[2, :, :] = normalize_image(image[2, :, :])
|
| 61 |
+
image[3, :, :] = normalize_image(image[3, :, :])
|
| 62 |
+
image_transposed = np.transpose(image, (1, 2, 0))
|
| 63 |
+
enlarged_image = np.repeat(image_transposed, 8, axis=0)
|
| 64 |
+
enlarged_image = np.repeat(enlarged_image, 8, axis=1)
|
| 65 |
+
return np.flipud(enlarged_image).astype(np.uint8)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def InputBatch2Encode_STFT(encoder, STFT_batch, resolution=(512, 256), quantizer=None, squared=True):
|
| 69 |
+
"""Transform batch of numpy spectrogram's into signals and encodings."""
|
| 70 |
+
# Todo: remove resolution hard-coding
|
| 71 |
+
frequency_resolution, time_resolution = resolution
|
| 72 |
+
|
| 73 |
+
device = next(encoder.parameters()).device
|
| 74 |
+
if not (quantizer is None):
|
| 75 |
+
latent_representation_batch = encoder(STFT_batch.to(device))
|
| 76 |
+
quantized_latent_representation_batch, loss, (_, _, _) = quantizer(latent_representation_batch)
|
| 77 |
+
else:
|
| 78 |
+
mu, logvar, latent_representation_batch = encoder(STFT_batch.to(device))
|
| 79 |
+
quantized_latent_representation_batch = None
|
| 80 |
+
|
| 81 |
+
STFT_batch = STFT_batch.to("cpu").detach().numpy()
|
| 82 |
+
|
| 83 |
+
origin_flipped_log_spectrums, origin_flipped_phases, origin_signals = [], [], []
|
| 84 |
+
for STFT in STFT_batch:
|
| 85 |
+
|
| 86 |
+
padded_D_rec = decode_stft(STFT)
|
| 87 |
+
D_rec = depad_STFT(padded_D_rec)
|
| 88 |
+
spc = np.abs(D_rec)
|
| 89 |
+
phase = np.angle(D_rec)
|
| 90 |
+
|
| 91 |
+
flipped_log_spectrum = spectrogram_to_Gradio_image(spc)
|
| 92 |
+
flipped_phase = phase_to_Gradio_image(phase)
|
| 93 |
+
|
| 94 |
+
# get_audio
|
| 95 |
+
rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)
|
| 96 |
+
|
| 97 |
+
origin_flipped_log_spectrums.append(flipped_log_spectrum)
|
| 98 |
+
origin_flipped_phases.append(flipped_phase)
|
| 99 |
+
origin_signals.append(rec_signal)
|
| 100 |
+
|
| 101 |
+
return origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, \
|
| 102 |
+
latent_representation_batch, quantized_latent_representation_batch
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def encodeBatch2GradioOutput_STFT(decoder, latent_vector_batch, resolution=(512, 256), original_STFT_batch=None):
|
| 106 |
+
"""Show a spectrogram."""
|
| 107 |
+
# Todo: remove resolution hard-coding
|
| 108 |
+
frequency_resolution, time_resolution = resolution
|
| 109 |
+
|
| 110 |
+
if isinstance(latent_vector_batch, np.ndarray):
|
| 111 |
+
latent_vector_batch = torch.from_numpy(latent_vector_batch).to(next(decoder.parameters()).device)
|
| 112 |
+
|
| 113 |
+
reconstruction_batch = decoder(latent_vector_batch).to("cpu").detach().numpy()
|
| 114 |
+
|
| 115 |
+
flipped_log_spectrums, flipped_phases, rec_signals = [], [], []
|
| 116 |
+
flipped_log_spectrums_with_original_amp, flipped_phases_with_original_amp, rec_signals_with_original_amp = [], [], []
|
| 117 |
+
|
| 118 |
+
for index, STFT in enumerate(reconstruction_batch):
|
| 119 |
+
padded_D_rec = decode_stft(STFT)
|
| 120 |
+
D_rec = depad_STFT(padded_D_rec)
|
| 121 |
+
spc = np.abs(D_rec)
|
| 122 |
+
phase = np.angle(D_rec)
|
| 123 |
+
|
| 124 |
+
flipped_log_spectrum = spectrogram_to_Gradio_image(spc)
|
| 125 |
+
flipped_phase = phase_to_Gradio_image(phase)
|
| 126 |
+
|
| 127 |
+
# get_audio
|
| 128 |
+
rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)
|
| 129 |
+
|
| 130 |
+
flipped_log_spectrums.append(flipped_log_spectrum)
|
| 131 |
+
flipped_phases.append(flipped_phase)
|
| 132 |
+
rec_signals.append(rec_signal)
|
| 133 |
+
|
| 134 |
+
##########################################
|
| 135 |
+
|
| 136 |
+
if original_STFT_batch is not None:
|
| 137 |
+
STFT[0, :, :] = original_STFT_batch[index, 0, :, :]
|
| 138 |
+
|
| 139 |
+
padded_D_rec = decode_stft(STFT)
|
| 140 |
+
D_rec = depad_STFT(padded_D_rec)
|
| 141 |
+
spc = np.abs(D_rec)
|
| 142 |
+
phase = np.angle(D_rec)
|
| 143 |
+
|
| 144 |
+
flipped_log_spectrum = spectrogram_to_Gradio_image(spc)
|
| 145 |
+
flipped_phase = phase_to_Gradio_image(phase)
|
| 146 |
+
|
| 147 |
+
# get_audio
|
| 148 |
+
rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)
|
| 149 |
+
|
| 150 |
+
flipped_log_spectrums_with_original_amp.append(flipped_log_spectrum)
|
| 151 |
+
flipped_phases_with_original_amp.append(flipped_phase)
|
| 152 |
+
rec_signals_with_original_amp.append(rec_signal)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
return flipped_log_spectrums, flipped_phases, rec_signals, \
|
| 156 |
+
flipped_log_spectrums_with_original_amp, flipped_phases_with_original_amp, rec_signals_with_original_amp
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def add_instrument(source_dict, virtual_instruments_dict, virtual_instrument_name, sample_index):
|
| 161 |
+
|
| 162 |
+
virtual_instruments = virtual_instruments_dict["virtual_instruments"]
|
| 163 |
+
virtual_instrument = {
|
| 164 |
+
"latent_representation": source_dict["latent_representations"][sample_index],
|
| 165 |
+
"quantized_latent_representation": source_dict["quantized_latent_representations"][sample_index],
|
| 166 |
+
"sampler": source_dict["sampler"],
|
| 167 |
+
"signal": source_dict["new_sound_rec_signals_gradio"][sample_index],
|
| 168 |
+
"spectrogram_gradio_image": source_dict["new_sound_spectrogram_gradio_images"][
|
| 169 |
+
sample_index],
|
| 170 |
+
"phase_gradio_image": source_dict["new_sound_phase_gradio_images"][
|
| 171 |
+
sample_index]}
|
| 172 |
+
virtual_instruments[virtual_instrument_name] = virtual_instrument
|
| 173 |
+
virtual_instruments_dict["virtual_instruments"] = virtual_instruments
|
| 174 |
+
return virtual_instruments_dict
|