| import os
|
| import torch
|
| import torchvision
|
| from torch import nn
|
| from torchvision.utils import save_image
|
| from torch.utils.data import DataLoader
|
| from torch.optim import Adam, SGD
|
| from Diffusion.diffuser import DeformDDPM
|
| from Diffusion.networks import get_net, STN
|
| from torchvision.transforms import Lambda
|
| import Diffusion.losses as losses
|
| import random
|
| import glob
|
| import numpy as np
|
| import utils
|
|
|
| from Dataloader.dataloader0 import get_dataloader
|
|
|
| from Dataloader.dataloader_utils import thresh_img
|
| import yaml
|
| import argparse
|
|
|
|
|
| import torch.multiprocessing as mp
|
| from torch.utils.data.distributed import DistributedSampler
|
| from torch.nn.parallel import DistributedDataParallel as DDP
|
| from torch.distributed import init_process_group, destroy_process_group
|
|
|
| def ddp_setup(rank, world_size):
|
| """
|
| Args:
|
| rank: Unique identifier of each process
|
| world_size: Total number of processes
|
| """
|
| os.environ["MASTER_ADDR"] = "localhost"
|
| os.environ["MASTER_PORT"] = "12355"
|
| init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
| torch.cuda.set_device(rank)
|
|
|
| use_parallel=False
|
| use_distributed = False
|
|
|
| EPS = 1e-5
|
|
|
| parser = argparse.ArgumentParser()
|
|
|
|
|
| parser.add_argument(
|
| "--config",
|
| "-C",
|
| help="Path for the config file",
|
| type=str,
|
| default="Config/config_cmr.yaml",
|
|
|
| required=False,
|
| )
|
| args = parser.parse_args()
|
|
|
|
|
|
|
|
|
| def main_train(rank,world_size):
|
|
|
| ddp_setup(rank,world_size)
|
| gpu_id = rank
|
|
|
|
|
| with open(args.config, 'r') as file:
|
| hyp_parameters = yaml.safe_load(file)
|
| print(hyp_parameters)
|
|
|
|
|
|
|
|
|
| epoch_per_save=hyp_parameters['epoch_per_save']
|
|
|
| data_name=hyp_parameters['data_name']
|
| net_name = hyp_parameters['net_name']
|
|
|
| Net=get_net(net_name)
|
|
|
| suffix_pth=f'_{data_name}_{net_name}.pth'
|
| model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
|
| model_dir=model_save_path
|
| transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
|
| Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
|
|
|
| tsfm = torchvision.transforms.Compose([
|
| torchvision.transforms.ToTensor(),
|
| ])
|
|
|
|
|
| dataset = Data_Loader(target_res = [hyp_parameters["img_size"]]*hyp_parameters["ndims"], transforms=None, noise_scale=hyp_parameters['noise_scale'])
|
| train_loader = DataLoader(
|
| dataset,
|
| batch_size=hyp_parameters['batchsize'],
|
|
|
| shuffle=True,
|
| drop_last=True,
|
| )
|
|
|
|
|
|
|
| Deformddpm = DeformDDPM(
|
| network=Net(n_steps=hyp_parameters["timesteps"], ndims=hyp_parameters["ndims"], num_input_chn=1),
|
| n_steps=hyp_parameters["timesteps"],
|
| image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
|
| device=hyp_parameters["device"],
|
| batch_size=hyp_parameters["batchsize"],
|
| img_pad_mode=hyp_parameters["img_pad_mode"],
|
| v_scale=hyp_parameters["v_scale"],
|
| )
|
|
|
|
|
| ddf_stn = STN(
|
| img_sz=hyp_parameters["img_size"],
|
| ndims=hyp_parameters["ndims"],
|
|
|
| padding_mode=hyp_parameters["padding_mode"],
|
| device=hyp_parameters["device"],
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| Deformddpm.to(rank)
|
| Deformddpm = DDP(Deformddpm, device_ids=[rank])
|
| ddf_stn.to(rank)
|
|
|
|
|
|
|
|
|
| loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
|
| loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
|
|
| loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
|
|
|
| optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if not os.path.exists(model_dir):
|
| os.makedirs(model_dir, exist_ok=True)
|
| model_files = glob.glob(os.path.join(model_dir, "*.pth"))
|
| model_files.sort()
|
| print(model_files)
|
| if model_files:
|
|
|
| latest_model_file = model_files[-1]
|
|
|
| if use_parallel:
|
| Deformddpm.module.load_state_dict(torch.load(latest_model_file), strict=False)
|
| else:
|
| Deformddpm.load_state_dict(torch.load(latest_model_file), strict=False)
|
|
|
| initial_epoch = int(os.path.basename(latest_model_file).split('.')[0][:6]) + 1
|
| else:
|
| initial_epoch = 0
|
| print('len_train_data: ',len(dataset))
|
| for epoch in range(initial_epoch,hyp_parameters["epoch"]):
|
|
|
| epoch_loss_tot = 0.0
|
| epoch_loss_gen_d = 0.0
|
| epoch_loss_gen_a = 0.0
|
| epoch_loss_reg = 0.0
|
|
|
| Deformddpm.train()
|
|
|
| for step, batch in enumerate(train_loader):
|
|
|
| x0, _, _ = batch
|
| x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
|
|
|
| n = x0.size()[0]
|
| x0 = x0.to(hyp_parameters["device"])
|
|
|
| if hyp_parameters["ndims"]>2:
|
| if np.random.uniform(0,1)<0.6:
|
| x0 = utils.random_resample(x0, deform_scale=0)
|
| x0 = transformer(x0)
|
| if hyp_parameters['noise_scale']>0:
|
| x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
|
| x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
|
|
|
|
| t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
|
| hyp_parameters["device"]
|
| )
|
|
|
|
|
| if use_parallel:
|
|
|
|
|
|
|
| pre_dvf_I, _ = Deformddpm.module(x0, t)
|
| else:
|
|
|
|
|
|
|
| pre_dvf_I,dvf_I = Deformddpm(x0, t)
|
|
|
| loss_tot=0
|
|
|
| loss_ddf = loss_reg(pre_dvf_I)
|
| trm_pred = ddf_stn(pre_dvf_I, dvf_I)
|
| loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None)
|
| loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None)
|
|
|
| loss_tot += 1.0 * loss_gen_d + 1.0 * loss_gen_a
|
| loss_tot +=10 * loss_ddf
|
| optimizer.zero_grad()
|
| loss_tot.backward()
|
| optimizer.step()
|
|
|
| epoch_loss_tot += loss_tot.item() * len(x0) / len(train_loader.dataset)
|
| epoch_loss_gen_d += loss_gen_d.item() * len(x0) / len(train_loader.dataset)
|
| epoch_loss_gen_a += loss_gen_a.item() * len(x0) / len(train_loader.dataset)
|
| epoch_loss_reg += loss_ddf.item() * len(x0) / len(train_loader.dataset)
|
|
|
|
|
| if gpu_id == 0:
|
| print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
|
|
|
|
|
|
|
|
|
| if 0 == epoch % epoch_per_save:
|
| save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
|
| if os.path.exists(model_save_path):
|
| print(f"saved in {save_dir}")
|
| else:
|
| os.makedirs(os.path.dirname(model_save_path))
|
|
|
| if use_parallel:
|
| torch.save(Deformddpm.module.state_dict(), save_dir)
|
| elif gpu_id == 0:
|
| torch.save(Deformddpm.module.state_dict(), save_dir)
|
|
|
|
|
| if __name__ == "__main__":
|
| world_size = torch.cuda.device_count()
|
| print(f"world size = {world_size}")
|
| mp.spawn(main_train,args = (world_size,),nprocs = world_size) |