| from matplotlib import pyplot as plt
|
| from config import config
|
| from dataset import myDataset
|
| from transform import myTransform
|
| from torch.utils.data import DataLoader
|
| from model import myUnet, myVQGANModel
|
| from diffusers import LCMScheduler,DDPMScheduler
|
| from torch.optim.lr_scheduler import MultiStepLR
|
| from tqdm import tqdm
|
| from datetime import date
|
|
|
| import torch.nn.functional as F
|
| import torch
|
| import time
|
| from monai.utils import set_determinism
|
|
|
| set_determinism(42)
|
|
|
|
|
| def train():
|
| if config.use_server:
|
| file = open('log.txt', 'w')
|
| else:
|
| file = None
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
| train_file_list = "JSRT_trainset.txt"
|
| test_file_list = "JSRT_valset.txt"
|
|
|
| cxr_path = "/mntcephfs/med_dataset/SYF/JSRTnew1024-241/CXR"
|
| bs_path = "/mntcephfs/med_dataset/SYF/JSRTnew1024-241/BS"
|
| masked_cxr_path = "/mntcephfs/med_dataset/SYF/JSRTnew1024-241/Masked_CXR"
|
| masked_bs_path = "/mntcephfs/med_dataset/SYF/JSRTnew1024-241/Masked_BS"
|
|
|
| myTrainSet = myDataset(train_file_list, cxr_path, bs_path,
|
| myTransform['trainTransform']) + myDataset(train_file_list, masked_cxr_path, masked_bs_path,
|
| myTransform['trainTransform'])
|
| myTestSet = myDataset(test_file_list, cxr_path, bs_path,
|
| myTransform['testTransform']) + myDataset(test_file_list, masked_cxr_path, masked_bs_path,
|
| myTransform['testTransform'])
|
|
|
| myTrainLoader = DataLoader(myTrainSet, batch_size=config.batch_size, shuffle=True)
|
| myTestLoader = DataLoader(myTestSet, batch_size=config.batch_size, shuffle=True)
|
|
|
| print("Number of batches in train set:", len(myTrainLoader))
|
| print("Train set size:", len(myTrainSet))
|
| print("Number of batches in test set:", len(myTestLoader))
|
| print("Test set size:", len(myTestSet))
|
|
|
| model = myUnet.to(device).train()
|
|
|
|
|
| noise_scheduler = LCMScheduler(num_train_timesteps=config.num_train_timesteps)
|
| noise_scheduler.set_timesteps(config.num_infer_timesteps)
|
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=config.initial_learning_rate, eps=1e-6)
|
| milestones = [x * len(myTrainLoader) for x in config.milestones]
|
| optimizer_scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1)
|
|
|
| train_losses = []
|
| test_losses = []
|
| plt_train_loss_epoch = []
|
| plt_test_loss_epoch = []
|
| train_epoch_list = list(range(0, config.epoch_number))
|
| test_epoch_list = list(range(0, int(config.epoch_number / config.test_epoch_interval)))
|
|
|
| VQGAN = torch.load("2025-02-04-Mask-JSRT-VQGAN.pth").to(device).eval()
|
| print(time.strftime("%H:%M:%S", time.localtime()), "----------Begin Training----------", file=file)
|
| for epoch in range(config.epoch_number):
|
| model.train()
|
| print(time.strftime("%H:%M:%S", time.localtime()),
|
| f"Epoch:{epoch},learning rate:{optimizer.param_groups[0]['lr']}", file=file)
|
| for i, batch in tqdm(enumerate(myTrainLoader)):
|
| cxr_i, bs_i = batch[0].to(device), batch[1].to(device)
|
|
|
| with torch.no_grad():
|
| cxr = VQGAN.encode_stage_2_inputs(cxr_i)
|
| bs = VQGAN.encode_stage_2_inputs(bs_i)
|
|
|
| cat = torch.cat((bs, cxr), dim=-3)
|
|
|
|
|
| if config.offset_noise:
|
| noise = torch.randn_like(cxr).to(device) + config.offset_noise_coefficient * torch.randn(
|
| cxr.shape[0], cxr.shape[1], 1,
|
| 1).to(device)
|
| else:
|
| noise = torch.randn_like(cxr).to(device)
|
|
|
| blank = torch.zeros_like(cxr).to(device)
|
| noise = torch.cat((noise, blank), dim=-3)
|
|
|
|
|
| timesteps = torch.randint(0, config.num_train_timesteps, (cxr.shape[0],), device=device).long()
|
|
|
|
|
| noisy_images = noise_scheduler.add_noise(cat, noise, timesteps)
|
|
|
|
|
| noise_pred = model(noisy_images, timesteps)
|
|
|
|
|
| loss = F.mse_loss(noise_pred[:, :4].float(), noise[:, :4].float())
|
|
|
| loss.backward()
|
| train_losses.append(loss.item())
|
|
|
|
|
| optimizer.step()
|
| optimizer.zero_grad()
|
| optimizer_scheduler.step()
|
|
|
| train_loss_epoch = sum(train_losses[-len(myTrainLoader):]) / len(myTrainLoader)
|
| print(time.strftime("%H:%M:%S", time.localtime()), f"Epoch:{epoch},train losses:{train_loss_epoch}", file=file)
|
| plt_train_loss_epoch.append(train_loss_epoch)
|
|
|
| if (epoch + 1) % config.test_epoch_interval == 0:
|
| model.eval()
|
| print(time.strftime("%H:%M:%S", time.localtime()), "----------Stop Training----------", file=file)
|
| print(time.strftime("%H:%M:%S", time.localtime()), "----------Begin Testing----------", file=file)
|
| with torch.no_grad():
|
| for i, batch in tqdm(enumerate(myTestLoader)):
|
| cxr_i, bs_i = batch[0].to(device), batch[1].to(device)
|
|
|
| with torch.no_grad():
|
| cxr = VQGAN.encode_stage_2_inputs(cxr_i)
|
| bs = VQGAN.encode_stage_2_inputs(bs_i)
|
|
|
| cat = torch.cat((bs, cxr), dim=-3)
|
|
|
|
|
| if config.offset_noise:
|
| noise = torch.randn_like(cxr).to(device) + config.offset_noise_coefficient * torch.randn(
|
| cxr.shape[0],
|
| cxr.shape[1], 1,
|
| 1).to(device)
|
| else:
|
| noise = torch.randn_like(cxr).to(device)
|
|
|
| blank = torch.zeros_like(cxr).to(device)
|
| noise = torch.cat((noise, blank), dim=-3)
|
|
|
|
|
| timesteps = torch.randint(0, config.num_train_timesteps, (cxr.shape[0],),
|
| device=device).long()
|
|
|
|
|
| noisy_images = noise_scheduler.add_noise(cat, noise, timesteps)
|
|
|
|
|
| noise_pred = model(noisy_images, timesteps)
|
|
|
|
|
| loss = F.mse_loss(noise_pred[:, :4].float(), noise[:, :4].float())
|
|
|
| test_losses.append(loss.item())
|
|
|
| test_loss_epoch = sum(test_losses[-len(myTestLoader):]) / len(myTestLoader)
|
| print(time.strftime("%H:%M:%S", time.localtime()), f"Epoch:{epoch},test losses:{test_loss_epoch}",
|
| file=file)
|
| plt_test_loss_epoch.append(test_loss_epoch)
|
| print(time.strftime("%H:%M:%S", time.localtime()), "----------End Validation----------", file=file)
|
| print(time.strftime("%H:%M:%S", time.localtime()), "----------Continue to Train----------",
|
| file=file)
|
| print(time.strftime("%H:%M:%S", time.localtime()), "----------End Training Normally----------", file=file)
|
|
|
| f, ([ax1, ax2]) = plt.subplots(1, 2)
|
| ax1.plot(train_epoch_list, plt_train_loss_epoch, color="red")
|
| ax1.set_title('Train loss')
|
| ax2.plot(test_epoch_list, plt_test_loss_epoch, color="blue")
|
| ax2.set_title('Test loss')
|
| plt.savefig("./loss.png")
|
| if not config.use_server:
|
| plt.show()
|
| torch.save(model, "masked_lcm-600JSRT-" + str(date.today()) + "-myModel.pth")
|
|
|
|
|
| if __name__ == "__main__":
|
| train()
|
|
|