DeBoneDiT / code /dae_train.py
diaoquesang's picture
Upload 15 files
6107278 verified
Raw
History Blame Contribute Delete
10.3 kB
import time
from generative.networks.nets import VQVAE
import matplotlib.pyplot as plt
import torch
from monai.config import print_config
from torch.utils.data import DataLoader
from monai.utils import set_determinism
from tqdm import tqdm
from generative.losses import PatchAdversarialLoss, PerceptualLoss
from generative.networks.nets import PatchDiscriminator
from datetime import date
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
from torchvision import transforms
import cv2 as cv
import numpy as np
from torch.utils.data import Dataset
import pandas as pd
import os
from depth_loss import depth_loss
print_config()
set_determinism(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = 1024
vae_batch_size = 4
n_example_images = 4
vae_epoch_number = 200
val_interval = 10
train_file_list = "SZCH-X-Rays_trainset.txt"
test_file_list = "SZCH-X-Rays_valset.txt"
cxr_path = "SZCH-X-Rays-741/CXR"
bs_path = "SZCH-X-Rays-741/BS"
myVQGANModel = VQVAE(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_channels=(128, 256, 512),
num_res_channels=512,
num_res_layers=2,
downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1),),
upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
num_embeddings=1024,
embedding_dim=4,
)
class myTransformMethod(): # Python3默认继承object类
def __call__(self, img): # __call___,让类实例变成一个可以被调用的对象,像函数
img = cv.resize(img, (image_size, image_size)) # 改变图像大小
if img.shape[-1] == 3: # HWC
img = cv.cvtColor(img, cv.COLOR_BGR2GRAY) # 将BGR(openCV默认读取为BGR)改为GRAY
return img # 返回预处理后的图像
myTransform = {
'Transform1': transforms.Compose([
myTransformMethod(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
]),
}
class mySingleDataset(Dataset): # 定义数据集类
def __init__(self, filelist, img_dir, transform=None): # 传入参数(标签路径,图像路径,图像预处理方式,标签预处理方式)
self.img_dir = img_dir # 读取图像路径
self.transform = transform # 读取图像预处理方式
self.filelist = pd.read_csv(filelist, sep="\t", header=None) # 读取文件名列表
def __len__(self):
return len(self.filelist) # 读取文件名数量作为数据集长度
def __getitem__(self, idx): # 从数据集中取出数据
img_path = self.img_dir # 读取图片文件夹路径
file = self.filelist.iloc[idx, 0] # 读取文件名
image = cv.imread(os.path.join(img_path, file)) # 用openCV的imread函数读取图像
if self.transform:
image = self.transform(image) # 图像预处理
return image, file # 返回图像和名称
myTrainSet = mySingleDataset(train_file_list, cxr_path, myTransform['Transform1']) + mySingleDataset(
train_file_list, bs_path, myTransform['Transform1'])
myTestSet = mySingleDataset(test_file_list, cxr_path, myTransform['Transform1']) + mySingleDataset(test_file_list,
bs_path,
myTransform[
'Transform1'])
myTrainLoader = DataLoader(myTrainSet, batch_size=vae_batch_size, shuffle=True)
myTestLoader = DataLoader(myTestSet, batch_size=vae_batch_size, shuffle=False)
print("Number of batches in train set:", len(myTrainLoader)) # 输出训练集batch数量
print("Train set size:", len(myTrainSet)) # 输出训练集大小
print("Number of batches in test set:", len(myTestLoader)) # 输出测试集batch数量
print("Test set size:", len(myTestSet)) # 输出测试集大小
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")
model = myVQGANModel.to(device)
discriminator = PatchDiscriminator(spatial_dims=2, in_channels=1, num_layers_d=3, num_channels=64).to(device)
perceptual_loss = PerceptualLoss(spatial_dims=2, network_type="vgg").to(device)
optimizer_g = torch.optim.Adam(params=model.parameters(), lr=1e-4)
optimizer_d = torch.optim.Adam(params=discriminator.parameters(), lr=5e-4)
optimizer_scheduler_g = MultiStepLR(optimizer_g, milestones=[200 * len(myTrainLoader)], gamma=0.5)
optimizer_scheduler_d = MultiStepLR(optimizer_d, milestones=[200 * len(myTrainLoader)], gamma=0.5)
adv_loss = PatchAdversarialLoss(criterion="least_squares")
adv_weight = 0.01
perceptual_weight = 0.001
# msssim_weight = 1
depth_weight = 1
epoch_recon_loss_list = []
epoch_gen_loss_list = []
epoch_disc_loss_list = []
val_recon_epoch_loss_list = []
intermediary_images = []
total_start = time.time()
for epoch in range(vae_epoch_number):
model.train()
discriminator.train()
epoch_loss = 0
gen_epoch_loss = 0
disc_epoch_loss = 0
progress_bar = tqdm(enumerate(myTrainLoader), total=len(myTrainLoader), ncols=110)
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in progress_bar:
images = batch[0].to(device=device, non_blocking=True)
optimizer_g.zero_grad(set_to_none=True)
# Generator part
reconstruction, quantization_loss = model(images=images)
logits_fake = discriminator(reconstruction.contiguous().float())[-1]
recons_loss = F.mse_loss(reconstruction.float(), images.float())
p_loss = perceptual_loss(reconstruction.float(), images.float())
d_loss = depth_loss(reconstruction.float(), images.float())
generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False)
# msssim = pytorch_msssim.MSSSIM(window_size=11, size_average=True, channel=1, normalize='relu')
# msssim_loss = 1 - msssim(reconstruction.float(), images.float())
loss_g = recons_loss + quantization_loss + perceptual_weight * p_loss + adv_weight * generator_loss + depth_weight * d_loss
loss_g.backward()
optimizer_g.step()
optimizer_scheduler_g.step()
# Discriminator part
optimizer_d.zero_grad(set_to_none=True)
logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
loss_d_fake = adv_loss(logits_fake, target_is_real=False, for_discriminator=True)
logits_real = discriminator(images.contiguous().detach())[-1]
loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True)
discriminator_loss = (loss_d_fake + loss_d_real) * 0.5
loss_d = adv_weight * discriminator_loss
loss_d.backward()
optimizer_d.step()
optimizer_scheduler_d.step()
epoch_loss += recons_loss.item()
gen_epoch_loss += generator_loss.item()
disc_epoch_loss += discriminator_loss.item()
progress_bar.set_postfix(
{
"recons_loss": epoch_loss / (step + 1),
"gen_loss": gen_epoch_loss / (step + 1),
"disc_loss": disc_epoch_loss / (step + 1),
}
)
epoch_recon_loss_list.append(epoch_loss / (step + 1))
epoch_gen_loss_list.append(gen_epoch_loss / (step + 1))
epoch_disc_loss_list.append(disc_epoch_loss / (step + 1))
if (epoch + 1) % val_interval == 0:
model.eval()
val_loss = 0
with torch.no_grad():
for val_step, batch in enumerate(myTestLoader, start=1):
images = batch[0].to(device=device, non_blocking=True)
reconstruction, quantization_loss = model(images=images)
# get the first sample from the first validation batch for visualization
# purposes
if val_step == 1:
intermediary_images.append(reconstruction[:n_example_images, 0])
recons_loss = F.mse_loss(reconstruction.float(), images.float())
val_loss += recons_loss.item()
val_loss /= val_step
val_recon_epoch_loss_list.append(val_loss)
torch.save(model, str(date.today()) + "-SZCH-X-Rays-VQGAN"+str(depth_weight)+".pth")
total_time = time.time() - total_start
print(f"train completed, total time: {total_time}.")
plt.style.use("seaborn-v0_8")
plt.title("Learning Curves", fontsize=20)
plt.plot(np.linspace(1, vae_epoch_number, vae_epoch_number), epoch_recon_loss_list, color="C0",
linewidth=2.0,
label="Train")
plt.plot(
np.linspace(val_interval, vae_epoch_number, int(vae_epoch_number / val_interval)),
val_recon_epoch_loss_list,
color="C1",
linewidth=2.0,
label="Validation",
)
plt.yticks(fontsize=12)
plt.xticks(fontsize=12)
plt.xlabel("Epochs", fontsize=16)
plt.ylabel("Loss", fontsize=16)
plt.legend(prop={"size": 14})
plt.savefig("Learning-S"+str(depth_weight)+".png")
plt.title("Adversarial Training Curves", fontsize=20)
plt.plot(np.linspace(1, vae_epoch_number, vae_epoch_number), epoch_gen_loss_list, color="C0",
linewidth=2.0,
label="Generator")
plt.plot(np.linspace(1, vae_epoch_number, vae_epoch_number), epoch_disc_loss_list, color="C1",
linewidth=2.0,
label="Discriminator")
plt.yticks(fontsize=12)
plt.xticks(fontsize=12)
plt.xlabel("Epochs", fontsize=16)
plt.ylabel("Loss", fontsize=16)
plt.legend(prop={"size": 14})
plt.savefig("Adversarial-S"+str(depth_weight)+".png")
fig, ax = plt.subplots(nrows=1, ncols=2)
images = (images[0, 0] * 0.5 + 0.5) * 255
ax[0].imshow(images.detach().cpu(), vmin=0, vmax=255, cmap="gray")
ax[0].axis("off")
ax[0].title.set_text("Inputted Image")
reconstructions = (reconstruction[0, 0] * 0.5 + 0.5) * 255
ax[1].imshow(reconstructions.detach().cpu(), vmin=0, vmax=255, cmap="gray")
ax[1].axis("off")
ax[1].title.set_text("Reconstruction")
plt.savefig("reconstruction images-S"+str(depth_weight)+".png")