| import time
|
| from generative.networks.nets import VQVAE
|
| import matplotlib.pyplot as plt
|
| import torch
|
| from monai.config import print_config
|
| from torch.utils.data import DataLoader
|
| from monai.utils import set_determinism
|
| from tqdm import tqdm
|
| from generative.losses import PatchAdversarialLoss, PerceptualLoss
|
| from generative.networks.nets import PatchDiscriminator
|
| from datetime import date
|
| import torch.nn.functional as F
|
| from torch.optim.lr_scheduler import MultiStepLR
|
| from torchvision import transforms
|
| import cv2 as cv
|
| import numpy as np
|
| from torch.utils.data import Dataset
|
| import pandas as pd
|
| import os
|
| from depth_loss import depth_loss
|
|
|
| print_config()
|
|
|
| set_determinism(42)
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
| image_size = 1024
|
| vae_batch_size = 4
|
| n_example_images = 4
|
| vae_epoch_number = 200
|
| val_interval = 10
|
|
|
| train_file_list = "SZCH-X-Rays_trainset.txt"
|
| test_file_list = "SZCH-X-Rays_valset.txt"
|
|
|
| cxr_path = "SZCH-X-Rays-741/CXR"
|
| bs_path = "SZCH-X-Rays-741/BS"
|
|
|
| myVQGANModel = VQVAE(
|
| spatial_dims=2,
|
| in_channels=1,
|
| out_channels=1,
|
| num_channels=(128, 256, 512),
|
| num_res_channels=512,
|
| num_res_layers=2,
|
| downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1),),
|
| upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
|
| num_embeddings=1024,
|
| embedding_dim=4,
|
| )
|
|
|
|
|
|
|
| class myTransformMethod():
|
| def __call__(self, img):
|
|
|
| img = cv.resize(img, (image_size, image_size))
|
| if img.shape[-1] == 3:
|
| img = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
|
| return img
|
|
|
|
|
| myTransform = {
|
| 'Transform1': transforms.Compose([
|
| myTransformMethod(),
|
| transforms.ToTensor(),
|
| transforms.Normalize([0.5], [0.5])
|
| ]),
|
| }
|
|
|
|
|
| class mySingleDataset(Dataset):
|
| def __init__(self, filelist, img_dir, transform=None):
|
| self.img_dir = img_dir
|
| self.transform = transform
|
| self.filelist = pd.read_csv(filelist, sep="\t", header=None)
|
|
|
| def __len__(self):
|
| return len(self.filelist)
|
|
|
| def __getitem__(self, idx):
|
| img_path = self.img_dir
|
|
|
| file = self.filelist.iloc[idx, 0]
|
| image = cv.imread(os.path.join(img_path, file))
|
|
|
| if self.transform:
|
| image = self.transform(image)
|
| return image, file
|
|
|
|
|
| myTrainSet = mySingleDataset(train_file_list, cxr_path, myTransform['Transform1']) + mySingleDataset(
|
| train_file_list, bs_path, myTransform['Transform1'])
|
| myTestSet = mySingleDataset(test_file_list, cxr_path, myTransform['Transform1']) + mySingleDataset(test_file_list,
|
| bs_path,
|
| myTransform[
|
| 'Transform1'])
|
|
|
| myTrainLoader = DataLoader(myTrainSet, batch_size=vae_batch_size, shuffle=True)
|
| myTestLoader = DataLoader(myTestSet, batch_size=vae_batch_size, shuffle=False)
|
|
|
| print("Number of batches in train set:", len(myTrainLoader))
|
| print("Train set size:", len(myTrainSet))
|
| print("Number of batches in test set:", len(myTestLoader))
|
| print("Test set size:", len(myTestSet))
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| print(f"Using {device}")
|
|
|
| model = myVQGANModel.to(device)
|
|
|
| discriminator = PatchDiscriminator(spatial_dims=2, in_channels=1, num_layers_d=3, num_channels=64).to(device)
|
|
|
| perceptual_loss = PerceptualLoss(spatial_dims=2, network_type="vgg").to(device)
|
|
|
| optimizer_g = torch.optim.Adam(params=model.parameters(), lr=1e-4)
|
| optimizer_d = torch.optim.Adam(params=discriminator.parameters(), lr=5e-4)
|
|
|
| optimizer_scheduler_g = MultiStepLR(optimizer_g, milestones=[200 * len(myTrainLoader)], gamma=0.5)
|
| optimizer_scheduler_d = MultiStepLR(optimizer_d, milestones=[200 * len(myTrainLoader)], gamma=0.5)
|
|
|
| adv_loss = PatchAdversarialLoss(criterion="least_squares")
|
| adv_weight = 0.01
|
| perceptual_weight = 0.001
|
|
|
| depth_weight = 1
|
|
|
| epoch_recon_loss_list = []
|
| epoch_gen_loss_list = []
|
| epoch_disc_loss_list = []
|
| val_recon_epoch_loss_list = []
|
| intermediary_images = []
|
|
|
| total_start = time.time()
|
| for epoch in range(vae_epoch_number):
|
| model.train()
|
| discriminator.train()
|
| epoch_loss = 0
|
| gen_epoch_loss = 0
|
| disc_epoch_loss = 0
|
| progress_bar = tqdm(enumerate(myTrainLoader), total=len(myTrainLoader), ncols=110)
|
| progress_bar.set_description(f"Epoch {epoch}")
|
| for step, batch in progress_bar:
|
| images = batch[0].to(device=device, non_blocking=True)
|
|
|
| optimizer_g.zero_grad(set_to_none=True)
|
|
|
|
|
| reconstruction, quantization_loss = model(images=images)
|
| logits_fake = discriminator(reconstruction.contiguous().float())[-1]
|
|
|
| recons_loss = F.mse_loss(reconstruction.float(), images.float())
|
| p_loss = perceptual_loss(reconstruction.float(), images.float())
|
| d_loss = depth_loss(reconstruction.float(), images.float())
|
| generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False)
|
|
|
|
|
| loss_g = recons_loss + quantization_loss + perceptual_weight * p_loss + adv_weight * generator_loss + depth_weight * d_loss
|
|
|
| loss_g.backward()
|
| optimizer_g.step()
|
| optimizer_scheduler_g.step()
|
|
|
|
|
| optimizer_d.zero_grad(set_to_none=True)
|
|
|
| logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
|
| loss_d_fake = adv_loss(logits_fake, target_is_real=False, for_discriminator=True)
|
| logits_real = discriminator(images.contiguous().detach())[-1]
|
| loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True)
|
| discriminator_loss = (loss_d_fake + loss_d_real) * 0.5
|
|
|
| loss_d = adv_weight * discriminator_loss
|
|
|
| loss_d.backward()
|
| optimizer_d.step()
|
| optimizer_scheduler_d.step()
|
|
|
| epoch_loss += recons_loss.item()
|
| gen_epoch_loss += generator_loss.item()
|
| disc_epoch_loss += discriminator_loss.item()
|
|
|
| progress_bar.set_postfix(
|
| {
|
| "recons_loss": epoch_loss / (step + 1),
|
| "gen_loss": gen_epoch_loss / (step + 1),
|
| "disc_loss": disc_epoch_loss / (step + 1),
|
| }
|
| )
|
| epoch_recon_loss_list.append(epoch_loss / (step + 1))
|
| epoch_gen_loss_list.append(gen_epoch_loss / (step + 1))
|
| epoch_disc_loss_list.append(disc_epoch_loss / (step + 1))
|
|
|
| if (epoch + 1) % val_interval == 0:
|
| model.eval()
|
| val_loss = 0
|
| with torch.no_grad():
|
| for val_step, batch in enumerate(myTestLoader, start=1):
|
| images = batch[0].to(device=device, non_blocking=True)
|
|
|
| reconstruction, quantization_loss = model(images=images)
|
|
|
|
|
|
|
| if val_step == 1:
|
| intermediary_images.append(reconstruction[:n_example_images, 0])
|
|
|
| recons_loss = F.mse_loss(reconstruction.float(), images.float())
|
|
|
| val_loss += recons_loss.item()
|
|
|
| val_loss /= val_step
|
| val_recon_epoch_loss_list.append(val_loss)
|
|
|
| torch.save(model, str(date.today()) + "-SZCH-X-Rays-VQGAN"+str(depth_weight)+".pth")
|
|
|
| total_time = time.time() - total_start
|
| print(f"train completed, total time: {total_time}.")
|
|
|
| plt.style.use("seaborn-v0_8")
|
| plt.title("Learning Curves", fontsize=20)
|
| plt.plot(np.linspace(1, vae_epoch_number, vae_epoch_number), epoch_recon_loss_list, color="C0",
|
| linewidth=2.0,
|
| label="Train")
|
| plt.plot(
|
| np.linspace(val_interval, vae_epoch_number, int(vae_epoch_number / val_interval)),
|
| val_recon_epoch_loss_list,
|
| color="C1",
|
| linewidth=2.0,
|
| label="Validation",
|
| )
|
| plt.yticks(fontsize=12)
|
| plt.xticks(fontsize=12)
|
| plt.xlabel("Epochs", fontsize=16)
|
| plt.ylabel("Loss", fontsize=16)
|
| plt.legend(prop={"size": 14})
|
| plt.savefig("Learning-S"+str(depth_weight)+".png")
|
|
|
| plt.title("Adversarial Training Curves", fontsize=20)
|
| plt.plot(np.linspace(1, vae_epoch_number, vae_epoch_number), epoch_gen_loss_list, color="C0",
|
| linewidth=2.0,
|
| label="Generator")
|
| plt.plot(np.linspace(1, vae_epoch_number, vae_epoch_number), epoch_disc_loss_list, color="C1",
|
| linewidth=2.0,
|
| label="Discriminator")
|
| plt.yticks(fontsize=12)
|
| plt.xticks(fontsize=12)
|
| plt.xlabel("Epochs", fontsize=16)
|
| plt.ylabel("Loss", fontsize=16)
|
| plt.legend(prop={"size": 14})
|
| plt.savefig("Adversarial-S"+str(depth_weight)+".png")
|
|
|
| fig, ax = plt.subplots(nrows=1, ncols=2)
|
| images = (images[0, 0] * 0.5 + 0.5) * 255
|
| ax[0].imshow(images.detach().cpu(), vmin=0, vmax=255, cmap="gray")
|
| ax[0].axis("off")
|
| ax[0].title.set_text("Inputted Image")
|
| reconstructions = (reconstruction[0, 0] * 0.5 + 0.5) * 255
|
| ax[1].imshow(reconstructions.detach().cpu(), vmin=0, vmax=255, cmap="gray")
|
| ax[1].axis("off")
|
| ax[1].title.set_text("Reconstruction")
|
| plt.savefig("reconstruction images-S"+str(depth_weight)+".png")
|
|
|