flickr8k-backend / core /train_vae.py
Rohan3's picture
Updated: VAE, UNet, config, text embeddings, model and main
a625e96
import sys, os
sys.path.insert(0, os.path.dirname(__file__))
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import torch
from torch import nn, optim
import torch.nn.functional as F
from dataloader import image_dataloader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from config import *
from vae import VAE
import matplotlib.pyplot as plt
from seed import seed_everything
from sample_vae import reconstruct
import lpips
from pytorch_msssim import SSIM
# # Simple PatchGAN Discriminator
# class PatchDiscriminator(nn.Module):
# def __init__(self, in_channels=3):
# super().__init__()
# def block(in_c, out_c, stride=2, normalize=True):
# layers = [nn.Conv2d(in_c, out_c, 4, stride, 1, bias=False)]
# if normalize: layers.append(nn.GroupNorm(vae_group_size, out_c))
# layers.append(nn.LeakyReLU(0.2, inplace=True))
# return layers
# self.model = nn.Sequential(
# *block(in_channels, 64, normalize=False), # 128->64
# *block(64, 128), # 64->32
# *block(128, 256), # 32->16
# *block(256, 512), # 16->8
# nn.Conv2d(512, 1, 4, 1, 0) # output 1x1 map
# )
# def forward(self, x):
# return self.model(x)
# def vae_loss(recon_x, x, mu, logvar, discriminator=None, beta_kld=1.0):
# # lpips_weight = 1
# # lpips_loss = lpips_fn(x, recon_x).mean()
# # gan_weight = 0
# # rec_loss = F.l1_loss(recon_x, x, reduction="mean")
# # rec_loss = torch.sum((recon_x - x) ** 2, dim=[1,2,3]).mean()
# # kld = torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=[1,2,3]))
# # rec_loss = F.mse_loss(recon_x, x, reduction="sum") / recon_x.size(0)
# # rec_loss = F.mse_loss(recon_x, x, reduction="mean")
# # kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=[1,2,3]).mean() / (4*16*16)
# rec_loss = F.mse_loss(recon_x, x, reduction="none").view(x.size(0), -1).sum(dim=1).mean()
# kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=[1, 2, 3]).mean()
# total_loss = (rec_loss + beta_kld * kld)
# # gan_loss = -discriminator(recon_x).mean()
# # total_loss = rec_loss + kld * beta_kld + tvl * vae_lambda_tvl + lpips_loss * lpips_weight + gan_loss * gan_weight
# # return total_loss, rec_loss, kld, tvl, lpips_loss, gan_loss
# # total_loss = rec_loss + kld * beta_kld
# return total_loss, rec_loss, kld, 0, 0
# def update_discriminator(discriminator, optimizer_D, real_imgs, fake_imgs):
# optimizer_D.zero_grad()
# real_pred = discriminator(real_imgs)
# fake_pred = discriminator(fake_imgs.detach())
# loss_D = torch.mean(F.relu(1.0 - real_pred)) + torch.mean(F.relu(1.0 + fake_pred)) # Hinge loss
# loss_D.backward()
# optimizer_D.step()
# return loss_D.item()
def total_variance_loss(x):
tvl_h = torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2).sum() # TV-L2
tvl_w = torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2).sum() # TV-L2
return (tvl_h + tvl_w) / x.size(0)
def get_annealed_beta(epoch, warmup_epochs=100, max_beta=1.0): return vae_beta_kld * min(max_beta, epoch / warmup_epochs)
# def vae_loss(recon_x, x, mu, logvar, lpips_fn, ssim_fn, beta_kld=1.0):
def vae_loss(recon_x, x, mu, logvar, beta_kld=1.0, lpips_fn=None, ssim_fn=None):
# mse = torch.mean((recon_x - x) ** 2)
# kld_loss = torch.mean(-0.5 * (1 + logvar - mu.pow(2) - logvar.exp()))
# b, c, h, w = x.shape
tvl = total_variance_loss(recon_x) if vae_lambda_tvl > 0 else 0.0
mse = F.mse_loss(recon_x, x, reduction="sum") / recon_x.size(0)
kld_loss = torch.sum(-0.5 * (1 + logvar - mu.pow(2) - logvar.exp())) / recon_x.size(0)
# mse = F.mse_loss(recon_x, x, reduction="mean")
# kld_loss = torch.mean(-0.5 * (1 + logvar - mu.pow(2) - logvar.exp())) # divides by latent_ch×H×W×B
with torch.amp.autocast("cuda", enabled=False):
ssim_loss = 1 - ssim_fn(recon_x.float(), x.float())
lpips_loss = lpips_fn(recon_x.float() * 2 - 1, x.float() * 2 - 1).mean()
total_loss = mse + (beta_kld * kld_loss) + (vae_lambda_tvl * tvl) + (vae_lpips_weight * lpips_loss) + ssim_loss * 0.1
return total_loss, mse, kld_loss, tvl, lpips_loss, ssim_loss
# total_loss = mse + (beta_kld * kld_loss)
# return total_loss, mse, kld_loss, 0, 0, 0
def train_test_vae():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, test_loader = image_dataloader()
vae = VAE().to(device)
optimizer = optim.AdamW(vae.parameters(), lr=vae_optim_lr, betas=(0.9, 0.999), weight_decay=0)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
os.makedirs(f"{vae_checkpoint_dir}", exist_ok=True)
early_stopping_counter = 0
best_test_loss = float("inf")
train_bce_loss = []
train_kld_loss = []
lpips_fn = lpips.LPIPS(net='vgg').to(device)
lpips_fn.eval()
for param in lpips_fn.parameters():
param.requires_grad = False
ssim_fn = SSIM(data_range=1, size_average=True, channel=3)
scaler = torch.amp.GradScaler("cuda", enabled=torch.cuda.is_available())
# discriminator = PatchDiscriminator().to(device)
# optimizer_D = optim.AdamW(discriminator.parameters(), lr=vae_optim_lr, betas=(0.5, 0.999))
# Train VAE
for epoch in range(vae_num_epochs):
vae.train()
train_loss, total_rec, total_kld, total_tvl, total_lpips, total_ssim = 0, 0, 0, 0, 0, 0
for image_gt in tqdm(train_loader, desc=f"Epoch {epoch+1}/{vae_num_epochs}", colour="#CC00FF"):
x = image_gt.to(device)
# discriminator.train()
# for param in discriminator.parameters(): param.requires_grad = True
# with torch.no_grad(): recon_x_disc, _, _ = vae(x)
# d_loss = update_discriminator(discriminator, optimizer_D, x, recon_x_disc.detach()) # Update discriminator
# discriminator.eval() # Optional: turn off dropout/batchnorm
# for param in discriminator.parameters(): param.requires_grad = False
with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()):
recon_x, mu, logvar = vae(x)
# loss, rec, kld, tvl, lpips_loss, gan_loss = vae_loss(recon_x, x, mu, logvar, discriminator, lpips_fn, beta_kld=get_annealed_beta(epoch)) # Update VAE (generator)
# loss, rec, kld, tvl, lpips_loss = vae_loss(recon_x, x, mu, logvar, lpips_fn, beta_kld=min(1, global_step / warmup_steps)) # Update VAE (generator)
# loss, rec, kld, tvl, lpips_loss, ssim_loss = vae_loss(recon_x, x, mu, logvar, lpips_fn, ssim_fn, beta_kld=get_annealed_beta(epoch)) # Update VAE (generator)
loss, rec, kld, tvl, lpips_loss, ssim_loss = vae_loss(recon_x, x, mu, logvar, beta_kld=get_annealed_beta(epoch), lpips_fn=lpips_fn, ssim_fn=ssim_fn) # Update VAE (generator)
optimizer.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
train_loss += loss.item()
total_rec += rec.item()
total_kld += kld.item()
total_tvl += tvl.item() if vae_lambda_tvl > 0 else 0.0
total_lpips += lpips_loss.item()
total_ssim += ssim_loss.item()
# total_gan += gan_loss.item()
avg_train_loss = train_loss / len(train_loader)
avg_rec = total_rec / len(train_loader)
avg_kld = total_kld / len(train_loader)
train_bce_loss.append(avg_rec)
train_kld_loss.append(avg_kld)
print(f"🔥 Epoch {epoch+1}: Avg Train Loss={avg_train_loss:.6f} | Recon={total_rec/len(train_loader):.6f} | KLD={get_annealed_beta(epoch) * total_kld/len(train_loader):.6f} | TVL={total_tvl/len(train_loader):.6f} | LPIPS={vae_lpips_weight * total_lpips/len(train_loader):.6f} | SSIM={total_ssim/len(train_loader):.6f}")
# print(f"Epoch {epoch}: D_loss={d_loss:.6f}")
# Test VAE
vae.eval()
# discriminator.eval()
test_loss = 0
with torch.no_grad():
for image_gt in tqdm(test_loader, desc=f"Epoch {epoch+1}/{vae_num_epochs}", colour="#FFDD22"):
x = image_gt.to(device)
with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()):
recon_x, mu, logvar = vae(x)
# loss, *_ = vae_loss(recon_x, x, mu, logvar, discriminator, lpips_fn, beta_kld=1.0)
# loss, *_ = vae_loss(recon_x, x, mu, logvar, lpips_fn, ssim_fn, beta_kld=vae_beta_kld)
loss, *_ = vae_loss(recon_x, x, mu, logvar, beta_kld=vae_beta_kld, lpips_fn=lpips_fn, ssim_fn=ssim_fn)
test_loss += loss.item()
avg_test_loss = test_loss / len(test_loader)
scheduler.step(avg_test_loss)
print(f"🧪 Test Loss = {avg_test_loss:.6f}")
if avg_test_loss < best_test_loss:
best_test_loss = avg_test_loss
torch.save({
"epoch": epoch,
"vae": vae.state_dict(),
"optimizer": optimizer.state_dict(),
"test_loss": best_test_loss
}, f"{vae_checkpoint_dir}/{vae_weight}")
print(f"Saved best model at {epoch+1}")
early_stopping_counter = 0
# Generate samples every test epoch save
vae.eval(); reconstruct(vae=vae, device=device, epoch=epoch)
else:
early_stopping_counter += 1
if early_stopping_counter >= vae_stopping_patience:
print("Early stopping triggered")
break
# torch.cuda.empty_cache()
plot_recon_vs_kld(train_bce_loss[2:], train_kld_loss[2:])
def plot_recon_vs_kld(train_bce_loss, train_kld_loss):
epochs = list(range(len(train_bce_loss)))
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_bce_loss, label='Reconstruction Loss (BCE)', color='blue', linewidth=2)
plt.plot(epochs, train_kld_loss, label='KL Divergence', color='red', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Reconstruction Loss vs KL Divergence over Epochs')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
if __name__ == "__main__":
seed_everything(42)
train_test_vae()