""" Train a diffusion model on images with proper distributed training support. """ import sys import os import argparse import numpy as np import socket from datetime import datetime sys.path.append("..") sys.path.append(".") from guided_diffusion import dist_util, logger from guided_diffusion.resample import create_named_schedule_sampler from guided_diffusion.custom_lidc_dataset import CustomLIDCDataset from guided_diffusion.script_util import ( model_and_diffusion_defaults, create_model_and_diffusion, args_to_dict, add_dict_to_argparser, ) from scripts.metrics import model_size import torch as th from torch import nn import torch.distributed as dist import torch.multiprocessing as mp from guided_diffusion.train_util import TrainLoop def expanduservars(path: str) -> str: return os.path.expanduser(os.path.expandvars(path)) def find_free_port(): """Find a free port for distributed training.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(('', 0)) s.listen(1) port = s.getsockname()[1] return port def setup_distributed(rank, world_size, port): """Initialize the distributed environment.""" os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = str(port) # Initialize the process group if th.cuda.is_available(): backend = 'nccl' else: backend = 'gloo' dist.init_process_group(backend, rank=rank, world_size=world_size) if th.cuda.is_available(): th.cuda.set_device(rank) def cleanup_distributed(): """Clean up the distributed environment.""" dist.destroy_process_group() def get_device(rank): """Get the appropriate device for the given rank.""" if th.cuda.is_available(): return th.device(f'cuda:{rank}') return th.device('cpu') def worker_init_fn(worker_id): """Seed dataloader workers for reproducibility and ensure consistent tensor types.""" np.random.seed(th.initial_seed() % 2 ** 32) # Ensure workers use float32 as default th.set_default_dtype(th.float32) def run_train(rank, world_size, port, args): """Main training function for each process.""" # Set default tensor type to float32 th.set_default_dtype(th.float32) # Setup distributed training only if world_size > 1 if world_size > 1: setup_distributed(rank, world_size, port) # Configure logging (only for rank 0) if rank == 0: if args.use_mose_dataset: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") run_name = f"train_{args.dataset_type}_{timestamp}" out_dir = f"./results/{args.dataset_type}/{run_name}" logger.configure(dir=out_dir) logger.log(f"Log dir: {out_dir}") else: logger.configure() logger.log(f"Arguments: {args.__dict__}") logger.log(f"World size: {world_size}, Rank: {rank}") logger.log("creating model, diffusion, prior and posterior distribution...") model, diffusion, prior, posterior = create_model_and_diffusion( **args_to_dict(args, model_and_diffusion_defaults().keys()) ) # Move models to appropriate device device = get_device(rank) if world_size > 1 else (th.device("cuda") if th.cuda.is_available() else th.device("cpu")) model.to(device) prior.to(device) posterior.to(device) # Ensure all models use float32 model = model.float() prior = prior.float() posterior = posterior.float() # Setup distributed model if world_size > 1: model = th.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = th.nn.parallel.DistributedDataParallel( model, device_ids=[rank], output_device=rank, find_unused_parameters=True ) schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion, maxt=1000) if rank == 0: try: model_size(model, diffusion, prior, posterior, logger) except Exception as e: logger.log(f"Could not compute model size: {e}") # Create dataset logger.log("Loading dataset...") if not args.use_mose_dataset: if args.dataset_type in {"lidc", "multiannotator"}: ds = CustomLIDCDataset( data_root=args.data_dir, split="train", image_size=args.image_size, dataset_type=args.dataset_type, split_strategy=args.split_strategy, ) logger.log( f"Using {args.dataset_type} dataset from: {args.data_dir} " f"(strategy={args.split_strategy})" ) else: # Backward-compatible fallback to legacy folder-structured loader. from guided_diffusion.lidcloader import LIDCDataset ds = LIDCDataset(args.data_dir, test_flag=False) logger.log(f"Using converted legacy dataset from: {args.data_dir}") else: # Modern dataset (for future use with NPY format) if args.dataset_type == "lidc": from guided_diffusion.lidcloader_mose import lidc_Dataloader ds = lidc_Dataloader( data_folder=args.lidc_data_folder, transform_train=None, transform_test=None ).train_ds logger.log(f"Using LIDC NPY dataset from: {args.lidc_data_folder}") elif args.dataset_type == "msmri": from guided_diffusion.msmri_dataset_mose import msmri_Dataloader ds = msmri_Dataloader( data_folder=args.msmri_data_folder, transform_train=None, transform_test=None ).train_ds logger.log(f"Using MSMRI NPY dataset from: {args.msmri_data_folder}") else: raise ValueError(f"Unknown dataset type: {args.dataset_type}") logger.log(f"Dataset size: {len(ds)} samples") # Setup distributed sampler if world_size > 1: train_sampler = th.utils.data.distributed.DistributedSampler( ds, rank=rank, num_replicas=world_size, shuffle=True ) batch_size = args.batch_size // world_size else: train_sampler = None batch_size = args.batch_size # Create data loader datal = th.utils.data.DataLoader( ds, batch_size=batch_size, sampler=train_sampler, shuffle=(train_sampler is None), drop_last=True, pin_memory=th.cuda.is_available(), num_workers=args.mp_loaders, worker_init_fn=worker_init_fn ) data = iter(datal) if rank == 0: logger.log("Starting training...") # Start training TrainLoop( model=model, diffusion=diffusion, classifier=None, data=data, dataloader=datal, prior=prior, posterior=posterior, batch_size=batch_size, microbatch=args.microbatch, lr=args.lr, ema_rate=args.ema_rate, log_interval=args.log_interval, save_interval=args.save_interval, resume_checkpoint=args.resume_checkpoint, use_fp16=args.use_fp16, fp16_scale_growth=args.fp16_scale_growth, schedule_sampler=schedule_sampler, weight_decay=args.weight_decay, lr_anneal_steps=args.lr_anneal_steps, total_steps=args.total_steps, ).run_loop() # Cleanup distributed training if world_size > 1: cleanup_distributed() def create_argparser(): defaults = dict( data_dir="./data/training", # This uses our converted data! schedule_sampler="uniform", lr=1e-4, weight_decay=0.0, lr_anneal_steps=0, batch_size=8, # Will be scaled with world_size microbatch=-1, # -1 disables microbatches ema_rate="0.9999", # comma-separated list of EMA values log_interval=100, save_interval=25000, # Save every 25k steps as requested mp_loaders=4, resume_checkpoint='', use_fp16=False, fp16_scale_growth=1e-3, use_mose_dataset=False, # FALSE = Use our converted data in ./data/training dataset_type="lidc", # "lidc" or "msmri" split_strategy="all_annotations", lidc_data_folder="./data/lidc_npy", # Only used if use_mose_dataset=True msmri_data_folder="./data/msmri_npy", # Only used if use_mose_dataset=True world_size=1, # Number of GPUs/processes total_steps=50000, ) defaults.update(model_and_diffusion_defaults()) parser = argparse.ArgumentParser() add_dict_to_argparser(parser, defaults) return parser def main(): # Set default tensor type to float32 globally th.set_default_dtype(th.float32) args = create_argparser().parse_args() # Clean up SLURM environment variables that might interfere os.environ.pop("SLURM_JOBID", None) # Check available GPUs if th.cuda.is_available(): world_size = min(args.world_size, th.cuda.device_count()) print(f"Using {world_size} GPU(s)") else: world_size = 1 print("Using CPU") if world_size > 1: # Multi-GPU distributed training using mp.spawn port = find_free_port() print(f"Starting distributed training on port {port}") # Spawn processes for distributed training mp.spawn(run_train, args=(world_size, port, args), nprocs=world_size, join=True) else: # Single GPU/CPU training print("Starting single process training") # For single process, we still call run_train but without distributed setup run_train(0, 1, find_free_port(), args) if __name__ == "__main__": main()