| | import torch
|
| | import torch.nn.functional as F
|
| | from IPython.display import clear_output
|
| |
|
| |
|
| | class Trainer():
|
| | def __init__(self, encoder, decoder, D, vgg, losses, data_len, ema=3, a_disc=1, a_vae=1, a_KL=0.1, isViT=True):
|
| | self.vgg_schedule = None
|
| | self.ema = 2/(ema+1)
|
| | self.a_disc = a_disc
|
| | self.a_vae = a_vae
|
| | self.a_KL = a_KL
|
| |
|
| | self.isViT = isViT
|
| | self.encoder = encoder
|
| | self.decoder = decoder
|
| | self.D = D
|
| | self.vgg = vgg
|
| | self.encoder_optimizer = torch.optim.Adam(self.encoder.parameters(), lr=1e-5)
|
| | self.encoder_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.encoder_optimizer, T_max=50)
|
| | self.decoder_optimizer = torch.optim.Adam(self.decoder.parameters(), lr=1e-5)
|
| | self.decoder_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.decoder_optimizer, T_max=50)
|
| | self.D_optimizer = torch.optim.Adam(self.D.parameters(), lr=4e-5)
|
| | self.D_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.D_optimizer, T_max=50)
|
| | self.losses = losses
|
| | self.loss_vals = {loss:0 for loss in losses}
|
| | self.data_len = data_len
|
| | self.loss_record = []
|
| | self.epoch = 1
|
| | self.index = 1
|
| | self.device = torch.device("cuda")
|
| |
|
| | self.encoder.to(self.device)
|
| | self.decoder.to(self.device)
|
| | self.D.to(self.device)
|
| | self.vgg.to(self.device)
|
| |
|
| | def train_step(self, x, with_mse=False, freeze_ae=False, freeze_disc=False):
|
| | self.index += 1
|
| | x = x.to(self.device)
|
| | with torch.no_grad():
|
| | x_hat = self.decoder(self.encoder(x.permute(0,2,3,1))).permute(0,3,1,2) if not self.isViT else self.decoder(self.encoder(x))
|
| | if not freeze_disc:
|
| | disc_loss = F.relu(1. - self.D(x)).mean() + F.relu(1. + self.D(x_hat)).mean()
|
| | self.D_optimizer.zero_grad()
|
| | disc_loss.backward()
|
| | self.D_optimizer.step()
|
| | self.D_scheduler.step()
|
| |
|
| | if not freeze_ae:
|
| | z = self.encoder(x.permute(0,2,3,1)) if not self.isViT else self.encoder(x)
|
| | x_hat = self.decoder(z).permute(0,3,1,2) if not self.isViT else self.decoder(z)
|
| | mse = F.mse_loss(x_hat, x)
|
| | KL = 0.5 * (z.mean() ** 2)
|
| | vgg_real = self.vgg(x)
|
| | vgg_fake = self.vgg(x_hat)
|
| | vgg_loss = 0
|
| | for i in range(len(vgg_real)):
|
| | vgg_loss += F.mse_loss(vgg_real[i], vgg_fake[i])
|
| |
|
| | adv_loss = 0
|
| | if not freeze_disc:
|
| | adv_loss = -(self.D(self.decoder(self.encoder(x))).mean())
|
| |
|
| | loss = mse * with_mse + self.a_KL* KL + vgg_loss + self.a_vae * adv_loss
|
| | self.encoder_optimizer.zero_grad()
|
| | self.decoder_optimizer.zero_grad()
|
| | loss.backward()
|
| | self.encoder_optimizer.step()
|
| | self.decoder_optimizer.step()
|
| | self.encoder_scheduler.step()
|
| | self.decoder_scheduler.step()
|
| |
|
| | self.update_batch({"mse":mse.item() if not freeze_ae else 0,
|
| | "gan":disc_loss.item() if not freeze_disc else 0,
|
| | "vgg":vgg_loss.item() if not freeze_ae else 0,
|
| | "KL":z.mean() if not freeze_ae else 0})
|
| |
|
| | def update_batch(self, loss_vals):
|
| | clear_output(wait=True)
|
| | for record in self.loss_record:
|
| | print(record)
|
| | self.loss_vals = {loss:(1-self.ema)*self.loss_vals[loss] + self.ema*loss_vals[loss] for loss in self.losses}
|
| | print(f"epoch:{self.epoch} ", end="")
|
| | for loss in self.losses:
|
| | print(f"{loss}: {self.loss_vals[loss]:.3f} ", end="")
|
| | for _ in range(int(self.index * 20 / self.data_len)):
|
| | print("=", end="")
|
| | for _ in range(int(self.index * 20 / self.data_len),20):
|
| | print("-", end="")
|
| |
|
| | def update_epoch(self):
|
| | self.index = 0
|
| | record = f"epoch:{self.epoch} "
|
| | for loss in self.losses:
|
| | record += f"{loss}: {self.loss_vals[loss]:.3f} "
|
| | self.loss_record.append(record)
|
| | self.epoch += 1 |