Update train.py
Browse files
train.py
CHANGED
|
@@ -197,7 +197,12 @@ def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_op
|
|
| 197 |
for cur_disc_step in range(5):
|
| 198 |
discriminator.zero_grad()
|
| 199 |
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
bw, dfm, color_for_real = bw.to(device), dfm.to(device), color_for_real.to(device)
|
| 202 |
|
| 203 |
y_real = torch.full((bw.size(0), 1), 0.9, device = device)
|
|
@@ -227,7 +232,12 @@ def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_op
|
|
| 227 |
|
| 228 |
colorizer.generator.zero_grad()
|
| 229 |
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
bw, dfm = bw.to(device), dfm.to(device)
|
| 232 |
|
| 233 |
y_real = torch.ones((bw.size(0), 1), device = device)
|
|
@@ -243,8 +253,6 @@ def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_op
|
|
| 243 |
|
| 244 |
generator_loss.backward()
|
| 245 |
gen_optimizer.step()
|
| 246 |
-
|
| 247 |
-
|
| 248 |
|
| 249 |
def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer, ft_dataloader, device = 'cpu'):
|
| 250 |
colorizer.generator.train()
|
|
|
|
| 197 |
for cur_disc_step in range(5):
|
| 198 |
discriminator.zero_grad()
|
| 199 |
|
| 200 |
+
try:
|
| 201 |
+
bw, dfm, color_for_real = next(data_iter)
|
| 202 |
+
except StopIteration:
|
| 203 |
+
data_iter = iter(ft_dataloader)
|
| 204 |
+
bw, dfm, color_for_real = next(data_iter)
|
| 205 |
+
|
| 206 |
bw, dfm, color_for_real = bw.to(device), dfm.to(device), color_for_real.to(device)
|
| 207 |
|
| 208 |
y_real = torch.full((bw.size(0), 1), 0.9, device = device)
|
|
|
|
| 232 |
|
| 233 |
colorizer.generator.zero_grad()
|
| 234 |
|
| 235 |
+
try:
|
| 236 |
+
bw, dfm, _ = next(data_iter)
|
| 237 |
+
except StopIteration:
|
| 238 |
+
data_iter = iter(ft_dataloader)
|
| 239 |
+
bw, dfm, _ = next(data_iter)
|
| 240 |
+
|
| 241 |
bw, dfm = bw.to(device), dfm.to(device)
|
| 242 |
|
| 243 |
y_real = torch.ones((bw.size(0), 1), device = device)
|
|
|
|
| 253 |
|
| 254 |
generator_loss.backward()
|
| 255 |
gen_optimizer.step()
|
|
|
|
|
|
|
| 256 |
|
| 257 |
def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer, ft_dataloader, device = 'cpu'):
|
| 258 |
colorizer.generator.train()
|