| """
|
| Train a diffusion model on images with proper distributed training support.
|
| """
|
| import sys
|
| import os
|
| import argparse
|
| import numpy as np
|
| import socket
|
| from datetime import datetime
|
|
|
| sys.path.append("..")
|
| sys.path.append(".")
|
| from guided_diffusion import dist_util, logger |
| from guided_diffusion.resample import create_named_schedule_sampler |
| from guided_diffusion.custom_lidc_dataset import CustomLIDCDataset |
| from guided_diffusion.script_util import ( |
| model_and_diffusion_defaults, |
| create_model_and_diffusion, |
| args_to_dict,
|
| add_dict_to_argparser,
|
| )
|
| from scripts.metrics import model_size
|
| import torch as th
|
| from torch import nn
|
| import torch.distributed as dist
|
| import torch.multiprocessing as mp
|
| from guided_diffusion.train_util import TrainLoop
|
|
|
|
|
| def expanduservars(path: str) -> str:
|
| return os.path.expanduser(os.path.expandvars(path))
|
|
|
|
|
| def find_free_port():
|
| """Find a free port for distributed training."""
|
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| s.bind(('', 0))
|
| s.listen(1)
|
| port = s.getsockname()[1]
|
| return port
|
|
|
|
|
| def setup_distributed(rank, world_size, port):
|
| """Initialize the distributed environment."""
|
| os.environ['MASTER_ADDR'] = 'localhost'
|
| os.environ['MASTER_PORT'] = str(port)
|
|
|
|
|
| if th.cuda.is_available():
|
| backend = 'nccl'
|
| else:
|
| backend = 'gloo'
|
|
|
| dist.init_process_group(backend, rank=rank, world_size=world_size)
|
|
|
| if th.cuda.is_available():
|
| th.cuda.set_device(rank)
|
|
|
|
|
| def cleanup_distributed():
|
| """Clean up the distributed environment."""
|
| dist.destroy_process_group()
|
|
|
|
|
| def get_device(rank):
|
| """Get the appropriate device for the given rank."""
|
| if th.cuda.is_available():
|
| return th.device(f'cuda:{rank}')
|
| return th.device('cpu')
|
|
|
|
|
| def worker_init_fn(worker_id):
|
| """Seed dataloader workers for reproducibility and ensure consistent tensor types."""
|
| np.random.seed(th.initial_seed() % 2 ** 32)
|
|
|
| th.set_default_dtype(th.float32)
|
|
|
|
|
| def run_train(rank, world_size, port, args):
|
| """Main training function for each process."""
|
|
|
| th.set_default_dtype(th.float32)
|
|
|
|
|
| if world_size > 1:
|
| setup_distributed(rank, world_size, port)
|
|
|
|
|
| if rank == 0:
|
| if args.use_mose_dataset:
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| run_name = f"train_{args.dataset_type}_{timestamp}"
|
| out_dir = f"./results/{args.dataset_type}/{run_name}"
|
| logger.configure(dir=out_dir)
|
| logger.log(f"Log dir: {out_dir}")
|
| else:
|
| logger.configure()
|
|
|
| logger.log(f"Arguments: {args.__dict__}")
|
| logger.log(f"World size: {world_size}, Rank: {rank}")
|
|
|
| logger.log("creating model, diffusion, prior and posterior distribution...")
|
| model, diffusion, prior, posterior = create_model_and_diffusion(
|
| **args_to_dict(args, model_and_diffusion_defaults().keys())
|
| )
|
|
|
|
|
| device = get_device(rank) if world_size > 1 else (th.device("cuda") if th.cuda.is_available() else th.device("cpu"))
|
| model.to(device)
|
| prior.to(device)
|
| posterior.to(device)
|
|
|
|
|
| model = model.float()
|
| prior = prior.float()
|
| posterior = posterior.float()
|
|
|
|
|
| if world_size > 1:
|
| model = th.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
| model = th.nn.parallel.DistributedDataParallel(
|
| model,
|
| device_ids=[rank],
|
| output_device=rank,
|
| find_unused_parameters=True
|
| )
|
|
|
| schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion, maxt=1000)
|
|
|
| if rank == 0:
|
| try:
|
| model_size(model, diffusion, prior, posterior, logger)
|
| except Exception as e:
|
| logger.log(f"Could not compute model size: {e}")
|
|
|
|
|
| logger.log("Loading dataset...")
|
| if not args.use_mose_dataset: |
| if args.dataset_type in {"lidc", "multiannotator"}: |
| ds = CustomLIDCDataset( |
| data_root=args.data_dir, |
| split="train", |
| image_size=args.image_size, |
| dataset_type=args.dataset_type, |
| split_strategy=args.split_strategy, |
| ) |
| logger.log( |
| f"Using {args.dataset_type} dataset from: {args.data_dir} " |
| f"(strategy={args.split_strategy})" |
| ) |
| else: |
| |
| from guided_diffusion.lidcloader import LIDCDataset |
| ds = LIDCDataset(args.data_dir, test_flag=False) |
| logger.log(f"Using converted legacy dataset from: {args.data_dir}") |
| else: |
| |
| if args.dataset_type == "lidc": |
| from guided_diffusion.lidcloader_mose import lidc_Dataloader |
| ds = lidc_Dataloader( |
| data_folder=args.lidc_data_folder, |
| transform_train=None, |
| transform_test=None |
| ).train_ds |
| logger.log(f"Using LIDC NPY dataset from: {args.lidc_data_folder}") |
| elif args.dataset_type == "msmri": |
| from guided_diffusion.msmri_dataset_mose import msmri_Dataloader |
| ds = msmri_Dataloader( |
| data_folder=args.msmri_data_folder, |
| transform_train=None, |
| transform_test=None |
| ).train_ds |
| logger.log(f"Using MSMRI NPY dataset from: {args.msmri_data_folder}")
|
| else:
|
| raise ValueError(f"Unknown dataset type: {args.dataset_type}")
|
|
|
| logger.log(f"Dataset size: {len(ds)} samples")
|
|
|
|
|
| if world_size > 1:
|
| train_sampler = th.utils.data.distributed.DistributedSampler(
|
| ds,
|
| rank=rank,
|
| num_replicas=world_size,
|
| shuffle=True
|
| )
|
| batch_size = args.batch_size // world_size
|
| else:
|
| train_sampler = None
|
| batch_size = args.batch_size
|
|
|
|
|
| datal = th.utils.data.DataLoader(
|
| ds,
|
| batch_size=batch_size,
|
| sampler=train_sampler,
|
| shuffle=(train_sampler is None),
|
| drop_last=True,
|
| pin_memory=th.cuda.is_available(),
|
| num_workers=args.mp_loaders,
|
| worker_init_fn=worker_init_fn
|
| )
|
| data = iter(datal)
|
|
|
| if rank == 0:
|
| logger.log("Starting training...")
|
|
|
|
|
| TrainLoop( |
| model=model, |
| diffusion=diffusion, |
| classifier=None, |
| data=data,
|
| dataloader=datal,
|
| prior=prior,
|
| posterior=posterior,
|
| batch_size=batch_size,
|
| microbatch=args.microbatch,
|
| lr=args.lr,
|
| ema_rate=args.ema_rate,
|
| log_interval=args.log_interval,
|
| save_interval=args.save_interval,
|
| resume_checkpoint=args.resume_checkpoint,
|
| use_fp16=args.use_fp16,
|
| fp16_scale_growth=args.fp16_scale_growth,
|
| schedule_sampler=schedule_sampler, |
| weight_decay=args.weight_decay, |
| lr_anneal_steps=args.lr_anneal_steps, |
| total_steps=args.total_steps, |
| ).run_loop() |
|
|
|
|
| if world_size > 1:
|
| cleanup_distributed()
|
|
|
|
|
| def create_argparser():
|
| defaults = dict(
|
| data_dir="./data/training",
|
| schedule_sampler="uniform",
|
| lr=1e-4,
|
| weight_decay=0.0,
|
| lr_anneal_steps=0,
|
| batch_size=8,
|
| microbatch=-1,
|
| ema_rate="0.9999",
|
| log_interval=100,
|
| save_interval=25000,
|
| mp_loaders=4,
|
| resume_checkpoint='',
|
| use_fp16=False,
|
| fp16_scale_growth=1e-3,
|
| use_mose_dataset=False, |
| dataset_type="lidc", |
| split_strategy="all_annotations", |
| lidc_data_folder="./data/lidc_npy", |
| msmri_data_folder="./data/msmri_npy",
|
| world_size=1, |
| total_steps=50000, |
| ) |
| defaults.update(model_and_diffusion_defaults())
|
| parser = argparse.ArgumentParser()
|
| add_dict_to_argparser(parser, defaults)
|
| return parser
|
|
|
|
|
| def main():
|
|
|
| th.set_default_dtype(th.float32)
|
|
|
| args = create_argparser().parse_args()
|
|
|
|
|
| os.environ.pop("SLURM_JOBID", None)
|
|
|
|
|
| if th.cuda.is_available():
|
| world_size = min(args.world_size, th.cuda.device_count())
|
| print(f"Using {world_size} GPU(s)")
|
| else:
|
| world_size = 1
|
| print("Using CPU")
|
|
|
| if world_size > 1:
|
|
|
| port = find_free_port()
|
| print(f"Starting distributed training on port {port}")
|
|
|
|
|
| mp.spawn(run_train, args=(world_size, port, args), nprocs=world_size, join=True)
|
| else:
|
|
|
| print("Starting single process training")
|
|
|
| run_train(0, 1, find_free_port(), args)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|