Spaces:
Sleeping
Sleeping
| import sys, os | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" | |
| import torch | |
| from torch import nn, optim | |
| import torch.nn.functional as F | |
| from dataloader import image_dataloader | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau | |
| from tqdm import tqdm | |
| from config import * | |
| from vae import VAE | |
| import matplotlib.pyplot as plt | |
| from seed import seed_everything | |
| from sample_vae import reconstruct | |
| import lpips | |
| from pytorch_msssim import SSIM | |
| # # Simple PatchGAN Discriminator | |
| # class PatchDiscriminator(nn.Module): | |
| # def __init__(self, in_channels=3): | |
| # super().__init__() | |
| # def block(in_c, out_c, stride=2, normalize=True): | |
| # layers = [nn.Conv2d(in_c, out_c, 4, stride, 1, bias=False)] | |
| # if normalize: layers.append(nn.GroupNorm(vae_group_size, out_c)) | |
| # layers.append(nn.LeakyReLU(0.2, inplace=True)) | |
| # return layers | |
| # self.model = nn.Sequential( | |
| # *block(in_channels, 64, normalize=False), # 128->64 | |
| # *block(64, 128), # 64->32 | |
| # *block(128, 256), # 32->16 | |
| # *block(256, 512), # 16->8 | |
| # nn.Conv2d(512, 1, 4, 1, 0) # output 1x1 map | |
| # ) | |
| # def forward(self, x): | |
| # return self.model(x) | |
| # def vae_loss(recon_x, x, mu, logvar, discriminator=None, beta_kld=1.0): | |
| # # lpips_weight = 1 | |
| # # lpips_loss = lpips_fn(x, recon_x).mean() | |
| # # gan_weight = 0 | |
| # # rec_loss = F.l1_loss(recon_x, x, reduction="mean") | |
| # # rec_loss = torch.sum((recon_x - x) ** 2, dim=[1,2,3]).mean() | |
| # # kld = torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=[1,2,3])) | |
| # # rec_loss = F.mse_loss(recon_x, x, reduction="sum") / recon_x.size(0) | |
| # # rec_loss = F.mse_loss(recon_x, x, reduction="mean") | |
| # # kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=[1,2,3]).mean() / (4*16*16) | |
| # rec_loss = F.mse_loss(recon_x, x, reduction="none").view(x.size(0), -1).sum(dim=1).mean() | |
| # kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=[1, 2, 3]).mean() | |
| # total_loss = (rec_loss + beta_kld * kld) | |
| # # gan_loss = -discriminator(recon_x).mean() | |
| # # total_loss = rec_loss + kld * beta_kld + tvl * vae_lambda_tvl + lpips_loss * lpips_weight + gan_loss * gan_weight | |
| # # return total_loss, rec_loss, kld, tvl, lpips_loss, gan_loss | |
| # # total_loss = rec_loss + kld * beta_kld | |
| # return total_loss, rec_loss, kld, 0, 0 | |
| # def update_discriminator(discriminator, optimizer_D, real_imgs, fake_imgs): | |
| # optimizer_D.zero_grad() | |
| # real_pred = discriminator(real_imgs) | |
| # fake_pred = discriminator(fake_imgs.detach()) | |
| # loss_D = torch.mean(F.relu(1.0 - real_pred)) + torch.mean(F.relu(1.0 + fake_pred)) # Hinge loss | |
| # loss_D.backward() | |
| # optimizer_D.step() | |
| # return loss_D.item() | |
| def total_variance_loss(x): | |
| tvl_h = torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2).sum() # TV-L2 | |
| tvl_w = torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2).sum() # TV-L2 | |
| return (tvl_h + tvl_w) / x.size(0) | |
| def get_annealed_beta(epoch, warmup_epochs=100, max_beta=1.0): return vae_beta_kld * min(max_beta, epoch / warmup_epochs) | |
| # def vae_loss(recon_x, x, mu, logvar, lpips_fn, ssim_fn, beta_kld=1.0): | |
| def vae_loss(recon_x, x, mu, logvar, beta_kld=1.0, lpips_fn=None, ssim_fn=None): | |
| # mse = torch.mean((recon_x - x) ** 2) | |
| # kld_loss = torch.mean(-0.5 * (1 + logvar - mu.pow(2) - logvar.exp())) | |
| # b, c, h, w = x.shape | |
| tvl = total_variance_loss(recon_x) if vae_lambda_tvl > 0 else 0.0 | |
| mse = F.mse_loss(recon_x, x, reduction="sum") / recon_x.size(0) | |
| kld_loss = torch.sum(-0.5 * (1 + logvar - mu.pow(2) - logvar.exp())) / recon_x.size(0) | |
| # mse = F.mse_loss(recon_x, x, reduction="mean") | |
| # kld_loss = torch.mean(-0.5 * (1 + logvar - mu.pow(2) - logvar.exp())) # divides by latent_ch×H×W×B | |
| with torch.amp.autocast("cuda", enabled=False): | |
| ssim_loss = 1 - ssim_fn(recon_x.float(), x.float()) | |
| lpips_loss = lpips_fn(recon_x.float() * 2 - 1, x.float() * 2 - 1).mean() | |
| total_loss = mse + (beta_kld * kld_loss) + (vae_lambda_tvl * tvl) + (vae_lpips_weight * lpips_loss) + ssim_loss * 0.1 | |
| return total_loss, mse, kld_loss, tvl, lpips_loss, ssim_loss | |
| # total_loss = mse + (beta_kld * kld_loss) | |
| # return total_loss, mse, kld_loss, 0, 0, 0 | |
| def train_test_vae(): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| train_loader, test_loader = image_dataloader() | |
| vae = VAE().to(device) | |
| optimizer = optim.AdamW(vae.parameters(), lr=vae_optim_lr, betas=(0.9, 0.999), weight_decay=0) | |
| scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5) | |
| os.makedirs(f"{vae_checkpoint_dir}", exist_ok=True) | |
| early_stopping_counter = 0 | |
| best_test_loss = float("inf") | |
| train_bce_loss = [] | |
| train_kld_loss = [] | |
| lpips_fn = lpips.LPIPS(net='vgg').to(device) | |
| lpips_fn.eval() | |
| for param in lpips_fn.parameters(): | |
| param.requires_grad = False | |
| ssim_fn = SSIM(data_range=1, size_average=True, channel=3) | |
| scaler = torch.amp.GradScaler("cuda", enabled=torch.cuda.is_available()) | |
| # discriminator = PatchDiscriminator().to(device) | |
| # optimizer_D = optim.AdamW(discriminator.parameters(), lr=vae_optim_lr, betas=(0.5, 0.999)) | |
| # Train VAE | |
| for epoch in range(vae_num_epochs): | |
| vae.train() | |
| train_loss, total_rec, total_kld, total_tvl, total_lpips, total_ssim = 0, 0, 0, 0, 0, 0 | |
| for image_gt in tqdm(train_loader, desc=f"Epoch {epoch+1}/{vae_num_epochs}", colour="#CC00FF"): | |
| x = image_gt.to(device) | |
| # discriminator.train() | |
| # for param in discriminator.parameters(): param.requires_grad = True | |
| # with torch.no_grad(): recon_x_disc, _, _ = vae(x) | |
| # d_loss = update_discriminator(discriminator, optimizer_D, x, recon_x_disc.detach()) # Update discriminator | |
| # discriminator.eval() # Optional: turn off dropout/batchnorm | |
| # for param in discriminator.parameters(): param.requires_grad = False | |
| with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()): | |
| recon_x, mu, logvar = vae(x) | |
| # loss, rec, kld, tvl, lpips_loss, gan_loss = vae_loss(recon_x, x, mu, logvar, discriminator, lpips_fn, beta_kld=get_annealed_beta(epoch)) # Update VAE (generator) | |
| # loss, rec, kld, tvl, lpips_loss = vae_loss(recon_x, x, mu, logvar, lpips_fn, beta_kld=min(1, global_step / warmup_steps)) # Update VAE (generator) | |
| # loss, rec, kld, tvl, lpips_loss, ssim_loss = vae_loss(recon_x, x, mu, logvar, lpips_fn, ssim_fn, beta_kld=get_annealed_beta(epoch)) # Update VAE (generator) | |
| loss, rec, kld, tvl, lpips_loss, ssim_loss = vae_loss(recon_x, x, mu, logvar, beta_kld=get_annealed_beta(epoch), lpips_fn=lpips_fn, ssim_fn=ssim_fn) # Update VAE (generator) | |
| optimizer.zero_grad(set_to_none=True) | |
| scaler.scale(loss).backward() | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| train_loss += loss.item() | |
| total_rec += rec.item() | |
| total_kld += kld.item() | |
| total_tvl += tvl.item() if vae_lambda_tvl > 0 else 0.0 | |
| total_lpips += lpips_loss.item() | |
| total_ssim += ssim_loss.item() | |
| # total_gan += gan_loss.item() | |
| avg_train_loss = train_loss / len(train_loader) | |
| avg_rec = total_rec / len(train_loader) | |
| avg_kld = total_kld / len(train_loader) | |
| train_bce_loss.append(avg_rec) | |
| train_kld_loss.append(avg_kld) | |
| print(f"🔥 Epoch {epoch+1}: Avg Train Loss={avg_train_loss:.6f} | Recon={total_rec/len(train_loader):.6f} | KLD={get_annealed_beta(epoch) * total_kld/len(train_loader):.6f} | TVL={total_tvl/len(train_loader):.6f} | LPIPS={vae_lpips_weight * total_lpips/len(train_loader):.6f} | SSIM={total_ssim/len(train_loader):.6f}") | |
| # print(f"Epoch {epoch}: D_loss={d_loss:.6f}") | |
| # Test VAE | |
| vae.eval() | |
| # discriminator.eval() | |
| test_loss = 0 | |
| with torch.no_grad(): | |
| for image_gt in tqdm(test_loader, desc=f"Epoch {epoch+1}/{vae_num_epochs}", colour="#FFDD22"): | |
| x = image_gt.to(device) | |
| with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()): | |
| recon_x, mu, logvar = vae(x) | |
| # loss, *_ = vae_loss(recon_x, x, mu, logvar, discriminator, lpips_fn, beta_kld=1.0) | |
| # loss, *_ = vae_loss(recon_x, x, mu, logvar, lpips_fn, ssim_fn, beta_kld=vae_beta_kld) | |
| loss, *_ = vae_loss(recon_x, x, mu, logvar, beta_kld=vae_beta_kld, lpips_fn=lpips_fn, ssim_fn=ssim_fn) | |
| test_loss += loss.item() | |
| avg_test_loss = test_loss / len(test_loader) | |
| scheduler.step(avg_test_loss) | |
| print(f"🧪 Test Loss = {avg_test_loss:.6f}") | |
| if avg_test_loss < best_test_loss: | |
| best_test_loss = avg_test_loss | |
| torch.save({ | |
| "epoch": epoch, | |
| "vae": vae.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "test_loss": best_test_loss | |
| }, f"{vae_checkpoint_dir}/{vae_weight}") | |
| print(f"Saved best model at {epoch+1}") | |
| early_stopping_counter = 0 | |
| # Generate samples every test epoch save | |
| vae.eval(); reconstruct(vae=vae, device=device, epoch=epoch) | |
| else: | |
| early_stopping_counter += 1 | |
| if early_stopping_counter >= vae_stopping_patience: | |
| print("Early stopping triggered") | |
| break | |
| # torch.cuda.empty_cache() | |
| plot_recon_vs_kld(train_bce_loss[2:], train_kld_loss[2:]) | |
| def plot_recon_vs_kld(train_bce_loss, train_kld_loss): | |
| epochs = list(range(len(train_bce_loss))) | |
| plt.figure(figsize=(10, 6)) | |
| plt.plot(epochs, train_bce_loss, label='Reconstruction Loss (BCE)', color='blue', linewidth=2) | |
| plt.plot(epochs, train_kld_loss, label='KL Divergence', color='red', linewidth=2) | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Loss') | |
| plt.title('Reconstruction Loss vs KL Divergence over Epochs') | |
| plt.legend() | |
| plt.grid(True) | |
| plt.tight_layout() | |
| plt.show() | |
| if __name__ == "__main__": | |
| seed_everything(42) | |
| train_test_vae() |