Raid41 commited on
Commit
7b8a199
·
1 Parent(s): 572ec06

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +209 -11
train.py CHANGED
@@ -4,6 +4,7 @@ import torch.optim as optim
4
  import numpy as np
5
  import albumentations as albu
6
  import argparse
 
7
 
8
  from utils.utils import open_json, weights_init, weights_init_spectr, generate_mask
9
  from model.models import Colorizer, Generator, Content, Discriminator
@@ -26,8 +27,8 @@ def parse_args():
26
  def get_transforms():
27
  return albu.Compose([albu.RandomCrop(512, 512, always_apply = True), albu.HorizontalFlip(p = 0.5)], p = 1.)
28
 
29
- def get_dataloaders(data_path, transforms, batch_size, fine_tuning):
30
- train_dataset = TrainDataset(data_path, transforms)
31
  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
32
 
33
  if fine_tuning:
@@ -58,12 +59,17 @@ def set_weights(colorizer, discriminator):
58
 
59
  discriminator.apply(weights_init_spectr)
60
 
61
- def get_losses():
62
- L1_loss = nn.L1Loss()
63
- mse_loss = nn.MSELoss()
64
- bce_loss = nn.BCEWithLogitsLoss()
65
 
66
- return L1_loss, bce_loss, mse_loss
 
 
 
 
 
 
67
 
68
  def get_optimizers(colorizer, discriminator, generator_lr, discriminator_lr):
69
  optimizerG = optim.Adam(colorizer.generator.parameters(), lr = generator_lr, betas=(0.5, 0.9))
@@ -71,8 +77,196 @@ def get_optimizers(colorizer, discriminator, generator_lr, discriminator_lr):
71
 
72
  return optimizerG, optimizerD
73
 
 
 
 
 
 
 
 
 
 
 
74
 
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  if __name__ == '__main__':
77
  args = parse_args()
78
  config = open_json('configs/train_config.json')
@@ -84,13 +278,17 @@ if __name__ == '__main__':
84
 
85
  augmentations = get_transforms()
86
 
87
- train_dataloader, ft_dataloader = get_dataloaders(args.path, augmentations, config['batch_size'], args.fine_tuning)
88
 
89
  colorizer, discriminator, content = get_models(device)
90
  set_weights(colorizer, discriminator)
91
 
92
- l1_loss, bce_loss, mse_loss = get_losses()
93
-
94
  gen_optimizer, disc_optimizer = get_optimizers(colorizer, discriminator, config['generator_lr'], config['discriminator_lr'])
95
 
96
-
 
 
 
 
 
 
 
4
  import numpy as np
5
  import albumentations as albu
6
  import argparse
7
+ import datetime
8
 
9
  from utils.utils import open_json, weights_init, weights_init_spectr, generate_mask
10
  from model.models import Colorizer, Generator, Content, Discriminator
 
27
  def get_transforms():
28
  return albu.Compose([albu.RandomCrop(512, 512, always_apply = True), albu.HorizontalFlip(p = 0.5)], p = 1.)
29
 
30
+ def get_dataloaders(data_path, transforms, batch_size, fine_tuning, mult_number):
31
+ train_dataset = TrainDataset(data_path, transforms, mult_number)
32
  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
33
 
34
  if fine_tuning:
 
59
 
60
  discriminator.apply(weights_init_spectr)
61
 
62
+ def generator_loss(disc_output, true_labels, main_output, guide_output, real_image, content_gen, content_true, dist_loss = nn.L1Loss(), content_dist_loss = nn.MSELoss(), class_loss = nn.BCEWithLogitsLoss()):
63
+ sim_loss_full = dist_loss(main_output, real_image)
64
+ sim_loss_guide = dist_loss(guide_output, real_image)
 
65
 
66
+ adv_loss = class_loss(disc_output, true_labels)
67
+
68
+ content_loss = content_dist_loss(content_gen, content_true)
69
+
70
+ sum_loss = 10 * (sim_loss_full + 0.9 * sim_loss_guide) + adv_loss + content_loss
71
+
72
+ return sum_loss
73
 
74
  def get_optimizers(colorizer, discriminator, generator_lr, discriminator_lr):
75
  optimizerG = optim.Adam(colorizer.generator.parameters(), lr = generator_lr, betas=(0.5, 0.9))
 
77
 
78
  return optimizerG, optimizerD
79
 
80
+ def generator_step(inputs, colorizer, discriminator, content, loss_function, optimizer, device, white_penalty = True):
81
+ for p in discriminator.parameters():
82
+ p.requires_grad = False
83
+ for p in colorizer.generator.parameters():
84
+ p.requires_grad = True
85
+
86
+ colorizer.generator.zero_grad()
87
+
88
+ bw, color, hint, dfm = inputs
89
+ bw, color, hint, dfm = bw.to(device), color.to(device), hint.to(device), dfm.to(device)
90
 
91
+ fake, guide = colorizer(torch.cat([bw, dfm, hint], 1))
92
 
93
+ logits_fake = discriminator(fake)
94
+ y_real = torch.ones((bw.size(0), 1), device = device)
95
+
96
+ content_fake = content(fake)
97
+ with torch.no_grad():
98
+ content_true = content(color)
99
+
100
+ generator_loss = loss_function(logits_fake, y_real, fake, guide, color, content_fake, content_true)
101
+
102
+ if white_penalty:
103
+ mask = (~((color > 0.85).float().sum(dim = 1) == 3).unsqueeze(1).repeat((1, 3, 1, 1 ))).float()
104
+ white_zones = mask * (fake + 1) / 2
105
+ white_penalty = (torch.pow(white_zones.sum(dim = 1), 2).sum(dim = (1, 2)) / (mask.sum(dim = (1, 2, 3)) + 1)).mean()
106
+
107
+ generator_loss += white_penalty
108
+
109
+ generator_loss.backward()
110
+
111
+ optimizer.step()
112
+
113
+ return generator_loss.item()
114
+
115
+ def discriminator_step(inputs, colorizer, discriminator, optimizer, device, loss_function = nn.BCEWithLogitsLoss()):
116
+
117
+ for p in discriminator.parameters():
118
+ p.requires_grad = True
119
+ for p in colorizer.generator.parameters():
120
+ p.requires_grad = False
121
+
122
+ discriminator.zero_grad()
123
+
124
+ bw, color, hint, dfm = inputs
125
+ bw, color, hint, dfm = bw.to(device), color.to(device), hint.to(device), dfm.to(device)
126
+
127
+ y_real = torch.full((bw.size(0), 1), 0.9, device = device)
128
+
129
+ y_fake = torch.zeros((bw.size(0), 1), device = device)
130
+
131
+ with torch.no_grad():
132
+ fake_color, _ = colorizer(torch.cat([bw, dfm, hint], 1))
133
+ fake_color.detach()
134
+
135
+ logits_fake = discriminator(fake_color)
136
+ logits_real = discriminator(color)
137
+
138
+ fake_loss = loss_function(logits_fake, y_fake)
139
+ real_loss = loss_function(logits_real, y_real)
140
+
141
+ discriminator_loss = real_loss + fake_loss
142
+
143
+ discriminator_loss.backward()
144
+ optimizer.step()
145
+
146
+ return discriminator_loss.item()
147
+
148
+ def decrease_lr(optimizer, rate):
149
+ for group in optimizer.param_groups:
150
+ group['lr'] /= rate
151
+
152
+ def set_lr(optimizer, value):
153
+ for group in optimizer.param_groups:
154
+ group['lr'] = value
155
+
156
+ def train(colorizer, discriminator, content, dataloader, epochs, colorizer_optimizer, discriminator_optimizer, lr_decay_epoch = -1, device = 'cpu'):
157
+ colorizer.generator.train()
158
+ discriminator.train()
159
+
160
+ disc_step = True
161
+
162
+ for epoch in range(epochs):
163
+ if (epoch == lr_decay_epoch):
164
+ decrease_lr(colorizer_optimizer, 10)
165
+ decrease_lr(discriminator_optimizer, 10)
166
+
167
+ sum_disc_loss = 0
168
+ sum_gen_loss = 0
169
+
170
+ for n, inputs in enumerate(dataloader):
171
+ if n % 5 == 0:
172
+ print(datetime.datetime.now().time())
173
+ print('Step : %d Discr loss: %.4f Gen loss : %.4f \n'%(n, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))
174
+
175
+
176
+ if disc_step:
177
+ step_loss = discriminator_step(inputs, colorizer, discriminator, discriminator_optimizer, device)
178
+ sum_disc_loss += step_loss
179
+ else:
180
+ step_loss = generator_step(inputs, colorizer, discriminator, content, generator_loss, colorizer_optimizer, device)
181
+ sum_gen_loss += step_loss
182
+
183
+ disc_step = disc_step ^ True
184
+
185
+
186
+ print(datetime.datetime.now().time())
187
+ print('Epoch : %d Discr loss: %.4f Gen loss : %.4f \n'%(epoch, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))
188
+
189
+
190
+ def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_optimizer, device, loss_function = nn.BCEWithLogitsLoss()):
191
+
192
+ for p in discriminator.parameters():
193
+ p.requires_grad = True
194
+ for p in colorizer.generator.parameters():
195
+ p.requires_grad = False
196
+
197
+ for cur_disc_step in range(5):
198
+ discriminator.zero_grad()
199
+
200
+ bw, dfm, color_for_real = data_iter.next()
201
+ bw, dfm, color_for_real = bw.to(device), dfm.to(device), color_for_real.to(device)
202
+
203
+ y_real = torch.full((bw.size(0), 1), 0.9, device = device)
204
+ y_fake = torch.zeros((bw.size(0), 1), device = device)
205
+
206
+ empty_hint = torch.zeros(bw.shape[0], 4, bw.shape[2] , bw.shape[3] ).float().to(device)
207
+
208
+ with torch.no_grad():
209
+ fake_color_manga, _ = colorizer(torch.cat([bw, dfm, empty_hint ], 1))
210
+ fake_color_manga.detach()
211
+
212
+ logits_fake = discriminator(fake_color_manga)
213
+ logits_real = discriminator(color_for_real)
214
+
215
+ fake_loss = loss_function(logits_fake, y_fake)
216
+ real_loss = loss_function(logits_real, y_real)
217
+ discriminator_loss = real_loss + fake_loss
218
+
219
+ discriminator_loss.backward()
220
+ disc_optimizer.step()
221
+
222
+
223
+ for p in discriminator.parameters():
224
+ p.requires_grad = False
225
+ for p in colorizer.generator.parameters():
226
+ p.requires_grad = True
227
+
228
+ colorizer.generator.zero_grad()
229
+
230
+ bw, dfm, _ = data_iter.next()
231
+ bw, dfm = bw.to(device), dfm.to(device)
232
+
233
+ y_real = torch.ones((bw.size(0), 1), device = device)
234
+
235
+ empty_hint = torch.zeros(bw.shape[0], 4, bw.shape[2] , bw.shape[3]).float().to(device)
236
+
237
+ fake_manga, _ = colorizer(torch.cat([bw, dfm, empty_hint], 1))
238
+
239
+ logits_fake = discriminator(fake_manga)
240
+ adv_loss = loss_function(logits_fake, y_real)
241
+
242
+ generator_loss = adv_loss
243
+
244
+ generator_loss.backward()
245
+ gen_optimizer.step()
246
+
247
+
248
+
249
+ def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer, data_iter, device = 'cpu'):
250
+ colorizer.generator.train()
251
+ discriminator.train()
252
+
253
+ disc_step = True
254
+
255
+ for n, inputs in enumerate(dataloader):
256
+
257
+ if n == iterations:
258
+ return
259
+
260
+ if disc_step:
261
+ discriminator_step(inputs, colorizer, discriminator, discriminator_optimizer, device)
262
+ else:
263
+ generator_step(inputs, colorizer, discriminator, content, generator_loss, colorizer_optimizer, device)
264
+
265
+ disc_step = disc_step ^ True
266
+
267
+ if n % 10 == 5:
268
+ fine_tuning_step(data_iter, colorizer, discriminator, colorizer_optimizer, discriminator_optimizer, device)
269
+
270
  if __name__ == '__main__':
271
  args = parse_args()
272
  config = open_json('configs/train_config.json')
 
278
 
279
  augmentations = get_transforms()
280
 
281
+ train_dataloader, ft_dataloader = get_dataloaders(args.path, augmentations, config['batch_size'], args.fine_tuning, config['number_of_mults'])
282
 
283
  colorizer, discriminator, content = get_models(device)
284
  set_weights(colorizer, discriminator)
285
 
 
 
286
  gen_optimizer, disc_optimizer = get_optimizers(colorizer, discriminator, config['generator_lr'], config['discriminator_lr'])
287
 
288
+ train(colorizer, discriminator, content, train_dataloader, config['epochs'], gen_optimizer, disc_optimizer, config['lr_decrease_epoch'], device)
289
+
290
+ if args.fine_tuning:
291
+ set_lr(gen_optimizer, config["finetuning_generator_lr"])
292
+ fine_tuning(colorizer, discriminator, content, train_dataloader, config['finetuning_iterations'], gen_optimizer, disc_optimizer, iter(ft_dataloader), device)
293
+
294
+ torch.save(colorizer.generator.state_dict(), str(datetime.datetime.now().time()))