Spaces:
Sleeping
Sleeping
| import json | |
| from torch.utils.tensorboard import SummaryWriter | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from six.moves import xrange | |
| from einops import rearrange | |
| from torchvision import models | |
| def Normalize(in_channels, num_groups=32, norm_type="groupnorm"): | |
| """Normalization layer""" | |
| if norm_type == "batchnorm": | |
| return torch.nn.BatchNorm2d(in_channels) | |
| return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) | |
| def nonlinearity(x, act_type="relu"): | |
| """Nonlinear activation function""" | |
| if act_type == "relu": | |
| return F.relu(x) | |
| else: | |
| # swish | |
| return x * torch.sigmoid(x) | |
| class VectorQuantizer(nn.Module): | |
| """Vector quantization layer""" | |
| def __init__(self, num_embeddings, embedding_dim, commitment_cost): | |
| super(VectorQuantizer, self).__init__() | |
| self._embedding_dim = embedding_dim | |
| self._num_embeddings = num_embeddings | |
| self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) | |
| self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings) | |
| self._commitment_cost = commitment_cost | |
| def forward(self, inputs): | |
| # convert inputs from BCHW -> BHWC | |
| inputs = inputs.permute(0, 2, 3, 1).contiguous() | |
| input_shape = inputs.shape | |
| # Flatten input BCHW -> (BHW)C | |
| flat_input = inputs.view(-1, self._embedding_dim) | |
| # Calculate distances (input-embedding)^2 | |
| distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True) | |
| + torch.sum(self._embedding.weight ** 2, dim=1) | |
| - 2 * torch.matmul(flat_input, self._embedding.weight.t())) | |
| # Encoding (one-hot-encoding matrix) | |
| encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) | |
| encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) | |
| encodings.scatter_(1, encoding_indices, 1) | |
| # Quantize and unflatten | |
| quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) | |
| # Loss | |
| e_latent_loss = F.mse_loss(quantized.detach(), inputs) | |
| q_latent_loss = F.mse_loss(quantized, inputs.detach()) | |
| loss = q_latent_loss + self._commitment_cost * e_latent_loss | |
| quantized = inputs + (quantized - inputs).detach() | |
| avg_probs = torch.mean(encodings, dim=0) | |
| perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) | |
| # convert quantized from BHWC -> BCHW | |
| min_encodings, min_encoding_indices = None, None | |
| return quantized.permute(0, 3, 1, 2).contiguous(), loss, (perplexity, min_encodings, min_encoding_indices) | |
| class VectorQuantizerEMA(nn.Module): | |
| """Vector quantization layer based on exponential moving average""" | |
| def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5): | |
| super(VectorQuantizerEMA, self).__init__() | |
| self._embedding_dim = embedding_dim | |
| self._num_embeddings = num_embeddings | |
| self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) | |
| self._embedding.weight.data.normal_() | |
| self._commitment_cost = commitment_cost | |
| self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings)) | |
| self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim)) | |
| self._ema_w.data.normal_() | |
| self._decay = decay | |
| self._epsilon = epsilon | |
| def forward(self, inputs): | |
| # convert inputs from BCHW -> BHWC | |
| inputs = inputs.permute(0, 2, 3, 1).contiguous() | |
| input_shape = inputs.shape | |
| # Flatten input | |
| flat_input = inputs.view(-1, self._embedding_dim) | |
| # Calculate distances | |
| distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True) | |
| + torch.sum(self._embedding.weight ** 2, dim=1) | |
| - 2 * torch.matmul(flat_input, self._embedding.weight.t())) | |
| # Encoding | |
| encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) | |
| encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) | |
| encodings.scatter_(1, encoding_indices, 1) | |
| # Quantize and unflatten | |
| quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) | |
| # Use EMA to update the embedding vectors | |
| if self.training: | |
| self._ema_cluster_size = self._ema_cluster_size * self._decay + \ | |
| (1 - self._decay) * torch.sum(encodings, 0) | |
| # Laplace smoothing of the cluster size | |
| n = torch.sum(self._ema_cluster_size.data) | |
| self._ema_cluster_size = ( | |
| (self._ema_cluster_size + self._epsilon) | |
| / (n + self._num_embeddings * self._epsilon) * n) | |
| dw = torch.matmul(encodings.t(), flat_input) | |
| self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw) | |
| self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1)) | |
| # Loss | |
| e_latent_loss = F.mse_loss(quantized.detach(), inputs) | |
| loss = self._commitment_cost * e_latent_loss | |
| # Straight Through Estimator | |
| quantized = inputs + (quantized - inputs).detach() | |
| avg_probs = torch.mean(encodings, dim=0) | |
| perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) | |
| # convert quantized from BHWC -> BCHW | |
| min_encodings, min_encoding_indices = None, None | |
| return quantized.permute(0, 3, 1, 2).contiguous(), loss, (perplexity, min_encodings, min_encoding_indices) | |
| class DownSample(nn.Module): | |
| """DownSample layer""" | |
| def __init__(self, in_channels, out_channels): | |
| super(DownSample, self).__init__() | |
| self._conv2d = nn.Conv2d(in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=4, | |
| stride=2, padding=1) | |
| def forward(self, x): | |
| return self._conv2d(x) | |
| class UpSample(nn.Module): | |
| """UpSample layer""" | |
| def __init__(self, in_channels, out_channels): | |
| super(UpSample, self).__init__() | |
| self._conv2d = nn.ConvTranspose2d(in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=4, | |
| stride=2, padding=1) | |
| def forward(self, x): | |
| return self._conv2d(x) | |
| class ResnetBlock(nn.Module): | |
| """ResnetBlock is a combination of non-linearity, convolution, and normalization""" | |
| def __init__(self, *, in_channels, out_channels=None, double_conv=False, conv_shortcut=False, | |
| dropout=0.0, temb_channels=512, norm_type="groupnorm", act_type="relu", num_groups=32): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| out_channels = in_channels if out_channels is None else out_channels | |
| self.out_channels = out_channels | |
| self.use_conv_shortcut = conv_shortcut | |
| self.act_type = act_type | |
| self.norm1 = Normalize(in_channels, norm_type=norm_type, num_groups=num_groups) | |
| self.conv1 = torch.nn.Conv2d(in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1) | |
| if temb_channels > 0: | |
| self.temb_proj = torch.nn.Linear(temb_channels, | |
| out_channels) | |
| self.double_conv = double_conv | |
| if self.double_conv: | |
| self.norm2 = Normalize(out_channels, norm_type=norm_type, num_groups=num_groups) | |
| self.dropout = torch.nn.Dropout(dropout) | |
| self.conv2 = torch.nn.Conv2d(out_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1) | |
| if self.in_channels != self.out_channels: | |
| if self.use_conv_shortcut: | |
| self.conv_shortcut = torch.nn.Conv2d(in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1) | |
| else: | |
| self.nin_shortcut = torch.nn.Conv2d(in_channels, | |
| out_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0) | |
| def forward(self, x, temb=None): | |
| h = x | |
| h = self.norm1(h) | |
| h = nonlinearity(h, act_type=self.act_type) | |
| h = self.conv1(h) | |
| if temb is not None: | |
| h = h + self.temb_proj(nonlinearity(temb, act_type=self.act_type))[:, :, None, None] | |
| if self.double_conv: | |
| h = self.norm2(h) | |
| h = nonlinearity(h, act_type=self.act_type) | |
| h = self.dropout(h) | |
| h = self.conv2(h) | |
| if self.in_channels != self.out_channels: | |
| if self.use_conv_shortcut: | |
| x = self.conv_shortcut(x) | |
| else: | |
| x = self.nin_shortcut(x) | |
| return x + h | |
| class LinearAttention(nn.Module): | |
| """Efficient attention block based on <https://proceedings.mlr.press/v119/katharopoulos20a.html>""" | |
| def __init__(self, dim, heads=4, dim_head=32, with_skip=True): | |
| super().__init__() | |
| self.heads = heads | |
| hidden_dim = dim_head * heads | |
| self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) | |
| self.to_out = nn.Conv2d(hidden_dim, dim, 1) | |
| self.with_skip = with_skip | |
| if self.with_skip: | |
| self.nin_shortcut = torch.nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0) | |
| def forward(self, x): | |
| b, c, h, w = x.shape | |
| qkv = self.to_qkv(x) | |
| q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3) | |
| k = k.softmax(dim=-1) | |
| context = torch.einsum('bhdn,bhen->bhde', k, v) | |
| out = torch.einsum('bhde,bhdn->bhen', context, q) | |
| out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) | |
| if self.with_skip: | |
| return self.to_out(out) + self.nin_shortcut(x) | |
| return self.to_out(out) | |
| class Encoder(nn.Module): | |
| """The encoder, consisting of alternating stacks of ResNet blocks, efficient attention modules, and downsampling layers.""" | |
| def __init__(self, in_channels, hidden_channels, embedding_dim, block_depth=2, | |
| attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu", num_groups=32): | |
| super(Encoder, self).__init__() | |
| if attn_pos is None: | |
| attn_pos = [] | |
| self._layers = nn.ModuleList([DownSample(in_channels, hidden_channels[0])]) | |
| current_channel = hidden_channels[0] | |
| for i in range(1, len(hidden_channels)): | |
| for _ in range(block_depth - 1): | |
| self._layers.append(ResnetBlock(in_channels=current_channel, | |
| out_channels=current_channel, | |
| double_conv=False, | |
| conv_shortcut=False, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| num_groups=num_groups)) | |
| if current_channel in attn_pos: | |
| self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip)) | |
| self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups)) | |
| self._layers.append(nn.ReLU()) | |
| self._layers.append(DownSample(current_channel, hidden_channels[i])) | |
| current_channel = hidden_channels[i] | |
| for _ in range(block_depth - 1): | |
| self._layers.append(ResnetBlock(in_channels=current_channel, | |
| out_channels=current_channel, | |
| double_conv=False, | |
| conv_shortcut=False, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| num_groups=num_groups)) | |
| if current_channel in attn_pos: | |
| self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip)) | |
| # Conv1x1: hidden_channels[-1] -> embedding_dim | |
| self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups)) | |
| self._layers.append(nn.ReLU()) | |
| self._layers.append(nn.Conv2d(in_channels=current_channel, | |
| out_channels=embedding_dim, | |
| kernel_size=1, | |
| stride=1)) | |
| def forward(self, x): | |
| for layer in self._layers: | |
| x = layer(x) | |
| return x | |
| class Decoder(nn.Module): | |
| """The decoder, consisting of alternating stacks of ResNet blocks, efficient attention modules, and upsampling layers.""" | |
| def __init__(self, embedding_dim, hidden_channels, out_channels, block_depth=2, | |
| attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu", | |
| num_groups=32): | |
| super(Decoder, self).__init__() | |
| if attn_pos is None: | |
| attn_pos = [] | |
| reversed_hidden_channels = list(reversed(hidden_channels)) | |
| # Conv1x1: hidden_channels[-1] -> embedding_dim | |
| self._layers = nn.ModuleList([nn.Conv2d(in_channels=embedding_dim, | |
| out_channels=reversed_hidden_channels[0], | |
| kernel_size=1, stride=1, bias=False)]) | |
| current_channel = reversed_hidden_channels[0] | |
| for _ in range(block_depth - 1): | |
| if current_channel in attn_pos: | |
| self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip)) | |
| self._layers.append(ResnetBlock(in_channels=current_channel, | |
| out_channels=current_channel, | |
| double_conv=False, | |
| conv_shortcut=False, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| num_groups=num_groups)) | |
| for i in range(1, len(reversed_hidden_channels)): | |
| self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups)) | |
| self._layers.append(nn.ReLU()) | |
| self._layers.append(UpSample(current_channel, reversed_hidden_channels[i])) | |
| current_channel = reversed_hidden_channels[i] | |
| for _ in range(block_depth - 1): | |
| if current_channel in attn_pos: | |
| self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip)) | |
| self._layers.append(ResnetBlock(in_channels=current_channel, | |
| out_channels=current_channel, | |
| double_conv=False, | |
| conv_shortcut=False, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| num_groups=num_groups)) | |
| self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups)) | |
| self._layers.append(nn.ReLU()) | |
| self._layers.append(UpSample(current_channel, current_channel)) | |
| # final layers | |
| self._layers.append(ResnetBlock(in_channels=current_channel, | |
| out_channels=out_channels, | |
| double_conv=False, | |
| conv_shortcut=False, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| num_groups=num_groups)) | |
| def forward(self, x): | |
| for layer in self._layers: | |
| x = layer(x) | |
| log_magnitude = torch.nn.functional.softplus(x[:, 0, :, :]) | |
| cos_phase = torch.tanh(x[:, 1, :, :]) | |
| sin_phase = torch.tanh(x[:, 2, :, :]) | |
| x = torch.stack([log_magnitude, cos_phase, sin_phase], dim=1) | |
| return x | |
| class VQGAN_Discriminator(nn.Module): | |
| """The discriminator employs an 18-layer-ResNet architecture , with the first layer replaced by a 2D convolutional | |
| layer that accommodates spectral representation inputs and the last two layers replaced by a binary classifier | |
| layer.""" | |
| def __init__(self, in_channels=1): | |
| super(VQGAN_Discriminator, self).__init__() | |
| resnet = models.resnet18(pretrained=True) | |
| # 修改第一层以接受单通道(黑白)图像 | |
| resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) | |
| # 使用ResNet的特征提取部分 | |
| self.features = nn.Sequential(*list(resnet.children())[:-2]) | |
| # 添加判别器的额外层 | |
| self.classifier = nn.Sequential( | |
| nn.Linear(512, 1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) | |
| x = torch.flatten(x, 1) | |
| x = self.classifier(x) | |
| return x | |
| class VQGAN(nn.Module): | |
| """The VQ-GAN model. <https://openaccess.thecvf.com/content/CVPR2021/html/Esser_Taming_Transformers_for_High-Resolution_Image_Synthesis_CVPR_2021_paper.html?ref=>""" | |
| def __init__(self, in_channels, hidden_channels, embedding_dim, out_channels, block_depth=2, | |
| attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu", | |
| num_embeddings=1024, commitment_cost=0.25, decay=0.99, num_groups=32): | |
| super(VQGAN, self).__init__() | |
| self._encoder = Encoder(in_channels, hidden_channels, embedding_dim, block_depth=block_depth, | |
| attn_pos=attn_pos, attn_with_skip=attn_with_skip, norm_type=norm_type, act_type="act_type", num_groups=num_groups) | |
| if decay > 0.0: | |
| self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim, | |
| commitment_cost, decay) | |
| else: | |
| self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim, | |
| commitment_cost) | |
| self._decoder = Decoder(embedding_dim, hidden_channels, out_channels, block_depth=block_depth, | |
| attn_pos=attn_pos, attn_with_skip=attn_with_skip, norm_type=norm_type, | |
| act_type=act_type, num_groups=num_groups) | |
| def forward(self, x): | |
| z = self._encoder(x) | |
| quantized, vq_loss, (perplexity, _, _) = self._vq_vae(z) | |
| x_recon = self._decoder(quantized) | |
| return vq_loss, x_recon, perplexity | |
| class ReconstructionLoss(nn.Module): | |
| def __init__(self, w1, w2, epsilon=1e-3): | |
| super(ReconstructionLoss, self).__init__() | |
| self.w1 = w1 | |
| self.w2 = w2 | |
| self.epsilon = epsilon | |
| def weighted_mae_loss(self, y_true, y_pred): | |
| # avoid divide by zero | |
| y_true_safe = torch.clamp(y_true, min=self.epsilon) | |
| # compute weighted MAE | |
| loss = torch.mean(torch.abs(y_pred - y_true) / y_true_safe) | |
| return loss | |
| def mae_loss(self, y_true, y_pred): | |
| loss = torch.mean(torch.abs(y_pred - y_true)) | |
| return loss | |
| def forward(self, y_pred, y_true): | |
| # loss for magnitude channel | |
| log_magnitude_loss = self.w1 * self.weighted_mae_loss(y_pred[:, 0, :, :], y_true[:, 0, :, :]) | |
| # loss for phase channels | |
| phase_loss = self.w2 * self.mae_loss(y_pred[:, 1:, :, :], y_true[:, 1:, :, :]) | |
| # sum up | |
| rec_loss = log_magnitude_loss + phase_loss | |
| return log_magnitude_loss, phase_loss, rec_loss | |
| def evaluate_VQGAN(model, discriminator, iterator, reconstructionLoss, adversarial_loss, trainingConfig): | |
| model.to(trainingConfig["device"]) | |
| model.eval() | |
| train_res_error = [] | |
| for i in xrange(100): | |
| data = next(iter(iterator)) | |
| data = data.to(trainingConfig["device"]) | |
| # true/fake labels | |
| real_labels = torch.ones(data.size(0), 1).to(trainingConfig["device"]) | |
| vq_loss, data_recon, perplexity = model(data) | |
| fake_preds = discriminator(data_recon) | |
| adver_loss = adversarial_loss(fake_preds, real_labels) | |
| log_magnitude_loss, phase_loss, rec_loss = reconstructionLoss(data_recon, data) | |
| loss = rec_loss + trainingConfig["vq_weight"] * vq_loss + trainingConfig["adver_weight"] * adver_loss | |
| train_res_error.append(loss.item()) | |
| initial_loss = np.mean(train_res_error) | |
| return initial_loss | |
| def get_VQGAN(model_Config, load_pretrain=False, model_name=None, device="cpu"): | |
| VQVAE = VQGAN(**model_Config) | |
| print(f"Model intialized, size: {sum(p.numel() for p in VQVAE.parameters() if p.requires_grad)}") | |
| VQVAE.to(device) | |
| if load_pretrain: | |
| print(f"Loading weights from models/{model_name}_imageVQVAE.pth") | |
| checkpoint = torch.load(f'models/{model_name}_imageVQVAE.pth', map_location=device) | |
| VQVAE.load_state_dict(checkpoint['model_state_dict']) | |
| VQVAE.eval() | |
| return VQVAE | |
| def train_VQGAN(model_Config, trainingConfig, iterator): | |
| def save_model_hyperparameter(model_Config, trainingConfig, current_iter, | |
| log_magnitude_loss, phase_loss, current_perplexity, current_vq_loss, | |
| current_loss): | |
| model_name = trainingConfig["model_name"] | |
| model_hyperparameter = model_Config | |
| model_hyperparameter.update(trainingConfig) | |
| model_hyperparameter["current_iter"] = current_iter | |
| model_hyperparameter["log_magnitude_loss"] = log_magnitude_loss | |
| model_hyperparameter["phase_loss"] = phase_loss | |
| model_hyperparameter["erplexity"] = current_perplexity | |
| model_hyperparameter["vq_loss"] = current_vq_loss | |
| model_hyperparameter["total_loss"] = current_loss | |
| with open(f"models/hyperparameters/{model_name}_VQGAN_STFT.json", "w") as json_file: | |
| json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4) | |
| # initialize VAE | |
| model = VQGAN(**model_Config) | |
| print(f"VQ_VAE size: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") | |
| model.to(trainingConfig["device"]) | |
| VAE_optimizer = torch.optim.Adam(model.parameters(), lr=trainingConfig["lr"], amsgrad=False) | |
| model_name = trainingConfig["model_name"] | |
| if trainingConfig["load_pretrain"]: | |
| print(f"Loading weights from models/{model_name}_imageVQVAE.pth") | |
| checkpoint = torch.load(f'models/{model_name}_imageVQVAE.pth', map_location=trainingConfig["device"]) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| VAE_optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| else: | |
| print("VAE initialized.") | |
| if trainingConfig["max_iter"] == 0: | |
| print("Return VAE directly.") | |
| return model | |
| # initialize discriminator | |
| discriminator = VQGAN_Discriminator(model_Config["in_channels"]) | |
| print(f"Discriminator size: {sum(p.numel() for p in discriminator.parameters() if p.requires_grad)}") | |
| discriminator.to(trainingConfig["device"]) | |
| discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=trainingConfig["d_lr"], amsgrad=False) | |
| if trainingConfig["load_pretrain"]: | |
| print(f"Loading weights from models/{model_name}_imageVQVAE_discriminator.pth") | |
| checkpoint = torch.load(f'models/{model_name}_imageVQVAE_discriminator.pth', map_location=trainingConfig["device"]) | |
| discriminator.load_state_dict(checkpoint['model_state_dict']) | |
| discriminator_optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| else: | |
| print("Discriminator initialized.") | |
| # Training | |
| train_res_phase_loss, train_res_perplexity, train_res_log_magnitude_loss, train_res_vq_loss, train_res_loss = [], [], [], [], [] | |
| train_discriminator_loss, train_adverserial_loss = [], [] | |
| reconstructionLoss = ReconstructionLoss(w1=trainingConfig["w1"], w2=trainingConfig["w2"], epsilon=trainingConfig["threshold"]) | |
| adversarial_loss = nn.BCEWithLogitsLoss() | |
| writer = SummaryWriter(f'runs/{model_name}_VQVAE_lr=1e-4') | |
| previous_lowest_loss = evaluate_VQGAN(model, discriminator, iterator, | |
| reconstructionLoss, adversarial_loss, trainingConfig) | |
| print(f"initial_loss: {previous_lowest_loss}") | |
| model.train() | |
| for i in xrange(trainingConfig["max_iter"]): | |
| data = next(iter(iterator)) | |
| data = data.to(trainingConfig["device"]) | |
| # true/fake labels | |
| real_labels = torch.ones(data.size(0), 1).to(trainingConfig["device"]) | |
| fake_labels = torch.zeros(data.size(0), 1).to(trainingConfig["device"]) | |
| # update discriminator | |
| discriminator_optimizer.zero_grad() | |
| vq_loss, data_recon, perplexity = model(data) | |
| real_preds = discriminator(data) | |
| fake_preds = discriminator(data_recon.detach()) | |
| loss_real = adversarial_loss(real_preds, real_labels) | |
| loss_fake = adversarial_loss(fake_preds, fake_labels) | |
| loss_D = loss_real + loss_fake | |
| loss_D.backward() | |
| discriminator_optimizer.step() | |
| # update VQVAE | |
| VAE_optimizer.zero_grad() | |
| fake_preds = discriminator(data_recon) | |
| adver_loss = adversarial_loss(fake_preds, real_labels) | |
| log_magnitude_loss, phase_loss, rec_loss = reconstructionLoss(data_recon, data) | |
| loss = rec_loss + trainingConfig["vq_weight"] * vq_loss + trainingConfig["adver_weight"] * adver_loss | |
| loss.backward() | |
| VAE_optimizer.step() | |
| train_discriminator_loss.append(loss_D.item()) | |
| train_adverserial_loss.append(trainingConfig["adver_weight"] * adver_loss.item()) | |
| train_res_log_magnitude_loss.append(log_magnitude_loss.item()) | |
| train_res_phase_loss.append(phase_loss.item()) | |
| train_res_perplexity.append(perplexity.item()) | |
| train_res_vq_loss.append(trainingConfig["vq_weight"] * vq_loss.item()) | |
| train_res_loss.append(loss.item()) | |
| step = int(VAE_optimizer.state_dict()['state'][list(VAE_optimizer.state_dict()['state'].keys())[0]]['step'].cpu().numpy()) | |
| save_steps = trainingConfig["save_steps"] | |
| if (i + 1) % 100 == 0: | |
| print('%d step' % (step)) | |
| if (i + 1) % save_steps == 0: | |
| current_discriminator_loss = np.mean(train_discriminator_loss[-save_steps:]) | |
| current_adverserial_loss = np.mean(train_adverserial_loss[-save_steps:]) | |
| current_log_magnitude_loss = np.mean(train_res_log_magnitude_loss[-save_steps:]) | |
| current_phase_loss = np.mean(train_res_phase_loss[-save_steps:]) | |
| current_perplexity = np.mean(train_res_perplexity[-save_steps:]) | |
| current_vq_loss = np.mean(train_res_vq_loss[-save_steps:]) | |
| current_loss = np.mean(train_res_loss[-save_steps:]) | |
| print('discriminator_loss: %.3f' % current_discriminator_loss) | |
| print('adverserial_loss: %.3f' % current_adverserial_loss) | |
| print('log_magnitude_loss: %.3f' % current_log_magnitude_loss) | |
| print('phase_loss: %.3f' % current_phase_loss) | |
| print('perplexity: %.3f' % current_perplexity) | |
| print('vq_loss: %.3f' % current_vq_loss) | |
| print('total_loss: %.3f' % current_loss) | |
| writer.add_scalar(f"log_magnitude_loss", current_log_magnitude_loss, step) | |
| writer.add_scalar(f"phase_loss", current_phase_loss, step) | |
| writer.add_scalar(f"perplexity", current_perplexity, step) | |
| writer.add_scalar(f"vq_loss", current_vq_loss, step) | |
| writer.add_scalar(f"total_loss", current_loss, step) | |
| if current_loss < previous_lowest_loss: | |
| previous_lowest_loss = current_loss | |
| torch.save({ | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': VAE_optimizer.state_dict(), | |
| }, f'models/{model_name}_imageVQVAE.pth') | |
| torch.save({ | |
| 'model_state_dict': discriminator.state_dict(), | |
| 'optimizer_state_dict': discriminator_optimizer.state_dict(), | |
| }, f'models/{model_name}_imageVQVAE_discriminator.pth') | |
| save_model_hyperparameter(model_Config, trainingConfig, step, | |
| current_log_magnitude_loss, current_phase_loss, current_perplexity, current_vq_loss, | |
| current_loss) | |
| return model |