| from matplotlib import pyplot as plt
|
| from config import config
|
| from dataset import myDataset, myDiTDataset
|
| from transform import myTransform, myDiTTransform
|
| from torch.utils.data import DataLoader
|
| from model import mySiTModel
|
| from BBDMScheduler import BBDMScheduler
|
| from torch.optim.lr_scheduler import MultiStepLR
|
| from tqdm import tqdm
|
| from datetime import date
|
| import torch.nn.functional as F
|
|
|
| import torch
|
| import time
|
| from monai.utils import set_determinism
|
| from torch_ema import ExponentialMovingAverage
|
|
|
| set_determinism(42)
|
|
|
|
|
| def train():
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
| train_file_list = "SZCH-X-Rays_trainset.txt"
|
| test_file_list = "SZCH-X-Rays_valset.txt"
|
|
|
| cxr_path = "SZCH-X-Rays-741/CXR"
|
| bs_path = "SZCH-X-Rays-741/BS"
|
|
|
| myTrainSet = myDiTDataset(train_file_list, cxr_path, bs_path,
|
| myDiTTransform['trainTransform'])
|
| myTestSet = myDataset(test_file_list, cxr_path, bs_path,
|
| myTransform['testTransform'])
|
|
|
| myTrainLoader = DataLoader(myTrainSet, batch_size=config.batch_size, shuffle=True)
|
| myTestLoader = DataLoader(myTestSet, batch_size=config.batch_size, shuffle=False)
|
|
|
| print("Number of batches in train set:", len(myTrainLoader))
|
| print("Train set size:", len(myTrainSet))
|
| print("Number of batches in test set:", len(myTestLoader))
|
| print("Test set size:", len(myTestSet))
|
|
|
| model = mySiTModel.to(device).train()
|
| noise_scheduler = BBDMScheduler(num_train_timesteps=config.num_train_timesteps)
|
| noise_scheduler.set_timesteps(config.num_infer_timesteps, device="cuda",
|
| original_inference_steps=config.num_train_timesteps)
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=config.initial_learning_rate, eps=1e-6)
|
| milestones = [x * len(myTrainLoader) for x in config.milestones]
|
| optimizer_scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1)
|
| if config.ema:
|
| ema = ExponentialMovingAverage(model.parameters(), decay=0.995)
|
|
|
| train_losses = []
|
| test_losses = []
|
| plt_train_loss_epoch = []
|
| plt_test_loss_epoch = []
|
| train_epoch_list = list(range(0, config.epoch_number))
|
| test_epoch_list = list(range(0, int(config.epoch_number / config.test_epoch_interval)))
|
|
|
| VQGAN = torch.load("YOUR DAE MODEL PATH").to(device).eval().requires_grad_(False)
|
| print(time.strftime("%H:%M:%S", time.localtime()), "----------Begin Training----------")
|
| for epoch in range(config.epoch_number):
|
| model.train()
|
| print(time.strftime("%H:%M:%S", time.localtime()),
|
| f"Epoch:{epoch},learning rate:{optimizer.param_groups[0]['lr']}")
|
| for i, batch in tqdm(enumerate(myTrainLoader)):
|
| cxr_i, bs_i = batch[0].to(device), batch[1].to(device)
|
|
|
| with torch.no_grad():
|
| cxr = VQGAN.encode_stage_2_inputs(cxr_i)
|
| bs = VQGAN.encode_stage_2_inputs(bs_i)
|
|
|
| noise = torch.randn_like(cxr).to(device)
|
|
|
| timesteps = torch.randint(0, config.num_train_timesteps, (cxr.shape[0],), device=device).long()
|
|
|
| noisy_images = noise_scheduler.add_noise(bs, cxr, noise, timesteps)
|
|
|
| noisy_images = torch.cat(
|
| (noisy_images, cxr.clone() * torch.bernoulli(torch.full((cxr.shape[0], 1, 1, 1), 0.85)).to(device)),
|
| dim=1)
|
|
|
| pred = model(noisy_images, timesteps)[:, :4]
|
|
|
| if config.prediction_type == "noise":
|
| loss = F.mse_loss(pred.float(), noise.float())
|
| else:
|
| loss = F.mse_loss(pred.float(), ((timesteps / config.num_train_timesteps).view(-1, 1, 1, 1) * (
|
| cxr - bs) + noise_scheduler.sqrd_sigma(timesteps).view(-1, 1, 1, 1) * noise).float())
|
|
|
| loss.backward()
|
| train_losses.append(loss.item())
|
|
|
|
|
| optimizer.step()
|
| optimizer.zero_grad()
|
| optimizer_scheduler.step()
|
| ema.update()
|
|
|
| train_loss_epoch = sum(train_losses[-len(myTrainLoader):]) / len(myTrainLoader)
|
| print(time.strftime("%H:%M:%S", time.localtime()), f"Epoch:{epoch},train losses:{train_loss_epoch}")
|
| plt_train_loss_epoch.append(train_loss_epoch)
|
|
|
| if (epoch + 1) % config.test_epoch_interval == 0:
|
| model.eval()
|
| print(time.strftime("%H:%M:%S", time.localtime()), "----------Stop Training----------")
|
| print(time.strftime("%H:%M:%S", time.localtime()), "----------Begin Testing----------")
|
| with torch.no_grad():
|
| if config.ema:
|
| ema.store()
|
| ema.copy_to()
|
| for i, batch in tqdm(enumerate(myTestLoader)):
|
| cxr_i, bs_i = batch[0].to(device), batch[1].to(device)
|
|
|
| with torch.no_grad():
|
| cxr = VQGAN.encode_stage_2_inputs(cxr_i)
|
| bs = VQGAN.encode_stage_2_inputs(bs_i)
|
|
|
| noise = torch.randn_like(cxr).to(device)
|
|
|
| timesteps = torch.randint(0, config.num_train_timesteps, (cxr.shape[0],),
|
| device=device).long()
|
|
|
| noisy_images = noise_scheduler.add_noise(bs, cxr, noise, timesteps)
|
|
|
|
|
| noisy_images = torch.cat((noisy_images, cxr.clone() * torch.bernoulli(
|
| torch.full((cxr.shape[0], 1, 1, 1), 0.85)).to(device)), dim=1)
|
|
|
| pred = model(noisy_images, timesteps)[:, :4]
|
|
|
| if config.prediction_type == "noise":
|
| loss = F.mse_loss(pred.float(), noise.float())
|
| else:
|
| loss = F.mse_loss(pred.float(), (
|
| (timesteps / config.num_train_timesteps).view(-1, 1, 1, 1) * (
|
| cxr - bs) + noise_scheduler.sqrd_sigma(timesteps).view(-1, 1, 1,
|
| 1) * noise).float())
|
| test_losses.append(loss.item())
|
|
|
| if config.ema:
|
| ema.restore()
|
|
|
| test_loss_epoch = sum(test_losses[-len(myTestLoader):]) / len(myTestLoader)
|
| print(time.strftime("%H:%M:%S", time.localtime()), f"Epoch:{epoch},test losses:{test_loss_epoch}")
|
| plt_test_loss_epoch.append(test_loss_epoch)
|
| print(time.strftime("%H:%M:%S", time.localtime()), "----------End Validation----------")
|
| print(time.strftime("%H:%M:%S", time.localtime()), "----------Continue to Train----------")
|
| print(time.strftime("%H:%M:%S", time.localtime()), "----------End Training Normally----------")
|
|
|
| f, ([ax1, ax2]) = plt.subplots(1, 2)
|
| ax1.plot(train_epoch_list, plt_train_loss_epoch, color="red")
|
| ax1.set_title('Train loss')
|
| ax2.plot(test_epoch_list, plt_test_loss_epoch, color="blue")
|
| ax2.set_title('Test loss')
|
| plt.savefig("./loss-S-ema-noise-2000-dl-8c-s1-dep1.png")
|
| if not config.use_server:
|
| plt.show()
|
| if config.ema:
|
| ema.copy_to()
|
| torch.save(model,
|
| "dit-" + str(config.epoch_number) + "-" + str(date.today()) + "-S-ema-noise-2000-dl-8c-s1-dep1.pth")
|
|
|
|
|
| if __name__ == "__main__":
|
| train()
|
|
|