import sys, os sys.path.insert(0, os.path.dirname(__file__)) os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" import torch from torch import nn, optim import torch.nn.functional as F from dataloader import image_dataloader from torch.optim.lr_scheduler import ReduceLROnPlateau from tqdm import tqdm from config import * from vae import VAE import matplotlib.pyplot as plt from seed import seed_everything from sample_vae import reconstruct import lpips from pytorch_msssim import SSIM # # Simple PatchGAN Discriminator # class PatchDiscriminator(nn.Module): # def __init__(self, in_channels=3): # super().__init__() # def block(in_c, out_c, stride=2, normalize=True): # layers = [nn.Conv2d(in_c, out_c, 4, stride, 1, bias=False)] # if normalize: layers.append(nn.GroupNorm(vae_group_size, out_c)) # layers.append(nn.LeakyReLU(0.2, inplace=True)) # return layers # self.model = nn.Sequential( # *block(in_channels, 64, normalize=False), # 128->64 # *block(64, 128), # 64->32 # *block(128, 256), # 32->16 # *block(256, 512), # 16->8 # nn.Conv2d(512, 1, 4, 1, 0) # output 1x1 map # ) # def forward(self, x): # return self.model(x) # def vae_loss(recon_x, x, mu, logvar, discriminator=None, beta_kld=1.0): # # lpips_weight = 1 # # lpips_loss = lpips_fn(x, recon_x).mean() # # gan_weight = 0 # # rec_loss = F.l1_loss(recon_x, x, reduction="mean") # # rec_loss = torch.sum((recon_x - x) ** 2, dim=[1,2,3]).mean() # # kld = torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=[1,2,3])) # # rec_loss = F.mse_loss(recon_x, x, reduction="sum") / recon_x.size(0) # # rec_loss = F.mse_loss(recon_x, x, reduction="mean") # # kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=[1,2,3]).mean() / (4*16*16) # rec_loss = F.mse_loss(recon_x, x, reduction="none").view(x.size(0), -1).sum(dim=1).mean() # kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=[1, 2, 3]).mean() # total_loss = (rec_loss + beta_kld * kld) # # gan_loss = -discriminator(recon_x).mean() # # total_loss = rec_loss + kld * beta_kld + tvl * vae_lambda_tvl + lpips_loss * lpips_weight + gan_loss * gan_weight # # return total_loss, rec_loss, kld, tvl, lpips_loss, gan_loss # # total_loss = rec_loss + kld * beta_kld # return total_loss, rec_loss, kld, 0, 0 # def update_discriminator(discriminator, optimizer_D, real_imgs, fake_imgs): # optimizer_D.zero_grad() # real_pred = discriminator(real_imgs) # fake_pred = discriminator(fake_imgs.detach()) # loss_D = torch.mean(F.relu(1.0 - real_pred)) + torch.mean(F.relu(1.0 + fake_pred)) # Hinge loss # loss_D.backward() # optimizer_D.step() # return loss_D.item() def total_variance_loss(x): tvl_h = torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2).sum() # TV-L2 tvl_w = torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2).sum() # TV-L2 return (tvl_h + tvl_w) / x.size(0) def get_annealed_beta(epoch, warmup_epochs=100, max_beta=1.0): return vae_beta_kld * min(max_beta, epoch / warmup_epochs) # def vae_loss(recon_x, x, mu, logvar, lpips_fn, ssim_fn, beta_kld=1.0): def vae_loss(recon_x, x, mu, logvar, beta_kld=1.0, lpips_fn=None, ssim_fn=None): # mse = torch.mean((recon_x - x) ** 2) # kld_loss = torch.mean(-0.5 * (1 + logvar - mu.pow(2) - logvar.exp())) # b, c, h, w = x.shape tvl = total_variance_loss(recon_x) if vae_lambda_tvl > 0 else 0.0 mse = F.mse_loss(recon_x, x, reduction="sum") / recon_x.size(0) kld_loss = torch.sum(-0.5 * (1 + logvar - mu.pow(2) - logvar.exp())) / recon_x.size(0) # mse = F.mse_loss(recon_x, x, reduction="mean") # kld_loss = torch.mean(-0.5 * (1 + logvar - mu.pow(2) - logvar.exp())) # divides by latent_ch×H×W×B with torch.amp.autocast("cuda", enabled=False): ssim_loss = 1 - ssim_fn(recon_x.float(), x.float()) lpips_loss = lpips_fn(recon_x.float() * 2 - 1, x.float() * 2 - 1).mean() total_loss = mse + (beta_kld * kld_loss) + (vae_lambda_tvl * tvl) + (vae_lpips_weight * lpips_loss) + ssim_loss * 0.1 return total_loss, mse, kld_loss, tvl, lpips_loss, ssim_loss # total_loss = mse + (beta_kld * kld_loss) # return total_loss, mse, kld_loss, 0, 0, 0 def train_test_vae(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_loader, test_loader = image_dataloader() vae = VAE().to(device) optimizer = optim.AdamW(vae.parameters(), lr=vae_optim_lr, betas=(0.9, 0.999), weight_decay=0) scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5) os.makedirs(f"{vae_checkpoint_dir}", exist_ok=True) early_stopping_counter = 0 best_test_loss = float("inf") train_bce_loss = [] train_kld_loss = [] lpips_fn = lpips.LPIPS(net='vgg').to(device) lpips_fn.eval() for param in lpips_fn.parameters(): param.requires_grad = False ssim_fn = SSIM(data_range=1, size_average=True, channel=3) scaler = torch.amp.GradScaler("cuda", enabled=torch.cuda.is_available()) # discriminator = PatchDiscriminator().to(device) # optimizer_D = optim.AdamW(discriminator.parameters(), lr=vae_optim_lr, betas=(0.5, 0.999)) # Train VAE for epoch in range(vae_num_epochs): vae.train() train_loss, total_rec, total_kld, total_tvl, total_lpips, total_ssim = 0, 0, 0, 0, 0, 0 for image_gt in tqdm(train_loader, desc=f"Epoch {epoch+1}/{vae_num_epochs}", colour="#CC00FF"): x = image_gt.to(device) # discriminator.train() # for param in discriminator.parameters(): param.requires_grad = True # with torch.no_grad(): recon_x_disc, _, _ = vae(x) # d_loss = update_discriminator(discriminator, optimizer_D, x, recon_x_disc.detach()) # Update discriminator # discriminator.eval() # Optional: turn off dropout/batchnorm # for param in discriminator.parameters(): param.requires_grad = False with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()): recon_x, mu, logvar = vae(x) # loss, rec, kld, tvl, lpips_loss, gan_loss = vae_loss(recon_x, x, mu, logvar, discriminator, lpips_fn, beta_kld=get_annealed_beta(epoch)) # Update VAE (generator) # loss, rec, kld, tvl, lpips_loss = vae_loss(recon_x, x, mu, logvar, lpips_fn, beta_kld=min(1, global_step / warmup_steps)) # Update VAE (generator) # loss, rec, kld, tvl, lpips_loss, ssim_loss = vae_loss(recon_x, x, mu, logvar, lpips_fn, ssim_fn, beta_kld=get_annealed_beta(epoch)) # Update VAE (generator) loss, rec, kld, tvl, lpips_loss, ssim_loss = vae_loss(recon_x, x, mu, logvar, beta_kld=get_annealed_beta(epoch), lpips_fn=lpips_fn, ssim_fn=ssim_fn) # Update VAE (generator) optimizer.zero_grad(set_to_none=True) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0) scaler.step(optimizer) scaler.update() train_loss += loss.item() total_rec += rec.item() total_kld += kld.item() total_tvl += tvl.item() if vae_lambda_tvl > 0 else 0.0 total_lpips += lpips_loss.item() total_ssim += ssim_loss.item() # total_gan += gan_loss.item() avg_train_loss = train_loss / len(train_loader) avg_rec = total_rec / len(train_loader) avg_kld = total_kld / len(train_loader) train_bce_loss.append(avg_rec) train_kld_loss.append(avg_kld) print(f"🔥 Epoch {epoch+1}: Avg Train Loss={avg_train_loss:.6f} | Recon={total_rec/len(train_loader):.6f} | KLD={get_annealed_beta(epoch) * total_kld/len(train_loader):.6f} | TVL={total_tvl/len(train_loader):.6f} | LPIPS={vae_lpips_weight * total_lpips/len(train_loader):.6f} | SSIM={total_ssim/len(train_loader):.6f}") # print(f"Epoch {epoch}: D_loss={d_loss:.6f}") # Test VAE vae.eval() # discriminator.eval() test_loss = 0 with torch.no_grad(): for image_gt in tqdm(test_loader, desc=f"Epoch {epoch+1}/{vae_num_epochs}", colour="#FFDD22"): x = image_gt.to(device) with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()): recon_x, mu, logvar = vae(x) # loss, *_ = vae_loss(recon_x, x, mu, logvar, discriminator, lpips_fn, beta_kld=1.0) # loss, *_ = vae_loss(recon_x, x, mu, logvar, lpips_fn, ssim_fn, beta_kld=vae_beta_kld) loss, *_ = vae_loss(recon_x, x, mu, logvar, beta_kld=vae_beta_kld, lpips_fn=lpips_fn, ssim_fn=ssim_fn) test_loss += loss.item() avg_test_loss = test_loss / len(test_loader) scheduler.step(avg_test_loss) print(f"🧪 Test Loss = {avg_test_loss:.6f}") if avg_test_loss < best_test_loss: best_test_loss = avg_test_loss torch.save({ "epoch": epoch, "vae": vae.state_dict(), "optimizer": optimizer.state_dict(), "test_loss": best_test_loss }, f"{vae_checkpoint_dir}/{vae_weight}") print(f"Saved best model at {epoch+1}") early_stopping_counter = 0 # Generate samples every test epoch save vae.eval(); reconstruct(vae=vae, device=device, epoch=epoch) else: early_stopping_counter += 1 if early_stopping_counter >= vae_stopping_patience: print("Early stopping triggered") break # torch.cuda.empty_cache() plot_recon_vs_kld(train_bce_loss[2:], train_kld_loss[2:]) def plot_recon_vs_kld(train_bce_loss, train_kld_loss): epochs = list(range(len(train_bce_loss))) plt.figure(figsize=(10, 6)) plt.plot(epochs, train_bce_loss, label='Reconstruction Loss (BCE)', color='blue', linewidth=2) plt.plot(epochs, train_kld_loss, label='KL Divergence', color='red', linewidth=2) plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('Reconstruction Loss vs KL Divergence over Epochs') plt.legend() plt.grid(True) plt.tight_layout() plt.show() if __name__ == "__main__": seed_everything(42) train_test_vae()