Yash Nagraj commited on
Commit
275907d
·
1 Parent(s): 89e0ef4

Add the training scripts for cloud training

Browse files
Files changed (5) hide show
  1. Discriminators.py +0 -0
  2. Generators.py +0 -0
  3. models.py +121 -0
  4. train.py +102 -0
  5. utils.py +109 -0
Discriminators.py DELETED
File without changes
Generators.py DELETED
File without changes
models.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch.nn as nn
3
+ class ResidualBlock(nn.Module):
4
+ def __init__(self,input_channels ) -> None:
5
+ super(ResidualBlock,self).__init__()
6
+ self.conv1 = nn.Conv2d(input_channels,input_channels,3,1,padding=1,padding_mode='reflect')
7
+ self.conv2 = nn.Conv2d(input_channels,input_channels,3,1,padding=1,padding_mode='reflect')
8
+ self.instanceNorm = nn.InstanceNorm2d(input_channels)
9
+ self.activation = nn.ReLU()
10
+
11
+ def forward(self,x):
12
+ original = x.copy()
13
+ x = self.conv1(x)
14
+ x = self.instanceNorm(x)
15
+ x = self.activation(x)
16
+ x = self.conv2(x)
17
+ x = self.instanceNorm(x)
18
+ return original + x
19
+
20
+
21
+
22
+ class ContractingBlock(nn.Module):
23
+ def __init__(self, input_channels, use_bn=True,kernel_size=3,activation='relu') -> None:
24
+ super(ContractingBlock,self).__init__()
25
+ self.conv1 = nn.Conv2d(input_channels, input_channels*2, kernel_size,padding=1,stride=2,padding_mode='reflect')
26
+ self.activation = nn.ReLU() if activation == 'relu' else nn.LeakyReLU(0.2)
27
+ if use_bn:
28
+ self.normalization = nn.InstanceNorm2d(input_channels)
29
+ self.use_bn = use_bn
30
+
31
+ def forward(self,x):
32
+ x = self.conv1(x)
33
+ if self.use_bn:
34
+ self.normalization(x)
35
+ x = self.activation(x)
36
+ return x
37
+
38
+
39
+ class ExpandingBlock(nn.Module):
40
+ def __init__(self,input_channels,use_bn=True) -> None:
41
+ super(ExpandingBlock, self).__init__()
42
+ self.conv1 = nn.ConvTranspose2d(input_channels, input_channels // 2, kernel_size=3,stride=2,padding=1,output_padding=1)
43
+ if use_bn:
44
+ self.normalization = nn.InstanceNorm2d(input_channels // 2)
45
+ self.use_bn = use_bn
46
+ self.activation = nn.ReLU()
47
+
48
+ def forward(self, x):
49
+ x = self.conv1(x)
50
+ if self.use_bn:
51
+ x = self.normalization(x)
52
+ x = self.activation(x)
53
+ return x
54
+
55
+
56
+
57
+ class FeatureMapBlock(nn.Module):
58
+ def __init__(self, input_channels, output_channels) -> None:
59
+ super(FeatureMapBlock,self).__init__()
60
+ self.conv = nn.Conv2d(input_channels, output_channels,kernel_size=7,padding=3,padding_mode='reflect')
61
+
62
+ def forward(self,x):
63
+ x = self.conv(x)
64
+ return x
65
+
66
+ class Generator(nn.Module):
67
+ def __init__(self, input_channels,output_channels, hidden_dim=64) -> None:
68
+ super(Generator,self).__init__()
69
+ self.upfeature = FeatureMapBlock(input_channels,hidden_dim)
70
+ self.contract1 = ContractingBlock(hidden_dim)
71
+ self.contract2 = ContractingBlock(hidden_dim * 2)
72
+ res_mult = 4
73
+ self.res0 = ResidualBlock(hidden_dim * res_mult)
74
+ self.res1 = ResidualBlock(hidden_dim * res_mult)
75
+ self.res2 = ResidualBlock(hidden_dim * res_mult)
76
+ self.res3 = ResidualBlock(hidden_dim * res_mult)
77
+ self.res4 = ResidualBlock(hidden_dim * res_mult)
78
+ self.res5 = ResidualBlock(hidden_dim * res_mult)
79
+ self.res6 = ResidualBlock(hidden_dim * res_mult)
80
+ self.res7 = ResidualBlock(hidden_dim * res_mult)
81
+ self.res8 = ResidualBlock(hidden_dim * res_mult)
82
+ self.expand1 = ExpandingBlock(hidden_dim * res_mult)
83
+ self.expand2 = ExpandingBlock(hidden_dim * 2)
84
+ self.downfeature = FeatureMapBlock(hidden_dim,output_channels)
85
+ self.tanh = nn.Tanh()
86
+
87
+ def forward(self, x):
88
+ x0 = self.upfeature(x)
89
+ x1 = self.contract1(x0)
90
+ x2 = self.contract2(x1)
91
+ x3 = self.res0(x2)
92
+ x4 = self.res1(x3)
93
+ x5 = self.res2(x4)
94
+ x6 = self.res3(x5)
95
+ x7 = self.res4(x6)
96
+ x8 = self.res5(x7)
97
+ x9 = self.res6(x8)
98
+ x10 = self.res7(x9)
99
+ x11 = self.res8(x10)
100
+ x12 = self.expand1(x11)
101
+ x13 = self.expand2(x12)
102
+ xn = self.downfeature(x13)
103
+ return self.tanh(xn)
104
+
105
+
106
+ class Discriminator(nn.Module):
107
+ def __init__(self, input_channels, hidden_channels=64) -> None:
108
+ super(Discriminator,self).__init__()
109
+ self.upfeature = FeatureMapBlock(input_channels,hidden_channels)
110
+ self.contract1 = ContractingBlock(hidden_channels, False,kernel_size=4,activation='lrelu')
111
+ self.contract2 = ContractingBlock(hidden_channels * 2,kernel_size=4,activation='lrelu')
112
+ self.contract3 = ContractingBlock(hidden_channels * 4,kernel_size=4,activation='lrelu')
113
+ self.conv = nn.Conv2d(hidden_channels*8,1,kernel_size=1)
114
+
115
+ def forward(self,x):
116
+ x0 = self.upfeature(x)
117
+ x1 = self.contract1(x0)
118
+ x2 = self.contract2(x1)
119
+ x3 = self.contract3(x2)
120
+ x4 = self.conv(x3)
121
+ return x4
train.py CHANGED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torchvision import transforms
3
+ from utils import *
4
+ from models import Generator , Discriminator
5
+ from tqdm.auto import tqdm
6
+
7
+ adv_criterion = nn.MSELoss()
8
+ recon_criterion = nn.L1Loss()
9
+
10
+ n_epochs = 60
11
+ dim_A = 3
12
+ dim_B = 3
13
+ display_step = 200
14
+ batch_size = 1
15
+ lr = 0.0002
16
+ load_shape = 286
17
+ target_shape = 256
18
+ device='cuda'
19
+
20
+
21
+ transform = transforms.Compose([
22
+ transforms.Resize(load_shape),
23
+ transforms.RandomCrop(target_shape),
24
+ transforms.RandomHorizontalFlip(),
25
+ transforms.ToTensor(),
26
+ ])
27
+
28
+ dataset = ImageDataset("horse2zebra", transform=transform)
29
+
30
+ gen_AB = Generator(dim_A,dim_B).to(device)
31
+ gen_BA = Generator(dim_B,dim_A).to(device)
32
+ gen_opt = torch.optim.Adam(list(gen_AB.parameters()) + list(gen_BA.parameters()),lr = lr,betas=(0.5,0.999))
33
+ disc_A = Discriminator(dim_A).to(device)
34
+ disc_A_opt = torch.optim.Adam(disc_A.parameters(),lr=lr,betas=(0.5,0.999))
35
+ disc_B = Discriminator(dim_B).to(device)
36
+ disc_B_opt = torch.optim.Adam(disc_B.parameters(),lr=lr,betas=(0.5,0.999))
37
+
38
+
39
+ gen_AB = gen_AB.apply(weights_init)
40
+ gen_BA = gen_BA.apply(weights_init)
41
+ disc_A = disc_A.apply(weights_init)
42
+ disc_B = disc_B.apply(weights_init)
43
+
44
+
45
+
46
+ def train():
47
+ mean_gen_loss = 0
48
+ mean_disc_loss = 0
49
+ dataloader = DataLoader(dataset,batch_size,shuffle=True)
50
+ cur_step = 0
51
+
52
+ for epoch in range(n_epochs):
53
+ for real_A,real_B in tqdm(dataloader):
54
+ real_A = nn.functional.interpolate(real_A,size=target_shape)
55
+ real_B = nn.functional.interpolate(real_B,size=target_shape)
56
+ cur_batch_size = len(real_A)
57
+ real_A = real_A.to(device)
58
+ real_B = real_B.to(device)
59
+
60
+ disc_A_opt.zero_grad()
61
+ with torch.no_grad():
62
+ fake_A = gen_BA(real_A)
63
+ disc_A_loss = get_disc_loss(real_A,fake_A,disc_A,adv_criterion)
64
+ disc_A_loss.backward(retain_graph=True)
65
+ disc_A_opt.step()
66
+
67
+ disc_B_opt.zero_grad()
68
+ with torch.no_grad():
69
+ fake_B = gen_AB(real_B)
70
+ disc_B_loss = get_disc_loss(real_B,fake_B,disc_B,adv_criterion)
71
+ disc_B_loss.backward(retain_graph=True)
72
+ disc_B_opt.step()
73
+
74
+ gen_opt.zero_grad()
75
+ gen_loss ,fake_A,fake_B= get_gen_loss(real_A,real_B,gen_AB,gen_BA,disc_B,disc_A,adv_criterion=,identity_criterion=recon_criterion,cycle_criterion=recon_criterion)
76
+ gen_loss.backward()
77
+ gen_opt.step()
78
+
79
+ mean_gen_loss += gen_loss.item() / display_step
80
+
81
+ mean_disc_loss += disc_A_loss.item() / display_step
82
+
83
+ if cur_step % display_step == 0 and cur_step > 0:
84
+ print(f"Epoch: {epoch} | Step: {cur_step} | Gen_loss: {mean_gen_loss} | Disc_loss: {mean_disc_loss} |")
85
+ show_tensor_images(torch.cat([real_A,real_B]),size=(dim_A,target_shape,target_shape))
86
+ show_tensor_images(torch.cat([fake_A,fake_B]),size=(dim_B,target_shape,target_shape))
87
+ mean_gen_loss = 0
88
+ mean_disc_loss = 0
89
+ torch.save({
90
+ 'gen_AB': gen_AB,
91
+ 'gen_BA': gen_BA,
92
+ 'gen_opt': gen_opt,
93
+ 'disc_A': disc_A,
94
+ 'disc_A_opt': disc_A_opt,
95
+ 'disc_B': disc_B,
96
+ 'disc_B_opt': disc_B_opt
97
+ }, f"checkpoints/cycleGAN_{cur_step}.pth")
98
+
99
+ cur_step += 1
100
+
101
+ if __name__ == "__main__":
102
+ train()
utils.py CHANGED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision.utils import make_grid
5
+ from torch.utils.data import DataLoader
6
+ import matplotlib.pyplot as plt
7
+ import glob
8
+ import os
9
+ from torch.utils.data import Dataset
10
+ from PIL import Image
11
+
12
+ def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
13
+ '''
14
+ Function for visualizing images: Given a tensor of images, number of images, and
15
+ size per image, plots and prints the images in an uniform grid.
16
+ '''
17
+ image_tensor = (image_tensor + 1) / 2
18
+ image_shifted = image_tensor
19
+ image_unflat = image_shifted.detach().cpu().view(-1, *size)
20
+ image_grid = make_grid(image_unflat[:num_images], nrow=5)
21
+ plt.imshow(image_grid.permute(1, 2, 0).squeeze())
22
+ plt.show()
23
+
24
+
25
+ class ImageDataset(Dataset):
26
+ def __init__(self, root, transform=None, mode='train'):
27
+ self.transform = transform
28
+ self.files_A = sorted(glob.glob(os.path.join(root, '%sA' % mode) + '/*.*'))
29
+ self.files_B = sorted(glob.glob(os.path.join(root, '%sB' % mode) + '/*.*'))
30
+ if len(self.files_A) > len(self.files_B):
31
+ self.files_A, self.files_B = self.files_B, self.files_A
32
+ self.new_perm()
33
+ assert len(self.files_A) > 0, "Make sure you downloaded the horse2zebra images!"
34
+
35
+ def new_perm(self):
36
+ self.randperm = torch.randperm(len(self.files_B))[:len(self.files_A)]
37
+
38
+ def __getitem__(self, index):
39
+ item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))
40
+ item_B = self.transform(Image.open(self.files_B[self.randperm[index]]))
41
+ if item_A.shape[0] != 3:
42
+ item_A = item_A.repeat(3, 1, 1)
43
+ if item_B.shape[0] != 3:
44
+ item_B = item_B.repeat(3, 1, 1)
45
+ if index == len(self) - 1:
46
+ self.new_perm()
47
+ # Old versions of PyTorch didn't support normalization for different-channeled images
48
+ return (item_A - 0.5) * 2, (item_B - 0.5) * 2
49
+
50
+ def __len__(self):
51
+ return min(len(self.files_A), len(self.files_B))
52
+
53
+
54
+ def weights_init(m):
55
+ if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d):
56
+ torch.nn.init.normal_(m.weight,1.0,0.2)
57
+ if isinstance(m, nn.BatchNorm2d):
58
+ torch.nn.init.normal_(m.weight, 0.0, 0.02)
59
+ torch.nn.init.constant_(m.bias, 0)
60
+
61
+
62
+ def get_disc_loss(real_X, fake_X,disc_X, adv_criterion):
63
+ real_pred = disc_X(real_X.detach())
64
+ disc_real_loss = adv_criterion(real_pred,torch.ones_like(real_pred))
65
+ fake_pred = disc_X(fake_X.deatch())
66
+ disc_fake_loss = adv_criterion(fake_pred.detach(),torch.zeros_like(fake_pred))
67
+ disc_loss = (disc_real_loss + disc_fake_loss) / 2
68
+ return disc_loss
69
+
70
+
71
+ def get_gen_adversarial_loss(real_X, disc_Y, gen_XY, adv_criterion):
72
+ fake_Y = gen_XY(real_X.detach())
73
+ disc_pred = disc_Y(fake_Y)
74
+ adverserial_loss = adv_criterion(disc_pred,torch.ones_like(disc_pred))
75
+ return adverserial_loss,fake_Y
76
+
77
+ def get_identity_loss(real_X, gen_YX,identity_criterion):
78
+ identity_X = gen_YX(real_X)
79
+ identity_loss = identity_criterion(identity_X,real_X)
80
+ return identity_loss,identity_X
81
+
82
+
83
+
84
+ def get_cycle_consistency_loss(real_X, fake_Y, gen_YX, cycle_criterion):
85
+ cycle_X = gen_YX(fake_Y)
86
+ cycle_loss = cycle_criterion(cycle_X,real_X)
87
+ return cycle_loss,cycle_X
88
+
89
+
90
+
91
+ def get_gen_loss(real_A, real_B,gen_AB,gen_BA,disc_B,disc_A,adv_criterion,cycle_criterion,identity_criterion,lambda_identity=0.2,lambda_cycle=10):
92
+ adv_loss_BA, fake_A = get_gen_adversarial_loss(real_B, disc_A, gen_BA, adv_criterion)
93
+ adv_loss_AB, fake_B = get_gen_adversarial_loss(real_A, disc_B, gen_AB, adv_criterion)
94
+ gen_adversarial_loss = adv_loss_BA + adv_loss_AB
95
+
96
+ # Identity Loss -- get_identity_loss(real_X, gen_YX, identity_criterion)
97
+ identity_loss_A, identity_A = get_identity_loss(real_A, gen_BA, identity_criterion)
98
+ identity_loss_B, identity_B = get_identity_loss(real_B, gen_AB, identity_criterion)
99
+ gen_identity_loss = identity_loss_A + identity_loss_B
100
+
101
+ # Cycle-consistency Loss -- get_cycle_consistency_loss(real_X, fake_Y, gen_YX, cycle_criterion)
102
+ cycle_loss_BA, cycle_A = get_cycle_consistency_loss(real_A, fake_B, gen_BA, cycle_criterion)
103
+ cycle_loss_AB, cycle_B = get_cycle_consistency_loss(real_B, fake_A, gen_AB, cycle_criterion)
104
+ gen_cycle_loss = cycle_loss_BA + cycle_loss_AB
105
+
106
+ # Total loss
107
+ gen_loss = lambda_identity * gen_identity_loss + lambda_cycle * gen_cycle_loss + gen_adversarial_loss
108
+
109
+ return gen_loss , fake_A,fake_B