import os import gc import torch import torchvision from torch import nn from torchvision.utils import save_image from torch.utils.data import DataLoader from torch.optim import Adam, SGD from Diffusion.diffuser import DeformDDPM from Diffusion.networks import get_net, STN from torchvision.transforms import Lambda import Diffusion.losses as losses import random import glob import numpy as np import utils from Dataloader.dataloader0 import get_dataloader from Dataloader.dataLoader import * from Dataloader.dataloader_utils import thresh_img import yaml import argparse #################### import torch.multiprocessing as mp from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP import torch.distributed as dist # from torch.distributed import init_process_group ############### def ddp_setup(rank, world_size): """ Args: rank: Unique identifier of each process world_size: Total number of processes """ os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) use_distributed = True # use_distributed = False EPS = 1e-5 parser = argparse.ArgumentParser() # config_file_path = 'Config/config_cmr.yaml' parser.add_argument( "--config", "-C", help="Path for the config file", type=str, # default="Config/config_cmr.yaml", # default="Config/config_lct.yaml", default="Config/config_all.yaml", required=False, ) args = parser.parse_args() #======================================================================================================================= def main_train(rank=0,world_size=1): if use_distributed: ddp_setup(rank,world_size) gpu_id = rank # Load the YAML file into a dictionary with open(args.config, 'r') as file: hyp_parameters = yaml.safe_load(file) print(hyp_parameters) # epoch_per_save=10 epoch_per_save=hyp_parameters['epoch_per_save'] data_name=hyp_parameters['data_name'] net_name = hyp_parameters['net_name'] Net=get_net(net_name) suffix_pth=f'_{data_name}_{net_name}.pth' model_save_path = os.path.join('Models',f'{data_name}_{net_name}/') model_dir=model_save_path transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']]) # Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train') # tsfm = torchvision.transforms.Compose([ # torchvision.transforms.ToTensor(), # ]) # dataset = Data_Loader(target_res = [hyp_parameters["img_size"]]*hyp_parameters["ndims"], transforms=None, noise_scale=hyp_parameters['noise_scale']) # train_loader = DataLoader( # dataset, # batch_size=hyp_parameters['batchsize'], # # shuffle=False, # shuffle=True, # drop_last=True, # ) dataset = OminiDataset_v1(transform=None) train_loader = DataLoader( dataset, batch_size=hyp_parameters['batchsize'], shuffle=True, drop_last=True, ) Deformddpm = DeformDDPM( network=Net( n_steps=hyp_parameters["timesteps"], ndims=hyp_parameters["ndims"], num_input_chn = hyp_parameters["num_input_chn"], res = hyp_parameters['img_size'] ), n_steps=hyp_parameters["timesteps"], image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"], device=hyp_parameters["device"], batch_size=hyp_parameters["batchsize"], img_pad_mode=hyp_parameters["img_pad_mode"], v_scale=hyp_parameters["v_scale"], ) ddf_stn = STN( img_sz=hyp_parameters["img_size"], ndims=hyp_parameters["ndims"], # padding_mode="zeros", padding_mode=hyp_parameters["padding_mode"], device=hyp_parameters["device"], ) if use_distributed: Deformddpm.to(rank) Deformddpm = DDP(Deformddpm, device_ids=[rank]) ddf_stn.to(rank) else: Deformddpm.to(hyp_parameters["device"]) ddf_stn.to(hyp_parameters["device"]) # ddf_stn = DDP(ddf_stn, device_ids=[rank]) # mse = nn.MSELoss() loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"]) loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"]) # loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"]) loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"]) optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"]) # hyp_parameters["lr"]=0.00000001 # # optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.95) # optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.9) # # LR scheduler ----- YHM # scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, hyp_parameters["lr"], hyp_parameters["lr"]*10, step_size_up=500, step_size_down=500, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1) # Deformddpm.network.load_state_dict(torch.load('/home/data/jzheng/Adaptive_Motion_Generator-master/models/1000.pth')) # check for existing models if not os.path.exists(model_dir): os.makedirs(model_dir, exist_ok=True) model_files = glob.glob(os.path.join(model_dir, "*.pth")) model_files.sort() if model_files: if gpu_id == 0: print(model_files) initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, model_files[-1]) else: initial_epoch = 0 if gpu_id == 0: print('len_train_data: ',len(dataset)) for epoch in range(initial_epoch,hyp_parameters["epoch"]): epoch_loss_tot = 0.0 epoch_loss_gen_d = 0.0 epoch_loss_gen_a = 0.0 epoch_loss_reg = 0.0 # Set model inside to train model Deformddpm.train() for step, batch in enumerate(train_loader): # for step, batch in enumerate(train_loader_omni): # x0, _ = batch x0 = batch # for omni dataset x0 = x0.to(hyp_parameters["device"]).type(torch.float32) n = x0.size()[0] # batch_size -> n x0 = x0.to(hyp_parameters["device"]) blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"]) # random deformation + rotation if hyp_parameters["ndims"]>2: if np.random.uniform(0,1)<0.6: x0 = utils.random_resample(x0, deform_scale=0) x0 = transformer(x0) if hyp_parameters['noise_scale']>0: x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']]) x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1) # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to( hyp_parameters["device"] ) # pick up a seq of rand number from 0 to 'timestep' pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, mask=blind_mask) # forward diffusion process loss_tot=0 loss_ddf = loss_reg(pre_dvf_I) trm_pred = ddf_stn(pre_dvf_I, dvf_I) loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask) loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask) loss_tot += 1. * loss_gen_d + 1. * loss_gen_a loss_tot += 1.0 * loss_ddf optimizer.zero_grad() loss_tot.backward() optimizer.step() epoch_loss_tot += loss_tot.item() * len(x0) / len(train_loader.dataset) epoch_loss_gen_d += loss_gen_d.item() * len(x0) / len(train_loader.dataset) epoch_loss_gen_a += loss_gen_a.item() * len(x0) / len(train_loader.dataset) epoch_loss_reg += loss_ddf.item() * len(x0) / len(train_loader.dataset) # print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item()) # break # FOR TESTING if gpu_id == 0: print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)') # # LR schedular step ----- YHM # scheduler.step() if 0 == epoch % epoch_per_save: save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth os.makedirs(os.path.dirname(model_save_path), exist_ok=True) # break # FOR TESTING if not use_distributed: print(f"saved in {save_dir}") # torch.save(Deformddpm.state_dict(), save_dir) torch.save({ 'model_state_dict': Deformddpm.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch }, save_dir) elif gpu_id == 0: print(f"saved in {save_dir}") # torch.save(Deformddpm.module.state_dict(), save_dir) torch.save({ 'model_state_dict': Deformddpm.module.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch }, save_dir) def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True): if gpu_id == 0: # if 0: utils.print_memory_usage("Before Loading Model") if 1: gc.collect() torch.cuda.empty_cache() # Deformddpm.network.load_state_dict(torch.load(latest_model_file)) # Deformddpm.load_state_dict(torch.load(latest_model_file), strict=False) checkpoint = torch.load(model_file) # checkpoint = torch.load(latest_model_file, map_location=f"cuda:{rank}") if use_distributed: Deformddpm.module.load_state_dict(checkpoint['model_state_dict']) else: Deformddpm.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) utils.print_memory_usage("After Loading Checkpoint on GPU") if use_distributed: # Broadcast model weights from rank 0 to all other GPUs dist.barrier() for param in Deformddpm.parameters(): dist.broadcast(param.data, src=0) # Synchronize model across ranks dist.barrier() for param_group in optimizer.param_groups: for param in param_group['params']: if param.grad is not None: dist.broadcast(param.grad, src=0) # Sync optimizer gradients # initial_epoch = checkpoint['epoch'] + 1 # get the epoch number from the filename and add 1 to set as initial_epoch initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1 return initial_epoch, Deformddpm, optimizer if __name__ == "__main__": if use_distributed: world_size = torch.cuda.device_count() print(f"Distributed GPU number = {world_size}") mp.spawn(main_train,args = (world_size,),nprocs = world_size) else: main_train(0,1)