Update train.py
Browse files
train.py
CHANGED
|
@@ -246,7 +246,7 @@ def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_op
|
|
| 246 |
|
| 247 |
|
| 248 |
|
| 249 |
-
def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer,
|
| 250 |
colorizer.generator.train()
|
| 251 |
discriminator.train()
|
| 252 |
|
|
@@ -265,6 +265,7 @@ def fine_tuning(colorizer, discriminator, content, dataloader, iterations, color
|
|
| 265 |
disc_step = disc_step ^ True
|
| 266 |
|
| 267 |
if n % 10 == 5:
|
|
|
|
| 268 |
fine_tuning_step(data_iter, colorizer, discriminator, colorizer_optimizer, discriminator_optimizer, device)
|
| 269 |
|
| 270 |
if __name__ == '__main__':
|
|
|
|
| 246 |
|
| 247 |
|
| 248 |
|
| 249 |
+
def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer, ft_dataloader, device = 'cpu'):
|
| 250 |
colorizer.generator.train()
|
| 251 |
discriminator.train()
|
| 252 |
|
|
|
|
| 265 |
disc_step = disc_step ^ True
|
| 266 |
|
| 267 |
if n % 10 == 5:
|
| 268 |
+
data_iter = iter(ft_dataloader)
|
| 269 |
fine_tuning_step(data_iter, colorizer, discriminator, colorizer_optimizer, discriminator_optimizer, device)
|
| 270 |
|
| 271 |
if __name__ == '__main__':
|