File size: 4,252 Bytes
582b238 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | import torch
import torch.nn.functional as F
from IPython.display import clear_output
# @title Trainer
class Trainer():
def __init__(self, encoder, decoder, D, vgg, losses, data_len, ema=3, a_disc=1, a_vae=1, a_KL=0.1, isViT=True):
self.vgg_schedule = None
self.ema = 2/(ema+1)
self.a_disc = a_disc
self.a_vae = a_vae
self.a_KL = a_KL
self.isViT = isViT
self.encoder = encoder
self.decoder = decoder
self.D = D
self.vgg = vgg
self.encoder_optimizer = torch.optim.Adam(self.encoder.parameters(), lr=1e-5)
self.encoder_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.encoder_optimizer, T_max=50)
self.decoder_optimizer = torch.optim.Adam(self.decoder.parameters(), lr=1e-5)
self.decoder_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.decoder_optimizer, T_max=50)
self.D_optimizer = torch.optim.Adam(self.D.parameters(), lr=4e-5)
self.D_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.D_optimizer, T_max=50)
self.losses = losses
self.loss_vals = {loss:0 for loss in losses}
self.data_len = data_len
self.loss_record = []
self.epoch = 1
self.index = 1
self.device = torch.device("cuda")
self.encoder.to(self.device)
self.decoder.to(self.device)
self.D.to(self.device)
self.vgg.to(self.device)
def train_step(self, x, with_mse=False, freeze_ae=False, freeze_disc=False):
self.index += 1
x = x.to(self.device)
with torch.no_grad():
x_hat = self.decoder(self.encoder(x.permute(0,2,3,1))).permute(0,3,1,2) if not self.isViT else self.decoder(self.encoder(x))
if not freeze_disc:
disc_loss = F.relu(1. - self.D(x)).mean() + F.relu(1. + self.D(x_hat)).mean() # Hinge
self.D_optimizer.zero_grad()
disc_loss.backward()
self.D_optimizer.step()
self.D_scheduler.step()
if not freeze_ae:
z = self.encoder(x.permute(0,2,3,1)) if not self.isViT else self.encoder(x)
x_hat = self.decoder(z).permute(0,3,1,2) if not self.isViT else self.decoder(z)
mse = F.mse_loss(x_hat, x)
KL = 0.5 * (z.mean() ** 2)
vgg_real = self.vgg(x)
vgg_fake = self.vgg(x_hat)
vgg_loss = 0
for i in range(len(vgg_real)):
vgg_loss += F.mse_loss(vgg_real[i], vgg_fake[i])
adv_loss = 0
if not freeze_disc:
adv_loss = -(self.D(self.decoder(self.encoder(x))).mean())
loss = mse * with_mse + self.a_KL* KL + vgg_loss + self.a_vae * adv_loss
self.encoder_optimizer.zero_grad()
self.decoder_optimizer.zero_grad()
loss.backward()
self.encoder_optimizer.step()
self.decoder_optimizer.step()
self.encoder_scheduler.step()
self.decoder_scheduler.step()
self.update_batch({"mse":mse.item() if not freeze_ae else 0,
"gan":disc_loss.item() if not freeze_disc else 0,
"vgg":vgg_loss.item() if not freeze_ae else 0,
"KL":z.mean() if not freeze_ae else 0})
def update_batch(self, loss_vals):
clear_output(wait=True)
for record in self.loss_record:
print(record)
self.loss_vals = {loss:(1-self.ema)*self.loss_vals[loss] + self.ema*loss_vals[loss] for loss in self.losses}
print(f"epoch:{self.epoch} ", end="")
for loss in self.losses:
print(f"{loss}: {self.loss_vals[loss]:.3f} ", end="")
for _ in range(int(self.index * 20 / self.data_len)):
print("=", end="")
for _ in range(int(self.index * 20 / self.data_len),20):
print("-", end="")
def update_epoch(self):
self.index = 0
record = f"epoch:{self.epoch} "
for loss in self.losses:
record += f"{loss}: {self.loss_vals[loss]:.3f} "
self.loss_record.append(record)
self.epoch += 1 |