Commit ·
2c4ef68
1
Parent(s): 90d71da
Made changes in the training script
Browse files- celebhq/vqvae_autoencoder_ckpt.pth +1 -1
- celebhq/vqvae_autoencoder_samples/current_autoencoder_sample_11.png +3 -0
- celebhq/vqvae_autoencoder_samples/current_autoencoder_sample_12.png +3 -0
- celebhq/vqvae_autoencoder_samples/current_autoencoder_sample_13.png +3 -0
- celebhq/vqvae_autoencoder_samples/current_autoencoder_sample_14.png +3 -0
- celebhq/vqvae_optim_g_ckpt.pth +1 -1
- config/celebahq.yaml +1 -1
- train_vqvae.py +6 -6
celebhq/vqvae_autoencoder_ckpt.pth
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 88110787
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:34e9a28cd01122493eed58be315fb4757b3513701b47a4bf6058a59506f3f7fd
|
| 3 |
size 88110787
|
celebhq/vqvae_autoencoder_samples/current_autoencoder_sample_11.png
ADDED
|
Git LFS Details
|
celebhq/vqvae_autoencoder_samples/current_autoencoder_sample_12.png
ADDED
|
Git LFS Details
|
celebhq/vqvae_autoencoder_samples/current_autoencoder_sample_13.png
ADDED
|
Git LFS Details
|
celebhq/vqvae_autoencoder_samples/current_autoencoder_sample_14.png
ADDED
|
Git LFS Details
|
celebhq/vqvae_optim_g_ckpt.pth
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 264209698
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f29a6752be23312bbe39525029bd015a401b32d25de7b38f3200b1e0473a22ef
|
| 3 |
size 264209698
|
config/celebahq.yaml
CHANGED
|
@@ -49,7 +49,7 @@ train_config:
|
|
| 49 |
ldm_batch_size: 16
|
| 50 |
autoencoder_batch_size: 4
|
| 51 |
disc_start: 15000
|
| 52 |
-
|
| 53 |
codebook_beta: 1
|
| 54 |
commitment_beta: 0.2
|
| 55 |
perceptual_weight: 1
|
|
|
|
| 49 |
ldm_batch_size: 16
|
| 50 |
autoencoder_batch_size: 4
|
| 51 |
disc_start: 15000
|
| 52 |
+
disc_beta: 0.5
|
| 53 |
codebook_beta: 1
|
| 54 |
commitment_beta: 0.2
|
| 55 |
perceptual_weight: 1
|
train_vqvae.py
CHANGED
|
@@ -53,8 +53,8 @@ def train(args):
|
|
| 53 |
print("Loading checkpoint...")
|
| 54 |
model = torch.load(vqvae_ckpt_path).to(device)
|
| 55 |
discriminator = torch.load(discriminator_ckpt_path).to(device)
|
| 56 |
-
optimizer_d = torch.load(optimizer_d_ckpt)
|
| 57 |
-
optimizer_g = torch.load(optimizer_g_ckpt)
|
| 58 |
|
| 59 |
else:
|
| 60 |
model = VQVAE(
|
|
@@ -69,10 +69,10 @@ def train(args):
|
|
| 69 |
|
| 70 |
|
| 71 |
img_save_steps = train_config["autoencoder_img_save_steps"]
|
| 72 |
-
img_saved =
|
| 73 |
|
| 74 |
disc_step_start = train_config['disc_start']
|
| 75 |
-
steps =
|
| 76 |
|
| 77 |
for epoch in range(train_config["autoencoder_epochs"]):
|
| 78 |
recon_losses = []
|
|
@@ -129,7 +129,7 @@ def train(args):
|
|
| 129 |
if steps > disc_step_start:
|
| 130 |
disc_fake_pred = discriminator(model_output[0])
|
| 131 |
disc_fake_loss = disc_criterion(
|
| 132 |
-
disc_fake_pred, torch.
|
| 133 |
gen_losses.append(
|
| 134 |
train_config["disc_beta"] * disc_fake_loss.item())
|
| 135 |
g_loss += train_config["disc_beta"] * disc_fake_loss
|
|
@@ -149,7 +149,7 @@ def train(args):
|
|
| 149 |
disc_fake_pred, device=disc_fake_pred.device))
|
| 150 |
disc_real_loss = disc_criterion(disc_real_pred, torch.ones_like(
|
| 151 |
disc_real_pred, device=disc_real_pred.device))
|
| 152 |
-
disc_loss = train_config["
|
| 153 |
(disc_real_loss + disc_fake_loss) / 2
|
| 154 |
disc_losses.append(disc_loss)
|
| 155 |
disc_loss.backward()
|
|
|
|
| 53 |
print("Loading checkpoint...")
|
| 54 |
model = torch.load(vqvae_ckpt_path).to(device)
|
| 55 |
discriminator = torch.load(discriminator_ckpt_path).to(device)
|
| 56 |
+
optimizer_d = torch.load(optimizer_d_ckpt)
|
| 57 |
+
optimizer_g = torch.load(optimizer_g_ckpt)
|
| 58 |
|
| 59 |
else:
|
| 60 |
model = VQVAE(
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
img_save_steps = train_config["autoencoder_img_save_steps"]
|
| 72 |
+
img_saved = 14
|
| 73 |
|
| 74 |
disc_step_start = train_config['disc_start']
|
| 75 |
+
steps = 15000
|
| 76 |
|
| 77 |
for epoch in range(train_config["autoencoder_epochs"]):
|
| 78 |
recon_losses = []
|
|
|
|
| 129 |
if steps > disc_step_start:
|
| 130 |
disc_fake_pred = discriminator(model_output[0])
|
| 131 |
disc_fake_loss = disc_criterion(
|
| 132 |
+
disc_fake_pred, torch.ones(disc_fake_pred.shape, device=disc_fake_pred.device))
|
| 133 |
gen_losses.append(
|
| 134 |
train_config["disc_beta"] * disc_fake_loss.item())
|
| 135 |
g_loss += train_config["disc_beta"] * disc_fake_loss
|
|
|
|
| 149 |
disc_fake_pred, device=disc_fake_pred.device))
|
| 150 |
disc_real_loss = disc_criterion(disc_real_pred, torch.ones_like(
|
| 151 |
disc_real_pred, device=disc_real_pred.device))
|
| 152 |
+
disc_loss = train_config["disc_beta"] * \
|
| 153 |
(disc_real_loss + disc_fake_loss) / 2
|
| 154 |
disc_losses.append(disc_loss)
|
| 155 |
disc_loss.backward()
|