Raid41 commited on
Commit
bfadb5d
·
1 Parent(s): 5be6be6

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +5 -4
train.py CHANGED
@@ -174,14 +174,15 @@ def train(colorizer, discriminator, content, dataloader, epochs, colorizer_optim
174
 
175
 
176
  if disc_step:
177
- step_loss = discriminator_step(inputs, colorizer, discriminator, discriminator_optimizer, device)
178
  sum_disc_loss += step_loss
 
179
  else:
180
- step_loss = generator_step(inputs, colorizer, discriminator, content, generator_loss, colorizer_optimizer, device)
181
  sum_gen_loss += step_loss
 
182
 
183
- disc_step = disc_step ^ True
184
-
185
 
186
  print(datetime.datetime.now().time())
187
  print('Epoch : %d Discr loss: %.4f Gen loss : %.4f \n'%(epoch, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))
 
174
 
175
 
176
  if disc_step:
177
+ step_loss = discriminator_step(inputs, colorizer, discriminator, optimizer, device)
178
  sum_disc_loss += step_loss
179
+ disc_step = False
180
  else:
181
+ step_loss = generator_step(inputs, colorizer, discriminator, content, loss_function, optimizer, device)
182
  sum_gen_loss += step_loss
183
+ disc_step = True
184
 
185
+ disc_step = disc_step ^ True # Toggle the disc_step variable at the end of the inner for loop
 
186
 
187
  print(datetime.datetime.now().time())
188
  print('Epoch : %d Discr loss: %.4f Gen loss : %.4f \n'%(epoch, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))