Spaces:
Paused
Paused
| import json | |
| from functools import partial | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from six.moves import xrange | |
| from torch.utils.tensorboard import SummaryWriter | |
| import random | |
| from metrics.IS import get_inception_score | |
| from tools import create_key | |
| from model.diffusion_components import default, ConvNextBlock, ResnetBlock, SinusoidalPositionEmbeddings, Residual, \ | |
| PreNorm, \ | |
| Downsample, Upsample, exists, q_sample, get_beta_schedule, pad_and_concat, ConditionalEmbedding, \ | |
| LinearCrossAttention, LinearCrossAttentionAdd | |
| class ConditionedUnet(nn.Module): | |
| def __init__( | |
| self, | |
| in_dim, | |
| out_dim=None, | |
| down_dims=None, | |
| up_dims=None, | |
| mid_depth=3, | |
| with_time_emb=True, | |
| time_dim=None, | |
| resnet_block_groups=8, | |
| use_convnext=True, | |
| convnext_mult=2, | |
| attn_type="linear_cat", | |
| n_label_class=11, | |
| condition_type="instrument_family", | |
| label_emb_dim=128, | |
| ): | |
| super().__init__() | |
| self.label_embedding = ConditionalEmbedding(int(n_label_class + 1), int(label_emb_dim), condition_type) | |
| if up_dims is None: | |
| up_dims = [128, 128, 64, 32] | |
| if down_dims is None: | |
| down_dims = [32, 32, 64, 128] | |
| out_dim = default(out_dim, in_dim) | |
| assert len(down_dims) == len(up_dims), "len(down_dims) != len(up_dims)" | |
| assert down_dims[0] == up_dims[-1], "down_dims[0] != up_dims[-1]" | |
| assert up_dims[0] == down_dims[-1], "up_dims[0] != down_dims[-1]" | |
| down_in_out = list(zip(down_dims[:-1], down_dims[1:])) | |
| up_in_out = list(zip(up_dims[:-1], up_dims[1:])) | |
| print(f"down_in_out: {down_in_out}") | |
| print(f"up_in_out: {up_in_out}") | |
| time_dim = default(time_dim, int(down_dims[0] * 4)) | |
| self.init_conv = nn.Conv2d(in_dim, down_dims[0], 7, padding=3) | |
| if use_convnext: | |
| block_klass = partial(ConvNextBlock, mult=convnext_mult) | |
| else: | |
| block_klass = partial(ResnetBlock, groups=resnet_block_groups) | |
| if attn_type == "linear_cat": | |
| attn_klass = partial(LinearCrossAttention) | |
| elif attn_type == "linear_add": | |
| attn_klass = partial(LinearCrossAttentionAdd) | |
| else: | |
| raise NotImplementedError() | |
| # time embeddings | |
| if with_time_emb: | |
| self.time_mlp = nn.Sequential( | |
| SinusoidalPositionEmbeddings(down_dims[0]), | |
| nn.Linear(down_dims[0], time_dim), | |
| nn.GELU(), | |
| nn.Linear(time_dim, time_dim), | |
| ) | |
| else: | |
| time_dim = None | |
| self.time_mlp = None | |
| # left layers | |
| self.downs = nn.ModuleList([]) | |
| self.ups = nn.ModuleList([]) | |
| skip_dims = [] | |
| for down_dim_in, down_dim_out in down_in_out: | |
| self.downs.append( | |
| nn.ModuleList( | |
| [ | |
| block_klass(down_dim_in, down_dim_out, time_emb_dim=time_dim), | |
| Residual(PreNorm(down_dim_out, attn_klass(down_dim_out, label_emb_dim=label_emb_dim, ))), | |
| block_klass(down_dim_out, down_dim_out, time_emb_dim=time_dim), | |
| Residual(PreNorm(down_dim_out, attn_klass(down_dim_out, label_emb_dim=label_emb_dim, ))), | |
| Downsample(down_dim_out), | |
| ] | |
| ) | |
| ) | |
| skip_dims.append(down_dim_out) | |
| # bottleneck | |
| mid_dim = down_dims[-1] | |
| self.mid_left = nn.ModuleList([]) | |
| self.mid_right = nn.ModuleList([]) | |
| for _ in range(mid_depth - 1): | |
| self.mid_left.append(block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)) | |
| self.mid_right.append(block_klass(mid_dim * 2, mid_dim, time_emb_dim=time_dim)) | |
| self.mid_mid = nn.ModuleList( | |
| [ | |
| block_klass(mid_dim, mid_dim, time_emb_dim=time_dim), | |
| Residual(PreNorm(mid_dim, attn_klass(mid_dim, label_emb_dim=label_emb_dim, ))), | |
| block_klass(mid_dim, mid_dim, time_emb_dim=time_dim), | |
| ] | |
| ) | |
| # right layers | |
| for ind, (up_dim_in, up_dim_out) in enumerate(up_in_out): | |
| skip_dim = skip_dims.pop() # down_dim_out | |
| self.ups.append( | |
| nn.ModuleList( | |
| [ | |
| # pop&cat (h/2, w/2, down_dim_out) | |
| block_klass(up_dim_in + skip_dim, up_dim_in, time_emb_dim=time_dim), | |
| Residual(PreNorm(up_dim_in, attn_klass(up_dim_in, label_emb_dim=label_emb_dim, ))), | |
| Upsample(up_dim_in), | |
| # pop&cat (h, w, down_dim_out) | |
| block_klass(up_dim_in + skip_dim, up_dim_out, time_emb_dim=time_dim), | |
| Residual(PreNorm(up_dim_out, attn_klass(up_dim_out, label_emb_dim=label_emb_dim, ))), | |
| # pop&cat (h, w, down_dim_out) | |
| block_klass(up_dim_out + skip_dim, up_dim_out, time_emb_dim=time_dim), | |
| Residual(PreNorm(up_dim_out, attn_klass(up_dim_out, label_emb_dim=label_emb_dim, ))), | |
| ] | |
| ) | |
| ) | |
| self.final_conv = nn.Sequential( | |
| block_klass(down_dims[0] + up_dims[-1], up_dims[-1]), nn.Conv2d(up_dims[-1], out_dim, 3, padding=1) | |
| ) | |
| def size(self): | |
| total_params = sum(p.numel() for p in self.parameters()) | |
| trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| print(f"Total parameters: {total_params}") | |
| print(f"Trainable parameters: {trainable_params}") | |
| def forward(self, x, time, condition=None): | |
| if condition is not None: | |
| condition_emb = self.label_embedding(condition) | |
| else: | |
| condition_emb = None | |
| h = [] | |
| x = self.init_conv(x) | |
| h.append(x) | |
| time_emb = self.time_mlp(time) if exists(self.time_mlp) else None | |
| # downsample | |
| for block1, attn1, block2, attn2, downsample in self.downs: | |
| x = block1(x, time_emb) | |
| x = attn1(x, condition_emb) | |
| h.append(x) | |
| x = block2(x, time_emb) | |
| x = attn2(x, condition_emb) | |
| h.append(x) | |
| x = downsample(x) | |
| h.append(x) | |
| # bottleneck | |
| for block in self.mid_left: | |
| x = block(x, time_emb) | |
| h.append(x) | |
| (block1, attn, block2) = self.mid_mid | |
| x = block1(x, time_emb) | |
| x = attn(x, condition_emb) | |
| x = block2(x, time_emb) | |
| for block in self.mid_right: | |
| # This is U-Net!!! | |
| x = pad_and_concat(h.pop(), x) | |
| x = block(x, time_emb) | |
| # upsample | |
| for block1, attn1, upsample, block2, attn2, block3, attn3 in self.ups: | |
| x = pad_and_concat(h.pop(), x) | |
| x = block1(x, time_emb) | |
| x = attn1(x, condition_emb) | |
| x = upsample(x) | |
| x = pad_and_concat(h.pop(), x) | |
| x = block2(x, time_emb) | |
| x = attn2(x, condition_emb) | |
| x = pad_and_concat(h.pop(), x) | |
| x = block3(x, time_emb) | |
| x = attn3(x, condition_emb) | |
| x = pad_and_concat(h.pop(), x) | |
| x = self.final_conv(x) | |
| return x | |
| def conditional_p_losses(denoise_model, x_start, t, condition, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, | |
| noise=None, loss_type="l1"): | |
| if noise is None: | |
| noise = torch.randn_like(x_start) | |
| x_noisy = q_sample(x_start=x_start, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, | |
| sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise) | |
| predicted_noise = denoise_model(x_noisy, t, condition) | |
| if loss_type == 'l1': | |
| loss = F.l1_loss(noise, predicted_noise) | |
| elif loss_type == 'l2': | |
| loss = F.mse_loss(noise, predicted_noise) | |
| elif loss_type == "huber": | |
| loss = F.smooth_l1_loss(noise, predicted_noise) | |
| else: | |
| raise NotImplementedError() | |
| return loss | |
| def evaluate_diffusion_model(device, model, iterator, BATCH_SIZE, timesteps, unetConfig, encodes2embeddings_mapping, | |
| uncondition_rate, unconditional_condition): | |
| model.to(device) | |
| model.eval() | |
| eva_loss = [] | |
| sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, _, _ = get_beta_schedule(timesteps) | |
| for i in xrange(500): | |
| data, attributes = next(iter(iterator)) | |
| data = data.to(device) | |
| conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes] | |
| selected_conditions = [ | |
| unconditional_condition if random.random() < uncondition_rate else random.choice(conditions_of_one_sample) | |
| for conditions_of_one_sample in conditions] | |
| selected_conditions = torch.stack(selected_conditions).float().to(device) | |
| t = torch.randint(0, timesteps, (BATCH_SIZE,), device=device).long() | |
| loss = conditional_p_losses(model, data, t, selected_conditions, loss_type="huber", | |
| sqrt_alphas_cumprod=sqrt_alphas_cumprod, | |
| sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod) | |
| eva_loss.append(loss.item()) | |
| initial_loss = np.mean(eva_loss) | |
| return initial_loss | |
| def get_diffusion_model(model_Config, load_pretrain=False, model_name=None, device="cpu"): | |
| UNet = ConditionedUnet(**model_Config) | |
| print(f"Model intialized, size: {sum(p.numel() for p in UNet.parameters() if p.requires_grad)}") | |
| UNet.to(device) | |
| if load_pretrain: | |
| print(f"Loading weights from models/{model_name}_UNet.pth") | |
| checkpoint = torch.load(f'models/{model_name}_UNet.pth', map_location=device) | |
| UNet.load_state_dict(checkpoint['model_state_dict']) | |
| UNet.eval() | |
| return UNet | |
| def train_diffusion_model(VAE, text_encoder, CLAP_tokenizer, timbre_encoder, device, init_model_name, unetConfig, BATCH_SIZE, timesteps, lr, max_iter, iterator, load_pretrain, | |
| encodes2embeddings_mapping, uncondition_rate, unconditional_condition, save_steps=5000, init_loss=None, save_model_name=None, | |
| n_IS_batches=50): | |
| if save_model_name is None: | |
| save_model_name = init_model_name | |
| def save_model_hyperparameter(model_name, unetConfig, BATCH_SIZE, lr, model_size, current_iter, current_loss): | |
| model_hyperparameter = unetConfig | |
| model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE | |
| model_hyperparameter["lr"] = lr | |
| model_hyperparameter["model_size"] = model_size | |
| model_hyperparameter["current_iter"] = current_iter | |
| model_hyperparameter["current_loss"] = current_loss | |
| with open(f"models/hyperparameters/{model_name}_UNet.json", "w") as json_file: | |
| json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4) | |
| model = ConditionedUnet(**unetConfig) | |
| model_size = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f"Trainable parameters: {model_size}") | |
| model.to(device) | |
| optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, amsgrad=False) | |
| if load_pretrain: | |
| print(f"Loading weights from models/{init_model_name}_UNet.pt") | |
| checkpoint = torch.load(f'models/{init_model_name}_UNet.pth') | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| else: | |
| print("Model initialized.") | |
| if max_iter == 0: | |
| print("Return model directly.") | |
| return model, optimizer | |
| train_loss = [] | |
| writer = SummaryWriter(f'runs/{save_model_name}_UNet') | |
| if init_loss is None: | |
| previous_loss = evaluate_diffusion_model(device, model, iterator, BATCH_SIZE, timesteps, unetConfig, encodes2embeddings_mapping, | |
| uncondition_rate, unconditional_condition) | |
| else: | |
| previous_loss = init_loss | |
| print(f"initial_IS: {previous_loss}") | |
| sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, _, _ = get_beta_schedule(timesteps) | |
| model.train() | |
| for i in xrange(max_iter): | |
| data, attributes = next(iter(iterator)) | |
| data = data.to(device) | |
| conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes] | |
| unconditional_condition_copy = torch.tensor(unconditional_condition, dtype=torch.float32).to(device).detach() | |
| selected_conditions = [unconditional_condition_copy if random.random() < uncondition_rate else random.choice( | |
| conditions_of_one_sample) for conditions_of_one_sample in conditions] | |
| selected_conditions = torch.stack(selected_conditions).float().to(device) | |
| optimizer.zero_grad() | |
| t = torch.randint(0, timesteps, (BATCH_SIZE,), device=device).long() | |
| loss = conditional_p_losses(model, data, t, selected_conditions, loss_type="huber", | |
| sqrt_alphas_cumprod=sqrt_alphas_cumprod, | |
| sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod) | |
| loss.backward() | |
| optimizer.step() | |
| train_loss.append(loss.item()) | |
| step = int(optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].numpy()) | |
| if step % 100 == 0: | |
| print('%d step' % (step)) | |
| if step % save_steps == 0: | |
| current_loss = np.mean(train_loss[-save_steps:]) | |
| print(f"current_loss = {current_loss}") | |
| torch.save({ | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| }, f'models/{save_model_name}_UNet.pth') | |
| save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, lr, model_size, step, current_loss) | |
| if step % 20000 == 0: | |
| current_IS = get_inception_score(device, model, VAE, text_encoder, CLAP_tokenizer, timbre_encoder, n_IS_batches, | |
| positive_prompts="", negative_prompts="", CFG=1, sample_steps=20, task="STFT") | |
| print('current_IS: %.5f' % current_IS) | |
| current_loss = np.mean(train_loss[-save_steps:]) | |
| writer.add_scalar(f"current_IS", current_IS, step) | |
| torch.save({ | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| }, f'models/history/{save_model_name}_{step}_UNet.pth') | |
| save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, lr, model_size, step, current_loss) | |
| return model, optimizer | |