Commit ·
7e6af10
1
Parent(s): bf96a55
Add wandb logging (Life is more peacefull)
Browse files- train_vqvae.py +44 -19
train_vqvae.py
CHANGED
|
@@ -2,17 +2,19 @@ import os
|
|
| 2 |
import argparse
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
-
import torchvision
|
| 6 |
-
from torchvision import models
|
| 7 |
from models.vqvae import VQVAE
|
| 8 |
from models.discriminator import Discriminator
|
| 9 |
from torch.optim import Adam
|
| 10 |
from models.lpips import LPIPS
|
| 11 |
from dataset.celeba import create_dataloader
|
| 12 |
from torchvision.utils import make_grid
|
|
|
|
| 13 |
import yaml
|
| 14 |
import numpy as np
|
| 15 |
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 18 |
|
|
@@ -77,6 +79,8 @@ def train(args):
|
|
| 77 |
optimizer_g = Adam(
|
| 78 |
model.parameters(), lr=train_config["autoencoder_lr"], betas=(0.5, 0.999)
|
| 79 |
)
|
|
|
|
|
|
|
| 80 |
# LPIPS model
|
| 81 |
lpips_model = LPIPS().eval().to(device)
|
| 82 |
|
|
@@ -115,24 +119,10 @@ def train(args):
|
|
| 115 |
grid = make_grid(
|
| 116 |
torch.cat([save_input, save_output], dim=0), nrow=sample_size
|
| 117 |
)
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
):
|
| 122 |
-
os.mkdir(
|
| 123 |
-
os.path.join(
|
| 124 |
-
train_config["task_name"], "vqvae_autoencoder_samples"
|
| 125 |
-
)
|
| 126 |
-
)
|
| 127 |
-
img.save(
|
| 128 |
-
os.path.join(
|
| 129 |
-
train_config["task_name"],
|
| 130 |
-
"vqvae_autoencoder_samples",
|
| 131 |
-
"current_autoencoder_sample_{}.png".format(img_saved),
|
| 132 |
-
)
|
| 133 |
-
)
|
| 134 |
img_saved += 1
|
| 135 |
-
img.close()
|
| 136 |
|
| 137 |
steps += 1
|
| 138 |
|
|
@@ -189,6 +179,19 @@ def train(args):
|
|
| 189 |
optimizer_g.zero_grad()
|
| 190 |
optimizer_d.step()
|
| 191 |
optimizer_d.zero_grad()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
if len(disc_losses) > 0:
|
| 193 |
print(
|
| 194 |
"Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | "
|
|
@@ -235,7 +238,29 @@ def train(args):
|
|
| 235 |
),
|
| 236 |
)
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
print("Done Training....")
|
|
|
|
| 239 |
|
| 240 |
|
| 241 |
if __name__ == "__main__":
|
|
|
|
| 2 |
import argparse
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
|
|
|
|
|
|
| 5 |
from models.vqvae import VQVAE
|
| 6 |
from models.discriminator import Discriminator
|
| 7 |
from torch.optim import Adam
|
| 8 |
from models.lpips import LPIPS
|
| 9 |
from dataset.celeba import create_dataloader
|
| 10 |
from torchvision.utils import make_grid
|
| 11 |
+
from torchvision.transforms import ToPILImage
|
| 12 |
import yaml
|
| 13 |
import numpy as np
|
| 14 |
from tqdm import tqdm
|
| 15 |
+
import wandb
|
| 16 |
+
|
| 17 |
+
wandb.init(project="vqvae")
|
| 18 |
|
| 19 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
|
|
|
|
| 79 |
optimizer_g = Adam(
|
| 80 |
model.parameters(), lr=train_config["autoencoder_lr"], betas=(0.5, 0.999)
|
| 81 |
)
|
| 82 |
+
wandb.watch(model,log="all", log_freq=100)
|
| 83 |
+
|
| 84 |
# LPIPS model
|
| 85 |
lpips_model = LPIPS().eval().to(device)
|
| 86 |
|
|
|
|
| 119 |
grid = make_grid(
|
| 120 |
torch.cat([save_input, save_output], dim=0), nrow=sample_size
|
| 121 |
)
|
| 122 |
+
|
| 123 |
+
grid_image = ToPILImage(grid)
|
| 124 |
+
wandb.log({"Latent generation": wandb.Image(grid_image,caption=f"Epoch: {epoch+1}, Step: {steps}")})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
img_saved += 1
|
|
|
|
| 126 |
|
| 127 |
steps += 1
|
| 128 |
|
|
|
|
| 179 |
optimizer_g.zero_grad()
|
| 180 |
optimizer_d.step()
|
| 181 |
optimizer_d.zero_grad()
|
| 182 |
+
|
| 183 |
+
wandb.log({
|
| 184 |
+
"epoch": epoch + 1,
|
| 185 |
+
"step": steps,
|
| 186 |
+
"image_saved": img_saved,
|
| 187 |
+
"recon_loss": np.mean(recon_losses),
|
| 188 |
+
"perceptual_loss": np.mean(perceptual_losses),
|
| 189 |
+
"codebook_loss": np.mean(codebook_losses),
|
| 190 |
+
"gen_loss": np.mean(gen_losses),
|
| 191 |
+
"disc_loss": np.mean(disc_losses),
|
| 192 |
+
"overall_loss": np.mean(losses)
|
| 193 |
+
})
|
| 194 |
+
|
| 195 |
if len(disc_losses) > 0:
|
| 196 |
print(
|
| 197 |
"Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | "
|
|
|
|
| 238 |
),
|
| 239 |
)
|
| 240 |
|
| 241 |
+
wandb.save(
|
| 242 |
+
os.path.join(
|
| 243 |
+
train_config["task_name"], train_config["vqvae_autoencoder_ckpt_name"]
|
| 244 |
+
)
|
| 245 |
+
)
|
| 246 |
+
wandb.save(
|
| 247 |
+
os.path.join(
|
| 248 |
+
train_config["task_name"], train_config["vqvae_discriminator_ckpt_name"]
|
| 249 |
+
)
|
| 250 |
+
)
|
| 251 |
+
wandb.save(
|
| 252 |
+
os.path.join(
|
| 253 |
+
train_config["task_name"], train_config["vqvae_optim_d_ckpt_name"]
|
| 254 |
+
)
|
| 255 |
+
)
|
| 256 |
+
wandb.save(
|
| 257 |
+
os.path.join(
|
| 258 |
+
train_config["task_name"], train_config["vqvae_optim_g_ckpt_name"]
|
| 259 |
+
)
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
print("Done Training....")
|
| 263 |
+
wandb.finish()
|
| 264 |
|
| 265 |
|
| 266 |
if __name__ == "__main__":
|