YashNagraj75 commited on
Commit
7e6af10
·
1 Parent(s): bf96a55

Add wandb logging (Life is more peacefull)

Browse files
Files changed (1) hide show
  1. train_vqvae.py +44 -19
train_vqvae.py CHANGED
@@ -2,17 +2,19 @@ import os
2
  import argparse
3
  import torch
4
  import torch.nn as nn
5
- import torchvision
6
- from torchvision import models
7
  from models.vqvae import VQVAE
8
  from models.discriminator import Discriminator
9
  from torch.optim import Adam
10
  from models.lpips import LPIPS
11
  from dataset.celeba import create_dataloader
12
  from torchvision.utils import make_grid
 
13
  import yaml
14
  import numpy as np
15
  from tqdm import tqdm
 
 
 
16
 
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
 
@@ -77,6 +79,8 @@ def train(args):
77
  optimizer_g = Adam(
78
  model.parameters(), lr=train_config["autoencoder_lr"], betas=(0.5, 0.999)
79
  )
 
 
80
  # LPIPS model
81
  lpips_model = LPIPS().eval().to(device)
82
 
@@ -115,24 +119,10 @@ def train(args):
115
  grid = make_grid(
116
  torch.cat([save_input, save_output], dim=0), nrow=sample_size
117
  )
118
- img = torchvision.transforms.ToPILImage()(grid)
119
- if not os.path.exists(
120
- os.path.join(train_config["task_name"], "vqvae_autoencoder_samples")
121
- ):
122
- os.mkdir(
123
- os.path.join(
124
- train_config["task_name"], "vqvae_autoencoder_samples"
125
- )
126
- )
127
- img.save(
128
- os.path.join(
129
- train_config["task_name"],
130
- "vqvae_autoencoder_samples",
131
- "current_autoencoder_sample_{}.png".format(img_saved),
132
- )
133
- )
134
  img_saved += 1
135
- img.close()
136
 
137
  steps += 1
138
 
@@ -189,6 +179,19 @@ def train(args):
189
  optimizer_g.zero_grad()
190
  optimizer_d.step()
191
  optimizer_d.zero_grad()
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  if len(disc_losses) > 0:
193
  print(
194
  "Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | "
@@ -235,7 +238,29 @@ def train(args):
235
  ),
236
  )
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  print("Done Training....")
 
239
 
240
 
241
  if __name__ == "__main__":
 
2
  import argparse
3
  import torch
4
  import torch.nn as nn
 
 
5
  from models.vqvae import VQVAE
6
  from models.discriminator import Discriminator
7
  from torch.optim import Adam
8
  from models.lpips import LPIPS
9
  from dataset.celeba import create_dataloader
10
  from torchvision.utils import make_grid
11
+ from torchvision.transforms import ToPILImage
12
  import yaml
13
  import numpy as np
14
  from tqdm import tqdm
15
+ import wandb
16
+
17
+ wandb.init(project="vqvae")
18
 
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
 
79
  optimizer_g = Adam(
80
  model.parameters(), lr=train_config["autoencoder_lr"], betas=(0.5, 0.999)
81
  )
82
+ wandb.watch(model,log="all", log_freq=100)
83
+
84
  # LPIPS model
85
  lpips_model = LPIPS().eval().to(device)
86
 
 
119
  grid = make_grid(
120
  torch.cat([save_input, save_output], dim=0), nrow=sample_size
121
  )
122
+
123
+ grid_image = ToPILImage(grid)
124
+ wandb.log({"Latent generation": wandb.Image(grid_image,caption=f"Epoch: {epoch+1}, Step: {steps}")})
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  img_saved += 1
 
126
 
127
  steps += 1
128
 
 
179
  optimizer_g.zero_grad()
180
  optimizer_d.step()
181
  optimizer_d.zero_grad()
182
+
183
+ wandb.log({
184
+ "epoch": epoch + 1,
185
+ "step": steps,
186
+ "image_saved": img_saved,
187
+ "recon_loss": np.mean(recon_losses),
188
+ "perceptual_loss": np.mean(perceptual_losses),
189
+ "codebook_loss": np.mean(codebook_losses),
190
+ "gen_loss": np.mean(gen_losses),
191
+ "disc_loss": np.mean(disc_losses),
192
+ "overall_loss": np.mean(losses)
193
+ })
194
+
195
  if len(disc_losses) > 0:
196
  print(
197
  "Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | "
 
238
  ),
239
  )
240
 
241
+ wandb.save(
242
+ os.path.join(
243
+ train_config["task_name"], train_config["vqvae_autoencoder_ckpt_name"]
244
+ )
245
+ )
246
+ wandb.save(
247
+ os.path.join(
248
+ train_config["task_name"], train_config["vqvae_discriminator_ckpt_name"]
249
+ )
250
+ )
251
+ wandb.save(
252
+ os.path.join(
253
+ train_config["task_name"], train_config["vqvae_optim_d_ckpt_name"]
254
+ )
255
+ )
256
+ wandb.save(
257
+ os.path.join(
258
+ train_config["task_name"], train_config["vqvae_optim_g_ckpt_name"]
259
+ )
260
+ )
261
+
262
  print("Done Training....")
263
+ wandb.finish()
264
 
265
 
266
  if __name__ == "__main__":