YashNagraj75 commited on
Commit
2c4ef68
·
1 Parent(s): 90d71da

Made changes in the training script

Browse files
celebhq/vqvae_autoencoder_ckpt.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cb6f542f508f916324196d1d64aa13657905de80419b8b638387bd67040f0ed5
3
  size 88110787
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34e9a28cd01122493eed58be315fb4757b3513701b47a4bf6058a59506f3f7fd
3
  size 88110787
celebhq/vqvae_autoencoder_samples/current_autoencoder_sample_11.png ADDED

Git LFS Details

  • SHA256: 4571cf6b7a3980893f7509fb8c33a27ea5417eb048ade9ac6c2dcd955b084f18
  • Pointer size: 132 Bytes
  • Size of remote file: 1.84 MB
celebhq/vqvae_autoencoder_samples/current_autoencoder_sample_12.png ADDED

Git LFS Details

  • SHA256: f9f45ff7b47782d60e794abe113e2b2d77397ba4461791edf02ad729c977a13e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.9 MB
celebhq/vqvae_autoencoder_samples/current_autoencoder_sample_13.png ADDED

Git LFS Details

  • SHA256: d85d5cb185a1fbb8b6541480c53c4057146967268464256896b8d1d1c40c8d40
  • Pointer size: 132 Bytes
  • Size of remote file: 1.87 MB
celebhq/vqvae_autoencoder_samples/current_autoencoder_sample_14.png ADDED

Git LFS Details

  • SHA256: a93797bf8516552597fb66c2bf3383baae505846be549ad9daac6f3b6d4e6183
  • Pointer size: 132 Bytes
  • Size of remote file: 1.89 MB
celebhq/vqvae_optim_g_ckpt.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6f8ae23a0f51866d3116b110a66b983f271e6461c5892c1d945870d84a1fcee4
3
  size 264209698
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f29a6752be23312bbe39525029bd015a401b32d25de7b38f3200b1e0473a22ef
3
  size 264209698
config/celebahq.yaml CHANGED
@@ -49,7 +49,7 @@ train_config:
49
  ldm_batch_size: 16
50
  autoencoder_batch_size: 4
51
  disc_start: 15000
52
- disc_weight: 0.5
53
  codebook_beta: 1
54
  commitment_beta: 0.2
55
  perceptual_weight: 1
 
49
  ldm_batch_size: 16
50
  autoencoder_batch_size: 4
51
  disc_start: 15000
52
+ disc_beta: 0.5
53
  codebook_beta: 1
54
  commitment_beta: 0.2
55
  perceptual_weight: 1
train_vqvae.py CHANGED
@@ -53,8 +53,8 @@ def train(args):
53
  print("Loading checkpoint...")
54
  model = torch.load(vqvae_ckpt_path).to(device)
55
  discriminator = torch.load(discriminator_ckpt_path).to(device)
56
- optimizer_d = torch.load(optimizer_d_ckpt).to(device)
57
- optimizer_g = torch.load(optimizer_g_ckpt).to(device)
58
 
59
  else:
60
  model = VQVAE(
@@ -69,10 +69,10 @@ def train(args):
69
 
70
 
71
  img_save_steps = train_config["autoencoder_img_save_steps"]
72
- img_saved = 0
73
 
74
  disc_step_start = train_config['disc_start']
75
- steps = 0
76
 
77
  for epoch in range(train_config["autoencoder_epochs"]):
78
  recon_losses = []
@@ -129,7 +129,7 @@ def train(args):
129
  if steps > disc_step_start:
130
  disc_fake_pred = discriminator(model_output[0])
131
  disc_fake_loss = disc_criterion(
132
- disc_fake_pred, torch.ones_like(disc_fake_pred.shape, device=disc_fake_pred.device))
133
  gen_losses.append(
134
  train_config["disc_beta"] * disc_fake_loss.item())
135
  g_loss += train_config["disc_beta"] * disc_fake_loss
@@ -149,7 +149,7 @@ def train(args):
149
  disc_fake_pred, device=disc_fake_pred.device))
150
  disc_real_loss = disc_criterion(disc_real_pred, torch.ones_like(
151
  disc_real_pred, device=disc_real_pred.device))
152
- disc_loss = train_config["disc_weight"] * \
153
  (disc_real_loss + disc_fake_loss) / 2
154
  disc_losses.append(disc_loss)
155
  disc_loss.backward()
 
53
  print("Loading checkpoint...")
54
  model = torch.load(vqvae_ckpt_path).to(device)
55
  discriminator = torch.load(discriminator_ckpt_path).to(device)
56
+ optimizer_d = torch.load(optimizer_d_ckpt)
57
+ optimizer_g = torch.load(optimizer_g_ckpt)
58
 
59
  else:
60
  model = VQVAE(
 
69
 
70
 
71
  img_save_steps = train_config["autoencoder_img_save_steps"]
72
+ img_saved = 14
73
 
74
  disc_step_start = train_config['disc_start']
75
+ steps = 15000
76
 
77
  for epoch in range(train_config["autoencoder_epochs"]):
78
  recon_losses = []
 
129
  if steps > disc_step_start:
130
  disc_fake_pred = discriminator(model_output[0])
131
  disc_fake_loss = disc_criterion(
132
+ disc_fake_pred, torch.ones(disc_fake_pred.shape, device=disc_fake_pred.device))
133
  gen_losses.append(
134
  train_config["disc_beta"] * disc_fake_loss.item())
135
  g_loss += train_config["disc_beta"] * disc_fake_loss
 
149
  disc_fake_pred, device=disc_fake_pred.device))
150
  disc_real_loss = disc_criterion(disc_real_pred, torch.ones_like(
151
  disc_real_pred, device=disc_real_pred.device))
152
+ disc_loss = train_config["disc_beta"] * \
153
  (disc_real_loss + disc_fake_loss) / 2
154
  disc_losses.append(disc_loss)
155
  disc_loss.backward()