Raid41 commited on
Commit
8a8d138
·
1 Parent(s): 2051ea3

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +11 -209
train.py CHANGED
@@ -4,7 +4,6 @@ import torch.optim as optim
4
  import numpy as np
5
  import albumentations as albu
6
  import argparse
7
- import datetime
8
 
9
  from utils.utils import open_json, weights_init, weights_init_spectr, generate_mask
10
  from model.models import Colorizer, Generator, Content, Discriminator
@@ -27,8 +26,8 @@ def parse_args():
27
  def get_transforms():
28
  return albu.Compose([albu.RandomCrop(512, 512, always_apply = True), albu.HorizontalFlip(p = 0.5)], p = 1.)
29
 
30
- def get_dataloaders(data_path, transforms, batch_size, fine_tuning, mult_number):
31
- train_dataset = TrainDataset(data_path, transforms, mult_number)
32
  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
33
 
34
  if fine_tuning:
@@ -59,17 +58,12 @@ def set_weights(colorizer, discriminator):
59
 
60
  discriminator.apply(weights_init_spectr)
61
 
62
- def generator_loss(disc_output, true_labels, main_output, guide_output, real_image, content_gen, content_true, dist_loss = nn.L1Loss(), content_dist_loss = nn.MSELoss(), class_loss = nn.BCEWithLogitsLoss()):
63
- sim_loss_full = dist_loss(main_output, real_image)
64
- sim_loss_guide = dist_loss(guide_output, real_image)
 
65
 
66
- adv_loss = class_loss(disc_output, true_labels)
67
-
68
- content_loss = content_dist_loss(content_gen, content_true)
69
-
70
- sum_loss = 10 * (sim_loss_full + 0.9 * sim_loss_guide) + adv_loss + content_loss
71
-
72
- return sum_loss
73
 
74
  def get_optimizers(colorizer, discriminator, generator_lr, discriminator_lr):
75
  optimizerG = optim.Adam(colorizer.generator.parameters(), lr = generator_lr, betas=(0.5, 0.9))
@@ -77,196 +71,8 @@ def get_optimizers(colorizer, discriminator, generator_lr, discriminator_lr):
77
 
78
  return optimizerG, optimizerD
79
 
80
- def generator_step(inputs, colorizer, discriminator, content, loss_function, optimizer, device, white_penalty = True):
81
- for p in discriminator.parameters():
82
- p.requires_grad = False
83
- for p in colorizer.generator.parameters():
84
- p.requires_grad = True
85
-
86
- colorizer.generator.zero_grad()
87
-
88
- bw, color, hint, dfm = inputs
89
- bw, color, hint, dfm = bw.to(device), color.to(device), hint.to(device), dfm.to(device)
90
 
91
- fake, guide = colorizer(torch.cat([bw, dfm, hint], 1))
92
 
93
- logits_fake = discriminator(fake)
94
- y_real = torch.ones((bw.size(0), 1), device = device)
95
-
96
- content_fake = content(fake)
97
- with torch.no_grad():
98
- content_true = content(color)
99
-
100
- generator_loss = loss_function(logits_fake, y_real, fake, guide, color, content_fake, content_true)
101
-
102
- if white_penalty:
103
- mask = (~((color > 0.85).float().sum(dim = 1) == 3).unsqueeze(1).repeat((1, 3, 1, 1 ))).float()
104
- white_zones = mask * (fake + 1) / 2
105
- white_penalty = (torch.pow(white_zones.sum(dim = 1), 2).sum(dim = (1, 2)) / (mask.sum(dim = (1, 2, 3)) + 1)).mean()
106
-
107
- generator_loss += white_penalty
108
-
109
- generator_loss.backward()
110
-
111
- optimizer.step()
112
-
113
- return generator_loss.item()
114
-
115
- def discriminator_step(inputs, colorizer, discriminator, optimizer, device, loss_function = nn.BCEWithLogitsLoss()):
116
-
117
- for p in discriminator.parameters():
118
- p.requires_grad = True
119
- for p in colorizer.generator.parameters():
120
- p.requires_grad = False
121
-
122
- discriminator.zero_grad()
123
-
124
- bw, color, hint, dfm = inputs
125
- bw, color, hint, dfm = bw.to(device), color.to(device), hint.to(device), dfm.to(device)
126
-
127
- y_real = torch.full((bw.size(0), 1), 0.9, device = device)
128
-
129
- y_fake = torch.zeros((bw.size(0), 1), device = device)
130
-
131
- with torch.no_grad():
132
- fake_color, _ = colorizer(torch.cat([bw, dfm, hint], 1))
133
- fake_color.detach()
134
-
135
- logits_fake = discriminator(fake_color)
136
- logits_real = discriminator(color)
137
-
138
- fake_loss = loss_function(logits_fake, y_fake)
139
- real_loss = loss_function(logits_real, y_real)
140
-
141
- discriminator_loss = real_loss + fake_loss
142
-
143
- discriminator_loss.backward()
144
- optimizer.step()
145
-
146
- return discriminator_loss.item()
147
-
148
- def decrease_lr(optimizer, rate):
149
- for group in optimizer.param_groups:
150
- group['lr'] /= rate
151
-
152
- def set_lr(optimizer, value):
153
- for group in optimizer.param_groups:
154
- group['lr'] = value
155
-
156
- def train(colorizer, discriminator, content, dataloader, epochs, colorizer_optimizer, discriminator_optimizer, lr_decay_epoch = -1, device = 'cpu'):
157
- colorizer.generator.train()
158
- discriminator.train()
159
-
160
- disc_step = True
161
-
162
- for epoch in range(epochs):
163
- if (epoch == lr_decay_epoch):
164
- decrease_lr(colorizer_optimizer, 10)
165
- decrease_lr(discriminator_optimizer, 10)
166
-
167
- sum_disc_loss = 0
168
- sum_gen_loss = 0
169
-
170
- for n, inputs in enumerate(dataloader):
171
- if n % 5 == 0:
172
- print(datetime.datetime.now().time())
173
- print('Step : %d Discr loss: %.4f Gen loss : %.4f \n'%(n, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))
174
-
175
-
176
- if disc_step:
177
- step_loss = discriminator_step(inputs, colorizer, discriminator, discriminator_optimizer, device)
178
- sum_disc_loss += step_loss
179
- else:
180
- step_loss = generator_step(inputs, colorizer, discriminator, content, generator_loss, colorizer_optimizer, device)
181
- sum_gen_loss += step_loss
182
-
183
- disc_step = disc_step ^ True
184
-
185
-
186
- print(datetime.datetime.now().time())
187
- print('Epoch : %d Discr loss: %.4f Gen loss : %.4f \n'%(epoch, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))
188
-
189
-
190
- def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_optimizer, device, loss_function = nn.BCEWithLogitsLoss()):
191
-
192
- for p in discriminator.parameters():
193
- p.requires_grad = True
194
- for p in colorizer.generator.parameters():
195
- p.requires_grad = False
196
-
197
- for cur_disc_step in range(5):
198
- discriminator.zero_grad()
199
-
200
- bw, dfm, color_for_real = data_iter.next()
201
- bw, dfm, color_for_real = bw.to(device), dfm.to(device), color_for_real.to(device)
202
-
203
- y_real = torch.full((bw.size(0), 1), 0.9, device = device)
204
- y_fake = torch.zeros((bw.size(0), 1), device = device)
205
-
206
- empty_hint = torch.zeros(bw.shape[0], 4, bw.shape[2] , bw.shape[3] ).float().to(device)
207
-
208
- with torch.no_grad():
209
- fake_color_manga, _ = colorizer(torch.cat([bw, dfm, empty_hint ], 1))
210
- fake_color_manga.detach()
211
-
212
- logits_fake = discriminator(fake_color_manga)
213
- logits_real = discriminator(color_for_real)
214
-
215
- fake_loss = loss_function(logits_fake, y_fake)
216
- real_loss = loss_function(logits_real, y_real)
217
- discriminator_loss = real_loss + fake_loss
218
-
219
- discriminator_loss.backward()
220
- disc_optimizer.step()
221
-
222
-
223
- for p in discriminator.parameters():
224
- p.requires_grad = False
225
- for p in colorizer.generator.parameters():
226
- p.requires_grad = True
227
-
228
- colorizer.generator.zero_grad()
229
-
230
- bw, dfm, _ = data_iter.next()
231
- bw, dfm = bw.to(device), dfm.to(device)
232
-
233
- y_real = torch.ones((bw.size(0), 1), device = device)
234
-
235
- empty_hint = torch.zeros(bw.shape[0], 4, bw.shape[2] , bw.shape[3]).float().to(device)
236
-
237
- fake_manga, _ = colorizer(torch.cat([bw, dfm, empty_hint], 1))
238
-
239
- logits_fake = discriminator(fake_manga)
240
- adv_loss = loss_function(logits_fake, y_real)
241
-
242
- generator_loss = adv_loss
243
-
244
- generator_loss.backward()
245
- gen_optimizer.step()
246
-
247
-
248
-
249
- def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer, data_iter, device = 'cpu'):
250
- colorizer.generator.train()
251
- discriminator.train()
252
-
253
- disc_step = True
254
-
255
- for n, inputs in enumerate(dataloader):
256
-
257
- if n == iterations:
258
- return
259
-
260
- if disc_step:
261
- discriminator_step(inputs, colorizer, discriminator, discriminator_optimizer, device)
262
- else:
263
- generator_step(inputs, colorizer, discriminator, content, generator_loss, colorizer_optimizer, device)
264
-
265
- disc_step = disc_step ^ True
266
-
267
- if n % 10 == 5:
268
- fine_tuning_step(data_iter, colorizer, discriminator, colorizer_optimizer, discriminator_optimizer, device)
269
-
270
  if __name__ == '__main__':
271
  args = parse_args()
272
  config = open_json('configs/train_config.json')
@@ -278,17 +84,13 @@ if __name__ == '__main__':
278
 
279
  augmentations = get_transforms()
280
 
281
- train_dataloader, ft_dataloader = get_dataloaders(args.path, augmentations, config['batch_size'], args.fine_tuning, config['number_of_mults'])
282
 
283
  colorizer, discriminator, content = get_models(device)
284
  set_weights(colorizer, discriminator)
285
 
286
- gen_optimizer, disc_optimizer = get_optimizers(colorizer, discriminator, config['generator_lr'], config['discriminator_lr'])
287
-
288
- train(colorizer, discriminator, content, train_dataloader, config['epochs'], gen_optimizer, disc_optimizer, config['lr_decrease_epoch'], device)
289
 
290
- if args.fine_tuning:
291
- set_lr(gen_optimizer, config["finetuning_generator_lr"])
292
- fine_tuning(colorizer, discriminator, content, train_dataloader, config['finetuning_iterations'], gen_optimizer, disc_optimizer, iter(ft_dataloader), device)
293
 
294
- torch.save(colorizer.generator.state_dict(), str(datetime.datetime.now().time()))
 
4
  import numpy as np
5
  import albumentations as albu
6
  import argparse
 
7
 
8
  from utils.utils import open_json, weights_init, weights_init_spectr, generate_mask
9
  from model.models import Colorizer, Generator, Content, Discriminator
 
26
  def get_transforms():
27
  return albu.Compose([albu.RandomCrop(512, 512, always_apply = True), albu.HorizontalFlip(p = 0.5)], p = 1.)
28
 
29
+ def get_dataloaders(data_path, transforms, batch_size, fine_tuning):
30
+ train_dataset = TrainDataset(data_path, transforms)
31
  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
32
 
33
  if fine_tuning:
 
58
 
59
  discriminator.apply(weights_init_spectr)
60
 
61
+ def get_losses():
62
+ L1_loss = nn.L1Loss()
63
+ mse_loss = nn.MSELoss()
64
+ bce_loss = nn.BCEWithLogitsLoss()
65
 
66
+ return L1_loss, bce_loss, mse_loss
 
 
 
 
 
 
67
 
68
  def get_optimizers(colorizer, discriminator, generator_lr, discriminator_lr):
69
  optimizerG = optim.Adam(colorizer.generator.parameters(), lr = generator_lr, betas=(0.5, 0.9))
 
71
 
72
  return optimizerG, optimizerD
73
 
 
 
 
 
 
 
 
 
 
 
74
 
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  if __name__ == '__main__':
77
  args = parse_args()
78
  config = open_json('configs/train_config.json')
 
84
 
85
  augmentations = get_transforms()
86
 
87
+ train_dataloader, ft_dataloader = get_dataloaders(args.path, augmentations, config['batch_size'], args.fine_tuning)
88
 
89
  colorizer, discriminator, content = get_models(device)
90
  set_weights(colorizer, discriminator)
91
 
92
+ l1_loss, bce_loss, mse_loss = get_losses()
 
 
93
 
94
+ gen_optimizer, disc_optimizer = get_optimizers(colorizer, discriminator, config['generator_lr'], config['discriminator_lr'])
 
 
95
 
96
+