File size: 10,295 Bytes
6107278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import time
from generative.networks.nets import VQVAE
import matplotlib.pyplot as plt
import torch
from monai.config import print_config
from torch.utils.data import DataLoader
from monai.utils import set_determinism
from tqdm import tqdm
from generative.losses import PatchAdversarialLoss, PerceptualLoss
from generative.networks.nets import PatchDiscriminator
from datetime import date
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
from torchvision import transforms
import cv2 as cv
import numpy as np
from torch.utils.data import Dataset
import pandas as pd
import os
from depth_loss import depth_loss

print_config()

set_determinism(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_size = 1024
vae_batch_size = 4
n_example_images = 4
vae_epoch_number = 200
val_interval = 10

train_file_list = "SZCH-X-Rays_trainset.txt"
test_file_list = "SZCH-X-Rays_valset.txt"

cxr_path = "SZCH-X-Rays-741/CXR"
bs_path = "SZCH-X-Rays-741/BS"

myVQGANModel = VQVAE(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    num_channels=(128, 256, 512),
    num_res_channels=512,
    num_res_layers=2,
    downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1),),
    upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
    num_embeddings=1024,
    embedding_dim=4,
)



class myTransformMethod():  # Python3默认继承object类
    def __call__(self, img):  # __call___,让类实例变成一个可以被调用的对象,像函数

        img = cv.resize(img, (image_size, image_size))  # 改变图像大小
        if img.shape[-1] == 3:  # HWC
            img = cv.cvtColor(img, cv.COLOR_BGR2GRAY)  # 将BGR(openCV默认读取为BGR)改为GRAY
        return img  # 返回预处理后的图像


myTransform = {
    'Transform1': transforms.Compose([
        myTransformMethod(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
}


class mySingleDataset(Dataset):  # 定义数据集类
    def __init__(self, filelist, img_dir, transform=None):  # 传入参数(标签路径,图像路径,图像预处理方式,标签预处理方式)
        self.img_dir = img_dir  # 读取图像路径
        self.transform = transform  # 读取图像预处理方式
        self.filelist = pd.read_csv(filelist, sep="\t", header=None)  # 读取文件名列表

    def __len__(self):
        return len(self.filelist)  # 读取文件名数量作为数据集长度

    def __getitem__(self, idx):  # 从数据集中取出数据
        img_path = self.img_dir  # 读取图片文件夹路径

        file = self.filelist.iloc[idx, 0]  # 读取文件名
        image = cv.imread(os.path.join(img_path, file))  # 用openCV的imread函数读取图像

        if self.transform:
            image = self.transform(image)  # 图像预处理
        return image, file  # 返回图像和名称


myTrainSet = mySingleDataset(train_file_list, cxr_path, myTransform['Transform1']) + mySingleDataset(
    train_file_list, bs_path, myTransform['Transform1'])
myTestSet = mySingleDataset(test_file_list, cxr_path, myTransform['Transform1']) + mySingleDataset(test_file_list,
                                                                                                   bs_path,
                                                                                                   myTransform[
                                                                                                       'Transform1'])

myTrainLoader = DataLoader(myTrainSet, batch_size=vae_batch_size, shuffle=True)
myTestLoader = DataLoader(myTestSet, batch_size=vae_batch_size, shuffle=False)

print("Number of batches in train set:", len(myTrainLoader))  # 输出训练集batch数量
print("Train set size:", len(myTrainSet))  # 输出训练集大小
print("Number of batches in test set:", len(myTestLoader))  # 输出测试集batch数量
print("Test set size:", len(myTestSet))  # 输出测试集大小

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

model = myVQGANModel.to(device)

discriminator = PatchDiscriminator(spatial_dims=2, in_channels=1, num_layers_d=3, num_channels=64).to(device)

perceptual_loss = PerceptualLoss(spatial_dims=2, network_type="vgg").to(device)

optimizer_g = torch.optim.Adam(params=model.parameters(), lr=1e-4)
optimizer_d = torch.optim.Adam(params=discriminator.parameters(), lr=5e-4)

optimizer_scheduler_g = MultiStepLR(optimizer_g, milestones=[200 * len(myTrainLoader)], gamma=0.5)
optimizer_scheduler_d = MultiStepLR(optimizer_d, milestones=[200 * len(myTrainLoader)], gamma=0.5)

adv_loss = PatchAdversarialLoss(criterion="least_squares")
adv_weight = 0.01
perceptual_weight = 0.001
# msssim_weight = 1
depth_weight = 1

epoch_recon_loss_list = []
epoch_gen_loss_list = []
epoch_disc_loss_list = []
val_recon_epoch_loss_list = []
intermediary_images = []

total_start = time.time()
for epoch in range(vae_epoch_number):
    model.train()
    discriminator.train()
    epoch_loss = 0
    gen_epoch_loss = 0
    disc_epoch_loss = 0
    progress_bar = tqdm(enumerate(myTrainLoader), total=len(myTrainLoader), ncols=110)
    progress_bar.set_description(f"Epoch {epoch}")
    for step, batch in progress_bar:
        images = batch[0].to(device=device, non_blocking=True)

        optimizer_g.zero_grad(set_to_none=True)

        # Generator part
        reconstruction, quantization_loss = model(images=images)
        logits_fake = discriminator(reconstruction.contiguous().float())[-1]

        recons_loss = F.mse_loss(reconstruction.float(), images.float())
        p_loss = perceptual_loss(reconstruction.float(), images.float())
        d_loss = depth_loss(reconstruction.float(), images.float())
        generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False)
        # msssim = pytorch_msssim.MSSSIM(window_size=11, size_average=True, channel=1, normalize='relu')
        # msssim_loss = 1 - msssim(reconstruction.float(), images.float())
        loss_g = recons_loss + quantization_loss + perceptual_weight * p_loss + adv_weight * generator_loss + depth_weight * d_loss

        loss_g.backward()
        optimizer_g.step()
        optimizer_scheduler_g.step()

        # Discriminator part
        optimizer_d.zero_grad(set_to_none=True)

        logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
        loss_d_fake = adv_loss(logits_fake, target_is_real=False, for_discriminator=True)
        logits_real = discriminator(images.contiguous().detach())[-1]
        loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True)
        discriminator_loss = (loss_d_fake + loss_d_real) * 0.5

        loss_d = adv_weight * discriminator_loss

        loss_d.backward()
        optimizer_d.step()
        optimizer_scheduler_d.step()

        epoch_loss += recons_loss.item()
        gen_epoch_loss += generator_loss.item()
        disc_epoch_loss += discriminator_loss.item()

        progress_bar.set_postfix(
            {
                "recons_loss": epoch_loss / (step + 1),
                "gen_loss": gen_epoch_loss / (step + 1),
                "disc_loss": disc_epoch_loss / (step + 1),
            }
        )
    epoch_recon_loss_list.append(epoch_loss / (step + 1))
    epoch_gen_loss_list.append(gen_epoch_loss / (step + 1))
    epoch_disc_loss_list.append(disc_epoch_loss / (step + 1))

    if (epoch + 1) % val_interval == 0:
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for val_step, batch in enumerate(myTestLoader, start=1):
                images = batch[0].to(device=device, non_blocking=True)

                reconstruction, quantization_loss = model(images=images)

                # get the first sample from the first validation batch for visualization
                # purposes
                if val_step == 1:
                    intermediary_images.append(reconstruction[:n_example_images, 0])

                recons_loss = F.mse_loss(reconstruction.float(), images.float())

                val_loss += recons_loss.item()

        val_loss /= val_step
        val_recon_epoch_loss_list.append(val_loss)

        torch.save(model, str(date.today()) + "-SZCH-X-Rays-VQGAN"+str(depth_weight)+".pth")

total_time = time.time() - total_start
print(f"train completed, total time: {total_time}.")

plt.style.use("seaborn-v0_8")
plt.title("Learning Curves", fontsize=20)
plt.plot(np.linspace(1, vae_epoch_number, vae_epoch_number), epoch_recon_loss_list, color="C0",
         linewidth=2.0,
         label="Train")
plt.plot(
    np.linspace(val_interval, vae_epoch_number, int(vae_epoch_number / val_interval)),
    val_recon_epoch_loss_list,
    color="C1",
    linewidth=2.0,
    label="Validation",
)
plt.yticks(fontsize=12)
plt.xticks(fontsize=12)
plt.xlabel("Epochs", fontsize=16)
plt.ylabel("Loss", fontsize=16)
plt.legend(prop={"size": 14})
plt.savefig("Learning-S"+str(depth_weight)+".png")

plt.title("Adversarial Training Curves", fontsize=20)
plt.plot(np.linspace(1, vae_epoch_number, vae_epoch_number), epoch_gen_loss_list, color="C0",
         linewidth=2.0,
         label="Generator")
plt.plot(np.linspace(1, vae_epoch_number, vae_epoch_number), epoch_disc_loss_list, color="C1",
         linewidth=2.0,
         label="Discriminator")
plt.yticks(fontsize=12)
plt.xticks(fontsize=12)
plt.xlabel("Epochs", fontsize=16)
plt.ylabel("Loss", fontsize=16)
plt.legend(prop={"size": 14})
plt.savefig("Adversarial-S"+str(depth_weight)+".png")

fig, ax = plt.subplots(nrows=1, ncols=2)
images = (images[0, 0] * 0.5 + 0.5) * 255
ax[0].imshow(images.detach().cpu(), vmin=0, vmax=255, cmap="gray")
ax[0].axis("off")
ax[0].title.set_text("Inputted Image")
reconstructions = (reconstruction[0, 0] * 0.5 + 0.5) * 255
ax[1].imshow(reconstructions.detach().cpu(), vmin=0, vmax=255, cmap="gray")
ax[1].axis("off")
ax[1].title.set_text("Reconstruction")
plt.savefig("reconstruction images-S"+str(depth_weight)+".png")