File size: 4,252 Bytes
582b238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn.functional as F
from IPython.display import clear_output

# @title Trainer
class Trainer():
    def __init__(self, encoder, decoder, D, vgg, losses, data_len, ema=3, a_disc=1, a_vae=1, a_KL=0.1, isViT=True):
        self.vgg_schedule = None
        self.ema = 2/(ema+1)
        self.a_disc = a_disc
        self.a_vae = a_vae
        self.a_KL = a_KL

        self.isViT = isViT
        self.encoder = encoder
        self.decoder = decoder
        self.D = D
        self.vgg = vgg
        self.encoder_optimizer = torch.optim.Adam(self.encoder.parameters(),  lr=1e-5)
        self.encoder_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.encoder_optimizer, T_max=50)
        self.decoder_optimizer = torch.optim.Adam(self.decoder.parameters(),  lr=1e-5)
        self.decoder_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.decoder_optimizer, T_max=50)
        self.D_optimizer = torch.optim.Adam(self.D.parameters(),  lr=4e-5)
        self.D_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.D_optimizer, T_max=50)
        self.losses = losses
        self.loss_vals = {loss:0 for loss in losses}
        self.data_len = data_len
        self.loss_record = []
        self.epoch = 1
        self.index = 1
        self.device = torch.device("cuda")

        self.encoder.to(self.device)
        self.decoder.to(self.device)
        self.D.to(self.device)
        self.vgg.to(self.device)

    def train_step(self, x, with_mse=False, freeze_ae=False, freeze_disc=False):
        self.index += 1
        x = x.to(self.device)
        with torch.no_grad():
            x_hat = self.decoder(self.encoder(x.permute(0,2,3,1))).permute(0,3,1,2) if not self.isViT else self.decoder(self.encoder(x))
        if not freeze_disc:
            disc_loss = F.relu(1. - self.D(x)).mean() + F.relu(1. + self.D(x_hat)).mean() # Hinge
            self.D_optimizer.zero_grad()
            disc_loss.backward()
            self.D_optimizer.step()
            self.D_scheduler.step()

        if not freeze_ae:
            z = self.encoder(x.permute(0,2,3,1)) if not self.isViT else self.encoder(x)
            x_hat = self.decoder(z).permute(0,3,1,2) if not self.isViT else self.decoder(z)
            mse = F.mse_loss(x_hat, x)
            KL = 0.5 * (z.mean() ** 2)
            vgg_real = self.vgg(x)
            vgg_fake = self.vgg(x_hat)
            vgg_loss = 0
            for i in range(len(vgg_real)):
                vgg_loss += F.mse_loss(vgg_real[i], vgg_fake[i]) 

            adv_loss = 0
            if not freeze_disc:
                adv_loss = -(self.D(self.decoder(self.encoder(x))).mean())

            loss = mse * with_mse + self.a_KL* KL + vgg_loss + self.a_vae * adv_loss
            self.encoder_optimizer.zero_grad()
            self.decoder_optimizer.zero_grad()
            loss.backward()
            self.encoder_optimizer.step()
            self.decoder_optimizer.step()
            self.encoder_scheduler.step()
            self.decoder_scheduler.step()

        self.update_batch({"mse":mse.item() if not freeze_ae else 0,
                           "gan":disc_loss.item() if not freeze_disc else 0,
                           "vgg":vgg_loss.item() if not freeze_ae else 0,
                           "KL":z.mean() if not freeze_ae else 0})

    def update_batch(self, loss_vals):
        clear_output(wait=True)
        for record in self.loss_record:
            print(record)
        self.loss_vals = {loss:(1-self.ema)*self.loss_vals[loss] + self.ema*loss_vals[loss] for loss in self.losses}
        print(f"epoch:{self.epoch} ", end="")
        for loss in self.losses:
            print(f"{loss}: {self.loss_vals[loss]:.3f} ", end="")
        for _ in range(int(self.index * 20 / self.data_len)):
            print("=", end="")
        for _ in range(int(self.index * 20 / self.data_len),20):
            print("-", end="")

    def update_epoch(self):
        self.index = 0
        record = f"epoch:{self.epoch} "
        for loss in self.losses:
            record += f"{loss}: {self.loss_vals[loss]:.3f} "
        self.loss_record.append(record)
        self.epoch += 1