File size: 3,232 Bytes
52c73e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch.nn.functional as F
import torch.nn as nn
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader
from UNet import *
from tqdm.auto import tqdm

adv_criterion = nn.BCEWithLogitsLoss() 
recon_criterion = nn.L1Loss() 
lambda_recon = 200

n_epochs = 50
input_dim = 3
real_dim = 3
display_step = 200
batch_size = 4
lr = 0.0002
target_shape = 256
device = 'cuda'


transform = transforms.Compose([
    transforms.ToTensor(),
])

dataset = torchvision.datasets.ImageFolder("/datasets/maps", transform=transform)


gen = UNet(input_dim, real_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(),lr=lr)
disc = Discriminator(input_dim+real_dim).to(device)
disc_opt = torch.optim.Adam(disc.parameters(),lr=lr)

def weights_init(m):
    if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight,0.0,0.02)
    if isinstance(m,nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight,0.0,0.02)
        torch.nn.init.constant_(m.bias,0.0)

gen =gen.apply(weights_init)
disc = disc.apply(weights_init)


mean_gen_loss = 0
mean_disc_loss = 0
dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True)

cur_step = 0
for epoch in range(n_epochs):
    for image,_ in tqdm(dataloader):
        image_width = image.shape[3]
        condition = image[:, :, :, :image_width // 2]
        condition = nn.functional.interpolate(condition, size=target_shape)
        real = image[:, :, :, image_width // 2:]
        real = nn.functional.interpolate(real, size=target_shape)
        cur_batch_size = len(real)
        condition = condition.to(device)
        real = real.to(device)

        disc_opt.zero_grad()
        with torch.no_grad():
            fake = gen(condition)
        disc_fake_pred = disc(fake.detach(),condition)
        disc_fake_loss = adv_criterion(disc_fake_pred,torch.zeros_like(disc_fake_pred))
        disc_real_pred = disc(real,condition)
        disc_real_loss = adv_criterion(disc_real_pred,torch.ones_like(disc_real_pred))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2
        disc_loss.backward(retain_graph= True)
        disc_opt.step()

        gen.zero_grad()
        gen_loss = get_gen_loss(gen,disc,real,condition,adv_criterion,recon_criterion,lambda_recon)
        gen_loss.backward()
        gen_opt.step()

        mean_gen_loss += gen_loss.item() / display_step
        mean_disc_loss += disc_loss.item() / display_step

        if cur_step % display_step == 0:
            print(f"Epoch: {epoch} | Step: {cur_step} | Gen-Loss: {mean_gen_loss} | Disc-loss: {mean_disc_loss}")
            show_tensor_images(condition,epoch,cur_step,"condition", size=(input_dim, target_shape, target_shape))
            show_tensor_images(real, epoch,cur_step,"real",size=(real_dim, target_shape, target_shape))
            show_tensor_images(fake, epoch,cur_step,"generated",size=(real_dim, target_shape, target_shape))
            mean_gen_loss = 0
            mean_disc_loss = 0

            torch.save({
                "gen":gen,
                "disc":disc,
                "gen_opt": gen_opt,
                "disc_opt": disc_opt
            },f"checkpoints/Pix2Pix_Epoch{epoch}.pth")
        cur_step += 1