Update train.py
Browse files
train.py
CHANGED
|
@@ -174,14 +174,15 @@ def train(colorizer, discriminator, content, dataloader, epochs, colorizer_optim
|
|
| 174 |
|
| 175 |
|
| 176 |
if disc_step:
|
| 177 |
-
step_loss = discriminator_step(inputs, colorizer, discriminator,
|
| 178 |
sum_disc_loss += step_loss
|
|
|
|
| 179 |
else:
|
| 180 |
-
step_loss = generator_step(inputs, colorizer, discriminator, content,
|
| 181 |
sum_gen_loss += step_loss
|
|
|
|
| 182 |
|
| 183 |
-
disc_step = disc_step ^ True
|
| 184 |
-
|
| 185 |
|
| 186 |
print(datetime.datetime.now().time())
|
| 187 |
print('Epoch : %d Discr loss: %.4f Gen loss : %.4f \n'%(epoch, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))
|
|
|
|
| 174 |
|
| 175 |
|
| 176 |
if disc_step:
|
| 177 |
+
step_loss = discriminator_step(inputs, colorizer, discriminator, optimizer, device)
|
| 178 |
sum_disc_loss += step_loss
|
| 179 |
+
disc_step = False
|
| 180 |
else:
|
| 181 |
+
step_loss = generator_step(inputs, colorizer, discriminator, content, loss_function, optimizer, device)
|
| 182 |
sum_gen_loss += step_loss
|
| 183 |
+
disc_step = True
|
| 184 |
|
| 185 |
+
disc_step = disc_step ^ True # Toggle the disc_step variable at the end of the inner for loop
|
|
|
|
| 186 |
|
| 187 |
print(datetime.datetime.now().time())
|
| 188 |
print('Epoch : %d Discr loss: %.4f Gen loss : %.4f \n'%(epoch, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))
|